前馈神经网络dropout实例
直接看代码。
(一)手动实现
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt#下载MNIST手写数据集
mnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor()) #读取数据
batch_size = 256
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0) #初始化参数
num_inputs,num_hiddens,num_outputs =784, 256,10num_epochs=30lr = 0.001def init_param():W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32) b1 = torch.zeros(1, dtype=torch.float32) W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32) b2 = torch.zeros(1, dtype=torch.float32) params =[W1,b1,W2,b2]for param in params: param.requires_grad_(requires_grad=True) return W1,b1,W2,b2def dropout(X, drop_prob):X = X.float()assert 0 <= drop_prob <= 1keep_prob = 1 - drop_probif keep_prob == 0:return torch.zeros_like(X)mask = (torch.rand(X.shape) < keep_prob).float()print(mask)return mask * X / keep_probdef net(X, is_training=True):X = X.view(-1, num_inputs)H1 = (torch.matmul(X, W1.t()) + b1).relu()if is_training:H1 = dropout(H1, drop_prob)return (torch.matmul(H1,W2.t()) + b2).relu()def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None):train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter:l=loss(net(X),y)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:l=loss(net(X,is_training=False),y)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch+1)%10==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_lsdrop_probs = np.arange(0,1.1,0.1)Train_ls, Test_ls = [], []for drop_prob in drop_probs:W1,b1,W2,b2 = init_param()loss = nn.CrossEntropyLoss()optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer) Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(train_ls),len(train_ls))plt.figure(figsize=(10,8))for i in range(0,len(drop_probs)):plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)plt.xlabel('epoch')plt.ylabel('loss')# plt.legend()
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()
运行结果:

(二)torch.nn实现
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as pltmnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())
batch_size = 256
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0) class LinearNet(nn.Module):def __init__(self,num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob1,drop_prob2):super(LinearNet,self).__init__()self.linear1 = nn.Linear(num_inputs,num_hiddens1)self.relu = nn.ReLU()self.drop1 = nn.Dropout(drop_prob1)self.linear2 = nn.Linear(num_hiddens1,num_hiddens2)self.drop2 = nn.Dropout(drop_prob2)self.linear3 = nn.Linear(num_hiddens2,num_outputs)self.flatten = nn.Flatten()def forward(self,x):x = self.flatten(x)x = self.linear1(x)x = self.relu(x)x = self.drop1(x)x = self.linear2(x)x = self.relu(x)x = self.drop2(x)x = self.linear3(x)y = self.relu(x)return ydef train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter:l=loss(net(X),y)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:l=loss(net(X),y)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch+1)%5==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_ls num_inputs,num_hiddens1,num_hiddens2,num_outputs =784, 256,256,10
num_epochs=20
lr = 0.001
drop_probs = np.arange(0,1.1,0.1)
Train_ls, Test_ls = [], []for drop_prob in drop_probs:net = LinearNet(num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob,drop_prob)for param in net.parameters():nn.init.normal_(param,mean=0, std= 0.01)loss = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(net.parameters(),lr)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,net.parameters,lr,optimizer)Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(train_ls),len(train_ls))
plt.figure(figsize=(10,8))
for i in range(0,len(drop_probs)):plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)plt.xlabel('epoch')plt.ylabel('loss')
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()input = torch.randn(2, 5, 5)
m = nn.Sequential(
nn.Flatten()
)
output = m(input)
output.size()
运行结果:

