当前位置: 首页 > news >正文

PyTorch入门学习(十七):完整的模型训练套路

目录

一、构建神经网络

二、数据准备

三、损失函数和优化器

四、训练模型

五、保存模型


一、构建神经网络

首先,需要构建一个神经网络模型。在示例代码中,构建了一个名为Tudui的卷积神经网络(CNN)模型。这个模型包括卷积层、池化层和全连接层,用于处理图像分类任务。

class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init()self.mode1 = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.mode1(x)return x

二、数据准备

训练深度学习模型需要数据集。在示例中,使用CIFAR-10数据集作为示例数据。数据集的准备包括下载、预处理和分割成训练集和测试集。

import torch
import torchvision
from torch.utils.data import DataLoader# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)train_data_size = len(train_data)
test_data_size = len(test_data)

三、损失函数和优化器

在训练中,需要定义损失函数和优化器。损失函数用于度量模型的输出与真实标签之间的差距,而优化器用于更新模型的参数以减小损失。

loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

四、训练模型

模型训练分为多轮迭代,每轮包括训练和测试步骤。在训练步骤中,通过反向传播算法更新模型参数,以最小化损失函数。在测试步骤中,用测试集验证模型性能。

for epoch in range(10):  # 训练的轮数tudui.train()for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()tudui.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss += loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))

五、保存模型

最后,可以保存训练好的模型,以备后续使用。示例代码展示了两种保存模型的方式,包括保存整个模型和仅保存模型参数。

# 保存方式一
torch.save(tudui, "tudui_{}.pth".format(epoch))
# 保存方式二(官方推荐)
# torch.save(tudui.state_dict(), 'tudui_{}.pth'.format(epoch))

完整代码如下:

import torch
from torch import nn# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui,self).__init__()self.mode1 = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32,64,5,1,2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4,64),nn.Linear(64,10))def forward(self, x):x = self.mode1(x)return xif __name__ == '__main__':tudui = Tudui()input = torch.ones((64,3,32,32))output = tudui(input)print(output.shape)
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from P27_model import *
import time# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)# length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)# 如果train_data_size=10,训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用DataLoader 来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
tudui = Tudui()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0# 训练的轮数
epoch = 10# 添加tensorboard
writer = SummaryWriter("logs_train")
# 添加开始时间
strat_time = time.time()for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始tudui.train()  # 这两个层,只对一部分层起作用,比如 dropout层;如果有这些特殊的层,才需要调用这个语句for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad() # 优化器,梯度清零loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:end_time = time.time()  # 结束时间print(end_time - strat_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss.item()))   # 这里用到的 item()方法,有说法的,其实加不加都行,就是输出的形式不一样而已writer.add_scalar("train_loss", loss.item(),total_train_step)# 每训练完一轮,进行测试,在测试集上测试,以测试集的损失或者正确率,来评估有没有训练好,测试时,就不要调优了,就是以当前的模型,进行测试,所以不用再使用梯度(with no_grad 那句)# 测试步骤开始tudui.eval()  # 这两个层,只对一部分层起作用,比如 dropout层;如果有这些特殊的层,才需要调用这个语句total_test_loss = 0total_accuracy = 0with torch.no_grad():     # 这样后面就没有梯度了,  测试的过程中,不需要更新参数,所以不需要梯度?for data in test_dataloader: # 在测试集中,选取数据imgs, targets = dataoutputs = tudui(imgs)   # 分类的问题,是可以这样的,用一个output进行绘制loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()     # 为了查看总体数据上的 loss,创建的 total_test_loss,初始值是0accuracy = (outputs.argmax(1) == targets).sum()  # 正确率,这是分类问题中,特有的一种,评价指标,语义分割之类的,不一定非要有这个东西,这里是存疑的,再看。total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))   # 即便是输出了上一行的 loss,也不能很好的表现出效果。# 在分类问题上比较特有,通常使用正确率来表示优劣。因为其他问题,可以可视化地显示在tensorboard中。writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)total_test_step = total_test_step + 1# print(total_test_step)# 保存方式一,其实后缀都可以自己取,习惯用 .pth。torch.save(tudui, "tudui_{}.pth".format(i))# 保存方式2(官方推荐)# torch.save(model.state_dict(), pth_dir + '/model_{}.pth'.format(i)print("模型已保存")writer.close()

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

相关文章:

PyTorch入门学习(十七):完整的模型训练套路

目录 一、构建神经网络 二、数据准备 三、损失函数和优化器 四、训练模型 五、保存模型 一、构建神经网络 首先,需要构建一个神经网络模型。在示例代码中,构建了一个名为Tudui的卷积神经网络(CNN)模型。这个模型包括卷积层、…...

