当前位置: 首页 > news >正文

【Pytorch】13.搭建完整的CIFAR10模型

项目源码

已上传至githubCIFAR10Model,如果有帮助可以点个star

简介

在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程
本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型

基本步骤

搭建神经网络最主要的流程是

  • 导入数据集(包括训练集和测试集)
  • 创建DataLoader
  • 创建自定义的神经网络
  • 选择损失函数与梯度下降算法
  • 进行n轮训练
  • n轮训练完成后通过测试集进行验证
  • 引入TensorBoard进行可视化
  • 保存每轮训练好的模型
    接下来将逐步拆解这每一个步骤

1.导入数据集

因为我们本文是要训练CIFAR10的模型,所以我们导入CIFAR10的数据集

# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)

分别导入训练集与测试集,并且分别记录训练集与测试集的大小
对参数的解释可以看【Pytorch】4.torchvision.datasets的使用这篇文章

2.创建DataLoader

DataLoader主要定义了如何在数据集中取数据的规则,具体讲解可以看【Pytorch】5.DataLoder的使用

# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

3.创建自定义的神经网络

在这里插入图片描述
我们可以在网上搜到CIFAR10的网络模型,通过网络模型来搭建网络,具体可以看【Pytorch】10.CIFAR10模型搭建

import torch
from torch import nnclass CIFAR10Model(nn.Module):def __init__(self):super(CIFAR10Model, self).__init__()self.conv1 = nn.Conv2d(3, 32, 5, padding=2)self.maxpool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 32, 5, padding=2)self.maxpool2 = nn.MaxPool2d(2, 2)self.conv3 = nn.Conv2d(32, 64, 5, padding=2)self.maxpool3 = nn.MaxPool2d(2, 2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(1024, 64)self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.conv3(x)x = self.maxpool3(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xif __name__ == '__main__':model = CIFAR10Model()input_test = torch.ones((64, 3, 32, 32))output_test = model(input_test)print(output_test.shape)

这里我们新创建了一个model.py用于专门存储网络结构,这样在我们的训练文件中,可以通过

from model import *# 3.创建神经网络
model = CIFAR10Model()

来导入我们自定义的神经网络

4.选择损失函数和梯度下降的方法

我们选择了交叉熵损失函数与SGD的梯度下降算法,具体讲解可以看【Pytorch】11.损失函数与梯度下降

# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)

5.开始进行训练

首先将模型设置为训练模式

    model.train()

具体的训练流程分为以下几部

  • 从DataLoader中获取图片以及对应的编号
  • 将图片传入神经网络并获取输出
  • 将优化器清零
  • 计算损失函数
  • 进行梯度下降
  • 调用优化器进行更新
    for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()

在基础训练的基础上,还安排了每进行100次训练就将训练数据print出来,并且写入tensorboard

 # 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")

6.测试集验证

首先将模型设置为测试集模式

    model.eval()

首先通过with关键字来创建一个没有梯度的上下文
验证方法与训练集类似,但是没有计算梯度与更新优化器的步骤

 with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)

然后通过torch.argmax用于计算所有标签的最大值

  • 参数为1时代表横向判断
  • 参数为0的代表纵向判断
    计算当前模型在训练集中的正确次数
            pre_accuracy += outputs.argmax(1).eq(labels).sum().item()

7.引入TensorBoard进行可视化

我们主要是通过Summary中的add_scalar来建立可视化函数来进行可视化的,具体可以看【Pytorch】2.TensorBoard的运用

# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')# 在训练集中,输出每一百次训练的损失函数平均值# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)# 在测试集中,输出模型在测试集中的正确率
pre_accuracy += outputs.argmax(1).eq(labels).sum().item()writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)

8.保存模型

具体可以看【Pytorch】12.网络模型的加载、修改与保存

    # 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')

完整代码

