当前位置: 首页 > news >正文

learn_pytorch03

第三章
深度学习分为如下几个步骤
1:数据预处理,划分训练集和测试集
2:选择模型,设定损失函数和优化函数
3:用模型取拟合训练数据,并在验证计算模型上表现。
接着学习了一些数据读入
模型构建
损失函数的构建
以及训练
第四章
主要是基础实战。
一些细节:
class MLP(nn.Module):

声明带有模型参数的层,这里声明了两个全连接层

def init(self, **kwargs):
# 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
super(MLP, self).init(**kwargs)
self.hidden = nn.Linear(784, 256)
self.act = nn.ReLU()
self.output = nn.Linear(256,10)

定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出

def forward(self, x):
o = self.act(self.hidden(x))
return self.output(o)
net(X) 会调用MLP继承自Module类的call函数,这个函数会调用MLP类定义的forward函数来完成前向计算

一个神经网络的典型训练过程如下:

定义包含一些可学习参数(或者叫权重)的神经网络

在输入数据集上迭代

通过网络处理输入

计算 loss (输出和正确答案的距离)

将梯度反向传播给网络的参数

更新网络的权重,一般使用一个简单的规则:weight = weight - learning_rate * gradient

torch.nn.init 作用
All the functions in this module are intended to be used to initialize neural network parameters, so they all run in torch.no_grad() mode and will not be taken into account by autograd.

损失函数的构建

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction=‘mean’)
功能:计算二分类任务时的交叉熵(Cross Entropy)函数。在二分类中,label是{0,1}。对于进入交叉熵函数的input为概率分布的形式。一般来说,input为sigmoid激活层的输出,或者softmax的输出。
weight:每个类别的loss设置权值

size_average:数据为bool,为True时,返回的loss为平均值;为False时,返回的各样本的loss之和。

reduce:数据类型为bool,为True时,loss的返回是标量

torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction=‘mean’)

功能:计算交叉熵函数

主要参数:

weight:每个类别的loss设置权值。

size_average:数据为bool,为True时,返回的loss为平均值;为False时,返回的各样本的loss之和。

ignore_index:忽略某个类的损失函数。

reduce:数据类型为bool,为True时,loss的返回是标量。

torch.nn.L1Loss(size_average=None, reduce=None, reduction=‘mean’)

功能: 计算输出y和真实标签target之间的差值的绝对值。

我们需要知道的是,reduction参数决定了计算模式。有三种计算模式可选:none:逐个元素计算。 sum:所有元素求和,返回标量。 mean:加权平均,返回标量。 如果选择none,那么返回的结果是和输入元素相同尺寸的。默认计算方式是求平均。

torch.nn.MSELoss(size_average=None, reduce=None, reduction=‘mean’)
功能: 计算输出y和真实标签target之差的平方

torch.nn.SmoothL1Loss(size_average=None, reduce=None, reduction=‘mean’, beta=1.0)
L1的平滑输出,其功能是减轻离群点带来的影响

reduction参数决定了计算模式。有三种计算模式可选:none:逐个元素计算。 sum:所有元素求和,返回标量。默认计算方式是求平均。

torch.nn.SmoothL1Loss(size_average=None, reduce=None, reduction=‘mean’, beta=1.0)
L1的平滑输出,其功能是减轻离群点带来的影响

reduction参数决定了计算模式。有三种计算模式可选:none:逐个元素计算。 sum:所有元素求和,返回标量。默认计算方式是求平均。

torch.nn.PoissonNLLLoss(log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction=‘mean’)
功能: 泊松分布的负对数似然损失函数

主要参数:

log_input:输入是否为对数形式,决定计算公式。

full:计算所有 loss,默认为 False。

eps:修正项,避免 input 为 0 时,log(input) 为 nan 的情况。

