当前位置: 首页 > news >正文

【深度学习笔记】09 权重衰减

09 权重衰减

    • 范数和权重衰减
    • 利用高维线性回归实现权重衰减
    • 权重衰减的简洁实现

范数和权重衰减

在训练参数化机器学习模型时,权重衰减(decay weight)是最广泛应用的正则化技术之一,它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,
因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0
在某种意义上是最简单的。

一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx
中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
要保证权重向量比较小,
最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
将原来的训练目标最小化训练标签上的预测损失,
调整为最小化预测损失和惩罚项之和。

损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

x ( i ) \mathbf{x}^{(i)} x(i)是样本 i i i的特征,
y ( i ) y^{(i)} y(i)是样本 i i i的标签,
( w , b ) (\mathbf{w}, b) (w,b)是权重和偏置参数。

为了惩罚权重向量的大小,
必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
我们通过正则化常数 λ \lambda λ来描述这种权衡,
这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ = 0 \lambda = 0 λ=0,我们恢复了原来的损失函数。
对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。
这里我们仍然除以 2 2 2:当我们取一个二次函数的导数时,
2 2 2 1 / 2 1/2 1/2会抵消。

通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w
然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减
我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。
与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。
较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w
而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。

是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。
通常,网络输出层的偏置项不会被正则化。

利用高维线性回归实现权重衰减

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

首先生成数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

选择标签是关于输入的线性函数。
标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200
并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

初始化模型参数

定义一个函数来随机初始化模型参数

def init_params():w = torch.normal(0, 1, size = (num_inputs, 1), requires_grad = True)b = torch.zeros(1, requires_grad = True)return [w, b]

定义 L 2 L_2 L2范数惩罚

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

定义训练代码实现

下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

忽略正则化直接训练

用lamdb=0禁用权重衰减后运行代码。此时训练误差有所减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd = 0)
w的L2范数是: 14.971677780151367

在这里插入图片描述

使用权重衰减

使用权重衰减来运行代码。此时训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd = 3)
w的L2范数是: 0.34405317902565

在这里插入图片描述

权重衰减的简洁实现

在实例化优化器时直接通过weight_decay指定weight decay超参数。默认情况下,PyTorch同时衰减权重和便宜。这里只为权重设置了weight_decay,所以偏置参数 b b b不会衰减。

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)
w的L2范数: 13.416662216186523

在这里插入图片描述

train_concise(3)
w的L2范数: 0.39273694157600403

在这里插入图片描述

相关文章:

【深度学习笔记】09 权重衰减

