【深度学习-pytorch篇】4. 正则化方法(Regularization Techniques)
正则化方法(Regularization Techniques)
1. 目标
- 理解什么是过拟合及其影响
- 掌握常见正则化技术:L2 正则化、Dropout、Batch Normalization、Early Stopping
- 能够使用 PyTorch 编程实现这些正则化方法并进行比较分析
2. 数据构造与任务设定
本实验是一个带噪声的回归任务,目标函数为 y = x + N ( 0 , σ 2 ) y = x + \mathcal{N}(0, \sigma^2) y=x+N(0,σ2)。使用均匀分布采样输入 x ∈ [ − 1 , 1 ] x \in [-1, 1] x∈[−1,1]。
import numpy as np
import torch
import torch.utils.data as DataN_SAMPLES = 20
NOISE_RATE = 0.4train_x = np.linspace(-1, 1, N_SAMPLES)[:, np.newaxis]
train_y = train_x + np.random.normal(0, NOISE_RATE, train_x.shape)validate_x = np.linspace(-1, 1, N_SAMPLES // 2)[:, np.newaxis]
validate_y = validate_x + np.random.normal(0, NOISE_RATE, validate_x.shape)test_x = np.linspace(-1, 1, N_SAMPLES)[:, np.newaxis]
test_y = test_x + np.random.normal(0, NOISE_RATE, test_x.shape)# 转换为 Tensor
train_x = torch.tensor(train_x, dtype=torch.float32)
train_y = torch.tensor(train_y, dtype=torch.float32)
validate_x = torch.tensor(validate_x, dtype=torch.float32)
validate_y = torch.tensor(validate_y, dtype=torch.float32)
test_x = torch.tensor(test_x, dtype=torch.float32)
test_y = torch.tensor(test_y, dtype=torch.float32)train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)
3. 模型定义
3.1 原始 MLP(无正则化)
import torch.nn as nn
import torch.nn.init as initclass FC_Classifier(nn.Module):def __init__(self, input_dim=1, hidden_dim=100, output_dim=1):super().__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, output_dim)self.activation = nn.ReLU()self._init_weights()def _init_weights(self):init.normal_(self.fc1.weight, mean=0.0, std=0.1)init.constant_(self.fc1.bias, 0)init.normal_(self.fc2.weight, mean=0.0, std=0.1)init.constant_(self.fc2.bias, 0)def forward(self, x):x = self.activation(self.fc1(x))return self.fc2(x)
3.2 Dropout MLP
class DropoutMLP(nn.Module):def __init__(self, dropout_rate=0.5):super().__init__()self.fc1 = nn.Linear(1, 100)self.dropout = nn.Dropout(dropout_rate)self.fc2 = nn.Linear(100, 1)self.activation = nn.ReLU()self._init_weights()def _init_weights(self):init.normal_(self.fc1.weight, mean=0.0, std=0.1)init.constant_(self.fc1.bias, 0)init.normal_(self.fc2.weight, mean=0.0, std=0.1)init.constant_(self.fc2.bias, 0)def forward(self, x):x = self.dropout(self.fc1(x))x = self.activation(x)return self.fc2(x)
3.3 Batch Normalization MLP
class BNMLP(nn.Module):def __init__(self):super().__init__()self.bn_input = nn.BatchNorm1d(1)self.fc1 = nn.Linear(1, 100)self.bn_hidden = nn.BatchNorm1d(100)self.fc2 = nn.Linear(100, 1)self.activation = nn.ReLU()def forward(self, x):x = self.bn_input(x)x = self.fc1(x)x = self.bn_hidden(x)x = self.activation(x)return self.fc2(x)
4. Early Stopping 策略
当验证集误差连续若干轮无提升时,提前停止训练,避免过拟合。
max_patience = 5
patience = 0
best_val_loss = float("inf")
is_early_stop = False
5. RMSNorm 实现与讲解
5.1 原理说明
RMSNorm 是一种替代 LayerNorm 的轻量化归一化方法:
- 不减均值
- 仅用激活值的均方根进行归一化
- 不依赖 batch 维度
数学公式:
RMS ( x ) = 1 n ∑ i = 1 n x i 2 \text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2} RMS(x)=n1i=1∑nxi2
RMSNorm ( x ) = x RMS ( x ) + ϵ ⋅ γ \text{RMSNorm}(x) = \frac{x}{\text{RMS}(x) + \epsilon} \cdot \gamma RMSNorm(x)=RMS(x)+ϵx⋅γ
其中 γ \gamma γ 为可学习参数, ϵ \epsilon ϵ 是一个很小的数避免除以 0。
5.2 代码实现
class RMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.eps = epsdef forward(self, x):rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)return self.weight * x / rms
5.3 与其他归一化对比
方法 | 是否减均值 | 是否除方差 | 是否依赖 batch |
---|---|---|---|
BatchNorm | 是 | 是 | 是 |
LayerNorm | 是 | 是 | 否 |
RMSNorm | 否 | 是 (仅 RMS) | 否 |
6. 实验建议
- 尝试不同的 Dropout 比例(如 0.1 / 0.3 / 0.5)并观察效果;
- 对比是否每层都加 BatchNorm 是否更优;
- 比较 L2 正则项中 weight decay 的不同取值;
- 使用 RMSNorm 替代 LayerNorm 做对比实验。
相关文章:
【深度学习-pytorch篇】4. 正则化方法(Regularization Techniques)
正则化方法(Regularization Techniques) 1. 目标 理解什么是过拟合及其影响掌握常见正则化技术:L2 正则化、Dropout、Batch Normalization、Early Stopping能够使用 PyTorch 编程实现这些正则化方法并进行比较分析 2. 数据构造与任务设定 …...

