【Pytorch】13.搭建完整的CIFAR10模型
项目源码
已上传至githubCIFAR10Model,如果有帮助可以点个star
简介
在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程
本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型
基本步骤
搭建神经网络最主要的流程是
- 导入数据集(包括训练集和测试集)
- 创建
DataLoader - 创建自定义的神经网络
- 选择损失函数与梯度下降算法
- 进行n轮训练
- n轮训练完成后通过测试集进行验证
- 引入
TensorBoard进行可视化 - 保存每轮训练好的模型
接下来将逐步拆解这每一个步骤
1.导入数据集
因为我们本文是要训练CIFAR10的模型,所以我们导入CIFAR10的数据集
# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)
分别导入训练集与测试集,并且分别记录训练集与测试集的大小
对参数的解释可以看【Pytorch】4.torchvision.datasets的使用这篇文章
2.创建DataLoader
DataLoader主要定义了如何在数据集中取数据的规则,具体讲解可以看【Pytorch】5.DataLoder的使用
# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)
3.创建自定义的神经网络

我们可以在网上搜到CIFAR10的网络模型,通过网络模型来搭建网络,具体可以看【Pytorch】10.CIFAR10模型搭建
import torch
from torch import nnclass CIFAR10Model(nn.Module):def __init__(self):super(CIFAR10Model, self).__init__()self.conv1 = nn.Conv2d(3, 32, 5, padding=2)self.maxpool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 32, 5, padding=2)self.maxpool2 = nn.MaxPool2d(2, 2)self.conv3 = nn.Conv2d(32, 64, 5, padding=2)self.maxpool3 = nn.MaxPool2d(2, 2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(1024, 64)self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.conv3(x)x = self.maxpool3(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xif __name__ == '__main__':model = CIFAR10Model()input_test = torch.ones((64, 3, 32, 32))output_test = model(input_test)print(output_test.shape)
这里我们新创建了一个model.py用于专门存储网络结构,这样在我们的训练文件中,可以通过
from model import *# 3.创建神经网络
model = CIFAR10Model()
来导入我们自定义的神经网络
4.选择损失函数和梯度下降的方法
我们选择了交叉熵损失函数与SGD的梯度下降算法,具体讲解可以看【Pytorch】11.损失函数与梯度下降
# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
5.开始进行训练
首先将模型设置为训练模式
model.train()
具体的训练流程分为以下几部
- 从DataLoader中获取图片以及对应的编号
- 将图片传入神经网络并获取输出
- 将优化器清零
- 计算损失函数
- 进行梯度下降
- 调用优化器进行更新
for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()
在基础训练的基础上,还安排了每进行100次训练就将训练数据print出来,并且写入tensorboard
# 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")
6.测试集验证
首先将模型设置为测试集模式
model.eval()
首先通过with关键字来创建一个没有梯度的上下文
验证方法与训练集类似,但是没有计算梯度与更新优化器的步骤
with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)
然后通过torch.argmax用于计算所有标签的最大值
- 参数为1时代表横向判断
- 参数为0的代表纵向判断
计算当前模型在训练集中的正确次数
pre_accuracy += outputs.argmax(1).eq(labels).sum().item()
7.引入TensorBoard进行可视化
我们主要是通过Summary中的add_scalar来建立可视化函数来进行可视化的,具体可以看【Pytorch】2.TensorBoard的运用
# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')# 在训练集中,输出每一百次训练的损失函数平均值# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)# 在测试集中,输出模型在测试集中的正确率
pre_accuracy += outputs.argmax(1).eq(labels).sum().item()writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)
8.保存模型
具体可以看【Pytorch】12.网络模型的加载、修改与保存
# 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')
完整代码
import time
import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriterfrom model import *# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)# 3.创建神经网络
model = CIFAR10Model()# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 0.0001
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)# 训练轮数
total_train_step = 0
total_test_step = 0# 训练轮数
epoch = 20# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')
# 5.开始训练
for i in range(epoch):# 将模型设置为训练模式print(f"----------------------------开启第{i+1}轮训练----------------------------")model.train()# 第i轮训练的次数pre_train_step = 0# 第i轮训练的总损失pre_train_loss = 0# 第i轮训练的起始时间start_time = time.time()for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()# 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")# 设置为测试模式model.eval()# 第i轮训练集的总损失pre_test_loss = 0# 第i轮训练集的总正确次数pre_accuracy = 0print(f"----------------------------开启第{i + 1}轮测试----------------------------")# 配置没有梯度下降的环境with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)# 定义参数pre_test_loss += loss.item()# 记录训练集的总正确率# argmax(1)代表横向判断,argmax(0)代表纵向判断pre_accuracy += outputs.argmax(1).eq(labels).sum().item()# 记录测试集运行完后的事件end_test_time = time.time()print(f'当前为第{i + 1}轮测试,已经过时间:{end_test_time - start_time},当前测试集的平均损失为:{pre_test_loss / test_size},当前在测试集的正确率为:{pre_accuracy / test_size}')writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)print(f"----------------------------第{i + 1}轮测试完成----------------------------")# 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')print(f"----------------------------第{i + 1}轮模型保存完成----------------------------")writer.close()
训练效果


