当前位置: 首页 > news >正文

抑制过拟合——Dropout原理

抑制过拟合——Dropout原理

  • Dropout的工作原理
  • 实验观察

  在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和误差,而不仅仅是底层模式。具体来说,这在神经网络训练中尤为常见,表现为在训练数据上表现优异(例如损失函数值很小,预测准确率高)而在未见过的数据(测试集)上表现不佳。

  过拟合不仅是机器学习新手容易遇到的问题,即使是经验丰富的从业者也会面临这一挑战。一个典型的解决方案是采用模型集成技术,这涉及训练多个模型并将它们的预测结合起来。但这种方法的缺点是显而易见的:它既耗时又昂贵,不仅在训练阶段,而且在模型评估和部署时也是如此。

  在这种背景下,Dropout 作为一种有效的正则化技术,可以显著减轻过拟合问题。它的基本原理是在每次训练迭代中随机“丢弃”(即暂时移除)网络中的一部分神经元。这种方法不仅简单,而且被证明在许多情况下都非常有效。

Dropout的工作原理

  在 PyTorch 中,Dropout 层的使用相当直观。通常,它被添加到神经网络的各个层之间,如下所示:

torch.nn.Dropout(p=0.5, inplace=False)

  p:这是一个关键参数,代表着每个神经元被丢弃的概率。

  在实践中,这意味着对于网络中的每个神经元,它在每次训练迭代中都有 1 − p 1-p 1p 的概率被保留, p p p 的概率被丢弃。值得注意的是,这种随机性确保了每个mini-batch都在对不完全相同的网络进行训练,从而减少过拟合的风险。

  在训练期间,对于每个训练样本,网络中的每个神经元都有概率 1 − p 1-p 1p 被保留,概率 p p p 被丢弃。如果神经元被保留,则其输出乘以 1 1 − p \frac{1}{1-p} 1p1​(这样做是为了保持该层输出的总期望值不变)。设 r j r_j rj​ 为一个随机变量,它对应于第 j j j 个神经元,且服从伯努利分布(即 r j = 1 r_j = 1 rj=1 的概率为 1 − p 1-p 1p r j = 0 r_j = 0 rj=0 的概率为 p p p)。那么在训练时,神经元的输出 y j y_j yj变为 r j × y j / ( 1 − p ) r_j \times y_j / (1-p) rj×yj/(1p)

为什么需要保持期望不变? 举个简单的例子,假设某层有两个神经元,它们的输出在没有dropout时都是1。在应用了50%的dropout后,期望只有一个神经元被激活,输出为1,另一个被丢弃,输出为0。这样,这层的平均输出变成了0.5。为了保持输出的总期望值不变,激活的神经元的输出应该乘以2,即 1 1 − p \frac{1}{1-p} 1p1​,这样平均输出才能保持为1,与没有应用dropout时相同。这样的处理有助于保持整个网络的稳定性和一致性。

  在模型预测(或测试)阶段,所有的神经元都保持激活(即不进行dropout)。因为在训练阶段,神经元的输出已经被放大了 1 1 − p \frac{1}{1-p} 1p1 倍,所以在预测时不需要进行任何调整,直接使用网络进行前向传播即可。

在这里插入图片描述

实验观察

  为了更深入地理解 Dropout 的影响,我们可以通过一个实验来观察不同的 Dropout 设置对训练过程的影响。比如,可以比较 Dropout = 0.1Dropout = 0 在训练过程中的表现差异,相关代码实现如下:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import timeclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linears = nn.Sequential(nn.Linear(2, 20),nn.Linear(20, 20),nn.Dropout(0.1),nn.Linear(20, 20),nn.Linear(20, 20),nn.Linear(20, 1),)def forward(self, x):_ = self.linears(x)return _lr = 0.01
iteration = 1000x1 = torch.arange(-10, 10).float()
x2 = torch.arange(0, 20).float()
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = 2*x1 - x2**2 + 1model = Model()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()start_time = time.time()
writer = SummaryWriter(comment='_随机失活')for iter in range(iteration):y_pred = model(x)loss = loss_function(y, y_pred.squeeze())loss.backward()for name, layer in model.named_parameters():writer.add_histogram(name + '_grad', layer.grad, iter)writer.add_histogram(name + '_data', layer, iter)writer.add_scalar('loss', loss, iter)optimizer.step()optimizer.zero_grad()if iter % 50 == 0:print("iter: ", iter)print("Time: ", time.time() - start_time)

这里我们使用 TensorBoardX 进行结果的可视化展示。

  通过观察模型训练1000轮后的线性层梯度分布,可以发现,应用 Dropout 后的模型梯度通常会更加分散和多样化。这种梯度的多样性有助于防止模型过于依赖训练数据中的特定模式,从而减轻过拟合。

