当前位置: 首页 > news >正文

深度学习——权重衰减(weight_decay)

深度学习——权重衰减(weight_decay)

文章目录

  • 前言
  • 一、权重衰减
    • 1.1. 范数与权重衰减
    • 1.2. 高维线性回归
    • 1.3. 从零开始实现
      • 1.3.1.初始化模型参数
      • 1.3.2. 定义L₂范数惩罚
      • 1.3.3. 定义训练代码实现
      • 1.3.4. 不管正则化直接训练
      • 1.3.5. 使用权重衰减
    • 1.4. 简洁实现
  • 总结


前言

上一章描述了过拟合的问题,本章我们将介绍一些正则化模型的技术。如权重衰减

参考书:
《动手学深度学习》


一、权重衰减

1.1. 范数与权重衰减

在训练参数化机器学习模型时,权重衰减(weight decay)是最广泛使用的正则化的技术之一, 它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,

因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0),在某种意义上是最简单的。

但是我们应该如何精确地测量一个函数和零之间的距离呢?

一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx 中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2

要保证权重向量比较小,最常用方法是将其范数作为惩罚项加到最小化损失的问题中
即将原来的训练目标最小化训练标签上的预测损失,调整为最小化预测损失和惩罚项之和

现在,如果我们的权重向量增长的太大,我们的学习算法可能会更集中于最小化权重范数 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2。这正是我们想要的。

我们的损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

为了惩罚权重向量的大小,我们必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2

但是模型应该如何平衡这个新的额外惩罚的损失?
实际上,我们通过正则化常数 λ \lambda λ来描述这种权衡,这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。

为什么在这里我们使用平方范数而不是标准范数(即欧几里得距离)?

我们这样做是为了便于计算。通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。

此外,为什么我们首先使用 L 2 L_2 L2范数,而不是 L 1 L_1 L1范数。

L 2 L_2 L2正则化线性模型构成经典的岭回归(ridge regression)算法,
L 1 L_1 L1正则化线性回归是统计学中类似的基本模型,通常被称为套索回归(lasso regression)。

使用 L 2 L_2 L2范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。在实践中,这可能使它们对单个变量中的观测误差更为稳定

相比之下, L 1 L_1 L1惩罚会导致模型将权重集中在一小部分特征上,
而将其他权重清除为零
。这称为特征选择(feature selection),可能是其他场景下需要的。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w。然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减。我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。

与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。 较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w,而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。

是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。通常,网络输出层的偏置项不会被正则化。

1.2. 高维线性回归

我们通过一个简单的例子来演示权重衰减。

首先,我们像以前一样生成一些数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200
并使用一个只包含20个样本的小训练集。

import torch
from d2l import torch as d2l
from torch import nnn_train,n_test,num_inputs,batch_size = 20,100,200,5
true_w,true_b = torch.ones((num_inputs,1))*0.01,0.05"""
使用d2l.synthetic_data函数生成了训练数据和测试数据,并使用d2l.load_array函数将数据加载为迭代器。
"""
train_data = d2l.synthetic_data(true_w,true_b,n_train)
train_iter = d2l.load_array(train_data,batch_size)test_data = d2l.synthetic_data(true_w,true_b,n_test)
test_iter = d2l.load_array(test_data,batch_size,is_train= False)
#这里设置is_train=False表示测试数据不用于模型训练,只用于评估模型的性能。

1.3. 从零开始实现

下面我们将从头开始实现权重衰减,只需将 L 2 L_2 L2的平方惩罚添加到原始目标函数中。

1.3.1.初始化模型参数

#初始化模型参数
#我们将定义一个函数来随机初始化模型参数
def init_params():w = torch.normal(0,1,size=(num_inputs,1),requires_grad= True)b = torch.zeros(1,requires_grad=True)return [w,b]

1.3.2. 定义L₂范数惩罚


#定义L2范数惩罚(实现这一惩罚最方便的方法是对所有项求平方后并将它们求和)
def l2_penalty(w):return torch.sum(w.pow(2))/2 #将权重w的平方和除以2,除以2是为了方便计算梯度

1.3.3. 定义训练代码实现

