当前位置: 首页 > news >正文

【Pytorch】13.搭建完整的CIFAR10模型

项目源码

已上传至githubCIFAR10Model,如果有帮助可以点个star

简介

在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程
本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型

基本步骤

搭建神经网络最主要的流程是

  • 导入数据集(包括训练集和测试集)
  • 创建DataLoader
  • 创建自定义的神经网络
  • 选择损失函数与梯度下降算法
  • 进行n轮训练
  • n轮训练完成后通过测试集进行验证
  • 引入TensorBoard进行可视化
  • 保存每轮训练好的模型
    接下来将逐步拆解这每一个步骤

1.导入数据集

因为我们本文是要训练CIFAR10的模型,所以我们导入CIFAR10的数据集

# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)

分别导入训练集与测试集,并且分别记录训练集与测试集的大小
对参数的解释可以看【Pytorch】4.torchvision.datasets的使用这篇文章

2.创建DataLoader

DataLoader主要定义了如何在数据集中取数据的规则,具体讲解可以看【Pytorch】5.DataLoder的使用

# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

3.创建自定义的神经网络

在这里插入图片描述
我们可以在网上搜到CIFAR10的网络模型,通过网络模型来搭建网络,具体可以看【Pytorch】10.CIFAR10模型搭建

import torch
from torch import nnclass CIFAR10Model(nn.Module):def __init__(self):super(CIFAR10Model, self).__init__()self.conv1 = nn.Conv2d(3, 32, 5, padding=2)self.maxpool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 32, 5, padding=2)self.maxpool2 = nn.MaxPool2d(2, 2)self.conv3 = nn.Conv2d(32, 64, 5, padding=2)self.maxpool3 = nn.MaxPool2d(2, 2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(1024, 64)self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.conv3(x)x = self.maxpool3(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xif __name__ == '__main__':model = CIFAR10Model()input_test = torch.ones((64, 3, 32, 32))output_test = model(input_test)print(output_test.shape)

这里我们新创建了一个model.py用于专门存储网络结构,这样在我们的训练文件中,可以通过

from model import *# 3.创建神经网络
model = CIFAR10Model()

来导入我们自定义的神经网络

4.选择损失函数和梯度下降的方法

我们选择了交叉熵损失函数与SGD的梯度下降算法,具体讲解可以看【Pytorch】11.损失函数与梯度下降

# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)

5.开始进行训练

首先将模型设置为训练模式

    model.train()

具体的训练流程分为以下几部

  • 从DataLoader中获取图片以及对应的编号
  • 将图片传入神经网络并获取输出
  • 将优化器清零
  • 计算损失函数
  • 进行梯度下降
  • 调用优化器进行更新
    for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()

在基础训练的基础上,还安排了每进行100次训练就将训练数据print出来,并且写入tensorboard

 # 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")

6.测试集验证

首先将模型设置为测试集模式

    model.eval()

首先通过with关键字来创建一个没有梯度的上下文
验证方法与训练集类似,但是没有计算梯度与更新优化器的步骤

 with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)

然后通过torch.argmax用于计算所有标签的最大值

  • 参数为1时代表横向判断
  • 参数为0的代表纵向判断
    计算当前模型在训练集中的正确次数
            pre_accuracy += outputs.argmax(1).eq(labels).sum().item()

7.引入TensorBoard进行可视化

我们主要是通过Summary中的add_scalar来建立可视化函数来进行可视化的,具体可以看【Pytorch】2.TensorBoard的运用

# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')# 在训练集中,输出每一百次训练的损失函数平均值# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)# 在测试集中,输出模型在测试集中的正确率
pre_accuracy += outputs.argmax(1).eq(labels).sum().item()writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)

8.保存模型

具体可以看【Pytorch】12.网络模型的加载、修改与保存

    # 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')

完整代码

