当前位置: 首页 > news >正文

【pytorch框架】对模型知识的基本了解

文章目录

    • TensorBoard的使用
      • 1、TensorBoard启动:
      • 2、使用TensorBoard查看一张图片
      • 3、transforms的使用
    • pytorch框架基础知识
      • 1 nn.module的使用
      • 2 nn.conv2d的使用
      • 3、池化(MaxPool2d)
      • 4 非线性激活
      • 5 线性层
      • 6 Sequential的使用
      • 7 损失函数与反向传播
      • 8 优化器
      • 9 对现有网络的使用和修改
      • 10 网络模型的保存与读取

TensorBoard的使用

1、TensorBoard启动:

在Terminal终端命令中输入:

tensorboard --logdir=logs #logs为创建的文件名

2、使用TensorBoard查看一张图片

writer=SummaryWriter("../logs")
image_path=r'F:\image\1.jpg'
img_PIL=Image.open(image_path)
image_array=np.array(img_PIL)
writer.add_image('test',image_array,1,dataformats='HWC')
writer.close()

3、transforms的使用

作用:使PIL Image 或者np ——》tensor

imgae_path=r'F:\image\1.jpg'
img=Image.open(img_path)
tensor_trans=transsforms.ToTensor()  #相当于创建一个工具
tensor_img=tensor_trans(img) #img转化成tensor模式

同理,ToPILIMage是为了tensor 或者 ndarray =》Image

pytorch框架基础知识

1 nn.module的使用

目的:给所有网络提供基本骨架
在这里插入图片描述

from torch import nn
class aiy(nn.Module):def  __init__(self):super().__init__()def forward(sel,input):output=input+1return outputaiy=aiy()
# x=torch.tensor(1.0)
x=1
output=aiy(x)
print(output)
'''
2
'''

2 nn.conv2d的使用

参数代码解释如下:
在这里插入图片描述
示例:输入一个5x5的矩阵,和一个3x3的卷积核做卷积操作

import torch
import torch.nn.functional as F
input=torch.tensor([[1,2,0,3,1],[0,1,2,3,1],[1,2,1,0,0],[5,2,3,1,1],[2,1,0,1,1]])
input=input.reshape([1,1,5,5])kears=torch.tensor([[1,2,1],[0,1,0],[2,1,0]])
kears=kears.reshape([1,1,3,3])output=F.conv2d(input,kears,stride=1)print(output)
print(output.shape)
'''
tensor([[[[10, 12, 12], [18, 16, 16],[13,  9,  3]]]])
torch.Size([1, 1, 3, 3])
'''

若是输入的卷积核的数量有两个,则得到的output也是两个
在这里插入图片描述
示例:借用CIFAR10数据集,用自定义的网络模型做一次卷积操作,然后用tensorboard查看卷积之后的结果。
这里需要注意的是,经过卷积得到的大小是[64,6,30,30],而图片的通道一般都是3通道的,6通道的图片不知道怎么显示,需要使用reshpae重新改变矩阵的大小。

output=output.reshape([-1,3,30,30]) #-1自动计算剩余的值,后面[3,30,30]改成指定大小

示例代码:

import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import Conv2d
from torch.utils.tensorboard import SummaryWriter#数据准备
dataset=torchvision.datasets.CIFAR10("./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader=DataLoader(dataset,batch_size=64)#自定义网络模型
class aiy(nn.Module):def __init__(self):super(aiy, self).__init__()#卷积运算self.conv1=Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)def forward(self,x):x=self.conv1(x)return xaiy=aiy()
# print(aiy)
step=0
writer=SummaryWriter("../log")for data in dataloader:img,targets=dataoutput=aiy(img)# print(img.shape)#torch.Size([64, 3, 32, 32]# print(output.shape)#torch.Size([64, 6, 30, 30])#因为图片的通道是3,需要改变矩阵的大小# output=output.reshape([-1,3,30,30])writer.add_images("input",img,step)output=torch.reshape(output,(-1,3,30,30))writer.add_images("output", output, step)# print(output.shape)step=step+1print(step)
writer.close()

在这里插入图片描述

3、池化(MaxPool2d)

目的:降采样,大幅减少网络的参数量,同时保留图像数据的特征。
需要注意的是: 池化不改变通道数

池化参数
在这里插入图片描述
数组演示示例:
在这里插入图片描述

input=torch.tensor([[1,2,0,3,1],[0,1,2,3,1],[1,2,1,0,0],[5,2,3,1,1],[2,1,0,1,1]],dtype=float)
input=torch.reshape(input,(-1,1,5,5))output=aiy(input)
print(output.shape)
'''
ceil_mode=True:
tensor([[[[2., 3.],[5., 1.]]]], dtype=torch.float64)ceil_mode=False:
tensor([[[[2.]]]], dtype=torch.float64)
'''

示例:同样,借用CIFAR10数据集,用自定义的网络模型做一次池化操作,然后用tensorboard查看卷积之后的结果。

# -*- coding: utf-8 -*-
# Auter:我菜就爱学import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d#带入数组
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterclass aiy(nn.Module):def __init__(self):super(aiy, self).__init__()self.maxpool1=MaxPool2d(kernel_size=3,ceil_mode=False)def forward(self,input):output=self.maxpool1(input)return outputaiy=aiy()#将池化层用数据集测试
dataset=torchvision.datasets.CIFAR10('./data',train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader=DataLoader(dataset,batch_size=64)step=0
writer=SummaryWriter("../logmaxpool")for data in dataloader:img,target=datawriter.add_images("input",img,step)output=aiy(img)writer.add_images("output",output,step)step=step+1

有点像打马赛克
在这里插入图片描述

4 非线性激活

作用:提高泛化能力,引入非线性特征
在这里插入图片描述

ReLu(input,inplace=True)  
=>表示原input替换input
out=ReLu(input,inplace=False)
=>表示原input被out替换

5 线性层

在这里插入图片描述

6 Sequential的使用

作用:可以简化自己搭建的网络模型
示例:参考CIFAR10的网络模型结构,创建一个网络。

在这里插入图片描述

# -*- coding: utf-8 -*-
# Auter:我菜就爱学import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.tensorboard import SummaryWriterclass Aiy(nn.Module):def __init__(self):super(Aiy, self).__init__()self.model=Sequential(Conv2d(3,32,5,padding=2,stride=1),MaxPool2d(kernel_size=2),Conv2d(32,32,5,padding=2),MaxPool2d(kernel_size=2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,input):output=self.model(input)return outputaiy=Aiy()# print(aiy)input=torch.ones((64,3,32,32))output=aiy(input)print(output.shape)

使用tensorboard中的命令可以查看网络模型结构

writer=SummaryWriter('../logmodel')writer.add_graph(aiy,input)writer.close()

在这里插入图片描述

7 损失函数与反向传播

作用:

  • 计算处实际输出与目标之间的差距
  • 更新输出提供一定的依据

通过小土堆举的示例可以很好的理解损失函数
在这里插入图片描述

说明:假设一张试卷满分是100分,其中选择30,填空20,解答50.第一次我们得到的结果是:选择10,填空10,解答20.第一次损失值是60.
然后通过不断的训练,让选择提高到20,填空提高20,解答提高到40,这个时候与满分差距20,损失值也就越来越小。

# -*- coding: utf-8 -*-
# Auter:我菜就爱学
import torch
from torch.nn import L1Lossinput=torch.tensor([1,2,3],dtype=torch.float32)
input=torch.reshape(input,(1,1,1,3))
target=torch.tensor([1,2,5],dtype=torch.float32)
target=torch.reshape(target,(1,1,1,3))#设置一个损失函数
loss=L1Loss(reduction='sum')
output=loss(input,target)
print(output)
'''
tensor(2.)
'''

8 优化器

优化器参数解释:
在这里插入图片描述

for epoch in range (20):sum_loss=0.0for data in dataloader:imgs,targets=dataoutput=aiy(imgs)result_loss=loss(output,targets)optim.zero_grad()result_loss.backward() #反向传播,更新对应的梯度optim.step()  #调整更新的参数sum_loss=sum_loss+result_lossprint(sum_loss)

下面是对优化器中的交叉熵的解释:
在这里插入图片描述

9 对现有网络的使用和修改

  • 下载现有网络,并使用数据集更新好的参数
vgg16_True=torchvision models vgg16(pretrained=True)

一般下载好的模型保存路径:==C:\user.cache\torch\hub\checkpoints

  • 在已有的网络模型中新添自己需要的层
vgg16_True.classifier.add_module("7",nn.Linear(1000,10))

10 网络模型的保存与读取

方法一:直接把模型和参数保存下来
注意: 有 一个陷阱,自定义的模型在下载的时候运行会报错,得需要复制下载原模型。只能导入专门经典的模型

#保存
torch.save(vgg16_true,"vgg16_method1.pth")#下载
model=torch.load("vgg16_method1.pth")

方法二:保存模型的参数,一般使用这个。内存比较小,节省空间;以字典的形式保存。

#保存
torch.save(vgg16_true.state_dict(),"vgg16_method1.pth")#下载
vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_false.load_state_dict(torch.long("wgg166_method2.pth"))

相关文章:

【pytorch框架】对模型知识的基本了解

文章目录TensorBoard的使用1、TensorBoard启动:2、使用TensorBoard查看一张图片3、transforms的使用pytorch框架基础知识1 nn.module的使用2 nn.conv2d的使用3、池化(MaxPool2d)4 非线性激活5 线性层6 Sequential的使用7 损失函数与反向传播8 优化器9 对现有网络的使…...

SUP桨板电动气泵方案——鼎盛合方案

SUP桨板是现时最热门的水上运动之一,它的全称是Stand Up Paddle,简称SUP。这项运动近几年在我国三亚等地区风靡一时,在网上经常看到一些运动博主或者明星网红晒出冲浪视频,刺激又惊险。SUP桨板为充气式桨板,需要通过充…...

小白系列Vite-Vue3-TypeScript:011-登录界面搭建及动态路由配置

前面几篇文章我们介绍的都是ViteVue3TypeScript项目中环境相关的配置,接下来我们开始进入系统搭建部分。本篇我们来介绍登录界面搭建及动态路由配置,大家一起撸起来......搭建登录界面登陆接口api项目登陆接口是通过mockjs前端来模拟的模拟服务接口Login…...

C语言( 缓冲区和重定向)

一.缓冲输入,无缓存输入 while((chgetchar()) ! #) putchar(ch); 这里getchar(),putchar()每次只处理一个字符(这里只是知道就好了),而我们使用while循环,当读到#字符时停止 而看到输出例子,第一行我们输入…...

编程思想、方法论和架构的类型及应用

概要编程思想是指在编写代码时所采用的基本思维方式和方法论。分类编程思想编程思想为软件开发提供了思维范式和指导思路,例如面向对象思想、函数式编程思想等,它们帮助程序员更好地抽象问题、组织代码、提高代码复用性和可维护性,包括一下几…...

【OA办公】OA流程审批大揭秘,带你看遍所有基础流程

流程审批,是所有企业的OA办公系统重要组成部分,是任何OA办公系统都不可缺少的。比起传统的纸张传阅、签批的审批模式浪费了大量的时间和成本,因此越来越多的企业采用OA这种全新的、高效的、智能的审批模式。流程审批除了这些好处,…...

《零基础入门数据结构与算法》专栏介绍

目录 前言 第一部分:重点 第二部分:题库 第三部分:测试 第四部分:实验 第五部分:试卷 总结 前言 本专栏主要分为五个部分: ① 重要知识点详解 ② 近百道练习题解析 ③ 数据结构与算法章节测试 …...

测试开发之Django实战示例 第九章 扩展商店功能

第九章 扩展商店功能在上一章里,为电商站点集成了支付功能,然后可以生成PDF发票发送给用户。在本章,我们将为商店添加优惠码功能。此外,还会学习国际化和本地化的设置和建立一个推荐商品的系统。本章涵盖如下要点:建立…...

【Spring】一文带你吃透AOP面向切面编程技术(下篇)

个人主页: 几分醉意的CSDN博客_传送门 上节我们介绍了什么是AOP、Aspectj框架的前置通知Before传送门,这篇文章将继续详解Aspectj框架的其它注解。 文章目录💖Aspectj框架介绍✨JoinPoint通知方法的参数✨后置通知AfterReturning✨环绕通知Ar…...

【java】Spring Boot --40 个 Spring Boot 常用注解(建议收藏)

本文目录一、Spring Web MVC 注解Spring Web MVC 注解RequestMappingRequestBodyGetMappingPostMappingPutMappingDeleteMappingPatchMappingControllerAdviceResponseBodyExceptionHandlerResponseStatusPathVariableRequestParamControllerRestControllerModelAttributeCross…...

《游戏学习》| 微信对话模拟生成器源码分析

简介微信对话生成器,是一款在线微信聊天对话制作的工具,它可以设置苹果或安卓状态栏,包括手机电量、手机时间等,还可以设置不同用户的角色,然后发送文字、语音、红包、转账等多种好玩的功能,可谓是一款娱乐…...

剑指 Offer 10- I. 斐波那契数列[c语言]

目录题目思路代码结果该文章只是用于记录考研复试刷题题目 力扣斐波那契数列 写一个函数,输入 n ,求斐波那契(Fibonacci)数列的第 n 项(即 F(N))。斐波那契数列的定义如下: F(0) 0, F(1) 1 …...

【C#基础】C# 数据类型总结

序号系列文章0【C#基础】初识编程语言C#1【C#基础】C# 程序通用结构2【C#基础】C# 基础语法解析文章目录前言数据类型一. 值类型(Value types)二. 引用类型(Reference types)三. 指针类型(Pointer types)结…...

再创荣誉 | Softing工业荣获CAIMRS 2023 数字化创新奖

在刚刚结束的中国工控-第二十一届“自动化及数字化”年度评选(CAIMRS 2023)中,Softing凭借edgeAggregator产品荣获“数字化创新奖”! 经层层筛选,Softing edgeAggregator边缘聚合服务器从中脱颖而出,摘得C…...

Multi Paxos

basic paxos 是用于确定且只能确定一个值,“只确定一个值有什么用?这可解决不了我面临的问题,例如每个用户都要多次保存数据.” 你心中可能有这样的疑问。 原simple paxos论文里有提到一连串个instance of paxos [4] 但没有提出 multi paxos的概念&…...

Android - dimen适配

一、分辨率对应DPIDPI名称范围值分辨率名称屏幕分辨率density密度(1dp显示多少px)ldpi120QVGA240*3200.75(120dpi/1600.75px)mdpi160(基线)HVGA320*4801(160dpi/1601px)hdpi240WVGA4…...

深度学习网络模型——RepVGG网络详解

深度学习网络模型——RepVGG网络详解0 前言1 RepVGG Block详解2 结构重参数化2.1 融合Conv2d和BN2.2 Conv2dBN融合实验(Pytorch)2.3 将1x1卷积转换成3x3卷积2.4 将BN转换成3x3卷积2.5 多分支融合2.6 结构重参数化实验(Pytorch)3 模型配置论文名称: RepVGG: Making V…...

仓库拣货标签应用案例

使用场景:富士康成都仓库 解决问题:仓库亮灯拣选, 提高作业效率和物料明晰展示仓库亮灯拣选使用场景:京东仓库 解决问题:播种墙分拣,合单拣货完成后按订单播种播种墙分拣使用场景:和尔泰智能料…...

介绍一款HCIA、HCIP、HCIE的刷题软件

华为认证考试分为三个等级,分别为工程师HCIA、高级工程师HCIP、专家HCIE,等级越高,考试难度越大。 本篇带大家详细了解华为数通题库刷题工具的详细操作步骤。 操作须知:本款刷题工具为一款刷题小程序,无需安装即可在线…...

线程池整理汇总

它山之石,可以攻玉。借鉴整理线程池相关文章,以及自身实践。 文章目录1. 线程池概述2. 线程池UML架构3. Executors创建线程的4种方法3.1 newSingleThreadExecutor3.2 newFixedThreadPool3.3 newCachedThreadPool3.4 newScheduledThreadPool小结4. 线程池…...

【第二十一章 SDIO接口(SDIO)】

第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

MMaDA: Multimodal Large Diffusion Language Models

CODE : https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA,它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构&#xf…...

数据链路层的主要功能是什么

数据链路层(OSI模型第2层)的核心功能是在相邻网络节点(如交换机、主机)间提供可靠的数据帧传输服务,主要职责包括: 🔑 核心功能详解: 帧封装与解封装 封装: 将网络层下发…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

爬虫基础学习day2

# 爬虫设计领域 工商:企查查、天眼查短视频:抖音、快手、西瓜 ---> 飞瓜电商:京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空:抓取所有航空公司价格 ---> 去哪儿自媒体:采集自媒体数据进…...

css3笔记 (1) 自用

outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size&#xff1a;0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格&#xff…...

LangFlow技术架构分析

&#x1f527; LangFlow 的可视化技术栈 前端节点编辑器 底层框架&#xff1a;基于 &#xff08;一个现代化的 React 节点绘图库&#xff09; 功能&#xff1a; 拖拽式构建 LangGraph 状态机 实时连线定义节点依赖关系 可视化调试循环和分支逻辑 与 LangGraph 的深…...

在树莓派上添加音频输入设备的几种方法

在树莓派上添加音频输入设备可以通过以下步骤完成&#xff0c;具体方法取决于设备类型&#xff08;如USB麦克风、3.5mm接口麦克风或HDMI音频输入&#xff09;。以下是详细指南&#xff1a; 1. 连接音频输入设备 USB麦克风/声卡&#xff1a;直接插入树莓派的USB接口。3.5mm麦克…...

深度剖析 DeepSeek 开源模型部署与应用:策略、权衡与未来走向

在人工智能技术呈指数级发展的当下&#xff0c;大模型已然成为推动各行业变革的核心驱动力。DeepSeek 开源模型以其卓越的性能和灵活的开源特性&#xff0c;吸引了众多企业与开发者的目光。如何高效且合理地部署与运用 DeepSeek 模型&#xff0c;成为释放其巨大潜力的关键所在&…...

stm32wle5 lpuart DMA数据不接收

配置波特率9600时&#xff0c;需要使用外部低速晶振...