《 Hello 算法 》 - 免费开源的数据结构与算法入门教程电子书,包含大量动画、图解,通俗易懂

这本学习算法的电子书应该是我看过这方面最好的书了,代码例子有多种编程语言,JavaScript 也支持。 《 Hello 算法 》,英文名称是 Hello algo,是一本关于编程中数据解构和算法入门的电子书,作者是毕业于上海交通大学的…...

数据库之事务

数据库之事务 事务的特点: ACID 原子性 一致性:数据库的完整性约束,不能被破坏 隔离性 持久性:数据一旦提交,事务的效果将会被永久的保留在数据库中。而且不会被回滚 主从复制 高可用 备份 权限控制 脏读&am…...

NOIP2023模拟12联测33 B. 游戏

NOIP2023模拟12联测33 B. 游戏 文章目录 NOIP2023模拟12联测33 B. 游戏题目大意思路code 题目大意 期望题 思路 二分答案 m i d mid mid ,我们只关注学生是否能够使得被抓的人数 ≤ m i d \le mid ≤mid 那我们就只关心 a > m i d a > mid a>mid 的房…...

代码随想录打卡第六十三天|84.柱状图中最大的矩形

84.柱状图中最大的矩形 题目&#xff1a;给定 n 个非负整数&#xff0c;用来表示柱状图中各个柱子的高度。每个柱子彼此相邻&#xff0c;且宽度为 1 。求在该柱状图中&#xff0c;能够勾勒出来的矩形的最大面积。 提示&#xff1a; 1 < heights.length <105 0 < h…...

python tempfile 模块使用

在Python中&#xff0c;tempfile 模块用于创建临时文件和目录&#xff0c;它们可以用于存储中间处理数据&#xff0c;不需要长期保存。该模块提供了几种不同的类和函数来创建临时文件和目录。 下面是几个常用的 tempfile 使用方法&#xff1a; 临时文件 使用 NamedTemporary…...

【软件测试】接口测试实战详解

最近找到了几个问题&#xff0c;都还比较有代表性。 作为一个初级测试&#xff0c;想学接口测试&#xff0c;但是一点头绪都没有。求教大神指点&#xff0c;有没有好的书或者工具推荐&#xff1f;如何做接口测试呢&#xff1f;接口测试有哪些工具做接口测试的流程一般是怎么样…...

轻量封装WebGPU渲染系统示例<20>- 美化一下元胞自动机之生命游戏(源码)

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/GameOfLifePretty.ts 系统特性: 1. 用户态与系统态隔离。 2. 高频调用与低频调用隔离。 3. 面向用户的易用性封装。 4. 渲染数据(内外部相关资源)和渲染机制分离…...

Nodejs的安装以及配置(node-v12.16.1-x64.msi)

Nodejs的安装以及配置 1、安装 node-v12.16.1-x64.msi点击安装&#xff0c;注意以下步骤 本文设置nodejs的安装的路径&#xff1a;D:\soft\nodejs 继续点击next&#xff0c;选中Add to PATH &#xff0c;旁边的英文告诉我们会把 环境变量 给我们配置好 当然也可以只选择 Nod…...

03【保姆级】-GO语言变量和数据类型和相互转换

03【保姆级】-GO语言变量和数据类型 一、变量1.1 变量的定义&#xff1a;1.2 变量的声明、初始化、赋值1.3 变量使用的注意事项 插播-关于fmt.Printf格式打印%的作用二、 变量的数据类型2.1整数的基本类型2.1.1 有符号类型 int8/16/32/642.1.2 无符号类型 int8/16/32/642.1.3 整…...

mermaid学习第一天/更改主题颜色和边框颜色/《需求解释流程图》

mermaid 在线官网&#xff1a; https://mermaid-js.github.io/ 在线学习文件&#xff1a; https://mermaid.js.org/syntax/quadrantChart.html 1、今天主要是想做需求解释的流程图&#xff0c;又不想自己画&#xff0c;就用了&#xff0c;框框不能直接进行全局配置&#xff0…...

SAP MASS增加PR字段-删除标识

MASS->BUS2105->发现没有找到PR删除标识字段 SAP MASS增加PR字段-删除标识 1.tcode:MASSOBJ 选中BUS2105 点“应用程序表” 点“字段列表” 2.选中一行进行参考 3.修改字段为删除标识 LOEKZ&#xff0c;保存即可。 4.然后MASS操作&#xff0c;批量设置删除标识&…...

【手把手教你】训练YOLOv8分割模型

1.下载文件 在github上下载YOLOV8模型的文件&#xff0c;搜索yolov8&#xff0c;star最多这个就是 2. 准备环境 环境要求python>3.8&#xff0c;PyTorch>1.8&#xff0c;自行安装ptyorch环境即可 2. 制作数据集 制作数据集&#xff0c;需要使用labelme这个包&#…...

