当前位置: 首页 > news >正文

秃姐学AI系列之:实战Kaggle比赛:狗的品种识别(ImageNet Dogs)

目录

前置准备

整理数据集

图片增广

读取数据集

微调预训练模型

训练函数

训练和验证模型

Kaggle提交结果


前置准备

常规导包

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

使用小规模数据样本

d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip','0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')# 如果使用Kaggle比赛的完整数据集,请将下面的变量更改为False
demo = True
if demo:data_dir = d2l.download_extract('dog_tiny')
else:data_dir = os.path.join('..', 'data', 'dog-breed-identification')

整理数据集

复用上一个 CIFAR-10 的处理一样,即从原始训练集中拆分验证集,然后将图像移动到按标签分组的子文件夹中。

下面的reorg_dog_data函数读取训练数据标签、拆分验证集并整理训练集。

def reorg_dog_data(data_dir, valid_ratio):labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))d2l.reorg_train_valid(data_dir, labels, valid_ratio)d2l.reorg_test(data_dir)batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)

图片增广

这个狗品种数据集是 ImageNet 数据集的子集,其图像大于 CIFAR-10 数据集的图像。

transform_train = torchvision.transforms.Compose([# 随机裁剪图像,所得图像为原始面积的0.08~1之间,高宽比在3/4和4/3之间。# 然后,缩放图像以创建224x224的新图像torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0),ratio=(3.0/4.0, 4.0/3.0)),torchvision.transforms.RandomHorizontalFlip(),# 随机更改亮度,对比度和饱和度torchvision.transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4),# 添加随机噪声torchvision.transforms.ToTensor(),# 标准化图像的每个通道torchvision.transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])# 测试时,我们只使用确定性的图像预处理操作
transform_test = torchvision.transforms.Compose([torchvision.transforms.Resize(256),# 从图像中心裁切224x224大小的图片torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

读取数据集

都和 CIFAR-10一样

train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']]valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_test) for folder in ['valid', 'test']]train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True)for dataset in (train_ds, train_valid_ds)]valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,drop_last=True)test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,drop_last=False)

微调预训练模型

同样,本次比赛的数据集是ImageNet数据集的子集。 因此,我们可以使用 微调中讨论的方法在完整ImageNet数据集上选择预训练的模型,然后使用该模型提取图像特征,以便将其输入到定制的小规模输出网络中。

深度学习框架的高级API提供了在ImageNet数据集上预训练的各种模型。 在这里,我们选择预训练的ResNet-34模型,我们只需重复使用此模型的输出层(即提取的特征)的输入。 然后,我们可以用一个可以训练的小型自定义输出网络替换原始输出层,例如堆叠两个完全连接的图层。

def get_net(devices):finetune_net = nn.Sequential()finetune_net.features = torchvision.models.resnet34(pretrained=True)# 定义一个新的输出网络,共有120个输出类别finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),nn.ReLU(),nn.Linear(256, 120))# 将模型参数分配给用于计算的CPU或GPUfinetune_net = finetune_net.to(devices[0])# 冻结参数for param in finetune_net.features.parameters():param.requires_grad = Falsereturn finetune_net

在计算损失之前,我们首先获取预训练模型的输出层的输入,即提取的特征。 然后我们使用此特征作为我们小型自定义输出网络的输入来计算损失。

loss = nn.CrossEntropyLoss(reduction='none')def evaluate_loss(data_iter, net, devices):l_sum, n = 0.0, 0for features, labels in data_iter:features, labels = features.to(devices[0]), labels.to(devices[0])outputs = net(features)l = loss(outputs, labels)l_sum += l.sum()n += labels.numel()return (l_sum / n).to('cpu')

训练函数

模型训练函数train只迭代小型自定义输出网络的参数

def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,lr_decay):# 只训练小型自定义输出网络net = nn.DataParallel(net, device_ids=devices).to(devices[0])trainer = torch.optim.SGD((param for param in net.parameters()if param.requires_grad), lr=lr,momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss']if valid_iter is not None:legend.append('valid loss')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=legend)for epoch in range(num_epochs):metric = d2l.Accumulator(2)for i, (features, labels) in enumerate(train_iter):timer.start()features, labels = features.to(devices[0]), labels.to(devices[0])trainer.zero_grad()output = net(features)l = loss(output, labels).sum()l.backward()trainer.step()metric.add(l, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[1], None))measures = f'train loss {metric[0] / metric[1]:.3f}'if valid_iter is not None:valid_loss = evaluate_loss(valid_iter, net, devices)animator.add(epoch + 1, (None, valid_loss.detach().cpu()))scheduler.step()if valid_iter is not None:measures += f', valid loss {valid_loss:.3f}'print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'f' examples/sec on {str(devices)}')

训练和验证模型

现在我们可以训练和验证模型了,以下超参数都是可调的。 例如,我们可以增加迭代轮数。 另外,由于lr_periodlr_decay分别设置为2和0.9, 因此优化算法的学习速率将在每2个迭代后乘以0.9。

devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net = 2, 0.9, get_net(devices)
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,lr_decay)

Kaggle提交结果

最终所有标记的数据(包括验证集)都用于训练模型和对测试集进行分类。 我们将使用训练好的自定义输出网络进行分类。

net = get_net(devices)
train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,lr_decay)preds = []
for data, label in test_iter:output = torch.nn.functional.softmax(net(data.to(devices[0])), dim=1)preds.extend(output.cpu().detach().numpy())
ids = sorted(os.listdir(os.path.join(data_dir, 'train_valid_test', 'test', 'unknown')))
with open('submission.csv', 'w') as f:f.write('id,' + ','.join(train_valid_ds.classes) + '\n')for i, output in zip(ids, preds):f.write(i.split('.')[0] + ',' + ','.join([str(num) for num in output]) + '\n')

相关文章:

秃姐学AI系列之:实战Kaggle比赛:狗的品种识别(ImageNet Dogs)

目录 前置准备 整理数据集 图片增广 读取数据集 微调预训练模型 训练函数 训练和验证模型 Kaggle提交结果 前置准备 常规导包 import os import torch import torchvision from torch import nn from d2l import torch as d2l 使用小规模数据样本 d2l.DATA_HUB[dog…...

图神经网络介绍3

1. 图同构网络:Weisfeiler-Lehman 测试与图神经网络的表达力 本节介绍一个关于图神经网络表达力的经典工作,以及随之产生的另一个重要的模型——图同构网络。图同构问题指的是验证两个图在拓扑结构上是否相同。Weisfeiler-Lehman 测试是一种有效的检验两…...

浅谈 React Fiber

想象一下,你正在搭建一个乐高积木城堡。 传统的搭建方式:一次性把所有积木拼好,如果中途发现某个地方拼错了,就需要拆掉重新拼。这个过程就像 React 15 之前的版本,一旦开始渲染,就很难中断,效…...

Winform实现石头剪刀布小游戏

1、电脑玩家类 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace RockScissorsClothApp {public class Computer{public Card Play(){Random random new Random();int num random.Next(0, 3…...

计算机的错误计算(九十)

摘要 计算机的错误计算(八十九)探讨了反双曲余切函数 acoth(x)在 附近的计算精度问题。本节讨论绝对值为大数的反双曲余切函数值的计算精度问题。 Acoth(x) 函数的定义为: 其中 x 的绝对值大于 1 . 例1. 计算 acoth(1.000000000002e15) .…...

对游戏语音软件Oopz遭遇DDoS攻击后的一些建议

由于武汉天气太热,因此周末两天就没怎么出门。一直在家打《黑神话:悟空》,结果卡在广智这里一直打不过去,本来想找好友一起讨论下该怎么过,但又没有好的游戏语音软件。于是在网上搜索了一些信息,并偶然间发…...

解锁Android开发利器:MVVM架构_android的mvvm

// 从网络或其他数据源获取天气数据return Weather(city, "25C") }} 2.定义View:class WeatherActivity : AppCompatActivity() { private lateinit var viewModel: WeatherViewModel override fun onCreate(savedInstanceState: Bundle?) {super.onCre…...

llama.cpp demo

git clone https://github.com/ggerganov/llama.cpp cd llama.cpp 修改Makefile使能mfma参数 MK_CFLAGS -mfma -mf16c -mavx MK_CXXFLAGS -mfma -mf16c -mavx 安装python3依赖 cat ./requirements/requirements-convert_legacy_llama.txt numpy~1.26.4 sentencepie…...

OpenCV结构分析与形状描述符(19)查找二维点集的最小面积外接旋转矩形函数minAreaRect()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 找到一个包围输入的二维点集的最小面积旋转矩形。 该函数计算并返回指定点集的最小面积边界矩形(可能是旋转的)。开发者…...

[SWPU2019]Web1 超详细教程

老规矩先看源码,没找到啥提示,后面就是登录口对抗 弱口令试了几个不行,就注册了个账户登录进去 可以发布广告,能造成xss,但是没啥用啊感觉 查看广告信息的时候,注意到url当中存在id参数,可能存…...

【区块链通用服务平台及组件】基于向量数据库与 LLM 的智能合约 Copilot

智能合约是自动执行、无需信任的代码,可以在区块链上运行,确保了数据和程序的透明性和不可篡改性。然而, 智能合约的编写、调试和优化仍然是一个具有挑战性的过程,因为它需要高度的技术专长,且发布后的智能合约代码通常…...

mfc140u.dll丢失有啥方法能够进行修复?分享几种mfc140u.dll丢失的解决办法

你是否曾遇到过这样的情况:当你满怀期待地打开一个应用程序时,却被一个错误提示拦住了去路,提示信息中指出 mfc140u.dll 文件丢失。这个问题可能会让你感到困惑和无助,但是不要担心,本文将为你详细解读 mfc140u.dll 丢…...

【PyQt6 应用程序】在用户登录界面实现密码密文保存复用

在开发现代应用程序中,为用户提供既安全又便捷的登录体验是至关重要的。特别是在那些需要用户认证的应用中,实现一个功能丰富且用户友好的登录界面不仅能增强用户满意度,还能提升整体的安全性。基于PyQt6框架和QtDesigner,本文将展示如何在已有的用户登录页面基础上,进一步…...

赋能百业:多模态处理技术与大模型架构下的AI解决方案落地实践

赋能百业:多模态处理技术与大模型架构下的AI解决方案落地实践 AI 语音交互大模型其实有两种主流的做法: All in LLM多个模块组合, ASR+LLM+TTS实际应用中,这两种方案并不是要对立存在的,像永劫无间这种游戏的场景,用户要的是低延迟,无障碍交流。并且能够触发某些动作技…...

游戏论坛网站|基于Springboot+vue的游戏论坛网站系统游戏分享网站(源码+数据库+文档)

游戏论坛|游戏论坛系统|游戏分享网站 目录 基于Springbootvue的游戏论坛网站系统游戏分享网站 一、前言 二、系统设计 三、系统功能设计 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取: 博主介绍:✌️大…...

【go】pprof 性能分析

前言 go pprof是 Go 语言提供的性能分析工具。它可以帮助开发者分析 Go 程序的性能问题,包括 CPU 使用情况、内存分配情况、阻塞情况等。 主要功能 CPU 性能分析 go pprof可以对程序的 CPU 使用情况进行分析。它通过在一定时间内对程序的执行进行采样&#xff0…...

Python | Leetcode Python题解之第397题整数替换

题目: 题解: class Solution:def integerReplacement(self, n: int) -> int:ans 0while n ! 1:if n % 2 0:ans 1n // 2elif n % 4 1:ans 2n // 2else:if n 3:ans 2n 1else:ans 2n n // 2 1return ans...

JDBC使用

7.2 创建JDBC应用 7.2.1 创建JDBC应用程序的步骤 使用JDBC操作数据库中的数据包括6个基本操作步骤: (1)载入JDBC驱动程序: 首先要在应用程序中加载驱动程序driver,使用Class.forName()方法加载特定的驱动程序&#xf…...

633. 平方数之和-LeetCode(C++)

633. 平方数之和 2024.9.11 题目 给定一个非负整数 c &#xff0c;你要判断是否存在两个整数 a 和 b&#xff0c;使得 a2 b2 c 。 0 < c < 2的31次方 - 1 示例 示例 1&#xff1a; 输入&#xff1a;c 5 输出&#xff1a;true 解释&#xff1a;1 * 1 2 * 2 5示…...

Linux shell编程学习笔记79:cpio命令——文件和目录归档工具(下)

在 Linux shell编程学习笔记78&#xff1a;cpio命令——文件和目录归档工具&#xff08;上&#xff09;-CSDN博客https://blog.csdn.net/Purpleendurer/article/details/142095476?spm1001.2014.3001.5501中&#xff0c;我们研究了 cpio命令 的功能、格式、选项说明 以及 cpi…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

转转集团旗下首家二手多品类循环仓店“超级转转”开业

6月9日&#xff0c;国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解&#xff0c;“超级…...

Cloudflare 从 Nginx 到 Pingora:性能、效率与安全的全面升级

在互联网的快速发展中&#xff0c;高性能、高效率和高安全性的网络服务成为了各大互联网基础设施提供商的核心追求。Cloudflare 作为全球领先的互联网安全和基础设施公司&#xff0c;近期做出了一个重大技术决策&#xff1a;弃用长期使用的 Nginx&#xff0c;转而采用其内部开发…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

爬虫基础学习day2

# 爬虫设计领域 工商&#xff1a;企查查、天眼查短视频&#xff1a;抖音、快手、西瓜 ---> 飞瓜电商&#xff1a;京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空&#xff1a;抓取所有航空公司价格 ---> 去哪儿自媒体&#xff1a;采集自媒体数据进…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败&#xff0c;具体原因是客户端发送了密码认证请求&#xff0c;但Redis服务器未设置密码 1.为Redis设置密码&#xff08;匹配客户端配置&#xff09; 步骤&#xff1a; 1&#xff09;.修…...

08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险

C#入门系列【类的基本概念】&#xff1a;开启编程世界的奇妙冒险 嘿&#xff0c;各位编程小白探险家&#xff01;欢迎来到 C# 的奇幻大陆&#xff01;今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类&#xff01;别害怕&#xff0c;跟着我&#xff0c;保准让你轻松搞…...

计算机基础知识解析:从应用到架构的全面拆解

目录 前言 1、 计算机的应用领域&#xff1a;无处不在的数字助手 2、 计算机的进化史&#xff1a;从算盘到量子计算 3、计算机的分类&#xff1a;不止 “台式机和笔记本” 4、计算机的组件&#xff1a;硬件与软件的协同 4.1 硬件&#xff1a;五大核心部件 4.2 软件&#…...

智能职业发展系统:AI驱动的职业规划平台技术解析

智能职业发展系统&#xff1a;AI驱动的职业规划平台技术解析 引言&#xff1a;数字时代的职业革命 在当今瞬息万变的就业市场中&#xff0c;传统的职业规划方法已无法满足个人和企业的需求。据统计&#xff0c;全球每年有超过2亿人面临职业转型困境&#xff0c;而企业也因此遭…...

在 Visual Studio Code 中使用驭码 CodeRider 提升开发效率:以冒泡排序为例

目录 前言1 插件安装与配置1.1 安装驭码 CodeRider1.2 初始配置建议 2 示例代码&#xff1a;冒泡排序3 驭码 CodeRider 功能详解3.1 功能概览3.2 代码解释功能3.3 自动注释生成3.4 逻辑修改功能3.5 单元测试自动生成3.6 代码优化建议 4 驭码的实际应用建议5 常见问题与解决建议…...