【Pytorch】13.搭建完整的CIFAR10模型
项目源码
已上传至githubCIFAR10Model,如果有帮助可以点个star
简介
在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程
本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型
基本步骤
搭建神经网络最主要的流程是
- 导入数据集(包括训练集和测试集)
- 创建
DataLoader - 创建自定义的神经网络
- 选择损失函数与梯度下降算法
- 进行n轮训练
- n轮训练完成后通过测试集进行验证
- 引入
TensorBoard进行可视化 - 保存每轮训练好的模型
接下来将逐步拆解这每一个步骤
1.导入数据集
因为我们本文是要训练CIFAR10的模型,所以我们导入CIFAR10的数据集
# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)
分别导入训练集与测试集,并且分别记录训练集与测试集的大小
对参数的解释可以看【Pytorch】4.torchvision.datasets的使用这篇文章
2.创建DataLoader
DataLoader主要定义了如何在数据集中取数据的规则,具体讲解可以看【Pytorch】5.DataLoder的使用
# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)
3.创建自定义的神经网络

我们可以在网上搜到CIFAR10的网络模型,通过网络模型来搭建网络,具体可以看【Pytorch】10.CIFAR10模型搭建
import torch
from torch import nnclass CIFAR10Model(nn.Module):def __init__(self):super(CIFAR10Model, self).__init__()self.conv1 = nn.Conv2d(3, 32, 5, padding=2)self.maxpool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 32, 5, padding=2)self.maxpool2 = nn.MaxPool2d(2, 2)self.conv3 = nn.Conv2d(32, 64, 5, padding=2)self.maxpool3 = nn.MaxPool2d(2, 2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(1024, 64)self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.conv3(x)x = self.maxpool3(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xif __name__ == '__main__':model = CIFAR10Model()input_test = torch.ones((64, 3, 32, 32))output_test = model(input_test)print(output_test.shape)
这里我们新创建了一个model.py用于专门存储网络结构,这样在我们的训练文件中,可以通过
from model import *# 3.创建神经网络
model = CIFAR10Model()
来导入我们自定义的神经网络
4.选择损失函数和梯度下降的方法
我们选择了交叉熵损失函数与SGD的梯度下降算法,具体讲解可以看【Pytorch】11.损失函数与梯度下降
# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
5.开始进行训练
首先将模型设置为训练模式
model.train()
具体的训练流程分为以下几部
- 从DataLoader中获取图片以及对应的编号
- 将图片传入神经网络并获取输出
- 将优化器清零
- 计算损失函数
- 进行梯度下降
- 调用优化器进行更新
for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()
在基础训练的基础上,还安排了每进行100次训练就将训练数据print出来,并且写入tensorboard
# 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")
6.测试集验证
首先将模型设置为测试集模式
model.eval()
首先通过with关键字来创建一个没有梯度的上下文
验证方法与训练集类似,但是没有计算梯度与更新优化器的步骤
with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)
然后通过torch.argmax用于计算所有标签的最大值
- 参数为1时代表横向判断
- 参数为0的代表纵向判断
计算当前模型在训练集中的正确次数
pre_accuracy += outputs.argmax(1).eq(labels).sum().item()
7.引入TensorBoard进行可视化
我们主要是通过Summary中的add_scalar来建立可视化函数来进行可视化的,具体可以看【Pytorch】2.TensorBoard的运用
# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')# 在训练集中,输出每一百次训练的损失函数平均值# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)# 在测试集中,输出模型在测试集中的正确率
pre_accuracy += outputs.argmax(1).eq(labels).sum().item()writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)
8.保存模型
具体可以看【Pytorch】12.网络模型的加载、修改与保存
# 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')
完整代码
import time
import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriterfrom model import *# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)# 3.创建神经网络
model = CIFAR10Model()# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 0.0001
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)# 训练轮数
total_train_step = 0
total_test_step = 0# 训练轮数
epoch = 20# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')
# 5.开始训练
for i in range(epoch):# 将模型设置为训练模式print(f"----------------------------开启第{i+1}轮训练----------------------------")model.train()# 第i轮训练的次数pre_train_step = 0# 第i轮训练的总损失pre_train_loss = 0# 第i轮训练的起始时间start_time = time.time()for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()# 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")# 设置为测试模式model.eval()# 第i轮训练集的总损失pre_test_loss = 0# 第i轮训练集的总正确次数pre_accuracy = 0print(f"----------------------------开启第{i + 1}轮测试----------------------------")# 配置没有梯度下降的环境with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)# 定义参数pre_test_loss += loss.item()# 记录训练集的总正确率# argmax(1)代表横向判断,argmax(0)代表纵向判断pre_accuracy += outputs.argmax(1).eq(labels).sum().item()# 记录测试集运行完后的事件end_test_time = time.time()print(f'当前为第{i + 1}轮测试,已经过时间:{end_test_time - start_time},当前测试集的平均损失为:{pre_test_loss / test_size},当前在测试集的正确率为:{pre_accuracy / test_size}')writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)print(f"----------------------------第{i + 1}轮测试完成----------------------------")# 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')print(f"----------------------------第{i + 1}轮模型保存完成----------------------------")writer.close()
训练效果


