当前位置: 首页 > news >正文

10-pytorch-完整模型训练

b站小土堆pytorch教程学习笔记

一、从零开始构建自己的神经网络

1.模型构建
#准备数据集
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoadertrain_data=torchvision.datasets.CIFAR10('dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
#查看训练数据集和测试集大小
train_data_size=len(train_data)
test_data_size=len(test_data)
print('训练数据集长度为:{}'.format(train_data_size))#训练数据集长度为:50000
print('测试数据集长度为:{}'.format(test_data_size))#测试数据集长度为:10000#利用datalo加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)#搭建神经网络,在model文件中搭建网络,在此文件中引用
han=Han()#损失函数
loss_fn=nn.CrossEntropyLoss()#优化器
# learning_rate=0.01
learning_rate=1e-2
optimizer=torch.optim.SGD(han.parameters(),lr=learning_rate)#设置训练网络的相关参数
total_train_step = 0#记录训练的次数
total_test_step = 0#记录测试的次数
epoch=10#训练轮数#添加tensorboard
writer=SummaryWriter('logs/train')for i in range(10):print('-------第{}轮训练开始-------'.format(i+1))for data in train_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)#优化器优化模型optimizer.zero_grad()#梯度清零loss.backward()#反向传播计算梯度optimizer.step()#参数优化total_train_step=total_train_step+1if total_train_step % 100==0:#逢100打印print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))#loss.item()取出tensor类型的数字writer.add_scalar('train_loss',loss.item(),total_train_step)#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()print('所有测试集上的损失:{}'.format(total_test_loss))writer.add_scalar('test_loss',total_test_loss,total_test_step)total_test_step+=1#保存每一轮模型torch.save(han,'han_{}.pth'.format(i))print('模型已保存')
writer.close()
import torch
from torch import nnclass Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':han=Han()input=torch.ones(64,3,32,32)output=han(input)print(output.shape)#torch.Size([64, 10])10表示十个类别输出概率

结果如下:
在这里插入图片描述

2.使用argmax计算整体正确率
#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0total_acc=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()acc=(output.argmax(1)==target).sum()#(1)横着看total_acc+=accprint('所有测试集上的损失:{}'.format(total_test_loss))print('整体测试集上的正确率:{}'.format(total_acc/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_acc', total_acc/test_data_size, total_test_step)total_test_step+=1

整体测试集上的正确率:0.27480000257492065

3.当训练或测试时存在dropout层或batch normal层,则需要在训练训练和测试前加入:
#训练前
han.train()
#测试前
han.eval()

二、使用GPU

网络模型、数据(输入、标注)、损失函数调用cuda()

1.方式1
#模型
if torch.cuda.is_available():han=han.cuda()
#损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
imgs,target=data
imgs=imgs.cuda()
target=target.cuda()
2.方式2
#定义训练设备
device=torch.device('cuda')
han=han.to(device)
imgs = imgs.to(device)
target = target.to(device)

相关文章:

10-pytorch-完整模型训练

b站小土堆pytorch教程学习笔记 一、从零开始构建自己的神经网络 1.模型构建 #准备数据集 import torch import torchvision from torch.utils.tensorboard import SummaryWriterfrom model import * from torch.utils.data import DataLoadertrain_datatorchvision.datasets.…...

高级RAG:重新排名,从原理到实现的两种主流方法

原文地址:https://pub.towardsai.net/advanced-rag-04-re-ranking-85f6ae8170b1 2024 年 2 月 14 日 重新排序在检索增强生成(RAG)过程中起着至关重要的作用。在简单的 RAG 方法中,可以检索大量上下文,但并非所有上下…...

使用logicflow流程图实例

一.背景 需要使用流程引擎开发项目,没有使用flowable、activiti这类的国外流程引擎,想使用国内的引擎二次开发,缺少单例模式的流程画图程序,都是vue、react、angluer的不适合,从网上找了antx6、logicflow、bpmn.js。感…...

Stable Diffusion 绘画入门教程(webui)-ControlNet(IP2P)

上篇文章介绍了深度Depth,这篇文章介绍下IP2P(InstructP2P), 通俗理解就是图生图,给原有图加一些效果,比如下图,左边为原图,右边为增加了效果的图: 文章目录 一、选大模型二、写提示词三、基础参…...

五力分析(Porter‘s Five Forces)

五力分析是一种用于评估竞争力的框架,由哈佛商学院教授迈克尔波特(Michael Porter)提出。它通过分析一个行业的五个关键力量(竞争对手、供应商、顾客、替代品和新进入者)来评估一个企业或行业的竞争环境。这个框架可以…...

