当前位置: 首页 > news >正文

【Pytorch】13.搭建完整的CIFAR10模型

项目源码

已上传至githubCIFAR10Model,如果有帮助可以点个star

简介

在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程
本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型

基本步骤

搭建神经网络最主要的流程是

  • 导入数据集(包括训练集和测试集)
  • 创建DataLoader
  • 创建自定义的神经网络
  • 选择损失函数与梯度下降算法
  • 进行n轮训练
  • n轮训练完成后通过测试集进行验证
  • 引入TensorBoard进行可视化
  • 保存每轮训练好的模型
    接下来将逐步拆解这每一个步骤

1.导入数据集

因为我们本文是要训练CIFAR10的模型,所以我们导入CIFAR10的数据集

# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)

分别导入训练集与测试集,并且分别记录训练集与测试集的大小
对参数的解释可以看【Pytorch】4.torchvision.datasets的使用这篇文章

2.创建DataLoader

DataLoader主要定义了如何在数据集中取数据的规则,具体讲解可以看【Pytorch】5.DataLoder的使用

# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

3.创建自定义的神经网络

在这里插入图片描述
我们可以在网上搜到CIFAR10的网络模型,通过网络模型来搭建网络,具体可以看【Pytorch】10.CIFAR10模型搭建

import torch
from torch import nnclass CIFAR10Model(nn.Module):def __init__(self):super(CIFAR10Model, self).__init__()self.conv1 = nn.Conv2d(3, 32, 5, padding=2)self.maxpool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 32, 5, padding=2)self.maxpool2 = nn.MaxPool2d(2, 2)self.conv3 = nn.Conv2d(32, 64, 5, padding=2)self.maxpool3 = nn.MaxPool2d(2, 2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(1024, 64)self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.conv3(x)x = self.maxpool3(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xif __name__ == '__main__':model = CIFAR10Model()input_test = torch.ones((64, 3, 32, 32))output_test = model(input_test)print(output_test.shape)

这里我们新创建了一个model.py用于专门存储网络结构,这样在我们的训练文件中,可以通过

from model import *# 3.创建神经网络
model = CIFAR10Model()

来导入我们自定义的神经网络

4.选择损失函数和梯度下降的方法

我们选择了交叉熵损失函数与SGD的梯度下降算法,具体讲解可以看【Pytorch】11.损失函数与梯度下降

# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)

5.开始进行训练

首先将模型设置为训练模式

    model.train()

具体的训练流程分为以下几部

  • 从DataLoader中获取图片以及对应的编号
  • 将图片传入神经网络并获取输出
  • 将优化器清零
  • 计算损失函数
  • 进行梯度下降
  • 调用优化器进行更新
    for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()

在基础训练的基础上,还安排了每进行100次训练就将训练数据print出来,并且写入tensorboard

 # 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")

6.测试集验证

首先将模型设置为测试集模式

    model.eval()

首先通过with关键字来创建一个没有梯度的上下文
验证方法与训练集类似,但是没有计算梯度与更新优化器的步骤

 with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)

然后通过torch.argmax用于计算所有标签的最大值

  • 参数为1时代表横向判断
  • 参数为0的代表纵向判断
    计算当前模型在训练集中的正确次数
            pre_accuracy += outputs.argmax(1).eq(labels).sum().item()

7.引入TensorBoard进行可视化

我们主要是通过Summary中的add_scalar来建立可视化函数来进行可视化的,具体可以看【Pytorch】2.TensorBoard的运用

# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')# 在训练集中,输出每一百次训练的损失函数平均值# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)# 在测试集中,输出模型在测试集中的正确率
pre_accuracy += outputs.argmax(1).eq(labels).sum().item()writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)

8.保存模型

具体可以看【Pytorch】12.网络模型的加载、修改与保存

    # 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')

完整代码

