《动手学深度学习 Pytorch版》 7.5 批量规范化
7.5.1 训练深层网络
训练神经网络的实际问题:
-
数据预处理的方式会对最终结果产生巨大影响。
-
训练时,多层感知机的中间层变量可能具有更广的变化范围。
-
更深层的网络很复杂容易过拟合。
批量规范化对小批量的大小有要求,只有批量大小足够大时批量规范化才是有效的。
用 x ∈ B \boldsymbol{x}\in B x∈B 表示一个来自小批量 B B B 的输入;$\hat{\boldsymbol{\mu}}_B $ 表示小批量 B B B 的样本均值; σ ^ B \hat{\boldsymbol{\sigma}}_B σ^B 表示小批量 B B B 的样本标准差;批量规范化 BN 根据以下表达式转换 x \boldsymbol{x} x:
B N ( x ) = γ ⊙ x + μ ^ B σ ^ B + β BN(\boldsymbol{x})=\gamma\odot\frac{\boldsymbol{x}+\hat{\boldsymbol{\mu}}_B}{\hat{\boldsymbol{\sigma}}_B}+\beta BN(x)=γ⊙σ^Bx+μ^B+β
应用标准化后生成的小批量的均值为 0,单位方差为 1。此外,其中还包含与 x \boldsymbol{x} x 形状相同的拉伸参数 γ \gamma γ 和偏移参数 β \beta β。需要注意的是, γ \gamma γ 和 β \beta β 是需要与其他模型一起参与学习的参数。
从形式上看,可以计算出上式中的 $\hat{\boldsymbol{\mu}}_B $ 和 σ ^ B \hat{\boldsymbol{\sigma}}_B σ^B:
μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x σ ^ B = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ \begin{align} \hat{\boldsymbol{\mu}}_B &= \frac{1}{\left|B\right|}\sum_{\boldsymbol{x}\in B}\boldsymbol{x}\\ \hat{\boldsymbol{\sigma}}_B &= \frac{1}{\left|B\right|}\sum_{\boldsymbol{x}\in B}(\boldsymbol{x}-\hat{\boldsymbol{\mu}}_B)^2+\epsilon \end{align} μ^Bσ^B=∣B∣1x∈B∑x=∣B∣1x∈B∑(x−μ^B)2+ϵ
式中添加的大于零的常量 ϵ \epsilon ϵ 可以保证不会发生除数为零的错误。
7.5.2 批量规范化层
全连接层和卷积层需要两种略有不同的批量规范化策略:
-
全连接层
通常,我们将批量规范化层置于全连接层中的仿射变换和激活函数之间。 设全连接层的输入为 x x x,权重参数和偏置参数分别为 W \boldsymbol{W} W 和 b b b,激活函数为 ϕ \phi ϕ,批量规范化的运算符为 B N BN BN。那么,使用批量规范化的全连接层的输出的计算详情如下:
h = ϕ ( B N ( W x + b ) ) \boldsymbol{h}=\phi(BN(\boldsymbol{W}x+b)) h=ϕ(BN(Wx+b))
-
卷积层
对于卷积层,可以在卷积层之后和非线性激活函数之前应用批量规范化。而且需要对多个输出通道中的每个输出执行批量规范化,每个通道都有自己的标量参数:拉伸和偏移参数。
-
预测过程中的批量规范化
批量规范化在训练模式和预测模式下的行为通常不同。
7.5.3 从零实现
import torch
from torch import nn
from d2l import torch as d2l
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):if not torch.is_grad_enabled(): # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2: # 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0) # 按行求均值var = ((X - mean) ** 2).mean(dim=0) # 按行求方差else: # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。mean = X.mean(dim=(0, 2, 3), keepdim=True) # 保持X的形状(即第1维,输出通道数)以便后面可以做广播运算,结果的形状是1*n*1*1var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)X_hat = (X - mean) / torch.sqrt(var + eps) # 训练模式下,用当前的均值和方差做标准化# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta # 缩放和移位return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y
7.5.4 使用批量规范化层的 LeNet
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))
学习率拉的好大。
lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.262, train acc 0.902, test acc 0.879
20495.7 examples/sec on cuda:0