十一、Qt数据库操作

一、Sql介绍 Qt Sql模块包含多个类,实现数据库的连接,Sql语句的执行,数据获取与界面显示,数据与界面直接使用Model/View结构。1、使用Sql模块 (1)工程加入 QT sql(2)添加头文件 …...

【Spring】IoC容器 控制反转 与 DI依赖注入 XML实现版本 第二期

文章目录 基于 XML 配置方式组件管理前置 准备项目一、 组件(Bean)信息声明配置(IoC):1.1 基于无参构造1.2 基于静态 工厂方法实例化1.3 基于非静态 工厂方法实例化 二、 组件(Bean)依赖注入配置…...

神经网络系列---感知机(Neuron)

文章目录 感知机(Neuron)感知机(Neuron)的决策函数可以表示为:感知机(Neuron)的学习算法主要包括以下步骤:感知机可以实现逻辑运算中的AND、OR、NOT和异或(XOR)运算。 感知机(Neuron) 感知机(Neuron)是一种简单而有效的二分类算法,用于将输入…...

k8s(2)

目录 一.二进制部署k8s 常见的K8S安装部署方式: k8s部署 二进制与高可用的区别 二.部署k8s 初始化操作: 每台node安装docker: 在 master01 节点上操作; 准备cfssl证书生成工具:: 执行脚本文件: 拉入etcd压缩包…...

利用nginx内部访问特性实现静态资源授权访问