import time
import torch
import torchvision.transforms
from torch.utils.tensorboard import SummaryWriterfrom model import *# 1.创建训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())
# 记录数据集大小
train_size = len(train_dataset)
test_size = len(test_dataset)# 2.创建dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)# 3.创建神经网络
model = CIFAR10Model()# 4.设置损失函数与梯度下降算法
loss_fn = nn.CrossEntropyLoss()learn_rate = 0.0001
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)# 训练轮数
total_train_step = 0
total_test_step = 0# 训练轮数
epoch = 20# 创建TensorBoard
writer = SummaryWriter('./CIFAR10_logs')
# 5.开始训练
for i in range(epoch):# 将模型设置为训练模式print(f"----------------------------开启第{i+1}轮训练----------------------------")model.train()# 第i轮训练的次数pre_train_step = 0# 第i轮训练的总损失pre_train_loss = 0# 第i轮训练的起始时间start_time = time.time()for data in train_loader:# 训练基本流程inputs, labels = dataoutputs = model(inputs)optimizer.zero_grad()loss = loss_fn(outputs, labels)loss.backward()optimizer.step()# 第i轮训练次数加一pre_train_step += 1pre_train_loss += loss.item()total_train_step += 1# 每100次输出一下if pre_train_step % 100 == 0:end_train_time = time.time()print(f'当前为第{i+1}轮训练,当前训练轮数为:{pre_train_step},已经过时间为:{end_train_time-start_time},当前训练次数的平均损失为:{pre_train_loss / pre_train_step}')# 添加可视化writer.add_scalar('train_loss', pre_train_loss / pre_train_step, total_train_step)print(f"----------------------------第{i + 1}轮训练完成----------------------------")# 设置为测试模式model.eval()# 第i轮训练集的总损失pre_test_loss = 0# 第i轮训练集的总正确次数pre_accuracy = 0print(f"----------------------------开启第{i + 1}轮测试----------------------------")# 配置没有梯度下降的环境with torch.no_grad():for data in test_loader:# 测试集流程inputs, labels = dataoutputs = model(inputs)loss = loss_fn(outputs, labels)# 定义参数pre_test_loss += loss.item()# 记录训练集的总正确率# argmax(1)代表横向判断,argmax(0)代表纵向判断pre_accuracy += outputs.argmax(1).eq(labels).sum().item()# 记录测试集运行完后的事件end_test_time = time.time()print(f'当前为第{i + 1}轮测试,已经过时间:{end_test_time - start_time},当前测试集的平均损失为:{pre_test_loss / test_size},当前在测试集的正确率为:{pre_accuracy / test_size}')writer.add_scalar('test_accuracy', pre_accuracy / test_size, i)print(f"----------------------------第{i + 1}轮测试完成----------------------------")# 保存每轮的训练模型torch.save(CIFAR10Model, f'./CIFAR10TrainModel{i}.pth')print(f"----------------------------第{i + 1}轮模型保存完成----------------------------")writer.close()

训练效果
在这里插入图片描述
在这里插入图片描述

相关文章:

【Pytorch】13.搭建完整的CIFAR10模型

项目源码 已上传至githubCIFAR10Model,如果有帮助可以点个star 简介 在前文【Pytorch】10.CIFAR10模型搭建我们学习了用Module来模拟搭建CIFAR10的训练流程 本节将会加入损失函数,梯度下降,TensorBoard来完整搭建一个训练的模型 基本步骤 搭建…...

护目镜佩戴自动识别预警摄像机

护目镜佩戴自动识别预警摄像机是一种智能监测设备,专门用于佩戴护目镜的工人进行作业时,能够自动识别有潜在风险的场景,并及时发出预警信号。该摄像机配备人脸识别和智能预警系统,可以检测危险情况并为工人提供实时安全保护&#…...

keep-alive的使用

Vue中的<keep-alive>组件是前端开发中的一个宝藏功能&#xff0c;它如同时光胶囊般保留组件的状态&#xff0c;让组件在切换时仿佛按下暂停键&#xff0c;再次回来时还能继续播放&#xff0c;极大地优化了用户体验和性能。&#x1f680;✨ 作用 状态保留&#xff1a;当包…...

【Linux】中的常见的重要指令(中)

目录 一、man指令 二、cp指令 三、cat指令 四、mv指令 五、more指令 六、less指令 七、head指令 八、tail指令 一、man指令 Linux的命令有很多参数&#xff0c;我们不可能全记住&#xff0c;我们可以通过查看联机手册获取帮助。访问Linux手册页的命令是 man 语法: m…...

