Pytorch入门(7)—— 梯度累加(Gradient Accumulation)
1. 梯度累加
- 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick
- 梯度累加的思想很简单,就是时间换空间。具体而言,我们不在每个 batch data 梯度计算后直接更新模型,而是多算几个 batch 后,使用这些 batch 的平均梯度更新模型,从而放大等效 batch_size。如下图所示
- 用公式表示:设 batch size 为 n n n,模型参数为 w \pmb{w} w,样本 i i i 的损失为 l i l_i li,则正常情况下 sgd 参数更新为
w ← w + α ∑ i = 1 n 1 n ∂ l i ∂ w \pmb{w} \leftarrow \pmb{w} + \alpha \sum_{i=1}^n\frac{1}{n}\frac{\partial l_i}{\partial \pmb{w}} w←w+αi=1∑nn1∂w∂li 使用梯度累加时,设累加步长为 m m m(即计算 m m m 个 batch 梯度后用梯度均值更新一次),sgd 更新如下
w ← w + α 1 m ∑ b = 1 m ∑ i = 1 n 1 n ∂ l b i ∂ w = w + α ∑ i = 1 m n 1 m n ∂ l i ∂ w \begin{aligned} \pmb{w} &\leftarrow \pmb{w} + \alpha \frac{1}{m} \sum_{b=1}^m \sum_{i=1}^n\frac{1}{n}\frac{\partial l_{bi}}{\partial \pmb{w}} \\ &= \pmb{w} + \alpha \sum_{i=1}^{mn}\frac{1}{mn} \frac{\partial l_i}{\partial \pmb{w}} \end{aligned} w←w+αm1b=1∑mi=1∑nn1∂w∂lbi=w+αi=1∑mnmn1∂w∂li 可见这等价于使用 batch_size = m n mn mn 进行训练
2. 在 pytorch 中实现梯度累加
2.1 伪代码
- pytorch 使用和 tensor 绑定的自动微分机制。每个 tensor 对象都有
.grad
属性存储其中每个元素的梯度值,通过.requires_grad
属性控制其是否参与梯度计算。训练模型时,一般通过对标量loss
执行loss.backward()
自动进行反向传播,以得到计算图中所有 tensor 的梯度。详见 PyTorch入门(2)—— 自动求梯度 - pytorch 中梯度
tensor.grad
不会自动清零,而会在每次反向传播过程中自动累加,所以一般在反向传播前把梯度清零
这种设计对于实现梯度累加 trick 是很方便的,我们可以在 batch 计算过程中进行计数,仅在达到计数达到更新步长时进行一次参数更新并清零梯度,即for inputs, labels in data_loader:# forward pass preds = model(inputs)loss = criterion(preds, labels)# clear grad of last batch optimizer.zero_grad()# backward pass, calculate grad of batch dataloss.backward()# update modeloptimizer.step()
# batch accumulation parameter accum_iter = 4 # loop through enumaretad batches for batch_idx, (inputs, labels) in enumerate(data_loader):# forward pass preds = model(inputs)loss = criterion(preds, labels)# scale the loss to the mean of the accumulated batch sizeloss = loss / accum_iter # backward passloss.backward()# weights updateif ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):optimizer.step()optimizer.zero_grad()
2.2 线性回归案例
- 下面使用来自 经典机器学习方法(1)—— 线性回归 的简单线性回归任务说明梯度累加的具体实现方法
本节代码直接从 jupyter notebook 复制而来,可能无法直接运行!
- 首先生成随机数据构造 dataset
import torch from IPython import display from matplotlib import pyplot as plt import numpy as np import random import torch.utils.data as Data import torch.nn as nn import torch.optim as optim# 生成样本 num_inputs = 2 num_examples = 1000 true_w = torch.Tensor([-2,3.4]).view(2,1) true_b = 4.2 batch_size = 10# 1000 个2特征样本,每个特征都服从 N(0,1) features = torch.randn(num_examples, num_inputs, dtype=torch.float32) # 生成真实标记 labels = torch.mm(features,true_w) + true_b labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)# 包装数据集,将训练数据的特征和标签组合 dataset = Data.TensorDataset(features, labels)
- 不使用梯度累加技巧,batch size 设置为 40
# 构造 DataLoader batch_size = 40 data_iter = Data.DataLoader(dataset, batch_size, shuffle=False) # shuffle=False 保证实验可比# 定义模型 net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数 nn.init.normal_(net[0].weight, mean=0, std=0) nn.init.constant_(net[0].bias, val=0)# 均方差损失函数 criterion = nn.MSELoss()# SGD优化器 optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练 num_epochs = 3 for epoch in range(1, num_epochs + 1):epoch_loss = []for X, y in data_iter:# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))# 梯度清零optimizer.zero_grad() # 计算各参数梯度loss.backward()#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')''' epoch 1, loss: 0.5434057731628418 epoch 2, loss: 0.1914414196014404 epoch 3, loss: 0.06752514398097992 '''
- 使用梯度累加,batch size 设置为 10,步长设为 4,等效 batch size 为 40
# 构造 DataLoader batch_size = 10 accum_iter = 4 data_iter = Data.DataLoader(dataset, batch_size, shuffle=False) # shuffle=False 保证实验可比# 定义模型 net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数 nn.init.normal_(net[0].weight, mean=0, std=0) nn.init.constant_(net[0].bias, val=0)# 均方差损失 criterion = nn.MSELoss()# SGD优化器对象 optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练 num_epochs = 3 for epoch in range(1, num_epochs + 1):epoch_loss = []for batch_idx, (X, y) in enumerate(data_iter):# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1)) loss = loss / accum_iter # 取各个累计batch的平均损失,从而在.backward()时得到平均梯度# 反向传播,梯度累计loss.backward()if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_iter)):#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step() # 梯度清零optimizer.zero_grad()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}') ''' epoch 1, loss: 0.5434057596921921 epoch 2, loss: 0.19144139245152472 epoch 3, loss: 0.06752512042224407 '''
- 不使用梯度累加技巧,batch size 设置为 40
- 可以观察到无论 epoch loss 还是
net[0].weight.grad
都完全相同,说明梯度累加不影响计算结果
相关文章:

