(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别
文章目录
- 1. 导入相关库
- 2. 加载数据集
- 3. 整理数据集
- 4. 图像增广
- 5. 读取数据
- 6. 微调预训练模型
- 7. 定义损失函数和评价损失函数
- 9. 训练模型
1. 导入相关库
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
2. 加载数据集
- 该数据集是完整数据集的小规模样本
# 下载数据集
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip','0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')# 如果使用Kaggle比赛的完整数据集,请将下面的变量更改为False
demo = True
if demo:data_dir = d2l.download_extract('dog_tiny')
else:data_dir = os.path.join('..', 'data', 'dog-breed-identification')
3. 整理数据集
def reorg_dog_data(data_dir, valid_ratio):labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))d2l.reorg_train_valid(data_dir, labels, valid_ratio)d2l.reorg_test(data_dir)batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)
4. 图像增广
transform_train = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
5. 读取数据
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_test) for folder in ['valid', 'test']
]
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=True
)
6. 微调预训练模型
def get_net(devices):finetune_net = nn.Sequential()finetune_net.features = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)# 定义一个新的输出网络,共有120个输出类别finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),nn.ReLU(),nn.Linear(256, 120))finetune_net = finetune_net.to(devices[0])# 冻结参数for param in finetune_net.features.parameters():param.requires_grad = Falsereturn finetune_net
# 查看网络模型
get_net(devices=d2l.try_all_gpus())

7. 定义损失函数和评价损失函数
# 定义损失函数
loss = nn.CrossEntropyLoss(reduction='none')def evaluate_loss(data_iter, net, device):l_sum, n = 0.0, 0for features, labels in data_iter:features, labels = features.to(device[0]), labels.to(device[0])outputs = net(features)l = loss(outputs, labels)l_sum += l.sum()n += labels.numel()return (l_sum / n).to('cpu')
- 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):# 只训练小型定义输出网络net = nn.DataParallel(net, device_ids=devices).to(devices[0])trainer = torch.optim.SGD((param for param in net.parameters() if param.requires_grad),lr=lr, momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss']if valid_iter is not None:legend.append('valid loss')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)for epoch in range(num_epochs):metric = d2l.Accumulator(2)for i, (features, labels) in enumerate(train_iter):timer.start()features, labels = features.to(devices[0]), labels.to(devices[0])trainer.zero_grad()output = net(features)l = loss(output, labels).sum()l.backward()trainer.step()metric.add(l, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1], None))measures = f'train loss {metric[0] / metric[1]:.3f}'if valid_iter is not None :valid_loss = evaluate_loss(valid_iter, net, devices)animator.add(epoch + 1, (None, valid_loss.detach().cpu()))scheduler.step()if valid_iter is not None:measures += f', valid loss {valid_loss:.3f}'print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'f'examples/sec on {str(devices)}')
9. 训练模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net, = 2, 0.9, get_net(devices)
import time# 在开头设置开始时间
start = time.perf_counter() # start = time.clock() python3.8之前可以train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter() # end = time.clock() python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