相关文章:
【Pytorch】13.搭建完整的CIFAR10模型
项目源码 已上传至githubCIFAR10Model,如果有帮助可以点个star 简介 在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程 本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型 基本步骤 搭建…...
护目镜佩戴自动识别预警摄像机
护目镜佩戴自动识别预警摄像机是一种智能监测设备,专门用于佩戴护目镜的工人进行作业时,能够自动识别有潜在风险的场景,并及时发出预警信号。该摄像机配备人脸识别和智能预警系统,可以检测危险情况并为工人提供实时安全保护&#…...
keep-alive的使用
Vue中的<keep-alive>组件是前端开发中的一个宝藏功能,它如同时光胶囊般保留组件的状态,让组件在切换时仿佛按下暂停键,再次回来时还能继续播放,极大地优化了用户体验和性能。🚀✨ 作用 状态保留:当包…...
【Linux】中的常见的重要指令(中)
目录 一、man指令 二、cp指令 三、cat指令 四、mv指令 五、more指令 六、less指令 七、head指令 八、tail指令 一、man指令 Linux的命令有很多参数,我们不可能全记住,我们可以通过查看联机手册获取帮助。访问Linux手册页的命令是 man 语法: m…...
营收净利双降、股东减持,大降价也救不了良品铺子
号称“高端零食第一股”的良品铺子(603719.SH),正遭遇部分股东的“用脚投票”。 5月17日晚间,良品铺子连发两份减持公告,其控股股东宁波汉意创业投资合伙企业、持股5%以上股东达永有限公司,两者均计划减持。 其中,宁…...
【设计模式】设计模式的分类
通常设计模式的分类有创建型、行为型和结构型。 创建型 常用的有:单例模式、工厂模式(工厂方法和抽象工厂)、建造者模式。 不常用的有:原型模式。 创建型模式涉及到将对象实例化,这类模式都提供一个方法,将…...
TCP/UDP的连接机制
TCP/UDP的连接机制 TCP的连接机制 TCP(Transmission Control Protocol)是一种面向连接的协议,提供可靠的、按顺序的数据传输服务。TCP的连接机制包括连接建立、数据传输和连接终止。 1. 连接建立(三次握手) TCP通过…...
供应链金融模式学习资料
目录 产生背景 供应链金融的诞生 供应链金额的六大特征...
代码随想录-算法训练营day50【动态规划12:最佳买卖股票时机含冷冻期、买卖股票的最佳时机含手续费、股票问题总结】
代码随想录-035期-算法训练营【博客笔记汇总表】-CSDN博客 第九章 动态规划part12● 309.最佳买卖股票时机含冷冻期 ● 714.买卖股票的最佳时机含手续费 ●总结309.最佳买卖股票时机含冷冻期 本题加了一个冷冻期,状态就多了,有点难度,大家要把各个状态分清,思路才能清晰…...
Dilworth 定理
这是一个关于偏序集的定理,事实上它也可以扩展到图论,dp等中,是一个很有意思的东西 偏序集 偏序集是由集合 S S S以及其上的一个偏序关系 R R R定义的,记为 ( S , R ) (S,R) (S,R) 偏序关系: 对于一个二元关系 R ⊂…...
BUUCTF---web---[BJDCTF2020]ZJCTF,不过如此
1、点开连接,页面出现了提示 传入一个参数text,里面的内容要包括I have a dream。 构造:?/textI have a dream。发现页面没有显示。这里推测可能得使用伪协议 在文件包含那一行,我们看到了next.php的提示,我们尝试读取…...
力扣刷题---2206. 将数组划分成相等数对【简单】
题目描述🍗 给你一个整数数组 nums ,它包含 2 * n 个整数。 你需要将 nums 划分成 n 个数对,满足: 每个元素 只属于一个 数对。 同一数对中的元素 相等 。 如果可以将 nums 划分成 n 个数对,请你返回 true …...
2461. 长度为 K 子数组中的最大和(c++)
给你一个整数数组 nums 和一个整数 k 。请你从 nums 中满足下述条件的全部子数组中找出最大子数组和: 子数组的长度是 k,且子数组中的所有元素 各不相同 。 返回满足题面要求的最大子数组和。如果不存在子数组满足这些条件,返回 0 。 子数…...
range for
1. 基于范围的for循环语法 C11标准引入了基于范围的for循环特性,该特性隐藏了迭代器 的初始化和更新过程,让程序员只需要关心遍历对象本身,其语法也 比传统for循环简洁很多: for ( range_declaration : range_expression ) {loo…...
leetcode230 二叉搜索树中第K小的元素
题目 给定一个二叉搜索树的根节点 root ,和一个整数 k ,请你设计一个算法查找其中第 k 个最小元素(从 1 开始计数)。 示例 输入:root [5,3,6,2,4,null,null,1], k 3 输出:3 解析 这道题应该是能做出…...
.Net Core学习笔记 框架特性(注入、配置)
注:直接学习的.Net Core 6,此版本有没有startup.cs相关的内容 项目Program.cs文件中 是定义项目加载 启动的地方 //通过builder对项目进行配置、服务的加载 var builder WebApplication.CreateBuilder(args); builder.Services.AddControllers();//将…...
利用AI技术做电商网赚,这些百万级赛道流量,你还不知道?!
大家好,我是向阳 AI技术的飞速扩展已经势不可挡,不管你承不承认,AI 已经毫无争议的在互联网中占有一席之地了 无论你是做内容产业的,还是做电商的,你现在都躲不开 AI。 现在互联网行业的竞争就是这么残酷 互联网行业…...
leetcode-560 和为k的数组
一、题目描述 给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 注意:nums中的元素可为负数 输入:nums [1,1,1], k 2 输出:2输入:num…...
Spring Boot实战指南:从入门到企业级应用构建
目录 一、引言 二、快速入门 1. 使用Spring Initializr创建项目 三、Spring Boot基础概念与自动配置 1. 理解SpringBootApplication注解 2. 自动配置原理 3. 查看自动配置报告 四、Spring Boot核心特性及实战 1. 外部化配置 2. Actuator端点 3. 集成第三方库 五、Sp…...
OneAPI接入本地大模型+FastGPT调用本地大模型
将Ollama下载的本地大模型配置到OneAPI中,并通过FastGPT调用本地大模型完成对话。 OneAPI配置 新建令牌 新建渠道 FastGPT配置 配置docker-compose 配置令牌和OneAPI部署地址 配置config.json 配置调用的渠道名称和大模型名称 {"systemEnv": {&qu…...
从按键消抖到外部中断:STM32 GPIO输入模式的‘避坑’指南与AFIO的隐藏用法
从按键消抖到外部中断:STM32 GPIO输入模式的‘避坑’指南与AFIO的隐藏用法 在嵌入式开发中,GPIO(通用输入输出)接口是与外部世界交互的第一道门槛。对于STM32开发者来说,GPIO配置看似简单,却暗藏诸多细节陷…...
Python安全开发之简易Xss检测工具(详细注释)
核心代码:import requests # requests 库 - HTTP 请求处理库 # 【常用功能】: # requests.get(url) - 发送 HTTP GET 请求 # requests.post(url, data) - 发送 HTTP POST 请求 # response.text - 获取响应体内容(字符串) #…...
Phi-3-mini-4k-instruct-gguf多场景落地:跨境电商多语言商品描述批量生成
Phi-3-mini-4k-instruct-gguf多场景落地:跨境电商多语言商品描述批量生成 1. 跨境电商的痛点与解决方案 跨境电商卖家每天面临的最大挑战之一,就是为同一款商品准备不同语言版本的描述。传统做法要么需要雇佣多语种文案人员,要么使用机械的…...
四管升降压电路实战解析:从拓扑原理到模式切换(附波形对比)
1. 四管升降压电路为何成为工程师的"瑞士军刀" 第一次接触四管升降压电路时,我正被一个光伏储能项目折磨得焦头烂额。太阳能板的输出电压在8V-18V剧烈波动,而系统需要稳定的12V供电。传统方案要用两个独立电路串联,直到老工程师扔给…...
PN5180 ISO15693协议栈实现与嵌入式NFC开发指南
1. PN5180库深度解析:面向嵌入式工程师的NFC ISO15693协议栈实现指南NXP PN5180是业界领先的多协议NFC控制器,支持ISO/IEC 14443 A/B、ISO/IEC 15693、Felica及NFC Forum Type 1–5标签。其核心优势在于高集成度射频前端、可编程调制解调器及灵活的主机接…...
【ZGC性能黄金阈值手册】:基于127个线上集群实测数据,定义堆大小/线程数/触发频率最优配比
第一章:ZGC性能黄金阈值的定义与行业意义ZGC(Z Garbage Collector)作为JDK 11引入的低延迟垃圾收集器,其核心设计目标是将GC暂停时间稳定控制在10毫秒以内,且不随堆大小线性增长。而“ZGC性能黄金阈值”并非官方术语&a…...
5分钟快速上手LosslessCut:零编码视频剪辑的终极指南
5分钟快速上手LosslessCut:零编码视频剪辑的终极指南 【免费下载链接】lossless-cut The swiss army knife of lossless video/audio editing 项目地址: https://gitcode.com/gh_mirrors/lo/lossless-cut 你是否曾因视频剪辑导致画质下降而烦恼?是…...
联想新品入局,AI智能终端市场格局生变
联想新品发布,直击Mac mini“养虾”痛点2026年3月31日,联想集团正式发布YOGA AI Mini与Think AI Tiny两款AI原生智能终端。其中,YOGA AI Mini面向个人消费市场,精准对标当下被众多用户用于运行OpenClaw的Mac mini。Mac mini虽因便…...
5G RedCap路由器如何选?关键特性解析与典型应用场景指南
1. 5G RedCap路由器选购的核心指标 第一次接触5G RedCap路由器时,我被参数表里密密麻麻的术语搞得头晕眼花。后来在工业现场实测了7款不同型号后,才发现真正影响使用体验的关键指标其实就这几个: 频段支持就像路由器的"语言能力"。…...
Graphormer效果对比评测:vs GCN、GAT、GIN在分子回归任务上的表现
Graphormer效果对比评测:vs GCN、GAT、GIN在分子回归任务上的表现 1. 引言 在药物发现和材料科学领域,准确预测分子属性是一个关键挑战。传统方法依赖昂贵的实验或复杂的量子化学计算,而图神经网络(GNN)提供了一种更高效的替代方案。本文将…...