torch.nn.KLDivLoss(size_average=None, reduce=None, reduction=‘mean’, log_target=False)
功能: 计算KL散度,也就是计算相对熵。用于连续分布的距离度量,并且对离散采用的连续输出空间分布进行回归通常很有用。
reduction:计算模式,可为 none/sum/mean/batchmean。

none:逐个元素计算。

sum:所有元素求和,返回标量。

mean:加权平均,返回标量。

batchmean:batchsize 维度求平均值

torch.nn.MarginRankingLoss(margin=0.0, size_average=None, reduce=None, reduction=‘mean’)

功能: 计算两个向量之间的相似度,用于排序任务。该方法用于计算两组数据之间的差异。

margin:边界值,(x_{1}) 与(x_{2}) 之间的差异值。

reduction:计算模式,可为 none/sum/mean

torch.nn.MultiLabelMarginLoss(size_average=None, reduce=None, reduction=‘mean’)

功能: 对于多标签分类问题计算损失函数。

torch.nn.SoftMarginLoss(size_average=None, reduce=None, reduction=‘mean’)torch.nn.(size_average=None, reduce=None, reduction=‘mean’)

功能: 计算二分类的 logistic 损失。
reduction:计算模式,可为 none/sum/mean。

torch.nn.MultiMarginLoss(p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction=‘mean’)

功能: 计算多分类的折页损失

主要参数:

reduction:计算模式,可为 none/sum/mean。

p:可选 1 或 2。

weight:各类别的 loss 设置权值。

margin:边界值

torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction=‘mean’)

功能: 计算三元组损失。

三元组: 这是一种数据的存储或者使用格式。<实体1,关系,实体2>。在项目中,也可以表示为< anchor, positive examples , negative examples>

在这个损失函数中,我们希望去anchor的距离更接近positive examples,而远离negative examples

主要参数:

reduction:计算模式,可为 none/sum/mean。

p:可选 1 或 2。

margin:边界值

torch.nn.HingeEmbeddingLoss(margin=1.0, size_average=None, reduce=None, reduction=‘mean’)

功能: 对输出的embedding结果做Hing损失计算

主要参数:

reduction:计算模式,可为 none/sum/mean。

margin:边界值

torch.nn.CosineEmbeddingLoss(margin=0.0, size_average=None, reduce=None, reduction=‘mean’)

功能: 对两个向量做余弦相似度

主要参数:

reduction:计算模式,可为 none/sum/mean。

margin:可取值[-1,1] ,推荐为[0,0.5] 。

torch.nn.CTCLoss(blank=0, reduction=‘mean’, zero_infinity=False)

功能: 用于解决时序类数据的分类

计算连续时间序列和目标序列之间的损失。CTCLoss对输入和目标的可能排列的概率进行求和,产生一个损失值,这个损失值对每个输入节点来说是可分的。输入与目标的对齐方式被假定为 “多对一”,这就限制了目标序列的长度,使其必须是≤输入长度。

主要参数:

reduction:计算模式,可为 none/sum/mean。

blank:blank label。

zero_infinity:无穷大的值或梯度值为

训练和评估
model.train() # 训练状态
model.eval() # 验证/测试状态
for data, label in train_loader:
之后将数据放到GPU上用于后续计算,此处以.cuda()为例

data, label = data.cuda(), label.cuda()
开始用当前批次数据做训练时,应当先将优化器的梯度置零:

optimizer.zero_grad()
之后将data送入模型中训练:

output = model(data)
根据预先定义的criterion计算损失函数:

loss = criterion(output, label)
将loss反向传播回网络:

loss.backward()
使用优化器更新模型参数:

optimizer.step()
这样一个训练过程就完成了,后续还可以计算模型准确率等指标,这部分会在下一节的图像分类实战中加以介绍。

验证/测试的流程基本与训练过程一致,不同点在于:

需要预先设置torch.no_grad,以及将model调至eval模式

不需要将优化器的梯度置零

不需要将loss反向回传到网络

不需要更新optimizer

反向传播有什么作用

