当前位置: 首页 > article >正文

神经网络优化 - 小批量梯度下降之批量大小的选择

上一博文学习了小批量梯度下降在神经网络优化中的应用:

神经网络优化 - 小批量梯度下降-CSDN博客

在小批量梯度下降法中,批量大小(Batch Size)对网络优化的影响也非常大,本文我们来学习如何选择小批量梯度下降的批量大小。

一、批量大小的选择

一般而言,批量大小不影响随机梯度的期望,但是会影响随机梯度的方差。

批量大小越大,随机梯度的方差越小,引入的噪声也越小,训练也越稳定,因此可以设置较大的学习率。而批量大小较小时,需要设置较小的学习率,否则模型会不收敛。学习率通常要随着批量大小的增大而相应地增大。

一个简单有效的方法是线性缩放规则(Linear Scaling Rule):当批量大小增加 𝑚 倍时,学习率也增加 𝑚 倍。线性缩放规则往往在批量大小比较小时适用,当批量大小非常大时,线性缩放会使得训练不稳定。

我们来分析上面的这段话。

(一)首先,为什么“批量大小越大,随机梯度的方差越小,引入的噪声也越小,训练也越稳定”呢?

下面先给出核心结论:在小批量梯度下降(mini‑batch SGD)中,每次更新所用梯度是对选定批次样本梯度的平均。根据大数定律和方差运算规则,这个平均梯度的方差随着批次大小 B 的增大而减少(具体地,对单个样本梯度方差 σ^2,平均梯度的方差为 σ^2/B),因此批次越大,随机梯度的“噪声”越小,参数更新轨迹越平滑,训练也就越稳定。但批次过大又会降低更新频率、消耗更多内存并可能陷入不理想的平坦区域,所以实际中需在“稳定性”与“效率”之间做权衡。

1. 平均梯度方差随批次大小缩减

  • 单样本梯度的随机性
    每个训练样本对梯度的贡献视作随机变量,其方差可记作 σ^2。

  • 批次梯度为样本梯度的均值
    对一个大小为 B 的批次,梯度估计

    其方差根据样本独立同分布的假设满足

    因此 B 越大,Var(g) 越小,使得每一步的更新更接近于“真实”全量梯度。

2. 数学直观与统计背景

  • 大数定律
    当 B 足够大时,批次均值近似收敛到总体期望(真实梯度),方差衰减至零。

  • 中心极限定理
    即便单样本梯度分布非正态,均值在大批次时仍近似正态分布,方差为 σ^2/B。

  • 优化平滑性
    较小的方差意味着更新方向的随机波动减少,参数更新路径更加平滑,梯度下降的步伐也更稳定,不易在鞍点或小坑中震荡。

3. 稳定性 vs. 更新频率的权衡

  • 批次越大

    • 噪声小、更新平稳:更接近全量梯度,收敛更规律。

    • 计算开销大、更新慢:每次迭代需更多样本,迭代次数减少。

  • 批次越小

    • 噪声大、更新抖动:有助于跳出鞍点,提高泛化。

    • 更新快、泛化好:更多次更新让算法能更快“试探”参数空间。

(二)上面提到“批量大小较小时,需要设置较小的学习率,否则模型会不收敛。”,这是为什么呢?

1. 梯度方差与学习率的关系

1.1 噪声放大与发散风险

一次更新为 。更新的随机波动度量可近似为

当 B 很小时,若 η不相应减小, 会变得很大,导致更新步长在不同方向上有剧烈抖动,学习过程难以收敛。

2. 理论与实证研究

2.1 明确“泛化差距”与鞍点分析

  • Keskar et al. (ICLR 2017) 发现大批量训练易陷入“尖锐最小点”而泛化较差,小批量带来的噪声有助于跳出尖锐谷底,且需适当调低学习率以保持稳定。

  • Masters & Luschi (ICLR 2018) 系统回顾了小批量训练的优势与挑战,指出小批量时梯度噪声大,需要更小的步长以避免震荡和发散。

2.2 动量与超收敛现象

  • Hoffer et al. (arXiv 2017) 演示了在小批量下使用更小学习率并结合动量可以加速训练并稳定收敛。

  • Smith et al. (ICLR 2019) 通过“超收敛”实验,进一步展示学习率和批量大小的精细配合对于噪声平衡和收敛速度的重要性。

