当前位置: 首页 > news >正文

32. 小批量梯度下降法(Mini-batch Gradient Descent)

在深度学习模型的训练过程中,梯度下降法是最常用的优化算法之一。我们前面介绍了批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent),两者各有优缺点。为了在计算速度和收敛稳定性之间找到平衡,小批量梯度下降法(Mini-batch Gradient Descent)应运而生。下面我们详细介绍其基本思想、优缺点,并通过代码实现来比较三种梯度下降法。

小批量梯度下降法的基本思想

小批量梯度下降法在每次迭代中,使用一小部分随机样本(称为小批量)来计算梯度,并更新参数值。具体来说,算法步骤如下:

1. 初始化参数 \( w \) 和 \( b \)。
2. 在每次迭代中,从训练集中随机抽取 \( m \) 个样本。
3. 使用这 \( m \) 个样本计算损失函数的梯度。
4. 更新参数 \( w \) 和 \( b \)。

其梯度计算公式如下:
\[
\begin{align*}
w &= w - \alpha \cdot \frac{1}{m} \sum_{i=1}^{m} \nabla_w L(w, b, x_i, y_i), \\
b &= b - \alpha \cdot \frac{1}{m} \sum_{i=1}^{m} \nabla_b L(w, b, x_i, y_i),
\end{align*}
\]
其中,\( \alpha \) 是学习率,\( m \) 是小批量的大小。

优缺点

优点

1. 计算速度快:与批量梯度下降法相比,每次迭代只需计算小批量样本的梯度,速度更快。
2. 减少振荡:与随机梯度下降法相比,梯度的计算更加稳定,减少了参数更新时的振荡。
3. 控制灵活:可以调整小批量的大小,使得训练速度和精度之间达到平衡。

缺点

1. 需要调整学习率和小批量大小:学习率决定每次更新的步长,小批量大小决定每次计算梯度使用的样本数量。
2. 内存消耗:小批量大小的选择受限于内存容量,尤其在使用GPU运算时,需要选择合适的小批量大小。

代码实现

下面通过代码实现和比较三种梯度下降法的执行效果。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm# 定义数据集
np.random.seed(42)
X = np.random.rand(1000, 1)
y = 3*X + 2 + np.random.randn(1000, 1) * 0.1# 转换为tensor
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)# 封装为数据集
dataset = TensorDataset(X_tensor, y_tensor)# 定义模型
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):return self.linear(x)# 损失函数
criterion = nn.MSELoss()# 定义梯度下降法的批量大小
batch_sizes = [1000, 1, 128]
batch_labels = ['Batch Gradient Descent', 'Stochastic Gradient Descent', 'Mini-batch Gradient Descent']
colors = ['r', 'g', 'b']# 定义超参数
learning_rate = 0.01
num_epochs = 1000# 存储损失值
losses = {label: [] for label in batch_labels}# 训练模型
for batch_size, label, color in zip(batch_sizes, batch_labels, colors):model = LinearRegressionModel()optimizer = optim.SGD(model.parameters(), lr=learning_rate)data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)for epoch in tqdm(range(num_epochs), desc=label):epoch_loss = 0.0for batch_x, batch_y in data_loader:optimizer.zero_grad()outputs = model(batch_x)loss = criterion(outputs, batch_y)loss.backward()optimizer.step()epoch_loss += loss.item()losses[label].append(epoch_loss / len(data_loader))# 绘制损失值变化曲线
for label, color in zip(batch_labels, colors):plt.plot(losses[label], color=color, label=label)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()


 

结果分析

运行上述代码后,会显示三种梯度下降法在每个迭代周期(epoch)中的损失变化曲线。可以看到:

1. 批量梯度下降法:损失曲线平滑,但训练速度较慢。
2. 随机梯度下降法:训练速度快,但损失曲线波动较大。
3. 小批量梯度下降法:在训练速度和损失曲线的稳定性之间达到了平衡,效果较为理想。

总结

小批量梯度下降法结合了批量梯度下降法和随机梯度下降法的优点,是深度学习中常用的优化算法。通过调整小批量大小和学习率,可以在训练速度和模型精度之间找到最佳平衡。在实际应用中,小批量梯度下降法由于其较高的效率和较好的收敛效果,被广泛应用于各类深度学习模型的训练中。

相关文章:

32. 小批量梯度下降法(Mini-batch Gradient Descent)