import time
import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriterfrom model import *# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)# 3.创建神经网络
model = CIFAR10Model()# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 0.0001
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)# 训练轮数
total_train_step = 0
total_test_step = 0# 训练轮数
epoch = 20# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')
# 5.开始训练
for i in range(epoch):# 将模型设置为训练模式print(f"----------------------------开启第{i+1}轮训练----------------------------")model.train()# 第i轮训练的次数pre_train_step = 0# 第i轮训练的总损失pre_train_loss = 0# 第i轮训练的起始时间start_time = time.time()for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()# 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")# 设置为测试模式model.eval()# 第i轮训练集的总损失pre_test_loss = 0# 第i轮训练集的总正确次数pre_accuracy = 0print(f"----------------------------开启第{i + 1}轮测试----------------------------")# 配置没有梯度下降的环境with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)# 定义参数pre_test_loss += loss.item()# 记录训练集的总正确率# argmax(1)代表横向判断,argmax(0)代表纵向判断pre_accuracy += outputs.argmax(1).eq(labels).sum().item()# 记录测试集运行完后的事件end_test_time = time.time()print(f'当前为第{i + 1}轮测试,已经过时间:{end_test_time - start_time},当前测试集的平均损失为:{pre_test_loss / test_size},当前在测试集的正确率为:{pre_accuracy / test_size}')writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)print(f"----------------------------第{i + 1}轮测试完成----------------------------")# 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')print(f"----------------------------第{i + 1}轮模型保存完成----------------------------")writer.close()

训练效果
在这里插入图片描述
在这里插入图片描述

相关文章:

【Pytorch】13.搭建完整的CIFAR10模型

项目源码 已上传至githubCIFAR10Model,如果有帮助可以点个star 简介 在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程 本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型 基本步骤 搭建…...

护目镜佩戴自动识别预警摄像机

护目镜佩戴自动识别预警摄像机是一种智能监测设备,专门用于佩戴护目镜的工人进行作业时,能够自动识别有潜在风险的场景,并及时发出预警信号。该摄像机配备人脸识别和智能预警系统,可以检测危险情况并为工人提供实时安全保护&#…...

keep-alive的使用

Vue中的<keep-alive>组件是前端开发中的一个宝藏功能&#xff0c;它如同时光胶囊般保留组件的状态&#xff0c;让组件在切换时仿佛按下暂停键&#xff0c;再次回来时还能继续播放&#xff0c;极大地优化了用户体验和性能。&#x1f680;✨ 作用 状态保留&#xff1a;当包…...

【Linux】中的常见的重要指令(中)

目录 一、man指令 二、cp指令 三、cat指令 四、mv指令 五、more指令 六、less指令 七、head指令 八、tail指令 一、man指令 Linux的命令有很多参数&#xff0c;我们不可能全记住&#xff0c;我们可以通过查看联机手册获取帮助。访问Linux手册页的命令是 man 语法: m…...

营收净利双降、股东减持,大降价也救不了良品铺子

号称“高端零食第一股”的良品铺子(603719.SH)&#xff0c;正遭遇部分股东的“用脚投票”。 5月17日晚间&#xff0c;良品铺子连发两份减持公告&#xff0c;其控股股东宁波汉意创业投资合伙企业、持股5%以上股东达永有限公司&#xff0c;两者均计划减持。 其中&#xff0c;宁…...

【设计模式】设计模式的分类

通常设计模式的分类有创建型、行为型和结构型。 创建型 常用的有&#xff1a;单例模式、工厂模式&#xff08;工厂方法和抽象工厂&#xff09;、建造者模式。 不常用的有&#xff1a;原型模式。 创建型模式涉及到将对象实例化&#xff0c;这类模式都提供一个方法&#xff0c;将…...

TCP/UDP的连接机制

TCP/UDP的连接机制 TCP的连接机制 TCP&#xff08;Transmission Control Protocol&#xff09;是一种面向连接的协议&#xff0c;提供可靠的、按顺序的数据传输服务。TCP的连接机制包括连接建立、数据传输和连接终止。 1. 连接建立&#xff08;三次握手&#xff09; TCP通过…...

供应链金融模式学习资料

目录 产生背景 供应链金融的诞生 供应链金额的六大特征...

代码随想录-算法训练营day50【动态规划12:最佳买卖股票时机含冷冻期、买卖股票的最佳时机含手续费、股票问题总结】

代码随想录-035期-算法训练营【博客笔记汇总表】-CSDN博客 第九章 动态规划part12● 309.最佳买卖股票时机含冷冻期 ● 714.买卖股票的最佳时机含手续费 ●总结309.最佳买卖股票时机含冷冻期 本题加了一个冷冻期,状态就多了,有点难度,大家要把各个状态分清,思路才能清晰…...

