当前位置: 首页 > news >正文

Pytorch入门(7)—— 梯度累加(Gradient Accumulation)

1. 梯度累加

  • 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick
  • 梯度累加的思想很简单,就是时间换空间。具体而言,我们不在每个 batch data 梯度计算后直接更新模型,而是多算几个 batch 后,使用这些 batch 的平均梯度更新模型,从而放大等效 batch_size。如下图所示
    在这里插入图片描述
  • 用公式表示:设 batch size 为 n n n,模型参数为 w \pmb{w} w,样本 i i i 的损失为 l i l_i li,则正常情况下 sgd 参数更新为
    w ← w + α ∑ i = 1 n 1 n ∂ l i ∂ w \pmb{w} \leftarrow \pmb{w} + \alpha \sum_{i=1}^n\frac{1}{n}\frac{\partial l_i}{\partial \pmb{w}} ww+αi=1nn1wli 使用梯度累加时,设累加步长为 m m m(即计算 m m m 个 batch 梯度后用梯度均值更新一次),sgd 更新如下
    w ← w + α 1 m ∑ b = 1 m ∑ i = 1 n 1 n ∂ l b i ∂ w = w + α ∑ i = 1 m n 1 m n ∂ l i ∂ w \begin{aligned} \pmb{w} &\leftarrow \pmb{w} + \alpha \frac{1}{m} \sum_{b=1}^m \sum_{i=1}^n\frac{1}{n}\frac{\partial l_{bi}}{\partial \pmb{w}} \\ &= \pmb{w} + \alpha \sum_{i=1}^{mn}\frac{1}{mn} \frac{\partial l_i}{\partial \pmb{w}} \end{aligned} ww+αm1b=1mi=1nn1wlbi=w+αi=1mnmn1wli 可见这等价于使用 batch_size = m n mn mn 进行训练

2. 在 pytorch 中实现梯度累加

2.1 伪代码

  • pytorch 使用和 tensor 绑定的自动微分机制。每个 tensor 对象都有 .grad 属性存储其中每个元素的梯度值,通过 .requires_grad 属性控制其是否参与梯度计算。训练模型时,一般通过对标量 loss 执行 loss.backward() 自动进行反向传播,以得到计算图中所有 tensor 的梯度。详见 PyTorch入门(2)—— 自动求梯度
  • pytorch 中梯度 tensor.grad 不会自动清零,而会在每次反向传播过程中自动累加,所以一般在反向传播前把梯度清零
    for inputs, labels in data_loader:# forward pass preds = model(inputs)loss  = criterion(preds, labels)# clear grad of last batch	optimizer.zero_grad()# backward pass, calculate grad of batch dataloss.backward()# update modeloptimizer.step()
    
    这种设计对于实现梯度累加 trick 是很方便的,我们可以在 batch 计算过程中进行计数,仅在达到计数达到更新步长时进行一次参数更新并清零梯度,即
    # batch accumulation parameter
    accum_iter = 4  # loop through enumaretad batches
    for batch_idx, (inputs, labels) in enumerate(data_loader):# forward pass preds = model(inputs)loss  = criterion(preds, labels)# scale the loss to the mean of the accumulated batch sizeloss = loss / accum_iter # backward passloss.backward()# weights updateif ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):optimizer.step()optimizer.zero_grad()
    

