当前位置: 首页 > news >正文

《动手学深度学习 Pytorch版》 7.5 批量规范化

7.5.1 训练深层网络

训练神经网络的实际问题:

  • 数据预处理的方式会对最终结果产生巨大影响。

  • 训练时,多层感知机的中间层变量可能具有更广的变化范围。

  • 更深层的网络很复杂容易过拟合。

批量规范化对小批量的大小有要求,只有批量大小足够大时批量规范化才是有效的。

x ∈ B \boldsymbol{x}\in B xB 表示一个来自小批量 B B B 的输入;$\hat{\boldsymbol{\mu}}_B $ 表示小批量 B B B 的样本均值; σ ^ B \hat{\boldsymbol{\sigma}}_B σ^B 表示小批量 B B B 的样本标准差;批量规范化 BN 根据以下表达式转换 x \boldsymbol{x} x

B N ( x ) = γ ⊙ x + μ ^ B σ ^ B + β BN(\boldsymbol{x})=\gamma\odot\frac{\boldsymbol{x}+\hat{\boldsymbol{\mu}}_B}{\hat{\boldsymbol{\sigma}}_B}+\beta BN(x)=γσ^Bx+μ^B+β

应用标准化后生成的小批量的均值为 0,单位方差为 1。此外,其中还包含与 x \boldsymbol{x} x 形状相同的拉伸参数 γ \gamma γ 和偏移参数 β \beta β。需要注意的是, γ \gamma γ β \beta β 是需要与其他模型一起参与学习的参数。

从形式上看,可以计算出上式中的 $\hat{\boldsymbol{\mu}}_B $ 和 σ ^ B \hat{\boldsymbol{\sigma}}_B σ^B

μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x σ ^ B = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ \begin{align} \hat{\boldsymbol{\mu}}_B &= \frac{1}{\left|B\right|}\sum_{\boldsymbol{x}\in B}\boldsymbol{x}\\ \hat{\boldsymbol{\sigma}}_B &= \frac{1}{\left|B\right|}\sum_{\boldsymbol{x}\in B}(\boldsymbol{x}-\hat{\boldsymbol{\mu}}_B)^2+\epsilon \end{align} μ^Bσ^B=B1xBx=B1xB(xμ^B)2+ϵ

式中添加的大于零的常量 ϵ \epsilon ϵ 可以保证不会发生除数为零的错误。

7.5.2 批量规范化层

全连接层和卷积层需要两种略有不同的批量规范化策略:

  • 全连接层

    通常,我们将批量规范化层置于全连接层中的仿射变换和激活函数之间。 设全连接层的输入为 x x x,权重参数和偏置参数分别为 W \boldsymbol{W} W b b b,激活函数为 ϕ \phi ϕ,批量规范化的运算符为 B N BN BN。那么,使用批量规范化的全连接层的输出的计算详情如下:

    h = ϕ ( B N ( W x + b ) ) \boldsymbol{h}=\phi(BN(\boldsymbol{W}x+b)) h=ϕ(BN(Wx+b))

  • 卷积层

    对于卷积层,可以在卷积层之后和非线性激活函数之前应用批量规范化。而且需要对多个输出通道中的每个输出执行批量规范化,每个通道都有自己的标量参数:拉伸和偏移参数。

  • 预测过程中的批量规范化

    批量规范化在训练模式和预测模式下的行为通常不同。

7.5.3 从零实现

import torch
from torch import nn
from d2l import torch as d2l
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):if not torch.is_grad_enabled():  # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:  # 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)  # 按行求均值var = ((X - mean) ** 2).mean(dim=0)  # 按行求方差else:  # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。mean = X.mean(dim=(0, 2, 3), keepdim=True)  # 保持X的形状(即第1维,输出通道数)以便后面可以做广播运算,结果的形状是1*n*1*1var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)X_hat = (X - mean) / torch.sqrt(var + eps)  # 训练模式下,用当前的均值和方差做标准化# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

7.5.4 使用批量规范化层的 LeNet

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))

学习率拉的好大。

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.262, train acc 0.902, test acc 0.879
20495.7 examples/sec on cuda:0

在这里插入图片描述

7.5.5 简明实现