Dilworth 定理

这是一个关于偏序集的定理&#xff0c;事实上它也可以扩展到图论&#xff0c;dp等中&#xff0c;是一个很有意思的东西 偏序集 偏序集是由集合 S S S以及其上的一个偏序关系 R R R定义的&#xff0c;记为 ( S , R ) (S,R) (S,R) 偏序关系&#xff1a; 对于一个二元关系 R ⊂…...

BUUCTF---web---[BJDCTF2020]ZJCTF,不过如此

1、点开连接&#xff0c;页面出现了提示 传入一个参数text&#xff0c;里面的内容要包括I have a dream。 构造&#xff1a;?/textI have a dream。发现页面没有显示。这里推测可能得使用伪协议 在文件包含那一行&#xff0c;我们看到了next.php的提示&#xff0c;我们尝试读取…...

力扣刷题---2206. 将数组划分成相等数对【简单】

题目描述&#x1f357; 给你一个整数数组 nums &#xff0c;它包含 2 * n 个整数。 你需要将 nums 划分成 n 个数对&#xff0c;满足&#xff1a; 每个元素 只属于一个 数对。 同一数对中的元素 相等 。 如果可以将 nums 划分成 n 个数对&#xff0c;请你返回 true &#xf…...

2461. 长度为 K 子数组中的最大和(c++)

给你一个整数数组 nums 和一个整数 k 。请你从 nums 中满足下述条件的全部子数组中找出最大子数组和&#xff1a; 子数组的长度是 k&#xff0c;且子数组中的所有元素 各不相同 。 返回满足题面要求的最大子数组和。如果不存在子数组满足这些条件&#xff0c;返回 0 。 子数…...

range for

1. 基于范围的for循环语法 C11标准引入了基于范围的for循环特性&#xff0c;该特性隐藏了迭代器 的初始化和更新过程&#xff0c;让程序员只需要关心遍历对象本身&#xff0c;其语法也 比传统for循环简洁很多&#xff1a; for ( range_declaration : range_expression ) {loo…...

leetcode230 二叉搜索树中第K小的元素

题目 给定一个二叉搜索树的根节点 root &#xff0c;和一个整数 k &#xff0c;请你设计一个算法查找其中第 k 个最小元素&#xff08;从 1 开始计数&#xff09;。 示例 输入&#xff1a;root [5,3,6,2,4,null,null,1], k 3 输出&#xff1a;3 解析 这道题应该是能做出…...

.Net Core学习笔记 框架特性(注入、配置)

注&#xff1a;直接学习的.Net Core 6&#xff0c;此版本有没有startup.cs相关的内容 项目Program.cs文件中 是定义项目加载 启动的地方 //通过builder对项目进行配置、服务的加载 var builder WebApplication.CreateBuilder(args); builder.Services.AddControllers();//将…...

利用AI技术做电商网赚,这些百万级赛道流量,你还不知道?!

大家好&#xff0c;我是向阳 AI技术的飞速扩展已经势不可挡&#xff0c;不管你承不承认&#xff0c;AI 已经毫无争议的在互联网中占有一席之地了 无论你是做内容产业的&#xff0c;还是做电商的&#xff0c;你现在都躲不开 AI。 现在互联网行业的竞争就是这么残酷 互联网行业…...

leetcode-560 和为k的数组

一、题目描述 给你一个整数数组 nums 和一个整数 k &#xff0c;请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 注意&#xff1a;nums中的元素可为负数 输入&#xff1a;nums [1,1,1], k 2 输出&#xff1a;2输入&#xff1a;num…...

Spring Boot实战指南:从入门到企业级应用构建

目录 一、引言 二、快速入门 1. 使用Spring Initializr创建项目 三、Spring Boot基础概念与自动配置 1. 理解SpringBootApplication注解 2. 自动配置原理 3. 查看自动配置报告 四、Spring Boot核心特性及实战 1. 外部化配置 2. Actuator端点 3. 集成第三方库 五、Sp…...

OneAPI接入本地大模型+FastGPT调用本地大模型

将Ollama下载的本地大模型配置到OneAPI中&#xff0c;并通过FastGPT调用本地大模型完成对话。 OneAPI配置 新建令牌 新建渠道 FastGPT配置 配置docker-compose 配置令牌和OneAPI部署地址 配置config.json 配置调用的渠道名称和大模型名称 {"systemEnv": {&qu…...

