当前位置: 首页 > news >正文

(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10

  1. 导入相关库
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
  1. 下载数据集
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip','2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = Trueif demo:data_dir = d2l.download_extract('cifar10_tiny')
else:data_dir = '../data/kaggle/cifar-10/'
  1. 整理数据集
# 查看数据集
def read_csv_labels(fname):"""读取‘fname’来给标签字典返回一个文件名"""with open(fname, 'r') as f:lines = f.readlines()[1:]  # readlines(): 每次读文档的一行,以后还需要逐步循环tokens = [l.rstrip().split(',') for l in lines]  # rstrip(): 删除字符串后面(右面)的空格或特殊字符, 还有lstrip(左面)、strip(两面)return dict((name, label) for name, label in tokens)labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('训练样本:', len(labels))
print('类别:', len(set(labels.values())))  # set(): 集合,里面不能包含重复的元素,接受一个list作为参数

在这里插入图片描述
将验证集从原始的训练集钟拆分出来

# 拆分数据集:训练集、验证集
def copyfile(filename, target_dir):"""将文件复制到目标目录"""os.makedirs(target_dir, exist_ok=True)  # 创建多层目录,exist_ok为True:在目标目录已存在的情况下不会触发FileExistsError异常。shutil.copy(filename, target_dir)  #拷贝文件,filename:要拷贝的文件;target_dir:目标文件夹def reorg_train_valid(data_dir, labels, valid_ratio):"""将验证集从原始训练集钟拆分出来"""# 训练数据集中样本数量最少的类别中的样本数# Counter: 计数器,返回一个字典,键为元素,值为元素个数;# .most_common(): 返回一个列表, 列表元素为(元素,出现次数),默认按出现频率排序# [-1]: 样本数量最少的类别(类别, 样本数),[-1][1]: 样本数数量最少的类别中的样本数n = collections.Counter(labels.values()).most_common()[-1][1]# 验证集中每个类别的样本数n_valid_per_label= max(1, math.floor((n * valid_ratio)))  # math.floor(): 向下取整  math.ceil(): 向上取整label_count = {}# 遍历原始训练集中的每个样本for train_file in os.listdir(os.path.join(data_dir, 'train')):label = labels[train_file.split('.')[0]]  # 从文件名中提取标签fname = os.path.join(data_dir, 'train', train_file)copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train_valid', label))# 如果该类别的样本数还未达到在验证集中的设定数量,则将样本复制到验证集中if label not in label_count or label_count[label] < n_valid_per_label:copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'valid', label))label_count[label] = label_count.get(label, 0) + 1else:copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train', label))return n_valid_per_label# reorg_test函数用来在预测期间整理测试集,以方便读取
def reorg_test(data_dir):"""在预测期间整理测试集,以方便读取"""# 遍历测试集中的每个样本for test_file in os.listdir(os.path.join(data_dir, 'test')):# 将测试集中的样本复制到新的目录结构中的 'test' 子目录下,标签为 'unknown'copyfile(os.path.join(data_dir, 'test', test_file),os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
# 整个处理数据集函数
def reorg_cifar10_data(data_dir, valid_ratio):labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))reorg_train_valid(data_dir, labels, valid_ratio)reorg_test(data_dir)
  • 这个小规模数据集的批量大小是32,在实际的cifar-10数据集中,可以设为128
  • 将10%的训练样本作为调整超参数的验证集
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
  1. 图像增广
transform_train = torchvision.transforms.Compose([# 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强torchvision.transforms.Resize(40),torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),# 标准化图像的每个通道 : 消除评估结果中的随机性torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
  1. 加载数据集
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test) for folder in ['valid', 'test']
]
  1. 定义迭代器,方便快速迭代数据
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=False
)
  1. 定义模型与损失函数
# 对resnet18做微调,输入通道数为3, 输出类别数为10
def get_net():num_classes = 10net = d2l.resnet18(num_classes, in_channels=3)return net
# 查看网络模型
get_net()

在这里插入图片描述

