李沐深度学习记录4:12.权重衰减/L2正则化
权重衰减从零开始实现
#高维线性回归
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l#整个流程是,1.生成标准数据集,包括训练数据和测试数据
# 2.定义线性模型训练
# 模型初始化(函数)、包含惩罚项的损失(函数)
# 定义epochs进行训练,每训练5轮评估一次模型在训练集和测试集的损失,画图显示
# 训练结束后分别查看并比较是否添加范数惩罚项损失对应的训练结果w的L2范数
#生成数据集
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5 #训练数据样本数20,测试样本数100,数据维度200,批量大小5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05 #生成w矩阵(200,1),w值0.01,偏置b为0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train) #生成训练数据集X(20,200),y(20,1),y=Xw+b+噪声,train_data接收返回的X,y
train_iter = d2l.load_array(train_data, batch_size) #传入数据集和批量大小,构造训练数据迭代器
test_data = d2l.synthetic_data(true_w, true_b, n_test) #生成测试数据集
test_iter = d2l.load_array(test_data, batch_size, is_train=False) #构造测试数据迭代器#初始化模型参数
def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]#定义L2范数惩罚项
def l2_penalty(w):return torch.sum(w.pow(2)) / 2 #L2范数公式需要开平方根,但这里L2范数惩罚项是L2范数的平方,所以不需要开平方根了#训练代码
def train(lambd): #输入λ超参数w, b = init_params() #初始化模型参数net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss #net线性模型torch.matmul(X, w) + b;loss是均方误差num_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs): #进行多次迭代训练for X, y in train_iter: #每个epoch,取训练数据# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w) #loss计算加上了λ×范数惩罚项l.sum().backward() #这里计算损失和,下面参数更新时会对梯度求平均再更新参数d2l.sgd([w, b], lr, batch_size) #进行参数更新操作if (epoch + 1) % 5 == 0: #每5次epoch训练,评估一次模型的训练损失和测试损失animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item()) #训练结束后,计算w的L2范数(没有平方)
#λ为0,无正则化项,训练
train(lambd=0)
d2l.plt.show()

#λ为10,有正则化项,训练
train(lambd=5)
d2l.plt.show()

权重衰减的简洁实现
#权重衰减的简洁实现
def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1)) #定义模型for param in net.parameters(): #初始化参数param.data.normal_()loss = nn.MSELoss(reduction='none') #计算loss,这里不包含正则项num_epochs, lr = 100, 0.003# 偏置参数没有衰减#在参数优化部分,计算梯度时加入了权重衰减#所以是计算loss时没计算正则项,只是在计算梯度时加入了权重衰减吗?trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs): #训练100轮for X, y in train_iter: #对于每轮,取数据训练trainer.zero_grad() #梯度清零l = loss(net(X), y) #计算lossl.mean().backward() #反向传播trainer.step() #更新梯度if (epoch + 1) % 5 == 0: #每5轮评估一次模型在测试集和训练集的损失animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
#没有进行权重衰减
train_concise(0)

#进行权重衰减
train_concise(5)