import time
import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriterfrom model import *# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)# 3.创建神经网络
model = CIFAR10Model()# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 0.0001
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)# 训练轮数
total_train_step = 0
total_test_step = 0# 训练轮数
epoch = 20# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')
# 5.开始训练
for i in range(epoch):# 将模型设置为训练模式print(f"----------------------------开启第{i+1}轮训练----------------------------")model.train()# 第i轮训练的次数pre_train_step = 0# 第i轮训练的总损失pre_train_loss = 0# 第i轮训练的起始时间start_time = time.time()for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()# 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")# 设置为测试模式model.eval()# 第i轮训练集的总损失pre_test_loss = 0# 第i轮训练集的总正确次数pre_accuracy = 0print(f"----------------------------开启第{i + 1}轮测试----------------------------")# 配置没有梯度下降的环境with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)# 定义参数pre_test_loss += loss.item()# 记录训练集的总正确率# argmax(1)代表横向判断,argmax(0)代表纵向判断pre_accuracy += outputs.argmax(1).eq(labels).sum().item()# 记录测试集运行完后的事件end_test_time = time.time()print(f'当前为第{i + 1}轮测试,已经过时间:{end_test_time - start_time},当前测试集的平均损失为:{pre_test_loss / test_size},当前在测试集的正确率为:{pre_accuracy / test_size}')writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)print(f"----------------------------第{i + 1}轮测试完成----------------------------")# 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')print(f"----------------------------第{i + 1}轮模型保存完成----------------------------")writer.close()

训练效果
在这里插入图片描述
在这里插入图片描述

相关文章:

【Pytorch】13.搭建完整的CIFAR10模型

项目源码 已上传至githubCIFAR10Model,如果有帮助可以点个star 简介 在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程 本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型 基本步骤 搭建…...

护目镜佩戴自动识别预警摄像机

护目镜佩戴自动识别预警摄像机是一种智能监测设备,专门用于佩戴护目镜的工人进行作业时,能够自动识别有潜在风险的场景,并及时发出预警信号。该摄像机配备人脸识别和智能预警系统,可以检测危险情况并为工人提供实时安全保护&#…...

keep-alive的使用

Vue中的<keep-alive>组件是前端开发中的一个宝藏功能&#xff0c;它如同时光胶囊般保留组件的状态&#xff0c;让组件在切换时仿佛按下暂停键&#xff0c;再次回来时还能继续播放&#xff0c;极大地优化了用户体验和性能。&#x1f680;✨ 作用 状态保留&#xff1a;当包…...

【Linux】中的常见的重要指令(中)

目录 一、man指令 二、cp指令 三、cat指令 四、mv指令 五、more指令 六、less指令 七、head指令 八、tail指令 一、man指令 Linux的命令有很多参数&#xff0c;我们不可能全记住&#xff0c;我们可以通过查看联机手册获取帮助。访问Linux手册页的命令是 man 语法: m…...

营收净利双降、股东减持,大降价也救不了良品铺子

号称“高端零食第一股”的良品铺子(603719.SH)&#xff0c;正遭遇部分股东的“用脚投票”。 5月17日晚间&#xff0c;良品铺子连发两份减持公告&#xff0c;其控股股东宁波汉意创业投资合伙企业、持股5%以上股东达永有限公司&#xff0c;两者均计划减持。 其中&#xff0c;宁…...

【设计模式】设计模式的分类

通常设计模式的分类有创建型、行为型和结构型。 创建型 常用的有&#xff1a;单例模式、工厂模式&#xff08;工厂方法和抽象工厂&#xff09;、建造者模式。 不常用的有&#xff1a;原型模式。 创建型模式涉及到将对象实例化&#xff0c;这类模式都提供一个方法&#xff0c;将…...

TCP/UDP的连接机制

TCP/UDP的连接机制 TCP的连接机制 TCP&#xff08;Transmission Control Protocol&#xff09;是一种面向连接的协议&#xff0c;提供可靠的、按顺序的数据传输服务。TCP的连接机制包括连接建立、数据传输和连接终止。 1. 连接建立&#xff08;三次握手&#xff09; TCP通过…...

供应链金融模式学习资料

目录 产生背景 供应链金融的诞生 供应链金额的六大特征...

代码随想录-算法训练营day50【动态规划12:最佳买卖股票时机含冷冻期、买卖股票的最佳时机含手续费、股票问题总结】

代码随想录-035期-算法训练营【博客笔记汇总表】-CSDN博客 第九章 动态规划part12● 309.最佳买卖股票时机含冷冻期 ● 714.买卖股票的最佳时机含手续费 ●总结309.最佳买卖股票时机含冷冻期 本题加了一个冷冻期,状态就多了,有点难度,大家要把各个状态分清,思路才能清晰…...

Dilworth 定理

这是一个关于偏序集的定理&#xff0c;事实上它也可以扩展到图论&#xff0c;dp等中&#xff0c;是一个很有意思的东西 偏序集 偏序集是由集合 S S S以及其上的一个偏序关系 R R R定义的&#xff0c;记为 ( S , R ) (S,R) (S,R) 偏序关系&#xff1a; 对于一个二元关系 R ⊂…...

