当前位置: 首页 > news >正文

【深度学习笔记】09 权重衰减

09 权重衰减

    • 范数和权重衰减
    • 利用高维线性回归实现权重衰减
    • 权重衰减的简洁实现

范数和权重衰减

在训练参数化机器学习模型时,权重衰减(decay weight)是最广泛应用的正则化技术之一,它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,
因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0
在某种意义上是最简单的。

一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx
中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
要保证权重向量比较小,
最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
将原来的训练目标最小化训练标签上的预测损失,
调整为最小化预测损失和惩罚项之和。

损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

x ( i ) \mathbf{x}^{(i)} x(i)是样本 i i i的特征,
y ( i ) y^{(i)} y(i)是样本 i i i的标签,
( w , b ) (\mathbf{w}, b) (w,b)是权重和偏置参数。

为了惩罚权重向量的大小,
必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
我们通过正则化常数 λ \lambda λ来描述这种权衡,
这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ = 0 \lambda = 0 λ=0,我们恢复了原来的损失函数。
对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。
这里我们仍然除以 2 2 2:当我们取一个二次函数的导数时,
2 2 2 1 / 2 1/2 1/2会抵消。

通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w
然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减
我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。
与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。
较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w
而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。

是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。
通常,网络输出层的偏置项不会被正则化。

利用高维线性回归实现权重衰减

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

首先生成数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

选择标签是关于输入的线性函数。
标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200
并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

初始化模型参数

定义一个函数来随机初始化模型参数

def init_params():w = torch.normal(0, 1, size = (num_inputs, 1), requires_grad = True)b = torch.zeros(1, requires_grad = True)return [w, b]

定义 L 2 L_2 L2范数惩罚

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

定义训练代码实现

下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

忽略正则化直接训练

用lamdb=0禁用权重衰减后运行代码。此时训练误差有所减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd = 0)
w的L2范数是: 14.971677780151367

在这里插入图片描述

使用权重衰减

使用权重衰减来运行代码。此时训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd = 3)
w的L2范数是: 0.34405317902565

在这里插入图片描述

权重衰减的简洁实现

在实例化优化器时直接通过weight_decay指定weight decay超参数。默认情况下,PyTorch同时衰减权重和便宜。这里只为权重设置了weight_decay,所以偏置参数 b b b不会衰减。

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)
w的L2范数: 13.416662216186523

在这里插入图片描述

train_concise(3)
w的L2范数: 0.39273694157600403

在这里插入图片描述

相关文章:

【深度学习笔记】09 权重衰减