net1 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),nn.Linear(84, 10))
d2l.train_ch6(net1, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.263, train acc 0.903, test acc 0.870
36208.4 examples/sec on cuda:0

在这里插入图片描述

7.5.6 争议

这个东西就是玄学,有效但是不知大为什么有效。作者给出的解释是“减少内部协变量偏移”,但是也是处于直觉而不是证明。

练习

(1)在使用批量规范化之前,我们是否可以从全连接层或者卷积层中删除偏置函数?为什么?

我认为可以,偏置会在减去均值时消去,此外,BN 中也是带偏移参数的。


(2)比较 LeNet 在使用和不使用批量规范化情况下的学习率。

a. 绘制训练和测试精准度的提高。b. 学习率有多高?

学习率相同的话,使用批量规范化的收敛速度会非常快。


(3)我们是否需要在每个层中进行批量规范化?尝试一下?

net2 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net2, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.349, train acc 0.871, test acc 0.856
37741.0 examples/sec on cuda:0

在这里插入图片描述

去掉后面两个之后曲线稳多了。


(4)可以通过批量规范化来替换暂退法吗?行为会如何改变?

看来还是批量规范化好些

net3 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.Sigmoid(),nn.Dropout(p=0.1),nn.Linear(120, 84), nn.Sigmoid(),nn.Dropout(p=0.1),nn.Linear(84, 10))lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net3, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.541, train acc 0.790, test acc 0.748
40642.6 examples/sec on cuda:0

在这里插入图片描述


(5) 确定参数 gamma 和 beta,并观察和分析结果。

net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))
(tensor([3.1800, 1.6709, 4.0375, 3.4801, 2.6182, 2.3103], device='cuda:0',grad_fn=<ReshapeAliasBackward0>),tensor([ 3.5415,  1.6295,  1.8926, -1.5510, -2.4556,  1.1020], device='cuda:0',grad_fn=<ReshapeAliasBackward0>))

(6)查看高级 API 中关于 BatchNorm 的在线文档,以了解其他批量规范化的应用。

略。


(7)研究思路:可以应用的其他“规范化”变换有哪些,可以应用概率积分变换吗,全秩协方差估计呢?

略。

相关文章:

《动手学深度学习 Pytorch版》 7.5 批量规范化

7.5.1 训练深层网络 训练神经网络的实际问题&#xff1a; 数据预处理的方式会对最终结果产生巨大影响。 训练时&#xff0c;多层感知机的中间层变量可能具有更广的变化范围。 更深层的网络很复杂容易过拟合。 批量规范化对小批量的大小有要求&#xff0c;只有批量大小足够…...

Toaster - Android 吐司框架,专治 Toast 各种疑难杂症

官网 https://github.com/getActivity/Toaster 这可能是性能优、使用简单&#xff0c;支持自定义&#xff0c;不需要通知栏权限的吐司 想了解实现原理的可以点击此链接查看&#xff1a;Toaster 源码 集成步骤 如果你的项目 Gradle 配置是在 7.0 以下&#xff0c;需要在 bui…...

2023年9月26日,历史上的今天大事件早读

1620年9月26日大明皇帝朱常洛驾崩 1815年9月26日俄、普、奥三国在巴黎发表缔结“神圣同盟” 1841年9月26日清代思想家、诗人龚自珍逝世 1849年9月26日“生理学之父”巴甫洛夫诞生 1909年9月26日云南陆军讲武堂创办 1953年9月26日画家徐悲鸿逝世 1980年9月26日国际宇航联合…...

CListCtrl控件为只显示一列,持滚动显示其他,不用SetScrollFlags

CListCtrl控件为只显示一列,持滚动显示其他,不用SetScrollFlags 2023/9/5 下午4:52:58 如果您不希望使用 SetScrollFlags 函数来设置滚动条样式,可以使用以下代码将 CListCtrl 控件设置为只显示一列,并支持滚动显示其他内容: cpp // 设置控件样式和属性 m_listCtrl.Se…...

spring博客实现分页查询

