学习pytorch18 pytorch完整的模型训练流程
pytorch完整的模型训练流程
- 1. 流程
- 1. 整理训练数据 使用CIFAR10数据集
- 2. 搭建网络结构
- 3. 构建损失函数
- 4. 使用优化器
- 5. 训练模型
- 6. 测试数据 计算模型预测正确率
- 7. 保存模型
- 2. 代码
- 1. model.py
- 2. train.py
- 3. 结果
- tensorboard结果
- 以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值
- 训练loss
- 测试loss
- 测试集正确率
- 4. 需要注意的细节
1. 流程
1. 整理训练数据 使用CIFAR10数据集
train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
2. 搭建网络结构
model.py
3. 构建损失函数
loss_fn = nn.CrossEntropyLoss()
4. 使用优化器
learing_rate = 1e-2 # 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)
5. 训练模型
output = net(imgs) # 数据输入模型
loss = loss_fn(output, targets) # 损失函数计算损失 看计算的输出和真实的标签误差是多少
# 优化器开始优化模型 1.梯度清零 2.反向传播 3.参数优化
optimizer.zero_grad() # 利用优化器把梯度清零 全部设置为0
loss.backward() # 设置计算的损失值的钩子,调用损失的反向传播,计算每个参数结点的参数
optimizer.step() # 调用优化器的step()方法 对其中的参数进行优化
6. 测试数据 计算模型预测正确率
output = net(imags)
# 计算测试集的正确率
preds = (output.argmax(1)==targets).sum()
accuracy += preds
rate = accuracy/len(test_data)
调用模型输出tensor 数据类型的 argmax方法, argmax或获取一行或者一列数值中最大数值的下标位置,argmax(0) 是从列的维度取一列数值的最大值的下标,argmax(1) 是从行的维度取一行数值的最大值的下标
output.argmax(1)==targets 会输出如下图最后一行 [false, ture], 对应位置相同则为true,对应位置不同则为false;
调用sum()方法,计算求和,false值为0,true值为1.
最后计算得出测试集整体正确率: rate = accuracy/len(test_data)
7. 保存模型
torch.save(net, './net_epoch{}.pth'.format(i))
2. 代码
1. model.py
import torch
from torch import nn# 2. 搭建模型网络结构--神经网络
class Cifar10Net(nn.Module):def __init__(self):super(Cifar10Net, self).__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.net(x)return xif __name__ == '__main__':net = Cifar10Net()input = torch.ones((64, 3, 32, 32))output = net(input)print(output.shape)
2. train.py
import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriterfrom p24_model import *# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoadertrain_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)# 2. 导入模型结构 创建模型
net = Cifar10Net()# 3. 创建损失函数 分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)# 设置训练网络的一些参数
epoch = 10 # 记录训练的轮数
total_train_step = 0 # 记录训练的次数
total_test_step = 0 # 记录测试的次数# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')for i in range(epoch):# 训练步骤开始net.train() # 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用for data in train_loader:imgs, targets = data # 获取数据output = net(imgs) # 数据输入模型loss = loss_fn(output, targets) # 损失函数计算损失 看计算的输出和真实的标签误差是多少# 优化器开始优化模型 1.梯度清零 2.反向传播 3.参数优化optimizer.zero_grad() # 利用优化器把梯度清零 全部设置为0loss.backward() # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数optimizer.step() # 调用优化器的step()方法 对其中的参数进行优化# 优化一次 认为训练了一次total_train_step += 1if total_train_step % 100 == 0:print('训练次数: {} loss: {}'.format(total_train_step, loss))# 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】# print('训练次数: {} loss: {}'.format(total_train_step, loss.item()))writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)# 利用现有模型做模型测试# 测试步骤开始total_test_loss = 0accuracy = 0net.eval() # 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用with torch.no_grad():for data in test_loader:imags, targets = dataoutput = net(imags)loss = loss_fn(output, targets)total_test_loss += loss.item()# 计算测试集的正确率preds = (output.argmax(1)==targets).sum()accuracy += preds# writer.add_scalar('test-loss', total_test_loss, global_step=i+1)writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)total_test_step += 1print("---------test loss: {}--------------".format(total_test_loss))print("---------test accuracy: {}--------------".format(accuracy))# 保存每一个epoch训练得到的模型torch.save(net, './net_epoch{}.pth'.format(i))writer.close()
3. 结果
训练数据集大小: 50000
测试数据集大小: 10000
0.01
训练次数: 100 loss: 2.2905373573303223
训练次数: 200 loss: 2.2878968715667725
训练次数: 300 loss: 2.258394718170166
训练次数: 400 loss: 2.1968581676483154
训练次数: 500 loss: 2.0476632118225098
训练次数: 600 loss: 2.002145767211914
训练次数: 700 loss: 2.016021728515625
---------test loss: 316.382279753685--------------
训练次数: 800 loss: 1.8957302570343018
训练次数: 900 loss: 1.8659226894378662
训练次数: 1000 loss: 1.9004186391830444
训练次数: 1100 loss: 1.9708642959594727
......
tensorboard结果
安装tensorboard运行环境
pip install tensorboard
pip install opencv-python
pip install six
tensorboard --logdir=train_logs
以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值
训练loss
测试loss
测试集正确率
4. 需要注意的细节
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module
所有网络层继承于torch.nn.Module, net.train() net.eval() 在模型训练或测试之初 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用,当模型有这两个网络层的时候,两个代码需要加上。
相关文章:

