当前位置: 首页 > news >正文

深度学习中的 Dropout:原理、公式与实现解析

8. dropout

深度学习中的 Dropout:原理、公式与实现解析

在神经网络训练中,模型往往倾向于“记住”训练数据的细节甚至噪声,导致模型在新数据上的表现不佳,即过拟合。为了解决这一问题,Dropout 应运而生。通过在训练过程中随机丢弃一部分神经元,Dropout 能减少模型对特定神经元的依赖,从而提升泛化能力,今天我们将深入讲解 Dropout 的原理,并用代码实现它!


为什么需要 Dropout?

在没有正则化的情况下,神经网络可能会过于依赖于某些特定的神经元,这种现象容易导致过拟合。Dropout 通过随机丢弃神经元,避免模型过度依赖某些特征,使得模型在新数据上表现更好。


Dropout 的工作原理

1. Dropout 的训练过程

假设我们有一个输入向量 x = [ x 1 , x 2 , … , x n ] x = [x_1, x_2, \dots, x_n] x=[x1,x2,,xn]Dropout 在训练时会遵循以下步骤:

  1. 设置丢弃概率 p p p :通常在 0.1 到 0.5 之间,表示每个神经元被丢弃的概率。
  2. 生成随机掩码 m m m
    • 对每个神经元生成一个随机值。
    • 如果随机值小于 p p p ,该神经元输出置为 0(即丢弃)。
    • 如果随机值大于等于 p p p ,该神经元输出保持不变。
  3. 应用掩码:将掩码与输入相乘,丢弃部分神经元输出。

在测试时,我们不再随机丢弃神经元,而是将每个神经元的输出缩小 1 − p 1 - p 1p 倍,以保持与训练时相同的输出期望值。


Dropout 的数学公式

在训练时,Dropout 可以用以下公式表示:

output = x ⋅ m \text{output} = x \cdot m output=xm

其中 m m m 是随机掩码,0 表示丢弃,1 表示保留。训练时,为了保持输出一致性,我们会将结果除以 1 − p 1 - p 1p

output = x ⋅ m 1 − p \text{output} = \frac{x \cdot m}{1 - p} output=1pxm

在测试时,我们不再随机丢弃,而是将每个神经元的输出乘以 1 − p 1 - p 1p

output = x ⋅ ( 1 − p ) \text{output} = x \cdot (1 - p) output=x(1p)

这样可以确保训练和测试时的输出分布一致。


自己实现一个 Dropout 类

为了帮助大家理解 Dropout 的实现原理,我们可以用 Python 和 PyTorch 实现一个简单的 Dropout 类。

import torch
import torch.nn as nnclass CustomDropout(nn.Module):def __init__(self, p=0.5):super(CustomDropout, self).__init__()self.p = p  # 丢弃概率def forward(self, x):if self.training:# 生成与 x 形状相同的随机掩码mask = (torch.rand_like(x) > self.p).float()return x * mask / (1 - self.p)else:# 推理时,直接缩放输出return x * (1 - self.p)

代码解析

  • 初始化:我们定义了 p 表示丢弃的概率。p 越大,丢弃的神经元越多。
  • 前向传播
    • 在训练模式下:生成一个与输入张量形状相同的随机掩码,对每个神经元随机保留或丢弃。
    • 在测试模式下:不再随机丢弃,而是将输出乘以 1 − p 1 - p 1p ,确保输出分布一致。

测试代码

我们可以使用以下代码测试自定义 Dropout 的效果。

# 输入张量 x
x = torch.ones(5, 5)  # 一个简单的 5x5 全 1 张量# 实例化自定义 Dropout
dropout = CustomDropout(p=0.5)# 训练模式
dropout.train()
output_train = dropout(x)
print("训练模式下的输出:\\n", output_train)# 推理模式
dropout.eval()
output_eval = dropout(x)
print("推理模式下的输出:\\n", output_eval)

解释测试结果

  • 训练模式:输出中会有一部分元素被随机置为 0,其余的值会放大(除以 1 − p 1 - p 1p )。
  • 推理模式:所有元素值会被缩小到 1 − p 1 - p 1p 倍,以确保训练和推理阶段输出分布一致。

为什么训练和测试阶段需要缩放?

在训练时,Dropout 随机丢弃一部分神经元,使得实际参与计算的神经元变少。这样训练时的输出总量会降低,因此我们需要对保留下来的神经元进行缩放(除以 1 − p 1 - p 1p )。在测试时,我们则对输出进行整体缩放(乘以 1 − p 1 - p 1p ),以确保训练和测试阶段的输出期望值一致,从而保证模型在不同阶段表现一致。


总结

  • Dropout 是一种防止过拟合的正则化方法,通过随机丢弃神经元来提升模型的泛化能力。
  • 在训练时,随机丢弃神经元并缩放剩余神经元的输出。
  • 在推理时,直接缩放整个输出,以保持训练和推理的分布一致。

