基于Pytorch的神经网络部分自定义设计
一、基础概念(学习笔记)
(1)训练误差和泛化误差[1]
本质上,优化和深度学习的目标是根本不同的。前者主要关注的是最小化目标,后者则关注在给定有限数据量的情况下寻找合适的模型。训练误差和泛化误差通常不同:由于优化算法的目标函数通常是基于训练数据集的损失函数,因此优化的目标是减少训练误差。但是,深度学习(或更广义地说,统计推断)的目标是减少泛化误差。
最小化训练误差并不能保证我们找到最佳的参数集来最小化泛化误差。同时,对于深度学习而言,没有必要找到最优解,局部最优解或其近似解仍然非常有用。
个人理解:训练误差为训练集每个Epoch的损失,泛化误差为验证集对应Epoch的损失。训练集与验证集的精确率曲线对比可以作为验证模型是否过拟合的标志。一段时间内的训练集Epoch精确率曲线不断升高且验证集Epoch精确率曲线不断降低,即很大可能模型过拟合了。
(2)Python变量[10]
Python变量由内存地址标识、数据类型、变量值组成。标识对象所存储的内存地址,使用内置函数id(obj)来获取;对象的数据类型,使用内置函数type(obj)来获取。
Python的一切变量都是对象,都采用了引用语义的方式存储,存储的只是一个变量的值所在的内存空间,Python的变量创建、赋值操作可基于C中的指针理解。在Python中需要使用到变量值时,变量会根据自身存储的指向值的内存地址从而获取值。
(3)可迭代对象、迭代器和生成器(Python机制)[7][8]
可迭代对象(Iterable),指存储了元素的一个容器对象,且容器中的元素可以通过__iter__()方法(迭代器形式返回)或__getitem__()方法(实例名[index])访问。
迭代器(Iterator),迭代器可以看作是一个特殊的对象,每次调用该对象时会返回自身的下一个元素(调入内存)。从类构造上来看,一个迭代器对象必须是定义了__iter__()方法和next()方法的对象。迭代器的优点是节约内存(数据循环读入内存而不是一次性,同时每次取值上一条数据都会在内存释放)、实现惰性计算(需要时再根据保存的状态计算读入对应数据),缺点是只能向前一个个访问数据、已访问数据无法再次访问、迭代访问一轮后迭代器空置。
生成器(Generator),本质上还是一种迭代器,但其倾向于存在计算过程的迭代器。生成器包含yield关键字,同时也可使用next()获取数据。对比于列表生成式(实例 = [F(x)for x in iterable])返回一个关于F(x)的函数值列表,将[]替换为()即可转换为一个生成器。
显然,通过实验可以发现:
(1)若列表存储大量数值且不需要一次性全部参与计算,那么将其转换为迭代器逐次加入内存,可以节省大量的内存空间(迭代器的内存占用空间不变);
(2)在Python中,int数字变量是变长存储,其最小内存占用为24Bytes。通过查阅资料,Python基于C语言编成,但Python的int是动态的,其int是一个类,类内包含特殊的数据结构。
(3)getsize()仅返回对象的内存占用大小(仅考虑直接归因于该对象的内存消耗,而不考虑其引用的对象的内存消耗)。对于容器对象(例如字典或列表),给定的内存大小仅涵盖容器使用的内存以及用于引用其他对象的指针值。如果需要计算容器的内存占用以及该容器引用所有内容的大小,必须自己编程实现遍历这些包含对象并获取他们大小。
(4)权重初始化Weight Initialization(常用方法)[11]
Xavier与 Kaiming权重参数初始化