学习pytorch18 pytorch完整的模型训练流程
pytorch完整的模型训练流程 1. 流程1. 整理训练数据 使用CIFAR10数据集2. 搭建网络结构3. 构建损失函数4. 使用优化器5. 训练模型6. 测试数据 计算模型预测正确率7. 保存模型 2. 代码1. model.py2. train.py 3. 结果tensorboard结果以下图片 颜色较浅的线是真实计算的值&#x…...

电子学会C/C++编程等级考试2021年09月(五级)真题解析
C/C++等级考试(1~8级)全部真题・点这里 第1题:抓牛 农夫知道一头牛的位置,想要抓住它。农夫和牛都位于数轴上,农夫起始位于点N(0<=N<=100000),牛位于点K(0<=K<=100000)。农夫有两种移动方式: 1、从X移动到X-1或X+1,每次移动花费一分钟 2、从X移动到2*X,每…...

Halcon联合winform显示以及处理
在窗口中添加窗体和按钮,并在解决方案资源管理器中调加了导入Halcon导出的.cs文件,运行出现下图的问题: 问题1:CS0017 程序定义了多个入口点。使用/main(指定包含入口点的类型)进行编译。 解决方案1.: 右…...

【设计模式-4.3】行为型——责任链模式
说明:本文介绍设计模式中行为型设计模式中的,责任链模式; 审批流程 责任链模式属于行为型设计模式,关注于对象的行为。责任链模式非常典型的案例,就是审批流程的实现。如一个报销单的审批流程,根据报销单…...

单片机语言--C51语言的数据类型以及存储类型以及一些基本运算
C51语言 本文主要涉及C51语言的一些基本知识,比如C51语言的数据类型以及存储类型以及一些基本运算。 文章目录 C51语言一、 C51与标准C的比较二、 C51语言中的数据类型与存储类型2.1、C51的扩展数据类型2.2、数据存储类型 三、 C51的基本运算3.1 算术运算符3.2 逻辑…...