在nginx中,将静态资源设为internal;然后将前端的静态资源地址改为指向后端,在后端的响应头部中写上静态资源地址。 近期客户对我们项目做安全性测评,暴露出一些安全性问题,其中一个是有些静态页面(*.html&…...

fly-barrage 前端弹幕库(1):项目介绍

fly-barrage 是我写的一个前端弹幕库,由于经常在 Bilibili 上看视频,所以对网页的弹幕功能一直蛮感兴趣的,所以做了这个库,可以帮助前端快速的实现弹幕功能。 项目官网地址:https://fly-barrage.netlify.app/&#xff…...

jetcache如果一个主体涉及多个缓存时编辑或者删除时如何同时失效多个缓存

在实际使用过程中,可能会遇到这种情形:一个主体会有多个缓存,比如用户基础信息缓存、用户详情缓存,那么当删除用户信息后就需要同时失效多个缓存中该主体数据,那么jetcache支持这种应用场景么,答案是支持&a…...

uni-app 实现拍照后给照片加水印功能

遇到个需求需要实现&#xff0c;研究了一下后写了个demo 本质上就是把拍完照后的照片放到canvas里&#xff0c;然后加上水印样式然后再重新生成一张图片 代码如下&#xff0c;看注释即可~使用的话记得还是得优化下代码 <template><view class"content"&g…...

【ArcGIS】利用DEM进行水文分析:流向/流量等

利用DEM进行水文分析 ArcGIS实例参考 水文分析通过建立地表水文模型&#xff0c;研究与地表水流相关的各种自然现象&#xff0c;在城市和区域规划、农业及森林、交通道路等许多领域具有广泛的应用。 ArcGIS实例 某流域30m分辨率DEM如下&#xff1a; &#xff08;1&#xff09…...

论文阅读笔记——PathAFL:Path-Coverage Assisted Fuzzing

文章目录 前言PathAFL&#xff1a;Path-Coverage Assisted Fuzzing1、解决的问题和目标2、技术路线2.1、如何识别 h − p a t h h-path h−path&#xff1f;2.2、如何减少 h − p a t h h-path h−path的数量&#xff1f;2.3、哪些h-path将被添加到种子队列&#xff1f;2.4、种…...

C语言中各种运算符用法

C语言中有许多不同的运算符&#xff0c;用于执行各种不同的操作。 以下是C语言中常见的运算符及其用法&#xff1a; 算术运算符&#xff1a; 加法运算符&#xff08;&#xff09;&#xff1a;用于将两个值相加。减法运算符&#xff08;-&#xff09;&#xff1a;用于将一个值减…...

pythonJax小记(五):python: 使用Jax深度图像(正交投影和透视投影之间的转换)(持续更新,评论区可以补充)

python: 使用Jax深度图像&#xff08;正交投影和透视投影之间的转换&#xff09; 前言问题描述1. 透视投影2. 正交投影 直接上代码解释1. compute_projection_parameters 函数a. 参数解释b. 函数计算 2. ortho_to_persp 函数a. 计算投影参数&#xff1a;b. 生成像素坐标网格&am…...

web安全学习笔记【16】——信息打点(6)

信息打点-语言框架&开发组件&FastJson&Shiro&Log4j&SpringBoot等[1] #知识点&#xff1a; 1、业务资产-应用类型分类 2、Web单域名获取-接口查询 3、Web子域名获取-解析枚举 4、Web架构资产-平台指纹识别 ------------------------------------ 1、开源-C…...

145.二叉树的后序遍历

// 定义一个名为Solution的类&#xff0c;用于解决二叉树的后序遍历问题 class Solution { // 定义一个公共方法&#xff0c;输入是一个二叉树的根节点&#xff0c;返回一个包含后序遍历结果的整数列表 public List<Integer> postorderTraversal(TreeNode root) { /…...

ssh远程连接免密码访问

我们在远程登录的时候&#xff0c;经常需要输入密码&#xff0c;密码往往比较复杂&#xff0c;输入比较耗费时间&#xff0c;这种情况下可以使用ssh免密码登录。 一般的教程是需要生成ssh密钥后&#xff0c;然后把密钥复制到server端完成配置&#xff0c;这里提供一个简单的方…...

Unity3D中Gfx.WaitForPresent优化方案

前言 在Unity中&#xff0c;Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染&#xff08;即CPU被阻塞&#xff09;&#xff0c;这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案&#xff1a; 对惹&#xff0c;这里有一个游戏开发交流小组&…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

分布式增量爬虫实现方案

之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面&#xff0c;避免重复抓取&#xff0c;以节省资源和时间。 在分布式环境下&#xff0c;增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路&#xff1a;将增量判…...

【JVM面试篇】高频八股汇总——类加载和类加载器

目录 1. 讲一下类加载过程&#xff1f; 2. Java创建对象的过程&#xff1f; 3. 对象的生命周期&#xff1f; 4. 类加载器有哪些&#xff1f; 5. 双亲委派模型的作用&#xff08;好处&#xff09;&#xff1f; 6. 讲一下类的加载和双亲委派原则&#xff1f; 7. 双亲委派模…...

NPOI Excel用OLE对象的形式插入文件附件以及插入图片

static void Main(string[] args) {XlsWithObjData();Console.WriteLine("输出完成"); }static void XlsWithObjData() {// 创建工作簿和单元格,只有HSSFWorkbook,XSSFWorkbook不可以HSSFWorkbook workbook new HSSFWorkbook();HSSFSheet sheet (HSSFSheet)workboo…...

Caliper 配置文件解析:fisco-bcos.json

config.yaml 文件 config.yaml 是 Caliper 的主配置文件,通常包含以下内容: test:name: fisco-bcos-test # 测试名称description: Performance test of FISCO-BCOS # 测试描述workers:type: local # 工作进程类型number: 5 # 工作进程数量monitor:type: - docker- pro…...

Java详解LeetCode 热题 100(26):LeetCode 142. 环形链表 II(Linked List Cycle II)详解

文章目录 1. 题目描述1.1 链表节点定义 2. 理解题目2.1 问题可视化2.2 核心挑战 3. 解法一&#xff1a;HashSet 标记访问法3.1 算法思路3.2 Java代码实现3.3 详细执行过程演示3.4 执行结果示例3.5 复杂度分析3.6 优缺点分析 4. 解法二&#xff1a;Floyd 快慢指针法&#xff08;…...

AD学习(3)

1 PCB封装元素组成及简单的PCB封装创建 封装的组成部分&#xff1a; &#xff08;1&#xff09;PCB焊盘&#xff1a;表层的铜 &#xff0c;top层的铜 &#xff08;2&#xff09;管脚序号&#xff1a;用来关联原理图中的管脚的序号&#xff0c;原理图的序号需要和PCB封装一一…...

Java后端检查空条件查询

通过抛出运行异常&#xff1a;throw new RuntimeException("请输入查询条件&#xff01;");BranchWarehouseServiceImpl.java // 查询试剂交易&#xff08;入库/出库&#xff09;记录Overridepublic List<BranchWarehouseTransactions> queryForReagent(Branch…...

数据库正常,但后端收不到数据原因及解决

从代码和日志来看&#xff0c;后端SQL查询确实返回了数据&#xff0c;但最终user对象却为null。这表明查询结果没有正确映射到User对象上。 在前后端分离&#xff0c;并且ai辅助开发的时候&#xff0c;很容易出现前后端变量名不一致情况&#xff0c;还不报错&#xff0c;只是单…...