PyTorch入门学习(十七):完整的模型训练套路
目录
一、构建神经网络
二、数据准备
三、损失函数和优化器
四、训练模型
五、保存模型
一、构建神经网络
首先,需要构建一个神经网络模型。在示例代码中,构建了一个名为Tudui的卷积神经网络(CNN)模型。这个模型包括卷积层、池化层和全连接层,用于处理图像分类任务。
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init()self.mode1 = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.mode1(x)return x
二、数据准备
训练深度学习模型需要数据集。在示例中,使用CIFAR-10数据集作为示例数据。数据集的准备包括下载、预处理和分割成训练集和测试集。
import torch
import torchvision
from torch.utils.data import DataLoader# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)train_data_size = len(train_data)
test_data_size = len(test_data)
三、损失函数和优化器
在训练中,需要定义损失函数和优化器。损失函数用于度量模型的输出与真实标签之间的差距,而优化器用于更新模型的参数以减小损失。
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)
四、训练模型
模型训练分为多轮迭代,每轮包括训练和测试步骤。在训练步骤中,通过反向传播算法更新模型参数,以最小化损失函数。在测试步骤中,用测试集验证模型性能。
for epoch in range(10): # 训练的轮数tudui.train()for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()tudui.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss += loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))
五、保存模型
最后,可以保存训练好的模型,以备后续使用。示例代码展示了两种保存模型的方式,包括保存整个模型和仅保存模型参数。
# 保存方式一
torch.save(tudui, "tudui_{}.pth".format(epoch))
# 保存方式二(官方推荐)
# torch.save(tudui.state_dict(), 'tudui_{}.pth'.format(epoch))
完整代码如下:
import torch
from torch import nn# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui,self).__init__()self.mode1 = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32,64,5,1,2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4,64),nn.Linear(64,10))def forward(self, x):x = self.mode1(x)return xif __name__ == '__main__':tudui = Tudui()input = torch.ones((64,3,32,32))output = tudui(input)print(output.shape)
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from P27_model import *
import time# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)# length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)# 如果train_data_size=10,训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用DataLoader 来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
tudui = Tudui()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0# 训练的轮数
epoch = 10# 添加tensorboard
writer = SummaryWriter("logs_train")
# 添加开始时间
strat_time = time.time()for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始tudui.train() # 这两个层,只对一部分层起作用,比如 dropout层;如果有这些特殊的层,才需要调用这个语句for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad() # 优化器,梯度清零loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:end_time = time.time() # 结束时间print(end_time - strat_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss.item())) # 这里用到的 item()方法,有说法的,其实加不加都行,就是输出的形式不一样而已writer.add_scalar("train_loss", loss.item(),total_train_step)# 每训练完一轮,进行测试,在测试集上测试,以测试集的损失或者正确率,来评估有没有训练好,测试时,就不要调优了,就是以当前的模型,进行测试,所以不用再使用梯度(with no_grad 那句)# 测试步骤开始tudui.eval() # 这两个层,只对一部分层起作用,比如 dropout层;如果有这些特殊的层,才需要调用这个语句total_test_loss = 0total_accuracy = 0with torch.no_grad(): # 这样后面就没有梯度了, 测试的过程中,不需要更新参数,所以不需要梯度?for data in test_dataloader: # 在测试集中,选取数据imgs, targets = dataoutputs = tudui(imgs) # 分类的问题,是可以这样的,用一个output进行绘制loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item() # 为了查看总体数据上的 loss,创建的 total_test_loss,初始值是0accuracy = (outputs.argmax(1) == targets).sum() # 正确率,这是分类问题中,特有的一种,评价指标,语义分割之类的,不一定非要有这个东西,这里是存疑的,再看。total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size)) # 即便是输出了上一行的 loss,也不能很好的表现出效果。# 在分类问题上比较特有,通常使用正确率来表示优劣。因为其他问题,可以可视化地显示在tensorboard中。writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)total_test_step = total_test_step + 1# print(total_test_step)# 保存方式一,其实后缀都可以自己取,习惯用 .pth。torch.save(tudui, "tudui_{}.pth".format(i))# 保存方式2(官方推荐)# torch.save(model.state_dict(), pth_dir + '/model_{}.pth'.format(i)print("模型已保存")writer.close()
参考资料:
视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
相关文章:
PyTorch入门学习(十七):完整的模型训练套路
目录 一、构建神经网络 二、数据准备 三、损失函数和优化器 四、训练模型 五、保存模型 一、构建神经网络 首先,需要构建一个神经网络模型。在示例代码中,构建了一个名为Tudui的卷积神经网络(CNN)模型。这个模型包括卷积层、…...
《 Hello 算法 》 - 免费开源的数据结构与算法入门教程电子书,包含大量动画、图解,通俗易懂
这本学习算法的电子书应该是我看过这方面最好的书了,代码例子有多种编程语言,JavaScript 也支持。 《 Hello 算法 》,英文名称是 Hello algo,是一本关于编程中数据解构和算法入门的电子书,作者是毕业于上海交通大学的…...
数据库之事务
数据库之事务 事务的特点: ACID 原子性 一致性:数据库的完整性约束,不能被破坏 隔离性 持久性:数据一旦提交,事务的效果将会被永久的保留在数据库中。而且不会被回滚 主从复制 高可用 备份 权限控制 脏读&am…...
NOIP2023模拟12联测33 B. 游戏
NOIP2023模拟12联测33 B. 游戏 文章目录 NOIP2023模拟12联测33 B. 游戏题目大意思路code 题目大意 期望题 思路 二分答案 m i d mid mid ,我们只关注学生是否能够使得被抓的人数 ≤ m i d \le mid ≤mid 那我们就只关心 a > m i d a > mid a>mid 的房…...
代码随想录打卡第六十三天|84.柱状图中最大的矩形
84.柱状图中最大的矩形 题目:给定 n 个非负整数,用来表示柱状图中各个柱子的高度。每个柱子彼此相邻,且宽度为 1 。求在该柱状图中,能够勾勒出来的矩形的最大面积。 提示: 1 < heights.length <105 0 < h…...
python tempfile 模块使用
在Python中,tempfile 模块用于创建临时文件和目录,它们可以用于存储中间处理数据,不需要长期保存。该模块提供了几种不同的类和函数来创建临时文件和目录。 下面是几个常用的 tempfile 使用方法: 临时文件 使用 NamedTemporary…...
【软件测试】接口测试实战详解
最近找到了几个问题,都还比较有代表性。 作为一个初级测试,想学接口测试,但是一点头绪都没有。求教大神指点,有没有好的书或者工具推荐?如何做接口测试呢?接口测试有哪些工具做接口测试的流程一般是怎么样…...
轻量封装WebGPU渲染系统示例<20>- 美化一下元胞自动机之生命游戏(源码)
当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/GameOfLifePretty.ts 系统特性: 1. 用户态与系统态隔离。 2. 高频调用与低频调用隔离。 3. 面向用户的易用性封装。 4. 渲染数据(内外部相关资源)和渲染机制分离…...
Nodejs的安装以及配置(node-v12.16.1-x64.msi)
Nodejs的安装以及配置 1、安装 node-v12.16.1-x64.msi点击安装,注意以下步骤 本文设置nodejs的安装的路径:D:\soft\nodejs 继续点击next,选中Add to PATH ,旁边的英文告诉我们会把 环境变量 给我们配置好 当然也可以只选择 Nod…...
03【保姆级】-GO语言变量和数据类型和相互转换
03【保姆级】-GO语言变量和数据类型 一、变量1.1 变量的定义:1.2 变量的声明、初始化、赋值1.3 变量使用的注意事项 插播-关于fmt.Printf格式打印%的作用二、 变量的数据类型2.1整数的基本类型2.1.1 有符号类型 int8/16/32/642.1.2 无符号类型 int8/16/32/642.1.3 整…...
mermaid学习第一天/更改主题颜色和边框颜色/《需求解释流程图》
mermaid 在线官网: https://mermaid-js.github.io/ 在线学习文件: https://mermaid.js.org/syntax/quadrantChart.html 1、今天主要是想做需求解释的流程图,又不想自己画,就用了,框框不能直接进行全局配置࿰…...
SAP MASS增加PR字段-删除标识
MASS->BUS2105->发现没有找到PR删除标识字段 SAP MASS增加PR字段-删除标识 1.tcode:MASSOBJ 选中BUS2105 点“应用程序表” 点“字段列表” 2.选中一行进行参考 3.修改字段为删除标识 LOEKZ,保存即可。 4.然后MASS操作,批量设置删除标识&…...
【手把手教你】训练YOLOv8分割模型
1.下载文件 在github上下载YOLOV8模型的文件,搜索yolov8,star最多这个就是 2. 准备环境 环境要求python>3.8,PyTorch>1.8,自行安装ptyorch环境即可 2. 制作数据集 制作数据集,需要使用labelme这个包&#…...
物料主数据增强屏幕绘制器DUMP
问题描述 在做完物料主数据增强后,配置和代码传Q,在Q进入增强屏幕绘制器报错。 错误 CALLBACK_REJECTED_BY_WHITELIST RFC callback call rejected by positive list An RFC callback has been prevented due to no corresponding positive list en…...
vue 实现在线预览Excel-LuckyExcel/LuckySheet实现方案
一、准备工作 1. npm安装 luckyexcel npm i -D luckyexcel 2.引入luckysheet 注意:引入luckysheet,只能通过CDN或者直接引入静态资源的形式,不能npm install。 个人建议直接下载资源引入。我给你们提供一个下载资源的地址: …...
AIGPT重大升级,界面重新设计,功能更加饱满,用户体验升级
AIGPT AIGPT是一款功能强大的人工智能技术处理软件,不但拥有其他模型处理文本认知的能力还有AI绘画模型、拥有自身的插件库。 我们都知道使用ChatGPT是需要账号以及使用魔法的,实现其中的某一项对我们一般的初学者来说都是一次巨大的挑战,但…...
Web逆向-mtgsig1.2简单分析
{"a1": "1.2", # 加密版本"a2": new Date().valueOf() - serverTimeDiff, # 加密过程中用到的时间戳. 这次服主变坏了, 时间戳需要减去一个 serverTimeDiff(见a3) ! "a3": "这是把xxx信息加密后提交给服务器, 服主…...
【蓝桥杯省赛真题41】Scratch电脑开关机 蓝桥杯少儿编程scratch图形化编程 蓝桥杯省赛真题讲解
目录 scratch电脑开关机 一、题目要求 编程实现 二、案例分析 1、角色分析...
第10章 Java常用类
目录 内容说明 章节内容 一、Object类 二、String类和StringBuffer类 三、Math类和Random类...
Android 11 getPackageManager().getPackageInfo 返回null
Android11 之后, 在查找用户手机是否有安装app,进行查询包名是否存在时,getPackageManager().getPackageInfo()这个函数一直返回null ,Android 11增加了权限要求。 1、只是查询指定的App 包 只需要在Andro…...
铭豹扩展坞 USB转网口 突然无法识别解决方法
当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...
OpenLayers 可视化之热力图
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 热力图(Heatmap)又叫热点图,是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...
树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...
深入理解JavaScript设计模式之单例模式
目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式(Singleton Pattern&#…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)
宇树机器人多姿态起立控制强化学习框架论文解析 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一) 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...
(转)什么是DockerCompose?它有什么作用?
一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器。 Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...
【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分
一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计,提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合:各模块职责清晰,便于独立开发…...
Fabric V2.5 通用溯源系统——增加图片上传与下载功能
fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...
华为OD机考-机房布局
import java.util.*;public class DemoTest5 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseSystem.out.println(solve(in.nextLine()));}}priv…...
