pytorch网络模型构建中的注意点
记录使用pytorch构建网络模型过程遇到的点
1. 网络模型构建中的问题
1.1 输入变量是Tensor张量
各个模块和网络模型的输入, 一定要是tensor
张量;
可以用一个列表存放多个张量。
如果是张量维度不够,需要升维度,
可以先使用 torch.unsqueeze(dim = expected)
然后再使用torch.cat(dim )
进行拼接;
- 需要传递梯度的数据,禁止使用
numpy
, 也禁止先使用numpy,然后再转换成张量的这种情况出现;
这是因为pytorch的机制是只有是
Tensor
张量的类型,才会有梯度等属性值,如果是numpy这些类别,这些变量并会丢失其梯度值。
1.2 __init__()
方法使用
class ex:def __init__(self):pass
__init__
方法必须接受至少一个参数即self,
Python中,self是指向该对象本身的一个引用
,
通过在类的内部使用self变量,
类中的方法可以访问自己的成员变量,简单来说,self.varname的意义为”访问该对象的varname属性“
当然,__init__()
中可以封装任意的程序逻辑,这是允许的,init()方法还接受任意多个其他参数,允许在初始化时提供一些数据,例如,对于刚刚的worker类,可以这样写:
class worker:def __init__(self,name,pay):self.name=nameself.pay=pay
这样,在创建worker类的对象时,必须提供name和pay两个参数:
b=worker('Jim',5000)
Python会自动调用worker.init()方法,并传递参数。
细节参考这里init方法
1.3 内置函数 setattr()
此时,可以使用python自带的内置函数 setattr()
, 和对应的getattr()
setattr(object, name, value)
object – 对象。
name – 字符串,对象属性。
value – 属性值。
对已存在的属性进行赋值:
>>>class A(object):
... bar = 1
...
>>> a = A()
>>> getattr(a, 'bar') # 获取属性 bar 值
1
>>> setattr(a, 'bar', 5) # 设置属性 bar 值
>>> a.bar
5如果属性不存在会创建一个新的对象属性,并对属性赋值:>>>class A():
... name = "runoob"
...
>>> a = A()
>>> setattr(a, "age", 28)
>>> print(a.age)
28
>>>
setattr() 语法
setattr(object, name, value)
object – 对象。
name – 字符串,对象属性。
value – 属性值。
1.4 网络模型的构建
注意到, 在python的 __init__()
函数中, self
本身就是该类的对象的一个引用,即self是指向该对象本身的一个引用
,
利用上述这一点,当在神经网络中,
- 需要给多个属性进行实例化时,
- 且这多个属性使用的是同一个类进行实例化.
1.4.1 使用 setattr(self, string, object1)
添加属性;
注意到,下面这种方式,由于
Basic_slide_conv()
只经过了一次实例化,
所以在内存空间中,只会分配一个地址空间给该对象;
虽然后面使用 35 group,
但这35组本质上使用的同一个对象,即conv_block
该对象;
class Temporal_GroupTrans(nn.Module):def __init__(self, num_classes=10,num_groups=35, drop_prob=0.5, pretrained= True):super(Temporal_GroupTrans, self).__init__()conv_block = Basic_slide_conv()for i in range( num_groups):setattr(self, "group" + str(i), conv_block)# 自定义transformer模型的初始化, CustomTransformerModel() 在该类中传入初始化模型的参数,# nip:512 输入序列中,每个列向量的编码维度, 16: 注意力头的个数# 600: 中间mlp 隐藏层的维数, 6: 堆叠transforEncode 编码模块的个数;self.trans_model = CustomTransformerModel(512,16,600, 6,droupout=0.5,nclass=4)
如果想要分配35个不同的对象, 即需要分配出35个不同的地址空间用来存储,
那么需要将 Basic_slide_conv()
经过了35次实例化,
所以需要将 类Basic_slide_conv()
实例化的过程放在循环当中实现;
class Temporal_GroupTrans(nn.Module):def __init__(self, num_classes=10,num_groups=35, drop_prob=0.5, pretrained= True):super(Temporal_GroupTrans, self).__init__()# conv_block = Basic_slide_conv()for i in range( num_groups):setattr(self, "group" + str(i), Basil_slide_conv() )# 自定义transformer模型的初始化, CustomTransformerModel() 在该类中传入初始化模型的参数,# nip:512 输入序列中,每个列向量的编码维度, 16: 注意力头的个数# 600: 中间mlp 隐藏层的维数, 6: 堆叠transforEncode 编码模块的个数;self.trans_model = CustomTransformerModel(512,16,600, 6,droupout=0.5,nclass=4)
1.4.2 使用 getattr(self, string, object1)
获取属性;
trans_input_sequence = []for i in range(0, num_groups, ):# 每组语谱图的大小是一个 (bt, ch,96,12)的矩阵,组与组之间没有重叠;cur_group = x[:, :, :, 12 * i:12 * (i + 1)]# VARIABLE_fun = "self.group" # 每一组,与之对应的卷积模块;# cur_fun = eval(VARIABLE_fun + str(i ))cur_fun = getattr(self, 'group'+str(i))cur_group_out = cur_fun(cur_group).unsqueeze(dim=1) # [bt,1, 512]trans_input_sequence.append(cur_group_out)
相关文章:
pytorch网络模型构建中的注意点
记录使用pytorch构建网络模型过程遇到的点 1. 网络模型构建中的问题 1.1 输入变量是Tensor张量 各个模块和网络模型的输入, 一定要是tensor 张量; 可以用一个列表存放多个张量。 如果是张量维度不够,需要升维度, 可以先使用 …...
面试时候这样介绍redis,redis经典面试题
为什么要用redis做缓存 使用Redis缓存有以下几个优点: 1. 提高系统性能:缓存可以将数据存储在内存中,加快数据的访问速度,减少对数据库的读写次数,从而提高系统的性能。 2. 减轻后端压力:使用缓存可以减…...