3. 实践指南

  1. 依据批量大小缩放学习率

    • 经验法则:当批量大小减小 kkk 倍时,将学习率也减小约 k\sqrt{k}k​ 至 kkk 倍,以保持更新方差可控。

  2. 结合动量或自适应优化器

    • 在小批量时,动量(Momentum)或 Adam/RMSprop 等优化器可帮助缓冲梯度噪声,提高更新稳定性。

  3. 监控训练动态

    • 观察训练损失曲线平滑度:若出现震荡或不下降,尝试减小学习率;

    • 根据验证集性能动态调整,确保既能快速收敛又保持良好泛化。

二、回合(Epoch)和迭代(Iteration)的概念

每一次小批量更新为一次迭代,所有训练集的样本更新一遍为一个回合,两者的关系为:

下图说明在 MNIST 数据集上批量大小对损失下降的影响:

从左图可以看出,批量大小越大,下降效果越明显,并且下降曲线越平滑。但从右图可以看出,如果按整个数据集上的回合(Epoch)数来看,则是批量样本数越小,下降效果越明显。适当小的批量会导致更快的收敛。

此外,批量大小和模型的泛化能力也有一定的关系。通过实验发现:批量越大,越有可能收敛到尖锐最小值;批量越小,越有可能收敛到平坦最小值。

三、在神经网络参数优化当中,采用小批量梯度下降,如何确定批量大小呢?

在神经网络训练中,确定小批量梯度下降(Mini-batch Gradient Descent)的批量大小(Batch Size)需要综合考虑计算效率、内存限制、优化效果和泛化性能。

1. 基本原则与经验法则

  • 常见初始值:通常从 32、64、128、256 等 2 的幂次开始(因内存对齐优化),但需根据任务调整。

  • 资源限制:批量大小受限于硬件内存(如GPU显存),需确保不引发内存溢出(OOM)。

  • 学习率联动:增大批量时,可能需要按比例增大学习率(但需谨慎,避免不稳定)。

2. 批量大小的核心权衡

批量大小优点缺点
小批量- 梯度噪声大,可能逃离局部最优
- 内存需求低,适合小显存设备
- 泛化性能可能更好
- 计算效率低(并行性差)
- 梯度方向波动大,训练不稳定
大批量- 梯度估计更准确,训练稳定
- 计算效率高(充分利用并行性)
- 易陷入平坦区域或鞍点
- 可能损害泛化性能
- 显存占用高

3. 确定批量大小的具体方法

(1) 基于硬件限制的调整
  • 显存估算

    • 估算单个样本的显存占用(包括模型参数、激活值、梯度等)。

    • 批量大小 ≈ 可用显存 / 单样本显存占用(预留20%余量)。

    • 例如:单样本占用 100MB,显存 8GB → 最大批量 ≈ 8000MB / 100MB ≈ 80(取 64 或 128)。

(2) 基于任务特性的选择
  • 数据复杂度高(如图像分割、自然语言生成):

    • 建议小批量(如 16~64),以增加梯度多样性,避免过拟合。

  • 数据简单或噪声多(如分类任务):

    • 可尝试大批量(如 128~512),加速收敛。

(3) 学习率与批量大小的联动(线性缩放规则)
  • 基本规则
    当批量大小增大 k 倍时,学习率可同步增大 k 倍(适用于小批量→中等批量,如 32→256)。
    例如:批量从 64 增大到 256,学习率从 0.1 调整到 0.4。

  • 注意事项

    • 该规则在极大批量(如 >1024)时可能失效,需结合学习率预热(Learning Rate Warmup)或其他自适应优化器(如 Adam)。

    • 实践公式:学习率=基础学习率×批量大小N​,N 为参考批量(如 256)。

(4) 实验验证法
  • 网格搜索
    对候选批量(如 32、64、128、256)分别训练少量 epoch,观察:

    • 训练损失下降速度

    • 验证集准确率/损失

    • 训练时间/显存占用
      选择综合表现最优的批量。

  • 动态调整

    • 若训练初期梯度波动大(损失震荡),可适当增大批量。

    • 若模型陷入局部最优,可减小批量引入更多噪声。