# 使用交叉熵损失函数作为损失函数: 直接返回n分样本的loss
loss = nn.CrossEntropyLoss(reduction='none')
  1. 定义训练函数
# 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss', 'train acc']if valid_iter is not None:legend.append('valid acc')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)net = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):net.train()metric = d2l.Accumulator(3)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices)metric.add(l, acc, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(metric[0]/ metric[2], metric[1] / metric[2], None))if valid_iter is not None:valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)animator.add(epoch+1, (None, None, valid_acc))scheduler.step()measures = (f'train loss {metric[0] / metric[2]:.3f},'f'train acc{metric[1] / metric[2]:.3f}')if valid_iter is not None:measures += f', valid acc {valid_acc:.3f}'print(measures + f'\n{metric[2] * num_epochs /timer.sum():.1f}'f'example/sec on {str(devices)}')
  1. 训练模型
    • (数据集太小,导致精度不高)
import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述
10. 对测试集进行分类并提交结果

net, preds = get_net(), []
train(net ,train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)
for X, _ in test_iter:y_hat = net(X.to(devices[0]))preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id' : sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)

在这里插入图片描述

相关文章:

(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10

导入相关库 import collections import math import os import shutil import pandas as pd import torch import torchvision from torch import nn from d2l import torch as d2l下载数据集 d2l.DATA_HUB[cifar10_tiny] (d2l.DATA_URL kaggle_cifar10_tiny.zip,2068874e4…...

Go 语言中的map和内存泄漏

map在内存中总是会增长&#xff1b;它不会收缩。因此&#xff0c;如果map导致了一些内存问题&#xff0c;你可以尝试不同的选项&#xff0c;比如强制 Go 重新创建map或使用指针。 在 Go 中使用map时&#xff0c;我们需要了解map增长和收缩的一些重要特性。让我们深入探讨这一点…...

前缀和(c++,超详细,含二维)

前缀和与差分 当给定一段整数序列a1,a2,a3,a4,a5…an; 每次让我们求一段区间的和&#xff0c;正常做法是for循环遍历区间起始点到结束点&#xff0c;进行求和计算&#xff0c;但是当询问次数很多并且区间很长的时候 比如&#xff0c;10^5 个询问和10^6区间长度&#xff0c;相…...

详解FreeRTOS:二值信号量和计数信号量(高级篇—2)

目录 1、二值信号量 1.1、二值信号量运行机制 1.2、创建二值信号量 1...

持续集成交付CICD:Jenkins通过API触发流水线

目录 一、理论 1.HTTP请求 2.调用接口的方法 3.HTTP常见错误码 二、实验 1.Jenkins通过API触发流水线 三、问题 1.如何拿到上一次jenkinsfile文件进行自动触发流水线 一、理论 1.HTTP请求 &#xff08;1&#xff09;概念 HTTP超文本传输协议&#xff0c;是确保服务器…...

【Python】12 GPflow安装

概述 GPflow 是一个基于TensorFlow 在 Python 中构建高斯过程模型的包。高斯过程是一种监督学习模型。 高斯过程的一些优点是&#xff1a; 不确定性是高斯过程的固有部分。高斯过程可以在不知道答案时告诉您。适用于小型数据集。如果您的数据有限&#xff0c;高斯过程可以从…...

Ubuntu源码编译gdal3.6.2

在华为云申请了一台Ubuntu v18的机器,乱七八糟的不要装。 apt install build-essential pkg-config -y cmake-3.21.1 apt-get install openssl libssl-dev 过程参考&#xff1a;Yukon for PostgreSQL_格來羙、日出的博客-CSDN博客 zlib-1.2.9(不需要) 如果用系统的后面gd…...

【LeetCode】160. 相交链表

160. 相交链表 难度&#xff1a;简单 题目 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点&#xff0c;返回 null 。 图示两个链表在节点 c1 开始相交&#xff1a; 题目数据 保证 整个链式结构中…...

数据集笔记:NGSIM (next generation simulation)

1 数据集介绍 数据介绍s Next Generation Simulation (NGSIM) Open Data (transportation.gov) 数据地址&#xff1a;Next Generation Simulation (NGSIM) Vehicle Trajectories and Supporting Data | Department of Transportation - Data Portal 时间2005年到2006年间地…...

解决docker运行elastic服务端启动不成功

现象&#xff1a; 然后查看docker日志&#xff0c;发现有vm.max_map_count报错 ERROR: [1] bootstrap checks failed [1]: max virtual memory areas vm.max_map_count [65530] is too low, increase to at least [262144] 解决办法&#xff1a; 1. 宿主机&#xff08;运行doc…...

mysql数据库中mysql database 数据被破坏产生的一系列问题

在执行sql脚本时&#xff0c;没有注意到sql脚本文件包含了对mysql 原始数据库的操作&#xff0c;执行了脚本。 脚本执行成功之后&#xff0c;登录或链接数据库查看数据时报错&#xff1a; The user specified as a definer (‘mysql.infoschema’‘localhost’) does not exis…...

基于变形卷积和注意机制的带钢表面缺陷快速检测网络DCAM-Net(论文阅读笔记)

原论文链接->DCAM-Net: A Rapid Detection Network for Strip Steel Surface Defects Based on Deformable Convolution and Attention Mechanism | IEEE Journals & Magazine | IEEE Xplore DCAM-Net: A Rapid Detection Network for Strip Steel Surface Defects Base…...

05-Spring Boot工程中简化开发的方式Lombok和dev-tools

简化开发的方式Lombok和dev-tools Lombok常用注解 Lombok用标签方式代替构造器、getter/setter、toString()等重复代码, 在程序编译的时候自动生成这些代码 注解名功能NoArgsConstructor生成无参构造方法AllArgsConstructor生产含所有属性的有参构造方法,如果不希望含所有属…...

AIGC 技术在淘淘秀场景的探索与实践

本文介绍了AIGC相关领域的爆发式增长&#xff0c;并探讨了淘宝秀秀(AI买家秀)的设计思路和技术方案。文章涵盖了图像生成、仿真形象生成和换背景方案&#xff0c;以及模型流程串联等关键技术。 文章还介绍了淘淘秀的使用流程和遇到的问题及处理方法。最后&#xff0c;文章展望…...

ANSYS网格无关性检查

网格精度对应力结果存在很大的影响&#xff0c;有时候可以发现&#xff0c;随着网格精度逐渐提高&#xff0c;所求得的最大应力值逐渐趋于收敛。 默认网格&#xff1a; 从默认网格下计算出的应力云图可以发现&#xff0c;出现了的三处应力奇异点&#xff0c;此时算出的应力值是…...

设计模式-责任链-笔记

动机&#xff08;Motivation&#xff09; 在软件构建过程中&#xff0c;一个请求可能被多个对象处理&#xff0c;但是每个请求在运行时只能有个接受者&#xff0c;如果显示指定&#xff0c;将必不可少地带来请求者与接受者的紧耦合。 如何使请求的发送者不需要指定具体的接受…...

SpringMvc请求原理流程

springmvc是用户和服务沟通的桥梁&#xff0c;官网提供了springmvc的全面使用和解释&#xff1a;DispatcherServlet :: Spring Framework 流程 1.Tomcat启动 2.解析web.xml文件&#xff0c;根据servlet-class找到DispatcherServlet&#xff0c;根据init-param来获取spring的…...

【开源】基于Vue.js的音乐偏好度推荐系统的设计和实现

项目编号&#xff1a; S 012 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S012&#xff0c;文末获取源码。} 项目编号&#xff1a;S012&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、系统设计2.1 功能模块设计2.1.1 音乐档案模块2.1…...

采集1688整店商品(店铺所有商品、店铺列表api)

返回数据&#xff1a; 请求链接 {"user": [],"items": {"item": [{"num_iid": "738354436678","title": "国产正品i13 promax全网通5G安卓智能手机源头厂家批发手机","pic_url": "http…...

IObit Unlocker丨解除占用程序软件

更多内容请收藏&#xff1a;https://rwx.tza-3.xyz 官网&#xff1a;IObit Unlocker “永远不用担心电脑上无法删除的文件。” 界面简单&#xff0c;支持简体中文&#xff0c;一看就会&#xff0c;只需要把无法删除/移动的文件或整个U盘拖到框里就行。 解锁率很高&#xff0c;…...

开发一款小程序游戏需要多少钱?

小程序游戏的开发成本因多种因素而异&#xff0c;无法提供具体的固定数字。以下是影响小程序游戏开发成本的一些关键因素&#xff1a; 游戏规模和复杂度&#xff1a; 小程序游戏可以是简单的休闲游戏&#xff0c;也可以是更复杂的策略游戏。规模和复杂度会影响开发所需的时间和…...

基于Vue+SpringBoot的校园电商物流云平台开源项目

项目编号&#xff1a; S 034 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S034&#xff0c;文末获取源码。} 项目编号&#xff1a;S034&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 商品数据模块2.3 快…...

庖丁解牛:NIO核心概念与机制详解 03 _ 缓冲区分配、包装和分片

文章目录 Pre概述缓冲区分配和包装 &#xff08;allocate 、 wrap&#xff09;缓冲区分片 (slice)缓冲区份片和数据共享只读缓冲区 &#xff08;asReadOnlyBuffer&#xff09;直接和间接缓冲区 (allocateDirect)内存映射文件 I/O将文件映射到内存(map) Pre 庖丁解牛&#xff1…...

002 OpenCV dft 傅里叶变换

目录 一、傅里叶变换 1.1 傅里叶变换概念 1.2 opencv中傅里叶变换 二、实验代码 一、环境 本文使用环境为&#xff1a; Windows10Python 3.9.17opencv-python 4.8.0.74 二、傅里叶变换 2.1 傅里叶变换概念 傅里叶变换&#xff08;Fourier Transform&#xff09;是一种…...

力扣:171. Excel 表列序号(Python3)

题目&#xff1a; 给你一个字符串 columnTitle &#xff0c;表示 Excel 表格中的列名称。返回 该列名称对应的列序号 。 例如&#xff1a; A -> 1 B -> 2 C -> 3 ... Z -> 26 AA -> 27 AB -> 28 ... 来源&#xff1a;力扣&#xff08;LeetCode&#xff09; …...

C++中结构体的初始化

C中结构体的初始化 结构体是一个由程序员定义的数据类型&#xff0c;可以容纳许多不同的数据值。在过去&#xff0c;面向对象编程的应用尚未普及之前&#xff0c;程序员通常使用这些从逻辑上连接在一起的数据组合到一个单元中。一旦结构体类型被声明并且其数据成员被标识&…...

vue3+vite+ts 发布自定义组件到npm

vue3vite 发布自定义组件到npm 初始化项目编写组件配置打包组件上传到npm测试组件库 初始化项目 // 创建项目 pnpm create vite vue-test-app --template vue-ts// 运行项目 cd vite vue-test-app pnpm install pnpm run dev编写组件 1、根目录下创建packages目录作为组件的开…...

mybatis使用xml形式配置

以这个注解形式的查询代码为例 Select("select * from emp where name like concat(%,#{name},%) and gender #{gender} and entrydate between #{begin} and #{end} order by update_time desc ")public List<Emp> list(String name, Short gender, LocalDat…...

开源简历生成器OpenResume

什么是 OpenResume &#xff1f; OpenResume 是一个功能强大的开源简历生成器和简历解析器。OpenResume 的目标是为每个人提供免费的现代专业简历设计&#xff0c;让任何人都能充满信心地申请工作。 OpenResume 有 5 个核心特点&#xff1a; 特征描述1. 实时UI更新当您输入简历…...

AI变现之Gpts搞流量+赚钱

文章目录 Gpts | 搞流量 + 赚钱1.项目介绍2.项目分析3.项目实操4.变现路径Gpts | 搞流量 + 赚钱 1.项目介绍 这两天 AI 圈最火的莫过于 OpenAI 开发者大会公布的一个爆炸产品 Gpts 了,大家都知道这个肯定是一个划时代的产品,也绝对是一个风口,虽然官方还没有公布到底怎么通…...