09 权重衰减 范数和权重衰减利用高维线性回归实现权重衰减初始化模型参数定义 L 2 L_2 L2​范数惩罚定义训练代码实现忽略正则化直接训练使用权重衰减 权重衰减的简洁实现 范数和权重衰减 在训练参数化机器学习模型时,权重衰减(decay weight&#xff09…...

三大兼容 | 人大金仓兼容+优化MySQL用户变量特性

目前,KingbaseES对MySQL的兼容性,已从功能兼容阶段过渡到强性能兼容、生态全面兼容阶段,针对客户常常遇到的用户变量问题,KingbaseES在兼容MySQL用户变量功能的基础上,优化了MySQL用户变量的一些原生问题,使…...

Git介绍与安装使用

目录 1.Git初识 1.1提出问题 1.2如何解决--版本控制器 1.3注意事项 2.Git安装 2.1Linux-centos安装 2.2Linux-ubuntu安装 2.3Windows安装 3.Git基本操作 3.1创建Git本地仓库 3.2配置Git 4.认识⼯作区、暂存区、版本库 1.Git初识 1.1提出问题 不知道你工作或学习时…...

理解DuLinkList L中的“”引用符号

在C中,DuLinkList &L 这种形式的参数表示 L 是一个 DuLinkList 类型的引用。这里的 & 符号表示引用。 引用是C的一个特性,它提供了一种方式来访问已存在的变量的别名。当你对引用进行操作时,实际上是在操作它所引用的变量。如果你在…...

前端并发多个请求并失败重发

const MAX_RETRIES 3;// 模拟请求 function makeRequest(url) {return new Promise((resolve, reject) > {setTimeout(() > {Math.random() < 0.75 ? resolve(${url} 成功) : reject(${url} 失败); // 随机决定请求是否成功}, Math.random() * 2000); // 随机延时执…...

【Qt开发流程】之对象模型2:属性系统

描述 Qt提供了一个复杂的属性系统&#xff0c;类似于一些编译器供应商提供的属性系统。然而&#xff0c;作为一个独立于编译器和平台的库&#xff0c;Qt不依赖于非标准的编译器特性&#xff0c;如__property或[property]。 Qt解决方案适用于Qt支持的所有平台上的任何标准c编译…...

PHP之curl详细讲解

cURL&#xff08;全称为Client for URLs&#xff09;是一个功能强大的开源库&#xff0c;用于在多种协议上进行数据传输、发送HTTP请求和获取响应。它支持多种协议&#xff0c;包括HTTP、HTTPS、FTP、SMTP等&#xff0c;并且能够与各种服务器进行通信。 cURL库可以通过命令行工…...

R语言30分钟上手

文章目录 1. 环境&安装1.1. rstudio保存工作空间 2. 创建数据集2.1. 数据集概念2.2. 向量、矩阵2.3. 数据框2.3.1. 创建数据框2.3.2. 创建新变量2.3.3. 变量的重编码2.3.4. 列重命名2.3.5. 缺失值2.3.6. 日期值2.3.7. 数据框排序2.3.8. 数据框合并(合并沪深300和中证500收盘…...

上下拉电阻会增强驱动能力吗?

最近看到一个关于上下拉电阻的问题&#xff0c;发现不少人认为上下拉电阻能够增强驱动能力。随后跟几个朋友讨论了一下&#xff0c;大家一致认为不存在上下拉电阻增强驱动能力这回事&#xff0c;因为除了OC输出这类特殊结构外&#xff0c;上下拉电阻就是负载&#xff0c;只会减…...

题目:小明的彩灯(蓝桥OJ 1276)

题目描述&#xff1a; 解题思路&#xff1a; 一段连续区间加减&#xff0c;采用差分。最终每个元素结果与0比较大小&#xff0c;比0小即负数输出0。 题解&#xff1a; #include<bits/stdc.h> using namespace std;using ll long long; const int N 1e5 10; ll a[N],…...

换元法求不定积分

1.一般步骤&#xff1a;选取换元对象&#xff08;不一定是式子中的值&#xff0c;也可以是式子中的最小公倍数或者最大公因数&#xff09;&#xff0c;然后将dx换为dt*t的导数&#xff0c;再用t将原式表示&#xff0c;化简计算即可 2. 3. 4. 5. 6....

在Docker容器中启用SSH服务,实现外部访问的详细教程

目录 步骤 1: 安装 SSH 服务器 步骤 2: 配置 SSH 服务器 步骤 3: 设置 SSH 用户 步骤 4: 重启 SSH 服务器 步骤 5: 映射容器端口 步骤 6: 使用 SSH 连接到容器 要在Docker容器中启用SSH服务&#xff0c;以便从外部访问&#xff0c;您需要执行以下步骤&#xff1a; 步骤 …...

Go 模块系统最小版本选择法 MVS 详解

目录 Golang 模块系统简介 包版本管理 最小版本选择&#xff08;MVS&#xff09;原理 MVS 的优点 MVS的缺点 实际使用MVS 小结 参考资料 Golang 模块系统简介 Golang 模块系统是 Go 1.11 版本引入的一个新特性&#xff0c;主要目的是解决 Go 项目中的依赖管理问题。在模…...

ifstream读取txt中的中文数据转成QString出现乱码

使用ifstream从txt文本中读取中文数据到string&#xff0c;再将string转成QString输出时出现了乱码。 分析&#xff1a;如果ifstream能成功从txt文本中读出中文数据&#xff0c;那大概率txt用的编码是ANSI编码&#xff08;GBK就是ANSI的一种&#xff09;&#xff0c;那么在转成…...

UE4 双屏分辨率设置

背景&#xff1a; 做了一个UI 应用&#xff0c;需要在双屏上进行显示。 分辨率如下&#xff1a;3840*1080&#xff1b; 各种折腾&#xff0c;其实很简单&#xff1a; 主要是在全屏模式的时候 一开始没有选对&#xff0c;双屏总是不稳定。 全屏模式改成&#xff1a;Windows 之…...

$sformat在仿真中打印文本名的使用

在仿真中&#xff0c;定义队列&#xff0c;使用任务进行函数传递&#xff0c;并传递文件名&#xff0c;传递队列&#xff0c;进行打印 $sformat(filename, “./data_log/%0d_%0d_%0d_0.txt”, f_num, lane_num,dt); 使用此函数可以自定义字符串&#xff0c;在仿真的时候进行文件…...

【Rust】结构体与枚举

文章目录 结构体struct基础用法使用字段初始化简写语法使用没有命名字段的元组结构体来创建不同的类型没有任何字段的类单元结构体方法语法关联函数多个 impl 块 枚举枚举值Option 结构体struct 基础用法 一个存储用户账号信息的结构体&#xff1a; struct User {active: bo…...

CentOS7 防火墙常用命令

以下是在 CentOS 7 上使用 firewall-cmd 命令管理防火墙时的一些常用命令&#xff1a; 检查防火墙状态&#xff1a; sudo firewall-cmd --state 启动防火墙&#xff1a; sudo systemctl start firewalld 停止防火墙&#xff1a; sudo systemctl stop firewalld 重启防火墙&…...

【无标题】什么是UL9540测试,UL9540:2023版本增加哪些测试项目

什么是UL9540测试&#xff0c;UL9540:2023版本增加哪些测试项目 UL 9540是美国安全实验室&#xff08;Underwriters Laboratories&#xff09;发布的标准&#xff0c;名称为"UL 9540: Energy Storage Systems and Equipment"&#xff0c;翻译为中文为"能量存储…...

springcloud整合Oauth2自定义登录/登出接口

我使用的是password模式&#xff0c;并配置了token模式 一、登录 (这里我使用的示例是用户名密码认证方式) 1. Oath2提供默认登录授权接口 org.springframework.security.oauth2.provider.endpoint.postAccess; Tokenpublic ResponseEntity<OAuth2AccessToken> pos…...

基于ASP.NET+ SQL Server实现(Web)医院信息管理系统

医院信息管理系统 1. 课程设计内容 在 visual studio 2017 平台上&#xff0c;开发一个“医院信息管理系统”Web 程序。 2. 课程设计目的 综合运用 c#.net 知识&#xff0c;在 vs 2017 平台上&#xff0c;进行 ASP.NET 应用程序和简易网站的开发&#xff1b;初步熟悉开发一…...

dedecms 织梦自定义表单留言增加ajax验证码功能

增加ajax功能模块&#xff0c;用户不点击提交按钮&#xff0c;只要输入框失去焦点&#xff0c;就会提前提示验证码是否正确。 一&#xff0c;模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…...

数据库分批入库

今天在工作中&#xff0c;遇到一个问题&#xff0c;就是分批查询的时候&#xff0c;由于批次过大导致出现了一些问题&#xff0c;一下是问题描述和解决方案&#xff1a; 示例&#xff1a; // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...

爬虫基础学习day2

# 爬虫设计领域 工商&#xff1a;企查查、天眼查短视频&#xff1a;抖音、快手、西瓜 ---> 飞瓜电商&#xff1a;京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空&#xff1a;抓取所有航空公司价格 ---> 去哪儿自媒体&#xff1a;采集自媒体数据进…...

【HTTP三个基础问题】

面试官您好&#xff01;HTTP是超文本传输协议&#xff0c;是互联网上客户端和服务器之间传输超文本数据&#xff08;比如文字、图片、音频、视频等&#xff09;的核心协议&#xff0c;当前互联网应用最广泛的版本是HTTP1.1&#xff0c;它基于经典的C/S模型&#xff0c;也就是客…...

聊一聊接口测试的意义有哪些?

目录 一、隔离性 & 早期测试 二、保障系统集成质量 三、验证业务逻辑的核心层 四、提升测试效率与覆盖度 五、系统稳定性的守护者 六、驱动团队协作与契约管理 七、性能与扩展性的前置评估 八、持续交付的核心支撑 接口测试的意义可以从四个维度展开&#xff0c;首…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件&#xff0c;用于在原生应用中加载 HTML 页面&#xff1a; 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...

基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解

JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用&#xff0c;结合SQLite数据库实现联系人管理功能&#xff0c;并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能&#xff0c;同时可以最小化到系统…...

纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join

纯 Java 项目&#xff08;非 SpringBoot&#xff09;集成 Mybatis-Plus 和 Mybatis-Plus-Join 1、依赖1.1、依赖版本1.2、pom.xml 2、代码2.1、SqlSession 构造器2.2、MybatisPlus代码生成器2.3、获取 config.yml 配置2.3.1、config.yml2.3.2、项目配置类 2.4、ftl 模板2.4.1、…...

群晖NAS如何在虚拟机创建飞牛NAS

套件中心下载安装Virtual Machine Manager 创建虚拟机 配置虚拟机 飞牛官网下载 https://iso.liveupdate.fnnas.com/x86_64/trim/fnos-0.9.2-863.iso 群晖NAS如何在虚拟机创建飞牛NAS - 个人信息分享...