当前位置: 首页 > news >正文

Pytorch入门(7)—— 梯度累加(Gradient Accumulation)

1. 梯度累加

  • 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick
  • 梯度累加的思想很简单,就是时间换空间。具体而言,我们不在每个 batch data 梯度计算后直接更新模型,而是多算几个 batch 后,使用这些 batch 的平均梯度更新模型,从而放大等效 batch_size。如下图所示
    在这里插入图片描述
  • 用公式表示:设 batch size 为 n n n,模型参数为 w \pmb{w} w,样本 i i i 的损失为 l i l_i li,则正常情况下 sgd 参数更新为
    w ← w + α ∑ i = 1 n 1 n ∂ l i ∂ w \pmb{w} \leftarrow \pmb{w} + \alpha \sum_{i=1}^n\frac{1}{n}\frac{\partial l_i}{\partial \pmb{w}} ww+αi=1nn1wli 使用梯度累加时,设累加步长为 m m m(即计算 m m m 个 batch 梯度后用梯度均值更新一次),sgd 更新如下
    w ← w + α 1 m ∑ b = 1 m ∑ i = 1 n 1 n ∂ l b i ∂ w = w + α ∑ i = 1 m n 1 m n ∂ l i ∂ w \begin{aligned} \pmb{w} &\leftarrow \pmb{w} + \alpha \frac{1}{m} \sum_{b=1}^m \sum_{i=1}^n\frac{1}{n}\frac{\partial l_{bi}}{\partial \pmb{w}} \\ &= \pmb{w} + \alpha \sum_{i=1}^{mn}\frac{1}{mn} \frac{\partial l_i}{\partial \pmb{w}} \end{aligned} ww+αm1b=1mi=1nn1wlbi=w+αi=1mnmn1wli 可见这等价于使用 batch_size = m n mn mn 进行训练

2. 在 pytorch 中实现梯度累加

2.1 伪代码

  • pytorch 使用和 tensor 绑定的自动微分机制。每个 tensor 对象都有 .grad 属性存储其中每个元素的梯度值,通过 .requires_grad 属性控制其是否参与梯度计算。训练模型时,一般通过对标量 loss 执行 loss.backward() 自动进行反向传播,以得到计算图中所有 tensor 的梯度。详见 PyTorch入门(2)—— 自动求梯度
  • pytorch 中梯度 tensor.grad 不会自动清零,而会在每次反向传播过程中自动累加,所以一般在反向传播前把梯度清零
    for inputs, labels in data_loader:# forward pass preds = model(inputs)loss  = criterion(preds, labels)# clear grad of last batch	optimizer.zero_grad()# backward pass, calculate grad of batch dataloss.backward()# update modeloptimizer.step()
    
    这种设计对于实现梯度累加 trick 是很方便的,我们可以在 batch 计算过程中进行计数,仅在达到计数达到更新步长时进行一次参数更新并清零梯度,即
    # batch accumulation parameter
    accum_iter = 4  # loop through enumaretad batches
    for batch_idx, (inputs, labels) in enumerate(data_loader):# forward pass preds = model(inputs)loss  = criterion(preds, labels)# scale the loss to the mean of the accumulated batch sizeloss = loss / accum_iter # backward passloss.backward()# weights updateif ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):optimizer.step()optimizer.zero_grad()
    