上述初始化方式,尽可能地保证了训练较长时期的输入激活值都没有出现坍缩或饱和现象,使得网络中的信息得到了较好的流动。
(5)数据归一化处理
数据归一化的优点,一是对权值的微小变化不太敏感,更易于优化;二是一定程度上减少了数据内存消耗。
(6)损失函数
常见的损失函数有交叉熵损失、Hinge损失、平方误差损失等,其中针对样本不平衡情况也涌现出了较多的基于损失函数Trick的缓解方法。具体详细的关于损失函数的内容可见本文参考资料[14]。
(7)梯度下降(gradient descent)
在深度学习(神经网络)模型中,BP反向传播方法的核心就是对每层的权重参数不断使用梯度下降来进行优化 [13] 。梯度的方向是函数在给定点上升最快的方向,那么梯度的反方向就是函数在给定点下降最快的方向。在深度学习中,计算图使得反向传播算法实现变得简单。
小批量随机梯度下降(minibatch gradient descent)是目前常用的梯度下降策略。其在多次训练迭代之后,几乎能与使用完整数据集来计算梯度的模型精确率相近并避免了全读入的内存限制问题。同时,其缓解了单个观测值处理频繁随机梯度下降(随机梯度下降—SGD),频繁更新权重参数造成的大量耗时。
(8)学习率调整策略——WarmUp+
Warmup是一种学习率预热方法,在模型最初几个Epoch以较低学习率训练,后又改为预先设置的学习率以及调整策略进行训练。由于刚开始训练时,模型的权重是随机初始化的,此时若选择一个较大的学习率,可能使得模型学偏或不稳定(最初的更新方向可能也是毫无意义的)。而Warmup可以在训练初期使模型相对趋于稳定后再选择常见策略训练,从而更大可能使得模型收敛速度变快、模型效果更佳 [5]。
常见的学习率调整策略有分段常数衰减、指数衰减、余弦衰减、分段衰减、余弦退火、周期三角、warmup以及多种组合策略。本文主要针对warmup+其它学习率调整策略组合进行实现。若需要更为详细地了解学习率调整策略可参考本文参考资料[1][2][3][5]。
二、基于Jupyter的Pytorch自定义模型方法实践
(一)Pytorch自定义数据集读取类(父类Dataset)[6]
Dataset自定义设计,主要就是重写__len__()和__getitem__()两个成员函数。同时,我们可以在__getitem__()函数中加入对样本数据的预处理,该操作通常可以借助Torchvision.transforms包含的类操作。
值得注意的是:
1)样本及其对应标签样本数据的良好命名可以使得数据读入更为规范;
2)将图像样本数据归一化到0-1内可以节省内存;
3)对于标签数据,我们通常遵循类别从0-n(整型int)顺序标记为规范,将耕地标签改为1便于后续作交叉熵计算。
class MyDataset(Dataset):def __init__(self, imgPath, biaoqianPath):self.imgPath = imgPathself.biaoqianPath = biaoqianPathself.imgName = os.listdir(imgPath)# ToTensor会将数据归一化到0-1之间self.preProcess = ToTensor()def __len__(self):return len(self.imgName)def __getitem__(self, index):imgP = os.path.join(self.imgPath, self.imgName[index])biaoqianP = os.path.join(self.biaoqianPath, r"{0}_mask.png".format(self.imgName[index].split(".")[0]))# 读入单张图片数据为 H-W-C 即256×256×3img = cv2.imread(imgP)biaoqian = cv2.imread(biaoqianP)biaoqianGray = cv2.cvtColor(biaoqian, cv2.COLOR_BGR2GRAY)# 将耕地标签255更改为1(为规范化)biaoqianGray = biaoqianGray // 255# 样本多波段数据归一化img = self.preProcess(img)return img, biaoqianGray# DataLoder配合使用代码# RGB图像及语义分割标签数据绝对路径
trainPath = r"G:\ChanXueYan\dataset\train_val_test\train"
LtrainPath = r"G:\ChanXueYan\dataset\train_val_test\Ltrain"data_train = DataLoader(MyDataset(trainPath, LtrainPath), batch_size=5, shuffle=True, num_workers=0)
利用TensorBoard展示读入结果如下:
(二)神经网络模型自定义初始权重实践
具有适合的初始化权重的神经网络通常能获得较好的模型表现。不同的初始化权重方式在不同的激活函数上有不同的表现,例如,通常Kaiming初始化在搭配Relu激活函数后取得较好表现。接下来,我将展示部分Pytorch自定义初始化权重参数的方法。
需要注意的是,如果nn.Sequential容器里嵌套nn.Sequential容器,那么需要另开一个判断容器分支,从而再次在容器里循环初始化。
# 定义两个简单的卷积Block
ConvDemo = nn.Sequential(nn.Conv2d(3,2,3,1,1),nn.BatchNorm2d(2),nn.ReLU(),nn.Conv2d(3,2,3,1,1),nn.BatchNorm2d(2),nn.ReLU())# 第一种方法
# for layerpart in ConvDemo:
# # isinstance作用为判断变量是否为相同类
# if isinstance(layerpart,nn.Conv2d):
# nn.init.kaiming_normal_(layerpart.weight,mode="fan_out",nonlinearity="relu")
# elif isinstance(layerpart,nn.BatchNorm2d):
# nn.init.constant_(layerpart.weight,1)
# nn.init.constant_(layerpart.bias,0) # 第二种方法
def init_weights(m):if type(m) == nn.Conv2d:nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")elif type(m) == nn.BatchNorm2d:nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)
ConvDemo.apply(init_weights)# 查看结果
print(ConvDemo.modules)
# 打印第一层卷积权重
print(ConvDemo[0].weight)
# 打印第一次层批量归一化权重和偏移
print(ConvDemo[1].weight)
print(ConvDemo[1].bias)
结果展示:
(三)神经网络模型自定义学习率下降算法实践
学习率的大小很重要。如果学习率太大,优化会发散;如果太小,训练就会需要过长时间,或者最终只能得到次优的结果。同时,学习率衰减速率同样重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。
对于学习率的动态调整,实际上,我们可以在每个迭代轮数(Epoch)或者每个小批量(Step)之后动态调整学习率,从而响应优化。在实现简单的pytorch自定义学习率调整策略时,我们主要定义一个学习率调度器(单因子调度器、多因子调度器)并重写类内__init__()、__call__()和step函数即可。
在pytorch中,我们通常使用以下方式获取并管理模型参数:
object.state_dict()——获取模型中各类对象的参数
current_lr = 优化器实例变量.state_dict()["param_groups"][0]["lr"]——获取当前学习率
接下来,我以“warmup+余弦调度器”为例,实现学习率自定义动态变化。(补充:余弦调度器是Loshchilov.Hutter.2016提出的一种启发式算法)
# 自定义学习率调度策略
# 该方法未定义scheduler.step()方法,需要在训练过程通过对优化器的学习率参数赋值更新
class WCscheduler:def __init__(self, base_lr):self.base_lr = base_lr# 设置最大cosine余弦下降轮次,在该轮次之后全以final_lr继续训练self.max_Epoch = 40# 学习率变化下限(学习率最低值)# 这里可自己根据实验需求调整(此处仅作随意展示)self.final_lr = 1e-9# 预热Epoch数设置,此处可能在5次完全迭代取得较好效果self.warmup_Epoch = 5# 预热过程初始学习率(此处仅作随意展示)self.warmup_begin_lr = 1e-9# 非预热过程的变化策略迭代总Epochself.cosine_Epochs = self.max_Epoch - self.warmup_Epochdef get_warmup_lr(self, epoch):warmup_increase = (self.base_lr - self.warmup_begin_lr) * float(epoch) / float( self.warmup_Epoch)return self.warmup_begin_lr + warmup_increasedef __call__(self, epoch):if epoch < self.warmup_Epoch:return self.get_warmup_lr(epoch)if epoch <= self.max_Epoch:self.base_lr = self.final_lr + ( self.base_lr - self.final_lr) * (1 + math.cos(math.pi * (epoch - self.warmup_Epoch) / self.cosine_Epochs)) / 2return self.base_lr# 学习率调度器(提示作用)
scheduler = WCscheduler(lr)# 在训练部分加入该部分
# 设置优化器学习率for param_group in optim.param_groups:param_group['lr'] = scheduler(epoch)
结果展示:
参考资料:
[1]《动手学深度学习》 — 动手学深度学习 2.0.0 documentation
[2]lr scheduler介绍和可视化 - 知乎
[3]learning rate 的schedule 和callbacks 总结 - 知乎
[4]理解语言的 Transformer 模型 | TensorFlow Core
[5]pytorch之warm-up预热学习策略_pytorch warmup_还能坚持的博客-CSDN博客
[6]源码级理解Pytorch中的Dataset和DataLoader_pytorch dataloader shuffle 所有样本_算法美食屋的博客-CSDN博客
[7]第4章 基础知识进阶 第4.1节 Python基础概念之迭代、可迭代对象、迭代器_LaoYuanPython的博客-CSDN博客
[8]【Python】获取变量占用的内存大小_python查看变量内存大小_是小菜欸的博客-CSDN博客
[9]python中int占几个字节_Python中的整型占多少个字节?_weixin_39997173的博客-CSDN博客
[10]Python彻底搞懂:变量、对象、赋值、引用、拷贝 - 知乎
[11]Xavier初始化_探索世界的小白的博客-CSDN博客
[12]pytorch——计算图与动态图机制_然后就去远行吧的博客-CSDN博客
[13]深度学习相关概念:梯度下降_深度学习梯度下降_AiCharm的博客-CSDN博客
[14]损失函数总结-应用和trick - 知乎
模型表现快速分析和可视化平台(李飞飞老师团队)
Weights & Biases – Developer tools for ML
相关文章:

基于Pytorch的神经网络部分自定义设计
一、基础概念(学习笔记) (1)训练误差和泛化误差[1] 本质上,优化和深度学习的目标是根本不同的。前者主要关注的是最小化目标,后者则关注在给定有限数据量的情况下寻找合适的模型。训练误差和泛化误差通常不…...
持续更新串联记忆English words
(一)这是一组关于“服装搭配”的单词。通过在记忆中检索,回忆起隐藏的信息吧~ >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>今日单词>>&…...
postgresql 内核源码分析 btree索引的增删查代码基本原理流程分析,索引膨胀的原因在这里
B-Tree索引代码流程分析 专栏内容: postgresql内核源码分析手写数据库toadb并发编程 开源贡献: toadb开源库 个人主页:我的主页 管理社区:开源数据库 座右铭:天行健,君子以自强不息;地势坤&…...

详细了解G1、了解G1、G1垃圾收集器详解、G1垃圾回收器简单调优
4.详细了解G1: 4.1.一:什么是垃圾回收 4.2.了解G1 4.3.G1 Yong GC 4.4.G1 Mix GC 4.5.三色标记算法 4.6.调优实践 5.G1垃圾收集器详解 5.1.G1垃圾收集器 5.2.G1的堆内存划分 5.3.G1的运行过程 5.4.三色标记 5.4.1.漏标问题 5.5.记忆集与卡表 5.6.安全点与…...
vue项目中 package.json 详解
在 Vue 项目中,package.json 是一个重要的配置文件,它包含了项目的名称、版本、作者、依赖等信息。下面是一份详细的 Vue 项目 package.json 配置说明: 1.name:项目的名称,用于标识项目,例如:&q…...

