当前位置: 首页 > news >正文

前馈神经网络dropout实例

直接看代码。

(一)手动实现


import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt#下载MNIST手写数据集  
mnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  #读取数据  
batch_size = 256 
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  #初始化参数  
num_inputs,num_hiddens,num_outputs =784, 256,10num_epochs=30lr = 0.001def init_param():W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  b1 = torch.zeros(1, dtype=torch.float32)  W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  b2 = torch.zeros(1, dtype=torch.float32)  params =[W1,b1,W2,b2]for param in params:  param.requires_grad_(requires_grad=True)  return W1,b1,W2,b2def dropout(X, drop_prob):X = X.float()assert 0 <= drop_prob <= 1keep_prob = 1 - drop_probif keep_prob == 0:return torch.zeros_like(X)mask = (torch.rand(X.shape) < keep_prob).float()print(mask)return mask * X / keep_probdef net(X, is_training=True):X = X.view(-1, num_inputs)H1 = (torch.matmul(X, W1.t()) + b1).relu()if is_training:H1 = dropout(H1, drop_prob)return (torch.matmul(H1,W2.t()) + b2).relu()def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None):train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter:l=loss(net(X),y)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:l=loss(net(X,is_training=False),y)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch+1)%10==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_lsdrop_probs = np.arange(0,1.1,0.1)Train_ls, Test_ls = [], []for drop_prob in drop_probs:W1,b1,W2,b2 = init_param()loss = nn.CrossEntropyLoss()optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)train_ls, test_ls =  train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer)   Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(train_ls),len(train_ls))plt.figure(figsize=(10,8))for i in range(0,len(drop_probs)):plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)plt.xlabel('epoch')plt.ylabel('loss')# plt.legend()
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()

运行结果:

在这里插入图片描述

(二)torch.nn实现

import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as pltmnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  
batch_size = 256 
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  class LinearNet(nn.Module):def __init__(self,num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob1,drop_prob2):super(LinearNet,self).__init__()self.linear1 = nn.Linear(num_inputs,num_hiddens1)self.relu = nn.ReLU()self.drop1 = nn.Dropout(drop_prob1)self.linear2 = nn.Linear(num_hiddens1,num_hiddens2)self.drop2 = nn.Dropout(drop_prob2)self.linear3 = nn.Linear(num_hiddens2,num_outputs)self.flatten  = nn.Flatten()def forward(self,x):x = self.flatten(x)x = self.linear1(x)x = self.relu(x)x = self.drop1(x)x = self.linear2(x)x = self.relu(x)x = self.drop2(x)x = self.linear3(x)y = self.relu(x)return ydef train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter:l=loss(net(X),y)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:l=loss(net(X),y)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch+1)%5==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_ls    num_inputs,num_hiddens1,num_hiddens2,num_outputs =784, 256,256,10
num_epochs=20
lr = 0.001
drop_probs = np.arange(0,1.1,0.1)
Train_ls, Test_ls = [], []for drop_prob in drop_probs:net = LinearNet(num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob,drop_prob)for param in net.parameters():nn.init.normal_(param,mean=0, std= 0.01)loss = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(net.parameters(),lr)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,net.parameters,lr,optimizer)Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(train_ls),len(train_ls))
plt.figure(figsize=(10,8))
for i in range(0,len(drop_probs)):plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)plt.xlabel('epoch')plt.ylabel('loss')
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()input = torch.randn(2, 5, 5)
m = nn.Sequential(
nn.Flatten()
)
output = m(input)
output.size()

运行结果:

在这里插入图片描述

关于dropout的原理,网上资料很多,一般都是用一个正态分布的矩阵,比较矩阵元素和(1-dropout),大于(1-dropout)的矩阵元素值的修正为1,小于(1-dropout)的改为1,将输入的值乘以修改后的矩阵,再除以(1-dropout)。

疑问:

  1. 数值经过正态分布矩阵的筛选后,还要除以 (1-dropout),这样做的原因是什么?
  2. Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。

相关文章:

前馈神经网络dropout实例

直接看代码。 &#xff08;一&#xff09;手动实现 import torch import torch.nn as nn import numpy as np import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt#下载MNIST手写数据集 mnist_train torchvision.datasets.MN…...

Android DataStore:安全存储和轻松管理数据

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、人工智能等&#xff0c;希望大家多多支持。 目录 一、导读二、概览三、使用3.1 Preferences DataStore添加依赖数据读…...

opencv进阶12-EigenFaces 人脸识别