4. 特殊情况处理

  • 极小批量(Batch Size=1)

    • 等价于随机梯度下降(SGD),噪声极大,需搭配低学习率或梯度累积(Gradient Accumulation)。

    • 适用于显存极度受限的场景(如训练大语言模型)。

  • 极大批量(Batch Size > 1000)

    • 需结合学习率预热(前几个 epoch 逐步增大学习率)和自适应优化器(如 LAMB、AdamW)。

    • 注意:可能降低模型泛化能力,需增强正则化(如数据增强、Dropout)。

5. 经典场景参考

任务类型推荐批量大小理由
图像分类(ResNet)64~512平衡并行效率与泛化性能
目标检测(YOLO)8~32高分辨率图像显存占用大
自然语言处理(BERT)16~64长序列导致单样本显存高
强化学习(PPO)64~256需大量环境交互数据

6. 总结与建议

  1. 从默认值开始:尝试批量大小 64 或 128,结合学习率 0.1(SGD)或 1e-4(Adam)。

  2. 逐步调整:根据显存占用和训练稳定性,按 2 的倍数增减批量。

  3. 监控指标:重点关注验证集性能而非训练速度,避免过拟合或欠拟合。

  4. 结合优化器:大批量时使用 LAMB 或 AdamW,小批量时使用 SGD 或 Adam。

最终,批量大小是超参数的一种,需通过实验找到任务、模型和硬件的最优平衡点。

相关文章:

神经网络优化 - 小批量梯度下降之批量大小的选择

上一博文学习了小批量梯度下降在神经网络优化中的应用: 神经网络优化 - 小批量梯度下降-CSDN博客 在小批量梯度下降法中,批量大小(Batch Size)对网络优化的影响也非常大,本文我们来学习如何选择小批量梯度下降的批量大小。 一、批量大小的…...

Novartis诺华制药社招入职综合能力测评真题SHL题库考什么?

一、综合能力测试 诺华制药的入职测评中,综合能力测试是重要的一部分,主要考察应聘者的问题解决能力、数值计算能力和逻辑推理能力。测试总时长为46分钟,实际作答时间为36分钟,共24题。题型丰富多样,包括图形变换题、分…...

文件的物理结构和逻辑结构的区分

文件的物理结构和逻辑结构是文件系统中两个重要的概念,它们分别描述了文件在存储设备上的实际存储方式以及用户在编程或操作文件时所看到的抽象组织形式。理解这两者的区别和联系对于深入掌握文件系统的设计和实现至关重要。 ​一、文件的逻辑结构 ​定义 文件的逻…...

C语言学习记录(16)文件操作7

前面学的东西感觉都跟写代码有关系,怎么突然就开始说文件了,有什么用呢? 其实,文件是另一种数据存储的方式,学会使用文件就可以让我们的数据持久的保存。 一、文件是什么 就算没有学过相关的知识,在这么…...

Coze平台​ 创建AI智能体的详细步骤指南