相关文章:
(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别
文章目录 1. 导入相关库2. 加载数据集3. 整理数据集4. 图像增广5. 读取数据6. 微调预训练模型7. 定义损失函数和评价损失函数9. 训练模型 1. 导入相关库 import os import torch import torchvision from torch import nn from d2l import torch as d2l2. 加载数据集 - 该数据…...
自定义注解+AOP
自定义注解与AOP(面向切面编程)的结合常常用于在应用程序中划定切面,以便在特定的方法或类上应用横切关注点。以下是一个简单的示例,演示了如何创建自定义注解,并使用Spring AOP来在被注解的方法上应用通知。 如何创建…...
Ribbon
在Spring Cloud中,Ribbon是一个用于客户端负载均衡的组件,它可以与其他服务发现组件(例如Eureka)集成,以提供更强大的负载均衡功能。Ribbon使得微服务架构中的客户端能够更加智能地调用其他服务的实例,从而…...
git -1
1.创建第一个仓库并配置local用户信息 git config git config --global 对当前用户所有仓库有效 git config --system 对系统所有登录的用户有效 git config --local 只对某个仓库有效 git config --list 显示配置 git config --list --global 所有仓库 git config --list…...
基于SSM+Vue的鲜花销售系统/网上花店系统
基于SSM的鲜花销售系统/网上花店系统的设计与实现~ 开发语言:Java数据库:MySQL技术:SpringMyBatisSpringMVC工具:IDEA/Ecilpse、Navicat、Maven 系统展示 主页 管理员界面 摘要 鲜花销售系统是一个基于SSM(Spring …...
安卓:Android Studio4.0~2023中正确的打开Android Device Monitor
Android Studio4.0~2023 中如何正确的打开Android Device Monitor(亲测有效) 前些天买了新电脑,安装了新版本的Android Studio4.0想试一试,结果就出现了一些问题。 问题引出: Android Device Monitor在工具栏中找不到,后来上网查…...
装备制造企业设备远程运维平台的建设-天拓四方分享
设备远程运维平台是一种基于互联网和物联网技术的设备管理平台,可以实现设备的远程监控、故障诊断、预警维护等功能。近年来,随着云计算、大数据、人工智能等技术的不断发展,设备远程运维平台的智能化程度越来越高,传统的设备运维…...
群晖NAS搭建WebDav服务做文件共享,可随时随地远程访问
文章目录 1. 在群晖套件中心安装WebDav Server套件1.1 安装完成后,启动webdav服务,并勾选HTTP复选框 2. 局域网测试WebDav服务2.1 下载RaiDrive客户端2.2 打开RaiDrive,设置界面语言可以选择中文2.3 点击添加按钮,新建虚拟驱动区2…...
c++调用Lua(table嵌套写法)
通过c调用lua接口将数据存储到虚拟栈中,就可以在lua脚本在虚拟栈中取得数据 c调用lua库,加载lua文件, lua_State* L;//定义一个全局变量***************************L luaL_newstate();luaL_openlibs(L);//打开Lua脚本文件std::string pat…...
算法复杂度分析
文章目录 有数据范围反推算法复杂度以及算法内容一般方法递归 有数据范围反推算法复杂度以及算法内容 c一秒可以算 1 0 7 10^7 107~ 1 0 8 10^8 108次 一般方法 看循环 有几层循环就可以初步分析O( n i n^i ni) 双指针算法除外O(n) 递归 公式法 根据公式的形式࿰…...
几款Java源码扫描工具(FindBugs、PMD、SonarQube、Fortify、WebInspect)
说明 有几个常用的Java源码扫描工具可以帮助您进行源代码分析和检查。以下是其中一些工具: FindBugs:FindBugs是一个静态分析工具,用于查找Java代码中的潜在缺陷和错误。它可以检测出空指针引用、资源未关闭、不良的代码实践等问题。FindBu…...
java springboot测试类鉴定虚拟MVC请求 返回内容与预期值是否相同
上文 java springboot测试类鉴定虚拟MVC运行值与预期值是否相同 中 我们验证了它HTTP的返回状态 简单说 校验了他 是否成功的状态 这次 我们来不对得到的内容 我们 直接改写测试类代码如下 package com.example.webdom;import org.junit.jupiter.api.Test; import org.springf…...
MongoDB随记
MongoDB 1、简单介绍2、基本术语3、shard分片概述背景架构路由功能chunk(数据分片)shard key(分片键值) 4、常用命令 1、简单介绍 MongoDB是一个分布式文件存储的数据库,介于关系数据库和非关系数据库之间,…...
839 - Not so Mobile (UVA)
题目链接如下: Online Judge 这道题刘汝佳的解法极其简洁,用了20来行就解决了问题。膜拜…… 他的解法如下:天平(UVa839紫书p157)_天平 uva 839_falldeep的博客-CSDN博客 我写了两个(都很冗长ÿ…...
php字符串处理函数的使用
php字符串处理函数的使用 trim() trim()函数的功能用于去除字符串首尾的空白字符(包括空格、制表符、换行符等)。它可以用于清理用户输入的数据或去除字符串中的多余空格。 <?php $char" holle world! ";echo trim($char) ?>str_repl…...
UEC++ day8
伤害系统 给敌人创建血条 首先添加一个UI界面用来显示敌人血条设置背景图像为黑色半透明 填充颜色 给敌人类添加两种状态表示血量与最大血量,添加一个UWidegtComponet组件与UProgressBar组件 UPROPERTY(EditAnywhere, BlueprintReadWrite, Category "Enemy …...
学习记录——ipv4、ipv6与ip、DNS、网络协议
文章目录 前情提要:网络协议和域名DNS协议、DNS污染Ipv4、Ipv6NAT协议,IP:端口,环节IP地址枯竭NAT-PT协议,加速Ipv6应用 前情提要: 本文仅做个人的学习记录以及理解,可能存在一些错误。 网络协…...
cefsharp119.4.30(cef119.4.3,Chromium119.0.6045.159)版本升级体验支持H264及其他多个H264版本
Cefsharp119.4.30,cef119.4.3,Chromium119.0.6045.159 此更新包括一个高优先级安全更新 This update includes a high priority security update. 说明:此版本119.4.3支持H264视频播放(需要联系我),其他版本。.NETFramework 4.6.2 NuGet Gallery | CefSharp.WinForms 119.…...
“index“ should always be multi-word
vue报错:Component name “index” should always be multi-word 分析:组件名要以驼峰格式命名,自定义的要以loginIndex.vue等这种方式命名,防止和html标签冲突,所以命名index.vue 会报错 解决:在.eslint…...
服务器64GB内存、8核CPU的MySQL 8配置参数
服务器64GB内存、8核CPU的MySQL 8配置参数可以按照以下步骤进行调优: 调整缓冲区相关参数: 增加innodb_buffer_pool_size的值,将其设置为4GB或更大,以加速频繁读取的操作。 – 2147483648 增加key_buffer_size的值,将…...
IDEA运行Tomcat出现乱码问题解决汇总
最近正值期末周,有很多同学在写期末Java web作业时,运行tomcat出现乱码问题,经过多次解决与研究,我做了如下整理: 原因: IDEA本身编码与tomcat的编码与Windows编码不同导致,Windows 系统控制台…...
rknn优化教程(二)
文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...
Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
centos 7 部署awstats 网站访问检测
一、基础环境准备(两种安装方式都要做) bash # 安装必要依赖 yum install -y httpd perl mod_perl perl-Time-HiRes perl-DateTime systemctl enable httpd # 设置 Apache 开机自启 systemctl start httpd # 启动 Apache二、安装 AWStats࿰…...
ESP32读取DHT11温湿度数据
芯片:ESP32 环境:Arduino 一、安装DHT11传感器库 红框的库,别安装错了 二、代码 注意,DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...
蓝牙 BLE 扫描面试题大全(2):进阶面试题与实战演练
前文覆盖了 BLE 扫描的基础概念与经典问题蓝牙 BLE 扫描面试题大全(1):从基础到实战的深度解析-CSDN博客,但实际面试中,企业更关注候选人对复杂场景的应对能力(如多设备并发扫描、低功耗与高发现率的平衡)和前沿技术的…...
使用LangGraph和LangSmith构建多智能体人工智能系统
现在,通过组合几个较小的子智能体来创建一个强大的人工智能智能体正成为一种趋势。但这也带来了一些挑战,比如减少幻觉、管理对话流程、在测试期间留意智能体的工作方式、允许人工介入以及评估其性能。你需要进行大量的反复试验。 在这篇博客〔原作者&a…...
GO协程(Goroutine)问题总结
在使用Go语言来编写代码时,遇到的一些问题总结一下 [参考文档]:https://www.topgoer.com/%E5%B9%B6%E5%8F%91%E7%BC%96%E7%A8%8B/goroutine.html 1. main()函数默认的Goroutine 场景再现: 今天在看到这个教程的时候,在自己的电…...
DeepSeek源码深度解析 × 华为仓颉语言编程精粹——从MoE架构到全场景开发生态
前言 在人工智能技术飞速发展的今天,深度学习与大模型技术已成为推动行业变革的核心驱动力,而高效、灵活的开发工具与编程语言则为技术创新提供了重要支撑。本书以两大前沿技术领域为核心,系统性地呈现了两部深度技术著作的精华:…...
C# winform教程(二)----checkbox
一、作用 提供一个用户选择或者不选的状态,这是一个可以多选的控件。 二、属性 其实功能大差不差,除了特殊的几个外,与button基本相同,所有说几个独有的 checkbox属性 名称内容含义appearance控件外观可以变成按钮形状checkali…...