EigenFaces 通常也被称为 特征脸&#xff0c;它使用主成分分析&#xff08;Principal Component Analysis&#xff0c;PCA&#xff09; 方法将高维的人脸数据处理为低维数据后&#xff08;降维&#xff09;&#xff0c;再进行数据分析和处理&#xff0c;获取识别结果。 基本原理…...

The internal rate of return (IRR)

内部收益率 NPV(Net Present Value)_spencer_tseng的博客-CSDN博客...

半导体自动化专用静电消除器主要由哪些部分组成

半导体自动化专用静电消除器是一种用于消除半导体生产过程中的静电问题的设备。由于半导体制造过程中对静电的敏感性&#xff0c;静电可能会对半导体器件的质量和可靠性产生很大的影响&#xff0c;甚至造成元件损坏。因此&#xff0c;半导体生产中采用专用的静电消除器是非常重…...

【C++入门到精通】C++入门 —— deque(STL)

阅读导航 前言一、deque简介1. 概念2. 特点 二、deque使用1. 基本操作&#xff08;增、删、查、改&#xff09;2. 底层结构 三、deque的缺陷四、 为什么选择deque作为stack和queue的底层默认容器总结温馨提示 前言 文章绑定了VS平台下std::deque的源码&#xff0c;大家可以下载…...

Codeforces Round 893 (Div. 2) D.Trees and Segments

原题链接&#xff1a;Problem - D - Codeforces 题面&#xff1a; 大概意思就是让你在翻转01串不超过k次的情况下&#xff0c;使得a*&#xff08;0的最大连续长度&#xff09;&#xff08;1的最大连续长度&#xff09;最大&#xff08;1<a<n&#xff09;。输出n个数&…...

SpringBoot + Vue 前后端分离项目 微人事(九)

职位管理后端接口设计 在controller包里面新建system包&#xff0c;再在system包里面新建basic包&#xff0c;再在basic包里面创建PositionController类&#xff0c;在定义PositionController类的接口的时候&#xff0c;一定要与数据库的menu中的url地址到一致&#xff0c;不然…...

【业务功能篇71】Cglib的BeanCopier进行Bean对象拷贝

选择Cglib的BeanCopier进行Bean拷贝的理由是&#xff0c; 其性能要比Spring的BeanUtils&#xff0c;Apache的BeanUtils和PropertyUtils要好很多&#xff0c; 尤其是数据量比较大的情况下。 BeanCopier的主要作用是将数据库层面的Entity转化成service层的POJO。BeanCopier其实已…...

让eslint的错误信息显示在项目界面上

1.需求描述 效果如下 让eslint中的错误&#xff0c;显示在项目界面上 2.问题解决 1.安装 vite-plugin-eslint 插件 npm install vite-plugin-eslint --save-dev2.配置插件 // vite.config.js import { defineConfig } from vite import vue from vitejs/plugin-vue import e…...

手摸手带你实现一个开箱即用的Node邮件推送服务

目录 ​编辑 前言 准备工作 邮箱配置 代码实现 服务部署 使用效果 题外话 写在最后 相关代码&#xff1a; 前言 由于邮箱账号和手机号的唯一性&#xff0c;通常实现验证码的校验时比较常用的两种方式是手机短信推送和邮箱推送&#xff0c;此外&#xff0c;邮件推送服…...

【Linux网络】网络编程套接字 -- 基于socket实现一个简单UDP网络程序

认识端口号网络字节序处理字节序函数 htonl、htons、ntohl、ntohs socketsocket编程接口sockaddr结构结尾实现UDP程序的socket接口使用解析socket处理 IP 地址的函数初始化sockaddr_inbindrecvfromsendto 实现一个简单的UDP网络程序封装服务器相关代码封装客户端相关代码实验结…...

Python学习笔记第六十四天(Matplotlib 网格线)

Python学习笔记第六十四天 Matplotlib 网格线普通网格线样式网格线 后记 Matplotlib 网格线 我们可以使用 pyplot 中的 grid() 方法来设置图表中的网格线。 grid() 方法语法格式如下&#xff1a; matplotlib.pyplot.grid(bNone, whichmajor, axisboth, )参数说明&#xff1a…...

机器学习与模式识别3(线性回归与逻辑回归)

一、线性回归与逻辑回归简介 线性回归主要功能是拟合数据&#xff0c;常用平方误差函数。 逻辑回归主要功能是区分数据&#xff0c;找到决策边界&#xff0c;常用交叉熵。 二、线性回归与逻辑回归的实现 1.线性回归 利用回归方程对一个或多个特征值和目标值之间的关系进行建模…...

vue启动配置npm run serve,动态环境变量,根据不同环境访问不同域名