2.2 线性回归案例

  • 下面使用来自 经典机器学习方法(1)—— 线性回归 的简单线性回归任务说明梯度累加的具体实现方法

    本节代码直接从 jupyter notebook 复制而来,可能无法直接运行!

  • 首先生成随机数据构造 dataset
    import torch
    from IPython import display
    from matplotlib import pyplot as plt
    import numpy as np
    import random
    import torch.utils.data as Data
    import torch.nn as nn
    import torch.optim as optim# 生成样本
    num_inputs = 2
    num_examples = 1000
    true_w = torch.Tensor([-2,3.4]).view(2,1)
    true_b = 4.2
    batch_size = 10# 1000 个2特征样本,每个特征都服从 N(0,1)
    features = torch.randn(num_examples, num_inputs, dtype=torch.float32) # 生成真实标记
    labels = torch.mm(features,true_w) + true_b
    labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)# 包装数据集,将训练数据的特征和标签组合
    dataset = Data.TensorDataset(features, labels)
    
    1. 不使用梯度累加技巧,batch size 设置为 40
      # 构造 DataLoader
      batch_size = 40
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失函数
      criterion = nn.MSELoss()# SGD优化器
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for X, y in data_iter:# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))# 梯度清零optimizer.zero_grad()            # 计算各参数梯度loss.backward()#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')'''
      epoch 1, loss: 0.5434057731628418
      epoch 2, loss: 0.1914414196014404
      epoch 3, loss: 0.06752514398097992
      '''
      
    2. 使用梯度累加,batch size 设置为 10,步长设为 4,等效 batch size 为 40
      # 构造 DataLoader
      batch_size = 10
      accum_iter = 4
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失
      criterion = nn.MSELoss()# SGD优化器对象
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for batch_idx, (X, y) in enumerate(data_iter):# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))  loss = loss / accum_iter	# 取各个累计batch的平均损失,从而在.backward()时得到平均梯度# 反向传播,梯度累计loss.backward()if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_iter)):#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()              # 梯度清零optimizer.zero_grad()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')
      '''
      epoch 1, loss: 0.5434057596921921
      epoch 2, loss: 0.19144139245152472
      epoch 3, loss: 0.06752512042224407
      '''
      
  • 可以观察到无论 epoch loss 还是 net[0].weight.grad 都完全相同,说明梯度累加不影响计算结果

相关文章:

Pytorch入门(7)—— 梯度累加(Gradient Accumulation)

1. 梯度累加 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick梯度累加的思想很简单,就是时…...

day12

第一题 本题我们可以使用以下方法&#xff1a; 方法一&#xff1a; 使用hash表<元素&#xff0c;出现次数>来统计字符串中不同元素分别出现的次数&#xff0c;当某一个元素的次数大于1时&#xff0c;返回false&#xff0c;如果每个元素的出现次数都为1&#xff0c;则返回…...

MySQL技术点合集

目录 1. MySQL目录 2. 验证是否首次登陆方法 3. 在Liunx中使用命令来输入sql语句方法 4. 获取修改密码 5. 关闭密码策略 6. 忘记MySQL密码找回 7. 旋转90度横向查看表 8. 添加一个远程连接的用户 1. MySQL目录 /usr/bin/mysql相关命令vim /etc/my.cnfmysql配置文件ls /…...

记录使用 Vue3 过程中的一些技术点

1、自定义组件&#xff0c;并使用 v-model 进行数据双向绑定。 简述&#xff1a; 自定义组件使用 v-model 进行传参时&#xff0c;遵循 Vue 3 的 v-model 机制。在 Vue 3 中&#xff0c;v-model 默认使用了 modelValue 作为 prop 名称&#xff0c;以及 update:modelValue 作为…...

6. C++通过fork的方式实现高性能网络服务器

我们上一节课写的tcp我们发现只有第一个与之连接的人才能收发信息。他又很多的不足 高性能网络服务器 通过fork实现高性能网络服务器 我们通过fork进行改装之后就可以成百上千的用户进行连接访问&#xff0c;对于每一个用户来说我们都fork一个子进程。让后让每一个子进程都是…...

直播美颜插件、美颜SDK详解:技术、功能与实现原理

今天&#xff0c;小编将详细解析直播美颜插件和美颜SDK的技术、功能以及实现原理。 一、美颜技术的背景与发展 1.1美颜技术的兴起 随着直播平台的普及&#xff0c;美颜SDK技术逐渐被集成到直播软件中&#xff0c;以满足用户对更美观、自然的直播效果的需求。 1.2美颜技术的…...

MyBatis入门(1)

目录 一、JDBC操作回顾 二、什么是MyBatis&#xff1f; 三、MyBatis入门 1、准备工作 &#xff08;1&#xff09;创建工程 &#xff08;2&#xff09;数据准备 2、配置数据库连接字符串 3、写持久层代码 4、单元测试 &#xff08;1&#xff09;使用IDEA自动成成测试类…...

打开服务器远程桌面连接不上,可能的原因及相应的解决策略

在解决远程桌面连接不上服务器的问题时&#xff0c;我们首先需要从专业的角度对可能的原因进行深入分析&#xff0c;并据此提出针对性的解决方案。以下是一些可能的原因及相应的解决策略&#xff1a; 一、网络连接问题 远程桌面连接需要稳定的网络支持&#xff0c;如果网络连接…...

用于时间序列概率预测的蒙特卡洛模拟

大家好&#xff0c;蒙特卡洛模拟是一种广泛应用于各个领域的计算技术&#xff0c;它通过从概率分布中随机抽取大量样本&#xff0c;并对结果进行统计分析&#xff0c;从而模拟复杂系统的行为。这种技术具有很强的适用性&#xff0c;在金融建模、工程设计、物理模拟、运筹优化以…...

VScode解决报错“Remote-SSH XHR failed无法访问远程服务器“的方案

VScode解决报错"Remote-SSH XHR failed无法访问远程服务器"的方案 $ ls ~/.vscode-server/bin 2ccd690cbff1569e4a83d7c43d45101f817401dc稳定版下载链接&#xff1a;https://update.code.visualstudio.com/commit:COMMIT_ID/server-linux-x64/stable 内测版下载链接…...

Python高级进阶--dict字典

dict字典⭐⭐ 1. 字典简介 dictionary&#xff08;字典&#xff09; 是 除列表以外 Python 之中 最灵活 的数据类型&#xff0c;类型为dict 字典同样可以用来存储多个数据字典使用键值对存储数据 2. 字典的定义 字典用{}定义键值对之间使用,分隔键和值之间使用:分隔 d {中…...

记忆力和人才测评,如何提升记忆力?

什么是记忆力&#xff1f; 如何通俗意义上的记忆力&#xff1f;我们可以把人的经历、经验理解成为一部纪录片&#xff0c;那么已经过去发生的事情&#xff0c;就是影片之前的情节&#xff0c;对于这些信息&#xff0c;在脑海里&#xff0c;人们会将其进行处理和组合&#xff…...

数据仓库建模

目录 数仓建模 为什么要对数据仓库进行分层 主题 主题的概念 维度建模&#xff1a; 模型的选择&#xff1a; 星形模式 雪花模型 星座模式 拉链表 维度表和事实表&#xff1a; 维度表 事实表 事实表设计规则 退化维度 事务事实表、周期快照事实表、累积快照事实…...

力扣:1738. 找出第 K 大的异或坐标值

1738. 找出第 K 大的异或坐标值 给你一个二维矩阵 matrix 和一个整数 k &#xff0c;矩阵大小为 m x n 由非负整数组成。 矩阵中坐标 (a, b) 的 值 可由对所有满足 0 < i < a < m 且 0 < j < b < n 的元素 matrix[i][j]&#xff08;下标从 0 开始计数&…...

Keras深度学习框架第二十讲:使用KerasCV中的Stable Diffusion进行高性能图像生成

1、绪论 1.1 概念 为便于后文讨论&#xff0c;首先进行相关概念的陈述。 Stable Diffusion&#xff1a;Stable Diffusion 是一个在图像生成领域广泛使用的技术&#xff0c;尤其是用于文本到图像的转换。它基于扩散模型&#xff08;Diffusion Models&#xff09;&#xff0c;这…...

C/C++ vector详解

要想了解STL&#xff0c;就必须会看&#xff1a; cplusplus.comhttps://legacy.cplusplus.com/ 官方内容全都是英文的&#xff0c;可以参考&#xff1a; C/C初始识https://blog.csdn.net/2301_77087344/article/details/138596294?spm1001.2014.3001.5501 vector&#xff…...

使用libtorch加载YOLOv8生成的torchscript文件进行目标检测

在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集&#xff0c;使用 LabelMe 工具进行标注&#xff0c;然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件&#xff0c;并自动生成YOLOv8支持的目录结构&#xff0c;包括melon.yaml文件&#xff0c;其内容…...

Oracle 并行和 session 数量的

这也就是为什么我们指定parallel为4&#xff0c;而实际并行度为8的原因。 insert create index&#xff0c;发现并行数都是加倍的 Indexes seem always created with parallel degree 1 during import as seen from a sqlfile. The sql file shows content like: CREATE INDE…...

Android 版本与 API level 以及 NDK 版本对应

采用 Android studio 开发 Android app 的时候&#xff0c;需要选择支持的最低 API Level 和使用的 NDK 版本&#xff0c;对应开发 app 的最低 SDK 版本&#xff1a; 在 app 的 build.gradle 文件里&#xff0c;对应于代码如下&#xff1a; 目前各版本的占有率情况如下&#xf…...

护网经验面试题目原版

文章目录 一、护网项目经验1.项目经验**Hvv的分组和流程**有没有遇到过有意思的逻辑漏洞&#xff1f;有没有自己开发过武器/工具&#xff1f;有做过代码审计吗&#xff1f;有0day吗有cve/cnvd吗&#xff1f;有src排名吗&#xff1f;有没有写过技战法有钓鱼经历吗&#xff1f;具…...

Prompt Tuning、P-Tuning、Prefix Tuning的区别

一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...

Java如何权衡是使用无序的数组还是有序的数组

在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等

&#x1f50d; 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术&#xff0c;可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势&#xff0c;还能有效评价重大生态工程…...

三体问题详解

从物理学角度&#xff0c;三体问题之所以不稳定&#xff0c;是因为三个天体在万有引力作用下相互作用&#xff0c;形成一个非线性耦合系统。我们可以从牛顿经典力学出发&#xff0c;列出具体的运动方程&#xff0c;并说明为何这个系统本质上是混沌的&#xff0c;无法得到一般解…...

管理学院权限管理系统开发总结

文章目录 &#x1f393; 管理学院权限管理系统开发总结 - 现代化Web应用实践之路&#x1f4dd; 项目概述&#x1f3d7;️ 技术架构设计后端技术栈前端技术栈 &#x1f4a1; 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 &#x1f5c4;️ 数据库设…...

代码规范和架构【立芯理论一】(2025.06.08)

1、代码规范的目标 代码简洁精炼、美观&#xff0c;可持续性好高效率高复用&#xff0c;可移植性好高内聚&#xff0c;低耦合没有冗余规范性&#xff0c;代码有规可循&#xff0c;可以看出自己当时的思考过程特殊排版&#xff0c;特殊语法&#xff0c;特殊指令&#xff0c;必须…...

基于开源AI智能名片链动2 + 1模式S2B2C商城小程序的沉浸式体验营销研究

摘要&#xff1a;在消费市场竞争日益激烈的当下&#xff0c;传统体验营销方式存在诸多局限。本文聚焦开源AI智能名片链动2 1模式S2B2C商城小程序&#xff0c;探讨其在沉浸式体验营销中的应用。通过对比传统品鉴、工厂参观等初级体验方式&#xff0c;分析沉浸式体验的优势与价值…...

Python 高级应用10:在python 大型项目中 FastAPI 和 Django 的相互配合

无论是python&#xff0c;或者java 的大型项目中&#xff0c;都会涉及到 自身平台微服务之间的相互调用&#xff0c;以及和第三发平台的 接口对接&#xff0c;那在python 中是怎么实现的呢&#xff1f; 在 Python Web 开发中&#xff0c;FastAPI 和 Django 是两个重要但定位不…...

CppCon 2015 学习:REFLECTION TECHNIQUES IN C++

关于 Reflection&#xff08;反射&#xff09; 这个概念&#xff0c;总结一下&#xff1a; Reflection&#xff08;反射&#xff09;是什么&#xff1f; 反射是对类型的自我检查能力&#xff08;Introspection&#xff09; 可以查看类的成员变量、成员函数等信息。反射允许枚…...

SQL注入篇-sqlmap的配置和使用

在之前的皮卡丘靶场第五期SQL注入的内容中我们谈到了sqlmap&#xff0c;但是由于很多朋友看不了解命令行格式&#xff0c;所以是纯手动获取数据库信息的 接下来我们就用sqlmap来进行皮卡丘靶场的sql注入学习&#xff0c;链接&#xff1a;https://wwhc.lanzoue.com/ifJY32ybh6vc…...