当前位置: 首页 > news >正文

10-pytorch-完整模型训练

b站小土堆pytorch教程学习笔记

一、从零开始构建自己的神经网络

1.模型构建
#准备数据集
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoadertrain_data=torchvision.datasets.CIFAR10('dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
#查看训练数据集和测试集大小
train_data_size=len(train_data)
test_data_size=len(test_data)
print('训练数据集长度为:{}'.format(train_data_size))#训练数据集长度为:50000
print('测试数据集长度为:{}'.format(test_data_size))#测试数据集长度为:10000#利用datalo加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)#搭建神经网络,在model文件中搭建网络,在此文件中引用
han=Han()#损失函数
loss_fn=nn.CrossEntropyLoss()#优化器
# learning_rate=0.01
learning_rate=1e-2
optimizer=torch.optim.SGD(han.parameters(),lr=learning_rate)#设置训练网络的相关参数
total_train_step = 0#记录训练的次数
total_test_step = 0#记录测试的次数
epoch=10#训练轮数#添加tensorboard
writer=SummaryWriter('logs/train')for i in range(10):print('-------第{}轮训练开始-------'.format(i+1))for data in train_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)#优化器优化模型optimizer.zero_grad()#梯度清零loss.backward()#反向传播计算梯度optimizer.step()#参数优化total_train_step=total_train_step+1if total_train_step % 100==0:#逢100打印print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))#loss.item()取出tensor类型的数字writer.add_scalar('train_loss',loss.item(),total_train_step)#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()print('所有测试集上的损失:{}'.format(total_test_loss))writer.add_scalar('test_loss',total_test_loss,total_test_step)total_test_step+=1#保存每一轮模型torch.save(han,'han_{}.pth'.format(i))print('模型已保存')
writer.close()
import torch
from torch import nnclass Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':han=Han()input=torch.ones(64,3,32,32)output=han(input)print(output.shape)#torch.Size([64, 10])10表示十个类别输出概率

结果如下:
在这里插入图片描述

2.使用argmax计算整体正确率
#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0total_acc=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()acc=(output.argmax(1)==target).sum()#(1)横着看total_acc+=accprint('所有测试集上的损失:{}'.format(total_test_loss))print('整体测试集上的正确率:{}'.format(total_acc/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_acc', total_acc/test_data_size, total_test_step)total_test_step+=1

整体测试集上的正确率:0.27480000257492065

3.当训练或测试时存在dropout层或batch normal层,则需要在训练训练和测试前加入:
#训练前
han.train()
#测试前
han.eval()

二、使用GPU

网络模型、数据(输入、标注)、损失函数调用cuda()

1.方式1
#模型
if torch.cuda.is_available():han=han.cuda()
#损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
imgs,target=data
imgs=imgs.cuda()
target=target.cuda()
2.方式2
#定义训练设备
device=torch.device('cuda')
han=han.to(device)
imgs = imgs.to(device)
target = target.to(device)

相关文章:

10-pytorch-完整模型训练

b站小土堆pytorch教程学习笔记 一、从零开始构建自己的神经网络 1.模型构建 #准备数据集 import torch import torchvision from torch.utils.tensorboard import SummaryWriterfrom model import * from torch.utils.data import DataLoadertrain_datatorchvision.datasets.…...

高级RAG:重新排名,从原理到实现的两种主流方法

原文地址:https://pub.towardsai.net/advanced-rag-04-re-ranking-85f6ae8170b1 2024 年 2 月 14 日 重新排序在检索增强生成(RAG)过程中起着至关重要的作用。在简单的 RAG 方法中,可以检索大量上下文,但并非所有上下…...

使用logicflow流程图实例

一.背景 需要使用流程引擎开发项目,没有使用flowable、activiti这类的国外流程引擎,想使用国内的引擎二次开发,缺少单例模式的流程画图程序,都是vue、react、angluer的不适合,从网上找了antx6、logicflow、bpmn.js。感…...

Stable Diffusion 绘画入门教程(webui)-ControlNet(IP2P)

上篇文章介绍了深度Depth,这篇文章介绍下IP2P(InstructP2P), 通俗理解就是图生图,给原有图加一些效果,比如下图,左边为原图,右边为增加了效果的图: 文章目录 一、选大模型二、写提示词三、基础参…...

五力分析(Porter‘s Five Forces)

五力分析是一种用于评估竞争力的框架,由哈佛商学院教授迈克尔波特(Michael Porter)提出。它通过分析一个行业的五个关键力量(竞争对手、供应商、顾客、替代品和新进入者)来评估一个企业或行业的竞争环境。这个框架可以…...

十一、Qt数据库操作

一、Sql介绍 Qt Sql模块包含多个类,实现数据库的连接,Sql语句的执行,数据获取与界面显示,数据与界面直接使用Model/View结构。1、使用Sql模块 (1)工程加入 QT sql(2)添加头文件 …...

