当前位置: 首页 > news >正文

Pytorch入门(7)—— 梯度累加(Gradient Accumulation)

1. 梯度累加

  • 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick
  • 梯度累加的思想很简单,就是时间换空间。具体而言,我们不在每个 batch data 梯度计算后直接更新模型,而是多算几个 batch 后,使用这些 batch 的平均梯度更新模型,从而放大等效 batch_size。如下图所示
    在这里插入图片描述
  • 用公式表示:设 batch size 为 n n n,模型参数为 w \pmb{w} w,样本 i i i 的损失为 l i l_i li,则正常情况下 sgd 参数更新为
    w ← w + α ∑ i = 1 n 1 n ∂ l i ∂ w \pmb{w} \leftarrow \pmb{w} + \alpha \sum_{i=1}^n\frac{1}{n}\frac{\partial l_i}{\partial \pmb{w}} ww+αi=1nn1wli 使用梯度累加时,设累加步长为 m m m(即计算 m m m 个 batch 梯度后用梯度均值更新一次),sgd 更新如下
    w ← w + α 1 m ∑ b = 1 m ∑ i = 1 n 1 n ∂ l b i ∂ w = w + α ∑ i = 1 m n 1 m n ∂ l i ∂ w \begin{aligned} \pmb{w} &\leftarrow \pmb{w} + \alpha \frac{1}{m} \sum_{b=1}^m \sum_{i=1}^n\frac{1}{n}\frac{\partial l_{bi}}{\partial \pmb{w}} \\ &= \pmb{w} + \alpha \sum_{i=1}^{mn}\frac{1}{mn} \frac{\partial l_i}{\partial \pmb{w}} \end{aligned} ww+αm1b=1mi=1nn1wlbi=w+αi=1mnmn1wli 可见这等价于使用 batch_size = m n mn mn 进行训练

2. 在 pytorch 中实现梯度累加

2.1 伪代码

  • pytorch 使用和 tensor 绑定的自动微分机制。每个 tensor 对象都有 .grad 属性存储其中每个元素的梯度值,通过 .requires_grad 属性控制其是否参与梯度计算。训练模型时,一般通过对标量 loss 执行 loss.backward() 自动进行反向传播,以得到计算图中所有 tensor 的梯度。详见 PyTorch入门(2)—— 自动求梯度
  • pytorch 中梯度 tensor.grad 不会自动清零,而会在每次反向传播过程中自动累加,所以一般在反向传播前把梯度清零
    for inputs, labels in data_loader:# forward pass preds = model(inputs)loss  = criterion(preds, labels)# clear grad of last batch	optimizer.zero_grad()# backward pass, calculate grad of batch dataloss.backward()# update modeloptimizer.step()
    
    这种设计对于实现梯度累加 trick 是很方便的,我们可以在 batch 计算过程中进行计数,仅在达到计数达到更新步长时进行一次参数更新并清零梯度,即
    # batch accumulation parameter
    accum_iter = 4  # loop through enumaretad batches
    for batch_idx, (inputs, labels) in enumerate(data_loader):# forward pass preds = model(inputs)loss  = criterion(preds, labels)# scale the loss to the mean of the accumulated batch sizeloss = loss / accum_iter # backward passloss.backward()# weights updateif ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):optimizer.step()optimizer.zero_grad()
    