机械学习 - scikit-learn - 数据预处理 - 2
目录关于 scikit-learn 实现规范化的方法详解一、fit_transform 方法1. 最大最小归一化手动化与自动化代码对比演示 1:2. 均值归一化手动化代码演示:3. 小数定标归一化手动化代码演示:4. 零-均值标准化(均值移除)手动与自动化代码演示&#x…...
华为OD机试题 - 最长连续交替方波信号(JavaScript)| 机考必刷
更多题库,搜索引擎搜 梦想橡皮擦华为OD 👑👑👑 更多华为OD题库,搜 梦想橡皮擦 华为OD 👑👑👑 更多华为机考题库,搜 梦想橡皮擦华为OD 👑👑👑 华为OD机试题 最近更新的博客使用说明本篇题解:最长连续交替方波信号题目输入输出示例一输入输出Code解题思路版…...

executor行为相关Spark sql参数源码分析
0、前言 参数名和默认值spark.default.parallelismDefault number of partitions in RDDsspark.executor.cores1 in YARN mode 一般默认值spark.files.maxPartitionBytes134217728(128M)spark.files.openCostInBytes4194304 (4 MiB)spark.hadoop.mapreduce.fileoutputcommitte…...

双通道5.2GSPS(或单通道10.4GSPS)射频采样FMC+模块
概述 FMC140是一款具有缓冲模拟输入的低功耗、12位、双通道(5.2GSPS/通道)、单通道10.4GSPS、射频采样ADC模块,该板卡为FMC标准,符合VITA57.1规范,该模块可以作为一个理想的IO单元耦合至FPGA前端,8通道的JE…...
理解java反射
是什么Java反射是Java编程语言的一个功能,它允许程序在运行时(而不是编译时)检查、访问和修改类、对象和方法的属性和行为。使用反射创建对象相比直接创建对象有什么优点使用反射创建对象相比直接创建对象的主要优点是灵活性和可扩展性。当我…...

EasyRcovery16免费的电脑照片数据恢复软件
电脑作为一种重要的数据储存设备,其中保存着大量的文档,邮件,视频,音频和照片。那么,如果电脑照片被删除了怎么办?今天小编给大家介绍,误删除的照片从哪里可以找回来,误删除的照片如…...

若依微服务版在定时任务里面跨模块调用服务
第一步 在被调用的模块中添加代理 RemoteTaskFallbackFactory.java: package com.ruoyi.rpa.api.factory;import com.ruoyi.common.core.domain.R; import com.ruoyi.rpa.api.RemoteTaskService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springf…...
SpringMVC简单配置
1、pom.xml配置 <dependencies><dependency><groupId>org.springframework</groupId><artifactId>spring-webmvc</artifactId><version>5.1.12.RELEASE</version></dependency></dependencies><build><…...
xcat快速入门工作流程指南
目录一、快速入门指南一、先决条件二、准备管理节点xcatmn.mydomain.com三、第1阶段:添加你的第一个节点并且用带外BMC接口控制它四、第 2 阶段 预配节点并使用并行 shell 对其进行管理二:工作流程指南1. 查找 xCAT 管理节点的服务器2. 在所选服务器上安…...

C++回顾(十九)—— 容器string
19.1 string概述 1、string是STL的字符串类型,通常用来表示字符串。而在使用string之前,字符串通常是 用char * 表示的。string 与char * 都可以用来表示字符串,那么二者有什么区别呢。 2、string和 char * 的比较 (1)…...
Hadoop入门
数据分析与企业数据分析方向 数据是什么 数据是指对可观事件进行记录并可以鉴别的符号,是对客观事物的性质、状态以及相互关系等进行记载的物理符号或这些物理符号的组合,它是可以识别的、抽象的符号。 他不仅指狭义上的数字,还可以是具有一…...

