当前位置: 首页 > news >正文

学习pytorch18 pytorch完整的模型训练流程

pytorch完整的模型训练流程

  • 1. 流程
    • 1. 整理训练数据 使用CIFAR10数据集
    • 2. 搭建网络结构
    • 3. 构建损失函数
    • 4. 使用优化器
    • 5. 训练模型
    • 6. 测试数据 计算模型预测正确率
    • 7. 保存模型
  • 2. 代码
    • 1. model.py
    • 2. train.py
  • 3. 结果
    • tensorboard结果
      • 以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值
      • 训练loss
      • 测试loss
      • 测试集正确率
  • 4. 需要注意的细节

1. 流程

1. 整理训练数据 使用CIFAR10数据集

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)

2. 搭建网络结构

在这里插入图片描述
model.py

3. 构建损失函数

loss_fn = nn.CrossEntropyLoss()

4. 使用优化器

learing_rate = 1e-2 # 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

5. 训练模型

output = net(imgs)    # 数据输入模型
loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
loss.backward()        # 设置计算的损失值的钩子,调用损失的反向传播,计算每个参数结点的参数
optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化  

6. 测试数据 计算模型预测正确率

output = net(imags)
# 计算测试集的正确率
preds = (output.argmax(1)==targets).sum()
accuracy += preds 
rate = accuracy/len(test_data)

调用模型输出tensor 数据类型的 argmax方法, argmax或获取一行或者一列数值中最大数值的下标位置,argmax(0) 是从列的维度取一列数值的最大值的下标,argmax(1) 是从行的维度取一行数值的最大值的下标
output.argmax(1)==targets 会输出如下图最后一行 [false, ture], 对应位置相同则为true,对应位置不同则为false;
调用sum()方法,计算求和,false值为0,true值为1.
最后计算得出测试集整体正确率: rate = accuracy/len(test_data)
在这里插入图片描述

7. 保存模型

torch.save(net, './net_epoch{}.pth'.format(i))

2. 代码

1. model.py

import torch
from torch import nn# 2. 搭建模型网络结构--神经网络
class Cifar10Net(nn.Module):def __init__(self):super(Cifar10Net, self).__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.net(x)return xif __name__ == '__main__':net = Cifar10Net()input = torch.ones((64, 3, 32, 32))output = net(input)print(output.shape)

2. train.py

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriterfrom p24_model import *# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoadertrain_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)# 2. 导入模型结构 创建模型
net = Cifar10Net()# 3. 创建损失函数  分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)# 设置训练网络的一些参数
epoch = 10   # 记录训练的轮数
total_train_step = 0  # 记录训练的次数
total_test_step = 0   # 记录测试的次数# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')for i in range(epoch):# 训练步骤开始net.train()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用for data in train_loader:imgs, targets = data  # 获取数据output = net(imgs)    # 数据输入模型loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0loss.backward()        # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化# 优化一次 认为训练了一次total_train_step += 1if total_train_step % 100 == 0:print('训练次数: {}   loss: {}'.format(total_train_step, loss))# 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】# print('训练次数: {}   loss: {}'.format(total_train_step, loss.item()))writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)# 利用现有模型做模型测试# 测试步骤开始total_test_loss = 0accuracy = 0net.eval()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用with torch.no_grad():for data in test_loader:imags, targets = dataoutput = net(imags)loss = loss_fn(output, targets)total_test_loss += loss.item()# 计算测试集的正确率preds = (output.argmax(1)==targets).sum()accuracy += preds# writer.add_scalar('test-loss', total_test_loss, global_step=i+1)writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)total_test_step += 1print("---------test loss: {}--------------".format(total_test_loss))print("---------test accuracy: {}--------------".format(accuracy))# 保存每一个epoch训练得到的模型torch.save(net, './net_epoch{}.pth'.format(i))writer.close()

3. 结果

训练数据集大小: 50000
测试数据集大小: 10000
0.01
训练次数: 100   loss: 2.2905373573303223
训练次数: 200   loss: 2.2878968715667725
训练次数: 300   loss: 2.258394718170166
训练次数: 400   loss: 2.1968581676483154
训练次数: 500   loss: 2.0476632118225098
训练次数: 600   loss: 2.002145767211914
训练次数: 700   loss: 2.016021728515625
---------test loss: 316.382279753685--------------
训练次数: 800   loss: 1.8957302570343018
训练次数: 900   loss: 1.8659226894378662
训练次数: 1000   loss: 1.9004186391830444
训练次数: 1100   loss: 1.9708642959594727
......

