10-pytorch-完整模型训练
b站小土堆pytorch教程学习笔记
一、从零开始构建自己的神经网络
1.模型构建
#准备数据集
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoadertrain_data=torchvision.datasets.CIFAR10('dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
#查看训练数据集和测试集大小
train_data_size=len(train_data)
test_data_size=len(test_data)
print('训练数据集长度为:{}'.format(train_data_size))#训练数据集长度为:50000
print('测试数据集长度为:{}'.format(test_data_size))#测试数据集长度为:10000#利用datalo加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)#搭建神经网络,在model文件中搭建网络,在此文件中引用
han=Han()#损失函数
loss_fn=nn.CrossEntropyLoss()#优化器
# learning_rate=0.01
learning_rate=1e-2
optimizer=torch.optim.SGD(han.parameters(),lr=learning_rate)#设置训练网络的相关参数
total_train_step = 0#记录训练的次数
total_test_step = 0#记录测试的次数
epoch=10#训练轮数#添加tensorboard
writer=SummaryWriter('logs/train')for i in range(10):print('-------第{}轮训练开始-------'.format(i+1))for data in train_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)#优化器优化模型optimizer.zero_grad()#梯度清零loss.backward()#反向传播计算梯度optimizer.step()#参数优化total_train_step=total_train_step+1if total_train_step % 100==0:#逢100打印print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))#loss.item()取出tensor类型的数字writer.add_scalar('train_loss',loss.item(),total_train_step)#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()print('所有测试集上的损失:{}'.format(total_test_loss))writer.add_scalar('test_loss',total_test_loss,total_test_step)total_test_step+=1#保存每一轮模型torch.save(han,'han_{}.pth'.format(i))print('模型已保存')
writer.close()
import torch
from torch import nnclass Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':han=Han()input=torch.ones(64,3,32,32)output=han(input)print(output.shape)#torch.Size([64, 10])10表示十个类别输出概率
结果如下:

