(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10
- 导入相关库
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
- 下载数据集
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip','2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = Trueif demo:data_dir = d2l.download_extract('cifar10_tiny')
else:data_dir = '../data/kaggle/cifar-10/'
- 整理数据集
# 查看数据集
def read_csv_labels(fname):"""读取‘fname’来给标签字典返回一个文件名"""with open(fname, 'r') as f:lines = f.readlines()[1:] # readlines(): 每次读文档的一行,以后还需要逐步循环tokens = [l.rstrip().split(',') for l in lines] # rstrip(): 删除字符串后面(右面)的空格或特殊字符, 还有lstrip(左面)、strip(两面)return dict((name, label) for name, label in tokens)labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('训练样本:', len(labels))
print('类别:', len(set(labels.values()))) # set(): 集合,里面不能包含重复的元素,接受一个list作为参数

将验证集从原始的训练集钟拆分出来
# 拆分数据集:训练集、验证集
def copyfile(filename, target_dir):"""将文件复制到目标目录"""os.makedirs(target_dir, exist_ok=True) # 创建多层目录,exist_ok为True:在目标目录已存在的情况下不会触发FileExistsError异常。shutil.copy(filename, target_dir) #拷贝文件,filename:要拷贝的文件;target_dir:目标文件夹def reorg_train_valid(data_dir, labels, valid_ratio):"""将验证集从原始训练集钟拆分出来"""# 训练数据集中样本数量最少的类别中的样本数# Counter: 计数器,返回一个字典,键为元素,值为元素个数;# .most_common(): 返回一个列表, 列表元素为(元素,出现次数),默认按出现频率排序# [-1]: 样本数量最少的类别(类别, 样本数),[-1][1]: 样本数数量最少的类别中的样本数n = collections.Counter(labels.values()).most_common()[-1][1]# 验证集中每个类别的样本数n_valid_per_label= max(1, math.floor((n * valid_ratio))) # math.floor(): 向下取整 math.ceil(): 向上取整label_count = {}# 遍历原始训练集中的每个样本for train_file in os.listdir(os.path.join(data_dir, 'train')):label = labels[train_file.split('.')[0]] # 从文件名中提取标签fname = os.path.join(data_dir, 'train', train_file)copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train_valid', label))# 如果该类别的样本数还未达到在验证集中的设定数量,则将样本复制到验证集中if label not in label_count or label_count[label] < n_valid_per_label:copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'valid', label))label_count[label] = label_count.get(label, 0) + 1else:copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train', label))return n_valid_per_label# reorg_test函数用来在预测期间整理测试集,以方便读取
def reorg_test(data_dir):"""在预测期间整理测试集,以方便读取"""# 遍历测试集中的每个样本for test_file in os.listdir(os.path.join(data_dir, 'test')):# 将测试集中的样本复制到新的目录结构中的 'test' 子目录下,标签为 'unknown'copyfile(os.path.join(data_dir, 'test', test_file),os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
# 整个处理数据集函数
def reorg_cifar10_data(data_dir, valid_ratio):labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))reorg_train_valid(data_dir, labels, valid_ratio)reorg_test(data_dir)
- 这个小规模数据集的批量大小是32,在实际的cifar-10数据集中,可以设为128
- 将10%的训练样本作为调整超参数的验证集
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
- 图像增广
transform_train = torchvision.transforms.Compose([# 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强torchvision.transforms.Resize(40),torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),# 标准化图像的每个通道 : 消除评估结果中的随机性torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
- 加载数据集
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test) for folder in ['valid', 'test']
]
- 定义迭代器,方便快速迭代数据
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=False
)
- 定义模型与损失函数
# 对resnet18做微调,输入通道数为3, 输出类别数为10
def get_net():num_classes = 10net = d2l.resnet18(num_classes, in_channels=3)return net
# 查看网络模型
get_net()

# 使用交叉熵损失函数作为损失函数: 直接返回n分样本的loss
loss = nn.CrossEntropyLoss(reduction='none')
- 定义训练函数
# 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss', 'train acc']if valid_iter is not None:legend.append('valid acc')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):net.train()metric = d2l.Accumulator(3)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices)metric.add(l, acc, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0]/ metric[2], metric[1] / metric[2], None))if valid_iter is not None:valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)animator.add(epoch+1, (None, None, valid_acc))scheduler.step()measures = (f'train loss {metric[0] / metric[2]:.3f},'f'train acc{metric[1] / metric[2]:.3f}')if valid_iter is not None:measures += f', valid acc {valid_acc:.3f}'print(measures + f'\n{metric[2] * num_epochs /timer.sum():.1f}'f'example/sec on {str(devices)}')
- 训练模型
- (数据集太小,导致精度不高)
import time# 在开头设置开始时间
start = time.perf_counter() # start = time.clock() python3.8之前可以# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter() # end = time.clock() python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