在模型训练过程中,反向传播(Backpropagation)是一种极为重要的算法,它主要用于计算损失函数相对于模型参数的梯度,进而指导参数的更新,最终让模型学习到数据中的模式和规律。下面详细介绍反向传播的作用:
计算梯度
链式法则的应用:在深度学习模型中,损失函数通常是一个复合函数,涉及众多的参数和中间变量。反向传播算法利用链式法则,从损失函数开始,沿着网络的反向路径,逐层计算损失函数相对于每个参数的梯度。
具体示例:以一个简单的多层感知机(MLP)为例,假设输入层有 两个输入,经过隐藏层和激活函数后得到输出 ,损失函数为 ( 是真实标签)。反向传播会从损失函数 开始,依次计算 对输出层权重、隐藏层权重等参数的偏导数,即梯度。
指导参数更新
梯度下降的基础:梯度代表了函数在某一点上升最快的方向,而我们的目标是最小化损失函数,因此需要朝着梯度的反方向更新模型参数。反向传播计算得到的梯度为梯度下降算法提供了必要的信息。
参数更新公式:在随机梯度下降(SGD)算法中,参数更新公式为 ,其中 是模型参数, 是学习率, 是通过反向传播计算得到的梯度。
模型学习与优化
自动调整参数:通过不断地进行前向传播计算损失,再通过反向传播计算梯度并更新参数,模型能够自动调整自身的权重和偏置,使得损失函数逐渐减小,从而提高模型的预测能力。
提高泛化能力:反向传播算法帮助模型在训练数据上学习到有效的特征表示,同时通过适当的正则化方法(如 L1、L2 正则化),可以避免模型过拟合,提高模型在未知数据上的泛化能力。
支持复杂网络结构
深度网络的训练:在深度神经网络中,网络结构通常非常复杂,包含多个隐藏层和大量的参数。反向传播算法能够有效地处理这种复杂的网络结构,通过逐层计算梯度,使得深度模型的训练成为可能。
不同类型的层:反向传播算法可以应用于各种类型的神经网络层,如全连接层、卷积层、循环层等,为不同类型的深度学习模型提供了统一的训练方法。
综上所述,反向传播在模型训练中起着核心作用,它使得深度学习模型能够高效地学习数据中的模式,不断优化自身的性能,从而在各种任务中取得良好的效果。

深度学习的目标是通过不断改变网络参数,使得参数能够对输入做各种非线性变换拟合输出,本质上就是一个函数去寻找最优解,只不过这个最优解是一个矩阵,而如何快速求得这个最优解是深度学习研究的一个重点,以经典的resnet-50为例,它大约有2000万个系数需要进行计算,那么我们如何计算出这么多系数,有以下两种方法:

第一种是直接暴力穷举一遍参数,这种方法实施可能性基本为0,堪比愚公移山plus的难度。

为了使求解参数过程更快,人们提出了第二种办法,即BP+优化器逼近求解。

因此,优化器是根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值,使得模型输出更加接近真实标签。

在训练深度学习模型时,每次迭代计算梯度之前通常需要将梯度置为 0,主要是基于 PyTorch、TensorFlow 等深度学习框架中梯度计算和更新的机制,下面从多个方面为你详细解释原因。
梯度累加机制
框架的梯度计算方式:在大多数深度学习框架里,梯度并不是每次计算后就自动清零,而是会进行累加。这是因为框架考虑到可能存在需要多次前向传播和反向传播来累积梯度,再进行一次参数更新的情况(如使用大批次数据但内存有限时)。
避免梯度错误累加:如果在每次迭代开始时不将梯度置为 0,那么新计算得到的梯度会和上一次迭代的梯度累加在一起。这样会导致梯度值不断增大,使得参数更新的步长失去控制,模型无法按照预期收敛,甚至可能导致梯度爆炸。
独立计算每次迭代的梯度
保证每次迭代的独立性:模型训练是通过多次迭代不断更新参数以最小化损失函数的过程。每次迭代都应该基于当前的参数状态和当前批次的数据独立计算梯度,这样才能准确反映当前批次数据对损失函数的影响,进而正确更新参数。
示例说明:以随机梯度下降(SGD)为例,每次迭代使用一个小批次的数据计算梯度并更新参数。如果不将梯度置为 0,那么下一次迭代计算的梯度就会受到上一次批次数据的干扰,无法准确反映当前批次数据的信息,从而影响模型的学习效果。