相关文章:
李沐深度学习记录4:12.权重衰减/L2正则化
权重衰减从零开始实现 #高维线性回归 %matplotlib inline import torch from torch import nn from d2l import torch as d2l#整个流程是,1.生成标准数据集,包括训练数据和测试数据 # 2.定义线性模型训练 # 模型初始化(函…...
堆--数组中第K大元素
如果对于堆不是太认识,请点击:堆的初步认识-CSDN博客 解题思路: /*** <h3>求数组中第 K 大的元素</h3>* <p>* 解体思路* <ol>* 1.向小顶堆放入前k个元素* 2.剩余元素* 若 < 堆顶元素, 则略过* …...
ipad使用技巧
1、goodnotes中批量导入pdf文件 法一: 直接参考视频: 【目前为止所知iPad上goodnotes批量导入网盘文件最快的方法】 大致步骤:pdf文件传到百度网盘,然后ES软件登录百度网盘,在goodnotes中导入,选择ES&a…...
Windows系统上使用CLion远程开发Linux程序
CLion远程开发Linux程序 情景说明Ubuntu配置CLion配置同步 情景说明 在Windows系统上使用CLion开发Linux程序,安装CLion集成化开发环境时会自动安装cmake、mingw,代码提示功能也比较友好。 但是在socket开发时,包含sys/socket.h头文件时&am…...
github搜索技巧
指定语言 language:java 比如我要找用java写的含有blog的内容 搜索项目名称包含关键词的内容 vue in:name 其他如项目描述跟项目文档,如下 组合使用 vue in:name,description,readme 根据Star 或者fork的数量来查找 总结 springboot vue stars:>1000 p…...
Python生成器
生成器 Generators 要理解生成器,首先要理解迭代器,迭代器由以下三个部分组成: 可迭代对象(iterable)迭代器(iterator)迭代(iteration) 1. 可迭代对象 只要定义了可以…...
flutter开发实战-使用FutureBuilder异步数据更新Widget
flutter开发实战-使用FutureBuilder异步数据更新Widget 在开发过程中,经常遇到需要依赖异步数据更新Widget的情况,如下载图片后显示Widget,获取到某个数据时候,显示在对应的UI界面上,都可以使用FutureBuilder异步数据…...
1.2 数据模型
思维导图: 前言: **1.2.1 什么是模型** - **定义**:模型是对现实世界中某个对象特征的模拟和抽象。例如,一张地图、建筑设计沙盘或精致的航模飞机都可以视为具体的模型。 - **具体模型与现实生活**:具体模型可以很容…...
【实用工具】谷歌浏览器插件开发指南
谷歌浏览器插件开发指南涉及以下几个方面: 1. 开发环境准备:首先需要安装Chrome浏览器和开发者工具。进入Chrome应用商店,搜索“Extensions Reloader”和“Manifest Viewer”两个插件进行安装,这两个插件可以方便开发和调试。 2…...
应用层协议——DNS、DHCP、HTTP、FTP
目录 1、DNS 协议 1-1)Hosts 文件 1-2)DNS 系统 1-3)域名的组成、分类和树状结构 1-4)DNS 域名服务器类型 1-5)DNS 查询方式 1-6)DNS 域名解析的一般步骤 1-7)对象类型与资源记录 2、D…...
XML文件读写
0、.pro文件添加依赖 QT xml1、使用 QDomDocument 方式 #include <QtXml/QDomDocument> #include <QtXml/QDomProcessingInstruction> #include <QtXml/QDomElement> #include <QFile> #include <QTextStream> #include <QDebug>bo…...
Win11 安装 Vim
安装包: 链接:https://pan.baidu.com/s/1Ru7HhTSotz9mteHug-Yhpw?pwd6666 提取码:6666 双击安装包,一直下一步。 配置环境变量: 先配置系统变量中的path: 接着配置用户变量: 在 cmd 中输入…...
Mac电脑BIM建模软件 Archicad 26 for Mac最新
ARCHICAD 软件特色 智能化 在2D CAD中,所有的建筑构件都由线条构成和表现,仅仅是一些线条的组合而已,当我们阅读图纸的时候是按照制图规范来读取这些信息。我们用一组线条表示平面中的窗,再用另一组不同的线条在立面中表示同一个…...
JavaEE-网络编程套接字(UDP/TCP)
下面写一个简单的UDP客户端服务器流程 思路: 对于服务器端:读取请求,并解析–> 根据解析出的请求,做出响应(这里是一个回显,)–>把响应写回客户端 对于客户端:从控制台读取用户输入的内容–>从控制…...
微服务技术栈-Gateway服务网关
文章目录 前言一、为什么需要网关二、Spring Cloud Gateway三、断言工厂和过滤器1.断言工厂2.过滤器3.全局过滤器4.过滤器执行顺序 四、跨域问题总结 前言 在之前的文章中我们已经介绍了微服务技术中eureka、nacos、ribbon、Feign这几个组件,接下来将介绍另外一个组…...
函数形状有几种定义方式;操作符infer的作用
在 TypeScript 中,函数形状可以用多种方式进行定义。下面介绍了几种常用的函数形状定义方式: 函数声明: function add(a: number, b: number): number {return a b; }在函数声明中,我们直接使用 function 关键字来声明函数&…...
Java / MybatisPlus:JSON处理器的应用,在实体对象中设置对象属性,对象嵌套对象
1、数据库设计 2、定义内部的实体类 /*** Author lgz* Description* Date 2023/9/30.*/ Data // 静态构造staticName,方便构造对象并赋予属性 AllArgsConstructor(staticName "of") NoArgsConstructor ApiModel(value "亲友", description …...
力扣 -- 1027. 最长等差数列
解题步骤: 参考代码: class Solution { public:int longestArithSeqLength(vector<int>& nums) {int nnums.size();int ret2;unordered_map<int,int> hash;//这里可以先把nums[0]存进哈希表中,方便后面i从1开始遍历hash[num…...
正则验证用户名和跨域postmessage
正则验证用户名 字母数字符号大小写8-14匹配用户名的 <!DOCTYPE html> <html> <head><meta charset"utf-8"><meta name"viewport" content"widthdevice-width, initial-scale1"><title>form</title> …...
jsbridge实战1:xcode swift 构建iOS app
[[toc]] 环境安装 macOs: 10.15.5 xcode: 11.6 demo:app 创建 hello world iOS app 创建工程步骤 选择:Create a new Xcode project选择:iOS-> single View App填写: project name: swift-app-helloidentifer: smile 包名language: s…...
Linux 文件类型,目录与路径,文件与目录管理
文件类型 后面的字符表示文件类型标志 普通文件:-(纯文本文件,二进制文件,数据格式文件) 如文本文件、图片、程序文件等。 目录文件:d(directory) 用来存放其他文件或子目录。 设备…...
蓝桥杯 2024 15届国赛 A组 儿童节快乐
P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡,轻快的音乐在耳边持续回荡,小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下,六一来了。 今天是六一儿童节,小蓝老师为了让大家在节…...
将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?
Otsu 是一种自动阈值化方法,用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理,能够自动确定一个阈值,将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...
如何为服务器生成TLS证书
TLS(Transport Layer Security)证书是确保网络通信安全的重要手段,它通过加密技术保护传输的数据不被窃听和篡改。在服务器上配置TLS证书,可以使用户通过HTTPS协议安全地访问您的网站。本文将详细介绍如何在服务器上生成一个TLS证…...
Neo4j 集群管理:原理、技术与最佳实践深度解析
Neo4j 的集群技术是其企业级高可用性、可扩展性和容错能力的核心。通过深入分析官方文档,本文将系统阐述其集群管理的核心原理、关键技术、实用技巧和行业最佳实践。 Neo4j 的 Causal Clustering 架构提供了一个强大而灵活的基石,用于构建高可用、可扩展且一致的图数据库服务…...
Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)
参考官方文档:https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java(供 Kotlin 使用) 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...
HDFS分布式存储 zookeeper
hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架,允许使用简单的变成模型跨计算机对大型集群进行分布式处理(1.海量的数据存储 2.海量数据的计算)Hadoop核心组件 hdfs(分布式文件存储系统)&a…...
JavaScript基础-API 和 Web API
在学习JavaScript的过程中,理解API(应用程序接口)和Web API的概念及其应用是非常重要的。这些工具极大地扩展了JavaScript的功能,使得开发者能够创建出功能丰富、交互性强的Web应用程序。本文将深入探讨JavaScript中的API与Web AP…...
Git 3天2K星标:Datawhale 的 Happy-LLM 项目介绍(附教程)
引言 在人工智能飞速发展的今天,大语言模型(Large Language Models, LLMs)已成为技术领域的焦点。从智能写作到代码生成,LLM 的应用场景不断扩展,深刻改变了我们的工作和生活方式。然而,理解这些模型的内部…...
解析两阶段提交与三阶段提交的核心差异及MySQL实现方案
引言 在分布式系统的事务处理中,如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议(2PC)通过准备阶段与提交阶段的协调机制,以同步决策模式确保事务原子性。其改进版本三阶段提交协议(3PC…...