Pytorch入门(7)—— 梯度累加(Gradient Accumulation)
1. 梯度累加 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick梯度累加的思想很简单,就是时…...

day12
第一题 本题我们可以使用以下方法: 方法一: 使用hash表<元素,出现次数>来统计字符串中不同元素分别出现的次数,当某一个元素的次数大于1时,返回false,如果每个元素的出现次数都为1,则返回…...
MySQL技术点合集
目录 1. MySQL目录 2. 验证是否首次登陆方法 3. 在Liunx中使用命令来输入sql语句方法 4. 获取修改密码 5. 关闭密码策略 6. 忘记MySQL密码找回 7. 旋转90度横向查看表 8. 添加一个远程连接的用户 1. MySQL目录 /usr/bin/mysql相关命令vim /etc/my.cnfmysql配置文件ls /…...

记录使用 Vue3 过程中的一些技术点
1、自定义组件,并使用 v-model 进行数据双向绑定。 简述: 自定义组件使用 v-model 进行传参时,遵循 Vue 3 的 v-model 机制。在 Vue 3 中,v-model 默认使用了 modelValue 作为 prop 名称,以及 update:modelValue 作为…...
6. C++通过fork的方式实现高性能网络服务器
我们上一节课写的tcp我们发现只有第一个与之连接的人才能收发信息。他又很多的不足 高性能网络服务器 通过fork实现高性能网络服务器 我们通过fork进行改装之后就可以成百上千的用户进行连接访问,对于每一个用户来说我们都fork一个子进程。让后让每一个子进程都是…...

直播美颜插件、美颜SDK详解:技术、功能与实现原理
今天,小编将详细解析直播美颜插件和美颜SDK的技术、功能以及实现原理。 一、美颜技术的背景与发展 1.1美颜技术的兴起 随着直播平台的普及,美颜SDK技术逐渐被集成到直播软件中,以满足用户对更美观、自然的直播效果的需求。 1.2美颜技术的…...

MyBatis入门(1)
目录 一、JDBC操作回顾 二、什么是MyBatis? 三、MyBatis入门 1、准备工作 (1)创建工程 (2)数据准备 2、配置数据库连接字符串 3、写持久层代码 4、单元测试 (1)使用IDEA自动成成测试类…...

打开服务器远程桌面连接不上,可能的原因及相应的解决策略
在解决远程桌面连接不上服务器的问题时,我们首先需要从专业的角度对可能的原因进行深入分析,并据此提出针对性的解决方案。以下是一些可能的原因及相应的解决策略: 一、网络连接问题 远程桌面连接需要稳定的网络支持,如果网络连接…...

用于时间序列概率预测的蒙特卡洛模拟
大家好,蒙特卡洛模拟是一种广泛应用于各个领域的计算技术,它通过从概率分布中随机抽取大量样本,并对结果进行统计分析,从而模拟复杂系统的行为。这种技术具有很强的适用性,在金融建模、工程设计、物理模拟、运筹优化以…...

VScode解决报错“Remote-SSH XHR failed无法访问远程服务器“的方案
VScode解决报错"Remote-SSH XHR failed无法访问远程服务器"的方案 $ ls ~/.vscode-server/bin 2ccd690cbff1569e4a83d7c43d45101f817401dc稳定版下载链接:https://update.code.visualstudio.com/commit:COMMIT_ID/server-linux-x64/stable 内测版下载链接…...

