10-pytorch-完整模型训练
b站小土堆pytorch教程学习笔记
一、从零开始构建自己的神经网络
1.模型构建
#准备数据集
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoadertrain_data=torchvision.datasets.CIFAR10('dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
#查看训练数据集和测试集大小
train_data_size=len(train_data)
test_data_size=len(test_data)
print('训练数据集长度为:{}'.format(train_data_size))#训练数据集长度为:50000
print('测试数据集长度为:{}'.format(test_data_size))#测试数据集长度为:10000#利用datalo加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)#搭建神经网络,在model文件中搭建网络,在此文件中引用
han=Han()#损失函数
loss_fn=nn.CrossEntropyLoss()#优化器
# learning_rate=0.01
learning_rate=1e-2
optimizer=torch.optim.SGD(han.parameters(),lr=learning_rate)#设置训练网络的相关参数
total_train_step = 0#记录训练的次数
total_test_step = 0#记录测试的次数
epoch=10#训练轮数#添加tensorboard
writer=SummaryWriter('logs/train')for i in range(10):print('-------第{}轮训练开始-------'.format(i+1))for data in train_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)#优化器优化模型optimizer.zero_grad()#梯度清零loss.backward()#反向传播计算梯度optimizer.step()#参数优化total_train_step=total_train_step+1if total_train_step % 100==0:#逢100打印print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))#loss.item()取出tensor类型的数字writer.add_scalar('train_loss',loss.item(),total_train_step)#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()print('所有测试集上的损失:{}'.format(total_test_loss))writer.add_scalar('test_loss',total_test_loss,total_test_step)total_test_step+=1#保存每一轮模型torch.save(han,'han_{}.pth'.format(i))print('模型已保存')
writer.close()
import torch
from torch import nnclass Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':han=Han()input=torch.ones(64,3,32,32)output=han(input)print(output.shape)#torch.Size([64, 10])10表示十个类别输出概率
结果如下:

2.使用argmax计算整体正确率
#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0total_acc=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()acc=(output.argmax(1)==target).sum()#(1)横着看total_acc+=accprint('所有测试集上的损失:{}'.format(total_test_loss))print('整体测试集上的正确率:{}'.format(total_acc/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_acc', total_acc/test_data_size, total_test_step)total_test_step+=1
整体测试集上的正确率:0.27480000257492065
3.当训练或测试时存在dropout层或batch normal层,则需要在训练训练和测试前加入:
#训练前
han.train()
#测试前
han.eval()
二、使用GPU
网络模型、数据(输入、标注)、损失函数调用cuda()
1.方式1
#模型
if torch.cuda.is_available():han=han.cuda()
#损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
imgs,target=data
imgs=imgs.cuda()
target=target.cuda()
2.方式2
#定义训练设备
device=torch.device('cuda')
han=han.to(device)
imgs = imgs.to(device)
target = target.to(device)
相关文章:
10-pytorch-完整模型训练
b站小土堆pytorch教程学习笔记 一、从零开始构建自己的神经网络 1.模型构建 #准备数据集 import torch import torchvision from torch.utils.tensorboard import SummaryWriterfrom model import * from torch.utils.data import DataLoadertrain_datatorchvision.datasets.…...
高级RAG:重新排名,从原理到实现的两种主流方法
原文地址:https://pub.towardsai.net/advanced-rag-04-re-ranking-85f6ae8170b1 2024 年 2 月 14 日 重新排序在检索增强生成(RAG)过程中起着至关重要的作用。在简单的 RAG 方法中,可以检索大量上下文,但并非所有上下…...
使用logicflow流程图实例
一.背景 需要使用流程引擎开发项目,没有使用flowable、activiti这类的国外流程引擎,想使用国内的引擎二次开发,缺少单例模式的流程画图程序,都是vue、react、angluer的不适合,从网上找了antx6、logicflow、bpmn.js。感…...
Stable Diffusion 绘画入门教程(webui)-ControlNet(IP2P)
上篇文章介绍了深度Depth,这篇文章介绍下IP2P(InstructP2P), 通俗理解就是图生图,给原有图加一些效果,比如下图,左边为原图,右边为增加了效果的图: 文章目录 一、选大模型二、写提示词三、基础参…...
五力分析(Porter‘s Five Forces)
五力分析是一种用于评估竞争力的框架,由哈佛商学院教授迈克尔波特(Michael Porter)提出。它通过分析一个行业的五个关键力量(竞争对手、供应商、顾客、替代品和新进入者)来评估一个企业或行业的竞争环境。这个框架可以…...
十一、Qt数据库操作
一、Sql介绍 Qt Sql模块包含多个类,实现数据库的连接,Sql语句的执行,数据获取与界面显示,数据与界面直接使用Model/View结构。1、使用Sql模块 (1)工程加入 QT sql(2)添加头文件 …...
【Spring】IoC容器 控制反转 与 DI依赖注入 XML实现版本 第二期
文章目录 基于 XML 配置方式组件管理前置 准备项目一、 组件(Bean)信息声明配置(IoC):1.1 基于无参构造1.2 基于静态 工厂方法实例化1.3 基于非静态 工厂方法实例化 二、 组件(Bean)依赖注入配置…...
神经网络系列---感知机(Neuron)
文章目录 感知机(Neuron)感知机(Neuron)的决策函数可以表示为:感知机(Neuron)的学习算法主要包括以下步骤:感知机可以实现逻辑运算中的AND、OR、NOT和异或(XOR)运算。 感知机(Neuron) 感知机(Neuron)是一种简单而有效的二分类算法,用于将输入…...
k8s(2)
目录 一.二进制部署k8s 常见的K8S安装部署方式: k8s部署 二进制与高可用的区别 二.部署k8s 初始化操作: 每台node安装docker: 在 master01 节点上操作; 准备cfssl证书生成工具:: 执行脚本文件: 拉入etcd压缩包…...
利用nginx内部访问特性实现静态资源授权访问
在nginx中,将静态资源设为internal;然后将前端的静态资源地址改为指向后端,在后端的响应头部中写上静态资源地址。 近期客户对我们项目做安全性测评,暴露出一些安全性问题,其中一个是有些静态页面(*.html&…...
fly-barrage 前端弹幕库(1):项目介绍
fly-barrage 是我写的一个前端弹幕库,由于经常在 Bilibili 上看视频,所以对网页的弹幕功能一直蛮感兴趣的,所以做了这个库,可以帮助前端快速的实现弹幕功能。 项目官网地址:https://fly-barrage.netlify.app/ÿ…...
jetcache如果一个主体涉及多个缓存时编辑或者删除时如何同时失效多个缓存
在实际使用过程中,可能会遇到这种情形:一个主体会有多个缓存,比如用户基础信息缓存、用户详情缓存,那么当删除用户信息后就需要同时失效多个缓存中该主体数据,那么jetcache支持这种应用场景么,答案是支持&a…...
uni-app 实现拍照后给照片加水印功能
遇到个需求需要实现,研究了一下后写了个demo 本质上就是把拍完照后的照片放到canvas里,然后加上水印样式然后再重新生成一张图片 代码如下,看注释即可~使用的话记得还是得优化下代码 <template><view class"content"&g…...
【ArcGIS】利用DEM进行水文分析:流向/流量等
利用DEM进行水文分析 ArcGIS实例参考 水文分析通过建立地表水文模型,研究与地表水流相关的各种自然现象,在城市和区域规划、农业及森林、交通道路等许多领域具有广泛的应用。 ArcGIS实例 某流域30m分辨率DEM如下: (1)…...
论文阅读笔记——PathAFL:Path-Coverage Assisted Fuzzing
文章目录 前言PathAFL:Path-Coverage Assisted Fuzzing1、解决的问题和目标2、技术路线2.1、如何识别 h − p a t h h-path h−path?2.2、如何减少 h − p a t h h-path h−path的数量?2.3、哪些h-path将被添加到种子队列?2.4、种…...
C语言中各种运算符用法
C语言中有许多不同的运算符,用于执行各种不同的操作。 以下是C语言中常见的运算符及其用法: 算术运算符: 加法运算符():用于将两个值相加。减法运算符(-):用于将一个值减…...
pythonJax小记(五):python: 使用Jax深度图像(正交投影和透视投影之间的转换)(持续更新,评论区可以补充)
python: 使用Jax深度图像(正交投影和透视投影之间的转换) 前言问题描述1. 透视投影2. 正交投影 直接上代码解释1. compute_projection_parameters 函数a. 参数解释b. 函数计算 2. ortho_to_persp 函数a. 计算投影参数:b. 生成像素坐标网格&am…...
web安全学习笔记【16】——信息打点(6)
信息打点-语言框架&开发组件&FastJson&Shiro&Log4j&SpringBoot等[1] #知识点: 1、业务资产-应用类型分类 2、Web单域名获取-接口查询 3、Web子域名获取-解析枚举 4、Web架构资产-平台指纹识别 ------------------------------------ 1、开源-C…...
145.二叉树的后序遍历
// 定义一个名为Solution的类,用于解决二叉树的后序遍历问题 class Solution { // 定义一个公共方法,输入是一个二叉树的根节点,返回一个包含后序遍历结果的整数列表 public List<Integer> postorderTraversal(TreeNode root) { /…...
ssh远程连接免密码访问
我们在远程登录的时候,经常需要输入密码,密码往往比较复杂,输入比较耗费时间,这种情况下可以使用ssh免密码登录。 一般的教程是需要生成ssh密钥后,然后把密钥复制到server端完成配置,这里提供一个简单的方…...
DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径
目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...
Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)
文章目录 1.什么是Redis?2.为什么要使用redis作为mysql的缓存?3.什么是缓存雪崩、缓存穿透、缓存击穿?3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...
【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密
在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...
高危文件识别的常用算法:原理、应用与企业场景
高危文件识别的常用算法:原理、应用与企业场景 高危文件识别旨在检测可能导致安全威胁的文件,如包含恶意代码、敏感数据或欺诈内容的文档,在企业协同办公环境中(如Teams、Google Workspace)尤为重要。结合大模型技术&…...
OPENCV形态学基础之二腐蚀
一.腐蚀的原理 (图1) 数学表达式:dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一,腐蚀跟膨胀属于反向操作,膨胀是把图像图像变大,而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...
JS手写代码篇----使用Promise封装AJAX请求
15、使用Promise封装AJAX请求 promise就有reject和resolve了,就不必写成功和失败的回调函数了 const BASEURL ./手写ajax/test.jsonfunction promiseAjax() {return new Promise((resolve, reject) > {const xhr new XMLHttpRequest();xhr.open("get&quo…...
【JavaSE】多线程基础学习笔记
多线程基础 -线程相关概念 程序(Program) 是为完成特定任务、用某种语言编写的一组指令的集合简单的说:就是我们写的代码 进程 进程是指运行中的程序,比如我们使用QQ,就启动了一个进程,操作系统就会为该进程分配内存…...
打手机检测算法AI智能分析网关V4守护公共/工业/医疗等多场景安全应用
一、方案背景 在现代生产与生活场景中,如工厂高危作业区、医院手术室、公共场景等,人员违规打手机的行为潜藏着巨大风险。传统依靠人工巡查的监管方式,存在效率低、覆盖面不足、判断主观性强等问题,难以满足对人员打手机行为精…...
【安全篇】金刚不坏之身:整合 Spring Security + JWT 实现无状态认证与授权
摘要 本文是《Spring Boot 实战派》系列的第四篇。我们将直面所有 Web 应用都无法回避的核心问题:安全。文章将详细阐述认证(Authentication) 与授权(Authorization的核心概念,对比传统 Session-Cookie 与现代 JWT(JS…...
【java面试】微服务篇
【java面试】微服务篇 一、总体框架二、Springcloud(一)Springcloud五大组件(二)服务注册和发现1、Eureka2、Nacos (三)负载均衡1、Ribbon负载均衡流程2、Ribbon负载均衡策略3、自定义负载均衡策略4、总结 …...
