32. 小批量梯度下降法(Mini-batch Gradient Descent)
在深度学习模型的训练过程中,梯度下降法是最常用的优化算法之一。我们前面介绍了批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent),两者各有优缺点。为了在计算速度和收敛稳定性之间找到平衡,小批量梯度下降法(Mini-batch Gradient Descent)应运而生。下面我们详细介绍其基本思想、优缺点,并通过代码实现来比较三种梯度下降法。
小批量梯度下降法的基本思想
小批量梯度下降法在每次迭代中,使用一小部分随机样本(称为小批量)来计算梯度,并更新参数值。具体来说,算法步骤如下:
1. 初始化参数 \( w \) 和 \( b \)。
2. 在每次迭代中,从训练集中随机抽取 \( m \) 个样本。
3. 使用这 \( m \) 个样本计算损失函数的梯度。
4. 更新参数 \( w \) 和 \( b \)。
其梯度计算公式如下:
\[
\begin{align*}
w &= w - \alpha \cdot \frac{1}{m} \sum_{i=1}^{m} \nabla_w L(w, b, x_i, y_i), \\
b &= b - \alpha \cdot \frac{1}{m} \sum_{i=1}^{m} \nabla_b L(w, b, x_i, y_i),
\end{align*}
\]
其中,\( \alpha \) 是学习率,\( m \) 是小批量的大小。
优缺点
优点
1. 计算速度快:与批量梯度下降法相比,每次迭代只需计算小批量样本的梯度,速度更快。
2. 减少振荡:与随机梯度下降法相比,梯度的计算更加稳定,减少了参数更新时的振荡。
3. 控制灵活:可以调整小批量的大小,使得训练速度和精度之间达到平衡。
缺点
1. 需要调整学习率和小批量大小:学习率决定每次更新的步长,小批量大小决定每次计算梯度使用的样本数量。
2. 内存消耗:小批量大小的选择受限于内存容量,尤其在使用GPU运算时,需要选择合适的小批量大小。
代码实现
下面通过代码实现和比较三种梯度下降法的执行效果。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm# 定义数据集
np.random.seed(42)
X = np.random.rand(1000, 1)
y = 3*X + 2 + np.random.randn(1000, 1) * 0.1# 转换为tensor
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)# 封装为数据集
dataset = TensorDataset(X_tensor, y_tensor)# 定义模型
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):return self.linear(x)# 损失函数
criterion = nn.MSELoss()# 定义梯度下降法的批量大小
batch_sizes = [1000, 1, 128]
batch_labels = ['Batch Gradient Descent', 'Stochastic Gradient Descent', 'Mini-batch Gradient Descent']
colors = ['r', 'g', 'b']# 定义超参数
learning_rate = 0.01
num_epochs = 1000# 存储损失值
losses = {label: [] for label in batch_labels}# 训练模型
for batch_size, label, color in zip(batch_sizes, batch_labels, colors):model = LinearRegressionModel()optimizer = optim.SGD(model.parameters(), lr=learning_rate)data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)for epoch in tqdm(range(num_epochs), desc=label):epoch_loss = 0.0for batch_x, batch_y in data_loader:optimizer.zero_grad()outputs = model(batch_x)loss = criterion(outputs, batch_y)loss.backward()optimizer.step()epoch_loss += loss.item()losses[label].append(epoch_loss / len(data_loader))# 绘制损失值变化曲线
for label, color in zip(batch_labels, colors):plt.plot(losses[label], color=color, label=label)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
结果分析
运行上述代码后,会显示三种梯度下降法在每个迭代周期(epoch)中的损失变化曲线。可以看到:
1. 批量梯度下降法:损失曲线平滑,但训练速度较慢。
2. 随机梯度下降法:训练速度快,但损失曲线波动较大。
3. 小批量梯度下降法:在训练速度和损失曲线的稳定性之间达到了平衡,效果较为理想。
总结
小批量梯度下降法结合了批量梯度下降法和随机梯度下降法的优点,是深度学习中常用的优化算法。通过调整小批量大小和学习率,可以在训练速度和模型精度之间找到最佳平衡。在实际应用中,小批量梯度下降法由于其较高的效率和较好的收敛效果,被广泛应用于各类深度学习模型的训练中。
相关文章:
32. 小批量梯度下降法(Mini-batch Gradient Descent)
在深度学习模型的训练过程中,梯度下降法是最常用的优化算法之一。我们前面介绍了批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent),两者各有优缺点。为了在计算速度和…...
MySQL第八次作业
一、备份与恢复作业: 创库,建表: CREATE DATABASE booksDB; use booksDB; CREATE TABLE books ( bk_id INT NOT NULL PRIMARY KEY, bk_title VARCHAR(50) NOT NULL, copyright YEAR NOT NULL ); CREATE TABLE authors …...
【合集】临时邮箱网站 临时邮箱API(持续更新)
众所周知,在注册一些账户时,比较常见的验证方式就是邮箱,但是在进行一些小众和不知名网站注册时,邮箱的泄露可能预示着不休止的邮件推送。尤其是当我们只是想临时使用邮箱这种情况,第二种,批量注册账号的情…...
职场新人感受
互联网职场感受 阶段介绍 24届6月底毕业生,之前从未实习过。 岗位是后端开发(JAVA),目前已经上班三周(前两周看文档和做了半个简单需求,第三周脱产新人培训)。 职场体验 职场和想象中的工作…...
Window 下Mamba 环境安装踩坑问题汇总及解决方法 (无需绕过selective_scan_cuda)
导航 Mamba 及 Vim 安装问题参看本人之前博客:Mamba 环境安装踩坑问题汇总及解决方法Linux 下Vmamba 安装教程参看本人之前博客:Vmamba 安装教程(无需更改base环境中的cuda版本)Windows 下 VMamba的安装参看本人之前博客…...
前端项目本地的node_modules直接上传到服务器上无法直接使用(node-sasa模块报错)
跑 jekins任务的服务器不能连接外网下载依赖包,就将本地下载的 node_modules直接上传到服务器上,但是运行时node-sass模块报错了ERROR in Missing binding /root/component/node_modules/node-sass/vendor/linux-x64-48/binding.node >> 报错信息类…...
Hadoop3:动态扩容之新增一台机器的初始化工作
一、需求描述 给Hadoop集群动态扩容一个节点 那么,这个节点是全新的,我们需要做哪些准备工作,才能将它融入集群了? 二、初始化配置 1、修改IP和hostname vim /etc/sysconfig/network-scripts/ifcfg-ens33 vim /etc/hostname2、…...
【正点原子i.MX93开发板试用连载体验】录音小程序采集语料
本文最早发表于电子发烧友论坛:【新提醒】【正点原子i.MX93开发板试用连载体验】基于深度学习的语音本地控制 - 正点原子学习小组 - 电子技术论坛 - 广受欢迎的专业电子论坛! (elecfans.com) 接下来就是要尝试训练中文提示词。首先要进行语料采集,这是一…...
【EasyExcel】动态替换表头内容并应用样式
1.定义实体类 import com.alibaba.excel.annotation.ExcelProperty; import com.alibaba.excel.annotation.ContentStyle; import com.alibaba.excel.metadata.BorderStyleEnum; import com.alibaba.excel.metadata.VerticalAlignmentEnum; import com.alibaba.excel.metadata.…...
RocketMQ实现分布式事务
RocketMQ的分布式事务消息功能,在普通消息基础上,支持二阶段的提交。将二阶段提交和本地事务绑定,实现全局提交结果的一致性。 1、生产者将消息发送至RocketMQ服务端。 2、RocketMQ服务端将消息持久化成功之后,向生产者返回Ack确…...
【Rust练习】2.数值类型
练习题来自https://practice-zh.course.rs/basic-types/numbers.html 1 // 移除某个部分让代码工作 fn main() {let x: i32 5;let mut y: u32 5;y x;let z 10; // 这里 z 的类型是? }y的类型不对,另外,数字的默认类型是i32 fn main() {let x: i…...
通过 PPPOE 将 linux 服务器作为本地局域网 IPv4 外网网关
将 linux 服务器作为本地外网网关,方便利用 Linux 生态中的各种网络工具,对流量进行自定义、精细化管理… 环境说明 拨号主机:CentOS 7.9, Linux Kernel 5.4.257 拨号软件: rp-pppoe-3.11-7.el7.x86_64初始化 1、升级系统到新的稳定内核&a…...
gin源码分析
一、高性能 使用sync.pool解决频繁创建的context对象,在百万并发的场景下能大大提供访问性能和减少GC // ServeHTTP conforms to the http.Handler interface. // 每次的http请求都会从sync.pool中获取context,用完之后归还到pool中 func (engine *Engin…...
数学建模入门
目录 文章目录 前言 一、数学建模是什么? 1、官方概念: 2、具体过程 3、适合哪一类人参加? 4、需要有哪些学科基础呢? 二、怎样准备数学建模(必备‘硬件’) 1.组队 2.资料搜索 3.常用算法总结 4.论文撰写的…...
【学习笔记】无人机(UAV)在3GPP系统中的增强支持(十二)-无人机群在物流中的应用
引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…...
同三维T80006EH2-4K30编码器视频使用操作说明书:高清HDMI编码器,高清SDI编码器,4K超清HDMI编码器,双路4K超高清编码器
同三维T80006EH2-4K30编码器视频使用操作说明书:高清HDMI编码器,高清SDI编码器,4K超清HDMI编码器,双路4K超高清编码器 T80006EH2-4K30编码器 同三维,十多年老品牌,我们一直专注:视频采集卡、视频…...
DHCP原理及配置
目录 一、DHCP原理 DHCP介绍 DHCP工作原理 DHCP分配方式 工作原理 DHCP重新登录 DHCP优点 二、DHCP配置 一、DHCP原理 1 DHCP介绍 大家都知道,现在出门很多地方基本上都有WIFI,那么有没有想过这样一个问题,平时在家里都是“固定”的…...
异步日志:性能优化的金钥匙
一、背景 2024 年 4 月的一个宁静的夜晚,正当大家忙完一天的工作准备休息时,应急群里“咚咚咚”开始报警,提示我们余利宝业务的赎回接口成功率下降。 通过 Monitor 监控发现,该接口的耗时已经超过了网关配置的超时阈值(2s)&#…...
matlab仿真 模拟调制(上)
(内容源自详解MATLAB/SIMULINK 通信系统建模与仿真 刘学勇编著第五章内容,有兴趣的读者请阅读原书) 1.幅度调制 clear all ts0.0025; %信号抽样时间间隔 t0:ts:10-ts;%时间矢量 fs1/ts;%抽样频率 dffs/length(t); %fft的频率分…...
【数据结构】--- 堆的应用
个人主页:星纭-CSDN博客 系列文章专栏 :数据结构 踏上取经路,比抵达灵山更重要!一起努力一起进步! 一.堆排序 在前一个文章的学习中,我们使用数组的物理结构构造出了逻辑结构上的堆。那么堆到底有什么用呢&…...
智慧医疗能源事业线深度画像分析(上)
引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...
Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件
今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...
可靠性+灵活性:电力载波技术在楼宇自控中的核心价值
可靠性灵活性:电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中,电力载波技术(PLC)凭借其独特的优势,正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据,无需额外布…...
CMake基础:构建流程详解
目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...
汽车生产虚拟实训中的技能提升与生产优化
在制造业蓬勃发展的大背景下,虚拟教学实训宛如一颗璀璨的新星,正发挥着不可或缺且日益凸显的关键作用,源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例,汽车生产线上各类…...
《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》
在注意力分散、内容高度同质化的时代,情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现,消费者对内容的“有感”程度,正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中࿰…...
postgresql|数据库|只读用户的创建和删除(备忘)
CREATE USER read_only WITH PASSWORD 密码 -- 连接到xxx数据库 \c xxx -- 授予对xxx数据库的只读权限 GRANT CONNECT ON DATABASE xxx TO read_only; GRANT USAGE ON SCHEMA public TO read_only; GRANT SELECT ON ALL TABLES IN SCHEMA public TO read_only; GRANT EXECUTE O…...
C++ 基础特性深度解析
目录 引言 一、命名空间(namespace) C 中的命名空间 与 C 语言的对比 二、缺省参数 C 中的缺省参数 与 C 语言的对比 三、引用(reference) C 中的引用 与 C 语言的对比 四、inline(内联函数…...
【服务器压力测试】本地PC电脑作为服务器运行时出现卡顿和资源紧张(Windows/Linux)
要让本地PC电脑作为服务器运行时出现卡顿和资源紧张的情况,可以通过以下几种方式模拟或触发: 1. 增加CPU负载 运行大量计算密集型任务,例如: 使用多线程循环执行复杂计算(如数学运算、加密解密等)。运行图…...
【Java_EE】Spring MVC
目录 Spring Web MVC 编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 编辑参数重命名 RequestParam 编辑编辑传递集合 RequestParam 传递JSON数据 编辑RequestBody …...