《每天一个Linux命令》 -- (5)通过sshkey密钥登录服务器
欢迎阅读《每天一个Linux命令》系列!在本篇文章中,将介绍通过密钥生成,使用公钥连接管理服务器。 概念 SSH 密钥是用于安全地访问远程服务器的一种方法。SSH 密钥由一对密钥组成:公钥和私钥。公钥存储在远程服务器上,…...
kubernetes的服务发现(二)
如前面的文章我们说了,kubernetes的服务发现是服务端发现模式。它有一个服务注册中心,使用DNS作为服务的注册表。每个集群都会运行一个DNS服务,默认是CoreDNS服务。每个服务都会在这个DNS中注册。注册的大致过程: 1、向kube-apise…...
【矩阵论】Chapter 4—特征值和特征向量知识点总结复习
文章目录 1 特征值和特征向量2 对角化3 Schur定理和正规矩阵4 Python求解 1 特征值和特征向量 定义 设 σ \sigma σ为数域 F F F上线性空间 V V V上的一个线性变换,一个非零向量 v ∈ V v\in V v∈V,如果存在一个 λ ∈ F \lambda \in F λ∈F使得 σ (…...

Linux 进程地址空间
知识回顾 在 C 语言的学习过程中,我们知道内存是可以被划分为栈区,堆区,全局数据区,字符常量区,代码区的。他的空间排布可能是下面的样子: 其中,全局数据区,可以划分为已初始化全局…...
websocket vue操作
let websocket: WebSocket; /** websocket测试 */ function connectWebsocket() {if (typeof WebSocket "undefined") {console.log("您的浏览器不支持WebSocket");return;}// let ip window.location.hostname ":8080";let ip "10.192…...

腾讯云CentOS8 jenkins war安装jenkins步骤文档
腾讯云CentOS8 jenkins war安装jenkins步骤文档 一、安装jdk 1.1 上传jdk-11.0.20_linux-x64_bin.tar.gz 1.2 解压jdk安装包文件 tar -zxvf jdk*.tar.gz 1.3 在/usr/local 目录下创建java目录 cd /usr/local mkdir java 1.4 切到java目录,把jdk解压文件改名为jd…...
Linux: glibc: net/if.h vs linux/if.h
最近看到一段代码改动,用net/if.h替换了linux/if.h。仔细看了看这两个的区别: https://stackoverflow.com/questions/20082433/what-is-the-difference-between-linux-if-h-and-net-if-h 从网上搜了一下看到如下的一个编译错误,如果同时使用这两个if.h文件,需要将net/if.h…...

使用Android Studio导入Android源码:基于全志H713 AOSP,方便解决编译、编码问题
文章目录 一、 篇头二、 操作步骤2.1 编译AOSP AS工程文件2.2 将AOSP导入Android Studio2.3 切到Project试图2.4 等待index结束2.5 下载缺失的JDK 1.82.6 导入完成 三、 导入AS的好处3.1 本文案例演示源码编译错误AS对比同文件其余地方的调用AS错误提示依赖AS做错误修正 一、 篇…...
python random详解
文章目录 random简单示例1. 生成随机浮点数:2. 生成指定范围内的随机整数:3. 从序列中随机选择元素:4. 打乱序列顺序: 常用的方法及其解释和例子:1. random():该方法返回一个0到1之间的随机浮点数。例如&am…...

java-两个列表进行比较,判断那些是需要新增的、删除的、和更新的
文章目录 前言两个列表进行比较,判断那些是需要新增的、删除的、和更新的 前言 如果您觉得有用的话,记得给博主点个赞,评论,收藏一键三连啊,写作不易啊^ _ ^。 而且听说点赞的人每天的运气都不会太差,实…...

【WPF.NET开发】WPF中的对话框
目录 1、消息框 2、通用对话框 3、自定义对话框 实现对话框 4、打开对话框的 UI 元素 4.1 菜单项 4.2 按钮 5、返回结果 5.1 模式对话框 5.2 处理响应 5.3 非模式对话框 Windows Presentation Foundation (WPF) 为你提供了自行设计对话框的方法。 对话框是窗口&…...

NLP项目实战01之电影评论分类
介绍: 欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。 展示: 训练展示如下…...
一款可无限扩展的软件定时器开源框架项目代码
摘自链接 时间片轮询架构如何稳定高效实现,取代传统的标志位判断方式,更优雅更方便地管理程序的时间触发操作。 可以在STM32单片机上运行。...

GRE与顺丰圆通快递盒子
1. DNS污染 随想: 在输入一串网址后,会发生如下变化如果你在系统中配置了 Hosts 文件,那么电脑会先查询 Hosts 文件如果 Hosts 里面没有这个别名,就通过域名服务器查询域名服务器回应了,那么你的电脑就可以根据域名服…...

12.Mysql 多表数据横向合并和纵向合并
Mysql 函数参考和扩展:Mysql 常用函数和基础查询、 Mysql 官网 Mysql 语法执行顺序如下,一定要清楚!!!运算符相关,可前往 Mysql 基础语法和执行顺序扩展。 (8) select (9) distinct (11)<columns_name…...

K8S认证|CKS题库+答案| 11. AppArmor
目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...
Spring Boot 实现流式响应(兼容 2.7.x)
在实际开发中,我们可能会遇到一些流式数据处理的场景,比如接收来自上游接口的 Server-Sent Events(SSE) 或 流式 JSON 内容,并将其原样中转给前端页面或客户端。这种情况下,传统的 RestTemplate 缓存机制会…...
Objective-C常用命名规范总结
【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名(Class Name)2.协议名(Protocol Name)3.方法名(Method Name)4.属性名(Property Name)5.局部变量/实例变量(Local / Instance Variables&…...
鱼香ros docker配置镜像报错:https://registry-1.docker.io/v2/
使用鱼香ros一件安装docker时的https://registry-1.docker.io/v2/问题 一键安装指令 wget http://fishros.com/install -O fishros && . fishros出现问题:docker pull 失败 网络不同,需要使用镜像源 按照如下步骤操作 sudo vi /etc/docker/dae…...
高防服务器能够抵御哪些网络攻击呢?
高防服务器作为一种有着高度防御能力的服务器,可以帮助网站应对分布式拒绝服务攻击,有效识别和清理一些恶意的网络流量,为用户提供安全且稳定的网络环境,那么,高防服务器一般都可以抵御哪些网络攻击呢?下面…...
【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分
一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计,提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合:各模块职责清晰,便于独立开发…...

使用 SymPy 进行向量和矩阵的高级操作
在科学计算和工程领域,向量和矩阵操作是解决问题的核心技能之一。Python 的 SymPy 库提供了强大的符号计算功能,能够高效地处理向量和矩阵的各种操作。本文将深入探讨如何使用 SymPy 进行向量和矩阵的创建、合并以及维度拓展等操作,并通过具体…...
Go 语言并发编程基础:无缓冲与有缓冲通道
在上一章节中,我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道,它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好࿰…...
CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝
目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为:一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...
快刀集(1): 一刀斩断视频片头广告
一刀流:用一个简单脚本,秒杀视频片头广告,还你清爽观影体验。 1. 引子 作为一个爱生活、爱学习、爱收藏高清资源的老码农,平时写代码之余看看电影、补补片,是再正常不过的事。 电影嘛,要沉浸,…...