#定义训练代码实现
def train(lambd):w,b  = init_params()net,loss = lambda x: d2l.linreg(x,w,b),d2l.squared_lossnum_epochs,lr = 100,0.003animator = d2l.Animator(xlabel="epochs",ylabel="loss",yscale="log",xlim= [5,num_epochs],legend=["train","test"])for epoch in range(num_epochs):for x,y in train_iter:#增加了L2范数惩罚项#广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(x),y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w,b],lr,batch_size)if (epoch+1)%5 ==0:animator.add(epoch+1,(d2l.evaluate_loss(net,train_iter,loss),d2l.evaluate_loss(net,test_iter,loss)))print("w的L2范数是:",torch.norm(w).item())

1.3.4. 不管正则化直接训练

#现在用`lambd = 0`禁用权重衰减后运行这个代码。
#注意,这里训练误差有了减少,但测试误差没有减少,这意味着出现了严重的过拟合。train(lambd= 0)#结果:
w的L2范数是: 13.981727600097656

在这里插入图片描述

1.3.5. 使用权重衰减

#使用权重衰减来运行代码。
#注意,在这里训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。train(lambd= 3)#结果:
w的L2范数是: 0.3319331705570221d2l.plt.show()

在这里插入图片描述

1.4. 简洁实现

深度学习框架为了便于我们使用权重衰减,将权重衰减集成到优化算法中,以便与任何损失函数结合使用。

#在下面的代码中,我们在实例化优化器时直接通过`weight_decay`指定weight decay超参数。
#默认情况下,PyTorch同时衰减权重和偏移。
#这里我们只为权重设置了`weight_decay`,所以偏置参数$b$不会衰减。def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs,1))for param in net.parameters():param.data.normal_() #使用正态分布随机初始化参数loss = nn.MSELoss(reduction="none") #定义损失函数为均方误差损失num_epochs,lr = 100,0.003#偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,"weight_decay":wd},{"params":net[0].bias}],lr = lr)  #net[0].weight表示模型的权重参数,net[0].bias表示模型的偏置参数。weight_decay参数用于设置权重衰减的强度。animator = d2l.Animator(xlabel="epochs",ylabel="loss",yscale="log",xlim=[5,num_epochs],legend=["train","test"])for epoch in range(num_epochs):for x,y in train_iter:trainer.zero_grad() #清零梯度,以防止梯度累积l = loss(net(x),y)l.mean().backward() #计算损失的平均值,并进行反向传播,计算梯度trainer.step() #更新模型的参数,执行一步优化器的更新if (epoch+1)%5 == 0:animator.add(epoch+1,(d2l.evaluate_loss(net,train_iter,loss),d2l.evaluate_loss(net,test_iter,loss)))print("w的L2范数:", net[0].weight.norm().item())  #打印模型权重的L2范数,用于评估模型的复杂度。train_concise(0)
train_concise(3)
d2l.plt.show()#结果:
w的L2范数: 13.411089897155762
w的L2范数: 0.3319282829761505

在这里插入图片描述

在这里插入图片描述


总结

为了有效防止模型的过拟合,降低模型的复杂度,提高泛化能力,本章简单记录了一种常见的正则化技术:权重衰减。简单来说权重衰减是通过在损失函数中添加一个正则化项来实现的。这个正则化项通常是模型参数的L2范数(平方和)或L1范数(绝对值和),通过限制模型参数的大小来防止过拟合。

我独泊兮其未兆,如婴儿之未孩,傫傫(lèi lèi)兮,若无所归。

–2023-10-2 进阶篇

相关文章:

深度学习——权重衰减(weight_decay)

深度学习——权重衰减(weight_decay) 文章目录 前言一、权重衰减1.1. 范数与权重衰减1.2. 高维线性回归1.3. 从零开始实现1.3.1.初始化模型参数1.3.2. 定义L₂范数惩罚1.3.3. 定义训练代码实现1.3.4. 不管正则化直接训练1.3.5. 使用权重衰减 1.4. 简洁实现 总结 前言…...

nignx如何部署让前端不用清缓存就可以部署

在Nginx中,可以使用以下方法来部署前端应用程序,使前端用户无需清空缓存即可进行部署: 1、使用版本号:在前端应用程序的构建过程中,可以添加一个独特的版本号到应用程序的名称中。每次部署时,将版本号更新…...

CSS3实现动画加载效果

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>加载效果</title><link rel"style…...

springboot定时任务Scheduled使用和弊端分析

1.springboot定时任务Scheduled使用说明: (1)创建定时任务类 import com.one.utils.DateUtil; import org.springframework.beans.factory.annotation.Autowired; import...

openGauss学习笔记-93 openGauss 数据库管理-访问外部数据库-oracle_fdw

