(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10
- 导入相关库
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
- 下载数据集
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip','2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = Trueif demo:data_dir = d2l.download_extract('cifar10_tiny')
else:data_dir = '../data/kaggle/cifar-10/'
- 整理数据集
# 查看数据集
def read_csv_labels(fname):"""读取‘fname’来给标签字典返回一个文件名"""with open(fname, 'r') as f:lines = f.readlines()[1:] # readlines(): 每次读文档的一行,以后还需要逐步循环tokens = [l.rstrip().split(',') for l in lines] # rstrip(): 删除字符串后面(右面)的空格或特殊字符, 还有lstrip(左面)、strip(两面)return dict((name, label) for name, label in tokens)labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('训练样本:', len(labels))
print('类别:', len(set(labels.values()))) # set(): 集合,里面不能包含重复的元素,接受一个list作为参数

将验证集从原始的训练集钟拆分出来
# 拆分数据集:训练集、验证集
def copyfile(filename, target_dir):"""将文件复制到目标目录"""os.makedirs(target_dir, exist_ok=True) # 创建多层目录,exist_ok为True:在目标目录已存在的情况下不会触发FileExistsError异常。shutil.copy(filename, target_dir) #拷贝文件,filename:要拷贝的文件;target_dir:目标文件夹def reorg_train_valid(data_dir, labels, valid_ratio):"""将验证集从原始训练集钟拆分出来"""# 训练数据集中样本数量最少的类别中的样本数# Counter: 计数器,返回一个字典,键为元素,值为元素个数;# .most_common(): 返回一个列表, 列表元素为(元素,出现次数),默认按出现频率排序# [-1]: 样本数量最少的类别(类别, 样本数),[-1][1]: 样本数数量最少的类别中的样本数n = collections.Counter(labels.values()).most_common()[-1][1]# 验证集中每个类别的样本数n_valid_per_label= max(1, math.floor((n * valid_ratio))) # math.floor(): 向下取整 math.ceil(): 向上取整label_count = {}# 遍历原始训练集中的每个样本for train_file in os.listdir(os.path.join(data_dir, 'train')):label = labels[train_file.split('.')[0]] # 从文件名中提取标签fname = os.path.join(data_dir, 'train', train_file)copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train_valid', label))# 如果该类别的样本数还未达到在验证集中的设定数量,则将样本复制到验证集中if label not in label_count or label_count[label] < n_valid_per_label:copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'valid', label))label_count[label] = label_count.get(label, 0) + 1else:copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train', label))return n_valid_per_label# reorg_test函数用来在预测期间整理测试集,以方便读取
def reorg_test(data_dir):"""在预测期间整理测试集,以方便读取"""# 遍历测试集中的每个样本for test_file in os.listdir(os.path.join(data_dir, 'test')):# 将测试集中的样本复制到新的目录结构中的 'test' 子目录下,标签为 'unknown'copyfile(os.path.join(data_dir, 'test', test_file),os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
# 整个处理数据集函数
def reorg_cifar10_data(data_dir, valid_ratio):labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))reorg_train_valid(data_dir, labels, valid_ratio)reorg_test(data_dir)
- 这个小规模数据集的批量大小是32,在实际的cifar-10数据集中,可以设为128
- 将10%的训练样本作为调整超参数的验证集
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
- 图像增广
transform_train = torchvision.transforms.Compose([# 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强torchvision.transforms.Resize(40),torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),# 标准化图像的每个通道 : 消除评估结果中的随机性torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
- 加载数据集
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test) for folder in ['valid', 'test']
]
- 定义迭代器,方便快速迭代数据
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=False
)
- 定义模型与损失函数
# 对resnet18做微调,输入通道数为3, 输出类别数为10
def get_net():num_classes = 10net = d2l.resnet18(num_classes, in_channels=3)return net
# 查看网络模型
get_net()

# 使用交叉熵损失函数作为损失函数: 直接返回n分样本的loss
loss = nn.CrossEntropyLoss(reduction='none')
- 定义训练函数
# 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss', 'train acc']if valid_iter is not None:legend.append('valid acc')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):net.train()metric = d2l.Accumulator(3)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices)metric.add(l, acc, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0]/ metric[2], metric[1] / metric[2], None))if valid_iter is not None:valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)animator.add(epoch+1, (None, None, valid_acc))scheduler.step()measures = (f'train loss {metric[0] / metric[2]:.3f},'f'train acc{metric[1] / metric[2]:.3f}')if valid_iter is not None:measures += f', valid acc {valid_acc:.3f}'print(measures + f'\n{metric[2] * num_epochs /timer.sum():.1f}'f'example/sec on {str(devices)}')
- 训练模型
- (数据集太小,导致精度不高)
import time# 在开头设置开始时间
start = time.perf_counter() # start = time.clock() python3.8之前可以# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter() # end = time.clock() python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