为什么要进行管网水位监测,管网水位监测的作用是什么
管网水位监测是城市排水系统管理的重要手段,对于保障城市排水设施安全运行和提升城市管理水平具有重要意义。通过对排水管网的水位进行实时监测和分析,能够及时发现问题并采取措施,提高排水系统的运行效率和管理水平。本文将详细介绍为什么要…...
webpack学习笔记
1. webpack基本概念 webpack: JavaScript 应用程序的静态模块打包器,是目前最为流行的JavaScript打包工具之一。webpack会以一个或多个js文件为入口,递归检查每个js模块的依赖,从而构建一个依赖关系图,然后依据该关系…...

解析代理IP在跨境电商和社媒营销中的关键作用
跨境电商和社媒营销领域的从业者深知,代理IP的价值愈发凸显。在推广营销的过程中,频繁遇到因IP关联而封禁账号的情况,或因使用不安全IP而导致异常问题。 这些问题促使人们开始高度重视代理IP的作用。但实际上,代理IP究竟是何物&a…...

Unity 之 Start 与Update 方法的区别
文章目录 当谈论Unity中的 Start和 Update方法时,我们实际上是在讨论MonoBehaviour类中的两个常用方法,用于编写游戏逻辑。这两个方法在不同的时机被调用,因此您可以根据需要选择在哪个方法中编写特定的代码。 Start 方法: Start…...
Spring Boot中如何编写优雅的单元测试
单元测试是指对软件中的最小可测试单元进行检查和验证。在Java中,单元测试的最小单元是类。通过编写针对类或方法的小段代码,来检验被测代码是否符合预期结果或行为。执行单元测试可以帮助开发者验证代码是否正确实现了功能需求,以及是否能够…...

三星Galaxy S23与iPhone 15的对比分析:谁会胜出?
三星Galaxy S23与iPhone 15的对决将于下个月进入高潮,这将是今年智能手机中最大的一场较量。毕竟,这是两家领先的移动设备制造商的旗舰手机。他们的手机的比较将在很大程度上决定谁能获得最佳手机的称号。 我们已经知道有利于三星Galaxy S23的情况,该产品自春季以来一直在推…...

MySQL索引 事物 存储引擎
一 索引 索引的概念 索引就是一种帮助系统能够更快速的查找信息的结构 索引的作用 索引的副作用 创建索引的规则 MySQL的优化 哪些字段/场景适合创建索引 哪些不适合 小字段唯一性强的字段更新不频繁,但查询率比较高的字段表记录超过 300行主键,外键…...
【谷粒学院】报错记录
无法从Nacos获取动态配置 原先gulimall-common中SpringCloud Alibaba的版本是2.1.0.RELEASE,无法从Nacos中获取配置文件信息 <dependencyManagement><dependencies><dependency><groupId>com.alibaba.cloud</groupId><artifactId&…...