7.5.5 简明实现
net1 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),nn.Linear(84, 10))
d2l.train_ch6(net1, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.263, train acc 0.903, test acc 0.870
36208.4 examples/sec on cuda:0

7.5.6 争议
这个东西就是玄学,有效但是不知大为什么有效。作者给出的解释是“减少内部协变量偏移”,但是也是处于直觉而不是证明。
练习
(1)在使用批量规范化之前,我们是否可以从全连接层或者卷积层中删除偏置函数?为什么?
我认为可以,偏置会在减去均值时消去,此外,BN 中也是带偏移参数的。
(2)比较 LeNet 在使用和不使用批量规范化情况下的学习率。
a. 绘制训练和测试精准度的提高。b. 学习率有多高?
学习率相同的话,使用批量规范化的收敛速度会非常快。
(3)我们是否需要在每个层中进行批量规范化?尝试一下?
net2 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net2, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.349, train acc 0.871, test acc 0.856
37741.0 examples/sec on cuda:0

去掉后面两个之后曲线稳多了。
(4)可以通过批量规范化来替换暂退法吗?行为会如何改变?
看来还是批量规范化好些
net3 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.Sigmoid(),nn.Dropout(p=0.1),nn.Linear(120, 84), nn.Sigmoid(),nn.Dropout(p=0.1),nn.Linear(84, 10))lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net3, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.541, train acc 0.790, test acc 0.748
40642.6 examples/sec on cuda:0

