当前位置: 首页 > news >正文

机器学习深度学习——权重衰减

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——模型选择、欠拟合和过拟合
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

权重衰减

  • 讨论(思维过一下,后面会总结)
  • 权重衰减
    • 使用均方范数作为硬性限制
    • 使用均方范数作为柔性限制
    • 对最优解的影响
    • 参数更新法则
    • 总结
  • 高维线性回归
  • 从零开始实现
    • 初始化模型参数
    • 定义L2范数乘法
    • 定义训练代码实现
    • 忽略正则化直接训练
    • 使用权重衰减
  • 简洁实现

讨论(思维过一下,后面会总结)

前一节已经描述了过拟合的问题,本节将会介绍一些正则化模型的技术。
之前用了多项式回归的例子,我们可以通过调整拟合多项式的阶数来限制模型容量。而限制特征数量是缓解过拟合的一种常用技术。然而,我们还需要考虑高维输入可能发生的情况。多项式对多变量的自然扩展称为单项式,也可以说是变量幂的成绩。单项式的阶数是幂的和。例如,x12x2和x3x52都是3次单项式。
随着阶数d的增长,带有阶数d的项数迅速增加。给定k个变量,阶数为d的项的个数为:
C k − 1 + d k − 1 = ( k − 1 + d ) ! ( d ) ! ( k − 1 ) ! C_{k-1+d}^{k-1}=\frac{(k-1+d)!}{(d)!(k-1)!} Ck1+dk1=(d)!(k1)!(k1+d)!
因此即使是阶数上的微小变化,也会显著增加我们模型的复杂性。仅仅通过简单的限制特征数量,可能仍然使模型在过简单和过复杂中徘徊。我们需要一个更细粒度的工具来调整函数的复杂性,使其达到一个合适的平衡位置。
在之前已经描述了L2范数和L1范数。
在训练参数化机器学习模型时,权重衰减是最广泛使用的正则化的技术之一,它通常被称为L2正则化。这项技术通过函数与0的距离来衡量函数的复杂度,因为所有的函数f中,f=0在某种意义上是最简单的。但是衡量函数f与0的距离并不简单,这也没有一个正确的答案。
一种简单的方法是通过线性函数f(x)=wTx中的某个向量的范数来度量其复杂性,例如||w||2。要保证权重向量比较小,最常用方法是将其范数作为惩罚项加到最小化损失的问题中。将原来的训练目标最小化训练标签上的预测损失,调整为最小化预测损失和惩罚项之和。如果我们的权重向量增长的太大,我们的学习算法可能更集中于最小化权重范数||w||2。回归线性回归,我们的损失由下式给出:
L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w T x ( i ) + b − y ( i ) ) 2 L(w,b)=\frac{1}{n}\sum_{i=1}^n\frac{1}{2}(w^Tx^{(i)}+b-y^{(i)})^2 L(w,b)=n1i=1n21(wTx(i)+by(i))2
其中,x(i)是样本i的特征,y(i)是样本i的标签,(w,b)是权重和偏置参数。
为了乘法权重向量的大小,我们现在在损失函数添加||w||2,模型如何平衡这个新的额外乘法的损失?我们通过正则化常数λ来描述这种权衡(这是一个非负超参数,我们使用验证数据拟合):
L ( w , b ) + λ 2 ∣ ∣ w ∣ ∣ 2 L(w,b)+\frac{\lambda}{2}||w||^2 L(w,b)+2λ∣∣w2
对于λ=0,我们恢复了原来的损失函数。对于λ>0,我们限制||w||的大小。这里我们仍然除以2(当我们取一个二次函数的导数时, 2和1/2会抵消)。
对于范数的选择,可以提出两个问题:
1、为什么不选择欧几里得距离?
2、为什么不用L1范数?
对于第1个问题:这样做就是为了通过平方去掉L2范数的平方根,留下权重向量每个分量的平方和,这样就很好进行求导了(此时导数的和就等于和的导数)
对于第2个问题:L2范数比起L1范数,对权重向量的大分量施加了巨大的惩罚,在这里还是L2更适合。
那么L2正则化回归的小批量随机梯度下降更新如下式:
w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w T x ( i ) + b − y ( i ) ) w←(1-ηλ)w-\frac{η}{|B|}\sum_{i∈B}x^{(i)}(w^Tx^{(i)}+b-y^{(i)}) w(1ηλ)wBηiBx(i)(wTx(i)+by(i))
为啥是这个结果可以看后面的推导,这个结论说明:我们根据估计值和观测值之间的差距来更新w,同时也在试图缩小w的大小,这就叫权重衰减(或权重衰退)。
权重衰减为我们提供了一种连续的机制来调整函数复杂度,较小的λ值对应较少约束的w,较大的λ值对w的约束会更大。