文章目录 openGauss学习笔记-93 openGauss 数据库管理-访问外部数据库-oracle_fdw93.1 编译oracle_fdw93.2 使用oracle_fdw93.3 常见问题93.4 注意事项 openGauss学习笔记-93 openGauss 数据库管理-访问外部数据库-oracle_fdw openGauss的fdw实现的功能是各个openGauss数据库及…...

【Git】Git下载安装环境配置 下载速度慢的解决方案

这里写自定义目录标题 介绍一、下载官网下载镜像站 二、安装安装成功 三、Git三种界面介绍Git cmd界面展示git bash界面展示git GUI界面展示 四、环境配置配置流程1、打开环境变量界面2、添加环境变量 /删除环境变量3、在变量中找到Git\cmd的值就表示配置成功4、没有找到点击新…...

常见源协议介绍

开源协议&#xff08;Open Source License&#xff09;是一种法律文档&#xff0c;用于规定如何使用、修改和分发开源软件和其他开源项目的规则和条件。这些协议允许创作者或组织将其创造的代码或作品以开放源代码的形式共享给他人&#xff0c;以促进协作、创新和知识共享。常见…...

大数据概述(林子雨慕课课程)

文章目录 1. 大数据概述1.1 大数据概念和影响1.2 大数据的应用1.3 大数据的关键技术1.4 大数据与云计算和物联网的关系云计算物联网 1. 大数据概述 大数据的四大特点&#xff1a;大量化、快速化、多样化、价值密度低 1.1 大数据概念和影响 大数据摩尔定律 大数据由结构化和非…...

ES6 class类关键字super

super关键字 在 JavaSCript 中&#xff0c;能通过 extends 关键字去继承父类 super 关键字在子类中有以下用法&#xff1a; 当成函数调用 super() 作为 "属性查询" super.prop 和 super[expr] super() super 作为函数调用时&#xff0c;代表父类的构造函数。 ES6 要求…...

C++并发与多线程(4) | 传递临时对象作为线程参数的一些问题Ⅰ

一、陷阱1 写一个传递临时对象作为线程参数的示例: #include <iostream> #include <vector> #include <thread> using namespace std;void myprint(const int& i, char* pmybuf) {cout << i << endl;cout << pmybuf << endl;r…...

CentOS Integration SIG 正式成立

导读CentOS 董事会已批准成立 CentOS Integration Special Interest Group (SIG)。该小组旨在帮助那些在 Red Hat Enterprise Linux (RHEL) 或特别是其上游 CentOS Stream 上构建产品和服务的人员&#xff0c;验证其能否在未来版本中继续运行。 红帽 RHEL CI 工程师 Aleksandr…...

智能AI系统源码ChatGPT系统源码+详细搭建部署教程+AI绘画系统+已支持OpenAI GPT全模型+国内AI全模型

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统&#xff0c;支持OpenAI GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作Chat…...

软考程序员考试大纲(2023)

文章目录 前言一、考试说明1.考试目标2.考试要求3&#xff0e;考试科目设置 二、考试范围考试科目1&#xff1a;计算机与软件工程基本知识1&#xff0e;计算机科学基础2&#xff0e;计算机系统基础知识3&#xff0e;系统开发和运行知识4&#xff0e;网络与信息安全基础知识5&am…...

【重拾C语言】七、指针(一)指针与变量、指针操作、指向指针的指针

目录 前言 七、指针 7.1 指针与变量 7.1.1 指针类型和指针变量 7.1.2 指针所指变量 7.1.3 空指针、无效指针 7.2 指针操作 7.2.1 指针的算术运算 7.2.2 指针的比较 7.2.3 指针的递增和递减 7.3 指向指针的指针 前言 指针是C语言中一个重要的概念正确灵活运用指针 可…...

Kafka源码简要分析

目录 一、生产者的初始化流程 二、生产者到缓冲队列的流程 三、Sender拉取数据到Kafka流程 四、消费者初始化 五、主题订阅原理 六、消费者抓取数据原理 七、消费者组初始化 八、消费者组消费流程 九、提交offset原理 一、生产者的初始化流程 首先获取事务id和客户端…...

react 按住ctrl键,点击时会出现菜单的问题修复

问题描述&#xff1a;我需要按住crtl键&#xff0c;然后鼠标点击后做一些逻辑操作&#xff0c;但是出现如下问题 问题一&#xff1a;按住ctrl键后&#xff0c;点击时不触发click事件&#xff0c;只触发 mousedown和mouseup事件。 问题二&#xff1a;按住ctrl键点击时出现菜单…...

