当前位置: 首页 > news >正文

32. 小批量梯度下降法(Mini-batch Gradient Descent)

在深度学习模型的训练过程中,梯度下降法是最常用的优化算法之一。我们前面介绍了批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent),两者各有优缺点。为了在计算速度和收敛稳定性之间找到平衡,小批量梯度下降法(Mini-batch Gradient Descent)应运而生。下面我们详细介绍其基本思想、优缺点,并通过代码实现来比较三种梯度下降法。

小批量梯度下降法的基本思想

小批量梯度下降法在每次迭代中,使用一小部分随机样本(称为小批量)来计算梯度,并更新参数值。具体来说,算法步骤如下:

1. 初始化参数 \( w \) 和 \( b \)。
2. 在每次迭代中,从训练集中随机抽取 \( m \) 个样本。
3. 使用这 \( m \) 个样本计算损失函数的梯度。
4. 更新参数 \( w \) 和 \( b \)。

其梯度计算公式如下:
\[
\begin{align*}
w &= w - \alpha \cdot \frac{1}{m} \sum_{i=1}^{m} \nabla_w L(w, b, x_i, y_i), \\
b &= b - \alpha \cdot \frac{1}{m} \sum_{i=1}^{m} \nabla_b L(w, b, x_i, y_i),
\end{align*}
\]
其中,\( \alpha \) 是学习率,\( m \) 是小批量的大小。

优缺点

优点

1. 计算速度快:与批量梯度下降法相比,每次迭代只需计算小批量样本的梯度,速度更快。
2. 减少振荡:与随机梯度下降法相比,梯度的计算更加稳定,减少了参数更新时的振荡。
3. 控制灵活:可以调整小批量的大小,使得训练速度和精度之间达到平衡。

缺点

1. 需要调整学习率和小批量大小:学习率决定每次更新的步长,小批量大小决定每次计算梯度使用的样本数量。
2. 内存消耗:小批量大小的选择受限于内存容量,尤其在使用GPU运算时,需要选择合适的小批量大小。

代码实现

下面通过代码实现和比较三种梯度下降法的执行效果。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm# 定义数据集
np.random.seed(42)
X = np.random.rand(1000, 1)
y = 3*X + 2 + np.random.randn(1000, 1) * 0.1# 转换为tensor
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)# 封装为数据集
dataset = TensorDataset(X_tensor, y_tensor)# 定义模型
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):return self.linear(x)# 损失函数
criterion = nn.MSELoss()# 定义梯度下降法的批量大小
batch_sizes = [1000, 1, 128]
batch_labels = ['Batch Gradient Descent', 'Stochastic Gradient Descent', 'Mini-batch Gradient Descent']
colors = ['r', 'g', 'b']# 定义超参数
learning_rate = 0.01
num_epochs = 1000# 存储损失值
losses = {label: [] for label in batch_labels}# 训练模型
for batch_size, label, color in zip(batch_sizes, batch_labels, colors):model = LinearRegressionModel()optimizer = optim.SGD(model.parameters(), lr=learning_rate)data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)for epoch in tqdm(range(num_epochs), desc=label):epoch_loss = 0.0for batch_x, batch_y in data_loader:optimizer.zero_grad()outputs = model(batch_x)loss = criterion(outputs, batch_y)loss.backward()optimizer.step()epoch_loss += loss.item()losses[label].append(epoch_loss / len(data_loader))# 绘制损失值变化曲线
for label, color in zip(batch_labels, colors):plt.plot(losses[label], color=color, label=label)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()


 

结果分析

运行上述代码后,会显示三种梯度下降法在每个迭代周期(epoch)中的损失变化曲线。可以看到:

1. 批量梯度下降法:损失曲线平滑,但训练速度较慢。
2. 随机梯度下降法:训练速度快,但损失曲线波动较大。
3. 小批量梯度下降法:在训练速度和损失曲线的稳定性之间达到了平衡,效果较为理想。

总结

小批量梯度下降法结合了批量梯度下降法和随机梯度下降法的优点,是深度学习中常用的优化算法。通过调整小批量大小和学习率,可以在训练速度和模型精度之间找到最佳平衡。在实际应用中,小批量梯度下降法由于其较高的效率和较好的收敛效果,被广泛应用于各类深度学习模型的训练中。

相关文章:

32. 小批量梯度下降法(Mini-batch Gradient Descent)

在深度学习模型的训练过程中,梯度下降法是最常用的优化算法之一。我们前面介绍了批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent),两者各有优缺点。为了在计算速度和…...

MySQL第八次作业