权重衰减

使用均方范数作为硬性限制

1、通过限制参数值的选择范围来控制模型容量:
m i n l ( w , b ) 其中 ∣ ∣ w ∣ ∣ 2 ≤ θ min l(w,b)其中||w||^2≤θ minl(w,b)其中∣∣w2θ
2、通常不限制偏移b(其实限制不限制都差不多)
3、更小的θ意味着更强的正则项

使用均方范数作为柔性限制

1、对每个θ,都可以找到λ使得之前的目标函数等价于
m i n l ( w , b ) + λ 2 ∣ ∣ w ∣ ∣ 2 min l(w,b)+\frac{\lambda}{2}||w||^2 minl(w,b)+2λ∣∣w2
2、上式通过拉格朗日乘子就能证明
3、超参数λ控制了正则项的重要程度
λ = 0 :无作用 λ → ∞ : w ∗ → 0 \lambda=0:无作用\\ \lambda→∞:w^*→0 λ=0:无作用λw0

对最优解的影响

一张图片就能看出来:
在这里插入图片描述

参数更新法则

计算梯度:
∂ ∂ w ( l ( w , b ) + λ 2 ∣ ∣ w ∣ ∣ 2 ) = ∂ l ( w , b ) ∂ w + λ w \frac{\partial}{\partial w}(l(w,b)+\frac{\lambda}{2}||w||^2)=\frac{\partial l(w,b)}{\partial w}+\lambda w w(l(w,b)+2λ∣∣w2)=wl(w,b)+λw
时间t更新参数:
w t + 1 = w t − η ∂ ∂ w 把 ∂ ∂ w 用上式带入,得: w t + 1 = ( 1 − η λ ) w t − η ∂ l ( w t , b t ) ∂ w t w_{t+1}=w_t-η\frac{\partial}{\partial w}\\把\frac{\partial}{\partial w}用上式带入,得:\\w_{t+1}=(1-η\lambda)w_t-η\frac{\partial l(w_t,b_t)}{\partial w_t} wt+1=wtηww用上式带入,得:wt+1=(1ηλ)wtηwtl(wt,bt)
通常ηλ<1,在深度学习中叫作权重衰退
(可以和之前的梯度做比较,会发现也就是在w之前加了个(1-ηλ)的系数,这样就可以做到权重衰退)

总结

1、权重衰退通过L2正则项使得模型参数不会太大,从而控制模型复杂度。
2、正则项权重是控制模型复杂度的超参数。

高维线性回归

我们通过简单例子来演示权重衰减

import torch
from torch import nn
from d2l import torch as d2l

生成一些人工数据集,生成公式如下:
y = 0.05 + ∑ i = 1 d 0.01 x i + σ 其中 σ 符合正态分布 N ( 0 , 0.0 1 2 ) y=0.05+\sum_{i=1}^d0.01x_i+\sigma\\ 其中\sigma符合正态分布N(0,0.01^2) y=0.05+i=1d0.01xi+σ其中σ符合正态分布N(0,0.012)
为了把过拟合体现的更明显,我们的训练集就只有20个(数据越简单越容易过拟合)

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

从零开始实现

下面从头开始实现权重衰减,只需将L2的平方惩罚添加到原始目标函数值。

初始化模型参数

def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]

定义L2范数乘法

我们这边是在原来的L2基础上加上了平方,从而去除了他的根号。

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

定义训练代码实现

下面将模型与训练数据集进行拟合,并在测试数据集上进行评估。其中线性网络与平方损失是没有变化的,唯一的变化只是现在增加了惩罚项。

def train(lambd):w, b = init_params()# lambda X相当于定义了一个net()函数,不好理解少用net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加L2范数惩罚项# 广播机制使L2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

忽略正则化直接训练

此时,我们使用lambd=0来禁止权重衰减,运行代码以后,训练误差会减少,但是测试误差却没有减少,说明出现了严重的过拟合。

train(lambd=0)
d2l.plt.show()

运行结果:

w的L2范数是: 13.375638008117676

运行图片:
在这里插入图片描述

使用权重衰减

这里的训练误差增大,但测试误差减小,这正是期望从正则化中得到的效果。

train(lambd=3)
d2l.plt.show()

运行结果:

w的L2范数是: 0.35898885130882263

运行图片:
在这里插入图片描述

简洁实现

由于权重衰减在神经网络优化中很常用,深度学习框架就将权重衰减集成到优化算法中,以便与任何损失函数结合使用。此外,这种集成还有计算上的好处,允许在不增加任何额外的计算开销的情况下向算法中添加权重衰减。由于更新的权重衰减部分仅依赖于每个参数的当前值,因此优化器必须至少接触每个参数一次。
在下面的代码中,我们在实例化优化器时直接通过weight_decay指定weight decay超参数。 默认情况下,PyTorch同时衰减权重和偏移。 这里我们只为权重设置了weight_decay,所以偏置参数b不会衰减。

import torch
from torch import nn
from d2l import torch as d2ln_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params": net[0].weight, 'weight_decay': wd},{"params": net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())

测试运行:

train_concise(0)
d2l.plt.show()

运行结果:

w的L2范数: 14.566418647766113

运行图片:
在这里插入图片描述
测试运行:

train_concise(3)
d2l.plt.show()

运行结果:

w的L2范数: 0.45850494503974915

运行图片:
在这里插入图片描述
运行后的图和之前的图相同,但是它们运行得更快,更容易实现。对于复杂问题,这一好处将变得更加明显。
后序的内容,在深层网络的所有层上,都会应用权重衰减。

相关文章:

机器学习深度学习——权重衰减

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——模型选择、欠拟合和过拟合 📚订阅专栏:机器学习&&深度学习 希望文章对你…...

【Linux】线程互斥 -- 互斥锁 | 死锁 | 线程安全

引入互斥初识锁互斥量mutex锁原理解析 可重入VS线程安全STL中的容器是否是线程安全的? 死锁 引入 我们写一个多线程同时访问一个全局变量的情况(抢票系统),看看会出什么bug: // 共享资源, 火车票 int tickets 10000; //新线程执行方法 vo…...

【vue-pdf】PDF文件预览插件

1 插件安装 npm install vue-pdf vue-pdf GitHub:https://github.com/FranckFreiburger/vue-pdf#readme 参考文档:https://www.cnblogs.com/steamed-twisted-roll/p/9648255.html catch报错:vue-pdf组件报错vue-pdf Cannot read properti…...

Flink集群运行模式--Standalone运行模式