Python高级进阶--dict字典
dict字典⭐⭐ 1. 字典简介 dictionary(字典) 是 除列表以外 Python 之中 最灵活 的数据类型,类型为dict 字典同样可以用来存储多个数据字典使用键值对存储数据 2. 字典的定义 字典用{}定义键值对之间使用,分隔键和值之间使用:分隔 d {中…...

记忆力和人才测评,如何提升记忆力?
什么是记忆力? 如何通俗意义上的记忆力?我们可以把人的经历、经验理解成为一部纪录片,那么已经过去发生的事情,就是影片之前的情节,对于这些信息,在脑海里,人们会将其进行处理和组合ÿ…...
数据仓库建模
目录 数仓建模 为什么要对数据仓库进行分层 主题 主题的概念 维度建模: 模型的选择: 星形模式 雪花模型 星座模式 拉链表 维度表和事实表: 维度表 事实表 事实表设计规则 退化维度 事务事实表、周期快照事实表、累积快照事实…...

力扣:1738. 找出第 K 大的异或坐标值
1738. 找出第 K 大的异或坐标值 给你一个二维矩阵 matrix 和一个整数 k ,矩阵大小为 m x n 由非负整数组成。 矩阵中坐标 (a, b) 的 值 可由对所有满足 0 < i < a < m 且 0 < j < b < n 的元素 matrix[i][j](下标从 0 开始计数&…...

Keras深度学习框架第二十讲:使用KerasCV中的Stable Diffusion进行高性能图像生成
1、绪论 1.1 概念 为便于后文讨论,首先进行相关概念的陈述。 Stable Diffusion:Stable Diffusion 是一个在图像生成领域广泛使用的技术,尤其是用于文本到图像的转换。它基于扩散模型(Diffusion Models),这…...

C/C++ vector详解
要想了解STL,就必须会看: cplusplus.comhttps://legacy.cplusplus.com/ 官方内容全都是英文的,可以参考: C/C初始识https://blog.csdn.net/2301_77087344/article/details/138596294?spm1001.2014.3001.5501 vectorÿ…...

使用libtorch加载YOLOv8生成的torchscript文件进行目标检测
在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集,使用 LabelMe 工具进行标注,然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件,并自动生成YOLOv8支持的目录结构,包括melon.yaml文件,其内容…...

Oracle 并行和 session 数量的
这也就是为什么我们指定parallel为4,而实际并行度为8的原因。 insert create index,发现并行数都是加倍的 Indexes seem always created with parallel degree 1 during import as seen from a sqlfile. The sql file shows content like: CREATE INDE…...

Android 版本与 API level 以及 NDK 版本对应
采用 Android studio 开发 Android app 的时候,需要选择支持的最低 API Level 和使用的 NDK 版本,对应开发 app 的最低 SDK 版本: 在 app 的 build.gradle 文件里,对应于代码如下: 目前各版本的占有率情况如下…...
护网经验面试题目原版
文章目录 一、护网项目经验1.项目经验**Hvv的分组和流程**有没有遇到过有意思的逻辑漏洞?有没有自己开发过武器/工具?有做过代码审计吗?有0day吗有cve/cnvd吗?有src排名吗?有没有写过技战法有钓鱼经历吗?具…...
【网络】每天掌握一个Linux命令 - iftop
在Linux系统中,iftop是网络管理的得力助手,能实时监控网络流量、连接情况等,帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...
SkyWalking 10.2.0 SWCK 配置过程
SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外,K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案,全安装在K8S群集中。 具体可参…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘
美国西海岸的夏天,再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至,这不仅是开发者的盛宴,更是全球数亿苹果用户翘首以盼的科技春晚。今年,苹果依旧为我们带来了全家桶式的系统更新,包括 iOS 26、iPadOS 26…...
golang循环变量捕获问题
在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - 循环变量捕获问题。让我详细解释一下: 问题背景 看这个代码片段: fo…...
rknn优化教程(二)
文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...
Python爬虫实战:研究feedparser库相关技术
1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...
测试markdown--肇兴
day1: 1、去程:7:04 --11:32高铁 高铁右转上售票大厅2楼,穿过候车厅下一楼,上大巴车 ¥10/人 **2、到达:**12点多到达寨子,买门票,美团/抖音:¥78人 3、中饭&a…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序
一、开发环境准备 工具安装: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 项目初始化: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南
🚀 C extern 关键字深度解析:跨文件编程的终极指南 📅 更新时间:2025年6月5日 🏷️ 标签:C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言🔥一、extern 是什么?&…...
JVM暂停(Stop-The-World,STW)的原因分类及对应排查方案
JVM暂停(Stop-The-World,STW)的完整原因分类及对应排查方案,结合JVM运行机制和常见故障场景整理而成: 一、GC相关暂停 1. 安全点(Safepoint)阻塞 现象:JVM暂停但无GC日志,日志显示No GCs detected。原因:JVM等待所有线程进入安全点(如…...