相关文章:
【Pytorch】13.搭建完整的CIFAR10模型
项目源码 已上传至githubCIFAR10Model,如果有帮助可以点个star 简介 在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程 本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型 基本步骤 搭建…...
护目镜佩戴自动识别预警摄像机
护目镜佩戴自动识别预警摄像机是一种智能监测设备,专门用于佩戴护目镜的工人进行作业时,能够自动识别有潜在风险的场景,并及时发出预警信号。该摄像机配备人脸识别和智能预警系统,可以检测危险情况并为工人提供实时安全保护&#…...
keep-alive的使用
Vue中的<keep-alive>组件是前端开发中的一个宝藏功能,它如同时光胶囊般保留组件的状态,让组件在切换时仿佛按下暂停键,再次回来时还能继续播放,极大地优化了用户体验和性能。🚀✨ 作用 状态保留:当包…...
【Linux】中的常见的重要指令(中)
目录 一、man指令 二、cp指令 三、cat指令 四、mv指令 五、more指令 六、less指令 七、head指令 八、tail指令 一、man指令 Linux的命令有很多参数,我们不可能全记住,我们可以通过查看联机手册获取帮助。访问Linux手册页的命令是 man 语法: m…...
营收净利双降、股东减持,大降价也救不了良品铺子
号称“高端零食第一股”的良品铺子(603719.SH),正遭遇部分股东的“用脚投票”。 5月17日晚间,良品铺子连发两份减持公告,其控股股东宁波汉意创业投资合伙企业、持股5%以上股东达永有限公司,两者均计划减持。 其中,宁…...
【设计模式】设计模式的分类
通常设计模式的分类有创建型、行为型和结构型。 创建型 常用的有:单例模式、工厂模式(工厂方法和抽象工厂)、建造者模式。 不常用的有:原型模式。 创建型模式涉及到将对象实例化,这类模式都提供一个方法,将…...
TCP/UDP的连接机制
TCP/UDP的连接机制 TCP的连接机制 TCP(Transmission Control Protocol)是一种面向连接的协议,提供可靠的、按顺序的数据传输服务。TCP的连接机制包括连接建立、数据传输和连接终止。 1. 连接建立(三次握手) TCP通过…...
供应链金融模式学习资料
目录 产生背景 供应链金融的诞生 供应链金额的六大特征...
代码随想录-算法训练营day50【动态规划12:最佳买卖股票时机含冷冻期、买卖股票的最佳时机含手续费、股票问题总结】
代码随想录-035期-算法训练营【博客笔记汇总表】-CSDN博客 第九章 动态规划part12● 309.最佳买卖股票时机含冷冻期 ● 714.买卖股票的最佳时机含手续费 ●总结309.最佳买卖股票时机含冷冻期 本题加了一个冷冻期,状态就多了,有点难度,大家要把各个状态分清,思路才能清晰…...
Dilworth 定理
这是一个关于偏序集的定理,事实上它也可以扩展到图论,dp等中,是一个很有意思的东西 偏序集 偏序集是由集合 S S S以及其上的一个偏序关系 R R R定义的,记为 ( S , R ) (S,R) (S,R) 偏序关系: 对于一个二元关系 R ⊂…...
BUUCTF---web---[BJDCTF2020]ZJCTF,不过如此
1、点开连接,页面出现了提示 传入一个参数text,里面的内容要包括I have a dream。 构造:?/textI have a dream。发现页面没有显示。这里推测可能得使用伪协议 在文件包含那一行,我们看到了next.php的提示,我们尝试读取…...
力扣刷题---2206. 将数组划分成相等数对【简单】
题目描述🍗 给你一个整数数组 nums ,它包含 2 * n 个整数。 你需要将 nums 划分成 n 个数对,满足: 每个元素 只属于一个 数对。 同一数对中的元素 相等 。 如果可以将 nums 划分成 n 个数对,请你返回 true …...
2461. 长度为 K 子数组中的最大和(c++)
给你一个整数数组 nums 和一个整数 k 。请你从 nums 中满足下述条件的全部子数组中找出最大子数组和: 子数组的长度是 k,且子数组中的所有元素 各不相同 。 返回满足题面要求的最大子数组和。如果不存在子数组满足这些条件,返回 0 。 子数…...
range for
1. 基于范围的for循环语法 C11标准引入了基于范围的for循环特性,该特性隐藏了迭代器 的初始化和更新过程,让程序员只需要关心遍历对象本身,其语法也 比传统for循环简洁很多: for ( range_declaration : range_expression ) {loo…...
leetcode230 二叉搜索树中第K小的元素
题目 给定一个二叉搜索树的根节点 root ,和一个整数 k ,请你设计一个算法查找其中第 k 个最小元素(从 1 开始计数)。 示例 输入:root [5,3,6,2,4,null,null,1], k 3 输出:3 解析 这道题应该是能做出…...
.Net Core学习笔记 框架特性(注入、配置)
注:直接学习的.Net Core 6,此版本有没有startup.cs相关的内容 项目Program.cs文件中 是定义项目加载 启动的地方 //通过builder对项目进行配置、服务的加载 var builder WebApplication.CreateBuilder(args); builder.Services.AddControllers();//将…...
利用AI技术做电商网赚,这些百万级赛道流量,你还不知道?!
大家好,我是向阳 AI技术的飞速扩展已经势不可挡,不管你承不承认,AI 已经毫无争议的在互联网中占有一席之地了 无论你是做内容产业的,还是做电商的,你现在都躲不开 AI。 现在互联网行业的竞争就是这么残酷 互联网行业…...
leetcode-560 和为k的数组
一、题目描述 给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 注意:nums中的元素可为负数 输入:nums [1,1,1], k 2 输出:2输入:num…...
Spring Boot实战指南:从入门到企业级应用构建
目录 一、引言 二、快速入门 1. 使用Spring Initializr创建项目 三、Spring Boot基础概念与自动配置 1. 理解SpringBootApplication注解 2. 自动配置原理 3. 查看自动配置报告 四、Spring Boot核心特性及实战 1. 外部化配置 2. Actuator端点 3. 集成第三方库 五、Sp…...
OneAPI接入本地大模型+FastGPT调用本地大模型
将Ollama下载的本地大模型配置到OneAPI中,并通过FastGPT调用本地大模型完成对话。 OneAPI配置 新建令牌 新建渠道 FastGPT配置 配置docker-compose 配置令牌和OneAPI部署地址 配置config.json 配置调用的渠道名称和大模型名称 {"systemEnv": {&qu…...
stm32G473的flash模式是单bank还是双bank?
今天突然有人stm32G473的flash模式是单bank还是双bank?由于时间太久,我真忘记了。搜搜发现,还真有人和我一样。见下面的链接:https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...
【大模型RAG】Docker 一键部署 Milvus 完整攻略
本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装;只需暴露 19530(gRPC)与 9091(HTTP/WebUI)两个端口,即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...
OkHttp 中实现断点续传 demo
在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...
[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...
大学生职业发展与就业创业指导教学评价
这里是引用 作为软工2203/2204班的学生,我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要,而您认真负责的教学态度,让课程的每一部分都充满了实用价值。 尤其让我…...
智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制
在数字化浪潮席卷全球的今天,数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具,在大规模数据获取中发挥着关键作用。然而,传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时,常出现数据质…...
html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码
目录 一、👨🎓网站题目 二、✍️网站描述 三、📚网站介绍 四、🌐网站效果 五、🪓 代码实现 🧱HTML 六、🥇 如何让学习不再盲目 七、🎁更多干货 一、👨…...
React---day11
14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store: 我们在使用异步的时候理应是要使用中间件的,但是configureStore 已经自动集成了 redux-thunk,注意action里面要返回函数 import { configureS…...
Java + Spring Boot + Mybatis 实现批量插入
在 Java 中使用 Spring Boot 和 MyBatis 实现批量插入可以通过以下步骤完成。这里提供两种常用方法:使用 MyBatis 的 <foreach> 标签和批处理模式(ExecutorType.BATCH)。 方法一:使用 XML 的 <foreach> 标签ÿ…...
C++.OpenGL (20/64)混合(Blending)
混合(Blending) 透明效果核心原理 #mermaid-svg-SWG0UzVfJms7Sm3e {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-icon{fill:#552222;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-text{fill…...