(5) 确定参数 gamma 和 beta,并观察和分析结果。
net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))
(tensor([3.1800, 1.6709, 4.0375, 3.4801, 2.6182, 2.3103], device='cuda:0',grad_fn=<ReshapeAliasBackward0>),tensor([ 3.5415, 1.6295, 1.8926, -1.5510, -2.4556, 1.1020], device='cuda:0',grad_fn=<ReshapeAliasBackward0>))
(6)查看高级 API 中关于 BatchNorm 的在线文档,以了解其他批量规范化的应用。
略。
(7)研究思路:可以应用的其他“规范化”变换有哪些,可以应用概率积分变换吗,全秩协方差估计呢?
略。
相关文章:
《动手学深度学习 Pytorch版》 7.5 批量规范化
7.5.1 训练深层网络 训练神经网络的实际问题: 数据预处理的方式会对最终结果产生巨大影响。 训练时,多层感知机的中间层变量可能具有更广的变化范围。 更深层的网络很复杂容易过拟合。 批量规范化对小批量的大小有要求,只有批量大小足够…...
Toaster - Android 吐司框架,专治 Toast 各种疑难杂症
官网 https://github.com/getActivity/Toaster 这可能是性能优、使用简单,支持自定义,不需要通知栏权限的吐司 想了解实现原理的可以点击此链接查看:Toaster 源码 集成步骤 如果你的项目 Gradle 配置是在 7.0 以下,需要在 bui…...
2023年9月26日,历史上的今天大事件早读
1620年9月26日大明皇帝朱常洛驾崩 1815年9月26日俄、普、奥三国在巴黎发表缔结“神圣同盟” 1841年9月26日清代思想家、诗人龚自珍逝世 1849年9月26日“生理学之父”巴甫洛夫诞生 1909年9月26日云南陆军讲武堂创办 1953年9月26日画家徐悲鸿逝世 1980年9月26日国际宇航联合…...
CListCtrl控件为只显示一列,持滚动显示其他,不用SetScrollFlags
CListCtrl控件为只显示一列,持滚动显示其他,不用SetScrollFlags 2023/9/5 下午4:52:58 如果您不希望使用 SetScrollFlags 函数来设置滚动条样式,可以使用以下代码将 CListCtrl 控件设置为只显示一列,并支持滚动显示其他内容: cpp // 设置控件样式和属性 m_listCtrl.Se…...
spring博客实现分页查询
1、首先创建dto下的分页类PageBean package com.zzz.blog.dto;import java.util.List;public class PageBean {private Integer pageSize; //页面大小private Integer currentPage; //当前页private Integer totalCount; //总条数private Integer totalPage; //总页数private …...
代码阅读分析神器-Scitools Understand
这里写目录标题 前言概要功能介绍1.代码统计2.图形化分析3.代码检查 使用方法下载及使用 前言 作为一名程序员,阅读代码是一个必须要拥有的能力,但无奈很多代码逻辑嵌套非常多,看起来非常吃力,看了那段逻辑就忘记了刚才的逻辑&am…...
学霸吐血整理‼《2023 年 IC 验证岗面试真题解析》宝藏干货!
Q1.定宽数组、动态数组、关联数组、队列各自的特点和使用方式。 Q2.fork…join/fork…join_any/fork…join_none 之间的异同 Q3.mailbox、event、semaphore 之间的异同 Q4.(event_handle)和 wait(event_handle.triggered)区别 Q5.task 和 function 异同区别 Q6.使用 clocking b…...
稳定性、可靠性、可用性、灵活性、解耦性
稳定性 平衡的能力 Linux系统的OOM机制、tcp的拥塞控制 可靠性 确定的能力 tcp的ACK、HA机制、加密 可用性 复原的能力 负债均衡、tcp的重传、冗余机制、故障域 灵活性 界限的能力 用户态、restful api、IP地址掩码 解耦性 不依赖的能力 分布式、SDN、容器、操作…...
docker搭建Redis三主三从
docker搭建Redis三主三从 首先启动6个redis进入容器构建主从关系连接进入6381作为切入点,查看集群状态 首先启动6个redis [rootdocker redis-node-1]# cat /etc/hosts 127.0.0.1 localhost localhost.localdomain localhost4 localhost4.localdomain4 ::1 …...
亚马逊要求的UL报告的产品标准是什么?如何区分
亚马逊为什么要求电子产品有UL检测报告? 首先,美国是一个对安全要求非常严格的国家,美国本土的所有电子产品生产企业早在很多年前就要求有相关安规检测。 其次,随着亚马逊在全球商业的战略地位不断提高,境外的电子设…...
如何在linux定时备份opengauss数据库(linux核心至少在GLIBC_2.34及以上)
前提环境,linux的核心至少在GLIBC_2.34及以上才能使用。 查看linux的glibc版本的命令如下 strings /lib64/libc.so.6 | grep GLIBC 如下图 或者用ldd --version 如下图 在官网下载对应的依赖包, 只需要这个lib文件即可,将这个包放在lin…...
SkyWalking快速上手(七)——Skywalking UI 界面简介
文章目录 前言1. 仪表盘1.1 指标展示1.2 自定义仪表盘 2. 拓扑图2.1 节点展示2.2 连接展示 3. 追踪3.1 请求链路3.2 请求详情 4. 性能剖析4.1 方法级别性能分析4.2 代码级别性能分析 5. 告警5.1 告警规则设置5.2 告警通知 6. 日志记录6.1 日志展示6.2日志分析6.3代码示例 总结 …...
python+vue驾校驾驶理论考试模拟系统
管理员的主要功能有: 1.管理员输入账户登陆后台 2.个人中心:管理员修改密码和账户信息 3.用户管理:管理员可以对用户信息进行添加,修改,删除,查询 4.添加选择题:管理员可以添加选择题目…...
go-redis 框架基本使用
文章目录 redis使用场景下载框架和连接redis1. 安装go-redis2. 连接redis 字符串操作有序集合操作流水线事务1. 普通事务2. Watch redis使用场景 缓存系统,减轻主数据库(MySQL)的压力。计数场景,比如微博、抖音中的关注数和粉丝数…...
java内嵌浏览器CEF-JAVA、jcef、java chrome
java内嵌浏览器CEF-JAVA、jcef、java chrome jcef是老牌cef的chrome内嵌方案,可以进行java-chrome-h5-桌面开发,下面为最新版本(2023年9月22日10:33:07) JCEF(Java Chromium Embedded Framework)是一个基于…...
string类模拟实现——C++
一、构造与析构 1.构造函数 构造函数需要尽可能将成员在初始化列表中初始化,string类的成员这里自定义的和顺序表相似,有_str , _size , _capacity , 以及一个静态成员 npos ,构造函数这里实现两种,一种是传参为常量字符串的&am…...
在 SQL Server 中,可以使用加号运算符(+)来拼接字符串。但是,如果需要拼接多个字符串或表中的字段,就需要使用内置的拼接函数了
以下是 SQL Server 中的一些内置拼接函数: 1. CONCAT:将两个或多个字符串拼接在一起。语法为: CONCAT (string1, string2, ...)示例: SELECT CONCAT(Hello, , World) as combined_string;输出结果为:Hello World&a…...
蓝桥杯每日一题2023.9.25
4406. 积木画 - AcWing题库 题目描述 分析 在完成此问题前可以先引入一个新的问题 291. 蒙德里安的梦想 - AcWing题库 我们发现16的二进制是 10000 15的二进制是1111 故刚好我们可以从0枚举到1 << n(相当于二的n次方的二进制表示) 注:奇数个0…...
前端面试的话术集锦第 20 篇博文——高频考点(输入 URL 到页面渲染的整个流程)
这是记录前端面试的话术集锦第二十篇博文——高频考点(输入 URL 到页面渲染的整个流程),我会不断更新该博文。❗❗❗ 借用这道经典面试题,将之前学习到的浏览器以及网络几章节的知识联系起来。 首先是DNS查询,如果这一步做了智能DNS解析的话,会提供访问速度最快的IP地址…...
Android Jetpack Compose之确定重组范围并优化重组
目录 1.概述2.确定Composable重组的范围3.优化重组的性能3.1 Composable 位置索引3.2 通过Key添加索引信息3.3 使用注解Stable优化重组 1.概述 前面的文章提到Compose的重组是智能的,Composable函数在进行重组时会尽可能的跳过不必要的重组,只对需要变化…...
Graphormer实战教程:基于ogb库加载PCQM4M数据微调模型示例
Graphormer实战教程:基于ogb库加载PCQM4M数据微调模型示例 1. 引言 Graphormer是一种创新的分子属性预测模型,采用纯Transformer架构的图神经网络设计。它专门针对分子图(原子-键结构)的全局结构建模与属性预测任务,…...
告别台式机没麦克风的尴尬:用SonoBus+VB-Cable把手机秒变无线麦(保姆级配置)
台式机零成本无线麦克风方案:SonoBus与VB-Cable实战指南 你是否遇到过这样的尴尬时刻——台式电脑突然需要语音沟通,却发现没有麦克风?无论是紧急会议、游戏开黑还是直播互动,这种硬件缺失带来的困扰可能让你措手不及。本文将介绍…...
Super IO:提升Blender批量处理效率的自动化流程解决方案
Super IO:提升Blender批量处理效率的自动化流程解决方案 【免费下载链接】super_io blender addon for copy paste import / export 项目地址: https://gitcode.com/gh_mirrors/su/super_io 在3D设计工作流中,设计师常常面临文件格式转换繁琐、跨…...
springboot+vue基于web的校园电动车短租系统的设计系统
目录同行可拿货,招校园代理 ,本人源头供货商系统功能分析用户管理模块车辆管理模块租赁业务模块安全与风控模块统计与报表模块技术实现要点项目技术支持源码获取详细视频演示 :文章底部获取博主联系方式!同行可合作同行可拿货,招校园代理 ,本人源头供货商…...
基于高斯过程回归的MATLAB时间序列区间预测代码实现与解析
基于高斯过程回归(GPR)的时间序列区间预测 GPR时间序列区间预测 matlab代码 暂无Matlab版本要求 -- 推荐 2018B 版本及以上做时间序列最烦的就是拍脑袋给个“明天涨3%左右”——“左右”到底是正负0.5还是正负3?如果是风电发电的负荷申报,正负差多了要罚…...
React项目实战:用XGPlayer打造带封面预览的沉浸式视频播放组件(附完整代码)
React项目实战:用XGPlayer打造带封面预览的沉浸式视频播放组件(附完整代码) 在当今内容为王的时代,视频已成为Web应用中不可或缺的元素。但如何让视频组件既美观又高效,同时提供流畅的用户体验?本文将带你深…...
高效解决E-Hentai图库下载难题:实用下载工具全攻略
高效解决E-Hentai图库下载难题:实用下载工具全攻略 【免费下载链接】E-Hentai-Downloader Download E-Hentai archive as zip file 项目地址: https://gitcode.com/gh_mirrors/eh/E-Hentai-Downloader 在数字资源管理领域,E-Hentai作为知名的漫画…...
7个关键步骤:用Meshroom实现高精度三维重建的完整指南
7个关键步骤:用Meshroom实现高精度三维重建的完整指南 【免费下载链接】Meshroom 3D Reconstruction Software 项目地址: https://gitcode.com/gh_mirrors/me/Meshroom 开源三维重建工具Meshroom凭借摄影测量实战技术,为用户提供了从二维图像到点…...
YimMenu完全指南:GTA5免费辅助工具从入门到精通
YimMenu完全指南:GTA5免费辅助工具从入门到精通 【免费下载链接】YimMenu YimMenu, a GTA V menu protecting against a wide ranges of the public crashes and improving the overall experience. 项目地址: https://gitcode.com/GitHub_Trending/yi/YimMenu …...
hgproxy4.0.35.0之前版本数据库连接卡在parse状态
文章目录环境症状问题原因解决方案环境 系统平台:Linux x86-64 Red Hat Enterprise Linux 7 版本:4.5.10 症状 查询数据库后台连接进程,发现主备节点均有超过几天的长连接,且状态卡在parse状态 问题原因 数据库会话sql出错后…...