物料主数据增强屏幕绘制器DUMP

问题描述 在做完物料主数据增强后&#xff0c;配置和代码传Q&#xff0c;在Q进入增强屏幕绘制器报错。 错误 CALLBACK_REJECTED_BY_WHITELIST RFC callback call rejected by positive list An RFC callback has been prevented due to no corresponding positive list en…...

vue 实现在线预览Excel-LuckyExcel/LuckySheet实现方案

一、准备工作 1. npm安装 luckyexcel npm i -D luckyexcel 2.引入luckysheet 注意&#xff1a;引入luckysheet&#xff0c;只能通过CDN或者直接引入静态资源的形式&#xff0c;不能npm install。 个人建议直接下载资源引入。我给你们提供一个下载资源的地址&#xff1a; …...

AIGPT重大升级,界面重新设计,功能更加饱满,用户体验升级

AIGPT AIGPT是一款功能强大的人工智能技术处理软件&#xff0c;不但拥有其他模型处理文本认知的能力还有AI绘画模型、拥有自身的插件库。 我们都知道使用ChatGPT是需要账号以及使用魔法的&#xff0c;实现其中的某一项对我们一般的初学者来说都是一次巨大的挑战&#xff0c;但…...

Web逆向-mtgsig1.2简单分析

{"a1": "1.2", # 加密版本"a2": new Date().valueOf() - serverTimeDiff, # 加密过程中用到的时间戳. 这次服主变坏了, 时间戳需要减去一个 serverTimeDiff(见a3) ! "a3": "这是把xxx信息加密后提交给服务器, 服主…...

【蓝桥杯省赛真题41】Scratch电脑开关机 蓝桥杯少儿编程scratch图形化编程 蓝桥杯省赛真题讲解

目录 scratch电脑开关机 一、题目要求 编程实现 二、案例分析 1、角色分析...

第10章 Java常用类

目录 内容说明 章节内容 一、Object类 二、String类和StringBuffer类 三、Math类和Random类...

Android 11 getPackageManager().getPackageInfo 返回null

Android11 之后&#xff0c; 在查找用户手机是否有安装app&#xff0c;进行查询包名是否存在时&#xff0c;getPackageManager().getPackageInfo&#xff08;&#xff09;这个函数一直返回null &#xff0c;Android 11增加了权限要求。 1、只是查询指定的App 包 只需要在Andro…...

利用ngx_stream_return_module构建简易 TCP/UDP 响应网关

一、模块概述 ngx_stream_return_module 提供了一个极简的指令&#xff1a; return <value>;在收到客户端连接后&#xff0c;立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量&#xff08;如 $time_iso8601、$remote_addr 等&#xff09;&a…...

UDP(Echoserver)

网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法&#xff1a;netstat [选项] 功能&#xff1a;查看网络状态 常用选项&#xff1a; n 拒绝显示别名&#…...

Keil 中设置 STM32 Flash 和 RAM 地址详解

文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...

网络编程(UDP编程)

思维导图 UDP基础编程&#xff08;单播&#xff09; 1.流程图 服务器&#xff1a;短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列&#xff1f;2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

【笔记】WSL 中 Rust 安装与测试完整记录

#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统&#xff1a;Ubuntu 24.04 LTS (WSL2)架构&#xff1a;x86_64 (GNU/Linux)Rust 版本&#xff1a;rustc 1.87.0 (2025-05-09)Cargo 版本&#xff1a;cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...

Netty从入门到进阶(二)

二、Netty入门 1. 概述 1.1 Netty是什么 Netty is an asynchronous event-driven network application framework for rapid development of maintainable high performance protocol servers & clients. Netty是一个异步的、基于事件驱动的网络应用框架&#xff0c;用于…...

腾讯云V3签名

想要接入腾讯云的Api&#xff0c;必然先按其文档计算出所要求的签名。 之前也调用过腾讯云的接口&#xff0c;但总是卡在签名这一步&#xff0c;最后放弃选择SDK&#xff0c;这次终于自己代码实现。 可能腾讯云翻新了接口文档&#xff0c;现在阅读起来&#xff0c;清晰了很多&…...

【C++进阶篇】智能指针

C内存管理终极指南&#xff1a;智能指针从入门到源码剖析 一. 智能指针1.1 auto_ptr1.2 unique_ptr1.3 shared_ptr1.4 make_shared 二. 原理三. shared_ptr循环引用问题三. 线程安全问题四. 内存泄漏4.1 什么是内存泄漏4.2 危害4.3 避免内存泄漏 五. 最后 一. 智能指针 智能指…...

CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝

目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为&#xff1a;一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...