1、首先创建dto下的分页类PageBean package com.zzz.blog.dto;import java.util.List;public class PageBean {private Integer pageSize; //页面大小private Integer currentPage; //当前页private Integer totalCount; //总条数private Integer totalPage; //总页数private …...

代码阅读分析神器-Scitools Understand

这里写目录标题 前言概要功能介绍1.代码统计2.图形化分析3.代码检查 使用方法下载及使用 前言 作为一名程序员&#xff0c;阅读代码是一个必须要拥有的能力&#xff0c;但无奈很多代码逻辑嵌套非常多&#xff0c;看起来非常吃力&#xff0c;看了那段逻辑就忘记了刚才的逻辑&am…...

学霸吐血整理‼《2023 年 IC 验证岗面试真题解析》宝藏干货!

Q1.定宽数组、动态数组、关联数组、队列各自的特点和使用方式。 Q2.fork…join/fork…join_any/fork…join_none 之间的异同 Q3.mailbox、event、semaphore 之间的异同 Q4.(event_handle)和 wait(event_handle.triggered)区别 Q5.task 和 function 异同区别 Q6.使用 clocking b…...

稳定性、可靠性、可用性、灵活性、解耦性

稳定性 平衡的能力 Linux系统的OOM机制、tcp的拥塞控制 可靠性 确定的能力 tcp的ACK、HA机制、加密 可用性 复原的能力 负债均衡、tcp的重传、冗余机制、故障域 灵活性 界限的能力 用户态、restful api、IP地址掩码 解耦性 不依赖的能力 分布式、SDN、容器、操作…...

docker搭建Redis三主三从

docker搭建Redis三主三从 首先启动6个redis进入容器构建主从关系连接进入6381作为切入点&#xff0c;查看集群状态 首先启动6个redis [rootdocker redis-node-1]# cat /etc/hosts 127.0.0.1 localhost localhost.localdomain localhost4 localhost4.localdomain4 ::1 …...

亚马逊要求的UL报告的产品标准是什么?如何区分

亚马逊为什么要求电子产品有UL检测报告&#xff1f; 首先&#xff0c;美国是一个对安全要求非常严格的国家&#xff0c;美国本土的所有电子产品生产企业早在很多年前就要求有相关安规检测。 其次&#xff0c;随着亚马逊在全球商业的战略地位不断提高&#xff0c;境外的电子设…...

如何在linux定时备份opengauss数据库(linux核心至少在GLIBC_2.34及以上)

前提环境&#xff0c;linux的核心至少在GLIBC_2.34及以上才能使用。 查看linux的glibc版本的命令如下 strings /lib64/libc.so.6 | grep GLIBC 如下图 或者用ldd --version 如下图 在官网下载对应的依赖包&#xff0c; 只需要这个lib文件即可&#xff0c;将这个包放在lin…...

SkyWalking快速上手(七)——Skywalking UI 界面简介

文章目录 前言1. 仪表盘1.1 指标展示1.2 自定义仪表盘 2. 拓扑图2.1 节点展示2.2 连接展示 3. 追踪3.1 请求链路3.2 请求详情 4. 性能剖析4.1 方法级别性能分析4.2 代码级别性能分析 5. 告警5.1 告警规则设置5.2 告警通知 6. 日志记录6.1 日志展示6.2日志分析6.3代码示例 总结 …...

python+vue驾校驾驶理论考试模拟系统

管理员的主要功能有&#xff1a; 1.管理员输入账户登陆后台 2.个人中心&#xff1a;管理员修改密码和账户信息 3.用户管理&#xff1a;管理员可以对用户信息进行添加&#xff0c;修改&#xff0c;删除&#xff0c;查询 4.添加选择题&#xff1a;管理员可以添加选择题目&#xf…...

go-redis 框架基本使用

文章目录 redis使用场景下载框架和连接redis1. 安装go-redis2. 连接redis 字符串操作有序集合操作流水线事务1. 普通事务2. Watch redis使用场景 缓存系统&#xff0c;减轻主数据库&#xff08;MySQL&#xff09;的压力。计数场景&#xff0c;比如微博、抖音中的关注数和粉丝数…...

java内嵌浏览器CEF-JAVA、jcef、java chrome

java内嵌浏览器CEF-JAVA、jcef、java chrome jcef是老牌cef的chrome内嵌方案&#xff0c;可以进行java-chrome-h5-桌面开发&#xff0c;下面为最新版本&#xff08;2023年9月22日10:33:07&#xff09; JCEF&#xff08;Java Chromium Embedded Framework&#xff09;是一个基于…...

string类模拟实现——C++

一、构造与析构 1.构造函数 构造函数需要尽可能将成员在初始化列表中初始化&#xff0c;string类的成员这里自定义的和顺序表相似&#xff0c;有_str , _size , _capacity , 以及一个静态成员 npos &#xff0c;构造函数这里实现两种&#xff0c;一种是传参为常量字符串的&am…...

在 SQL Server 中,可以使用加号运算符(+)来拼接字符串。但是,如果需要拼接多个字符串或表中的字段,就需要使用内置的拼接函数了

以下是 SQL Server 中的一些内置拼接函数&#xff1a; 1. CONCAT&#xff1a;将两个或多个字符串拼接在一起。语法为&#xff1a; CONCAT (string1, string2, ...)示例&#xff1a; SELECT CONCAT(Hello, , World) as combined_string;输出结果为&#xff1a;Hello World&a…...

蓝桥杯每日一题2023.9.25

4406. 积木画 - AcWing题库 题目描述 分析 在完成此问题前可以先引入一个新的问题 291. 蒙德里安的梦想 - AcWing题库 我们发现16的二进制是 10000 15的二进制是1111 故刚好我们可以从0枚举到1 << n(相当于二的n次方的二进制表示&#xff09; 注&#xff1a;奇数个0…...

前端面试的话术集锦第 20 篇博文——高频考点(输入 URL 到页面渲染的整个流程)

这是记录前端面试的话术集锦第二十篇博文——高频考点(输入 URL 到页面渲染的整个流程),我会不断更新该博文。❗❗❗ 借用这道经典面试题,将之前学习到的浏览器以及网络几章节的知识联系起来。 首先是DNS查询,如果这一步做了智能DNS解析的话,会提供访问速度最快的IP地址…...

Android Jetpack Compose之确定重组范围并优化重组

目录 1.概述2.确定Composable重组的范围3.优化重组的性能3.1 Composable 位置索引3.2 通过Key添加索引信息3.3 使用注解Stable优化重组 1.概述 前面的文章提到Compose的重组是智能的&#xff0c;Composable函数在进行重组时会尽可能的跳过不必要的重组&#xff0c;只对需要变化…...

进程地址空间(比特课总结)

一、进程地址空间 1. 环境变量 1 &#xff09;⽤户级环境变量与系统级环境变量 全局属性&#xff1a;环境变量具有全局属性&#xff0c;会被⼦进程继承。例如当bash启动⼦进程时&#xff0c;环 境变量会⾃动传递给⼦进程。 本地变量限制&#xff1a;本地变量只在当前进程(ba…...

vscode(仍待补充)

写于2025 6.9 主包将加入vscode这个更权威的圈子 vscode的基本使用 侧边栏 vscode还能连接ssh&#xff1f; debug时使用的launch文件 1.task.json {"tasks": [{"type": "cppbuild","label": "C/C: gcc.exe 生成活动文件"…...

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例

文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...

Frozen-Flask :将 Flask 应用“冻结”为静态文件

Frozen-Flask 是一个用于将 Flask 应用“冻结”为静态文件的 Python 扩展。它的核心用途是&#xff1a;将一个 Flask Web 应用生成成纯静态 HTML 文件&#xff0c;从而可以部署到静态网站托管服务上&#xff0c;如 GitHub Pages、Netlify 或任何支持静态文件的网站服务器。 &am…...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计&#xff0c;提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合&#xff1a;各模块职责清晰&#xff0c;便于独立开发…...

大学生职业发展与就业创业指导教学评价

这里是引用 作为软工2203/2204班的学生&#xff0c;我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要&#xff0c;而您认真负责的教学态度&#xff0c;让课程的每一部分都充满了实用价值。 尤其让我…...

AGain DB和倍数增益的关系

我在设置一款索尼CMOS芯片时&#xff0c;Again增益0db变化为6DB&#xff0c;画面的变化只有2倍DN的增益&#xff0c;比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析&#xff1a; 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...

PAN/FPN

import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...

Webpack性能优化:构建速度与体积优化策略

一、构建速度优化 1、​​升级Webpack和Node.js​​ ​​优化效果​​&#xff1a;Webpack 4比Webpack 3构建时间降低60%-98%。​​原因​​&#xff1a; V8引擎优化&#xff08;for of替代forEach、Map/Set替代Object&#xff09;。默认使用更快的md4哈希算法。AST直接从Loa…...