抑制过拟合——Dropout原理
抑制过拟合——Dropout原理
- Dropout的工作原理
- 实验观察
在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和误差,而不仅仅是底层模式。具体来说,这在神经网络训练中尤为常见,表现为在训练数据上表现优异(例如损失函数值很小,预测准确率高)而在未见过的数据(测试集)上表现不佳。
过拟合不仅是机器学习新手容易遇到的问题,即使是经验丰富的从业者也会面临这一挑战。一个典型的解决方案是采用模型集成技术,这涉及训练多个模型并将它们的预测结合起来。但这种方法的缺点是显而易见的:它既耗时又昂贵,不仅在训练阶段,而且在模型评估和部署时也是如此。
在这种背景下,Dropout
作为一种有效的正则化技术,可以显著减轻过拟合问题。它的基本原理是在每次训练迭代中随机“丢弃”(即暂时移除)网络中的一部分神经元。这种方法不仅简单,而且被证明在许多情况下都非常有效。
Dropout的工作原理
在 PyTorch
中,Dropout
层的使用相当直观。通常,它被添加到神经网络的各个层之间,如下所示:
torch.nn.Dropout(p=0.5, inplace=False)
p
:这是一个关键参数,代表着每个神经元被丢弃的概率。
在实践中,这意味着对于网络中的每个神经元,它在每次训练迭代中都有 1 − p 1-p 1−p 的概率被保留, p p p 的概率被丢弃。值得注意的是,这种随机性确保了每个mini-batch都在对不完全相同的网络进行训练,从而减少过拟合的风险。
在训练期间,对于每个训练样本,网络中的每个神经元都有概率 1 − p 1-p 1−p 被保留,概率 p p p 被丢弃。如果神经元被保留,则其输出乘以 1 1 − p \frac{1}{1-p} 1−p1(这样做是为了保持该层输出的总期望值不变)。设 r j r_j rj 为一个随机变量,它对应于第 j j j 个神经元,且服从伯努利分布(即 r j = 1 r_j = 1 rj=1 的概率为 1 − p 1-p 1−p, r j = 0 r_j = 0 rj=0 的概率为 p p p)。那么在训练时,神经元的输出 y j y_j yj变为 r j × y j / ( 1 − p ) r_j \times y_j / (1-p) rj×yj/(1−p)。
为什么需要保持期望不变? 举个简单的例子,假设某层有两个神经元,它们的输出在没有dropout时都是1。在应用了50%的dropout后,期望只有一个神经元被激活,输出为1,另一个被丢弃,输出为0。这样,这层的平均输出变成了0.5。为了保持输出的总期望值不变,激活的神经元的输出应该乘以2,即 1 1 − p \frac{1}{1-p} 1−p1,这样平均输出才能保持为1,与没有应用dropout时相同。这样的处理有助于保持整个网络的稳定性和一致性。
在模型预测(或测试)阶段,所有的神经元都保持激活(即不进行dropout)。因为在训练阶段,神经元的输出已经被放大了 1 1 − p \frac{1}{1-p} 1−p1 倍,所以在预测时不需要进行任何调整,直接使用网络进行前向传播即可。
实验观察
为了更深入地理解 Dropout
的影响,我们可以通过一个实验来观察不同的 Dropout
设置对训练过程的影响。比如,可以比较 Dropout = 0.1
和 Dropout = 0
在训练过程中的表现差异,相关代码实现如下:
import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import timeclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linears = nn.Sequential(nn.Linear(2, 20),nn.Linear(20, 20),nn.Dropout(0.1),nn.Linear(20, 20),nn.Linear(20, 20),nn.Linear(20, 1),)def forward(self, x):_ = self.linears(x)return _lr = 0.01
iteration = 1000x1 = torch.arange(-10, 10).float()
x2 = torch.arange(0, 20).float()
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = 2*x1 - x2**2 + 1model = Model()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()start_time = time.time()
writer = SummaryWriter(comment='_随机失活')for iter in range(iteration):y_pred = model(x)loss = loss_function(y, y_pred.squeeze())loss.backward()for name, layer in model.named_parameters():writer.add_histogram(name + '_grad', layer.grad, iter)writer.add_histogram(name + '_data', layer, iter)writer.add_scalar('loss', loss, iter)optimizer.step()optimizer.zero_grad()if iter % 50 == 0:print("iter: ", iter)print("Time: ", time.time() - start_time)
这里我们使用 TensorBoardX 进行结果的可视化展示。
通过观察模型训练1000轮后的线性层梯度分布,可以发现,应用 Dropout
后的模型梯度通常会更加分散和多样化。这种梯度的多样性有助于防止模型过于依赖训练数据中的特定模式,从而减轻过拟合。
同样值得注意的是,模型的损失曲线也会受到影响。加入 Dropout
通常会使损失曲线出现更多的波动(例如,图中的蓝色曲线),这反映了模型在学习过程中的不稳定性。然而,这种不稳定性通常是可接受的,因为它反映了模型正在学习更多的泛化模式而不是简单地记住训练数据。
相关文章:

抑制过拟合——Dropout原理
抑制过拟合——Dropout原理 Dropout的工作原理 实验观察 在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和…...

开发板启动进入系统以后再挂载 NFS 文件系统, 这里的NFS文件系统是根据正点原子教程制作的ubuntu_rootfs
如果是想开发板启动进入系统以后再挂载 NFS 文件系统,开发板启动进入文件系统,开发板和 ubuntu 能互相 ping 通,在开发板文件系统下新建一个目录 you,然后执行如下指令进行挂载: mkdir mi mount -t nfs -o nolock,nfsv…...
Ubuntu系统执行“docker ps“出现“permission denied“
当我们安装好Ubuntu时,使用鱼香ros一键安装指令 wget http://fishros.com/install -O fishros && . fishros 一键安装Docker后,执行"docker ps"出现"permission denied" seelina:~$ docker ps permission denied while …...
Python与设计模式--桥梁模式
23种计模式之 前言 (5)单例模式、工厂模式、简单工厂模式、抽象工厂模式、建造者模式、原型模式、(7)代理模式、装饰器模式、适配器模式、门面模式、组合模式、享元模式、桥梁模式、(11)策略模式、责任链模式、命令模式、中介者模…...

Linux下查看目录大小
查看目录大小 Linux下查看当前目录大小,可用一下命令: du -h --max-depth1它会从下到大的显示文件的大小。...

鸿蒙原生应用/元服务开发-AGC分发如何下载管理Profile
一、收到通知 尊敬的开发者: 您好,为支撑鸿蒙生态发展,HUAWEI AppGallery Connect已于X月XX日完成存量HarmonyOS应用/元服务的Profile文件更新,更新后Profile文件中已扩展App ID信息;后续上架流程会检测API9以上Harm…...
解决warning: #188-D: enumerated type mixed with another type问题
出现问题处如下, 指示在代码的某处将枚举类型与另一种类型混合使用,这种警告通常在将枚举类型与其他类型进行操作或赋值时出现 enum Mode {MODE_IDLE,MODE_1,MODE_2,MODE_3,MODE_4, }; enum Mode currentMode MODE_IDLE;currentMode (currentMode 1)…...
docker的知识点,以及使用
Docker 是一个开源的应用容器引擎,可以让开发者将应用程序及其依赖项打包至一个可移植的容器中,从而实现快速部署、可扩展和依赖项隔离等特性。下面是 Docker 的一些知识点以及使用方法: Docker 的组成部分包括 Docker 引擎、Docker 镜像、Do…...
WTM(基于Blazor)问题处理记录
问题描述一 有个需求,需要访问内网网络共享文件夹中的文件,有域控限制。 一开始直接在本地映射一个网络驱动器,然后像本地磁盘一样访问共享文件夹里的文件,比如:Y:\ 。 然后直接在程序中访问共享文件夹中的文件,如下代码: DirectoryInfo directoryInfo = new Direct…...
ubuntu 安装 towhee
安装Towhee pip3 install towhee如果你想在 towhee 中安装模型 pip3 install towhee.models打开python终端 python3引入towhee 数据转换是 Towhee 的核心;管道只是在有向无环图中连接在一起的一系列转换。所有预构建的 Towhee 管道都有代表当前任务的名称。 fr…...

ERP软件对Oracle安全产品的支持
这里的ERP软件仅指SAP ECC和Oracle EBS。 先来看Oracle EBS: EBS的认证查询方式,和数据库认证是一样的。这个体验到时不错。 结果中和安全相关的有: Oracle Database VaultTransparent Data Encryption TDE被支持很容易理解,…...

Linux 基础-常用的命令和搭建 Java 部署环境
文章目录 目录相关查看目录中的内容查看目录当前的完整路径切换目录 文件相关创建文件查看文件内容写文件vim 基础 创建删除创建目录 移动和复制移动(剪切粘贴)复制(复制粘贴) 搭建 Java 部署环境1. 安装 jdk2. 安装 tomcat1). 我们在自己电脑上下好 tomcat2). 从官网下载的 .z…...
c语言总结(解题方法)
项目前期处理: 1.首先需要确定项目的背景知识,即主要的难点知识,如指针,数组,结构体,以检索自己是否对项目所需的背景知识足够了解。 2.确定问题实现方法,即题目本身的实现方法,在c语…...
Webpack的ts的配置详细教程
文章目录 前言ts是什么?基础配置LoaderSource MapsClient types使用第三方类库导入其他资源 后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:webpack 🐱👓博主在前端领域还有很多知识和技术需要掌握…...

