【pytorch框架】对模型知识的基本了解
文章目录
- TensorBoard的使用
- 1、TensorBoard启动:
- 2、使用TensorBoard查看一张图片
- 3、transforms的使用
- pytorch框架基础知识
- 1 nn.module的使用
- 2 nn.conv2d的使用
- 3、池化(MaxPool2d)
- 4 非线性激活
- 5 线性层
- 6 Sequential的使用
- 7 损失函数与反向传播
- 8 优化器
- 9 对现有网络的使用和修改
- 10 网络模型的保存与读取
TensorBoard的使用
1、TensorBoard启动:
在Terminal终端命令中输入:
tensorboard --logdir=logs #logs为创建的文件名
2、使用TensorBoard查看一张图片
writer=SummaryWriter("../logs")
image_path=r'F:\image\1.jpg'
img_PIL=Image.open(image_path)
image_array=np.array(img_PIL)
writer.add_image('test',image_array,1,dataformats='HWC')
writer.close()
3、transforms的使用
作用:使PIL Image 或者np ——》tensor
imgae_path=r'F:\image\1.jpg'
img=Image.open(img_path)
tensor_trans=transsforms.ToTensor() #相当于创建一个工具
tensor_img=tensor_trans(img) #img转化成tensor模式
同理,ToPILIMage是为了tensor 或者 ndarray =》Image
pytorch框架基础知识
1 nn.module的使用
目的:给所有网络提供基本骨架
from torch import nn
class aiy(nn.Module):def __init__(self):super().__init__()def forward(sel,input):output=input+1return outputaiy=aiy()
# x=torch.tensor(1.0)
x=1
output=aiy(x)
print(output)
'''
2
'''
2 nn.conv2d的使用
参数代码解释如下:
示例:输入一个5x5的矩阵,和一个3x3的卷积核做卷积操作
import torch
import torch.nn.functional as F
input=torch.tensor([[1,2,0,3,1],[0,1,2,3,1],[1,2,1,0,0],[5,2,3,1,1],[2,1,0,1,1]])
input=input.reshape([1,1,5,5])kears=torch.tensor([[1,2,1],[0,1,0],[2,1,0]])
kears=kears.reshape([1,1,3,3])output=F.conv2d(input,kears,stride=1)print(output)
print(output.shape)
'''
tensor([[[[10, 12, 12], [18, 16, 16],[13, 9, 3]]]])
torch.Size([1, 1, 3, 3])
'''
若是输入的卷积核的数量有两个,则得到的output也是两个
示例:借用CIFAR10数据集,用自定义的网络模型做一次卷积操作,然后用tensorboard查看卷积之后的结果。
这里需要注意的是,经过卷积得到的大小是[64,6,30,30],而图片的通道一般都是3通道的,6通道的图片不知道怎么显示,需要使用reshpae重新改变矩阵的大小。
output=output.reshape([-1,3,30,30]) #-1自动计算剩余的值,后面[3,30,30]改成指定大小
示例代码:
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import Conv2d
from torch.utils.tensorboard import SummaryWriter#数据准备
dataset=torchvision.datasets.CIFAR10("./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader=DataLoader(dataset,batch_size=64)#自定义网络模型
class aiy(nn.Module):def __init__(self):super(aiy, self).__init__()#卷积运算self.conv1=Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)def forward(self,x):x=self.conv1(x)return xaiy=aiy()
# print(aiy)
step=0
writer=SummaryWriter("../log")for data in dataloader:img,targets=dataoutput=aiy(img)# print(img.shape)#torch.Size([64, 3, 32, 32]# print(output.shape)#torch.Size([64, 6, 30, 30])#因为图片的通道是3,需要改变矩阵的大小# output=output.reshape([-1,3,30,30])writer.add_images("input",img,step)output=torch.reshape(output,(-1,3,30,30))writer.add_images("output", output, step)# print(output.shape)step=step+1print(step)
writer.close()
3、池化(MaxPool2d)
目的:降采样,大幅减少网络的参数量,同时保留图像数据的特征。
需要注意的是: 池化不改变通道数
池化参数
数组演示示例:
input=torch.tensor([[1,2,0,3,1],[0,1,2,3,1],[1,2,1,0,0],[5,2,3,1,1],[2,1,0,1,1]],dtype=float)
input=torch.reshape(input,(-1,1,5,5))output=aiy(input)
print(output.shape)
'''
ceil_mode=True:
tensor([[[[2., 3.],[5., 1.]]]], dtype=torch.float64)ceil_mode=False:
tensor([[[[2.]]]], dtype=torch.float64)
'''
示例:同样,借用CIFAR10数据集,用自定义的网络模型做一次池化操作,然后用tensorboard查看卷积之后的结果。
# -*- coding: utf-8 -*-
# Auter:我菜就爱学import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d#带入数组
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterclass aiy(nn.Module):def __init__(self):super(aiy, self).__init__()self.maxpool1=MaxPool2d(kernel_size=3,ceil_mode=False)def forward(self,input):output=self.maxpool1(input)return outputaiy=aiy()#将池化层用数据集测试
dataset=torchvision.datasets.CIFAR10('./data',train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader=DataLoader(dataset,batch_size=64)step=0
writer=SummaryWriter("../logmaxpool")for data in dataloader:img,target=datawriter.add_images("input",img,step)output=aiy(img)writer.add_images("output",output,step)step=step+1
有点像打马赛克
4 非线性激活
作用:提高泛化能力,引入非线性特征
ReLu(input,inplace=True)
=>表示原input替换input
out=ReLu(input,inplace=False)
=>表示原input被out替换
5 线性层
6 Sequential的使用
作用:可以简化自己搭建的网络模型
示例:参考CIFAR10的网络模型结构,创建一个网络。
# -*- coding: utf-8 -*-
# Auter:我菜就爱学import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.tensorboard import SummaryWriterclass Aiy(nn.Module):def __init__(self):super(Aiy, self).__init__()self.model=Sequential(Conv2d(3,32,5,padding=2,stride=1),MaxPool2d(kernel_size=2),Conv2d(32,32,5,padding=2),MaxPool2d(kernel_size=2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,input):output=self.model(input)return outputaiy=Aiy()# print(aiy)input=torch.ones((64,3,32,32))output=aiy(input)print(output.shape)
使用tensorboard中的命令可以查看网络模型结构
writer=SummaryWriter('../logmodel')writer.add_graph(aiy,input)writer.close()
7 损失函数与反向传播
作用:
- 计算处实际输出与目标之间的差距
- 更新输出提供一定的依据
通过小土堆举的示例可以很好的理解损失函数
说明:假设一张试卷满分是100分,其中选择30,填空20,解答50.第一次我们得到的结果是:选择10,填空10,解答20.第一次损失值是60.
然后通过不断的训练,让选择提高到20,填空提高20,解答提高到40,这个时候与满分差距20,损失值也就越来越小。
# -*- coding: utf-8 -*-
# Auter:我菜就爱学
import torch
from torch.nn import L1Lossinput=torch.tensor([1,2,3],dtype=torch.float32)
input=torch.reshape(input,(1,1,1,3))
target=torch.tensor([1,2,5],dtype=torch.float32)
target=torch.reshape(target,(1,1,1,3))#设置一个损失函数
loss=L1Loss(reduction='sum')
output=loss(input,target)
print(output)
'''
tensor(2.)
'''
8 优化器
优化器参数解释:
for epoch in range (20):sum_loss=0.0for data in dataloader:imgs,targets=dataoutput=aiy(imgs)result_loss=loss(output,targets)optim.zero_grad()result_loss.backward() #反向传播,更新对应的梯度optim.step() #调整更新的参数sum_loss=sum_loss+result_lossprint(sum_loss)
下面是对优化器中的交叉熵的解释:
9 对现有网络的使用和修改
- 下载现有网络,并使用数据集更新好的参数
vgg16_True=torchvision models vgg16(pretrained=True)
一般下载好的模型保存路径:==C:\user.cache\torch\hub\checkpoints
- 在已有的网络模型中新添自己需要的层
vgg16_True.classifier.add_module("7",nn.Linear(1000,10))
10 网络模型的保存与读取
方法一:直接把模型和参数保存下来
注意: 有 一个陷阱,自定义的模型在下载的时候运行会报错,得需要复制下载原模型。只能导入专门经典的模型
#保存
torch.save(vgg16_true,"vgg16_method1.pth")#下载
model=torch.load("vgg16_method1.pth")
方法二:保存模型的参数,一般使用这个。内存比较小,节省空间;以字典的形式保存。
#保存
torch.save(vgg16_true.state_dict(),"vgg16_method1.pth")#下载
vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_false.load_state_dict(torch.long("wgg166_method2.pth"))
相关文章:

【pytorch框架】对模型知识的基本了解
文章目录TensorBoard的使用1、TensorBoard启动:2、使用TensorBoard查看一张图片3、transforms的使用pytorch框架基础知识1 nn.module的使用2 nn.conv2d的使用3、池化(MaxPool2d)4 非线性激活5 线性层6 Sequential的使用7 损失函数与反向传播8 优化器9 对现有网络的使…...
SUP桨板电动气泵方案——鼎盛合方案
SUP桨板是现时最热门的水上运动之一,它的全称是Stand Up Paddle,简称SUP。这项运动近几年在我国三亚等地区风靡一时,在网上经常看到一些运动博主或者明星网红晒出冲浪视频,刺激又惊险。SUP桨板为充气式桨板,需要通过充…...

小白系列Vite-Vue3-TypeScript:011-登录界面搭建及动态路由配置
前面几篇文章我们介绍的都是ViteVue3TypeScript项目中环境相关的配置,接下来我们开始进入系统搭建部分。本篇我们来介绍登录界面搭建及动态路由配置,大家一起撸起来......搭建登录界面登陆接口api项目登陆接口是通过mockjs前端来模拟的模拟服务接口Login…...

C语言( 缓冲区和重定向)
一.缓冲输入,无缓存输入 while((chgetchar()) ! #) putchar(ch); 这里getchar(),putchar()每次只处理一个字符(这里只是知道就好了),而我们使用while循环,当读到#字符时停止 而看到输出例子,第一行我们输入…...
编程思想、方法论和架构的类型及应用
概要编程思想是指在编写代码时所采用的基本思维方式和方法论。分类编程思想编程思想为软件开发提供了思维范式和指导思路,例如面向对象思想、函数式编程思想等,它们帮助程序员更好地抽象问题、组织代码、提高代码复用性和可维护性,包括一下几…...
【OA办公】OA流程审批大揭秘,带你看遍所有基础流程
流程审批,是所有企业的OA办公系统重要组成部分,是任何OA办公系统都不可缺少的。比起传统的纸张传阅、签批的审批模式浪费了大量的时间和成本,因此越来越多的企业采用OA这种全新的、高效的、智能的审批模式。流程审批除了这些好处,…...
《零基础入门数据结构与算法》专栏介绍
目录 前言 第一部分:重点 第二部分:题库 第三部分:测试 第四部分:实验 第五部分:试卷 总结 前言 本专栏主要分为五个部分: ① 重要知识点详解 ② 近百道练习题解析 ③ 数据结构与算法章节测试 …...

测试开发之Django实战示例 第九章 扩展商店功能
第九章 扩展商店功能在上一章里,为电商站点集成了支付功能,然后可以生成PDF发票发送给用户。在本章,我们将为商店添加优惠码功能。此外,还会学习国际化和本地化的设置和建立一个推荐商品的系统。本章涵盖如下要点:建立…...

【Spring】一文带你吃透AOP面向切面编程技术(下篇)
个人主页: 几分醉意的CSDN博客_传送门 上节我们介绍了什么是AOP、Aspectj框架的前置通知Before传送门,这篇文章将继续详解Aspectj框架的其它注解。 文章目录💖Aspectj框架介绍✨JoinPoint通知方法的参数✨后置通知AfterReturning✨环绕通知Ar…...

【java】Spring Boot --40 个 Spring Boot 常用注解(建议收藏)
本文目录一、Spring Web MVC 注解Spring Web MVC 注解RequestMappingRequestBodyGetMappingPostMappingPutMappingDeleteMappingPatchMappingControllerAdviceResponseBodyExceptionHandlerResponseStatusPathVariableRequestParamControllerRestControllerModelAttributeCross…...

《游戏学习》| 微信对话模拟生成器源码分析
简介微信对话生成器,是一款在线微信聊天对话制作的工具,它可以设置苹果或安卓状态栏,包括手机电量、手机时间等,还可以设置不同用户的角色,然后发送文字、语音、红包、转账等多种好玩的功能,可谓是一款娱乐…...

剑指 Offer 10- I. 斐波那契数列[c语言]
目录题目思路代码结果该文章只是用于记录考研复试刷题题目 力扣斐波那契数列 写一个函数,输入 n ,求斐波那契(Fibonacci)数列的第 n 项(即 F(N))。斐波那契数列的定义如下: F(0) 0, F(1) 1 …...
【C#基础】C# 数据类型总结
序号系列文章0【C#基础】初识编程语言C#1【C#基础】C# 程序通用结构2【C#基础】C# 基础语法解析文章目录前言数据类型一. 值类型(Value types)二. 引用类型(Reference types)三. 指针类型(Pointer types)结…...

再创荣誉 | Softing工业荣获CAIMRS 2023 数字化创新奖
在刚刚结束的中国工控-第二十一届“自动化及数字化”年度评选(CAIMRS 2023)中,Softing凭借edgeAggregator产品荣获“数字化创新奖”! 经层层筛选,Softing edgeAggregator边缘聚合服务器从中脱颖而出,摘得C…...

Multi Paxos
basic paxos 是用于确定且只能确定一个值,“只确定一个值有什么用?这可解决不了我面临的问题,例如每个用户都要多次保存数据.” 你心中可能有这样的疑问。 原simple paxos论文里有提到一连串个instance of paxos [4] 但没有提出 multi paxos的概念&…...

Android - dimen适配
一、分辨率对应DPIDPI名称范围值分辨率名称屏幕分辨率density密度(1dp显示多少px)ldpi120QVGA240*3200.75(120dpi/1600.75px)mdpi160(基线)HVGA320*4801(160dpi/1601px)hdpi240WVGA4…...

深度学习网络模型——RepVGG网络详解
深度学习网络模型——RepVGG网络详解0 前言1 RepVGG Block详解2 结构重参数化2.1 融合Conv2d和BN2.2 Conv2dBN融合实验(Pytorch)2.3 将1x1卷积转换成3x3卷积2.4 将BN转换成3x3卷积2.5 多分支融合2.6 结构重参数化实验(Pytorch)3 模型配置论文名称: RepVGG: Making V…...

仓库拣货标签应用案例
使用场景:富士康成都仓库 解决问题:仓库亮灯拣选, 提高作业效率和物料明晰展示仓库亮灯拣选使用场景:京东仓库 解决问题:播种墙分拣,合单拣货完成后按订单播种播种墙分拣使用场景:和尔泰智能料…...

介绍一款HCIA、HCIP、HCIE的刷题软件
华为认证考试分为三个等级,分别为工程师HCIA、高级工程师HCIP、专家HCIE,等级越高,考试难度越大。 本篇带大家详细了解华为数通题库刷题工具的详细操作步骤。 操作须知:本款刷题工具为一款刷题小程序,无需安装即可在线…...

线程池整理汇总
它山之石,可以攻玉。借鉴整理线程池相关文章,以及自身实践。 文章目录1. 线程池概述2. 线程池UML架构3. Executors创建线程的4种方法3.1 newSingleThreadExecutor3.2 newFixedThreadPool3.3 newCachedThreadPool3.4 newScheduledThreadPool小结4. 线程池…...

TDengine 快速体验(Docker 镜像方式)
简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能,本节首先介绍如何通过 Docker 快速体验 TDengine,然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker,请使用 安装包的方式快…...

Python:操作 Excel 折叠
💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂
蛋白质结合剂(如抗体、抑制肽)在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上,高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术,但这类方法普遍面临资源消耗巨大、研发周期冗长…...

大型活动交通拥堵治理的视觉算法应用
大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...
STM32+rt-thread判断是否联网
一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...
1688商品列表API与其他数据源的对接思路
将1688商品列表API与其他数据源对接时,需结合业务场景设计数据流转链路,重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点: 一、核心对接场景与目标 商品数据同步 场景:将1688商品信息…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成
厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...
数据库分批入库
今天在工作中,遇到一个问题,就是分批查询的时候,由于批次过大导致出现了一些问题,一下是问题描述和解决方案: 示例: // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...

云原生玩法三问:构建自定义开发环境
云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

【7色560页】职场可视化逻辑图高级数据分析PPT模版
7种色调职场工作汇报PPT,橙蓝、黑红、红蓝、蓝橙灰、浅蓝、浅绿、深蓝七种色调模版 【7色560页】职场可视化逻辑图高级数据分析PPT模版:职场可视化逻辑图分析PPT模版https://pan.quark.cn/s/78aeabbd92d1...