2.2 线性回归案例

  • 下面使用来自 经典机器学习方法(1)—— 线性回归 的简单线性回归任务说明梯度累加的具体实现方法

    本节代码直接从 jupyter notebook 复制而来,可能无法直接运行!

  • 首先生成随机数据构造 dataset
    import torch
    from IPython import display
    from matplotlib import pyplot as plt
    import numpy as np
    import random
    import torch.utils.data as Data
    import torch.nn as nn
    import torch.optim as optim# 生成样本
    num_inputs = 2
    num_examples = 1000
    true_w = torch.Tensor([-2,3.4]).view(2,1)
    true_b = 4.2
    batch_size = 10# 1000 个2特征样本,每个特征都服从 N(0,1)
    features = torch.randn(num_examples, num_inputs, dtype=torch.float32) # 生成真实标记
    labels = torch.mm(features,true_w) + true_b
    labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)# 包装数据集,将训练数据的特征和标签组合
    dataset = Data.TensorDataset(features, labels)
    
    1. 不使用梯度累加技巧,batch size 设置为 40
      # 构造 DataLoader
      batch_size = 40
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失函数
      criterion = nn.MSELoss()# SGD优化器
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for X, y in data_iter:# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))# 梯度清零optimizer.zero_grad()            # 计算各参数梯度loss.backward()#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')'''
      epoch 1, loss: 0.5434057731628418
      epoch 2, loss: 0.1914414196014404
      epoch 3, loss: 0.06752514398097992
      '''
      
    2. 使用梯度累加,batch size 设置为 10,步长设为 4,等效 batch size 为 40
      # 构造 DataLoader
      batch_size = 10
      accum_iter = 4
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失
      criterion = nn.MSELoss()# SGD优化器对象
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for batch_idx, (X, y) in enumerate(data_iter):# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))  loss = loss / accum_iter	# 取各个累计batch的平均损失,从而在.backward()时得到平均梯度# 反向传播,梯度累计loss.backward()if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_iter)):#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()              # 梯度清零optimizer.zero_grad()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')
      '''
      epoch 1, loss: 0.5434057596921921
      epoch 2, loss: 0.19144139245152472
      epoch 3, loss: 0.06752512042224407
      '''
      
  • 可以观察到无论 epoch loss 还是 net[0].weight.grad 都完全相同,说明梯度累加不影响计算结果

相关文章:

Pytorch入门(7)—— 梯度累加(Gradient Accumulation)

1. 梯度累加 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick梯度累加的思想很简单,就是时…...

day12

第一题 本题我们可以使用以下方法&#xff1a; 方法一&#xff1a; 使用hash表<元素&#xff0c;出现次数>来统计字符串中不同元素分别出现的次数&#xff0c;当某一个元素的次数大于1时&#xff0c;返回false&#xff0c;如果每个元素的出现次数都为1&#xff0c;则返回…...

MySQL技术点合集

目录 1. MySQL目录 2. 验证是否首次登陆方法 3. 在Liunx中使用命令来输入sql语句方法 4. 获取修改密码 5. 关闭密码策略 6. 忘记MySQL密码找回 7. 旋转90度横向查看表 8. 添加一个远程连接的用户 1. MySQL目录 /usr/bin/mysql相关命令vim /etc/my.cnfmysql配置文件ls /…...

记录使用 Vue3 过程中的一些技术点

1、自定义组件&#xff0c;并使用 v-model 进行数据双向绑定。 简述&#xff1a; 自定义组件使用 v-model 进行传参时&#xff0c;遵循 Vue 3 的 v-model 机制。在 Vue 3 中&#xff0c;v-model 默认使用了 modelValue 作为 prop 名称&#xff0c;以及 update:modelValue 作为…...

6. C++通过fork的方式实现高性能网络服务器

我们上一节课写的tcp我们发现只有第一个与之连接的人才能收发信息。他又很多的不足 高性能网络服务器 通过fork实现高性能网络服务器 我们通过fork进行改装之后就可以成百上千的用户进行连接访问&#xff0c;对于每一个用户来说我们都fork一个子进程。让后让每一个子进程都是…...

直播美颜插件、美颜SDK详解:技术、功能与实现原理

今天&#xff0c;小编将详细解析直播美颜插件和美颜SDK的技术、功能以及实现原理。 一、美颜技术的背景与发展 1.1美颜技术的兴起 随着直播平台的普及&#xff0c;美颜SDK技术逐渐被集成到直播软件中&#xff0c;以满足用户对更美观、自然的直播效果的需求。 1.2美颜技术的…...

MyBatis入门(1)

目录 一、JDBC操作回顾 二、什么是MyBatis&#xff1f; 三、MyBatis入门 1、准备工作 &#xff08;1&#xff09;创建工程 &#xff08;2&#xff09;数据准备 2、配置数据库连接字符串 3、写持久层代码 4、单元测试 &#xff08;1&#xff09;使用IDEA自动成成测试类…...

打开服务器远程桌面连接不上,可能的原因及相应的解决策略

在解决远程桌面连接不上服务器的问题时&#xff0c;我们首先需要从专业的角度对可能的原因进行深入分析&#xff0c;并据此提出针对性的解决方案。以下是一些可能的原因及相应的解决策略&#xff1a; 一、网络连接问题 远程桌面连接需要稳定的网络支持&#xff0c;如果网络连接…...

用于时间序列概率预测的蒙特卡洛模拟

大家好&#xff0c;蒙特卡洛模拟是一种广泛应用于各个领域的计算技术&#xff0c;它通过从概率分布中随机抽取大量样本&#xff0c;并对结果进行统计分析&#xff0c;从而模拟复杂系统的行为。这种技术具有很强的适用性&#xff0c;在金融建模、工程设计、物理模拟、运筹优化以…...

VScode解决报错“Remote-SSH XHR failed无法访问远程服务器“的方案

VScode解决报错"Remote-SSH XHR failed无法访问远程服务器"的方案 $ ls ~/.vscode-server/bin 2ccd690cbff1569e4a83d7c43d45101f817401dc稳定版下载链接&#xff1a;https://update.code.visualstudio.com/commit:COMMIT_ID/server-linux-x64/stable 内测版下载链接…...

Python高级进阶--dict字典

dict字典⭐⭐ 1. 字典简介 dictionary&#xff08;字典&#xff09; 是 除列表以外 Python 之中 最灵活 的数据类型&#xff0c;类型为dict 字典同样可以用来存储多个数据字典使用键值对存储数据 2. 字典的定义 字典用{}定义键值对之间使用,分隔键和值之间使用:分隔 d {中…...

记忆力和人才测评,如何提升记忆力?

什么是记忆力&#xff1f; 如何通俗意义上的记忆力&#xff1f;我们可以把人的经历、经验理解成为一部纪录片&#xff0c;那么已经过去发生的事情&#xff0c;就是影片之前的情节&#xff0c;对于这些信息&#xff0c;在脑海里&#xff0c;人们会将其进行处理和组合&#xff…...

数据仓库建模

目录 数仓建模 为什么要对数据仓库进行分层 主题 主题的概念 维度建模&#xff1a; 模型的选择&#xff1a; 星形模式 雪花模型 星座模式 拉链表 维度表和事实表&#xff1a; 维度表 事实表 事实表设计规则 退化维度 事务事实表、周期快照事实表、累积快照事实…...

力扣:1738. 找出第 K 大的异或坐标值

1738. 找出第 K 大的异或坐标值 给你一个二维矩阵 matrix 和一个整数 k &#xff0c;矩阵大小为 m x n 由非负整数组成。 矩阵中坐标 (a, b) 的 值 可由对所有满足 0 < i < a < m 且 0 < j < b < n 的元素 matrix[i][j]&#xff08;下标从 0 开始计数&…...

Keras深度学习框架第二十讲:使用KerasCV中的Stable Diffusion进行高性能图像生成

1、绪论 1.1 概念 为便于后文讨论&#xff0c;首先进行相关概念的陈述。 Stable Diffusion&#xff1a;Stable Diffusion 是一个在图像生成领域广泛使用的技术&#xff0c;尤其是用于文本到图像的转换。它基于扩散模型&#xff08;Diffusion Models&#xff09;&#xff0c;这…...

C/C++ vector详解

要想了解STL&#xff0c;就必须会看&#xff1a; cplusplus.comhttps://legacy.cplusplus.com/ 官方内容全都是英文的&#xff0c;可以参考&#xff1a; C/C初始识https://blog.csdn.net/2301_77087344/article/details/138596294?spm1001.2014.3001.5501 vector&#xff…...

使用libtorch加载YOLOv8生成的torchscript文件进行目标检测

在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集&#xff0c;使用 LabelMe 工具进行标注&#xff0c;然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件&#xff0c;并自动生成YOLOv8支持的目录结构&#xff0c;包括melon.yaml文件&#xff0c;其内容…...

Oracle 并行和 session 数量的

这也就是为什么我们指定parallel为4&#xff0c;而实际并行度为8的原因。 insert create index&#xff0c;发现并行数都是加倍的 Indexes seem always created with parallel degree 1 during import as seen from a sqlfile. The sql file shows content like: CREATE INDE…...

Android 版本与 API level 以及 NDK 版本对应

采用 Android studio 开发 Android app 的时候&#xff0c;需要选择支持的最低 API Level 和使用的 NDK 版本&#xff0c;对应开发 app 的最低 SDK 版本&#xff1a; 在 app 的 build.gradle 文件里&#xff0c;对应于代码如下&#xff1a; 目前各版本的占有率情况如下&#xf…...

护网经验面试题目原版

文章目录 一、护网项目经验1.项目经验**Hvv的分组和流程**有没有遇到过有意思的逻辑漏洞&#xff1f;有没有自己开发过武器/工具&#xff1f;有做过代码审计吗&#xff1f;有0day吗有cve/cnvd吗&#xff1f;有src排名吗&#xff1f;有没有写过技战法有钓鱼经历吗&#xff1f;具…...

MPNet:旋转机械轻量化故障诊断模型详解python代码复现

目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...

零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?

一、核心优势&#xff1a;专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发&#xff0c;是一款收费低廉但功能全面的Windows NAS工具&#xff0c;主打“无学习成本部署” 。与其他NAS软件相比&#xff0c;其优势在于&#xff1a; 无需硬件改造&#xff1a;将任意W…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装&#xff08;Encapsulation&#xff09; 定义&#xff1a;将数据&#xff08;属性&#xff09;和操作数据的方法绑定在一起&#xff0c;通过访问控制符&#xff08;private、protected、public&#xff09;隐藏内部实现细节。示例&#xff1a; public …...

DAY 47

三、通道注意力 3.1 通道注意力的定义 # 新增&#xff1a;通道注意力模块&#xff08;SE模块&#xff09; class ChannelAttention(nn.Module):"""通道注意力模块(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_rat…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院挂号小程序

一、开发准备 ​​环境搭建​​&#xff1a; 安装DevEco Studio 3.0或更高版本配置HarmonyOS SDK申请开发者账号 ​​项目创建​​&#xff1a; File > New > Create Project > Application (选择"Empty Ability") 二、核心功能实现 1. 医院科室展示 /…...

vue3 字体颜色设置的多种方式

在Vue 3中设置字体颜色可以通过多种方式实现&#xff0c;这取决于你是想在组件内部直接设置&#xff0c;还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法&#xff1a; 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...

Keil 中设置 STM32 Flash 和 RAM 地址详解

文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...

零基础设计模式——行为型模式 - 责任链模式

第四部分&#xff1a;行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习&#xff01;行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想&#xff1a;使多个对象都有机会处…...

HashMap中的put方法执行流程(流程图)

1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中&#xff0c;其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下&#xff1a; 初始判断与哈希计算&#xff1a; 首先&#xff0c;putVal 方法会检查当前的 table&#xff08;也就…...

LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》

这段 Python 代码是一个完整的 知识库数据库操作模块&#xff0c;用于对本地知识库系统中的知识库进行增删改查&#xff08;CRUD&#xff09;操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 &#x1f4d8; 一、整体功能概述 该模块…...