在这里插入图片描述

  同样值得注意的是,模型的损失曲线也会受到影响。加入 Dropout 通常会使损失曲线出现更多的波动(例如,图中的蓝色曲线),这反映了模型在学习过程中的不稳定性。然而,这种不稳定性通常是可接受的,因为它反映了模型正在学习更多的泛化模式而不是简单地记住训练数据。

在这里插入图片描述

相关文章:

抑制过拟合——Dropout原理

抑制过拟合——Dropout原理 Dropout的工作原理 实验观察 在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和…...

开发板启动进入系统以后再挂载 NFS 文件系统, 这里的NFS文件系统是根据正点原子教程制作的ubuntu_rootfs

如果是想开发板启动进入系统以后再挂载 NFS 文件系统,开发板启动进入文件系统,开发板和 ubuntu 能互相 ping 通,在开发板文件系统下新建一个目录 you,然后执行如下指令进行挂载: mkdir mi mount -t nfs -o nolock,nfsv…...

Ubuntu系统执行“docker ps“出现“permission denied“

当我们安装好Ubuntu时,使用鱼香ros一键安装指令 wget http://fishros.com/install -O fishros && . fishros 一键安装Docker后,执行"docker ps"出现"permission denied" seelina:~$ docker ps permission denied while …...

Python与设计模式--桥梁模式

23种计模式之 前言 (5)单例模式、工厂模式、简单工厂模式、抽象工厂模式、建造者模式、原型模式、(7)代理模式、装饰器模式、适配器模式、门面模式、组合模式、享元模式、桥梁模式、(11)策略模式、责任链模式、命令模式、中介者模…...

Linux下查看目录大小

查看目录大小 Linux下查看当前目录大小,可用一下命令: du -h --max-depth1它会从下到大的显示文件的大小。...

鸿蒙原生应用/元服务开发-AGC分发如何下载管理Profile

一、收到通知 尊敬的开发者: 您好,为支撑鸿蒙生态发展,HUAWEI AppGallery Connect已于X月XX日完成存量HarmonyOS应用/元服务的Profile文件更新,更新后Profile文件中已扩展App ID信息;后续上架流程会检测API9以上Harm…...

解决warning: #188-D: enumerated type mixed with another type问题

出现问题处如下, 指示在代码的某处将枚举类型与另一种类型混合使用,这种警告通常在将枚举类型与其他类型进行操作或赋值时出现 enum Mode {MODE_IDLE,MODE_1,MODE_2,MODE_3,MODE_4, }; enum Mode currentMode MODE_IDLE;currentMode (currentMode 1)…...

docker的知识点,以及使用

Docker 是一个开源的应用容器引擎,可以让开发者将应用程序及其依赖项打包至一个可移植的容器中,从而实现快速部署、可扩展和依赖项隔离等特性。下面是 Docker 的一些知识点以及使用方法: Docker 的组成部分包括 Docker 引擎、Docker 镜像、Do…...

WTM(基于Blazor)问题处理记录

问题描述一 有个需求,需要访问内网网络共享文件夹中的文件,有域控限制。 一开始直接在本地映射一个网络驱动器,然后像本地磁盘一样访问共享文件夹里的文件,比如:Y:\ 。 然后直接在程序中访问共享文件夹中的文件,如下代码: DirectoryInfo directoryInfo = new Direct…...

ubuntu 安装 towhee

安装Towhee pip3 install towhee如果你想在 towhee 中安装模型 pip3 install towhee.models打开python终端 python3引入towhee 数据转换是 Towhee 的核心;管道只是在有向无环图中连接在一起的一系列转换。所有预构建的 Towhee 管道都有代表当前任务的名称。 fr…...

ERP软件对Oracle安全产品的支持

这里的ERP软件仅指SAP ECC和Oracle EBS。 先来看Oracle EBS: EBS的认证查询方式,和数据库认证是一样的。这个体验到时不错。 结果中和安全相关的有: Oracle Database VaultTransparent Data Encryption TDE被支持很容易理解,…...

Linux 基础-常用的命令和搭建 Java 部署环境

文章目录 目录相关查看目录中的内容查看目录当前的完整路径切换目录 文件相关创建文件查看文件内容写文件vim 基础 创建删除创建目录 移动和复制移动(剪切粘贴)复制(复制粘贴) 搭建 Java 部署环境1. 安装 jdk2. 安装 tomcat1). 我们在自己电脑上下好 tomcat2). 从官网下载的 .z…...

c语言总结(解题方法)

项目前期处理: 1.首先需要确定项目的背景知识,即主要的难点知识,如指针,数组,结构体,以检索自己是否对项目所需的背景知识足够了解。 2.确定问题实现方法,即题目本身的实现方法,在c语…...