BUUCTF---web---[BJDCTF2020]ZJCTF,不过如此

1、点开连接&#xff0c;页面出现了提示 传入一个参数text&#xff0c;里面的内容要包括I have a dream。 构造&#xff1a;?/textI have a dream。发现页面没有显示。这里推测可能得使用伪协议 在文件包含那一行&#xff0c;我们看到了next.php的提示&#xff0c;我们尝试读取…...

力扣刷题---2206. 将数组划分成相等数对【简单】

题目描述&#x1f357; 给你一个整数数组 nums &#xff0c;它包含 2 * n 个整数。 你需要将 nums 划分成 n 个数对&#xff0c;满足&#xff1a; 每个元素 只属于一个 数对。 同一数对中的元素 相等 。 如果可以将 nums 划分成 n 个数对&#xff0c;请你返回 true &#xf…...

2461. 长度为 K 子数组中的最大和(c++)

给你一个整数数组 nums 和一个整数 k 。请你从 nums 中满足下述条件的全部子数组中找出最大子数组和&#xff1a; 子数组的长度是 k&#xff0c;且子数组中的所有元素 各不相同 。 返回满足题面要求的最大子数组和。如果不存在子数组满足这些条件&#xff0c;返回 0 。 子数…...

range for

1. 基于范围的for循环语法 C11标准引入了基于范围的for循环特性&#xff0c;该特性隐藏了迭代器 的初始化和更新过程&#xff0c;让程序员只需要关心遍历对象本身&#xff0c;其语法也 比传统for循环简洁很多&#xff1a; for ( range_declaration : range_expression ) {loo…...

leetcode230 二叉搜索树中第K小的元素

题目 给定一个二叉搜索树的根节点 root &#xff0c;和一个整数 k &#xff0c;请你设计一个算法查找其中第 k 个最小元素&#xff08;从 1 开始计数&#xff09;。 示例 输入&#xff1a;root [5,3,6,2,4,null,null,1], k 3 输出&#xff1a;3 解析 这道题应该是能做出…...

.Net Core学习笔记 框架特性(注入、配置)

注&#xff1a;直接学习的.Net Core 6&#xff0c;此版本有没有startup.cs相关的内容 项目Program.cs文件中 是定义项目加载 启动的地方 //通过builder对项目进行配置、服务的加载 var builder WebApplication.CreateBuilder(args); builder.Services.AddControllers();//将…...

利用AI技术做电商网赚,这些百万级赛道流量,你还不知道?!

大家好&#xff0c;我是向阳 AI技术的飞速扩展已经势不可挡&#xff0c;不管你承不承认&#xff0c;AI 已经毫无争议的在互联网中占有一席之地了 无论你是做内容产业的&#xff0c;还是做电商的&#xff0c;你现在都躲不开 AI。 现在互联网行业的竞争就是这么残酷 互联网行业…...

leetcode-560 和为k的数组

一、题目描述 给你一个整数数组 nums 和一个整数 k &#xff0c;请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 注意&#xff1a;nums中的元素可为负数 输入&#xff1a;nums [1,1,1], k 2 输出&#xff1a;2输入&#xff1a;num…...

Spring Boot实战指南:从入门到企业级应用构建

目录 一、引言 二、快速入门 1. 使用Spring Initializr创建项目 三、Spring Boot基础概念与自动配置 1. 理解SpringBootApplication注解 2. 自动配置原理 3. 查看自动配置报告 四、Spring Boot核心特性及实战 1. 外部化配置 2. Actuator端点 3. 集成第三方库 五、Sp…...

OneAPI接入本地大模型+FastGPT调用本地大模型

将Ollama下载的本地大模型配置到OneAPI中&#xff0c;并通过FastGPT调用本地大模型完成对话。 OneAPI配置 新建令牌 新建渠道 FastGPT配置 配置docker-compose 配置令牌和OneAPI部署地址 配置config.json 配置调用的渠道名称和大模型名称 {"systemEnv": {&qu…...

Windows 10终极指南:免费开启HEIC缩略图预览功能

Windows 10终极指南&#xff1a;免费开启HEIC缩略图预览功能 【免费下载链接】windows-heic-thumbnails Enable Windows Explorer to display thumbnails for HEIC files 项目地址: https://gitcode.com/gh_mirrors/wi/windows-heic-thumbnails 还在为iPhone拍摄的照片在…...