如何安全高效的文件管理?文件管理方法

文件的管理早已不只是办公场景中的需求。日常生活、在线学习以及个人收藏中&#xff0c;文件管理正逐渐成为我们数字生活中的基础。但与此同时&#xff0c;文件管理的混乱、低效以及安全性问题也频繁困扰着许多人。 文件管理的挑战与解决思路 挑战一&#xff1a;文件存储无序…...

【学习记录】深入解析 AI 交互中的五大核心概念:Prompt、Agent、MCP、Function Calling 与 Tools

&#x1f4cc; 引言 随着大语言模型&#xff08;LLM&#xff09;的发展&#xff0c;AI 已经不再只是“回答问题”的工具&#xff0c;而是可以主动执行任务、调用外部资源、甚至构建完整工作流的智能系统。 为了更好地理解和使用这些能力&#xff0c;我们需要了解 AI 交互中几…...

Tensorborad

一、tensorboard的基本操作 1.1 发展历史 TensorBoard 是 TensorFlow 生态中的官方可视化工具&#xff08;也可无缝集成 PyTorch&#xff09;&#xff0c;用于实时监控训练过程、可视化模型结构、分析数据分布、对比实验结果等。它通过网页端交互界面&#xff0c;将枯燥的训练…...

LeetCode 高频 SQL 50 题(基础版) 之 【高级查询和连接】· 上

题目&#xff1a;1731. 每位经理的下属员工数量 题解&#xff1a; select employee_id,name,reports_count,average_age from Employees t1,(select reports_to,count(*) reports_count,round(avg(age)) average_agefrom Employeeswhere reports_to is not nullgroup by repor…...

图像分类进阶:从基础到专业 (superior哥AI系列第10期)

图像分类进阶&#xff1a;从基础到专业 &#x1f680; 前言 &#x1f44b; 哈喽&#xff0c;各位深度学习的探索者们&#xff01;我是你们的老朋友superior哥 &#x1f60e; 经过前面九篇文章的学习&#xff0c;相信大家对深度学习的基础概念、神经网络架构、以及训练部署都…...

rocketmq索引

索引的理解 索引是什么, 索引实质是 相同数据的另一种存储结构 我们都知道读和写天然是存在矛盾的, 我们希望写的快,当然是顺序写的性能最高, 顺序写造成数据杂乱无章,没法按照一定的规律去找数。 如果想要找数的效率高, 必须要有结构组织的存放数据, 这样方便按规律找…...

Springboot——整合websocket并根据type区别处理

文章目录 前言架构思想项目结构代码实现依赖引入自定义注解定义具体的处理类定义 TypeAWebSocketHandler定义 TypeBWebSocketHandler 定义路由处理类配置类&#xff0c;绑定point制定前端页面编写测试接口方便跳转进入前端页面 测试验证结语 前言 之前写过一篇类似的博客&…...

【Git系列】如何同步原始仓库的更新到你的fork仓库?

&#x1f389;&#x1f389;&#x1f389;欢迎来到我们的博客&#xff01;无论您是第一次访问&#xff0c;还是我们的老朋友&#xff0c;我们都由衷地感谢您的到来。无论您是来寻找灵感、获取知识&#xff0c;还是单纯地享受阅读的乐趣&#xff0c;我们都希望您能在这里找到属于…...

找到每一个单词+模拟的思路和算法

如大家所知&#xff0c;我们可以对给定的字符串 sentence 进行一次遍历&#xff0c;找出其中的每一个单词&#xff0c;并根据题目的要求进行操作。 在寻找单词时&#xff0c;我们可以使用语言自带的 split() 函数&#xff0c;将空格作为分割字符&#xff0c;得到所有的单词。为…...

互联网大厂Java面试:从Spring Cloud到Kafka的技术考察

场景&#xff1a;互联网大厂Java求职者面试 面试官与谢飞机的对话 面试官&#xff1a;我们先从基础开始&#xff0c;谢飞机&#xff0c;你能简单介绍一下Java SE和Java EE的区别吗&#xff1f; 谢飞机&#xff1a;哦&#xff0c;这个简单。Java SE是标准版&#xff0c;适合桌…...