在深度学习模型的训练过程中,梯度下降法是最常用的优化算法之一。我们前面介绍了批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent),两者各有优缺点。为了在计算速度和…...

MySQL第八次作业

一、备份与恢复作业: 创库,建表: CREATE DATABASE booksDB; use booksDB; CREATE TABLE books ( bk_id INT NOT NULL PRIMARY KEY, bk_title VARCHAR(50) NOT NULL, copyright YEAR NOT NULL ); CREATE TABLE authors …...

【合集】临时邮箱网站 临时邮箱API(持续更新)

众所周知,在注册一些账户时,比较常见的验证方式就是邮箱,但是在进行一些小众和不知名网站注册时,邮箱的泄露可能预示着不休止的邮件推送。尤其是当我们只是想临时使用邮箱这种情况,第二种,批量注册账号的情…...

职场新人感受

互联网职场感受 阶段介绍 24届6月底毕业生,之前从未实习过。 岗位是后端开发(JAVA),目前已经上班三周(前两周看文档和做了半个简单需求,第三周脱产新人培训)。 职场体验 职场和想象中的工作…...

Window 下Mamba 环境安装踩坑问题汇总及解决方法 (无需绕过selective_scan_cuda)

导航 Mamba 及 Vim 安装问题参看本人之前博客:Mamba 环境安装踩坑问题汇总及解决方法Linux 下Vmamba 安装教程参看本人之前博客:Vmamba 安装教程(无需更改base环境中的cuda版本)Windows 下 VMamba的安装参看本人之前博客&#xf…...

前端项目本地的node_modules直接上传到服务器上无法直接使用(node-sasa模块报错)

跑 jekins任务的服务器不能连接外网下载依赖包,就将本地下载的 node_modules直接上传到服务器上,但是运行时node-sass模块报错了ERROR in Missing binding /root/component/node_modules/node-sass/vendor/linux-x64-48/binding.node >> 报错信息类…...

Hadoop3:动态扩容之新增一台机器的初始化工作

一、需求描述 给Hadoop集群动态扩容一个节点 那么,这个节点是全新的,我们需要做哪些准备工作,才能将它融入集群了? 二、初始化配置 1、修改IP和hostname vim /etc/sysconfig/network-scripts/ifcfg-ens33 vim /etc/hostname2、…...

【正点原子i.MX93开发板试用连载体验】录音小程序采集语料

本文最早发表于电子发烧友论坛:【新提醒】【正点原子i.MX93开发板试用连载体验】基于深度学习的语音本地控制 - 正点原子学习小组 - 电子技术论坛 - 广受欢迎的专业电子论坛! (elecfans.com) 接下来就是要尝试训练中文提示词。首先要进行语料采集,这是一…...

【EasyExcel】动态替换表头内容并应用样式

1.定义实体类 import com.alibaba.excel.annotation.ExcelProperty; import com.alibaba.excel.annotation.ContentStyle; import com.alibaba.excel.metadata.BorderStyleEnum; import com.alibaba.excel.metadata.VerticalAlignmentEnum; import com.alibaba.excel.metadata.…...

RocketMQ实现分布式事务

RocketMQ的分布式事务消息功能,在普通消息基础上,支持二阶段的提交。将二阶段提交和本地事务绑定,实现全局提交结果的一致性。 1、生产者将消息发送至RocketMQ服务端。 2、RocketMQ服务端将消息持久化成功之后,向生产者返回Ack确…...

【Rust练习】2.数值类型

练习题来自https://practice-zh.course.rs/basic-types/numbers.html 1 // 移除某个部分让代码工作 fn main() {let x: i32 5;let mut y: u32 5;y x;let z 10; // 这里 z 的类型是? }y的类型不对,另外,数字的默认类型是i32 fn main() {let x: i…...

通过 PPPOE 将 linux 服务器作为本地局域网 IPv4 外网网关

将 linux 服务器作为本地外网网关,方便利用 Linux 生态中的各种网络工具,对流量进行自定义、精细化管理… 环境说明 拨号主机:CentOS 7.9, Linux Kernel 5.4.257 拨号软件: rp-pppoe-3.11-7.el7.x86_64初始化 1、升级系统到新的稳定内核&a…...

gin源码分析

一、高性能 使用sync.pool解决频繁创建的context对象,在百万并发的场景下能大大提供访问性能和减少GC // ServeHTTP conforms to the http.Handler interface. // 每次的http请求都会从sync.pool中获取context,用完之后归还到pool中 func (engine *Engin…...