关于dropout的原理,网上资料很多,一般都是用一个正态分布的矩阵,比较矩阵元素和(1-dropout),大于(1-dropout)的矩阵元素值的修正为1,小于(1-dropout)的改为1,将输入的值乘以修改后的矩阵,再除以(1-dropout)。
疑问:
- 数值经过正态分布矩阵的筛选后,还要除以 (1-dropout),这样做的原因是什么?
- Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。
相关文章:
前馈神经网络dropout实例
直接看代码。 (一)手动实现 import torch import torch.nn as nn import numpy as np import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt#下载MNIST手写数据集 mnist_train torchvision.datasets.MN…...
Android DataStore:安全存储和轻松管理数据
关于作者:CSDN内容合伙人、技术专家, 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 ,擅长java后端、移动开发、人工智能等,希望大家多多支持。 目录 一、导读二、概览三、使用3.1 Preferences DataStore添加依赖数据读…...
opencv进阶12-EigenFaces 人脸识别
EigenFaces 通常也被称为 特征脸,它使用主成分分析(Principal Component Analysis,PCA) 方法将高维的人脸数据处理为低维数据后(降维),再进行数据分析和处理,获取识别结果。 基本原理…...
The internal rate of return (IRR)
内部收益率 NPV(Net Present Value)_spencer_tseng的博客-CSDN博客...
半导体自动化专用静电消除器主要由哪些部分组成
半导体自动化专用静电消除器是一种用于消除半导体生产过程中的静电问题的设备。由于半导体制造过程中对静电的敏感性,静电可能会对半导体器件的质量和可靠性产生很大的影响,甚至造成元件损坏。因此,半导体生产中采用专用的静电消除器是非常重…...
【C++入门到精通】C++入门 —— deque(STL)
阅读导航 前言一、deque简介1. 概念2. 特点 二、deque使用1. 基本操作(增、删、查、改)2. 底层结构 三、deque的缺陷四、 为什么选择deque作为stack和queue的底层默认容器总结温馨提示 前言 文章绑定了VS平台下std::deque的源码,大家可以下载…...
Codeforces Round 893 (Div. 2) D.Trees and Segments
原题链接:Problem - D - Codeforces 题面: 大概意思就是让你在翻转01串不超过k次的情况下,使得a*(0的最大连续长度)(1的最大连续长度)最大(1<a<n)。输出n个数&…...
SpringBoot + Vue 前后端分离项目 微人事(九)
职位管理后端接口设计 在controller包里面新建system包,再在system包里面新建basic包,再在basic包里面创建PositionController类,在定义PositionController类的接口的时候,一定要与数据库的menu中的url地址到一致,不然…...
【业务功能篇71】Cglib的BeanCopier进行Bean对象拷贝
选择Cglib的BeanCopier进行Bean拷贝的理由是, 其性能要比Spring的BeanUtils,Apache的BeanUtils和PropertyUtils要好很多, 尤其是数据量比较大的情况下。 BeanCopier的主要作用是将数据库层面的Entity转化成service层的POJO。BeanCopier其实已…...
让eslint的错误信息显示在项目界面上
1.需求描述 效果如下 让eslint中的错误,显示在项目界面上 2.问题解决 1.安装 vite-plugin-eslint 插件 npm install vite-plugin-eslint --save-dev2.配置插件 // vite.config.js import { defineConfig } from vite import vue from vitejs/plugin-vue import e…...
手摸手带你实现一个开箱即用的Node邮件推送服务
目录 编辑 前言 准备工作 邮箱配置 代码实现 服务部署 使用效果 题外话 写在最后 相关代码: 前言 由于邮箱账号和手机号的唯一性,通常实现验证码的校验时比较常用的两种方式是手机短信推送和邮箱推送,此外,邮件推送服…...
【Linux网络】网络编程套接字 -- 基于socket实现一个简单UDP网络程序
认识端口号网络字节序处理字节序函数 htonl、htons、ntohl、ntohs socketsocket编程接口sockaddr结构结尾实现UDP程序的socket接口使用解析socket处理 IP 地址的函数初始化sockaddr_inbindrecvfromsendto 实现一个简单的UDP网络程序封装服务器相关代码封装客户端相关代码实验结…...
Python学习笔记第六十四天(Matplotlib 网格线)
Python学习笔记第六十四天 Matplotlib 网格线普通网格线样式网格线 后记 Matplotlib 网格线 我们可以使用 pyplot 中的 grid() 方法来设置图表中的网格线。 grid() 方法语法格式如下: matplotlib.pyplot.grid(bNone, whichmajor, axisboth, )参数说明:…...
机器学习与模式识别3(线性回归与逻辑回归)
一、线性回归与逻辑回归简介 线性回归主要功能是拟合数据,常用平方误差函数。 逻辑回归主要功能是区分数据,找到决策边界,常用交叉熵。 二、线性回归与逻辑回归的实现 1.线性回归 利用回归方程对一个或多个特征值和目标值之间的关系进行建模…...
vue启动配置npm run serve,动态环境变量,根据不同环境访问不同域名
首先创建不同环境的配置文件,比如域名和一些常量,创建一个env文件,先看看文件目录 env.dev就是dev环境的域名,.test就是test环境域名,其他同理,然后配置package.json文件 {"name": "require-admin&qu…...
HTML <strike> 标签
HTML5 中不支持 <strike> 标签在 HTML 4 中用于定义删除线文本。 定义和用法 <strike> 标签可定义加删除线文本定义。 浏览器支持 元素ChromeIEFirefoxSafariOpera<strike>YesYesYesYesYes 所有浏览器都支持 <strike> 标签。 HTML 与 XHTML 之间…...
数学建模-模型详解(1)
规划模型 线性规划模型: 当涉及到线性规划模型实例时,以下是一个简单的示例: 假设我们有两个变量 x 和 y,并且我们希望最大化目标函数 Z 5x 3y,同时满足以下约束条件: x > 0y > 02x y < 10…...
MySQL 数据库表的基本操作
一、数据库表概述 在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。 二、数…...
企业微信电脑端开启chrome调试
首先: Mac端调试开启的快捷键:control shift command d Window端调试开启的快捷键: control shift alt d 这边以Mac为例,我们可以在电脑顶部看到调试的入口: 然后我们点击 『浏览器、webView相关』菜单,勾选上…...
Maven官网下载配置新仓库
1.Maven的下载 Maven的官网地址:Maven – Download Apache Maven 点击Download,查找 Files下的版本并下载如下图: 2.Maven的配置 自己在D盘或者E盘创建一个文件夹,作为本地仓库,存放项目依赖。 将下载好的zip文件进行解…...
Connect to Oracle Database with JDBC Driver
1. Overview The Oracle Database is one of the most popular relational databases. In this tutorial, we’ll learn how to connect to an Oracle Database using a JDBC Driver. 2. The Database To get us started, we need a database. If we don’t have access to …...
Simulink频域分析避坑指南:如何准确获取谐振频率(含MATLAB代码)
Simulink频域分析实战:谐振频率精准提取方法论与MATLAB实现 在控制系统设计与分析领域,频域特性是评估系统动态性能的核心指标之一。而谐振频率作为频域响应中的关键特征点,直接影响着系统的稳定性和响应速度。然而,许多工程师在使…...
YOLOv11涨点改进| 全网独家创新、检测头Head改进篇| CVPR 2026顶会 |使用FAAHead改进YOLOv11的检测头,处理小目标、遮挡小目标检测、旋转目标检测有效涨点,助力高效发论文
一、本文介绍 🔥本文给大家介绍使用CVPR 2026顶会 FAAHead 和 OBB_FAAHead 改进 YOLOv11的检测头,可以有效缓解目标检测中分类分支与框回归分支之间的特征冲突问题,尤其适合旋转目标检测或含明显方向信息的目标检测任务。FAAHead 的核心思想是在检测头阶段先对 RoI 或候选…...
实战:利用大模型预测 2026 年最热门的‘长尾提问’并提前进行 GEO 占位
各位编程领域的同仁、技术爱好者,大家好!今天,我们齐聚一堂,探讨一个既前沿又极具实战价值的议题:如何利用大模型(Large Language Models, LLMs)的强大能力,预测2026年可能成为热点的…...
当心“Pin-to-Pin兼容“陷阱:ICM-42688国产替代芯片深度拆解与避坑指南
两句话总结:近期TDK ICM-42688-P价格暴涨至百元且一芯难求,立创商城上出现了华轩阳、Tokmas等"国产替代"。本文通过详细对比三家datasheet数据手册,揭示所谓"兼容"背后的软件陷阱与性能差异。结论可能出乎你意料…...
开源动作捕捉与3D数据采集:FreeMoCap如何颠覆传统动捕方案
开源动作捕捉与3D数据采集:FreeMoCap如何颠覆传统动捕方案 【免费下载链接】freemocap Free Motion Capture for Everyone 💀✨ 项目地址: https://gitcode.com/GitHub_Trending/fr/freemocap 在游戏开发、动画制作和运动科学研究领域,…...
告别卡顿!GSYVideoPlayer的ExoPlayer内核配置全攻略(支持HLS/m3u8直播流)
GSYVideoPlayer的ExoPlayer内核深度调优:打造极致流畅的HLS直播体验 去年接手一个海外直播项目时,遇到最头疼的问题就是m3u8流媒体的卡顿和延迟。测试了各种方案后,最终通过GSYVideoPlayer的ExoPlayer内核解决了这个难题。今天就把这些实战经…...
从Siwave导入模型到Q3D仿真,如何避免‘幽灵’solder导致的网络报错?
从Siwave到Q3D的模型迁移:彻底解决"幽灵焊料"引发的网络冲突 当你在Ansys电子设计自动化工具链中切换工作环境时,是否遇到过这样的困扰:从Siwave精心准备的模型导入Q3D后,突然冒出各种莫名其妙的网络重叠报错ÿ…...
服务器风扇静音改造:揭秘线序定义的通用破解技巧——以IBM SystemX 3630 M4为案例
1. 为什么服务器风扇这么吵? 服务器风扇的噪音问题困扰着很多运维人员和家庭实验室用户。我拆解过几十台不同品牌的服务器,发现这个问题的根源在于服务器的散热设计理念与家用电脑完全不同。 服务器在设计时优先考虑的是稳定性和散热效率,而不…...
ARMv8、AArch64 与 arm64:命名与体系结构要点
ARMv8、AArch64 与 arm64:命名与体系结构要点 ARMv8 指 ARM 架构的一个主版本代际;AArch64 是该代际下的 64 位执行状态与 A64 指令集;arm64 与 aarch64 是操作系统与工具链中对 AArch64 的常用三元组/目录名,二进制约定一致。下…...