一、备份与恢复作业: 创库,建表: CREATE DATABASE booksDB; use booksDB; CREATE TABLE books ( bk_id INT NOT NULL PRIMARY KEY, bk_title VARCHAR(50) NOT NULL, copyright YEAR NOT NULL ); CREATE TABLE authors …...

【合集】临时邮箱网站 临时邮箱API(持续更新)

众所周知,在注册一些账户时,比较常见的验证方式就是邮箱,但是在进行一些小众和不知名网站注册时,邮箱的泄露可能预示着不休止的邮件推送。尤其是当我们只是想临时使用邮箱这种情况,第二种,批量注册账号的情…...

职场新人感受

互联网职场感受 阶段介绍 24届6月底毕业生,之前从未实习过。 岗位是后端开发(JAVA),目前已经上班三周(前两周看文档和做了半个简单需求,第三周脱产新人培训)。 职场体验 职场和想象中的工作…...

Window 下Mamba 环境安装踩坑问题汇总及解决方法 (无需绕过selective_scan_cuda)

导航 Mamba 及 Vim 安装问题参看本人之前博客:Mamba 环境安装踩坑问题汇总及解决方法Linux 下Vmamba 安装教程参看本人之前博客:Vmamba 安装教程(无需更改base环境中的cuda版本)Windows 下 VMamba的安装参看本人之前博客&#xf…...

前端项目本地的node_modules直接上传到服务器上无法直接使用(node-sasa模块报错)

跑 jekins任务的服务器不能连接外网下载依赖包,就将本地下载的 node_modules直接上传到服务器上,但是运行时node-sass模块报错了ERROR in Missing binding /root/component/node_modules/node-sass/vendor/linux-x64-48/binding.node >> 报错信息类…...

Hadoop3:动态扩容之新增一台机器的初始化工作

一、需求描述 给Hadoop集群动态扩容一个节点 那么,这个节点是全新的,我们需要做哪些准备工作,才能将它融入集群了? 二、初始化配置 1、修改IP和hostname vim /etc/sysconfig/network-scripts/ifcfg-ens33 vim /etc/hostname2、…...

【正点原子i.MX93开发板试用连载体验】录音小程序采集语料

本文最早发表于电子发烧友论坛:【新提醒】【正点原子i.MX93开发板试用连载体验】基于深度学习的语音本地控制 - 正点原子学习小组 - 电子技术论坛 - 广受欢迎的专业电子论坛! (elecfans.com) 接下来就是要尝试训练中文提示词。首先要进行语料采集,这是一…...

【EasyExcel】动态替换表头内容并应用样式

1.定义实体类 import com.alibaba.excel.annotation.ExcelProperty; import com.alibaba.excel.annotation.ContentStyle; import com.alibaba.excel.metadata.BorderStyleEnum; import com.alibaba.excel.metadata.VerticalAlignmentEnum; import com.alibaba.excel.metadata.…...

RocketMQ实现分布式事务

RocketMQ的分布式事务消息功能,在普通消息基础上,支持二阶段的提交。将二阶段提交和本地事务绑定,实现全局提交结果的一致性。 1、生产者将消息发送至RocketMQ服务端。 2、RocketMQ服务端将消息持久化成功之后,向生产者返回Ack确…...

【Rust练习】2.数值类型

练习题来自https://practice-zh.course.rs/basic-types/numbers.html 1 // 移除某个部分让代码工作 fn main() {let x: i32 5;let mut y: u32 5;y x;let z 10; // 这里 z 的类型是? }y的类型不对,另外,数字的默认类型是i32 fn main() {let x: i…...

通过 PPPOE 将 linux 服务器作为本地局域网 IPv4 外网网关

将 linux 服务器作为本地外网网关,方便利用 Linux 生态中的各种网络工具,对流量进行自定义、精细化管理… 环境说明 拨号主机:CentOS 7.9, Linux Kernel 5.4.257 拨号软件: rp-pppoe-3.11-7.el7.x86_64初始化 1、升级系统到新的稳定内核&a…...

gin源码分析

一、高性能 使用sync.pool解决频繁创建的context对象,在百万并发的场景下能大大提供访问性能和减少GC // ServeHTTP conforms to the http.Handler interface. // 每次的http请求都会从sync.pool中获取context,用完之后归还到pool中 func (engine *Engin…...

数学建模入门

目录 文章目录 前言 一、数学建模是什么? 1、官方概念: 2、具体过程 3、适合哪一类人参加? 4、需要有哪些学科基础呢? 二、怎样准备数学建模(必备‘硬件’) 1.组队 2.资料搜索 3.常用算法总结 4.论文撰写的…...

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(十二)-无人机群在物流中的应用

引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…...