10. 对测试集进行分类并提交结果
net, preds = get_net(), []
train(net ,train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)
for X, _ in test_iter:y_hat = net(X.to(devices[0]))preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id' : sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)

相关文章:
(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10
导入相关库 import collections import math import os import shutil import pandas as pd import torch import torchvision from torch import nn from d2l import torch as d2l下载数据集 d2l.DATA_HUB[cifar10_tiny] (d2l.DATA_URL kaggle_cifar10_tiny.zip,2068874e4…...
Go 语言中的map和内存泄漏
map在内存中总是会增长;它不会收缩。因此,如果map导致了一些内存问题,你可以尝试不同的选项,比如强制 Go 重新创建map或使用指针。 在 Go 中使用map时,我们需要了解map增长和收缩的一些重要特性。让我们深入探讨这一点…...
前缀和(c++,超详细,含二维)
前缀和与差分 当给定一段整数序列a1,a2,a3,a4,a5…an; 每次让我们求一段区间的和,正常做法是for循环遍历区间起始点到结束点,进行求和计算,但是当询问次数很多并且区间很长的时候 比如,10^5 个询问和10^6区间长度,相…...
详解FreeRTOS:二值信号量和计数信号量(高级篇—2)
目录 1、二值信号量 1.1、二值信号量运行机制 1.2、创建二值信号量 1...
持续集成交付CICD:Jenkins通过API触发流水线
目录 一、理论 1.HTTP请求 2.调用接口的方法 3.HTTP常见错误码 二、实验 1.Jenkins通过API触发流水线 三、问题 1.如何拿到上一次jenkinsfile文件进行自动触发流水线 一、理论 1.HTTP请求 (1)概念 HTTP超文本传输协议,是确保服务器…...
【Python】12 GPflow安装
概述 GPflow 是一个基于TensorFlow 在 Python 中构建高斯过程模型的包。高斯过程是一种监督学习模型。 高斯过程的一些优点是: 不确定性是高斯过程的固有部分。高斯过程可以在不知道答案时告诉您。适用于小型数据集。如果您的数据有限,高斯过程可以从…...
Ubuntu源码编译gdal3.6.2
在华为云申请了一台Ubuntu v18的机器,乱七八糟的不要装。 apt install build-essential pkg-config -y cmake-3.21.1 apt-get install openssl libssl-dev 过程参考:Yukon for PostgreSQL_格來羙、日出的博客-CSDN博客 zlib-1.2.9(不需要) 如果用系统的后面gd…...
【LeetCode】160. 相交链表
160. 相交链表 难度:简单 题目 给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点,返回 null 。 图示两个链表在节点 c1 开始相交: 题目数据 保证 整个链式结构中…...
数据集笔记:NGSIM (next generation simulation)
1 数据集介绍 数据介绍s Next Generation Simulation (NGSIM) Open Data (transportation.gov) 数据地址:Next Generation Simulation (NGSIM) Vehicle Trajectories and Supporting Data | Department of Transportation - Data Portal 时间2005年到2006年间地…...
解决docker运行elastic服务端启动不成功
现象: 然后查看docker日志,发现有vm.max_map_count报错 ERROR: [1] bootstrap checks failed [1]: max virtual memory areas vm.max_map_count [65530] is too low, increase to at least [262144] 解决办法: 1. 宿主机(运行doc…...
mysql数据库中mysql database 数据被破坏产生的一系列问题
在执行sql脚本时,没有注意到sql脚本文件包含了对mysql 原始数据库的操作,执行了脚本。 脚本执行成功之后,登录或链接数据库查看数据时报错: The user specified as a definer (‘mysql.infoschema’‘localhost’) does not exis…...
基于变形卷积和注意机制的带钢表面缺陷快速检测网络DCAM-Net(论文阅读笔记)
原论文链接->DCAM-Net: A Rapid Detection Network for Strip Steel Surface Defects Based on Deformable Convolution and Attention Mechanism | IEEE Journals & Magazine | IEEE Xplore DCAM-Net: A Rapid Detection Network for Strip Steel Surface Defects Base…...
05-Spring Boot工程中简化开发的方式Lombok和dev-tools
简化开发的方式Lombok和dev-tools Lombok常用注解 Lombok用标签方式代替构造器、getter/setter、toString()等重复代码, 在程序编译的时候自动生成这些代码 注解名功能NoArgsConstructor生成无参构造方法AllArgsConstructor生产含所有属性的有参构造方法,如果不希望含所有属…...
AIGC 技术在淘淘秀场景的探索与实践
本文介绍了AIGC相关领域的爆发式增长,并探讨了淘宝秀秀(AI买家秀)的设计思路和技术方案。文章涵盖了图像生成、仿真形象生成和换背景方案,以及模型流程串联等关键技术。 文章还介绍了淘淘秀的使用流程和遇到的问题及处理方法。最后,文章展望…...
ANSYS网格无关性检查
网格精度对应力结果存在很大的影响,有时候可以发现,随着网格精度逐渐提高,所求得的最大应力值逐渐趋于收敛。 默认网格: 从默认网格下计算出的应力云图可以发现,出现了的三处应力奇异点,此时算出的应力值是…...
设计模式-责任链-笔记
动机(Motivation) 在软件构建过程中,一个请求可能被多个对象处理,但是每个请求在运行时只能有个接受者,如果显示指定,将必不可少地带来请求者与接受者的紧耦合。 如何使请求的发送者不需要指定具体的接受…...
SpringMvc请求原理流程
springmvc是用户和服务沟通的桥梁,官网提供了springmvc的全面使用和解释:DispatcherServlet :: Spring Framework 流程 1.Tomcat启动 2.解析web.xml文件,根据servlet-class找到DispatcherServlet,根据init-param来获取spring的…...
【开源】基于Vue.js的音乐偏好度推荐系统的设计和实现
项目编号: S 012 ,文末获取源码。 \color{red}{项目编号:S012,文末获取源码。} 项目编号:S012,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、系统设计2.1 功能模块设计2.1.1 音乐档案模块2.1…...
采集1688整店商品(店铺所有商品、店铺列表api)
返回数据: 请求链接 {"user": [],"items": {"item": [{"num_iid": "738354436678","title": "国产正品i13 promax全网通5G安卓智能手机源头厂家批发手机","pic_url": "http…...
IObit Unlocker丨解除占用程序软件
更多内容请收藏:https://rwx.tza-3.xyz 官网:IObit Unlocker “永远不用担心电脑上无法删除的文件。” 界面简单,支持简体中文,一看就会,只需要把无法删除/移动的文件或整个U盘拖到框里就行。 解锁率很高,…...
Zustand 状态管理库:极简而强大的解决方案
Zustand 是一个轻量级、快速和可扩展的状态管理库,特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...
2025年能源电力系统与流体力学国际会议 (EPSFD 2025)
2025年能源电力系统与流体力学国际会议(EPSFD 2025)将于本年度在美丽的杭州盛大召开。作为全球能源、电力系统以及流体力学领域的顶级盛会,EPSFD 2025旨在为来自世界各地的科学家、工程师和研究人员提供一个展示最新研究成果、分享实践经验及…...
基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容
基于 UniApp + WebSocket实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...
Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)
概述 在 Swift 开发语言中,各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过,在涉及到多个子类派生于基类进行多态模拟的场景下,…...
YSYX学习记录(八)
C语言,练习0: 先创建一个文件夹,我用的是物理机: 安装build-essential 练习1: 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件,随机修改或删除一部分,之后…...
【数据分析】R版IntelliGenes用于生物标志物发现的可解释机器学习
禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍流程步骤1. 输入数据2. 特征选择3. 模型训练4. I-Genes 评分计算5. 输出结果 IntelliGenesR 安装包1. 特征选择2. 模型训练和评估3. I-Genes 评分计…...
ABAP设计模式之---“简单设计原则(Simple Design)”
“Simple Design”(简单设计)是软件开发中的一个重要理念,倡导以最简单的方式实现软件功能,以确保代码清晰易懂、易维护,并在项目需求变化时能够快速适应。 其核心目标是避免复杂和过度设计,遵循“让事情保…...
【VLNs篇】07:NavRL—在动态环境中学习安全飞行
项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战,克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...
Git 3天2K星标:Datawhale 的 Happy-LLM 项目介绍(附教程)
引言 在人工智能飞速发展的今天,大语言模型(Large Language Models, LLMs)已成为技术领域的焦点。从智能写作到代码生成,LLM 的应用场景不断扩展,深刻改变了我们的工作和生活方式。然而,理解这些模型的内部…...
jmeter聚合报告中参数详解
sample、average、min、max、90%line、95%line,99%line、Error错误率、吞吐量Thoughput、KB/sec每秒传输的数据量 sample(样本数) 表示测试中发送的请求数量,即测试执行了多少次请求。 单位,以个或者次数表示。 示例:…...