相关文章:

learn_pytorch03

第三章 深度学习分为如下几个步骤 1&#xff1a;数据预处理&#xff0c;划分训练集和测试集 2&#xff1a;选择模型&#xff0c;设定损失函数和优化函数 3&#xff1a;用模型取拟合训练数据&#xff0c;并在验证计算模型上表现。 接着学习了一些数据读入 模型构建 损失函数的构…...

机器学习:k近邻

所有代码和文档均在golitter/Decoding-ML-Top10: 使用 Python 优雅地实现机器学习十大经典算法。 (github.com)&#xff0c;欢迎查看。 K 邻近算法&#xff08;K-Nearest Neighbors&#xff0c;简称 KNN&#xff09;是一种经典的机器学习算法&#xff0c;主要用于分类和回归任务…...

redis之lua实现原理

文章目录 创建并修改Lua环境Lua环境协作组件伪客户端lua scripts字典 EVAL命令的实现定义脚本函数执行脚本函数 EVALSHA命令的实现脚本管理命令的实现SCRIPT FLUSHSCRIPTEXISTSSCRIPT LOADSCRIPT KILL 脚本复制复制 EVAL命令、SCRIPT FLUSH命令和SCRIPT LOAD命令* 复制EVALSHA命…...

[Android] 【汽车OBD软件】Torque Pro (OBD 2 Car)

[Android] 【汽车OBD软件】Torque Pro &#xff08;OBD 2 & Car&#xff09; 链接&#xff1a;https://pan.xunlei.com/s/VOIyKOKHBR-2XTUy6oy9A91yA1?pwdm5jm# 获取 OBD 故障代码、汽车性能数据等等。Torque 使用连接到您的 OBD2 发动机管理/ECU 的 OBD II 蓝牙适配器。…...

安全问答—安全的基本架构

前言 将一些安全相关的问答进行整理汇总和陈述&#xff0c;形成一些以问答呈现的东西&#xff0c;加入一些自己的理解&#xff0c;欢迎路过的各位大佬进行讨论和论述。很多内容都会从甲方的安全认知去进行阐述。 1.安全存在的目的&#xff1f; 为了支持组织的目标、使命和宗…...

