深度学习篇---模型参数调优
文章目录
- 前言
- 一、Adam学习(lr)
- 1. 默认学习率
- 2. 较小的学习率
- 模型复杂
- 数据集规模小
- 3. 较大的学习率
- 模型简单
- 训练初期
- 4. 学习率衰减策略
- 固定步长衰减
- 指数衰减
- 二、训练轮数(epoch)
- 1. 经验值设定
- 小数据集与简单模型
- 大数据集和复杂模型
- 2. 监控指标变化
- 损失函数与准确率:
- 验证集表现:
- 3. 学习率衰减结合
- 4. 逐步增加
- 三、批次大小(batch)
- 1. 较小的 batch 大小
- 优点
- 更好的泛化能力
- 更快逃离局部最优
- 缺点
- 训练速度慢
- 梯度估计不稳定
- 适用场景
- 2. 较大的 batch 大小
- 优点
- 训练速度快
- 梯度估计更稳定
- 缺点
- 泛化能力下降
- 内存需求高
- 适用场景
- 3. 动态调整 batch 大小
- 4. 考虑硬件资源
- 5. 结合学习率调整
前言
本文简单介绍了深度学习中的epoch、batch、learning-rate参数大小对模型训练的影响,以及怎样进行适当调优。
一、Adam学习(lr)
Adam(Adaptive Moment Estimation)是一种常用的优化算法,结合了 Adagrad 和 RMSProp 的优点,能自适应地调整每个参数的学习率。在使用 Adam 优化器时,学习率的设置对模型的训练效果有着重要影响。以下是一些常见的学习率设置情况:
1. 默认学习率
在大多数深度学习框架中,Adam 优化器的默认学习率通常设置为 0.001。例如在 PyTorch 中:
import torch
import torch.nn as nn
import torch.optim as optim
# 假设 model 是你的模型
model = nn.Linear(10, 1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
这个默认值在很多情况下表现良好,是一个不错的初始尝试值。它在许多不同类型的任务(如分类、回归等)和不同的模型架构(如神经网络、卷积神经网络等)中都能取得较好的效果。
2. 较小的学习率
当遇到以下情况时,可能需要使用较小的学习率:
模型复杂
模型复杂:如果模型的参数数量非常多,结构复杂,过大的学习率可能会导致模型在训练过程中跳过最优解,无法收敛到较好的结果。此时可以尝试将学习率设置为 0.0001 甚至更小,如 0.00001。
optimizer = optim.Adam(model.parameters(), lr=0.0001)
数据集规模小
数据集规模小:数据集较小时,模型容易过拟合,使用较小的学习率可以使模型在训练过程中更加稳定,避免过度调整参数。
3. 较大的学习率
在某些情况下,也可以尝试使用较大的学习率:
模型简单
模型简单:当模型结构比较简单,参数数量较少时,较大的学习率可以使模型更快地收敛到一个较好的解。可以尝试将学习率设置为 0.01 或 0.1。
optimizer = optim.Adam(model.parameters(), lr=0.01)
训练初期
训练初期:在训练的开始阶段,可以使用较大的学习率让模型快速地朝着最优解的方向前进,然后在训练过程中逐渐降低学习率,这种方法称为学习率衰减。
4. 学习率衰减策略
为了在训练过程中更好地平衡收敛速度和收敛精度,可以采用学习率衰减策略。常见的策略有:
固定步长衰减
固定步长衰减:每隔一定的训练轮数(epoch),将学习率乘以一个固定的衰减因子。例如,每 10 个 epoch 将学习率乘以 0.1。
from torch.optim.lr_scheduler import StepLRoptimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)for epoch in range(num_epochs):# 训练代码optimizer.step()scheduler.step()
指数衰减
指数衰减:学习率按照指数函数的形式进行衰减。
from torch.optim.lr_scheduler import ExponentialLRoptimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(num_epochs):# 训练代码optimizer.step()scheduler.step()
总之,学习率的设置需要根据具体的任务、模型和数据集进行调整,通常需要通过多次实验来找到最优的学习率。
二、训练轮数(epoch)
训练轮数(epoch)指的是将整个训练数据集完整地过一遍模型的次数。合理设计训练轮数对模型训练效果至关重要,轮数太少模型可能欠拟合,轮数太多则可能导致过拟合。以下是常见的训练轮数设计方法:
1. 经验值设定
小数据集与简单模型
小型数据集与简单模型:
当处理的数据集规模较小,且模型结构相对简单时,训练轮数通常不用太多。例如,对于手写数字识别这类简单任务,若使用全连接神经网络,训练轮数设置在 10 - 50 之间可能就足够了。
大数据集和复杂模型
大型数据集与复杂模型:
在处理大型数据集,如 ImageNet 这样的大规模图像数据集,并且使用像 ResNet、VGG 这类复杂的卷积神经网络时,训练轮数可能需要设置为几十甚至上百,常见的是 50 - 200 轮。
2. 监控指标变化
损失函数与准确率:
在训练过程中,可以监控损失函数值和准确率等指标的变化。当损失函数值不再显著下降,或者准确率不再提升时,就可以停止训练。在代码里,可以添加相应的逻辑来实现早停策略。以下是一个简单的早停示例:
# 假设 patience 是容忍训练轮数没有提升的最大次数
patience = 10
best_loss = float('inf')
no_improvement_count = 0for epoch in range(num_epochs):loss = train(model, train_loader, criterion, optimizer, device)if loss < best_loss:best_loss = lossno_improvement_count = 0else:no_improvement_count += 1if no_improvement_count >= patience:print(f"Early stopping at epoch {epoch + 1}")break
验证集表现:
将数据集划分为训练集和验证集,在每个 epoch 结束后,在验证集上评估模型的性能**。当验证集上的性能开始下降时,停止训练**。
3. 学习率衰减结合
在训练过程中采用学习率衰减策略时,训练轮数的设计要和学习率衰减的步数相配合。例如,每 10 个 epoch 衰减一次学习率,那么总的训练轮数可以设置为衰减步数的整数倍。
4. 逐步增加
在模型训练的初始阶段,可以先设置较少的训练轮数进行快速实验,观察模型的训练情况,如损失函数的下降趋势、准确率的变化等。根据初步实验的结果,逐步增加训练轮数,直到达到理想的训练效果。
三、批次大小(batch)
在深度学习中,batch(批次)指的是在一次前向 / 反向传播过程中使用的样本数量。合理设计 batch 大小对模型的训练效率、泛化能力和收敛速度都有重要影响。以下是常见的 batch 设计方法及相关考虑因素:
1. 较小的 batch 大小
优点
更好的泛化能力
更好的泛化能力:较小的 batch 会引入更多的噪声,这可以被看作是一种正则化手段,有助于模型学习到更鲁棒的特征,提高泛化能力。
更快逃离局部最优
更快逃离局部最优:噪声的存在使得模型在优化过程中更容易跳出局部最优解,从而有可能找到更优的全局最优解。
缺点
训练速度慢
训练速度慢:由于每次处理的样本数量少,参数更新的频率会更高,这会增加训练时间,尤其是在 GPU 等并行计算设备上,小 batch 无法充分利用设备的计算资源。
梯度估计不稳定
梯度估计不稳定:小 batch 计算得到的梯度可能会有较大的波动,导致训练过程不稳定。
适用场景
数据集规模较小的情况,小 batch 可以模拟更多的训练步骤,让模型有更多机会学习数据特征。
当模型容易过拟合时,小 batch 带来的噪声可以作为一种正则化方法。
常见取值:通常可以从 1、2、4、8、16 等开始尝试。在你的代码里,batch_size 设置为 32,如果想尝试小 batch,可以将其改为 8 或 16。
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)
2. 较大的 batch 大小
优点
训练速度快
训练速度快:大 batch 可以充分利用计算设备(如 GPU)的并行计算能力,减少参数更新的次数,从而加快训练速度。
梯度估计更稳定
梯度估计更稳定:由于使用了更多的样本计算梯度,梯度的估计会更加准确和稳定,有助于模型更快收敛。
缺点
泛化能力下降
泛化能力下降:大 batch 可能会使模型陷入局部最优解,导致泛化能力变差。
内存需求高
内存需求高:需要更多的内存来存储和处理大量的样本,可能会受到硬件资源的限制。
适用场景
数据集规模非常大的情况,大 batch 可以提高训练效率。
模型结构简单,不太容易过拟合时,可以使用大 batch 加速训练。
常见取值:常见的大 batch 大小有 64、128、256、512 等。你可以将代码中的 batch_size 调整为 64 进行尝试:
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
3. 动态调整 batch 大小
在训练过程中,可以根据训练的不同阶段动态调整 batch 大小。例如,在训练初期使用较大的 batch 快速收敛到一个较好的解,然后在训练后期使用较小的 batch 进行精细调整,提高模型的泛化能力。
4. 考虑硬件资源
在设计 batch 大小时,需要考虑硬件资源的限制。如果 GPU 内存有限,使用过大的 batch 可能会导致内存溢出错误。可以通过逐步增加 batch 大小,直到出现内存问题,然后选择一个稍小的 batch 大小作为合适的值。
5. 结合学习率调整
batch 大小和学习率通常需要一起调整。一般来说,大 batch 可以使用较大的学习率,小 batch 则需要使用较小的学习率。在调整 batch 大小后,可能需要相应地调整学习率,以保证模型的收敛性。
综上所述,选择合适的 batch 大小需要综合考虑数据集规模、模型复杂度、硬件资源等因素,通常需要通过多次实验来找到最优的 batch 大小。
相关文章:
深度学习篇---模型参数调优
文章目录 前言一、Adam学习(lr)1. 默认学习率2. 较小的学习率模型复杂数据集规模小 3. 较大的学习率模型简单训练初期 4. 学习率衰减策略固定步长衰减指数衰减 二、训练轮数(epoch)1. 经验值设定小数据集与简单模型大数据集和复杂…...
影响HTTP网络请求的因素
影响 HTTP 网络请求的因素 1. 带宽 2. 延迟 浏览器阻塞:浏览器会因为一些原因阻塞请求,浏览器对于同一个域名,同时只能有4个连接(这个根据浏览器内核不同可能会有所差异),超过浏览器最大连接数限制&…...
Openssl自签证书相关知识
1.前提 检查是否已安装 openssl $ which openssl /usr/bin/openssl 2.建立CA授权中心 2.1.生成ca私钥(ca-prikey.pem) 初始化 OpenSSL 证书颁发机构(CA)的序列号文件 在生成证书时,ca.srl 的初始序列号需正确初始化(如 01),否则可能导致证书冲突 这会将 01 显示在屏幕…...
浅析车规芯片软错误防护加固的重要性
随着汽车电子技术的飞速发展,汽车已经从传统的机械交通工具转变为高度依赖电子系统的智能移动终端。车规芯片作为汽车电子系统的核心部件,其可靠性和安全性直接关系到车辆的正常运行和驾乘人员的安全。然而,车规芯片在复杂的运行环境中面临着…...
(UI自动化测试web端)第二篇:元素定位的方法_css定位之css选择器
看代码里的【find_element_by_css_selector( )】( )里的表达式怎么写? 文章介绍了第三种写法css选择器,你要根据网页中的实际情况来判断自己到底要用哪一种方法来进行元素定位。每种方法都要多练习,全都熟了之后你在工作当中使用起来元素定位…...
QT自运行程序
终局 搞定了兄弟们,啥也别说了。 不要用xcb,用linuxfb。 用systemd服务。 海康威视的豆干型网络摄像头我这边尝试后,发现在multi-user.target运行级别下,摄像头登录成功了也采集不到画面。 具体愿意暂不清楚,所以如果是涉及摄像头的,建议…...
MPU6050模块详解:从原理到STM32驱动指南(上) | 零基础入门STM32第八十九步
主题内容教学目的/扩展视频加速度传感器电路连接。手册分析。驱动程序,读出数据。能读出3轴数据。 师从洋桃电子,杜洋老师 📑文章目录 一、MPU6050模块介绍1.1 核心特性1.2 模块化优势 二、MPU6050模块连接方法2.1 硬件连接2.2 电源注意事项 …...
STM32 MODBUS-RTU主从站库移植
代码地址 STM32MODBUSRTU: stm32上的modbus工程 从站 FreeModbus是一个开源的Modbus通信协议栈实现。它允许开发者在各种平台上轻松地实现Modbus通信功能,包括串口和以太网。FreeMODBUS提供了用于从设备和主站通信的功能,支持Modbus RTU和Modbus TCP协…...
架构师面试(二十二):TCP 协议
问题 今天我们聊一个非常常见的面试题目,不管前端还是后端,也不管做的是上层业务还是底层框架,更不管技术方向是运维还是架构,都可以思考和参与一下哈! TCP协议无处不在,我们知道 TCP 是基于连接的端到端…...
程序自动化填写网页表单数据
1 背景介绍 如何让程序自动化填写网页表单数据,特别是涉及到批量数据情况时,可以减少人力。下面是涉及到的一些场景,都可以通过相关自动化程序实现。 场景1 场景1,领导安排,通过相关省、市、县、乡镇数据࿰…...
Razer macOS v0.4.10快速安装
链接点这里下载最新的 .dmg 文件。将下载的 .dmg 映像文件拖入 应用程序 文件夹中。若首次打开时出现安全警告【什么扔到废纸篓】,这时候点击 Mac 的“系统偏好设置”-> “安全性与隐私”-> “通用”,然后点击底部的 “打开”。【或者仍然打开】 对…...
常用正则表达式-MAC 地址
MAC地址的定义 物理地址(通常称为 MAC地址,Media Access Control Address)是网络设备在数据链路层(如以太网、Wi-Fi)的唯一标识符。它由设备的网络接口卡(NIC)固化在硬件中,用于在局…...
如何自动化同义词并使用我们的 Synonyms API 进行上传
作者:来自 Elastic Andre Luiz 了解如何使用 LLM 来自动识别和生成同义词, 使术语可以通过程序方式加载到 Elasticsearch 同义词 API 中。 提高搜索结果的质量对于提供高效的用户体验至关重要。优化搜索的一种方法是通过同义词自动扩展查询词。这样可以更…...
一. 相机模组摆放原理
1. 背景: 相机开发时经常出现因模组摆放问题,导致相机成像方向异常。轻则修改软件、模组返工, 重则重新修改堆叠,影响相机调试进度。因此,设计一个模型实现模组摆放纠错很有必要。 2. 原理: 2.1 口诀&am…...
【C++游戏引擎开发】《线性代数》(1):环境配置与基础矩阵类设计
一、开发环境配置 1.1 启用C 20 在VS2022中新建项目后右键项目 1.2 启用增强指令集 1.3 安装Google Test vcpkg安装使用指南 vcpkg install gtest:x64-windows# 集成到系统目录,只需要执行一次,后续安装包之后不需要再次执行 vcpkg integrate inst…...
sqli-labs靶场 less 8
文章目录 sqli-labs靶场less 8 布尔盲注 sqli-labs靶场 每道题都从以下模板讲解,并且每个步骤都有图片,清晰明了,便于复盘。 sql注入的基本步骤 注入点注入类型 字符型:判断闭合方式 (‘、"、’、“”…...
基于大模型的知识图谱搜索的五大核心优势
在传统知识图谱与生成式AI融合的浪潮中,基于大模型的知识图谱搜索正成为新一代智能检索的标杆技术,飞速灵燕智能体平台就使用了该技术,其核心优势体现在: 1. 语义穿透力升级 突破关键词匹配局限,通过大模型的深层语义…...
【MySQL】从零开始:掌握MySQL数据库的核心概念(五)
由于我的无知,我对生存方式只有一个非常普通的信条:不许后悔。 前言 这是我自己学习mysql数据库的第五篇博客总结。后期我会继续把mysql数据库学习笔记开源至博客上。 上一期笔记是关于mysql数据库的增删查改,没看的同学可以过去看看…...
人生感悟8
前言 今天,在这里跟各位聊一些看法。为什么现在的歌曲和影视剧越来越没有艺术性和内涵?为什么现在读书的人越来越少? 正文 这里我先声明一点,就像C或者是Java创建variable or constant一样,本文所述内容只限于个人观…...
Java版Manus实现来了,Spring AI Alibaba发布开源OpenManus实现
此次官方发布的 Spring AI Alibaba OpenManus 实现,包含完整的多智能体任务规划、思考与执行流程,可以让开发者体验 Java 版本的多智能体效果。它能够根据用户的问题进行分析,操作浏览器,执行代码等来完成复杂任务等。 项目源码及…...
鸿蒙UI开发
鸿蒙UI开发 本文旨在分享一些鸿蒙UI布局开发上的一些建议,特别是对屏幕宽高比发生变化时的应对思路和好的实践。 折叠屏适配 一般情况(自适应布局/响应式布局) 1.自适应布局 1.1自适应拉伸 左右组件定宽 TypeScript //左右定宽 Row() { …...
Elasticsearch-实战案例
一、没有使用Elasticsearch的查询速度698ms 1.数据库模糊查询不走索引,在数据量较大的时候,查询性能很差。需要注意的是,数据库模糊查询随着表数据量的增多,查询性能的下降会非常明显,而搜索引擎的性能则不会随着数据增…...
#基于Django实现机器学习医学指标概率预测网站
基于Django实现机器学习医学指标概率预测网站 一、引言 在当今数字化医疗的大背景下,利用机器学习模型结合Web应用进行医学指标的概率预测具有重要的实际意义。本文将详细介绍一个基于Django框架构建的医学指标概率预测系统,通过结合随机森林模型&…...
【bug】OPENCV和FPGA的版本对应关系
如果opencv和FPGA的版本不对应,则会出现如下warning /usr/bin/ld: warning: libavcodec.so.57, needed by /usr/lib/gcc/aarch64-linux-gnu/7/../../../aarch64-linux-gnu/libopencv_videoio.so, may conflict with libavcodec.so.58 /usr/bin/ld: warning: libavformat.so.5…...
IP数据报报文格式
一 概述 IP数据报由两部分组成:首部数据部分。首部的前一部分是固定长度,一共20字节大小,是所有IP数据报文必须具有的;固定部分后面是一些可选字段,其长度是可变的。 二 首部固定部分各字段意义 (1&…...
一键实现:谷歌表单转word(formtoword)
一键将 Google Forms 转换为 Word,最简单的方法 有些繁琐的工作让人倍感挫败,明明 应该 可以自动化。你精心制作了一份 Google Forms,收集了数据,现在需要在 Word 文档中分享其结构或内容。于是,你只能手动复制粘贴问…...
openEuler24.03 LTS下安装Kafka集群
目录 前提条件 Kafka集群规划 下载Kafka 解压 设置环境变量 配置Kafka 分发到其他机器 分发安装文件 分发环境变量 启动Kafka 测试Kafka 关闭Kafka 集群启停脚本 问题及解决 前提条件 安装好ZooKeeper集群,可参考:openEuler24.03 LTS下安…...
qt QQuaternion详解
1. 概述 QQuaternion 是 Qt 中用于表示三维空间中旋转的四元数类。它包含一个标量部分和一个三维向量部分,可以用来表示旋转操作。四元数在计算机图形学中广泛用于平滑的旋转和插值。 2. 重要方法 默认构造函数 QQuaternion::QQuaternion(); // 构造单位四元数 (1…...
epoch、batch、batch size、step、iteration深度学习名词含义详细介绍
卷积神经网络训练中的三个核心概念:Epoch、Batch Size 和迭代次数 在深度学习中,理解一些基本的术语非常重要,这些术语对模型的训练过程、效率以及最终性能都有很大影响。以下是一些常见术语的含义介绍: 1. Epoch(周…...
pytorch中不同的mask方法:masked_fill, masked_select, masked_scatter
在 PyTorch 中,masked_fill、masked_select 和 masked_scatter 是三种常用的掩码(mask)操作方法,它们通过布尔类型的掩码张量(mask)对原始张量进行条件筛选或修改。以下是它们的详细解释和对比:…...