【虚拟机栈】

文章目录 1. 虚拟机栈概述2. 局部变量表(Local Variables)3. 操作数栈4. 动态链接4.1 方法的调用&#xff1a;解析与分配 5. 方法返回地址6. 栈的相关面试题 1. 虚拟机栈概述 每个线程在创建时都会创建一个虚拟机栈&#xff0c;其内部保存一个个的栈帧&#xff08;Stack Frame…...

Linux系列讲解 —— 【fsck】检查并修复Linux文件系统

当文件系统出现损坏时&#xff0c;例如文件无法查看&#xff0c;删除等&#xff0c;可以使用 fsck&#xff08;File System Consistency Check&#xff09;进行修复。但是需要注意fsck在修复时&#xff0c;如果检查出某个文件有问题&#xff0c;可能会向用户请求删除。所以&…...

gitlab突然提示我要输入密码了。

用了很长时间的一个gitlab库&#xff0c;今天提交代码的时候突然提示我输入密码了&#xff0c;并且用户还是gitxx.xx.xx.xx的&#xff0c;瞬间懵逼。 想想原因&#xff0c;可能是因为我不久前设置了本地对另外一个git库的远程访问&#xff0c;用的是ssh&#xff0c;操作过程中可…...

业务测试常见问题(一)

如何多维度的分析一个需求&#xff1f; 功能维度&#xff1a;需求中所描述的功能是否实现&#xff0c;与用户的需求是否一致&#xff0c;是否完整符合用户的需求等。 安全性维度&#xff1a;是否有安全漏洞&#xff0c;是否存在未授权访问漏洞等&#xff0c;以保证系统的安全性…...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…...

C++实现分布式网络通信框架RPC(3)--rpc调用端

目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中&#xff0c;我们已经大致实现了rpc服务端的各项功能代…...

synchronized 学习

学习源&#xff1a; https://www.bilibili.com/video/BV1aJ411V763?spm_id_from333.788.videopod.episodes&vd_source32e1c41a9370911ab06d12fbc36c4ebc 1.应用场景 不超卖&#xff0c;也要考虑性能问题&#xff08;场景&#xff09; 2.常见面试问题&#xff1a; sync出…...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括&#xff1a;采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中&#xff0c;设置任务排序规则尤其重要&#xff0c;因为它让看板视觉上直观地体…...

django filter 统计数量 按属性去重

在Django中&#xff0c;如果你想要根据某个属性对查询集进行去重并统计数量&#xff0c;你可以使用values()方法配合annotate()方法来实现。这里有两种常见的方法来完成这个需求&#xff1a; 方法1&#xff1a;使用annotate()和Count 假设你有一个模型Item&#xff0c;并且你想…...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异&#xff0c;它们的数据同步要求既要保持数据的准确性和一致性&#xff0c;又要处理好性能问题。以下是一些主要的技术要点&#xff1a; 数据结构差异 数据类型差异&#xff…...

算法笔记2

1.字符串拼接最好用StringBuilder&#xff0c;不用String 2.创建List<>类型的数组并创建内存 List arr[] new ArrayList[26]; Arrays.setAll(arr, i -> new ArrayList<>()); 3.去掉首尾空格...

安宝特方案丨船舶智造的“AR+AI+作业标准化管理解决方案”(装配)

船舶制造装配管理现状&#xff1a;装配工作依赖人工经验&#xff0c;装配工人凭借长期实践积累的操作技巧完成零部件组装。企业通常制定了装配作业指导书&#xff0c;但在实际执行中&#xff0c;工人对指导书的理解和遵循程度参差不齐。 船舶装配过程中的挑战与需求 挑战 (1…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 &#xff1a;HyperText Transfer Protocol&#xff08;超文本传输协议&#xff09; 默认端口 &#xff1a;HTTP 使用 80 端口&#xff0c;HTTPS 使用 443 端口。 请求方法 &#xff1a; GET &#xff1a;用于获取资源&#xff0c;…...

深度学习水论文:mamba+图像增强

&#x1f9c0;当前视觉领域对高效长序列建模需求激增&#xff0c;对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模&#xff0c;以及动态计算优势&#xff0c;在图像质量提升和细节恢复方面有难以替代的作用。 &#x1f9c0;因此短时间内&#xff0c;就有不…...