微积分基本概念
微分 函数的微分是指对函数的局部变化的一种线性描述。微分可以近似地描述当函数自变量的取值作足够小的改变时,函数的值是怎样改变的。。对于函数 y f ( x ) y f(x) yf(x) 的微分记作: d y f ′ ( x ) d x d_y f^{}(x)d_x dyf′(x)dx 微分和…...

【业务功能篇78】微服务-前端后端校验- 统一异常处理-JSR-303-validation注解
5. 前端校验 我们在前端提交的表单数据,我们也是需要对提交的数据做相关的校验的 Form 组件提供了表单验证的功能,只需要通过 rules 属性传入约定的验证规则,并将 Form-Item 的 prop 属性设置为需校验的字段名即可 校验的页面效果 前端数据…...

pytorch的用法
...
Qt 设置窗口背景
窗口背景无非两种:背景色、背景图片。Qt中窗口背景如何设置? 一、QPalette设置背景 二、实现paintEvent,使用QPainter来绘制背景 三、使用QSS来设置背景 关于QSS的使用不想多说,一般我不用QSS设置窗口背景,也不建议…...

大模型是什么?泰迪大模型能够解决企业哪些痛点?
什么是大模型? 大模型是指模型具有庞大的参数规模和复杂程度的机器学习模型。在深度学习领域,大模型通常是指具有数百万到数十亿参数的神经网络模型。这些模型需要大量的计算资源和存储空间来训练和存储,并且往往需要进行分布式计算和特殊…...

YOLOv7-tracker 目标追踪 输入视频帧
目录 1 项目安装1.1 环境搭建1.2 项目下载1.3 权重下载1.4 环境安装1.5 上传待检测的视频帧 2 视频帧检测与追踪2.1 检测与追踪2.3 结果 参考项目:https://github.com/JackWoo0831/Yolov7-tracker/tree/master github链接:https://github.com/Whiffe/Yo…...

C语言二——C语言编写一段代码,求一元二次方程的根
这段代码实现了解一元二次方程的根的计算。用户需要输入方程的系数a、b、c,然后根据判别式的值确定方程的根的情况,并进行相应的输出。 如果判别式大于0,说明方程有两个实根,分别计算并输出。如果判别式等于0,说明方程…...
KubeSphere 容器平台高可用:环境搭建与可视化操作指南
Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

国防科技大学计算机基础课程笔记02信息编码
1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...

汽车生产虚拟实训中的技能提升与生产优化
在制造业蓬勃发展的大背景下,虚拟教学实训宛如一颗璀璨的新星,正发挥着不可或缺且日益凸显的关键作用,源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例,汽车生产线上各类…...
JVM垃圾回收机制全解析
Java虚拟机(JVM)中的垃圾收集器(Garbage Collector,简称GC)是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象,从而释放内存空间,避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...

如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...
Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理
引言 Bitmap(位图)是Android应用内存占用的“头号杀手”。一张1080P(1920x1080)的图片以ARGB_8888格式加载时,内存占用高达8MB(192010804字节)。据统计,超过60%的应用OOM崩溃与Bitm…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”
2025年#高考 将在近日拉开帷幕,#AI 监考一度冲上热搜。当AI深度融入高考,#时间同步 不再是辅助功能,而是决定AI监考系统成败的“生命线”。 AI亮相2025高考,40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕,江西、…...
Hive 存储格式深度解析:从 TextFile 到 ORC,如何选对数据存储方案?
在大数据处理领域,Hive 作为 Hadoop 生态中重要的数据仓库工具,其存储格式的选择直接影响数据存储成本、查询效率和计算资源消耗。面对 TextFile、SequenceFile、Parquet、RCFile、ORC 等多种存储格式,很多开发者常常陷入选择困境。本文将从底…...

佰力博科技与您探讨热释电测量的几种方法
热释电的测量主要涉及热释电系数的测定,这是表征热释电材料性能的重要参数。热释电系数的测量方法主要包括静态法、动态法和积分电荷法。其中,积分电荷法最为常用,其原理是通过测量在电容器上积累的热释电电荷,从而确定热释电系数…...
【Go语言基础【12】】指针:声明、取地址、解引用
文章目录 零、概述:指针 vs. 引用(类比其他语言)一、指针基础概念二、指针声明与初始化三、指针操作符1. &:取地址(拿到内存地址)2. *:解引用(拿到值) 四、空指针&am…...