营收净利双降、股东减持,大降价也救不了良品铺子

号称“高端零食第一股”的良品铺子(603719.SH)&#xff0c;正遭遇部分股东的“用脚投票”。 5月17日晚间&#xff0c;良品铺子连发两份减持公告&#xff0c;其控股股东宁波汉意创业投资合伙企业、持股5%以上股东达永有限公司&#xff0c;两者均计划减持。 其中&#xff0c;宁…...

【设计模式】设计模式的分类

通常设计模式的分类有创建型、行为型和结构型。 创建型 常用的有&#xff1a;单例模式、工厂模式&#xff08;工厂方法和抽象工厂&#xff09;、建造者模式。 不常用的有&#xff1a;原型模式。 创建型模式涉及到将对象实例化&#xff0c;这类模式都提供一个方法&#xff0c;将…...

TCP/UDP的连接机制

TCP/UDP的连接机制 TCP的连接机制 TCP&#xff08;Transmission Control Protocol&#xff09;是一种面向连接的协议&#xff0c;提供可靠的、按顺序的数据传输服务。TCP的连接机制包括连接建立、数据传输和连接终止。 1. 连接建立&#xff08;三次握手&#xff09; TCP通过…...

供应链金融模式学习资料

目录 产生背景 供应链金融的诞生 供应链金额的六大特征...

代码随想录-算法训练营day50【动态规划12:最佳买卖股票时机含冷冻期、买卖股票的最佳时机含手续费、股票问题总结】

代码随想录-035期-算法训练营【博客笔记汇总表】-CSDN博客 第九章 动态规划part12● 309.最佳买卖股票时机含冷冻期 ● 714.买卖股票的最佳时机含手续费 ●总结309.最佳买卖股票时机含冷冻期 本题加了一个冷冻期,状态就多了,有点难度,大家要把各个状态分清,思路才能清晰…...

Dilworth 定理

这是一个关于偏序集的定理&#xff0c;事实上它也可以扩展到图论&#xff0c;dp等中&#xff0c;是一个很有意思的东西 偏序集 偏序集是由集合 S S S以及其上的一个偏序关系 R R R定义的&#xff0c;记为 ( S , R ) (S,R) (S,R) 偏序关系&#xff1a; 对于一个二元关系 R ⊂…...

BUUCTF---web---[BJDCTF2020]ZJCTF,不过如此

1、点开连接&#xff0c;页面出现了提示 传入一个参数text&#xff0c;里面的内容要包括I have a dream。 构造&#xff1a;?/textI have a dream。发现页面没有显示。这里推测可能得使用伪协议 在文件包含那一行&#xff0c;我们看到了next.php的提示&#xff0c;我们尝试读取…...

力扣刷题---2206. 将数组划分成相等数对【简单】

题目描述&#x1f357; 给你一个整数数组 nums &#xff0c;它包含 2 * n 个整数。 你需要将 nums 划分成 n 个数对&#xff0c;满足&#xff1a; 每个元素 只属于一个 数对。 同一数对中的元素 相等 。 如果可以将 nums 划分成 n 个数对&#xff0c;请你返回 true &#xf…...

2461. 长度为 K 子数组中的最大和(c++)

给你一个整数数组 nums 和一个整数 k 。请你从 nums 中满足下述条件的全部子数组中找出最大子数组和&#xff1a; 子数组的长度是 k&#xff0c;且子数组中的所有元素 各不相同 。 返回满足题面要求的最大子数组和。如果不存在子数组满足这些条件&#xff0c;返回 0 。 子数…...

range for

1. 基于范围的for循环语法 C11标准引入了基于范围的for循环特性&#xff0c;该特性隐藏了迭代器 的初始化和更新过程&#xff0c;让程序员只需要关心遍历对象本身&#xff0c;其语法也 比传统for循环简洁很多&#xff1a; for ( range_declaration : range_expression ) {loo…...

leetcode230 二叉搜索树中第K小的元素