Flink集群运行模式--Standalone运行模式 一、实验目的二、实验内容三、实验原理四、实验环境五、实验步骤5.1 部署模式5.1.1 会话模式(Session Mode)5.1.2 单作业模式(Per-Job Mode)5.1.3 应用模式(Application Mode&a…...

Spring整合JUnit实现单元测试

Spring整合JUnit实现单元测试 一、引言 在软件开发过程中,单元测试是保证代码质量和稳定性的重要手段。JUnit是一个流行的Java单元测试框架,而Spring是一个广泛应用于Java企业级开发的框架。本文将介绍如何使用Spring整合JUnit实现单元测试&#xff0c…...

Spring Boot学习路线1

Spring Boot是什么? Spring Boot是基于Spring Framework构建应用程序的框架,Spring Framework是一个广泛使用的用于构建基于Java的企业应用程序的开源框架。Spring Boot旨在使创建独立的、生产级别的Spring应用程序变得容易,您可以"只是…...

管理类联考——写作——论说文——实战篇——标题篇

角度3——4种材料类型、4个立意对象、5种写作态度 经过审题立意后,我们要根据我们的立意,确定一个主题,这个主题必须通过文章的标题直接表达出来。 标题的基本要求 主题清晰,态度明确 第一,阅卷人看到一篇论说文的标…...

idea中设置maven本地仓库和自动下载依赖jar包

1.下载maven 地址&#xff1a;maven3.6.3 解压缩在D:\apache-maven-3.6.3-bin\apache-maven-3.6.3\目录下新建文件夹repository打开apache-maven-3.6.3-bin\apache-maven-3.6.3\conf文件中的settings.xml编辑&#xff1a;新增本地仓库路径 <localRepository>D:\apache-…...

前缀和差分

前缀和 前缀和&#xff1a;一段序列里的前n项和 给出n个数&#xff0c;在给出q次问询&#xff0c;每次问询给出L、R&#xff0c;快速求出每组数组中一段L至R区间的和 给出一段数组&#xff0c;每次问询为求出l到r区间的和 普通方法&#xff1a;L到R进行遍历&#xff0c;那么…...

Golang GORM 模型定义

模型定义 参考文档&#xff1a;https://gorm.io/zh_CN/docs/models.html 模型一般都是普通的 Golang 的结构体&#xff0c;Go的基本数据类型&#xff0c;或者指针。 模型是标准的struct,由Go的基本数据类型、实现了Scanner和Valuer接口的自定义类型及其指针或别名组成&#x…...

微服务的各种边界在架构演进中的作用

演进式架构 在微服务设计和实施的过程中&#xff0c;很多人认为&#xff1a;“将单体拆分成多少个微服务&#xff0c;是微服务的设计重点。”可事实真的是这样吗&#xff1f;其实并非如此&#xff01; Martin Fowler 在提出微服务时&#xff0c;他提到了微服务的一个重要特征—…...

使用 docker-compose 一键部署多个 redis 实例

目录 1. 前期准备 2. 导入镜像 3. 部署redis master脚本 4. 部署redis slave脚本 5. 模板文件 6. 部署redis 7. 基本维护 1. 前期准备 新部署前可以从仓库&#xff08;repository&#xff09;下载 redis 镜像&#xff0c;或者从已有部署中的镜像生成文件&#xff1a; …...

14-测试分类

1.按照测试对象划分 ①界面测试 软件只是一种工具&#xff0c;软件与人的信息交流是通过界面来进行的&#xff0c;界面是软件与用户交流的最直接的一层&#xff0c;界面的设计决定了用户对设计的软件的第一印象。界面如同人的面孔&#xff0c;具有吸引用户的直接优势&#xf…...

打开域名跳转其他网站,官网被黑解决方案(Linux)

某天打开网站&#xff0c;发现进入首页&#xff0c;马上挑战到其他赌博网站。 事不宜迟&#xff0c;不能让客户发现&#xff0c;得马上解决 我的网站跳转到这个域名了 例如网站跳转到 k77.cc 就在你们部署的代码的当前文件夹下面&#xff0c;执行下如下命令 find -type …...

redis总结

1.redis redis高性能的key-value数据库&#xff0c;支持持久化&#xff0c;不仅仅支持简单的key-value&#xff0c;还提供了list&#xff0c;set&#xff0c;zset&#xff0c;hash等数据结构的存储&#xff0c;支持数据的备份&#xff08;master-slave模式&#xff09; redis&…...

现代C++中的从头开始深度学习:激活函数

一、说明 让我们通过在C中实现激活函数来获得乐趣。人工神经网络是生物启发模型的一个例子。在人工神经网络中&#xff0c;称为神经元的处理单元被分组在计算层中&#xff0c;通常用于执行模式识别任务。 在这个模型中&#xff0c;我们通常更喜欢控制每一层的输出以服从一些约束…...

python怎么实现tcp和udp连接

目录 什么是tcp连接 什么是udp连接 python怎么实现tcp和udp连接 什么是tcp连接 TCP&#xff08;Transmission Control Protocol&#xff09;连接是一种网络连接&#xff0c;它提供了可靠的、面向连接的数据传输服务。 在TCP连接中&#xff0c;通信的两端&#xff08;客户端和…...

java设计模式-观察者模式(jdk内置)

上一篇我们学习了 观察者模式。 观察者和被观察者接口都是我们自己定义的&#xff0c;整个设计模式我们从无到有都是自己设计的&#xff0c;其实&#xff0c;java已经内置了这个设计模式&#xff0c;我们只需要定义实现类即可。 下面我们不多说明&#xff0c;直接示例代码&am…...

秒级体验本地调试远程 k8s 中的服务

点击上方蓝色字体&#xff0c;选择“设为星标” 回复”云原生“获取基础架构实践 背景 在这个以k8s为云os的时代&#xff0c;程序员在日常的开发过程中&#xff0c;肯定会遇到各种问题&#xff0c;比如&#xff1a;本地开发完&#xff0c;需要部署到远程k8s集群&#xff0c;本地…...

CV前沿方向:Visual Prompting 视觉提示工程下的范式

prompt在视觉领域&#xff0c;也越来越重要&#xff0c;在图像生成&#xff0c;作为一种可控条件&#xff0c;增进交互和可控性&#xff0c;在多模态理解方面&#xff0c;指令prompt也使得任务灵活通用。视觉提示工程&#xff0c;已然成为CV一个前沿方向&#xff01; 下面来看看…...

物联网技术发展与应用研究分析

文章目录 引言一、物联网的基本架构&#xff08;一&#xff09;感知层&#xff08;二&#xff09;网络层&#xff08;三&#xff09;平台层&#xff08;四&#xff09;应用层 二、物联网的关键技术&#xff08;一&#xff09;传感器技术&#xff08;二&#xff09;通信技术&…...

WPF技术体系与现代化样式

目录 ​​1 WPF技术架构解析​​ ​​1.1 技术演进与定位​​ ​​1.2 核心机制对比​​ ​​2 样式与资源系统​​ ​​2.1 资源(Resource)定义与作用域​​ ​​2.2 样式(Style)与触发器​​ ​​3 开发环境配置(.NET 8)​​ ​​3.1 安装流程​​ ​​3.2 项目结…...

猜字符位置游戏-position gasses

import java.util.*;public class Main {/*字符猜位置游戏;每次提交只能被告知答对几个位置;根据提示答对的位置数推测出每个字符对应的正确位置;*/public static void main(String[] args) {char startChar A;int gameLength 8;List<String> ballList new ArrayList&…...

【后端】RPC

不定期更新。 定义 RPC 是 Remote Procedure Call 的缩写&#xff0c;中文通常翻译为远程过程调用。作用 简化分布式系统开发。实现微服务架构&#xff0c;便于模块化、复用。提高系统性能和可伸缩性。提供高性能通信、负载均衡、容错重试机制。 在现代分布式系统、微服务架构…...

springboot的test模块使用Autowired注入失败

springboot的test模块使用Autowired注入失败的原因&#xff1a; 注入失败的原因可能是用了junit4的包的Test注解 import org.junit.Test;解决方法&#xff1a;再加上RunWith(SpringRunner.class)注解即可 或者把Test由junit4改成junit5的注解&#xff0c;就不用加上RunWith&…...

PySide6 GUI 学习笔记——常用类及控件使用方法(单行文本控件QLineEdit)

文章目录 QLineEdit 介绍常用方法QLineEdit.EchoMode 取值光标相关方法文本选择方法输入格式化字符&#xff08;Input Mask&#xff09;常用信号QLineEdit 实例 QLineEdit 介绍 QLineEdit 是 PySide6&#xff08;Qt for Python&#xff09;中用于单行文本输入的控件。它支持文本…...

重构城市应急指挥布控策略 ——无人机智能视频监控的破局之道

在突发事件、高空巡查、边远区域布控中&#xff0c;传统摄像头常常“看不到、跟不上、调不动”。无人机智能视频监控系统&#xff0c;打破地面视角局限&#xff0c;以“高空布控 AI分析 实时响应”赋能政企单位智能化管理。在城市应急指挥中心的大屏上&#xff0c;一场暴雨正…...

Neovim - 打造一款属于自己的编辑器(一)

文章目录 前言&#xff08;劝退&#xff09;neovim 安装neovim 配置配置文件位置第一个 hello world 代码拆分 neovim 配置正式配置 neovim基础配置自定义键位Lazy 插件管理器配置tokyonight 插件配置BufferLine 插件配置自动补全括号 / 引号 插件配置 前言&#xff08;劝退&am…...

STM32学习笔记:定时器(TIM)原理与应用(详解篇)

前言 定时器是STM32微控制器中最重要且最常用的外设之一&#xff0c;它不仅能提供精确的定时功能&#xff0c;还能实现PWM输出、输入捕获、编码器接口等多种功能。本文将全面介绍STM32的通用定时器&#xff0c;包括其工作原理、配置方法和典型应用。 一、STM32定时器概述 定…...

【计算机组成原理】计算机硬件的基本组成、详细结构、工作原理

引言 计算机如同现代科技的“大脑”&#xff0c;其硬件结构的设计逻辑承载着信息处理的核心奥秘。从早期程序员手动输入指令的低效操作&#xff0c;到冯诺依曼提出“存储程序”概念引发的革命性突破&#xff0c;计算机硬件经历了从机械操控到自动化逻辑的蜕变。本文将深入拆解…...