ESP8266+STM32 AT驱动程序,心知天气API 记录时间: 2025年5月26日13:24:11
接线为 串口2 接入ESP8266 esp8266.c #include "stm32f10x.h"//8266预处理文件 #include "esp8266.h"//硬件驱动 #include "delay.h" #include "usart.h"//用得到的库 #include <string.h> #include <stdio.h> #include …...
WPF【11_5】WPF实战-重构与美化(MVVM 实战)
11-10 【重构】创建视图模型,显示客户列表 正式进入 MVVM 架构的代码实战。在之前的课程中, Model 和 View 这部分的代码重构实际上已经完成了。 Model 就是在 Models 文件夹中看到的两个文件, Customer 和 Appointment。 而 View 则是所有与…...

⭐️⭐️⭐️ 模拟题及答案 ⭐️⭐️⭐️ 大模型Clouder认证:RAG应用构建及优化
考试注意事项: 一、单选题(21题) 检索增强生成(RAG)的核心技术结合了什么? A. 图像识别与自然语言处理 B. 信息检索与文本生成 C. 语音识别与知识图谱 D. 数据挖掘与机器学习 RAG技术中,“建立索引”步骤不包括以下哪项操作? A. 将文档解析为纯文本 B. 文本片段分割(…...

kali系统的安装及配置
1 kali下载 Kali 下载地址:Get Kali | Kali Linux (https://www.kali.org/get-kali) 下载 kali-linux-2024.4-installer-amd64.iso (http://cdimage.kali.org/kali-2024.4/) 2. 具体安装步骤: 2.1 进入官方地址,点击…...
CSS--background-repeat详解
属性介绍 background-repeat 属性在CSS中用于控制背景图像是否以及如何重复。当背景图像的尺寸小于其容器的尺寸时,该属性决定了图像如何填充额外的空间。默认情况下,背景图像会在横向和纵向上重复,直到覆盖整个元素。 常见取值 repeat …...

Redis的大Key问题如何解决?
大家好,我是锋哥。今天分享关于【Redis的大Key问题如何解决?】面试题。希望对大家有帮助; Redis的大Key问题如何解决? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Redis中的“大Key”问题是指某个键的值占用了过多…...

影楼精修-AI追色算法解析
注意:本文样例图片为了避免侵权,均使用AIGC生成; AI追色是像素蛋糕软件中比较受欢迎的一个功能点,本文将针对AI追色来解析一下大概的技术原理。 功能分析 AI追色实际上可以理解为颜色迁移的一种变体或者叫做升级版,…...

node入门:安装和npm使用
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、安装npm命令nvm 前言 因为学习vue接触的,一直以为node是和vue绑定的,还以为vue跑起来必须要node,后续发现并不是。 看…...
‘js@https://registry.npmmirror.com/JS/-/JS-0.1.0.tgz‘ is not in this registry
解决方法: 1. npm cache clean --force 2.临时切换到官方源 npm config set registry https://registry.npmjs.org/ npm install js0.1.0 npm config set registry https://registry.npmmirror.com/ # 切换回镜像源...
el-table-column如何获取行数据的值
在Element UI的el-table组件中,你可以通过el-table-column的slot-scope属性(在Vue 2.x中)或者#default插槽的scope属性(在Vue 3.x中)来获取当前行的数据。以下是如何实现这一功能的详细步骤: 在el-table-…...
leetcode450.删除二叉搜索树中的节点:迭代法巧用中间节点应对多场景删除
一、题目深度解析与BST特性剖析 在二叉搜索树(BST)中删除节点,需确保删除操作后树依然保持BST特性。题目要求我们根据给定的节点值key,在BST中删除对应节点。BST的核心特性是左子树所有节点值小于根节点值,右子树所有…...

java虚拟机2
一、垃圾回收机制(GC) 1. 回收区域:GC主要回收堆内存区域。堆用于存放new出来的对象 。程序计数器、元数据区和栈一般不是GC回收的重点区域。 2. 回收单位:GC以对象为单位回收内存,而非字节。按对象维度回收更简便&am…...
自监督软提示调优:跨域NLP新突破
自监督的软提示调优方法(SPSS) 这篇论文提出了一种基于自监督的软提示调优方法(SPSS),用于无监督领域自适应。其核心目标是通过挖掘源域和目标域的内部知识,解决传统提示调优在跨域场景中依赖通用知识、模板生成低效的问题。 一、核心实现原理 1. 自监督分层聚类优化(…...

Pydantic 学习与使用
Pydantic 学习与使用 在 Fastapi 的 Web 开发中的数据验证通常都是在使用 Pydantic 来进行数据的校验,本文将对 Pydantic 的使用方法做记录与学习。 **简介:**Pydantic 是一个在 Python 中用于数据验证和解析的第三方库,它现在是 Python 使…...

PCB设计教程【入门篇】——电路分析基础-基本元件(二极管三极管场效应管)
前言 本教程基于B站Expert电子实验室的PCB设计教学的整理,为个人学习记录,旨在帮助PCB设计新手入门。所有内容仅作学习交流使用,无任何商业目的。若涉及侵权,请随时联系,将会立即处理、 目录 前言 1.二极管 1.发光…...

能按需拆分 PDF 为多个文档的工具
软件介绍 彩凤 PDF 拆分精灵是一款具备 PDF 拆分功能的软件。 功能特点 PDF 拆分功能较为常见,很多 PDF 软件都具备,例如 DC 软件提取 PDF 较为方便,但它不能从一个 PDF 里提取出多个 PDF。据印象,其他 PDF 软件也似乎没有能从…...

Apifox 5 月产品更新|数据模型支持查看「引用资源」、调试 AI 接口可实时预览 Markdown、性能优化
Apifox 新版本上线啦! 看看本次版本更新主要涵盖的重点内容,有没有你所关注的功能特性: 自动解析 JSON 参数名和参数值调试 AI 接口时,可预览 Markdown 格式的内容性能优化:新增「实验性功能」选项 使用独立进程执行…...

LiveGBS海康、大华、宇视、华为摄像头GB28181国标语音对讲及语音喊话:摄像头设备与服务HTTPS准备
LiveGBS海康、大华、宇视、华为摄像头GB28181国标语音对讲及语音喊话:摄像头设备与服务HTTPS准备 1、背景2、准备工作2.1、服务端必备条件(注意事项)2.2、语音对讲设备准备2.2.1、大华摄像机2.2.2、海康摄像机 3、开启音频并开始对讲4、相关问…...

Sqlalchemy 连mssql坑
连接失败: (pyodbc.OperationalError) (08001, [08001] [Microsoft][ODBC Driver 17 for SQL Server]SSL Provider: [error:0A00014D:SSL routines::legacy sigalg disallowed or unsupported] (-1) (SQLDriverConnect)) (Background on this error at: https://sqlalche.me/e/…...
Prompt Engineering 提示工程介绍与使用/调试技巧
1. 介绍 Prompt Engineering 是一种人工智能(AI)技术,它通过设计和改进 AI 的 prompt 来提高 AI 的表现。Prompt Engineering 的目标是创建高度有效和可控的 AI 系统,使其能够准确、可靠地执行特定任务。 如果你从来没有使用过Pr…...

LLaMaFactory - 支持的模型和模板 常用命令
一、 环境准备 激活LLaMaFactory环境,进入LLaMaFactory目录 cd LLaMA-Factoryconda activate llamafactory 下载模型 #模型下载 from modelscope import snapshot_download model_dir snapshot_download(Qwen/Qwen2.5-0.5B-Instruct) 二、启动一个 Qwen3-0.6B…...

大模型深度学习之双塔模型
前言 双塔模型(Two-Tower Model)是一种在推荐系统、信息检索和自然语言处理等领域广泛应用的深度学习架构。其核心思想是通过两个独立的神经网络(用户塔和物品塔)分别处理用户和物品的特征,并在共享的语义空间中通过相…...
MySQL 8主从同步实战指南:从原理到高可用架构落地
MySQL 8主从同步实战指南:从原理到高可用架构落地 本文将用3000字深度解析MySQL 8主从复制机制,配合全流程部署指南及电商平台实战案例,助你构建高性能数据库集群 一、主从复制核心原理剖析 1.1 复制架构全景图 #mermaid-svg-vdts3hTIyCtz4byk {font-family:"trebuche…...

瑞数6代jsvmp简单分析(天津电子税x局)
国际惯例 今天帮朋友看一个gov网站的瑞数加密(天津电子税x局) 传送门(登陆入口界面) 瑞数6特征 1.服务器会发两次包,第一次响应状态码为412,第二次响应状态码为200。 2.有三重debugger,其中有…...
缓存架构方案:Caffeine + Redis 双层缓存架构深度解析
在高并发、低延迟的现代互联网系统中,缓存是提升系统性能和稳定性的重要手段。随着业务复杂度的增长,单一缓存方案(如仅使用Redis或仅使用本地缓存)已难以满足高性能与一致性需求。 本文将围绕 Caffeine Redis 的双层缓存架构展…...
AI笔记 - 模型调试 - 调试方式
模型调试方式 基础信息打印模型信息计算参数量和计算量过滤原则profile方法get_model_complexity_info方法FlopCountAnalysis方法 基础信息 # 打印执行的设备数量:device_count:1 print(f"device_count:{torch.cuda.device_count()}")# 打印当前网络执行…...

榕壹云物品回收系统实战案例:基于ThinkPHP+MySQL+UniApp的二手物品回收小程序开发与优化
摘要:本文深入解析了一款基于ThinkPHPMySQLUniApp框架开发的二手物品回收小程序——榕壹云物品回收系统的技术实现与商业价值。通过剖析项目背景、核心技术架构、功能特性及系统优势,为开发者与潜在客户提供全面的参考指南,助力资源循环利用与…...

《软件工程》第 9 章 - 软件详细设计
目录 9.1 详细设计的任务与过程模型 9.2 用例设计 9.2.1 设计用例实现方案 9.2.2 构造设计类图 9.2.3 整合并优化用例实现方案 9.3 子系统设计 9.3.1 确立内部设计元素 9.3.2 导出设计类图 9.4 构件设计 9.5 类设计 9.5.1 精化类间关系 9.5.2 精化属性和操作 9.5.…...

WebVm:无需安装,一款可以在浏览器运行的 Linux 来了
WebVM 是一款可以在浏览器中运行的Linux虚拟机。不是那种HTMLJavaScript模拟的UI,完全通过HTML5/WebAssembly技术实现客户端运行。通过集成CheerpX虚拟化引擎,可直接在浏览器中运行未经修改的Debian系统。 Stars 数13054Forks 数2398 主要特点 完整 Lin…...