题目 给定一个二叉搜索树的根节点 root &#xff0c;和一个整数 k &#xff0c;请你设计一个算法查找其中第 k 个最小元素&#xff08;从 1 开始计数&#xff09;。 示例 输入&#xff1a;root [5,3,6,2,4,null,null,1], k 3 输出&#xff1a;3 解析 这道题应该是能做出…...

.Net Core学习笔记 框架特性(注入、配置)

注&#xff1a;直接学习的.Net Core 6&#xff0c;此版本有没有startup.cs相关的内容 项目Program.cs文件中 是定义项目加载 启动的地方 //通过builder对项目进行配置、服务的加载 var builder WebApplication.CreateBuilder(args); builder.Services.AddControllers();//将…...

利用AI技术做电商网赚,这些百万级赛道流量,你还不知道?!

大家好&#xff0c;我是向阳 AI技术的飞速扩展已经势不可挡&#xff0c;不管你承不承认&#xff0c;AI 已经毫无争议的在互联网中占有一席之地了 无论你是做内容产业的&#xff0c;还是做电商的&#xff0c;你现在都躲不开 AI。 现在互联网行业的竞争就是这么残酷 互联网行业…...

leetcode-560 和为k的数组

一、题目描述 给你一个整数数组 nums 和一个整数 k &#xff0c;请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 注意&#xff1a;nums中的元素可为负数 输入&#xff1a;nums [1,1,1], k 2 输出&#xff1a;2输入&#xff1a;num…...

Spring Boot实战指南:从入门到企业级应用构建

目录 一、引言 二、快速入门 1. 使用Spring Initializr创建项目 三、Spring Boot基础概念与自动配置 1. 理解SpringBootApplication注解 2. 自动配置原理 3. 查看自动配置报告 四、Spring Boot核心特性及实战 1. 外部化配置 2. Actuator端点 3. 集成第三方库 五、Sp…...

OneAPI接入本地大模型+FastGPT调用本地大模型

将Ollama下载的本地大模型配置到OneAPI中&#xff0c;并通过FastGPT调用本地大模型完成对话。 OneAPI配置 新建令牌 新建渠道 FastGPT配置 配置docker-compose 配置令牌和OneAPI部署地址 配置config.json 配置调用的渠道名称和大模型名称 {"systemEnv": {&qu…...

浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)

✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义&#xff08;Task Definition&…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

Neo4j 集群管理:原理、技术与最佳实践深度解析

Neo4j 的集群技术是其企业级高可用性、可扩展性和容错能力的核心。通过深入分析官方文档,本文将系统阐述其集群管理的核心原理、关键技术、实用技巧和行业最佳实践。 Neo4j 的 Causal Clustering 架构提供了一个强大而灵活的基石,用于构建高可用、可扩展且一致的图数据库服务…...

linux 下常用变更-8

1、删除普通用户 查询用户初始UID和GIDls -l /home/ ###家目录中查看UID cat /etc/group ###此文件查看GID删除用户1.编辑文件 /etc/passwd 找到对应的行&#xff0c;YW343:x:0:0::/home/YW343:/bin/bash 2.将标红的位置修改为用户对应初始UID和GID&#xff1a; YW3…...

HTML前端开发:JavaScript 常用事件详解

作为前端开发的核心&#xff0c;JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例&#xff1a; 1. onclick - 点击事件 当元素被单击时触发&#xff08;左键点击&#xff09; button.onclick function() {alert("按钮被点击了&#xff01;&…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现&#xff08;两者等价&#xff09;&#xff0c;用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例&#xff1a; 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

tree 树组件大数据卡顿问题优化

问题背景 项目中有用到树组件用来做文件目录&#xff0c;但是由于这个树组件的节点越来越多&#xff0c;导致页面在滚动这个树组件的时候浏览器就很容易卡死。这种问题基本上都是因为dom节点太多&#xff0c;导致的浏览器卡顿&#xff0c;这里很明显就需要用到虚拟列表的技术&…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制

在数字化浪潮席卷全球的今天&#xff0c;数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具&#xff0c;在大规模数据获取中发挥着关键作用。然而&#xff0c;传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时&#xff0c;常出现数据质…...

laravel8+vue3.0+element-plus搭建方法

创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...