tensorboard结果

安装tensorboard运行环境

pip install tensorboard
pip install opencv-python
pip install six
tensorboard --logdir=train_logs

以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值

训练loss

在这里插入图片描述

测试loss

在这里插入图片描述

测试集正确率

在这里插入图片描述

4. 需要注意的细节

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module

所有网络层继承于torch.nn.Module, net.train() net.eval() 在模型训练或测试之初 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用,当模型有这两个网络层的时候,两个代码需要加上。
在这里插入图片描述

在这里插入图片描述

相关文章:

学习pytorch18 pytorch完整的模型训练流程

pytorch完整的模型训练流程 1. 流程1. 整理训练数据 使用CIFAR10数据集2. 搭建网络结构3. 构建损失函数4. 使用优化器5. 训练模型6. 测试数据 计算模型预测正确率7. 保存模型 2. 代码1. model.py2. train.py 3. 结果tensorboard结果以下图片 颜色较浅的线是真实计算的值&#x…...

电子学会C/C++编程等级考试2021年09月(五级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:抓牛 农夫知道一头牛的位置,想要抓住它。农夫和牛都位于数轴上,农夫起始位于点N(0<=N<=100000),牛位于点K(0<=K<=100000)。农夫有两种移动方式: 1、从X移动到X-1或X+1,每次移动花费一分钟 2、从X移动到2*X,每…...

Halcon联合winform显示以及处理

在窗口中添加窗体和按钮&#xff0c;并在解决方案资源管理器中调加了导入Halcon导出的.cs文件&#xff0c;运行出现下图的问题&#xff1a; 问题1&#xff1a;CS0017 程序定义了多个入口点。使用/main(指定包含入口点的类型&#xff09;进行编译。 解决方案1.&#xff1a; 右…...

【设计模式-4.3】行为型——责任链模式

说明&#xff1a;本文介绍设计模式中行为型设计模式中的&#xff0c;责任链模式&#xff1b; 审批流程 责任链模式属于行为型设计模式&#xff0c;关注于对象的行为。责任链模式非常典型的案例&#xff0c;就是审批流程的实现。如一个报销单的审批流程&#xff0c;根据报销单…...

单片机语言--C51语言的数据类型以及存储类型以及一些基本运算

C51语言 本文主要涉及C51语言的一些基本知识&#xff0c;比如C51语言的数据类型以及存储类型以及一些基本运算。 文章目录 C51语言一、 C51与标准C的比较二、 C51语言中的数据类型与存储类型2.1、C51的扩展数据类型2.2、数据存储类型 三、 C51的基本运算3.1 算术运算符3.2 逻辑…...

《每天一个Linux命令》 -- (5)通过sshkey密钥登录服务器

欢迎阅读《每天一个Linux命令》系列&#xff01;在本篇文章中&#xff0c;将介绍通过密钥生成&#xff0c;使用公钥连接管理服务器。 概念 SSH 密钥是用于安全地访问远程服务器的一种方法。SSH 密钥由一对密钥组成&#xff1a;公钥和私钥。公钥存储在远程服务器上&#xff0c;…...

kubernetes的服务发现(二)

如前面的文章我们说了&#xff0c;kubernetes的服务发现是服务端发现模式。它有一个服务注册中心&#xff0c;使用DNS作为服务的注册表。每个集群都会运行一个DNS服务&#xff0c;默认是CoreDNS服务。每个服务都会在这个DNS中注册。注册的大致过程&#xff1a; 1、向kube-apise…...

【矩阵论】Chapter 4—特征值和特征向量知识点总结复习

文章目录 1 特征值和特征向量2 对角化3 Schur定理和正规矩阵4 Python求解 1 特征值和特征向量 定义 设 σ \sigma σ为数域 F F F上线性空间 V V V上的一个线性变换&#xff0c;一个非零向量 v ∈ V v\in V v∈V&#xff0c;如果存在一个 λ ∈ F \lambda \in F λ∈F使得 σ (…...

Linux 进程地址空间

知识回顾 在 C 语言的学习过程中&#xff0c;我们知道内存是可以被划分为栈区&#xff0c;堆区&#xff0c;全局数据区&#xff0c;字符常量区&#xff0c;代码区的。他的空间排布可能是下面的样子&#xff1a; 其中&#xff0c;全局数据区&#xff0c;可以划分为已初始化全局…...

websocket vue操作

let websocket: WebSocket; /** websocket测试 */ function connectWebsocket() {if (typeof WebSocket "undefined") {console.log("您的浏览器不支持WebSocket");return;}// let ip window.location.hostname ":8080";let ip "10.192…...

腾讯云CentOS8 jenkins war安装jenkins步骤文档

腾讯云CentOS8 jenkins war安装jenkins步骤文档 一、安装jdk 1.1 上传jdk-11.0.20_linux-x64_bin.tar.gz 1.2 解压jdk安装包文件 tar -zxvf jdk*.tar.gz 1.3 在/usr/local 目录下创建java目录 cd /usr/local mkdir java 1.4 切到java目录&#xff0c;把jdk解压文件改名为jd…...

Linux: glibc: net/if.h vs linux/if.h

最近看到一段代码改动,用net/if.h替换了linux/if.h。仔细看了看这两个的区别: https://stackoverflow.com/questions/20082433/what-is-the-difference-between-linux-if-h-and-net-if-h 从网上搜了一下看到如下的一个编译错误,如果同时使用这两个if.h文件,需要将net/if.h…...

使用Android Studio导入Android源码:基于全志H713 AOSP,方便解决编译、编码问题

文章目录 一、 篇头二、 操作步骤2.1 编译AOSP AS工程文件2.2 将AOSP导入Android Studio2.3 切到Project试图2.4 等待index结束2.5 下载缺失的JDK 1.82.6 导入完成 三、 导入AS的好处3.1 本文案例演示源码编译错误AS对比同文件其余地方的调用AS错误提示依赖AS做错误修正 一、 篇…...

python random详解

文章目录 random简单示例1. 生成随机浮点数&#xff1a;2. 生成指定范围内的随机整数&#xff1a;3. 从序列中随机选择元素&#xff1a;4. 打乱序列顺序&#xff1a; 常用的方法及其解释和例子&#xff1a;1. random()&#xff1a;该方法返回一个0到1之间的随机浮点数。例如&am…...

java-两个列表进行比较,判断那些是需要新增的、删除的、和更新的

文章目录 前言两个列表进行比较&#xff0c;判断那些是需要新增的、删除的、和更新的 前言 如果您觉得有用的话&#xff0c;记得给博主点个赞&#xff0c;评论&#xff0c;收藏一键三连啊&#xff0c;写作不易啊^ _ ^。   而且听说点赞的人每天的运气都不会太差&#xff0c;实…...

【WPF.NET开发】WPF中的对话框

目录 1、消息框 2、通用对话框 3、自定义对话框 实现对话框 4、打开对话框的 UI 元素 4.1 菜单项 4.2 按钮 5、返回结果 5.1 模式对话框 5.2 处理响应 5.3 非模式对话框 Windows Presentation Foundation (WPF) 为你提供了自行设计对话框的方法。 对话框是窗口&…...

NLP项目实战01之电影评论分类

介绍&#xff1a; 欢迎来到本篇文章&#xff01;在这里&#xff0c;我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言&#xff0c;我们将关注情感分析任务&#xff0c;即通过分析电影评论的情感来判断评论是正面的、负面的。 展示&#xff1a; 训练展示如下…...

一款可无限扩展的软件定时器开源框架项目代码

摘自链接 时间片轮询架构如何稳定高效实现&#xff0c;取代传统的标志位判断方式&#xff0c;更优雅更方便地管理程序的时间触发操作。 可以在STM32单片机上运行。...

GRE与顺丰圆通快递盒子

1. DNS污染 随想&#xff1a; 在输入一串网址后&#xff0c;会发生如下变化如果你在系统中配置了 Hosts 文件&#xff0c;那么电脑会先查询 Hosts 文件如果 Hosts 里面没有这个别名&#xff0c;就通过域名服务器查询域名服务器回应了&#xff0c;那么你的电脑就可以根据域名服…...

12.Mysql 多表数据横向合并和纵向合并

Mysql 函数参考和扩展&#xff1a;Mysql 常用函数和基础查询、 Mysql 官网 Mysql 语法执行顺序如下&#xff0c;一定要清楚&#xff01;&#xff01;&#xff01;运算符相关&#xff0c;可前往 Mysql 基础语法和执行顺序扩展。 (8) select (9) distinct (11)<columns_name…...

线性回归与逻辑回归:深入解析机器学习的基石模型

目录 一、线性回归 二、逻辑回归 逻辑回归算法和 KNN 算法的区别 分类算法评价维度...

电脑待机怎么设置?让你的电脑更加节能

在日常使用电脑的过程中&#xff0c;合理设置待机模式是一项省电且环保的好习惯。然而&#xff0c;许多用户对于如何设置电脑待机感到困扰。那么电脑待机怎么设置呢&#xff1f;本文将深入探讨三种常用的电脑待机设置方法&#xff0c;通过详细的步骤&#xff0c;帮助用户更好地…...

数据库对象介绍与实践:视图、函数、存储过程、触发器和物化视图

文章目录 一、视图&#xff08;View&#xff09;1、概念2、基本操作1&#xff09;创建视图2&#xff09;修改视图3&#xff09;删除视图4&#xff09;使用视图 3、使用场景4、实践 二、函数&#xff08;Function&#xff09;1、概念2、基本操作1&#xff09;创建函数2&#xff…...

arm平台编译so文件回顾

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、几个点二、回顾过程 1.上来就执行Makefile2.编译第三方开源库.a文件 2.1 build.sh脚本2.2 Makefile3.最终编译三、其它知识点总结 前言 提示&#xff1a;这…...

【数据结构】顺序表的定义和运算

目录 1.初始化 2.插入 3.删除 4.查找 5.修改 6.长度 7.遍历 8.完整代码 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高兴与大家相识&#xff0c;希望我的博客能对你有所帮助。 &#x1f4a1;本文由Filotimo__✍️原创&#xff0c;首发于CSDN&#x1f4da;。 &…...

idea使用maven的package打包时提示“找不到符号”或“找不到包”

介绍&#xff1a;由于我们的项目是多模块开发项目&#xff0c;在打包时有些模块内容更新导致其他模块在引用该模块时不能正确引入。 情况一&#xff1a;找不到符号 情况一&#xff1a;找不到包 错误代码部分展示&#xff1a; Failure to find com.xxx.xxxx:xxx:pom:0.5 in …...

MetricBeat监控MySQL

目录 一、安装部署 二、开启mysql监控模块 三、编辑mysql配置文件 四、启动Metricbeat 五、查看监控图表 一、安装部署 metriceat的安装部署参考章节&#xff1a; Metricbeat安装使用&#xff0c;这里不再赘述。 二、开启mysql监控模块 进入metricbeat安装目录 ./metricb…...

Child Mind Institute - Detect Sleep States(2023年第一次Kaggle拿到了银牌总结)

感谢 感谢艾兄&#xff08;大佬带队&#xff09;、rich师弟&#xff08;师弟通过这次比赛机械转码成功、耐心学习&#xff09;、张同学&#xff08;也很有耐心的在学习&#xff09;&#xff0c;感谢开源方案&#xff08;开源就是银牌&#xff09;&#xff0c;在此基础上一个月…...

Esxi7Esxi8设置VMFSL虚拟闪存的大小

Esxi7Esxi8设置VMFSL虚拟闪存的大小 ESXi7,8 默认安装会分配一个 VMFSL(VMFS-L)(Local VMFS)很大空间(120G), 感觉很浪费, 实际给 8G 就可以了, 最少 6G , 经实验,给2G没法安装 . Esxi7是虚拟闪存的 修改的方法是: 在安装时修改 设置 autoPartitionOSDataSize8192 在cdromBoo…...

vue2+electron桌面端一体机应用

vue2+electron项目 前言:公司有一个项目需要用Vue转成exe,首先我使用vue-cli脚手架搭建vue2项目,然后安装electron 安装electron 这一步骤可以省略,安装electron-builder时会自动安装electron npm i electron 安装electron-builder vue add electron-builder 项目中多出…...