2.2 线性回归案例

  • 下面使用来自 经典机器学习方法(1)—— 线性回归 的简单线性回归任务说明梯度累加的具体实现方法

    本节代码直接从 jupyter notebook 复制而来,可能无法直接运行!

  • 首先生成随机数据构造 dataset
    import torch
    from IPython import display
    from matplotlib import pyplot as plt
    import numpy as np
    import random
    import torch.utils.data as Data
    import torch.nn as nn
    import torch.optim as optim# 生成样本
    num_inputs = 2
    num_examples = 1000
    true_w = torch.Tensor([-2,3.4]).view(2,1)
    true_b = 4.2
    batch_size = 10# 1000 个2特征样本,每个特征都服从 N(0,1)
    features = torch.randn(num_examples, num_inputs, dtype=torch.float32) # 生成真实标记
    labels = torch.mm(features,true_w) + true_b
    labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)# 包装数据集,将训练数据的特征和标签组合
    dataset = Data.TensorDataset(features, labels)
    
    1. 不使用梯度累加技巧,batch size 设置为 40
      # 构造 DataLoader
      batch_size = 40
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失函数
      criterion = nn.MSELoss()# SGD优化器
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for X, y in data_iter:# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))# 梯度清零optimizer.zero_grad()            # 计算各参数梯度loss.backward()#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')'''
      epoch 1, loss: 0.5434057731628418
      epoch 2, loss: 0.1914414196014404
      epoch 3, loss: 0.06752514398097992
      '''
      
    2. 使用梯度累加,batch size 设置为 10,步长设为 4,等效 batch size 为 40
      # 构造 DataLoader
      batch_size = 10
      accum_iter = 4
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失
      criterion = nn.MSELoss()# SGD优化器对象
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for batch_idx, (X, y) in enumerate(data_iter):# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))  loss = loss / accum_iter	# 取各个累计batch的平均损失,从而在.backward()时得到平均梯度# 反向传播,梯度累计loss.backward()if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_iter)):#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()              # 梯度清零optimizer.zero_grad()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')
      '''
      epoch 1, loss: 0.5434057596921921
      epoch 2, loss: 0.19144139245152472
      epoch 3, loss: 0.06752512042224407
      '''
      
  • 可以观察到无论 epoch loss 还是 net[0].weight.grad 都完全相同,说明梯度累加不影响计算结果

相关文章:

Pytorch入门(7)—— 梯度累加(Gradient Accumulation)

1. 梯度累加 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick梯度累加的思想很简单,就是时…...

day12

第一题 本题我们可以使用以下方法&#xff1a; 方法一&#xff1a; 使用hash表<元素&#xff0c;出现次数>来统计字符串中不同元素分别出现的次数&#xff0c;当某一个元素的次数大于1时&#xff0c;返回false&#xff0c;如果每个元素的出现次数都为1&#xff0c;则返回…...

MySQL技术点合集

目录 1. MySQL目录 2. 验证是否首次登陆方法 3. 在Liunx中使用命令来输入sql语句方法 4. 获取修改密码 5. 关闭密码策略 6. 忘记MySQL密码找回 7. 旋转90度横向查看表 8. 添加一个远程连接的用户 1. MySQL目录 /usr/bin/mysql相关命令vim /etc/my.cnfmysql配置文件ls /…...

记录使用 Vue3 过程中的一些技术点

1、自定义组件&#xff0c;并使用 v-model 进行数据双向绑定。 简述&#xff1a; 自定义组件使用 v-model 进行传参时&#xff0c;遵循 Vue 3 的 v-model 机制。在 Vue 3 中&#xff0c;v-model 默认使用了 modelValue 作为 prop 名称&#xff0c;以及 update:modelValue 作为…...

6. C++通过fork的方式实现高性能网络服务器

我们上一节课写的tcp我们发现只有第一个与之连接的人才能收发信息。他又很多的不足 高性能网络服务器 通过fork实现高性能网络服务器 我们通过fork进行改装之后就可以成百上千的用户进行连接访问&#xff0c;对于每一个用户来说我们都fork一个子进程。让后让每一个子进程都是…...

直播美颜插件、美颜SDK详解:技术、功能与实现原理

今天&#xff0c;小编将详细解析直播美颜插件和美颜SDK的技术、功能以及实现原理。 一、美颜技术的背景与发展 1.1美颜技术的兴起 随着直播平台的普及&#xff0c;美颜SDK技术逐渐被集成到直播软件中&#xff0c;以满足用户对更美观、自然的直播效果的需求。 1.2美颜技术的…...

MyBatis入门(1)

目录 一、JDBC操作回顾 二、什么是MyBatis&#xff1f; 三、MyBatis入门 1、准备工作 &#xff08;1&#xff09;创建工程 &#xff08;2&#xff09;数据准备 2、配置数据库连接字符串 3、写持久层代码 4、单元测试 &#xff08;1&#xff09;使用IDEA自动成成测试类…...

打开服务器远程桌面连接不上,可能的原因及相应的解决策略

在解决远程桌面连接不上服务器的问题时&#xff0c;我们首先需要从专业的角度对可能的原因进行深入分析&#xff0c;并据此提出针对性的解决方案。以下是一些可能的原因及相应的解决策略&#xff1a; 一、网络连接问题 远程桌面连接需要稳定的网络支持&#xff0c;如果网络连接…...

用于时间序列概率预测的蒙特卡洛模拟

大家好&#xff0c;蒙特卡洛模拟是一种广泛应用于各个领域的计算技术&#xff0c;它通过从概率分布中随机抽取大量样本&#xff0c;并对结果进行统计分析&#xff0c;从而模拟复杂系统的行为。这种技术具有很强的适用性&#xff0c;在金融建模、工程设计、物理模拟、运筹优化以…...

VScode解决报错“Remote-SSH XHR failed无法访问远程服务器“的方案

VScode解决报错"Remote-SSH XHR failed无法访问远程服务器"的方案 $ ls ~/.vscode-server/bin 2ccd690cbff1569e4a83d7c43d45101f817401dc稳定版下载链接&#xff1a;https://update.code.visualstudio.com/commit:COMMIT_ID/server-linux-x64/stable 内测版下载链接…...

Python高级进阶--dict字典

dict字典⭐⭐ 1. 字典简介 dictionary&#xff08;字典&#xff09; 是 除列表以外 Python 之中 最灵活 的数据类型&#xff0c;类型为dict 字典同样可以用来存储多个数据字典使用键值对存储数据 2. 字典的定义 字典用{}定义键值对之间使用,分隔键和值之间使用:分隔 d {中…...

记忆力和人才测评,如何提升记忆力?

什么是记忆力&#xff1f; 如何通俗意义上的记忆力&#xff1f;我们可以把人的经历、经验理解成为一部纪录片&#xff0c;那么已经过去发生的事情&#xff0c;就是影片之前的情节&#xff0c;对于这些信息&#xff0c;在脑海里&#xff0c;人们会将其进行处理和组合&#xff…...

数据仓库建模

目录 数仓建模 为什么要对数据仓库进行分层 主题 主题的概念 维度建模&#xff1a; 模型的选择&#xff1a; 星形模式 雪花模型 星座模式 拉链表 维度表和事实表&#xff1a; 维度表 事实表 事实表设计规则 退化维度 事务事实表、周期快照事实表、累积快照事实…...

力扣:1738. 找出第 K 大的异或坐标值

1738. 找出第 K 大的异或坐标值 给你一个二维矩阵 matrix 和一个整数 k &#xff0c;矩阵大小为 m x n 由非负整数组成。 矩阵中坐标 (a, b) 的 值 可由对所有满足 0 < i < a < m 且 0 < j < b < n 的元素 matrix[i][j]&#xff08;下标从 0 开始计数&…...

Keras深度学习框架第二十讲:使用KerasCV中的Stable Diffusion进行高性能图像生成

1、绪论 1.1 概念 为便于后文讨论&#xff0c;首先进行相关概念的陈述。 Stable Diffusion&#xff1a;Stable Diffusion 是一个在图像生成领域广泛使用的技术&#xff0c;尤其是用于文本到图像的转换。它基于扩散模型&#xff08;Diffusion Models&#xff09;&#xff0c;这…...

C/C++ vector详解

要想了解STL&#xff0c;就必须会看&#xff1a; cplusplus.comhttps://legacy.cplusplus.com/ 官方内容全都是英文的&#xff0c;可以参考&#xff1a; C/C初始识https://blog.csdn.net/2301_77087344/article/details/138596294?spm1001.2014.3001.5501 vector&#xff…...

使用libtorch加载YOLOv8生成的torchscript文件进行目标检测

在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集&#xff0c;使用 LabelMe 工具进行标注&#xff0c;然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件&#xff0c;并自动生成YOLOv8支持的目录结构&#xff0c;包括melon.yaml文件&#xff0c;其内容…...

Oracle 并行和 session 数量的

这也就是为什么我们指定parallel为4&#xff0c;而实际并行度为8的原因。 insert create index&#xff0c;发现并行数都是加倍的 Indexes seem always created with parallel degree 1 during import as seen from a sqlfile. The sql file shows content like: CREATE INDE…...

Android 版本与 API level 以及 NDK 版本对应

采用 Android studio 开发 Android app 的时候&#xff0c;需要选择支持的最低 API Level 和使用的 NDK 版本&#xff0c;对应开发 app 的最低 SDK 版本&#xff1a; 在 app 的 build.gradle 文件里&#xff0c;对应于代码如下&#xff1a; 目前各版本的占有率情况如下&#xf…...

护网经验面试题目原版

文章目录 一、护网项目经验1.项目经验**Hvv的分组和流程**有没有遇到过有意思的逻辑漏洞&#xff1f;有没有自己开发过武器/工具&#xff1f;有做过代码审计吗&#xff1f;有0day吗有cve/cnvd吗&#xff1f;有src排名吗&#xff1f;有没有写过技战法有钓鱼经历吗&#xff1f;具…...

【Linux】shell脚本忽略错误继续执行

在 shell 脚本中&#xff0c;可以使用 set -e 命令来设置脚本在遇到错误时退出执行。如果你希望脚本忽略错误并继续执行&#xff0c;可以在脚本开头添加 set e 命令来取消该设置。 举例1 #!/bin/bash# 取消 set -e 的设置 set e# 执行命令&#xff0c;并忽略错误 rm somefile…...

基于Flask实现的医疗保险欺诈识别监测模型

基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施&#xff0c;由雇主和个人按一定比例缴纳保险费&#xff0c;建立社会医疗保险基金&#xff0c;支付雇员医疗费用的一种医疗保险制度&#xff0c; 它是促进社会文明和进步的…...

STM32标准库-DMA直接存储器存取

文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA&#xff08;Direct Memory Access&#xff09;直接存储器存取 DMA可以提供外设…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明

AI 领域的快速发展正在催生一个新时代&#xff0c;智能代理&#xff08;agents&#xff09;不再是孤立的个体&#xff0c;而是能够像一个数字团队一样协作。然而&#xff0c;当前 AI 生态系统的碎片化阻碍了这一愿景的实现&#xff0c;导致了“AI 巴别塔问题”——不同代理之间…...

ServerTrust 并非唯一

NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...

Spring AI 入门:Java 开发者的生成式 AI 实践之路

一、Spring AI 简介 在人工智能技术快速迭代的今天&#xff0c;Spring AI 作为 Spring 生态系统的新生力量&#xff0c;正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务&#xff08;如 OpenAI、Anthropic&#xff09;的无缝对接&…...

实现弹窗随键盘上移居中

实现弹窗随键盘上移的核心思路 在Android中&#xff0c;可以通过监听键盘的显示和隐藏事件&#xff0c;动态调整弹窗的位置。关键点在于获取键盘高度&#xff0c;并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...

ABAP设计模式之---“简单设计原则(Simple Design)”

“Simple Design”&#xff08;简单设计&#xff09;是软件开发中的一个重要理念&#xff0c;倡导以最简单的方式实现软件功能&#xff0c;以确保代码清晰易懂、易维护&#xff0c;并在项目需求变化时能够快速适应。 其核心目标是避免复杂和过度设计&#xff0c;遵循“让事情保…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析

Linux 内存管理实战精讲&#xff1a;核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用&#xff0c;还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...

【Nginx】使用 Nginx+Lua 实现基于 IP 的访问频率限制

使用 NginxLua 实现基于 IP 的访问频率限制 在高并发场景下&#xff0c;限制某个 IP 的访问频率是非常重要的&#xff0c;可以有效防止恶意攻击或错误配置导致的服务宕机。以下是一个详细的实现方案&#xff0c;使用 Nginx 和 Lua 脚本结合 Redis 来实现基于 IP 的访问频率限制…...