Webpack的ts的配置详细教程

文章目录 前言ts是什么?基础配置LoaderSource MapsClient types使用第三方类库导入其他资源 后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:webpack 🐱‍👓博主在前端领域还有很多知识和技术需要掌握…...

传智杯第五届题解

B.莲子的机械动力学 分析&#xff1a;这题有个小坑&#xff0c;如果是00 0&#xff0c;结果记得要输出0。 得到的教训是&#xff0c;避免前导0出现时&#xff0c;要注意答案为0的情况。否则有可能会没有输出 #include<assert.h> #include<cstdio> #include<…...

Android 通过demo调试节点权限问题

Android 通过demo调试节点权限问题 近来收到客户反馈提到在应用层无法控制节点&#xff0c;于是写了一个简单的demo来验证节点的IO权限&#xff0c;具体调试步骤就是写一个按钮点击事件&#xff0c;当点击按钮时将需要验证的节点写为1&#xff08;节点默认为1则写为0&#xff…...

邮政快递物流查询,将指定某天签收的单号筛选出来

批量查询邮政快递单号的物流信息&#xff0c;将指定某天签收的单号筛选出来。 所需工具&#xff1a; 一个【快递批量查询高手】软件 邮政快递单号若干 操作步骤&#xff1a; 步骤1&#xff1a;运行【快递批量查询高手】软件&#xff0c;并登录 步骤2&#xff1a;点击主界面左…...

Java 8 lambda的一个编译bug

最近利用github action向Maven中央仓库发布企业微信SDK时会失败&#xff0c;从日志中发现是系统资源耗尽了&#xff0c;日志如下&#xff1a; [INFO] Changes detected - recompiling the module! :dependency [INFO] Compiling 35 source files with javac [debug target 8] …...

无人机覆盖路径规划综述

摘要&#xff1a;覆盖路径规划包括找到覆盖某个目标区域的每个点的路线。近年来&#xff0c;无人机已被应用于涉及地形覆盖的多个应用领域&#xff0c;如监视、智能农业、摄影测量、灾害管理、民事安全和野火跟踪等。本文旨在探索和分析文献中与覆盖路径规划问题中使用的不同方…...

【代码随想录】算法训练计划37

贪心 1、738. 单调递增的数字 题目&#xff1a; 输入: n 10 输出: 9 思路&#xff1a; func monotoneIncreasingDigits(n int) int {// 贪心&#xff0c;利用字符数组s : strconv.Itoa(n)ss : []byte(s)leng : len(ss)if leng < 1 {return n}for i:leng-1; i>0; i-- …...

TDengine 快速体验(Docker 镜像方式)

简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能&#xff0c;本节首先介绍如何通过 Docker 快速体验 TDengine&#xff0c;然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker&#xff0c;请使用 安装包的方式快…...

大话软工笔记—需求分析概述

需求分析&#xff0c;就是要对需求调研收集到的资料信息逐个地进行拆分、研究&#xff0c;从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要&#xff0c;后续设计的依据主要来自于需求分析的成果&#xff0c;包括: 项目的目的…...

C++:std::is_convertible

C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)

笔记整理&#xff1a;刘治强&#xff0c;浙江大学硕士生&#xff0c;研究方向为知识图谱表示学习&#xff0c;大语言模型 论文链接&#xff1a;http://arxiv.org/abs/2407.16127 发表会议&#xff1a;ISWC 2024 1. 动机 传统的知识图谱补全&#xff08;KGC&#xff09;模型通过…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言&#xff1a;为什么 Eureka 依然是存量系统的核心&#xff1f; 尽管 Nacos 等新注册中心崛起&#xff0c;但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制&#xff0c;是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

vue3 定时器-定义全局方法 vue+ts

1.创建ts文件 路径&#xff1a;src/utils/timer.ts 完整代码&#xff1a; import { onUnmounted } from vuetype TimerCallback (...args: any[]) > voidexport function useGlobalTimer() {const timers: Map<number, NodeJS.Timeout> new Map()// 创建定时器con…...

css的定位(position)详解:相对定位 绝对定位 固定定位

在 CSS 中&#xff0c;元素的定位通过 position 属性控制&#xff0c;共有 5 种定位模式&#xff1a;static&#xff08;静态定位&#xff09;、relative&#xff08;相对定位&#xff09;、absolute&#xff08;绝对定位&#xff09;、fixed&#xff08;固定定位&#xff09;和…...

select、poll、epoll 与 Reactor 模式

在高并发网络编程领域&#xff0c;高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表&#xff0c;以及基于它们实现的 Reactor 模式&#xff0c;为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。​ 一、I…...

Java 二维码

Java 二维码 **技术&#xff1a;**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...