32. 小批量梯度下降法(Mini-batch Gradient Descent)
在深度学习模型的训练过程中,梯度下降法是最常用的优化算法之一。我们前面介绍了批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent),两者各有优缺点。为了在计算速度和收敛稳定性之间找到平衡,小批量梯度下降法(Mini-batch Gradient Descent)应运而生。下面我们详细介绍其基本思想、优缺点,并通过代码实现来比较三种梯度下降法。
小批量梯度下降法的基本思想
小批量梯度下降法在每次迭代中,使用一小部分随机样本(称为小批量)来计算梯度,并更新参数值。具体来说,算法步骤如下:
1. 初始化参数 \( w \) 和 \( b \)。
2. 在每次迭代中,从训练集中随机抽取 \( m \) 个样本。
3. 使用这 \( m \) 个样本计算损失函数的梯度。
4. 更新参数 \( w \) 和 \( b \)。
其梯度计算公式如下:
\[
\begin{align*}
w &= w - \alpha \cdot \frac{1}{m} \sum_{i=1}^{m} \nabla_w L(w, b, x_i, y_i), \\
b &= b - \alpha \cdot \frac{1}{m} \sum_{i=1}^{m} \nabla_b L(w, b, x_i, y_i),
\end{align*}
\]
其中,\( \alpha \) 是学习率,\( m \) 是小批量的大小。
优缺点
优点
1. 计算速度快:与批量梯度下降法相比,每次迭代只需计算小批量样本的梯度,速度更快。
2. 减少振荡:与随机梯度下降法相比,梯度的计算更加稳定,减少了参数更新时的振荡。
3. 控制灵活:可以调整小批量的大小,使得训练速度和精度之间达到平衡。
缺点
1. 需要调整学习率和小批量大小:学习率决定每次更新的步长,小批量大小决定每次计算梯度使用的样本数量。
2. 内存消耗:小批量大小的选择受限于内存容量,尤其在使用GPU运算时,需要选择合适的小批量大小。
代码实现
下面通过代码实现和比较三种梯度下降法的执行效果。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm# 定义数据集
np.random.seed(42)
X = np.random.rand(1000, 1)
y = 3*X + 2 + np.random.randn(1000, 1) * 0.1# 转换为tensor
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)# 封装为数据集
dataset = TensorDataset(X_tensor, y_tensor)# 定义模型
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):return self.linear(x)# 损失函数
criterion = nn.MSELoss()# 定义梯度下降法的批量大小
batch_sizes = [1000, 1, 128]
batch_labels = ['Batch Gradient Descent', 'Stochastic Gradient Descent', 'Mini-batch Gradient Descent']
colors = ['r', 'g', 'b']# 定义超参数
learning_rate = 0.01
num_epochs = 1000# 存储损失值
losses = {label: [] for label in batch_labels}# 训练模型
for batch_size, label, color in zip(batch_sizes, batch_labels, colors):model = LinearRegressionModel()optimizer = optim.SGD(model.parameters(), lr=learning_rate)data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)for epoch in tqdm(range(num_epochs), desc=label):epoch_loss = 0.0for batch_x, batch_y in data_loader:optimizer.zero_grad()outputs = model(batch_x)loss = criterion(outputs, batch_y)loss.backward()optimizer.step()epoch_loss += loss.item()losses[label].append(epoch_loss / len(data_loader))# 绘制损失值变化曲线
for label, color in zip(batch_labels, colors):plt.plot(losses[label], color=color, label=label)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
结果分析
运行上述代码后,会显示三种梯度下降法在每个迭代周期(epoch)中的损失变化曲线。可以看到:
1. 批量梯度下降法:损失曲线平滑,但训练速度较慢。
2. 随机梯度下降法:训练速度快,但损失曲线波动较大。
3. 小批量梯度下降法:在训练速度和损失曲线的稳定性之间达到了平衡,效果较为理想。
总结
小批量梯度下降法结合了批量梯度下降法和随机梯度下降法的优点,是深度学习中常用的优化算法。通过调整小批量大小和学习率,可以在训练速度和模型精度之间找到最佳平衡。在实际应用中,小批量梯度下降法由于其较高的效率和较好的收敛效果,被广泛应用于各类深度学习模型的训练中。
相关文章:
32. 小批量梯度下降法(Mini-batch Gradient Descent)
在深度学习模型的训练过程中,梯度下降法是最常用的优化算法之一。我们前面介绍了批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent),两者各有优缺点。为了在计算速度和…...
MySQL第八次作业
一、备份与恢复作业: 创库,建表: CREATE DATABASE booksDB; use booksDB; CREATE TABLE books ( bk_id INT NOT NULL PRIMARY KEY, bk_title VARCHAR(50) NOT NULL, copyright YEAR NOT NULL ); CREATE TABLE authors …...
【合集】临时邮箱网站 临时邮箱API(持续更新)
众所周知,在注册一些账户时,比较常见的验证方式就是邮箱,但是在进行一些小众和不知名网站注册时,邮箱的泄露可能预示着不休止的邮件推送。尤其是当我们只是想临时使用邮箱这种情况,第二种,批量注册账号的情…...
职场新人感受
互联网职场感受 阶段介绍 24届6月底毕业生,之前从未实习过。 岗位是后端开发(JAVA),目前已经上班三周(前两周看文档和做了半个简单需求,第三周脱产新人培训)。 职场体验 职场和想象中的工作…...
Window 下Mamba 环境安装踩坑问题汇总及解决方法 (无需绕过selective_scan_cuda)
导航 Mamba 及 Vim 安装问题参看本人之前博客:Mamba 环境安装踩坑问题汇总及解决方法Linux 下Vmamba 安装教程参看本人之前博客:Vmamba 安装教程(无需更改base环境中的cuda版本)Windows 下 VMamba的安装参看本人之前博客…...
前端项目本地的node_modules直接上传到服务器上无法直接使用(node-sasa模块报错)
跑 jekins任务的服务器不能连接外网下载依赖包,就将本地下载的 node_modules直接上传到服务器上,但是运行时node-sass模块报错了ERROR in Missing binding /root/component/node_modules/node-sass/vendor/linux-x64-48/binding.node >> 报错信息类…...
Hadoop3:动态扩容之新增一台机器的初始化工作
一、需求描述 给Hadoop集群动态扩容一个节点 那么,这个节点是全新的,我们需要做哪些准备工作,才能将它融入集群了? 二、初始化配置 1、修改IP和hostname vim /etc/sysconfig/network-scripts/ifcfg-ens33 vim /etc/hostname2、…...
【正点原子i.MX93开发板试用连载体验】录音小程序采集语料
本文最早发表于电子发烧友论坛:【新提醒】【正点原子i.MX93开发板试用连载体验】基于深度学习的语音本地控制 - 正点原子学习小组 - 电子技术论坛 - 广受欢迎的专业电子论坛! (elecfans.com) 接下来就是要尝试训练中文提示词。首先要进行语料采集,这是一…...
【EasyExcel】动态替换表头内容并应用样式
1.定义实体类 import com.alibaba.excel.annotation.ExcelProperty; import com.alibaba.excel.annotation.ContentStyle; import com.alibaba.excel.metadata.BorderStyleEnum; import com.alibaba.excel.metadata.VerticalAlignmentEnum; import com.alibaba.excel.metadata.…...
RocketMQ实现分布式事务
RocketMQ的分布式事务消息功能,在普通消息基础上,支持二阶段的提交。将二阶段提交和本地事务绑定,实现全局提交结果的一致性。 1、生产者将消息发送至RocketMQ服务端。 2、RocketMQ服务端将消息持久化成功之后,向生产者返回Ack确…...
【Rust练习】2.数值类型
练习题来自https://practice-zh.course.rs/basic-types/numbers.html 1 // 移除某个部分让代码工作 fn main() {let x: i32 5;let mut y: u32 5;y x;let z 10; // 这里 z 的类型是? }y的类型不对,另外,数字的默认类型是i32 fn main() {let x: i…...
通过 PPPOE 将 linux 服务器作为本地局域网 IPv4 外网网关
将 linux 服务器作为本地外网网关,方便利用 Linux 生态中的各种网络工具,对流量进行自定义、精细化管理… 环境说明 拨号主机:CentOS 7.9, Linux Kernel 5.4.257 拨号软件: rp-pppoe-3.11-7.el7.x86_64初始化 1、升级系统到新的稳定内核&a…...
gin源码分析
一、高性能 使用sync.pool解决频繁创建的context对象,在百万并发的场景下能大大提供访问性能和减少GC // ServeHTTP conforms to the http.Handler interface. // 每次的http请求都会从sync.pool中获取context,用完之后归还到pool中 func (engine *Engin…...
数学建模入门
目录 文章目录 前言 一、数学建模是什么? 1、官方概念: 2、具体过程 3、适合哪一类人参加? 4、需要有哪些学科基础呢? 二、怎样准备数学建模(必备‘硬件’) 1.组队 2.资料搜索 3.常用算法总结 4.论文撰写的…...
【学习笔记】无人机(UAV)在3GPP系统中的增强支持(十二)-无人机群在物流中的应用
引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…...
同三维T80006EH2-4K30编码器视频使用操作说明书:高清HDMI编码器,高清SDI编码器,4K超清HDMI编码器,双路4K超高清编码器
同三维T80006EH2-4K30编码器视频使用操作说明书:高清HDMI编码器,高清SDI编码器,4K超清HDMI编码器,双路4K超高清编码器 T80006EH2-4K30编码器 同三维,十多年老品牌,我们一直专注:视频采集卡、视频…...
DHCP原理及配置
目录 一、DHCP原理 DHCP介绍 DHCP工作原理 DHCP分配方式 工作原理 DHCP重新登录 DHCP优点 二、DHCP配置 一、DHCP原理 1 DHCP介绍 大家都知道,现在出门很多地方基本上都有WIFI,那么有没有想过这样一个问题,平时在家里都是“固定”的…...
异步日志:性能优化的金钥匙
一、背景 2024 年 4 月的一个宁静的夜晚,正当大家忙完一天的工作准备休息时,应急群里“咚咚咚”开始报警,提示我们余利宝业务的赎回接口成功率下降。 通过 Monitor 监控发现,该接口的耗时已经超过了网关配置的超时阈值(2s)&#…...
matlab仿真 模拟调制(上)
(内容源自详解MATLAB/SIMULINK 通信系统建模与仿真 刘学勇编著第五章内容,有兴趣的读者请阅读原书) 1.幅度调制 clear all ts0.0025; %信号抽样时间间隔 t0:ts:10-ts;%时间矢量 fs1/ts;%抽样频率 dffs/length(t); %fft的频率分…...
【数据结构】--- 堆的应用
个人主页:星纭-CSDN博客 系列文章专栏 :数据结构 踏上取经路,比抵达灵山更重要!一起努力一起进步! 一.堆排序 在前一个文章的学习中,我们使用数组的物理结构构造出了逻辑结构上的堆。那么堆到底有什么用呢&…...
认知殖民与范式陷阱:当代人工智能发展路径的文明危机研究
认知殖民与范式陷阱:当代人工智能发展路径的文明危机研究摘要本文从文明安全与认知主权视角出发,系统批判了当前以Transformer架构、Scaling Law和大语言模型为核心的人工智能技术范式。研究指出,该范式不仅是技术路径的选择,更是…...
加拿大AI治理实战:风险分级、监管沙盒与可信AI工程化落地
1. 项目概述:这不是一场技术秀,而是一场制度设计的实战演练“Canada’s AI Ambitions: Navigating the Future of AI Governance”——这个标题里没有一行代码,不提一个模型参数,却直指当前全球AI发展最棘手、最易被忽视的底层命题…...
餐饮门店AI Agent上线倒计时:错过Q3政策补贴窗口期,将多付47%算力成本(附工信部认证服务商名录)
更多请点击: https://kaifayun.com 第一章:餐饮门店AI Agent的核心价值与政策窗口期紧迫性 在人力成本持续攀升、消费者预期快速迭代的双重压力下,餐饮门店正面临从“经验驱动”向“智能协同”跃迁的关键拐点。AI Agent 不再是实验室概念&am…...
法律科技的发展脉络:从数字化管理到AI辅助办案的演进路径
摘要 执业15年,我经历了律师行业工具变迁的三个阶段:纸质时代、本地软件时代、云端时代。现在正站在第四个阶段的起点——AI辅助办案。这篇文章回顾法律科技的发展脉络,分析每个阶段的特征和局限性,以及正在发生的变化趋势。 第一…...
XZ63C,18V输入,CMOS输出电压检测芯片
产品概述这系列芯片是使用 CMOS 技术开发的高精度、低功耗、小封装电压检测芯片。检测电压在小温度漂移的情况下保持极高的精度。输出配置是 CMOS 输出。产品特点● 封装:SOT23-3,TO92● 输出配置:CMOS● 工作电压:1.5V-18V …...
从电影运镜到游戏镜头:手把手教你用Cinemachine实现高级镜头语言(含Dutch Angle等实战配置)
从电影运镜到游戏镜头:手把手教你用Cinemachine实现高级镜头语言(含Dutch Angle等实战配置) 在游戏开发中,镜头语言是叙事和情感表达的重要工具。就像电影导演通过精心设计的镜头来引导观众情绪一样,游戏开发者也可以…...
揭秘FPGA内部世界:PrjXRay开源工具完整指南
揭秘FPGA内部世界:PrjXRay开源工具完整指南 【免费下载链接】prjxray Documenting the Xilinx 7-series bit-stream format. 项目地址: https://gitcode.com/gh_mirrors/pr/prjxray 你是否曾好奇FPGA芯片内部的神秘世界?那些二进制位流背后究竟隐…...
sdk-manager-plugin源码剖析:学习Gradle插件架构的完美案例 [特殊字符]
sdk-manager-plugin源码剖析:学习Gradle插件架构的完美案例 🚀 【免费下载链接】sdk-manager-plugin DEPRECATED Gradle plugin which downloads and manages your Android SDK. 项目地址: https://gitcode.com/gh_mirrors/sd/sdk-manager-plugin …...
JetBrains IDE试用重置终极指南:如何快速解决开发工具到期问题
JetBrains IDE试用重置终极指南:如何快速解决开发工具到期问题 【免费下载链接】ide-eval-resetter 项目地址: https://gitcode.com/gh_mirrors/id/ide-eval-resetter 还在为IntelliJ IDEA、PyCharm等JetBrains IDE试用期到期而烦恼吗?当你的开发…...
终极指南:5分钟搭建Rust高性能HTTP文件服务器,告别繁琐配置
终极指南:5分钟搭建Rust高性能HTTP文件服务器,告别繁琐配置 【免费下载链接】simple-http-server Simple http server in Rust (Windows/Mac/Linux) 项目地址: https://gitcode.com/gh_mirrors/si/simple-http-server Simple HTTP Server是一款基…...