10. 对测试集进行分类并提交结果
net, preds = get_net(), []
train(net ,train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)
for X, _ in test_iter:y_hat = net(X.to(devices[0]))preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id' : sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)

相关文章:
(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10
导入相关库 import collections import math import os import shutil import pandas as pd import torch import torchvision from torch import nn from d2l import torch as d2l下载数据集 d2l.DATA_HUB[cifar10_tiny] (d2l.DATA_URL kaggle_cifar10_tiny.zip,2068874e4…...
Go 语言中的map和内存泄漏
map在内存中总是会增长;它不会收缩。因此,如果map导致了一些内存问题,你可以尝试不同的选项,比如强制 Go 重新创建map或使用指针。 在 Go 中使用map时,我们需要了解map增长和收缩的一些重要特性。让我们深入探讨这一点…...
前缀和(c++,超详细,含二维)
前缀和与差分 当给定一段整数序列a1,a2,a3,a4,a5…an; 每次让我们求一段区间的和,正常做法是for循环遍历区间起始点到结束点,进行求和计算,但是当询问次数很多并且区间很长的时候 比如,10^5 个询问和10^6区间长度,相…...
详解FreeRTOS:二值信号量和计数信号量(高级篇—2)
目录 1、二值信号量 1.1、二值信号量运行机制 1.2、创建二值信号量 1...
持续集成交付CICD:Jenkins通过API触发流水线
目录 一、理论 1.HTTP请求 2.调用接口的方法 3.HTTP常见错误码 二、实验 1.Jenkins通过API触发流水线 三、问题 1.如何拿到上一次jenkinsfile文件进行自动触发流水线 一、理论 1.HTTP请求 (1)概念 HTTP超文本传输协议,是确保服务器…...
【Python】12 GPflow安装
概述 GPflow 是一个基于TensorFlow 在 Python 中构建高斯过程模型的包。高斯过程是一种监督学习模型。 高斯过程的一些优点是: 不确定性是高斯过程的固有部分。高斯过程可以在不知道答案时告诉您。适用于小型数据集。如果您的数据有限,高斯过程可以从…...
Ubuntu源码编译gdal3.6.2
在华为云申请了一台Ubuntu v18的机器,乱七八糟的不要装。 apt install build-essential pkg-config -y cmake-3.21.1 apt-get install openssl libssl-dev 过程参考:Yukon for PostgreSQL_格來羙、日出的博客-CSDN博客 zlib-1.2.9(不需要) 如果用系统的后面gd…...
【LeetCode】160. 相交链表
160. 相交链表 难度:简单 题目 给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点,返回 null 。 图示两个链表在节点 c1 开始相交: 题目数据 保证 整个链式结构中…...
数据集笔记:NGSIM (next generation simulation)
1 数据集介绍 数据介绍s Next Generation Simulation (NGSIM) Open Data (transportation.gov) 数据地址:Next Generation Simulation (NGSIM) Vehicle Trajectories and Supporting Data | Department of Transportation - Data Portal 时间2005年到2006年间地…...
解决docker运行elastic服务端启动不成功
现象: 然后查看docker日志,发现有vm.max_map_count报错 ERROR: [1] bootstrap checks failed [1]: max virtual memory areas vm.max_map_count [65530] is too low, increase to at least [262144] 解决办法: 1. 宿主机(运行doc…...
mysql数据库中mysql database 数据被破坏产生的一系列问题
在执行sql脚本时,没有注意到sql脚本文件包含了对mysql 原始数据库的操作,执行了脚本。 脚本执行成功之后,登录或链接数据库查看数据时报错: The user specified as a definer (‘mysql.infoschema’‘localhost’) does not exis…...
基于变形卷积和注意机制的带钢表面缺陷快速检测网络DCAM-Net(论文阅读笔记)
原论文链接->DCAM-Net: A Rapid Detection Network for Strip Steel Surface Defects Based on Deformable Convolution and Attention Mechanism | IEEE Journals & Magazine | IEEE Xplore DCAM-Net: A Rapid Detection Network for Strip Steel Surface Defects Base…...
05-Spring Boot工程中简化开发的方式Lombok和dev-tools
简化开发的方式Lombok和dev-tools Lombok常用注解 Lombok用标签方式代替构造器、getter/setter、toString()等重复代码, 在程序编译的时候自动生成这些代码 注解名功能NoArgsConstructor生成无参构造方法AllArgsConstructor生产含所有属性的有参构造方法,如果不希望含所有属…...
AIGC 技术在淘淘秀场景的探索与实践
本文介绍了AIGC相关领域的爆发式增长,并探讨了淘宝秀秀(AI买家秀)的设计思路和技术方案。文章涵盖了图像生成、仿真形象生成和换背景方案,以及模型流程串联等关键技术。 文章还介绍了淘淘秀的使用流程和遇到的问题及处理方法。最后,文章展望…...
ANSYS网格无关性检查
网格精度对应力结果存在很大的影响,有时候可以发现,随着网格精度逐渐提高,所求得的最大应力值逐渐趋于收敛。 默认网格: 从默认网格下计算出的应力云图可以发现,出现了的三处应力奇异点,此时算出的应力值是…...
设计模式-责任链-笔记
动机(Motivation) 在软件构建过程中,一个请求可能被多个对象处理,但是每个请求在运行时只能有个接受者,如果显示指定,将必不可少地带来请求者与接受者的紧耦合。 如何使请求的发送者不需要指定具体的接受…...
SpringMvc请求原理流程
springmvc是用户和服务沟通的桥梁,官网提供了springmvc的全面使用和解释:DispatcherServlet :: Spring Framework 流程 1.Tomcat启动 2.解析web.xml文件,根据servlet-class找到DispatcherServlet,根据init-param来获取spring的…...
【开源】基于Vue.js的音乐偏好度推荐系统的设计和实现
项目编号: S 012 ,文末获取源码。 \color{red}{项目编号:S012,文末获取源码。} 项目编号:S012,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、系统设计2.1 功能模块设计2.1.1 音乐档案模块2.1…...
采集1688整店商品(店铺所有商品、店铺列表api)
返回数据: 请求链接 {"user": [],"items": {"item": [{"num_iid": "738354436678","title": "国产正品i13 promax全网通5G安卓智能手机源头厂家批发手机","pic_url": "http…...
IObit Unlocker丨解除占用程序软件
更多内容请收藏:https://rwx.tza-3.xyz 官网:IObit Unlocker “永远不用担心电脑上无法删除的文件。” 界面简单,支持简体中文,一看就会,只需要把无法删除/移动的文件或整个U盘拖到框里就行。 解锁率很高,…...
SpringBoot-17-MyBatis动态SQL标签之常用标签
文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
css实现圆环展示百分比,根据值动态展示所占比例
代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...
页面渲染流程与性能优化
页面渲染流程与性能优化详解(完整版) 一、现代浏览器渲染流程(详细说明) 1. 构建DOM树 浏览器接收到HTML文档后,会逐步解析并构建DOM(Document Object Model)树。具体过程如下: (…...
ardupilot 开发环境eclipse 中import 缺少C++
目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...
IT供电系统绝缘监测及故障定位解决方案
随着新能源的快速发展,光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域,IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选,但在长期运行中,例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...
什么是VR全景技术
VR全景技术,全称为虚拟现实全景技术,是通过计算机图像模拟生成三维空间中的虚拟世界,使用户能够在该虚拟世界中进行全方位、无死角的观察和交互的技术。VR全景技术模拟人在真实空间中的视觉体验,结合图文、3D、音视频等多媒体元素…...
第一篇:Liunx环境下搭建PaddlePaddle 3.0基础环境(Liunx Centos8.5安装Python3.10+pip3.10)
第一篇:Liunx环境下搭建PaddlePaddle 3.0基础环境(Liunx Centos8.5安装Python3.10pip3.10) 一:前言二:安装编译依赖二:安装Python3.10三:安装PIP3.10四:安装Paddlepaddle基础框架4.1…...
react菜单,动态绑定点击事件,菜单分离出去单独的js文件,Ant框架
1、菜单文件treeTop.js // 顶部菜单 import { AppstoreOutlined, SettingOutlined } from ant-design/icons; // 定义菜单项数据 const treeTop [{label: Docker管理,key: 1,icon: <AppstoreOutlined />,url:"/docker/index"},{label: 权限管理,key: 2,icon:…...
「Java基本语法」变量的使用
变量定义 变量是程序中存储数据的容器,用于保存可变的数据值。在Java中,变量必须先声明后使用,声明时需指定变量的数据类型和变量名。 语法 数据类型 变量名 [ 初始值]; 示例:声明与初始化 public class VariableDemo {publi…...