数学建模入门

目录 文章目录 前言 一、数学建模是什么? 1、官方概念: 2、具体过程 3、适合哪一类人参加? 4、需要有哪些学科基础呢? 二、怎样准备数学建模(必备‘硬件’) 1.组队 2.资料搜索 3.常用算法总结 4.论文撰写的…...

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(十二)-无人机群在物流中的应用

引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…...

同三维T80006EH2-4K30编码器视频使用操作说明书:高清HDMI编码器,高清SDI编码器,4K超清HDMI编码器,双路4K超高清编码器

同三维T80006EH2-4K30编码器视频使用操作说明书:高清HDMI编码器,高清SDI编码器,4K超清HDMI编码器,双路4K超高清编码器 T80006EH2-4K30编码器 同三维,十多年老品牌,我们一直专注:视频采集卡、视频…...

DHCP原理及配置

目录 一、DHCP原理 DHCP介绍 DHCP工作原理 DHCP分配方式 工作原理 DHCP重新登录 DHCP优点 二、DHCP配置 一、DHCP原理 1 DHCP介绍 大家都知道,现在出门很多地方基本上都有WIFI,那么有没有想过这样一个问题,平时在家里都是“固定”的…...

异步日志:性能优化的金钥匙

一、背景 2024 年 4 月的一个宁静的夜晚,正当大家忙完一天的工作准备休息时,应急群里“咚咚咚”开始报警,提示我们余利宝业务的赎回接口成功率下降。 通过 Monitor 监控发现,该接口的耗时已经超过了网关配置的超时阈值(2s)&#…...

matlab仿真 模拟调制(上)

(内容源自详解MATLAB/SIMULINK 通信系统建模与仿真 刘学勇编著第五章内容,有兴趣的读者请阅读原书) 1.幅度调制 clear all ts0.0025; %信号抽样时间间隔 t0:ts:10-ts;%时间矢量 fs1/ts;%抽样频率 dffs/length(t); %fft的频率分…...

【数据结构】--- 堆的应用

​ 个人主页:星纭-CSDN博客 系列文章专栏 :数据结构 踏上取经路,比抵达灵山更重要!一起努力一起进步! 一.堆排序 在前一个文章的学习中,我们使用数组的物理结构构造出了逻辑结构上的堆。那么堆到底有什么用呢&…...

label-studio的使用教程(导入本地路径)

文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法

树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”

2025年#高考 将在近日拉开帷幕,#AI 监考一度冲上热搜。当AI深度融入高考,#时间同步 不再是辅助功能,而是决定AI监考系统成败的“生命线”。 AI亮相2025高考,40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕,江西、…...

服务器--宝塔命令

一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行! sudo su - 1. CentOS 系统: yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...

HDFS分布式存储 zookeeper

hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架,允许使用简单的变成模型跨计算机对大型集群进行分布式处理(1.海量的数据存储 2.海量数据的计算)Hadoop核心组件 hdfs(分布式文件存储系统)&a…...

【生成模型】视频生成论文调研

工作清单 上游应用方向:控制、速度、时长、高动态、多主体驱动 类型工作基础模型WAN / WAN-VACE / HunyuanVideo控制条件轨迹控制ATI~镜头控制ReCamMaster~多主体驱动Phantom~音频驱动Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation速…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...

CVE-2020-17519源码分析与漏洞复现(Flink 任意文件读取)

漏洞概览 漏洞名称:Apache Flink REST API 任意文件读取漏洞CVE编号:CVE-2020-17519CVSS评分:7.5影响版本:Apache Flink 1.11.0、1.11.1、1.11.2修复版本:≥ 1.11.3 或 ≥ 1.12.0漏洞类型:路径遍历&#x…...

C语言中提供的第三方库之哈希表实现

一. 简介 前面一篇文章简单学习了C语言中第三方库(uthash库)提供对哈希表的操作,文章如下: C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...

人工智能--安全大模型训练计划:基于Fine-tuning + LLM Agent

安全大模型训练计划:基于Fine-tuning LLM Agent 1. 构建高质量安全数据集 目标:为安全大模型创建高质量、去偏、符合伦理的训练数据集,涵盖安全相关任务(如有害内容检测、隐私保护、道德推理等)。 1.1 数据收集 描…...