首先创建不同环境的配置文件&#xff0c;比如域名和一些常量&#xff0c;创建一个env文件,先看看文件目录 env.dev就是dev环境的域名&#xff0c;.test就是test环境域名&#xff0c;其他同理&#xff0c;然后配置package.json文件 {"name": "require-admin&qu…...

HTML <strike> 标签

HTML5 中不支持 <strike> 标签在 HTML 4 中用于定义删除线文本。 定义和用法 <strike> 标签可定义加删除线文本定义。 浏览器支持 元素ChromeIEFirefoxSafariOpera<strike>YesYesYesYesYes 所有浏览器都支持 <strike> 标签。 HTML 与 XHTML 之间…...

数学建模-模型详解(1)

规划模型 线性规划模型&#xff1a; 当涉及到线性规划模型实例时&#xff0c;以下是一个简单的示例&#xff1a; 假设我们有两个变量 x 和 y&#xff0c;并且我们希望最大化目标函数 Z 5x 3y&#xff0c;同时满足以下约束条件&#xff1a; x > 0y > 02x y < 10…...

MySQL 数据库表的基本操作

一、数据库表概述 在数据库中&#xff0c;数据表是数据库中最重要、最基本的操作对象&#xff0c;是数据存储的基本单位。数据表被定义为列的集合&#xff0c;数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录&#xff0c;每一列代表记录中的一个域。 二、数…...

企业微信电脑端开启chrome调试

首先&#xff1a; Mac端调试开启的快捷键&#xff1a;control shift command d Window端调试开启的快捷键: control shift alt d 这边以Mac为例&#xff0c;我们可以在电脑顶部看到调试的入口&#xff1a; 然后我们点击 『浏览器、webView相关』菜单&#xff0c;勾选上…...

Maven官网下载配置新仓库

1.Maven的下载 Maven的官网地址&#xff1a;Maven – Download Apache Maven 点击Download&#xff0c;查找 Files下的版本并下载如下图&#xff1a; 2.Maven的配置 自己在D盘或者E盘创建一个文件夹&#xff0c;作为本地仓库&#xff0c;存放项目依赖。 将下载好的zip文件进行解…...

椭圆曲线密码学(ECC)

一、ECC算法概述 椭圆曲线密码学&#xff08;Elliptic Curve Cryptography&#xff09;是基于椭圆曲线数学理论的公钥密码系统&#xff0c;由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA&#xff0c;ECC在相同安全强度下密钥更短&#xff08;256位ECC ≈ 3072位RSA…...

Oracle查询表空间大小

1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)

CSI-2 协议详细解析 (一&#xff09; 1. CSI-2层定义&#xff08;CSI-2 Layer Definitions&#xff09; 分层结构 &#xff1a;CSI-2协议分为6层&#xff1a; 物理层&#xff08;PHY Layer&#xff09; &#xff1a; 定义电气特性、时钟机制和传输介质&#xff08;导线&#…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时&#xff0c;需结合业务场景设计数据流转链路&#xff0c;重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点&#xff1a; 一、核心对接场景与目标 商品数据同步 场景&#xff1a;将1688商品信息…...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代&#xff0c;加密货币作为一种新兴的金融现象&#xff0c;正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而&#xff0c;加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下&#xff0c;稳定…...

html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码

目录 一、&#x1f468;‍&#x1f393;网站题目 二、✍️网站描述 三、&#x1f4da;网站介绍 四、&#x1f310;网站效果 五、&#x1fa93; 代码实现 &#x1f9f1;HTML 六、&#x1f947; 如何让学习不再盲目 七、&#x1f381;更多干货 一、&#x1f468;‍&#x1f…...

Go 并发编程基础:通道(Channel)的使用

在 Go 中&#xff0c;Channel 是 Goroutine 之间通信的核心机制。它提供了一个线程安全的通信方式&#xff0c;用于在多个 Goroutine 之间传递数据&#xff0c;从而实现高效的并发编程。 本章将介绍 Channel 的基本概念、用法、缓冲、关闭机制以及 select 的使用。 一、Channel…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

Python Einops库:深度学习中的张量操作革命

Einops&#xff08;爱因斯坦操作库&#xff09;就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库&#xff0c;用类似自然语言的表达式替代了晦涩的API调用&#xff0c;彻底改变了深度学习工程…...

破解路内监管盲区:免布线低位视频桩重塑停车管理新标准

城市路内停车管理常因行道树遮挡、高位设备盲区等问题&#xff0c;导致车牌识别率低、逃费率高&#xff0c;传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法&#xff0c;正成为破局关键。该设备安装于车位侧方0.5-0.7米高度&#xff0c;直接规避树枝遮…...