同三维T80006EH2-4K30编码器视频使用操作说明书:高清HDMI编码器,高清SDI编码器,4K超清HDMI编码器,双路4K超高清编码器

同三维T80006EH2-4K30编码器视频使用操作说明书:高清HDMI编码器,高清SDI编码器,4K超清HDMI编码器,双路4K超高清编码器 T80006EH2-4K30编码器 同三维,十多年老品牌,我们一直专注:视频采集卡、视频…...

DHCP原理及配置

目录 一、DHCP原理 DHCP介绍 DHCP工作原理 DHCP分配方式 工作原理 DHCP重新登录 DHCP优点 二、DHCP配置 一、DHCP原理 1 DHCP介绍 大家都知道,现在出门很多地方基本上都有WIFI,那么有没有想过这样一个问题,平时在家里都是“固定”的…...

异步日志:性能优化的金钥匙

一、背景 2024 年 4 月的一个宁静的夜晚,正当大家忙完一天的工作准备休息时,应急群里“咚咚咚”开始报警,提示我们余利宝业务的赎回接口成功率下降。 通过 Monitor 监控发现,该接口的耗时已经超过了网关配置的超时阈值(2s)&#…...

matlab仿真 模拟调制(上)

(内容源自详解MATLAB/SIMULINK 通信系统建模与仿真 刘学勇编著第五章内容,有兴趣的读者请阅读原书) 1.幅度调制 clear all ts0.0025; %信号抽样时间间隔 t0:ts:10-ts;%时间矢量 fs1/ts;%抽样频率 dffs/length(t); %fft的频率分…...

【数据结构】--- 堆的应用

​ 个人主页:星纭-CSDN博客 系列文章专栏 :数据结构 踏上取经路,比抵达灵山更重要!一起努力一起进步! 一.堆排序 在前一个文章的学习中,我们使用数组的物理结构构造出了逻辑结构上的堆。那么堆到底有什么用呢&…...

Android Wi-Fi 连接失败日志分析

1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分: 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析: CTR…...

调用支付宝接口响应40004 SYSTEM_ERROR问题排查

在对接支付宝API的时候,遇到了一些问题,记录一下排查过程。 Body:{"datadigital_fincloud_generalsaas_face_certify_initialize_response":{"msg":"Business Failed","code":"40004","sub_msg…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...

Kafka入门-生产者

生产者 生产者发送流程: 延迟时间为0ms时,也就意味着每当有数据就会直接发送 异步发送API 异步发送和同步发送的不同在于:异步发送不需要等待结果,同步发送必须等待结果才能进行下一步发送。 普通异步发送 首先导入所需的k…...

【JavaSE】多线程基础学习笔记

多线程基础 -线程相关概念 程序(Program) 是为完成特定任务、用某种语言编写的一组指令的集合简单的说:就是我们写的代码 进程 进程是指运行中的程序,比如我们使用QQ,就启动了一个进程,操作系统就会为该进程分配内存…...

Razor编程中@Html的方法使用大全

文章目录 1. 基础HTML辅助方法1.1 Html.ActionLink()1.2 Html.RouteLink()1.3 Html.Display() / Html.DisplayFor()1.4 Html.Editor() / Html.EditorFor()1.5 Html.Label() / Html.LabelFor()1.6 Html.TextBox() / Html.TextBoxFor() 2. 表单相关辅助方法2.1 Html.BeginForm() …...

如何应对敏捷转型中的团队阻力

应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中,明确沟通敏捷转型目的尤为关键,团队成员只有清晰理解转型背后的原因和利益,才能降低对变化的…...

Vue3 PC端 UI组件库我更推荐Naive UI

一、Vue3生态现状与UI库选择的重要性 随着Vue3的稳定发布和Composition API的广泛采用,前端开发者面临着UI组件库的重新选择。一个好的UI库不仅能提升开发效率,还能确保项目的长期可维护性。本文将对比三大主流Vue3 UI库(Naive UI、Element …...

leetcode73-矩阵置零

leetcode 73 思路 记录 0 元素的位置:遍历整个矩阵,找出所有值为 0 的元素,并将它们的坐标记录在数组zeroPosition中置零操作:遍历记录的所有 0 元素位置,将每个位置对应的行和列的所有元素置为 0 具体步骤 初始化…...

电脑桌面太单调,用Python写一个桌面小宠物应用。

下面是一个使用Python创建的简单桌面小宠物应用。这个小宠物会在桌面上游荡,可以响应鼠标点击,并且有简单的动画效果。 import tkinter as tk import random import time from PIL import Image, ImageTk import os import sysclass DesktopPet:def __i…...