【Spring】IoC容器 控制反转 与 DI依赖注入 XML实现版本 第二期

文章目录 基于 XML 配置方式组件管理前置 准备项目一、 组件(Bean)信息声明配置(IoC):1.1 基于无参构造1.2 基于静态 工厂方法实例化1.3 基于非静态 工厂方法实例化 二、 组件(Bean)依赖注入配置…...

神经网络系列---感知机(Neuron)

文章目录 感知机(Neuron)感知机(Neuron)的决策函数可以表示为:感知机(Neuron)的学习算法主要包括以下步骤:感知机可以实现逻辑运算中的AND、OR、NOT和异或(XOR)运算。 感知机(Neuron) 感知机(Neuron)是一种简单而有效的二分类算法,用于将输入…...

k8s(2)

目录 一.二进制部署k8s 常见的K8S安装部署方式: k8s部署 二进制与高可用的区别 二.部署k8s 初始化操作: 每台node安装docker: 在 master01 节点上操作; 准备cfssl证书生成工具:: 执行脚本文件: 拉入etcd压缩包…...

利用nginx内部访问特性实现静态资源授权访问

在nginx中,将静态资源设为internal;然后将前端的静态资源地址改为指向后端,在后端的响应头部中写上静态资源地址。 近期客户对我们项目做安全性测评,暴露出一些安全性问题,其中一个是有些静态页面(*.html&…...

fly-barrage 前端弹幕库(1):项目介绍

fly-barrage 是我写的一个前端弹幕库,由于经常在 Bilibili 上看视频,所以对网页的弹幕功能一直蛮感兴趣的,所以做了这个库,可以帮助前端快速的实现弹幕功能。 项目官网地址:https://fly-barrage.netlify.app/&#xff…...

jetcache如果一个主体涉及多个缓存时编辑或者删除时如何同时失效多个缓存

在实际使用过程中,可能会遇到这种情形:一个主体会有多个缓存,比如用户基础信息缓存、用户详情缓存,那么当删除用户信息后就需要同时失效多个缓存中该主体数据,那么jetcache支持这种应用场景么,答案是支持&a…...

uni-app 实现拍照后给照片加水印功能

遇到个需求需要实现&#xff0c;研究了一下后写了个demo 本质上就是把拍完照后的照片放到canvas里&#xff0c;然后加上水印样式然后再重新生成一张图片 代码如下&#xff0c;看注释即可~使用的话记得还是得优化下代码 <template><view class"content"&g…...

【ArcGIS】利用DEM进行水文分析:流向/流量等

利用DEM进行水文分析 ArcGIS实例参考 水文分析通过建立地表水文模型&#xff0c;研究与地表水流相关的各种自然现象&#xff0c;在城市和区域规划、农业及森林、交通道路等许多领域具有广泛的应用。 ArcGIS实例 某流域30m分辨率DEM如下&#xff1a; &#xff08;1&#xff09…...

论文阅读笔记——PathAFL:Path-Coverage Assisted Fuzzing

文章目录 前言PathAFL&#xff1a;Path-Coverage Assisted Fuzzing1、解决的问题和目标2、技术路线2.1、如何识别 h − p a t h h-path h−path&#xff1f;2.2、如何减少 h − p a t h h-path h−path的数量&#xff1f;2.3、哪些h-path将被添加到种子队列&#xff1f;2.4、种…...

C语言中各种运算符用法

C语言中有许多不同的运算符&#xff0c;用于执行各种不同的操作。 以下是C语言中常见的运算符及其用法&#xff1a; 算术运算符&#xff1a; 加法运算符&#xff08;&#xff09;&#xff1a;用于将两个值相加。减法运算符&#xff08;-&#xff09;&#xff1a;用于将一个值减…...

pythonJax小记(五):python: 使用Jax深度图像(正交投影和透视投影之间的转换)(持续更新,评论区可以补充)

python: 使用Jax深度图像&#xff08;正交投影和透视投影之间的转换&#xff09; 前言问题描述1. 透视投影2. 正交投影 直接上代码解释1. compute_projection_parameters 函数a. 参数解释b. 函数计算 2. ortho_to_persp 函数a. 计算投影参数&#xff1a;b. 生成像素坐标网格&am…...

web安全学习笔记【16】——信息打点(6)

信息打点-语言框架&开发组件&FastJson&Shiro&Log4j&SpringBoot等[1] #知识点&#xff1a; 1、业务资产-应用类型分类 2、Web单域名获取-接口查询 3、Web子域名获取-解析枚举 4、Web架构资产-平台指纹识别 ------------------------------------ 1、开源-C…...

145.二叉树的后序遍历

// 定义一个名为Solution的类&#xff0c;用于解决二叉树的后序遍历问题 class Solution { // 定义一个公共方法&#xff0c;输入是一个二叉树的根节点&#xff0c;返回一个包含后序遍历结果的整数列表 public List<Integer> postorderTraversal(TreeNode root) { /…...

ssh远程连接免密码访问

我们在远程登录的时候&#xff0c;经常需要输入密码&#xff0c;密码往往比较复杂&#xff0c;输入比较耗费时间&#xff0c;这种情况下可以使用ssh免密码登录。 一般的教程是需要生成ssh密钥后&#xff0c;然后把密钥复制到server端完成配置&#xff0c;这里提供一个简单的方…...

利用快马平台与vscode codex快速构建react待办事项应用原型

最近在尝试用AI工具快速验证产品原型&#xff0c;发现InsCode(快马)平台配合VSCode Codex能实现惊人的开发效率。以React待办事项应用为例&#xff0c;从零到可交互原型只用了不到10分钟&#xff0c;分享下具体实现思路和操作过程。 需求拆解与AI描述 首先将待办事项应用的7个核…...

【大英赛】2009-2026年大英赛ABCD类历年真题、样卷、听力音频及答案PDF电子版

2026年大英赛将于4月12日9:00—11:00举行&#xff0c;开始倒计时啦&#xff01;小编整理了最新的2009-2026年大学生英语竞赛&#xff08;大英赛NECCS&#xff09;ABCD类历年真题、样卷、听力音频及答案解析&#xff0c;PDF电子版&#xff0c;可下载打印&#xff01; 资料下载&a…...

3个步骤掌握Markmap:将Markdown转换为交互式思维导图完全指南

3个步骤掌握Markmap&#xff1a;将Markdown转换为交互式思维导图完全指南 【免费下载链接】markmap Build mindmaps with plain text 项目地址: https://gitcode.com/gh_mirrors/ma/markmap Markmap作为一款强大的开源工具&#xff0c;能够将普通的Markdown文本转换为直…...

实战演练:在快马平台用codex生成一个完整的react用户管理组件

今天想和大家分享一个实战案例&#xff1a;如何在InsCode(快马)平台用Codex快速生成一个React用户管理组件。整个过程比我预想的顺畅很多&#xff0c;特别适合需要快速原型开发的场景。 项目需求拆解 用户管理是后台系统的标配功能&#xff0c;这次要实现三个核心模块&#xff…...

UDS诊断自动化测试入门:用Python模拟Tester端,批量刷写DID与安全访问

UDS诊断自动化测试实战&#xff1a;Python构建高覆盖率ECU测试框架 在汽车电子控制单元&#xff08;ECU&#xff09;开发中&#xff0c;诊断功能测试往往是最耗时的手工操作环节之一。想象一下&#xff0c;当需要验证数百个数据标识符&#xff08;DID&#xff09;的读写功能时&…...

别再花冤枉钱!和腰突颈椎病斗了 3 年,我终于踩中了康复的捷径

有没有和我一样的打工人&#xff0c;每天久坐 8 小时起步&#xff0c;下班就低头刷手机&#xff0c;年纪轻轻颈椎腰椎先 “垮了”&#xff1f; 从最开始的脖子发酸、腰部发僵&#xff0c;到后来疼到睡不着觉、手麻到握不住鼠标&#xff0c;甚至走路都直不起腰&#xff0c;这 3…...

告别格式枷锁:ncmdumpGUI让音乐自由播放变得触手可及

告别格式枷锁&#xff1a;ncmdumpGUI让音乐自由播放变得触手可及 【免费下载链接】ncmdumpGUI C#版本网易云音乐ncm文件格式转换&#xff0c;Windows图形界面版本 项目地址: https://gitcode.com/gh_mirrors/nc/ncmdumpGUI 开篇痛点直击&#xff1a;那些被NCM格式困住的…...

3个技巧让Poppins字体为你的设计项目增添国际范儿

3个技巧让Poppins字体为你的设计项目增添国际范儿 【免费下载链接】Poppins Poppins, a Devanagari Latin family for Google Fonts. 项目地址: https://gitcode.com/gh_mirrors/po/Poppins 还在为多语言项目找不到统一风格的字体而烦恼吗&#xff1f;Poppins这款现代几…...

Guohua Diffusion 创意编程:用Processing可视化交互控制图像生成

Guohua Diffusion 创意编程&#xff1a;用Processing可视化交互控制图像生成 你有没有想过&#xff0c;自己随手画的一条线、选择的一个颜色&#xff0c;能立刻变成一幅由AI生成的完整画作&#xff1f;这听起来像是科幻电影里的场景&#xff0c;但现在&#xff0c;通过一点创意…...

Phi-4-mini-reasoning企业应用探索:智能客服知识推理模块集成方案

Phi-4-mini-reasoning企业应用探索&#xff1a;智能客服知识推理模块集成方案 1. 轻量级推理模型的价值 在当今企业智能化转型浪潮中&#xff0c;轻量级推理模型正成为技术落地的关键。Phi-4-mini-reasoning作为一款专注于高质量推理的开源模型&#xff0c;凭借其128K令牌的超…...