09 权重衰减 范数和权重衰减利用高维线性回归实现权重衰减初始化模型参数定义 L 2 L_2 L2​范数惩罚定义训练代码实现忽略正则化直接训练使用权重衰减 权重衰减的简洁实现 范数和权重衰减 在训练参数化机器学习模型时,权重衰减(decay weight&#xff09…...

三大兼容 | 人大金仓兼容+优化MySQL用户变量特性

目前,KingbaseES对MySQL的兼容性,已从功能兼容阶段过渡到强性能兼容、生态全面兼容阶段,针对客户常常遇到的用户变量问题,KingbaseES在兼容MySQL用户变量功能的基础上,优化了MySQL用户变量的一些原生问题,使…...

Git介绍与安装使用

目录 1.Git初识 1.1提出问题 1.2如何解决--版本控制器 1.3注意事项 2.Git安装 2.1Linux-centos安装 2.2Linux-ubuntu安装 2.3Windows安装 3.Git基本操作 3.1创建Git本地仓库 3.2配置Git 4.认识⼯作区、暂存区、版本库 1.Git初识 1.1提出问题 不知道你工作或学习时…...

理解DuLinkList L中的“”引用符号

在C中,DuLinkList &L 这种形式的参数表示 L 是一个 DuLinkList 类型的引用。这里的 & 符号表示引用。 引用是C的一个特性,它提供了一种方式来访问已存在的变量的别名。当你对引用进行操作时,实际上是在操作它所引用的变量。如果你在…...

前端并发多个请求并失败重发

const MAX_RETRIES 3;// 模拟请求 function makeRequest(url) {return new Promise((resolve, reject) > {setTimeout(() > {Math.random() < 0.75 ? resolve(${url} 成功) : reject(${url} 失败); // 随机决定请求是否成功}, Math.random() * 2000); // 随机延时执…...

【Qt开发流程】之对象模型2:属性系统

描述 Qt提供了一个复杂的属性系统&#xff0c;类似于一些编译器供应商提供的属性系统。然而&#xff0c;作为一个独立于编译器和平台的库&#xff0c;Qt不依赖于非标准的编译器特性&#xff0c;如__property或[property]。 Qt解决方案适用于Qt支持的所有平台上的任何标准c编译…...

PHP之curl详细讲解

cURL&#xff08;全称为Client for URLs&#xff09;是一个功能强大的开源库&#xff0c;用于在多种协议上进行数据传输、发送HTTP请求和获取响应。它支持多种协议&#xff0c;包括HTTP、HTTPS、FTP、SMTP等&#xff0c;并且能够与各种服务器进行通信。 cURL库可以通过命令行工…...

R语言30分钟上手

文章目录 1. 环境&安装1.1. rstudio保存工作空间 2. 创建数据集2.1. 数据集概念2.2. 向量、矩阵2.3. 数据框2.3.1. 创建数据框2.3.2. 创建新变量2.3.3. 变量的重编码2.3.4. 列重命名2.3.5. 缺失值2.3.6. 日期值2.3.7. 数据框排序2.3.8. 数据框合并(合并沪深300和中证500收盘…...

上下拉电阻会增强驱动能力吗?

最近看到一个关于上下拉电阻的问题&#xff0c;发现不少人认为上下拉电阻能够增强驱动能力。随后跟几个朋友讨论了一下&#xff0c;大家一致认为不存在上下拉电阻增强驱动能力这回事&#xff0c;因为除了OC输出这类特殊结构外&#xff0c;上下拉电阻就是负载&#xff0c;只会减…...

题目:小明的彩灯(蓝桥OJ 1276)

题目描述&#xff1a; 解题思路&#xff1a; 一段连续区间加减&#xff0c;采用差分。最终每个元素结果与0比较大小&#xff0c;比0小即负数输出0。 题解&#xff1a; #include<bits/stdc.h> using namespace std;using ll long long; const int N 1e5 10; ll a[N],…...

换元法求不定积分

1.一般步骤&#xff1a;选取换元对象&#xff08;不一定是式子中的值&#xff0c;也可以是式子中的最小公倍数或者最大公因数&#xff09;&#xff0c;然后将dx换为dt*t的导数&#xff0c;再用t将原式表示&#xff0c;化简计算即可 2. 3. 4. 5. 6....

在Docker容器中启用SSH服务,实现外部访问的详细教程

目录 步骤 1: 安装 SSH 服务器 步骤 2: 配置 SSH 服务器 步骤 3: 设置 SSH 用户 步骤 4: 重启 SSH 服务器 步骤 5: 映射容器端口 步骤 6: 使用 SSH 连接到容器 要在Docker容器中启用SSH服务&#xff0c;以便从外部访问&#xff0c;您需要执行以下步骤&#xff1a; 步骤 …...

Go 模块系统最小版本选择法 MVS 详解

目录 Golang 模块系统简介 包版本管理 最小版本选择&#xff08;MVS&#xff09;原理 MVS 的优点 MVS的缺点 实际使用MVS 小结 参考资料 Golang 模块系统简介 Golang 模块系统是 Go 1.11 版本引入的一个新特性&#xff0c;主要目的是解决 Go 项目中的依赖管理问题。在模…...

ifstream读取txt中的中文数据转成QString出现乱码

使用ifstream从txt文本中读取中文数据到string&#xff0c;再将string转成QString输出时出现了乱码。 分析&#xff1a;如果ifstream能成功从txt文本中读出中文数据&#xff0c;那大概率txt用的编码是ANSI编码&#xff08;GBK就是ANSI的一种&#xff09;&#xff0c;那么在转成…...

UE4 双屏分辨率设置

背景&#xff1a; 做了一个UI 应用&#xff0c;需要在双屏上进行显示。 分辨率如下&#xff1a;3840*1080&#xff1b; 各种折腾&#xff0c;其实很简单&#xff1a; 主要是在全屏模式的时候 一开始没有选对&#xff0c;双屏总是不稳定。 全屏模式改成&#xff1a;Windows 之…...

$sformat在仿真中打印文本名的使用

在仿真中&#xff0c;定义队列&#xff0c;使用任务进行函数传递&#xff0c;并传递文件名&#xff0c;传递队列&#xff0c;进行打印 $sformat(filename, “./data_log/%0d_%0d_%0d_0.txt”, f_num, lane_num,dt); 使用此函数可以自定义字符串&#xff0c;在仿真的时候进行文件…...

【Rust】结构体与枚举

文章目录 结构体struct基础用法使用字段初始化简写语法使用没有命名字段的元组结构体来创建不同的类型没有任何字段的类单元结构体方法语法关联函数多个 impl 块 枚举枚举值Option 结构体struct 基础用法 一个存储用户账号信息的结构体&#xff1a; struct User {active: bo…...

CentOS7 防火墙常用命令

以下是在 CentOS 7 上使用 firewall-cmd 命令管理防火墙时的一些常用命令&#xff1a; 检查防火墙状态&#xff1a; sudo firewall-cmd --state 启动防火墙&#xff1a; sudo systemctl start firewalld 停止防火墙&#xff1a; sudo systemctl stop firewalld 重启防火墙&…...

【无标题】什么是UL9540测试,UL9540:2023版本增加哪些测试项目

什么是UL9540测试&#xff0c;UL9540:2023版本增加哪些测试项目 UL 9540是美国安全实验室&#xff08;Underwriters Laboratories&#xff09;发布的标准&#xff0c;名称为"UL 9540: Energy Storage Systems and Equipment"&#xff0c;翻译为中文为"能量存储…...

springcloud整合Oauth2自定义登录/登出接口

我使用的是password模式&#xff0c;并配置了token模式 一、登录 (这里我使用的示例是用户名密码认证方式) 1. Oath2提供默认登录授权接口 org.springframework.security.oauth2.provider.endpoint.postAccess; Tokenpublic ResponseEntity<OAuth2AccessToken> pos…...

保姆级教程:从零配置ROS2自定义消息包(含CMake/ament避坑指南)

从零构建ROS2自定义消息包的终极实践指南 在机器人开发领域&#xff0c;ROS2的消息系统是模块间通信的核心枢纽。当标准消息类型无法满足特定需求时&#xff0c;自定义消息包便成为开发者必须掌握的技能。本文将带您从零开始&#xff0c;逐步构建一个完整的ROS2自定义消息包&am…...

MoveIt Config 配置文件完整一致性检查

检查范围&#xff08;全部核对完毕&#xff09;ros2_control xacro&#xff08;硬件接口 / 关节&#xff09;initial_positions.yaml&#xff08;初始位置&#xff09;srdf&#xff08;运动组 / 关节&#xff09;joint_limits.yaml&#xff08;关节限制&#xff09;kinematics.…...

GaussDB JDBC SSL加密全攻略:从零配置到生产环境最佳实践

GaussDB JDBC SSL加密全攻略&#xff1a;从零配置到生产环境最佳实践 在数据驱动的时代&#xff0c;数据库连接的安全性已成为企业级应用不可忽视的生命线。作为华为云推出的分布式关系型数据库&#xff0c;GaussDB在金融、政务等对安全性要求极高的场景中广泛应用。而JDBC作为…...

移动端语音交互避坑指南:录音超时截取、倒计时提醒与MP3转换的完整方案

移动端语音交互避坑指南&#xff1a;录音超时截取、倒计时提醒与MP3转换的完整方案 在即时通讯和语音输入场景中&#xff0c;流畅的录音体验直接影响用户留存。数据显示&#xff0c;超过83%的用户会因为录音功能卡顿或操作复杂而放弃使用语音功能。本文将深入解析三个关键体验优…...

OpenRocket:模型火箭仿真的全流程技术解决方案

OpenRocket&#xff1a;模型火箭仿真的全流程技术解决方案 【免费下载链接】openrocket Model-rocketry aerodynamics and trajectory simulation software 项目地址: https://gitcode.com/GitHub_Trending/op/openrocket OpenRocket作为一款开源的模型火箭仿真软件&…...

K230 vs树莓派视觉套件:300元预算该选谁?实测对比工业检测场景

K230与树莓派视觉套件&#xff1a;300元预算下的工业检测实战对比 在工业自动化浪潮中&#xff0c;视觉检测系统正从大型企业向中小型制造车间快速渗透。当预算被严格限制在300元区间时&#xff0c;K230开发板与树莓派摄像头组合成为最受关注的两种解决方案。我们历时三个月在6…...

Phi-4-Reasoning-Vision行业落地:教育领域图像题解与隐藏线索识别案例

Phi-4-Reasoning-Vision行业落地&#xff1a;教育领域图像题解与隐藏线索识别案例 1. 项目背景与价值 在教育领域&#xff0c;图像题解和隐藏线索识别一直是教学和考试中的难点。传统方法依赖人工标注和分析&#xff0c;效率低下且容易遗漏关键信息。Phi-4-Reasoning-Vision多…...

猫抓浏览器插件:网页资源嗅探与下载的终极解决方案

猫抓浏览器插件&#xff1a;网页资源嗅探与下载的终极解决方案 【免费下载链接】cat-catch 猫抓 chrome资源嗅探扩展 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch 你是否曾在浏览网页时&#xff0c;看到精彩的视频、音频或图片资源&#xff0c;却苦于无…...

别再只用Set5了!超分辨率模型训练,这5个开源数据集(DIV2K、Flickr2K等)的实战配置与对比

超分辨率模型训练&#xff1a;5个开源数据集的深度实战指南 在超分辨率研究领域&#xff0c;数据集的选择往往决定了模型性能的上限。许多开发者习惯性地使用Set5、Set14等小型数据集&#xff0c;却忽略了更丰富的数据资源可能带来的性能突破。本文将深入解析DIV2K、Flickr2K、…...

深入解析 Linux 内核中的 PCI 中断向量分配机制:pci_alloc_irq_vectors

1. PCI中断向量分配机制入门指南 第一次接触PCI设备中断处理时&#xff0c;我被各种专业术语搞得晕头转向。直到在项目里实际调试一个网卡驱动时&#xff0c;才真正理解pci_alloc_irq_vectors这个函数的重要性。想象一下&#xff0c;你的电脑就像个繁忙的快递分拣中心&#xf…...