希望这篇文章能帮助你理解 Dropout 的工作原理和实现过程。如果有任何疑问,欢迎留言讨论!

相关文章:

深度学习中的 Dropout:原理、公式与实现解析

8. dropout 深度学习中的 Dropout:原理、公式与实现解析 在神经网络训练中,模型往往倾向于“记住”训练数据的细节甚至噪声,导致模型在新数据上的表现不佳,即过拟合。为了解决这一问题,Dropout 应运而生。通过在训练…...

【大数据学习 | HBASE】habse的表结构

在使用的时候hbase就是一个普通的表,但是hbase是一个列式存储的表结构,与我们常用的mysql等关系型数据库的存储方式不同,mysql中的所有列的数据是按照行级别进行存储的,查询数据要整个一行查询出来,不想要的字段也需要…...

完成程序《大奖赛评分B》

学习目标: 使用代码完成程序《大奖赛评分B》 题目: 如今许多歌手大奖赛评分时,为了体现公平,在评委给出分数后统计平均得分时,都会去掉最高分和最低分。编写程序,读入评委打分(分数都是大于0的…...

K8S篇(基本介绍)

目录 一、什么是Kubernetes? 二、Kubernetes管理员认证(CKA) 1. 简介 2. 考试难易程度 3. 考试时长 4. 多少分及格 5. 考试费用 三、Kubernetes整体架构 Master Nodes 四、Kubernetes架构及和核心组件 五、Kubernetes各个组件及功…...

linux alsa-lib snd_pcm_open函数源码分析(三)

欢迎直接到博客 linux alsa-lib snd_pcm_open函数源码分析(三) 系列文章其他部分: linux alsa-lib snd_pcm_open函数源码分析(一) linux alsa-lib snd_pcm_open函数源码分析(二) linux alsa-lib snd_pcm_open函数源码分析(四…...

基于ssm的个人健康管理系统

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…...

Debian下载ISO镜像的方法

步骤 1:访问Debian官方网站 打开你的网络浏览器,在地址栏中输入 https://www.debian.org/ 并回车,这将带你到Debian的官方网站。 步骤 2:导航到下载页面 在Debian官方网站的首页上,找到并点击“Download Debian”或类…...

大厂面试真题-简单说说线程池接到新任务之后的操作流程

线程池在接到新任务后的操作流程通常遵循以下步骤,这些步骤确保了任务的高效管理和执行。 一、判断当前线程状态 线程池首先会判断当前是否存在空闲线程,即没有正在执行任务且未被标记为死亡的线程。 有空闲线程:如果存在空闲线程&#xf…...

「Mac畅玩鸿蒙与硬件23」鸿蒙UI组件篇13 - 自定义组件的创建与使用

自定义组件可以帮助开发者实现复用性强、逻辑清晰的界面模块。通过自定义组件,鸿蒙应用能够提高代码的可维护性,并简化复杂布局的构建。本篇将介绍如何创建自定义组件,如何向组件传递数据,以及如何在不同页面间复用这些组件。 关键…...

C++关键字:mutable

文章目录 一、mutable1.mutable修饰非静态的成员变量2.mutable用于lambda表达式3.mutable不能修饰的变量:静态变量、const变量 一、mutable 1.mutable修饰非静态的成员变量 1.mutable仅能修饰类中的非静态的成员变量。不能修饰全局变量、局部变量、静态变量、常量…...

Agent 智能体开发框架选型指南

编者按: 本文通过作者的实践对比发现,框架的选择应基于项目具体需求和团队特点,而不是简单追求某个特定框架。不同框架各有优势: 无框架方案实施最为简单直接,代码结构清晰,适合理解智能体原理,…...

基于Zynq FPGA对雷龙SD NAND的测试

一、SD NAND 特征 1.1 SD 卡简介 雷龙的 SD NAND 有很多型号,在测试中使用的是 CSNP4GCR01-AMW 与 CSNP32GCR01-AOW。芯片是基于 NAND FLASH 和 SD 控制器实现的 SD 卡。具有强大的坏块管理和纠错功能,并且在意外掉电的情况下同样能保证数据的安全。 …...

AOSP沙盒android 11

这里介绍一下aosp装系统 什么是aosp AOSP(Android Open Source Project)是Android操作系统的开源版本。 它由Google主导,提供了Android的源代码和相关工具,供开发者使用和修改。 AOSP包含了Android的核心组件和API,使…...

【JWT】Asp.Net Core中JWT刷新Token解决方案

Asp.Net Core中JWT刷新Token解决方案 前言方案一:当我们操作某个需要token作为请求头的接口时,返回的数据错误error.response.status === 401,说明我们的token已经过期了。方案二:实现用户无感知的刷新token值,我们希望当响应返回的数据是401身份过期时,响应阻拦器自动帮我…...

AJ-Report:一款开源且非常强大的数据可视化大屏和报表工具

嗨,大家好,我是小华同学,关注我们获得“最新、最全、最优质”开源项目和工作学习方法 AJ-Report是一个基于Java的开源报表工具,它集成了ECharts、Ant Design Vue等前端技术,致力于为企业提供一站式的数据可视化解决方案…...

stm32不小心把SWD和JTAG都给关了,程序下载不进去,怎么办?

因为想用STM32F103的PA15引脚,调试程序的时候不小心把SWD和JTAD接口都给关了,先看下罪魁祸首 GPIO_PinRemapConfig(GPIO_Remap_SWJ_JTAGDisable,ENABLE);//关掉JTAG,不关SWGPIO_PinRemapConfig(GPIO_Remap_SWJ_Disable, ENABLE);//关掉SW&am…...

【UE5】在材质中实现球形法线技术,常用于改善植物等表面的渲染效果

在材质中实现球形法线,这种技术常用于植被渲染等场景。通过应用球形法线可以显著提升植物再低几何体情况下的光照效果。 三二一上截图! 当然也可以用于任何你希望模型圆润的地方,下图中做了一个Cube倒角...

【MATLAB源码-第210期】基于matlab的OFDM电力线系统仿真,不同梳状导频间隔对比。三种信道估计,三种插值误码率对比

操作环境: MATLAB 2022a 1、算法描述 OFDM电力线通信系统(PLC)是一种通过电力线传输数据的通信技术,利用了OFDM(Orthogonal Frequency Division Multiplexing,正交频分复用)技术的优势来提高…...

基于SpringBoot的城镇保障性住房管理策略

3系统分析 3.1可行性分析 通过对本城镇保障性住房管理系统实行的目的初步调查和分析,提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本城镇保障性住房管理系统采用SSM框架,JA…...

支持高性能结构化数据提取的 Embedding 模型——NuExtract-v1.5

NuExtract 是一个用户友好型模型,设计用于从长文档中提取信息。它可以处理长达 20,000 个标记的输入,是合同、报告和其他商业通信的理想选择。NuExtract 的与众不同之处在于它能够处理和理解文档的整个上下文。这意味着它可以捕捉到可能分散在长文本不同…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销,平衡网络负载,延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

K8S认证|CKS题库+答案| 11. AppArmor

目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...

【Linux】C语言执行shell指令

在C语言中执行Shell指令 在C语言中&#xff0c;有几种方法可以执行Shell指令&#xff1a; 1. 使用system()函数 这是最简单的方法&#xff0c;包含在stdlib.h头文件中&#xff1a; #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异&#xff0c;它们的数据同步要求既要保持数据的准确性和一致性&#xff0c;又要处理好性能问题。以下是一些主要的技术要点&#xff1a; 数据结构差异 数据类型差异&#xff…...

ffmpeg(四):滤镜命令

FFmpeg 的滤镜命令是用于音视频处理中的强大工具&#xff0c;可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下&#xff1a; ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜&#xff1a; ffmpeg…...

Java入门学习详细版(一)

大家好&#xff0c;Java 学习是一个系统学习的过程&#xff0c;核心原则就是“理论 实践 坚持”&#xff0c;并且需循序渐进&#xff0c;不可过于着急&#xff0c;本篇文章推出的这份详细入门学习资料将带大家从零基础开始&#xff0c;逐步掌握 Java 的核心概念和编程技能。 …...

Linux系统部署KES

1、安装准备 1.版本说明V008R006C009B0014 V008&#xff1a;是version产品的大版本。 R006&#xff1a;是release产品特性版本。 C009&#xff1a;是通用版 B0014&#xff1a;是build开发过程中的构建版本2.硬件要求 #安全版和企业版 内存&#xff1a;1GB 以上 硬盘&#xf…...

自然语言处理——文本分类

文本分类 传统机器学习方法文本表示向量空间模型 特征选择文档频率互信息信息增益&#xff08;IG&#xff09; 分类器设计贝叶斯理论&#xff1a;线性判别函数 文本分类性能评估P-R曲线ROC曲线 将文本文档或句子分类为预定义的类或类别&#xff0c; 有单标签多类别文本分类和多…...

CppCon 2015 学习:REFLECTION TECHNIQUES IN C++

关于 Reflection&#xff08;反射&#xff09; 这个概念&#xff0c;总结一下&#xff1a; Reflection&#xff08;反射&#xff09;是什么&#xff1f; 反射是对类型的自我检查能力&#xff08;Introspection&#xff09; 可以查看类的成员变量、成员函数等信息。反射允许枚…...

【Vue】scoped+组件通信+props校验

【scoped作用及原理】 【作用】 默认写在组件中style的样式会全局生效, 因此很容易造成多个组件之间的样式冲突问题 故而可以给组件加上scoped 属性&#xff0c; 令样式只作用于当前组件的标签 作用&#xff1a;防止不同vue组件样式污染 【原理】 给组件加上scoped 属性后…...