传智杯第五届题解
B.莲子的机械动力学 分析:这题有个小坑,如果是00 0,结果记得要输出0。 得到的教训是,避免前导0出现时,要注意答案为0的情况。否则有可能会没有输出 #include<assert.h> #include<cstdio> #include<…...
Android 通过demo调试节点权限问题
Android 通过demo调试节点权限问题 近来收到客户反馈提到在应用层无法控制节点,于是写了一个简单的demo来验证节点的IO权限,具体调试步骤就是写一个按钮点击事件,当点击按钮时将需要验证的节点写为1(节点默认为1则写为0ÿ…...

邮政快递物流查询,将指定某天签收的单号筛选出来
批量查询邮政快递单号的物流信息,将指定某天签收的单号筛选出来。 所需工具: 一个【快递批量查询高手】软件 邮政快递单号若干 操作步骤: 步骤1:运行【快递批量查询高手】软件,并登录 步骤2:点击主界面左…...
Java 8 lambda的一个编译bug
最近利用github action向Maven中央仓库发布企业微信SDK时会失败,从日志中发现是系统资源耗尽了,日志如下: [INFO] Changes detected - recompiling the module! :dependency [INFO] Compiling 35 source files with javac [debug target 8] …...

无人机覆盖路径规划综述
摘要:覆盖路径规划包括找到覆盖某个目标区域的每个点的路线。近年来,无人机已被应用于涉及地形覆盖的多个应用领域,如监视、智能农业、摄影测量、灾害管理、民事安全和野火跟踪等。本文旨在探索和分析文献中与覆盖路径规划问题中使用的不同方…...

【代码随想录】算法训练计划37
贪心 1、738. 单调递增的数字 题目: 输入: n 10 输出: 9 思路: func monotoneIncreasingDigits(n int) int {// 贪心,利用字符数组s : strconv.Itoa(n)ss : []byte(s)leng : len(ss)if leng < 1 {return n}for i:leng-1; i>0; i-- …...
【位运算】消失的两个数字(hard)
消失的两个数字(hard) 题⽬描述:解法(位运算):Java 算法代码:更简便代码 题⽬链接:⾯试题 17.19. 消失的两个数字 题⽬描述: 给定⼀个数组,包含从 1 到 N 所有…...
MySQL账号权限管理指南:安全创建账户与精细授权技巧
在MySQL数据库管理中,合理创建用户账号并分配精确权限是保障数据安全的核心环节。直接使用root账号进行所有操作不仅危险且难以审计操作行为。今天我们来全面解析MySQL账号创建与权限分配的专业方法。 一、为何需要创建独立账号? 最小权限原则…...
使用Matplotlib创建炫酷的3D散点图:数据可视化的新维度
文章目录 基础实现代码代码解析进阶技巧1. 自定义点的大小和颜色2. 添加图例和样式美化3. 真实数据应用示例实用技巧与注意事项完整示例(带样式)应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...
python爬虫——气象数据爬取
一、导入库与全局配置 python 运行 import json import datetime import time import requests from sqlalchemy import create_engine import csv import pandas as pd作用: 引入数据解析、网络请求、时间处理、数据库操作等所需库。requests:发送 …...
MySQL 主从同步异常处理
阅读原文:https://www.xiaozaoshu.top/articles/mysql-m-s-update-pk MySQL 做双主,遇到的这个错误: Could not execute Update_rows event on table ... Error_code: 1032是 MySQL 主从复制时的经典错误之一,通常表示ÿ…...
API网关Kong的鉴权与限流:高并发场景下的核心实践
🔥「炎码工坊」技术弹药已装填! 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 引言 在微服务架构中,API网关承担着流量调度、安全防护和协议转换的核心职责。作为云原生时代的代表性网关,Kong凭借其插件化架构…...

Visual Studio Code 扩展
Visual Studio Code 扩展 change-case 大小写转换EmmyLua for VSCode 调试插件Bookmarks 书签 change-case 大小写转换 https://marketplace.visualstudio.com/items?itemNamewmaurer.change-case 选中单词后,命令 changeCase.commands 可预览转换效果 EmmyLua…...

C++--string的模拟实现
一,引言 string的模拟实现是只对string对象中给的主要功能经行模拟实现,其目的是加强对string的底层了解,以便于在以后的学习或者工作中更加熟练的使用string。本文中的代码仅供参考并不唯一。 二,默认成员函数 string主要有三个成员变量,…...

GAN模式奔溃的探讨论文综述(一)
简介 简介:今天带来一篇关于GAN的,对于模式奔溃的一个探讨的一个问题,帮助大家更好的解决训练中遇到的一个难题。 论文题目:An in-depth review and analysis of mode collapse in GAN 期刊:Machine Learning 链接:...

五、jmeter脚本参数化
目录 1、脚本参数化 1.1 用户定义的变量 1.1.1 添加及引用方式 1.1.2 测试得出用户定义变量的特点 1.2 用户参数 1.2.1 概念 1.2.2 位置不同效果不同 1.2.3、用户参数的勾选框 - 每次迭代更新一次 总结用户定义的变量、用户参数 1.3 csv数据文件参数化 1、脚本参数化 …...