Java 运行时常量池笔记(详细版

&#x1f4da; Java 运行时常量池笔记&#xff08;详细版&#xff09; Java 的运行时常量池&#xff08;Runtime Constant Pool&#xff09;是 JVM 方法区的一部分&#xff0c;用于存储编译期生成的字面量和符号引用。它是 Java 类文件常量池的运行时表示&#xff0c;具有动态…...

mysql增加字段操作以及关键字报错

修改mysql DDL语言 修改代码中domain 修改mapper中信息 java.sql.SQLSyntaxErrorException: You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near index, date, scroll_id, shard_ser…...

Wireshark 输出 数据包列表本身的值

在 Wireshark 中&#xff0c;如果你想输出数据包列表本身的值&#xff08;例如&#xff0c;将数据包的摘要信息、时间戳、源地址、目的地址等导出为文本格式&#xff09;&#xff0c;可以使用 导出为纯文本文件 的功能。以下是详细步骤&#xff1a; 步骤 1&#xff1a;打开 Wir…...

日常开发中,使用JSON.stringify来实现深拷贝的坑

使用JSON.stringify的方式来实现深拷贝的弊端 弊端一&#xff1a;无法拷贝NaN、Infinity、undefined这类值 无法拷贝成功的原因&#xff1a; 对于JSON来说&#xff0c;它支持的数据类型只有null、string、number、boolean、Object、Array&#xff0c;所以对于它不支持的数据类…...

【探商宝】:大数据与AI赋能,助力中小企业精准拓客引

引言&#xff1a;在数据洪流中&#xff0c;如何精准锁定商机&#xff1f; 在竞争激烈的商业环境中&#xff0c;中小企业如何从海量信息中快速筛选出高价值客户&#xff1f;如何避免无效沟通&#xff0c;精准触达目标企业&#xff1f; 探商宝——一款基于大数据与AI技术的企业信…...

Javascript网页设计案例:通过PDF.js实现一款PDF阅读器,包括预览、页面旋转、页面切换、放大缩小、黑夜模式等功能

前言 目前功能包括&#xff1a; 切换到首页。切换到尾页。上一页。下一页。添加标签。标签管理页面旋转页面随意拖动双击后还原位置 其实按照自己的预期来说&#xff0c;有很多功能还没有开发完&#xff0c;配色也没有全都搞完&#xff0c;先发出来吧&#xff0c;后期有需要…...

各类系统Pycharm安装教程

各类系统Pycharm安装教程 一、安装前的准备 1. 系统要求 操作系统: Windows:Windows 10 或更高版本(64位)。macOS:macOS 10.14 或更高版本。Linux:Ubuntu 18.04+、Fedora 30+ 等主流发行版。硬件要求: 内存:至少 4GB(推荐 8GB 以上)。磁盘空间:至少 2.5GB 可用空间…...

哈希表(C语言版)

文章目录 哈希表原理实现(无自动扩容功能)代码运行结果 分析应用 哈希表 如何统计一段文本中&#xff0c;小写字母出现的次数? 显然&#xff0c;我们可以用数组 int table[26] 来存储每个小写字母出现的次数&#xff0c;而且这样处理&#xff0c;效率奇高。假如我们想知道字…...

内容中台驱动企业数字化内容管理高效协同架构

内容概要 在数字化转型加速的背景下&#xff0c;企业对内容管理的需求从单一存储向全链路协同演进。内容中台作为核心支撑架构&#xff0c;通过统一的内容资源池与智能化管理工具&#xff0c;重塑了内容生产、存储、分发及迭代的流程。其核心价值在于打破部门壁垒&#xff0c;…...

LLaMA-Factory DeepSeek-R1 模型 微调基础教程

LLaMA-Factory 模型 微调基础教程 LLaMA-FactoryLLaMA-Factory 下载 AnacondaAnaconda 环境创建软硬件依赖 详情LLaMA-Factory 依赖安装CUDA 安装量化 BitsAndBytes 安装可视化微调启动 数据集准备所需工具下载使用教程所需数据合并数据集预处理 DeepSeek-R1 可视化微调数据集处…...

vue 文件下载(导出)excel的方法

目前有一个到处功能的需求&#xff0c;这是我用过DeepSeek生成的导出&#xff08;下载&#xff09;excel的一个方法。 1.excel的文件名是后端生成的&#xff0c;放在了响应头那里。 2.这里也可以自己制定文件名。 3.axios用的是原生的axios&#xff0c;不要用处理过的&#xff…...

【第4章:循环神经网络(RNN)与长短时记忆网络(LSTM)— 4.3 RNN与LSTM在自然语言处理中的应用案例】

咱今天来聊聊在人工智能领域里,特别重要的两个神经网络:循环神经网络(RNN)和长短时记忆网络(LSTM),主要讲讲它们在自然语言处理里的应用。你想想,平常咱们用手机和别人聊天、看新闻、听语音助手说话,背后说不定就有 RNN 和 LSTM 在帮忙呢! 二、RNN 是什么? (一)…...

LLMs Ollama

LLMs 即大型语言模型&#xff08;Large Language Models&#xff09;&#xff0c;是人工智能领域基于深度学习的重要技术&#xff0c;以下是关于它的详细介绍&#xff1a; 定义与原理 定义&#xff1a;LLMs 是一类基于深度学习的人工智能模型&#xff0c;通过海量数据和大量计…...

Blackbox.AI:高效智能的生产力工具新选择

前言 在当今数字化时代&#xff0c;一款高效、智能且功能全面的工具对于开发者、设计师以及全栈工程师来说至关重要。Blackbox.AI凭借其独特的产品特点&#xff0c;在众多生产力工具中脱颖而出&#xff0c;成为了我近期测评的焦点。以下是我对Blackbox.AI的详细测评&#xff0…...

计算机专业知识【 轻松理解数据库四大运算:笛卡尔积、选择、投影与连接】

在数据库的世界里&#xff0c;有几个关键的运算操作&#xff0c;就像是神奇的魔法工具&#xff0c;能帮助我们对数据进行各种处理和组合。今天&#xff0c;咱们就来聊聊笛卡尔积运算、选择运算、投影运算和连接运算这四大运算&#xff0c;用超简单的例子让小白也能轻松理解。 …...

vscode里如何用git

打开vs终端执行如下&#xff1a; 1 初始化 Git 仓库&#xff08;如果尚未初始化&#xff09; git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...

Java 8 Stream API 入门到实践详解

一、告别 for 循环&#xff01; 传统痛点&#xff1a; Java 8 之前&#xff0c;集合操作离不开冗长的 for 循环和匿名类。例如&#xff0c;过滤列表中的偶数&#xff1a; List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql

智慧工地管理云平台系统&#xff0c;智慧工地全套源码&#xff0c;java版智慧工地源码&#xff0c;支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求&#xff0c;提供“平台网络终端”的整体解决方案&#xff0c;提供劳务管理、视频管理、智能监测、绿色施工、安全管…...

抖音增长新引擎:品融电商,一站式全案代运营领跑者

抖音增长新引擎&#xff1a;品融电商&#xff0c;一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中&#xff0c;品牌如何破浪前行&#xff1f;自建团队成本高、效果难控&#xff1b;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...

定时器任务——若依源码分析

分析util包下面的工具类schedule utils&#xff1a; ScheduleUtils 是若依中用于与 Quartz 框架交互的工具类&#xff0c;封装了定时任务的 创建、更新、暂停、删除等核心逻辑。 createScheduleJob createScheduleJob 用于将任务注册到 Quartz&#xff0c;先构建任务的 JobD…...

三体问题详解

从物理学角度&#xff0c;三体问题之所以不稳定&#xff0c;是因为三个天体在万有引力作用下相互作用&#xff0c;形成一个非线性耦合系统。我们可以从牛顿经典力学出发&#xff0c;列出具体的运动方程&#xff0c;并说明为何这个系统本质上是混沌的&#xff0c;无法得到一般解…...

让AI看见世界:MCP协议与服务器的工作原理

让AI看见世界&#xff1a;MCP协议与服务器的工作原理 MCP&#xff08;Model Context Protocol&#xff09;是一种创新的通信协议&#xff0c;旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天&#xff0c;MCP正成为连接AI与现实世界的重要桥梁。…...

智能AI电话机器人系统的识别能力现状与发展水平

一、引言 随着人工智能技术的飞速发展&#xff0c;AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术&#xff0c;在客户服务、营销推广、信息查询等领域发挥着越来越重要…...

Qemu arm操作系统开发环境

使用qemu虚拟arm硬件比较合适。 步骤如下&#xff1a; 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载&#xff0c;下载地址&#xff1a;https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...

STM32---外部32.768K晶振(LSE)无法起振问题

晶振是否起振主要就检查两个1、晶振与MCU是否兼容&#xff1b;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容&#xff08;CL&#xff09;与匹配电容&#xff08;CL1、CL2&#xff09;的关系 2. 如何选择 CL1 和 CL…...