10-pytorch-完整模型训练
b站小土堆pytorch教程学习笔记
一、从零开始构建自己的神经网络
1.模型构建
#准备数据集
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoadertrain_data=torchvision.datasets.CIFAR10('dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
#查看训练数据集和测试集大小
train_data_size=len(train_data)
test_data_size=len(test_data)
print('训练数据集长度为:{}'.format(train_data_size))#训练数据集长度为:50000
print('测试数据集长度为:{}'.format(test_data_size))#测试数据集长度为:10000#利用datalo加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)#搭建神经网络,在model文件中搭建网络,在此文件中引用
han=Han()#损失函数
loss_fn=nn.CrossEntropyLoss()#优化器
# learning_rate=0.01
learning_rate=1e-2
optimizer=torch.optim.SGD(han.parameters(),lr=learning_rate)#设置训练网络的相关参数
total_train_step = 0#记录训练的次数
total_test_step = 0#记录测试的次数
epoch=10#训练轮数#添加tensorboard
writer=SummaryWriter('logs/train')for i in range(10):print('-------第{}轮训练开始-------'.format(i+1))for data in train_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)#优化器优化模型optimizer.zero_grad()#梯度清零loss.backward()#反向传播计算梯度optimizer.step()#参数优化total_train_step=total_train_step+1if total_train_step % 100==0:#逢100打印print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))#loss.item()取出tensor类型的数字writer.add_scalar('train_loss',loss.item(),total_train_step)#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()print('所有测试集上的损失:{}'.format(total_test_loss))writer.add_scalar('test_loss',total_test_loss,total_test_step)total_test_step+=1#保存每一轮模型torch.save(han,'han_{}.pth'.format(i))print('模型已保存')
writer.close()
import torch
from torch import nnclass Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':han=Han()input=torch.ones(64,3,32,32)output=han(input)print(output.shape)#torch.Size([64, 10])10表示十个类别输出概率
结果如下:

2.使用argmax计算整体正确率
#每训练完一轮将在测试集上跑一遍,评估其训练效果total_test_loss=0total_acc=0with torch.no_grad():for data in test_dataloader:imgs,target=dataoutput=han(imgs)loss=loss_fn(output,target)total_test_loss=total_test_loss+loss.item()acc=(output.argmax(1)==target).sum()#(1)横着看total_acc+=accprint('所有测试集上的损失:{}'.format(total_test_loss))print('整体测试集上的正确率:{}'.format(total_acc/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_acc', total_acc/test_data_size, total_test_step)total_test_step+=1
整体测试集上的正确率:0.27480000257492065
3.当训练或测试时存在dropout层或batch normal层,则需要在训练训练和测试前加入:
#训练前
han.train()
#测试前
han.eval()
二、使用GPU
网络模型、数据(输入、标注)、损失函数调用cuda()
1.方式1
#模型
if torch.cuda.is_available():han=han.cuda()
#损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
imgs,target=data
imgs=imgs.cuda()
target=target.cuda()
2.方式2
#定义训练设备
device=torch.device('cuda')
han=han.to(device)
imgs = imgs.to(device)
target = target.to(device)
相关文章:
10-pytorch-完整模型训练
b站小土堆pytorch教程学习笔记 一、从零开始构建自己的神经网络 1.模型构建 #准备数据集 import torch import torchvision from torch.utils.tensorboard import SummaryWriterfrom model import * from torch.utils.data import DataLoadertrain_datatorchvision.datasets.…...
高级RAG:重新排名,从原理到实现的两种主流方法
原文地址:https://pub.towardsai.net/advanced-rag-04-re-ranking-85f6ae8170b1 2024 年 2 月 14 日 重新排序在检索增强生成(RAG)过程中起着至关重要的作用。在简单的 RAG 方法中,可以检索大量上下文,但并非所有上下…...
使用logicflow流程图实例
一.背景 需要使用流程引擎开发项目,没有使用flowable、activiti这类的国外流程引擎,想使用国内的引擎二次开发,缺少单例模式的流程画图程序,都是vue、react、angluer的不适合,从网上找了antx6、logicflow、bpmn.js。感…...
Stable Diffusion 绘画入门教程(webui)-ControlNet(IP2P)
上篇文章介绍了深度Depth,这篇文章介绍下IP2P(InstructP2P), 通俗理解就是图生图,给原有图加一些效果,比如下图,左边为原图,右边为增加了效果的图: 文章目录 一、选大模型二、写提示词三、基础参…...
五力分析(Porter‘s Five Forces)
五力分析是一种用于评估竞争力的框架,由哈佛商学院教授迈克尔波特(Michael Porter)提出。它通过分析一个行业的五个关键力量(竞争对手、供应商、顾客、替代品和新进入者)来评估一个企业或行业的竞争环境。这个框架可以…...
十一、Qt数据库操作
一、Sql介绍 Qt Sql模块包含多个类,实现数据库的连接,Sql语句的执行,数据获取与界面显示,数据与界面直接使用Model/View结构。1、使用Sql模块 (1)工程加入 QT sql(2)添加头文件 …...
【Spring】IoC容器 控制反转 与 DI依赖注入 XML实现版本 第二期
文章目录 基于 XML 配置方式组件管理前置 准备项目一、 组件(Bean)信息声明配置(IoC):1.1 基于无参构造1.2 基于静态 工厂方法实例化1.3 基于非静态 工厂方法实例化 二、 组件(Bean)依赖注入配置…...
神经网络系列---感知机(Neuron)
文章目录 感知机(Neuron)感知机(Neuron)的决策函数可以表示为:感知机(Neuron)的学习算法主要包括以下步骤:感知机可以实现逻辑运算中的AND、OR、NOT和异或(XOR)运算。 感知机(Neuron) 感知机(Neuron)是一种简单而有效的二分类算法,用于将输入…...
k8s(2)
目录 一.二进制部署k8s 常见的K8S安装部署方式: k8s部署 二进制与高可用的区别 二.部署k8s 初始化操作: 每台node安装docker: 在 master01 节点上操作; 准备cfssl证书生成工具:: 执行脚本文件: 拉入etcd压缩包…...
利用nginx内部访问特性实现静态资源授权访问
在nginx中,将静态资源设为internal;然后将前端的静态资源地址改为指向后端,在后端的响应头部中写上静态资源地址。 近期客户对我们项目做安全性测评,暴露出一些安全性问题,其中一个是有些静态页面(*.html&…...
fly-barrage 前端弹幕库(1):项目介绍
fly-barrage 是我写的一个前端弹幕库,由于经常在 Bilibili 上看视频,所以对网页的弹幕功能一直蛮感兴趣的,所以做了这个库,可以帮助前端快速的实现弹幕功能。 项目官网地址:https://fly-barrage.netlify.app/ÿ…...
jetcache如果一个主体涉及多个缓存时编辑或者删除时如何同时失效多个缓存
在实际使用过程中,可能会遇到这种情形:一个主体会有多个缓存,比如用户基础信息缓存、用户详情缓存,那么当删除用户信息后就需要同时失效多个缓存中该主体数据,那么jetcache支持这种应用场景么,答案是支持&a…...
uni-app 实现拍照后给照片加水印功能
遇到个需求需要实现,研究了一下后写了个demo 本质上就是把拍完照后的照片放到canvas里,然后加上水印样式然后再重新生成一张图片 代码如下,看注释即可~使用的话记得还是得优化下代码 <template><view class"content"&g…...
【ArcGIS】利用DEM进行水文分析:流向/流量等
利用DEM进行水文分析 ArcGIS实例参考 水文分析通过建立地表水文模型,研究与地表水流相关的各种自然现象,在城市和区域规划、农业及森林、交通道路等许多领域具有广泛的应用。 ArcGIS实例 某流域30m分辨率DEM如下: (1)…...
论文阅读笔记——PathAFL:Path-Coverage Assisted Fuzzing
文章目录 前言PathAFL:Path-Coverage Assisted Fuzzing1、解决的问题和目标2、技术路线2.1、如何识别 h − p a t h h-path h−path?2.2、如何减少 h − p a t h h-path h−path的数量?2.3、哪些h-path将被添加到种子队列?2.4、种…...
C语言中各种运算符用法
C语言中有许多不同的运算符,用于执行各种不同的操作。 以下是C语言中常见的运算符及其用法: 算术运算符: 加法运算符():用于将两个值相加。减法运算符(-):用于将一个值减…...
pythonJax小记(五):python: 使用Jax深度图像(正交投影和透视投影之间的转换)(持续更新,评论区可以补充)
python: 使用Jax深度图像(正交投影和透视投影之间的转换) 前言问题描述1. 透视投影2. 正交投影 直接上代码解释1. compute_projection_parameters 函数a. 参数解释b. 函数计算 2. ortho_to_persp 函数a. 计算投影参数:b. 生成像素坐标网格&am…...
web安全学习笔记【16】——信息打点(6)
信息打点-语言框架&开发组件&FastJson&Shiro&Log4j&SpringBoot等[1] #知识点: 1、业务资产-应用类型分类 2、Web单域名获取-接口查询 3、Web子域名获取-解析枚举 4、Web架构资产-平台指纹识别 ------------------------------------ 1、开源-C…...
145.二叉树的后序遍历
// 定义一个名为Solution的类,用于解决二叉树的后序遍历问题 class Solution { // 定义一个公共方法,输入是一个二叉树的根节点,返回一个包含后序遍历结果的整数列表 public List<Integer> postorderTraversal(TreeNode root) { /…...
ssh远程连接免密码访问
我们在远程登录的时候,经常需要输入密码,密码往往比较复杂,输入比较耗费时间,这种情况下可以使用ssh免密码登录。 一般的教程是需要生成ssh密钥后,然后把密钥复制到server端完成配置,这里提供一个简单的方…...
如何通过Jellyfin Android TV客户端打造家庭影院级媒体体验?
如何通过Jellyfin Android TV客户端打造家庭影院级媒体体验? 【免费下载链接】jellyfin-androidtv Android TV Client for Jellyfin 项目地址: https://gitcode.com/gh_mirrors/je/jellyfin-androidtv 想要在智能电视上享受专业的媒体管理体验吗?…...
OpenAgentsControl:构建多智能体协同系统的开源框架解析
1. 项目概述:一个面向智能体控制的开放框架最近在折腾AI智能体(Agent)相关的项目,发现一个挺有意思的开源仓库:darrenhinde/OpenAgentsControl。这个项目名字直译过来就是“开放智能体控制”,听起来就很有搞…...
I2C地址冲突全解析:从原理到实战的嵌入式系统设计指南
1. I2C地址:嵌入式系统设计的“门牌号”与“交通规则”如果你玩过单片机或者树莓派,肯定对I2C不陌生。两根线,SDA和SCL,就能挂上一堆传感器、显示屏、扩展芯片,听起来简直是嵌入式开发的“万金油”。但真正上手后&…...
Godot引擎实验项目解析:从角色控制到着色器优化的实战指南
1. 项目概述与核心价值如果你是一名游戏开发者,尤其是对独立游戏开发充满热情,那么“Godot”这个名字对你来说一定不陌生。它是一个功能强大、开源免费的游戏引擎,以其轻量、高效和友好的编辑器而闻名。然而,引擎本身只是一个工具…...
开源提示词管理工具:本地化部署与AI工作流效率提升实践
1. 项目概述:一个为AI工作流设计的提示词管理利器如果你和我一样,每天都在和ChatGPT、Claude、Midjourney这些AI模型打交道,那你一定有过这样的烦恼:昨天精心调试好的、能稳定输出高质量代码的提示词,今天想用的时候&a…...
大语言模型分步推理与自我验证框架:提升AI生成准确性的工程实践
1. 项目概述:当AI学会“自我验证”最近在开源社区里,一个名为“Lets-Verify-Step-by-Step”的项目引起了我的注意。这个项目直指当前大语言模型(LLM)应用中的一个核心痛点:如何让模型在生成复杂答案时,能像…...
提示工程实战:从核心模式到高级技巧的AI交互优化指南
1. 项目概述:从代码仓库到提示工程实战指南最近在GitHub上看到一个名为“SKY-lv/prompt-engineer”的仓库,点进去一看,发现这不仅仅是一个简单的代码集合,更像是一位资深从业者(SKY-lv)精心整理的提示工程实…...
动态目标跨镜无缝接力追踪技术白皮书
一、前言在全域视觉监控、智能安防、智慧园区、交通管控、工业巡检等核心场景中,动态目标(人员、车辆、设备等)的跨摄像头连续追踪是实现智能化管理的核心需求。当前行业常规追踪方案普遍存在轨迹断点、坐标漂移、身份错乱等痛点,…...
AI科技热点日报 | 2026年5月16日
文章目录AI科技热点日报 | 2026年5月16日一、大模型与基础技术《人工智能终端智能化分级》系列国家标准发布"九章四号"量子计算原型机刷新世界纪录二、AI政策与监管人工智能科技伦理审查与服务先导计划启动工信部部署高质量行业数据集建设三、Agent与应用"AI教育…...
2026 私域救命玩法!90% 的老板赚不到钱,根本不是产品不行
我在杭州做电商、做私域、做投资这么多年,见过各行各业的起起伏伏。这些年接触过的实体老板,没有一百也有八十。手里握着工厂的、拿着自主知识产权的、有正规生产资质的,比比皆是。但 90% 的人都在亏钱。他们天天抱怨流量太贵、同行乱价、客户…...