高校如何通过校企合作/实验室建设来提高大数据人工智能学生就业质量
高校人才培养应该如何结合市场需求进行相关专业设置和就业引导,一直是高校就业工作的讨论热点。亘古不变的原则是,高校设置不能脱离市场需求太远,最佳的结合方式是,高校具有前瞻性,能领先市场一步,培养未来…...

提升学习 Prompt 总结
NLP现有的四个阶段: 完全有监督机器学习完全有监督深度学习预训练:预训练 -> 微调 -> 预测提示学习:预训练 -> 提示 -> 预测 阶段1,word的本质是特征,即特征的选取、衍生、侧重上的针对性工程。 阶段2&…...
JavaScript学习笔记(2.0)
BOM--(browser object model) 获取浏览器窗口尺寸 获取可视窗口高度:window.innerWidth 获取可视窗口高度:window.innerHeight 浏览器弹出层 提示框:window.alert(提示信息) 询问框:window.confirm(提示信息) 输…...

直击2023云南移动生态合作伙伴大会,聚焦云南移动的“价值裂变”
作者 | 曾响铃 文 | 响铃说 2023年3月2日下午,云南移动生态合作伙伴大会在昆明召开。云南移动党委书记,总经理葛松海在大会上提到“2023年,云南移动将重点在‘做大平台及生态级新产品,做优渠道转型新动能,做强合作新…...

STM32F1开发实例-振动传感器(机械)
振动(敲击)传感器 振动无处不在,有声音就有振动,哒哒的脚步是匆匆的过客,沙沙的夜雨是暗夜的忧伤。那你知道理科工程男是如何理解振动的吗?今天我们就来讲一讲本节的主角:最简单的机械式振动传感器。 下图即为振动传…...

2023最新ELK日志平台(elasticsearch+logstash+kibana)搭建
去年公司由于不断发展,内部自研系统越来越多,所以后来搭建了一个日志收集平台,并将日志收集功能以二方包形式引入自研系统,避免每个自研系统都要建立一套自己的日志模块,节约了开发时间,管理起来也更加容易…...
2023-3-10 刷题情况
打家劫舍 IV 题目描述 沿街有一排连续的房屋。每间房屋内都藏有一定的现金。现在有一位小偷计划从这些房屋中窃取现金。 由于相邻的房屋装有相互连通的防盗系统,所以小偷 不会窃取相邻的房屋 。 小偷的 窃取能力 定义为他在窃取过程中能从单间房屋中窃取的 最大…...

业务系统对接大模型的基础方案:架构设计与关键步骤
业务系统对接大模型:架构设计与关键步骤 在当今数字化转型的浪潮中,大语言模型(LLM)已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中,不仅可以优化用户体验,还能为业务决策提供…...
基于Uniapp开发HarmonyOS 5.0旅游应用技术实践
一、技术选型背景 1.跨平台优势 Uniapp采用Vue.js框架,支持"一次开发,多端部署",可同步生成HarmonyOS、iOS、Android等多平台应用。 2.鸿蒙特性融合 HarmonyOS 5.0的分布式能力与原子化服务,为旅游应用带来…...

DIY|Mac 搭建 ESP-IDF 开发环境及编译小智 AI
前一阵子在百度 AI 开发者大会上,看到基于小智 AI DIY 玩具的演示,感觉有点意思,想着自己也来试试。 如果只是想烧录现成的固件,乐鑫官方除了提供了 Windows 版本的 Flash 下载工具 之外,还提供了基于网页版的 ESP LA…...
【JavaSE】绘图与事件入门学习笔记
-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角,以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向,距离坐标原点x个像素;第二个是y坐标,表示当前位置为垂直方向,距离坐标原点y个像素。 坐标体系-像素 …...
【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统
目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...
在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?
uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件,用于在原生应用中加载 HTML 页面: 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...

回溯算法学习
一、电话号码的字母组合 import java.util.ArrayList; import java.util.List;import javax.management.loading.PrivateClassLoader;public class letterCombinations {private static final String[] KEYPAD {"", //0"", //1"abc", //2"…...

视觉slam十四讲实践部分记录——ch2、ch3
ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

2025年渗透测试面试题总结-腾讯[实习]科恩实验室-安全工程师(题目+回答)
安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 腾讯[实习]科恩实验室-安全工程师 一、网络与协议 1. TCP三次握手 2. SYN扫描原理 3. HTTPS证书机制 二…...

接口自动化测试:HttpRunner基础
相关文档 HttpRunner V3.x中文文档 HttpRunner 用户指南 使用HttpRunner 3.x实现接口自动化测试 HttpRunner介绍 HttpRunner 是一个开源的 API 测试工具,支持 HTTP(S)/HTTP2/WebSocket/RPC 等网络协议,涵盖接口测试、性能测试、数字体验监测等测试类型…...