2.使用argmax计算整体正确率
#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0total_acc=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()acc=(output.argmax(1)==target).sum()#(1)横着看total_acc+=accprint('所有测试集上的损失:{}'.format(total_test_loss))print('整体测试集上的正确率:{}'.format(total_acc/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_acc', total_acc/test_data_size, total_test_step)total_test_step+=1
整体测试集上的正确率:0.27480000257492065
3.当训练或测试时存在dropout层或batch normal层,则需要在训练训练和测试前加入:
#训练前
han.train()
#测试前
han.eval()
二、使用GPU
网络模型、数据(输入、标注)、损失函数调用cuda()
1.方式1
#模型
if torch.cuda.is_available():han=han.cuda()
#损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
imgs,target=data
imgs=imgs.cuda()
target=target.cuda()
2.方式2
#定义训练设备
device=torch.device('cuda')
han=han.to(device)
imgs = imgs.to(device)
target = target.to(device)
相关文章:
10-pytorch-完整模型训练
b站小土堆pytorch教程学习笔记 一、从零开始构建自己的神经网络 1.模型构建 #准备数据集 import torch import torchvision from torch.utils.tensorboard import SummaryWriterfrom model import * from torch.utils.data import DataLoadertrain_datatorchvision.datasets.…...
高级RAG:重新排名,从原理到实现的两种主流方法
原文地址:https://pub.towardsai.net/advanced-rag-04-re-ranking-85f6ae8170b1 2024 年 2 月 14 日 重新排序在检索增强生成(RAG)过程中起着至关重要的作用。在简单的 RAG 方法中,可以检索大量上下文,但并非所有上下…...
使用logicflow流程图实例
一.背景 需要使用流程引擎开发项目,没有使用flowable、activiti这类的国外流程引擎,想使用国内的引擎二次开发,缺少单例模式的流程画图程序,都是vue、react、angluer的不适合,从网上找了antx6、logicflow、bpmn.js。感…...
Stable Diffusion 绘画入门教程(webui)-ControlNet(IP2P)
上篇文章介绍了深度Depth,这篇文章介绍下IP2P(InstructP2P), 通俗理解就是图生图,给原有图加一些效果,比如下图,左边为原图,右边为增加了效果的图: 文章目录 一、选大模型二、写提示词三、基础参…...
五力分析(Porter‘s Five Forces)
五力分析是一种用于评估竞争力的框架,由哈佛商学院教授迈克尔波特(Michael Porter)提出。它通过分析一个行业的五个关键力量(竞争对手、供应商、顾客、替代品和新进入者)来评估一个企业或行业的竞争环境。这个框架可以…...
十一、Qt数据库操作
一、Sql介绍 Qt Sql模块包含多个类,实现数据库的连接,Sql语句的执行,数据获取与界面显示,数据与界面直接使用Model/View结构。1、使用Sql模块 (1)工程加入 QT sql(2)添加头文件 …...
【Spring】IoC容器 控制反转 与 DI依赖注入 XML实现版本 第二期
文章目录 基于 XML 配置方式组件管理前置 准备项目一、 组件(Bean)信息声明配置(IoC):1.1 基于无参构造1.2 基于静态 工厂方法实例化1.3 基于非静态 工厂方法实例化 二、 组件(Bean)依赖注入配置…...
神经网络系列---感知机(Neuron)
文章目录 感知机(Neuron)感知机(Neuron)的决策函数可以表示为:感知机(Neuron)的学习算法主要包括以下步骤:感知机可以实现逻辑运算中的AND、OR、NOT和异或(XOR)运算。 感知机(Neuron) 感知机(Neuron)是一种简单而有效的二分类算法,用于将输入…...
k8s(2)
目录 一.二进制部署k8s 常见的K8S安装部署方式: k8s部署 二进制与高可用的区别 二.部署k8s 初始化操作: 每台node安装docker: 在 master01 节点上操作; 准备cfssl证书生成工具:: 执行脚本文件: 拉入etcd压缩包…...
利用nginx内部访问特性实现静态资源授权访问
在nginx中,将静态资源设为internal;然后将前端的静态资源地址改为指向后端,在后端的响应头部中写上静态资源地址。 近期客户对我们项目做安全性测评,暴露出一些安全性问题,其中一个是有些静态页面(*.html&…...
fly-barrage 前端弹幕库(1):项目介绍
fly-barrage 是我写的一个前端弹幕库,由于经常在 Bilibili 上看视频,所以对网页的弹幕功能一直蛮感兴趣的,所以做了这个库,可以帮助前端快速的实现弹幕功能。 项目官网地址:https://fly-barrage.netlify.app/ÿ…...
jetcache如果一个主体涉及多个缓存时编辑或者删除时如何同时失效多个缓存
在实际使用过程中,可能会遇到这种情形:一个主体会有多个缓存,比如用户基础信息缓存、用户详情缓存,那么当删除用户信息后就需要同时失效多个缓存中该主体数据,那么jetcache支持这种应用场景么,答案是支持&a…...
uni-app 实现拍照后给照片加水印功能
遇到个需求需要实现,研究了一下后写了个demo 本质上就是把拍完照后的照片放到canvas里,然后加上水印样式然后再重新生成一张图片 代码如下,看注释即可~使用的话记得还是得优化下代码 <template><view class"content"&g…...
【ArcGIS】利用DEM进行水文分析:流向/流量等
利用DEM进行水文分析 ArcGIS实例参考 水文分析通过建立地表水文模型,研究与地表水流相关的各种自然现象,在城市和区域规划、农业及森林、交通道路等许多领域具有广泛的应用。 ArcGIS实例 某流域30m分辨率DEM如下: (1)…...
论文阅读笔记——PathAFL:Path-Coverage Assisted Fuzzing
文章目录 前言PathAFL:Path-Coverage Assisted Fuzzing1、解决的问题和目标2、技术路线2.1、如何识别 h − p a t h h-path h−path?2.2、如何减少 h − p a t h h-path h−path的数量?2.3、哪些h-path将被添加到种子队列?2.4、种…...
C语言中各种运算符用法
C语言中有许多不同的运算符,用于执行各种不同的操作。 以下是C语言中常见的运算符及其用法: 算术运算符: 加法运算符():用于将两个值相加。减法运算符(-):用于将一个值减…...
pythonJax小记(五):python: 使用Jax深度图像(正交投影和透视投影之间的转换)(持续更新,评论区可以补充)
python: 使用Jax深度图像(正交投影和透视投影之间的转换) 前言问题描述1. 透视投影2. 正交投影 直接上代码解释1. compute_projection_parameters 函数a. 参数解释b. 函数计算 2. ortho_to_persp 函数a. 计算投影参数:b. 生成像素坐标网格&am…...
web安全学习笔记【16】——信息打点(6)
信息打点-语言框架&开发组件&FastJson&Shiro&Log4j&SpringBoot等[1] #知识点: 1、业务资产-应用类型分类 2、Web单域名获取-接口查询 3、Web子域名获取-解析枚举 4、Web架构资产-平台指纹识别 ------------------------------------ 1、开源-C…...
145.二叉树的后序遍历
// 定义一个名为Solution的类,用于解决二叉树的后序遍历问题 class Solution { // 定义一个公共方法,输入是一个二叉树的根节点,返回一个包含后序遍历结果的整数列表 public List<Integer> postorderTraversal(TreeNode root) { /…...
ssh远程连接免密码访问
我们在远程登录的时候,经常需要输入密码,密码往往比较复杂,输入比较耗费时间,这种情况下可以使用ssh免密码登录。 一般的教程是需要生成ssh密钥后,然后把密钥复制到server端完成配置,这里提供一个简单的方…...
[特殊字符] 智能合约中的数据是如何在区块链中保持一致的?
🧠 智能合约中的数据是如何在区块链中保持一致的? 为什么所有区块链节点都能得出相同结果?合约调用这么复杂,状态真能保持一致吗?本篇带你从底层视角理解“状态一致性”的真相。 一、智能合约的数据存储在哪里…...
7.4.分块查找
一.分块查找的算法思想: 1.实例: 以上述图片的顺序表为例, 该顺序表的数据元素从整体来看是乱序的,但如果把这些数据元素分成一块一块的小区间, 第一个区间[0,1]索引上的数据元素都是小于等于10的, 第二…...
VB.net复制Ntag213卡写入UID
本示例使用的发卡器:https://item.taobao.com/item.htm?ftt&id615391857885 一、读取旧Ntag卡的UID和数据 Private Sub Button15_Click(sender As Object, e As EventArgs) Handles Button15.Click轻松读卡技术支持:网站:Dim i, j As IntegerDim cardidhex, …...
Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
Cloudflare 从 Nginx 到 Pingora:性能、效率与安全的全面升级
在互联网的快速发展中,高性能、高效率和高安全性的网络服务成为了各大互联网基础设施提供商的核心追求。Cloudflare 作为全球领先的互联网安全和基础设施公司,近期做出了一个重大技术决策:弃用长期使用的 Nginx,转而采用其内部开发…...
Spring AI 入门:Java 开发者的生成式 AI 实践之路
一、Spring AI 简介 在人工智能技术快速迭代的今天,Spring AI 作为 Spring 生态系统的新生力量,正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务(如 OpenAI、Anthropic)的无缝对接&…...
多种风格导航菜单 HTML 实现(附源码)
下面我将为您展示 6 种不同风格的导航菜单实现,每种都包含完整 HTML、CSS 和 JavaScript 代码。 1. 简约水平导航栏 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport&qu…...
Linux --进程控制
本文从以下五个方面来初步认识进程控制: 目录 进程创建 进程终止 进程等待 进程替换 模拟实现一个微型shell 进程创建 在Linux系统中我们可以在一个进程使用系统调用fork()来创建子进程,创建出来的进程就是子进程,原来的进程为父进程。…...
学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”
2025年#高考 将在近日拉开帷幕,#AI 监考一度冲上热搜。当AI深度融入高考,#时间同步 不再是辅助功能,而是决定AI监考系统成败的“生命线”。 AI亮相2025高考,40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕,江西、…...
【Nginx】使用 Nginx+Lua 实现基于 IP 的访问频率限制
使用 NginxLua 实现基于 IP 的访问频率限制 在高并发场景下,限制某个 IP 的访问频率是非常重要的,可以有效防止恶意攻击或错误配置导致的服务宕机。以下是一个详细的实现方案,使用 Nginx 和 Lua 脚本结合 Redis 来实现基于 IP 的访问频率限制…...