一、创建智能体的基础流程​ ​注册与登录​ 访问Coze官网(www.coze.cn),使用邮箱或手机号注册账号并登录。 ​创建智能体​ 在控制台点击左侧“”按钮,选择“创建智能体”,输入名称(如“职场鼓励师”&…...

《作用域大冒险:从闭包到内存泄漏的终极探索》

“爱自有天意,天有道自不会让有情人分离” 大家好,关于闭包问题其实实际上是js作用域的问题,那么js有几种作用域呢? 作用域类型关键字/场景作用域范围示例全局作用域var(无声明)整个程序var x 10;函数作用…...

android Stagefright框架

作为Android音视频开发人员,学习Stagefright框架需要结合理论、源码分析和实践验证。以下是系统化的学习路径: 1. 基础准备 熟悉Android多媒体体系 掌握MediaPlayer、MediaCodec、MediaExtractor等核心API的用法。 理解Android的OpenMAX IL&#xff08…...

Shell脚本-变量的分类

在Shell脚本编程中,变量是存储数据的基本单位。它们可以用来保存字符串、数字甚至是命令的输出结果。正确地定义和使用变量能够极大地提高脚本的灵活性与可维护性。本文将详细介绍Shell脚本中变量的不同分类及其应用场景,帮助你编写更高效、简洁的Shell脚…...

<C#>.NET WebAPI 的 FromBody ,FromForm ,FromServices等详细解释

在 .NET 8 Web API 中,[FromBody]、[FromForm]、[FromHeader]、[FromKeyedServices]、[FromQuery]、[FromRoute] 和 [FromServices] 这些都是用于绑定控制器动作方法参数的特性,下面为你详细解释这些特性。 1. [FromBody] 作用:从 HTTP 请求…...

让数据应用更简单:Streamlit与Gradio的比较与联系

在数据科学与机器学习的快速发展中,如何快速构建可视化应用成为了许多工程师和数据科学家的一个重要需求。Streamlit和Gradio是两款备受欢迎的开源库,它们各自提供了便捷的方式来构建基于Web的应用。虽然二者在功能上有许多相似之处,但它们的…...

LlamaIndex 生成的本地索引文件和文件夹详解

LlamaIndex 生成的本地索引文件和文件夹详解 LlamaIndex 在生成本地索引时会创建一个 storage 文件夹,并在其中生成多个 JSON 文件。以下是每个文件的详细解释: 1. storage 文件夹结构 1.1 docstore.json 功能:存储文档内容及其相关信息。…...

AndroidRom定制删除Settings某些菜单选项

AndroidRom定制删除Settings某些菜单选项 1.前言. 最近在Rom开发中需要隐藏设置中的某些菜单,launcher3中的定制开发,这个属于很基本的定制需求,和隐藏google搜素栏一样简单,这里我就不展开了,直接上代码. 2.隐藏网络…...

Mysql相关知识2:Mysql隔离级别、MVCC、锁

文章目录 MySQL的隔离级别可重复读的实现原理Mysql锁按锁的粒度分类按锁的使用方式分类按锁的状态分类 MySQL的隔离级别 在 MySQL 中,隔离级别定义了事务之间相互隔离的程度,用于控制一个事务对数据的修改在何时以及如何被其他事务可见。MySQL 支持四种…...

Python爬虫实战:获取海口最近2周天气数据,为出行做参考

一、引言 天气状况对人们的出行计划影响重大。获取准确的天气信息并进行分析,能助力用户更好地规划出行。天气网虽提供丰富的天气数据,但因网站存在反爬机制,直接获取数据存在一定难度。本研究借助 Python 的 Scrapy 框架,结合多种技术手段,实现对海口最近两周天气数据的…...

并发设计模式之双缓冲系统

双缓冲的本质是 ​​通过空间换时间​​,通过冗余的缓冲区解决生产者和消费者的速度差异问题,同时提升系统的并发性和稳定性。 双缓冲的核心优势 优势具体表现解耦生产与消费生产者和消费者可以独立工作,无需直接同步。提高并发性生产者和消…...

linux sysfs的使用

在Linux内核驱动开发中&#xff0c;device_create_file 和 device_remove_file 用于动态创建/删除设备的 sysfs 属性文件&#xff0c;常用于暴露设备信息或控制参数。以下是完整示例及详细说明&#xff1a; 1. 头文件引入 #include <linux/module.h> #include <linux/…...

【数据结构和算法】3. 排序算法

本文根据 数据结构和算法入门 视频记录 文章目录 1. 排序算法2. 插入排序 Insertion Sort2.1 概念2.2 具体步骤2.3 Java 实现2.4 复杂度分析 3. 快排 QuickSort3.1 概念3.2 具体步骤3.3 Java实现3.4 复杂度分析 4. 归并排序 MergeSort4.1 概念4.2 递归具体步骤4.3 Java实现4.4…...

LintCode第192题-通配符匹配

描述 给定一个字符串 s 和一个字符模式 p &#xff0c;实现一个支持 ? 和 * 的通配符匹配。匹配规则如下&#xff1a; ? 可以匹配任何单个字符。* 可以匹配任意字符串&#xff08;包括空字符串&#xff09;。 两个串完全匹配才算匹配成功。 样例 样例1 输入: "aa&q…...

redis常用的五种数据类型

redis常用的五种数据类型 文档 redis单机安装redis数据类型-位图bitmap 说明 官网操作命令指南页面&#xff1a;https://redis.io/docs/latest/commands/?nameget&groupstring 常用命令 keys *&#xff1a;查看所有键exists k1 k2&#xff1a;键存在个数type k1&…...

Linux 进程与线程间通信方式及应用分析

Linux 进程与线程间通信方式及应用分析 文章目录 Linux 进程与线程间通信方式及应用分析 1. 管道&#xff08;Pipe&#xff09;1.1 匿名管道&#xff08;Anonymous Pipe&#xff09;示例代码&#xff1a;结果&#xff1a; 1.2 命名管道&#xff08;FIFO&#xff09;示例代码&am…...

AI日报 - 2024年04月22日

&#x1f31f; 今日概览(60秒速览) ▎&#x1f916; 模型进展 | Google发布Gemini 2.5 Flash&#xff0c;强调低延迟与成本效益&#xff1b;Kling AI 2.0展示多轴运动视频生成&#xff1b;研究揭示SLM在知识图谱上优于LLM&#xff0c;RLHF在推理提升上存局限。 ▎&#x1f4bc;…...

FreeRTos学习记录--2.内存管理

后续的章节涉及这些内核对象&#xff1a;task、queue、semaphores和event group等。为了让FreeRTOS更容易使用&#xff0c;这些内核对象一般都是动态分配&#xff1a;用到时分配&#xff0c;不使用时释放。使用内存的动态管理功能&#xff0c;简化了程序设计&#xff1a;不再需…...

HAL库(STM32CubeMX)——高级ADC学习、HRTIM(STM32G474RBT6)

系列文章目录 文章目录 系列文章目录前言存在的问题HRTIMcubemx配置前言 对cubemx的ADC的设置进行补充 ADCs_Common_Settings Mode:ADC 模式 Independent mod 独立 ADC 模式,当使用一个 ADC 时是独立模式,使用两个 ADC 时是双模式,在双模式下还有很多细分模式可选 ADC_Se…...

单例模式(线程安全)

1.什么是单例模式 单例模式&#xff08;Singleton Pattern&#xff09;是一种创建型设计模式&#xff0c;旨在确保一个类只有一个实例&#xff0c;并提供一个全局访问点来访问该实例。这种模式涉及到一个单一的类&#xff0c;该类负责创建自己的对象&#xff0c;同时确保只有单…...

FreeRTos学习记录--1.工程创建与源码概述

1.工程创建与源码概述 1.1 工程创建 使用STM32CubeMX&#xff0c;可以手工添加任务、队列、信号量、互斥锁、定时器等等。但是本课程不想严重依赖STM32CubeMX&#xff0c;所以不会使用STM32CubeMX来添加这些对象&#xff0c;而是手写代码来使用这些对象。 使用STM32CubeMX时&…...

基于大模型的血栓性外痔全流程风险预测与治疗管理研究报告

目录 一、引言 1.1 研究背景与目的 1.2 研究意义 二、血栓性外痔概述 2.1 定义与发病机制 2.2 临床表现与诊断方法 2.3 现有治疗手段综述 三、大模型在血栓性外痔预测中的应用原理 3.1 大模型技术简介 3.2 模型构建与训练数据来源 3.3 模型预测血栓性外痔的工作流程…...

进程控制(linux+C/C++)

目录 进程创建 写时拷贝 fork 进程终止 退出码 进程退出三种情况对应退出信号 &#xff1a;退出码&#xff1a; 进程退出方法 进程等待 两种方式 阻塞等待和非阻塞等待 小知识 进程创建 1.在未创建子进程时&#xff0c;父进程页表对于数据权限为读写&#xff0c;对于…...

C++如何处理多线程环境下的异常?如何确保资源在异常情况下也能正确释放

多线程编程的基本概念与挑战 多线程编程的核心思想是将程序的执行划分为多个并行运行的线程&#xff0c;每个线程可以独立处理任务&#xff0c;从而充分利用多核处理器的性能优势。在C中&#xff0c;开发者可以通过std::thread创建线程&#xff0c;并使用同步原语如std::mutex、…...

TensorBoard如何在同一图表中绘制多个线条

1. 使用不同的日志目录 TensorBoard 会根据日志文件所在的目录来区分不同的运行。可以为每次运行指定一个独立的日志目录&#xff0c;TensorBoard 会自动将这些目录中的数据加载并显示为不同的运行。 示例&#xff08;TensorFlow&#xff09;&#xff1a; import tensorflow…...

微软Entra新安全功能引发大规模账户锁定事件

误报触发大规模锁定 多家机构的Windows管理员报告称&#xff0c;微软Entra ID新推出的"MACE"&#xff08;泄露凭证检测应用&#xff09;功能在部署过程中产生大量误报&#xff0c;导致用户账户被大规模锁定。这些警报和锁定始于昨夜&#xff0c;部分管理员认为属于误…...