别再只数步数了!深入聊聊ADXL345计步算法里的‘动态阈值’与‘最活跃轴’

别再只数步数了&#xff01;深入聊聊ADXL345计步算法里的‘动态阈值’与‘最活跃轴’ 当你盯着智能手环上的步数统计时&#xff0c;有没有想过这串数字背后藏着怎样的算法智慧&#xff1f;ADXL345作为一款经典的三轴加速度传感器&#xff0c;其计步算法远非简单的阈值比较那么简…...

PyTorch 2.8镜像快速部署:5分钟验证torch.cuda.is_available()并启动API服务

PyTorch 2.8镜像快速部署&#xff1a;5分钟验证torch.cuda.is_available()并启动API服务 1. 镜像概述与环境准备 PyTorch 2.8深度学习镜像是一个开箱即用的高性能计算环境&#xff0c;专为现代AI工作负载优化。这个预配置环境能让你跳过繁琐的安装过程&#xff0c;直接进入模…...

SEO_新手必看的SEO优化入门教程与常见误区

什么是SEO优化&#xff1f; SEO优化&#xff0c;全称搜索引擎优化&#xff0c;是指通过优化网站内容和结构&#xff0c;使其在搜索引擎&#xff08;如百度、谷歌&#xff09;中获得更高排名的一系列活动。SEO的目的是提高网站的自然流量&#xff0c;从而增加潜在客户和销售机会…...

你的文件真的‘上传’了吗?聊聊阿里云盘‘秒传’背后的隐私与安全考量

你的文件真的“上传”了吗&#xff1f;揭秘秒传技术背后的隐私博弈 第一次在阿里云盘体验“秒传”功能时&#xff0c;那种近乎魔法的速度确实令人惊叹——几个GB的文件眨眼间就完成了“上传”。但惊喜之余&#xff0c;一个更根本的问题浮现出来&#xff1a;我的文件真的被上传了…...

[具身智能-189]:ROS2的Node通信机制,为硬件的仿真平台与模型算法的分离以及他们之间标准化的通信提供了保障,在嵌入式系统,特别是具身智能开发中,解决“软硬耦合”这一顽疾。

ROS 2 的节点通信机制&#xff0c;本质上就是为了解决“软硬耦合”这一顽疾而生的。 它通过去中心化的架构和标准化的中间件&#xff08;DDS&#xff09;&#xff0c;让仿真平台&#xff08;如 Gazebo、Isaac Sim&#xff09;和模型算法&#xff08;如导航、感知&#xff09;能…...

电路设计与漫画艺术的跨界融合

1. 当电路遇见漫画&#xff1a;工程师的艺术表达在大多数人眼中&#xff0c;电路设计是冰冷的数据和复杂的公式&#xff0c;而漫画则是天马行空的创意表达。但作为一名从业十年的硬件工程师&#xff0c;我发现这两者其实有着惊人的相似之处——它们都需要严谨的结构设计&#x…...

AQM0802字符LCD轻量驱动库:裸机printf级显示方案

1. 项目概述AQM0802 是一款由旭化成&#xff08;AKM&#xff09;推出的超低功耗、单色字符型液晶显示模块&#xff0c;采用 COG&#xff08;Chip-on-Glass&#xff09;封装工艺&#xff0c;内置 KS0066 兼容控制器。其典型型号为 AQM0802A-YBW&#xff0c;具备 8 字符 2 行的显…...

3D Face HRN开源镜像:ModelScope官方cv_resnet50_face-reconstruction部署

3D Face HRN开源镜像&#xff1a;ModelScope官方cv_resnet50_face-reconstruction部署 1. 引言&#xff1a;从2D照片到3D人脸的魔法转换 你是否曾经想过&#xff0c;仅仅通过一张普通的2D人脸照片&#xff0c;就能生成精确的3D人脸模型&#xff1f;这在过去可能需要专业设备和…...

示波器测量UART波特率的原理与实践

1. 示波器测量串口波特率的原理与方法 1.1 串口通信基础 在嵌入式系统开发中&#xff0c;UART串口通信是最常用的调试接口之一。正确识别串口波特率对于设备调试和逆向工程具有重要意义。串口通信采用异步传输方式&#xff0c;其关键参数包括&#xff1a; 波特率&#xff1a;…...