21.过拟合和欠拟合示例
1. 背景介绍
在机器学习和深度学习中,过拟合和欠拟合是两个非常重要的概念。过拟合指的是模型在训练数据上表现很好,但在新的测试数据上效果变差的情况。欠拟合则是指模型无法很好地拟合训练数据的情况。这两种情况都会导致模型无法很好地泛化,影响最终的预测和应用效果。
为了帮助大家更好地理解过拟合和欠拟合的概念及其应对方法,我将通过一个基于PyTorch的代码示例来演示这两种情况的具体表现。我们将生成一个抛物线数据集,并定义三种不同复杂度的模型,分别对应欠拟合、正常拟合和过拟合的情况。通过可视化训练和测试误差的曲线图,以及预测结果的散点图,我们可以直观地观察到这三种情况下模型的拟合效果。
2. 核心概念与联系
过拟合和欠拟合是机器学习和深度学习中两个相互对应的概念:
1. 过拟合(Overfitting): 模型在训练数据上表现很好,但在新的测试数据上效果变差的情况。这通常是由于模型过于复杂,过度拟合了训练数据中的噪声和细节,导致无法很好地推广到未知数据。
2. 欠拟合(Underfitting): 模型无法很好地拟合训练数据的情况。这通常是由于模型过于简单,无法捕捉训练数据中的复杂模式和关系。
这两种情况都会导致模型在实际应用中无法很好地泛化,因此需要采取相应的措施来防止和缓解过拟合和欠拟合。常见的应对方法包括:
- 增加训练样本数量
- 减少模型复杂度(比如调整网络层数、神经元个数等)
- 使用正则化技术(如L1/L2正则化、Dropout等)
- 调整超参数(如学习率、批量大小等)
- 特征工程(如特征选择、降维等)
通过合理的模型设计和超参数调优,我们可以寻找到一个恰当的模型复杂度,使其既能很好地拟合训练数据,又能在新数据上保持良好的泛化性能。这就是机器学习中的**bias-variance tradeoff**,也是我们在实际应用中需要权衡的一个关键点。
3. 核心算法原理和具体操作步骤
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split# 生成数据
np.random.seed(42)
X = np.random.uniform(-5, 5, 500)
y = X**2 + 1 + np.random.normal(0, 1, 500)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 定义三种不同复杂度的模型
class UnderFitModel(nn.Module):def __init__(self):super(UnderFitModel, self).__init__()self.fc = nn.Linear(1, 1)def forward(self, x):return self.fc(x)class NormalFitModel(nn.Module):def __init__(self):super(NormalFitModel, self).__init__()self.fc1 = nn.Linear(1, 8)self.fc2 = nn.Linear(8, 1)self.activation = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.activation(x)x = self.fc2(x)return xclass OverFitModel(nn.Module):def __init__(self):super(OverFitModel, self).__init__()self.fc1 = nn.Linear(1, 32)self.fc2 = nn.Linear(32, 32)self.fc3 = nn.Linear(32, 1)self.activation = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.activation(x)x = self.fc2(x)x = self.activation(x)x = self.fc3(x)return x# 训练模型并记录误差
def train_and_evaluate(model, train_loader, test_loader):optimizer = torch.optim.SGD(model.parameters(), lr=0.005)criterion = nn.MSELoss()train_losses = []test_losses = []for epoch in range(100):model.train()train_loss = 0.0for inputs, targets in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()train_loss /= len(train_loader)train_losses.append(train_loss)model.eval()test_loss = 0.0with torch.no_grad():for inputs, targets in test_loader:outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()test_loss /= len(test_loader)test_losses.append(test_loss)return train_losses, test_losses# 训练三种模型并可视化
under_fit_model = UnderFitModel()
normal_fit_model = NormalFitModel()
over_fit_model = OverFitModel()under_fit_train_losses, under_fit_test_losses = train_and_evaluate(under_fit_model, train_loader, test_loader)
normal_fit_train_losses, normal_fit_test_losses = train_and_evaluate(normal_fit_model, train_loader, test_loader)
over_fit_train_losses, over_fit_test_losses = train_and_evaluate(over_fit_model, train_loader, test_loader)plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(under_fit_train_losses, label='Under-fit Train Loss')
plt.plot(under_fit_test_losses, label='Under-fit Test Loss')
plt.plot(normal_fit_train_losses, label='Normal-fit Train Loss')
plt.plot(normal_fit_test_losses, label='Normal-fit Test Loss')
plt.plot(over_fit_train_losses, label='Over-fit Train Loss')
plt.plot(over_fit_test_losses, label='Over-fit Test Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training and Test Loss Curves')
plt.legend()plt.subplot(1, 2, 2)
plt.scatter(X_test, y_test, label='True')
plt.scatter(X_test, under_fit_model(X_test).detach().numpy(), label='Under-fit Prediction')
plt.scatter(X_test, normal_fit_model(X_test).detach().numpy(), label='Normal-fit Prediction')
plt.scatter(X_test, over_fit_model(X_test).detach().numpy(), label='Over-fit Prediction')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Test Set Predictions')
plt.legend()plt.show()
这个代码示例涵盖了我们之前讨论的各个步骤:
数据生成: 我们生成了一个抛物线形状的数据集,并使用train_test_split函数将其划分为训练集和测试集。
模型定义: 我们定义了三种不同复杂度的PyTorch模型,分别对应欠拟合、正常拟合和过拟合的情况。
训练与评估: 我们实现了一个train_and_evaluate函数,该函数负责训练模型并记录训练集和测试集上的损失。
可视化: 最后,我们使用matplotlib绘制了训练损失和测试损失的曲线图,以及在测试集上的预测结果。
欠拟合模型:训练误差和测试误差都较大,说明模型无法很好地拟合数据。在测试集上的预测结果也存在较大偏差。
正常拟合模型:训练误差和测试误差较为接近,说明模型的拟合效果较好。在测试集上的预测也比较准确。
过拟合模型:训练误差很小,但测试误差较大,说明模型在训练集上表现很好,但在新数据上泛化能力较差。在测试集上的预测结果存在一定偏差。
通过这个实例,我们可以直观地观察到不同复杂度模型在训练和泛化性能上的差异。欠拟合模型在训练集和测试集上的损失都较大,说明模型无法很好地拟合数据。正常拟合模型在训练集和测试集上的损失较为接近,说明模型具有较好的泛化能力。而过拟合模型在训练集上的损失很小,但在测试集上的损失较大,说明模型过于复杂,在新数据上泛化性能较差。
通过这种观察训练误差和测试误差的方法,我们可以及时发现模型存在的问题,并针对性地调整模型结构、添加正则化等手段来优化模型性能。这是机器学习和深度学习中非常基础和重要的实践技能。
相关文章:
21.过拟合和欠拟合示例
1. 背景介绍 在机器学习和深度学习中,过拟合和欠拟合是两个非常重要的概念。过拟合指的是模型在训练数据上表现很好,但在新的测试数据上效果变差的情况。欠拟合则是指模型无法很好地拟合训练数据的情况。这两种情况都会导致模型无法很好地泛化ÿ…...

使用import语句导入模块
自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 创建模块后,就可以在其他程序中使用该模块了。要使用模块需要先以模块的形式加载模块中的代码,这可以使用import语句实现。im…...

一台FreeBSD笔记本突然鼠标乱动=>pf防火墙设置@FreeBSD
缘起 一台FreeBSD的笔记本,突然鼠标乱动 思考了下,可能原因有三: 1 无线鼠标干扰 正巧没带鼠标,但是插着无线鼠标usb,不知道是不是别人的鼠标跟这个usb串台了。 2 触摸板机械故障 也许是天热触摸板开始有故障了&…...
身份证OCR识别功能介绍
身份证OCR识别功能是一种基于光学字符识别(OCR)技术的解决方案,专门用于从身份证图像中快速、准确地提取和识别信息。以下是关于身份证OCR识别功能的详细介绍: 功能概述 身份证OCR识别功能通过高分辨率的摄像头或扫描仪获取身份证…...

一文看懂:MES定义和功能是什么,以及在数字化工厂的应用
MES应该是继ERP之后制造企业信息化最热门的管理软件,它适应产品个性化与敏捷化制造需求,满足生产过程精益管理而产生和发展起来的信息系统。 作为企业实现数字化与智能化的核心支撑技术与重要组成部分,MES在帮助制造企业走向数字化、智能化等…...

对 SQL 说“不”~
开发人员注意! 您在当前的应用程序架构中是否面临这些问题? 对 SQL 数据库的高吞吐量。SQL 数据库中的瓶颈。 内存数据存储将是解决问题的方案。Redis 是市场上最受欢迎的内存数据存储和缓存选项。Redis 拥有广泛的生态系统,因为主要科技巨…...

【爱空间_登录安全分析报告】
前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞 …...

web前端三大主流框架
一、介绍 目前,前端开发领域的三大主流框架是: React:React是由Facebook开发并维护的开源JavaScript库,用于构建用户界面。它提供了一种声明式的组件化开发方式,能够高效地创建交互性的用户界面。React具有高性能、可…...

git获取的项目无法运行
一、Unsupported engine 问题:在使用命令npm install下载依赖项的时候就遇到了这个问题,有帖子说多试几次,其实这是提示node版本问题,版本的更新出现兼容性问题,多试几次也没用。 解决方案: 更新node.js的…...

java 原生http服务器 测试JS前端ajax访问实现跨域
后端 package Httpv3;import com.sun.net.httpserver.Headers; import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpHandler; import com.sun.net.httpserver.HttpServer;import java.io.IOException; import java.io.OutputStream; import java…...

捋一捋C++中的逻辑运算(一)——表达式逻辑运算
注意,今天要谈的逻辑运算是C语言编程中的“与或非”逻辑运算,不是数学集合中的“交并补”逻辑运算。而编程中的逻辑运算又包括表达式逻辑运算和位逻辑运算,本章介绍表达式逻辑运算,下一章介绍位逻辑运算。 目录 一、几个基本的概…...

qcom 平台系统签名流程
security boot 平台的东东,oem 可定制的功能有限,只能参考平台文档,可以在高通的网站上搜索:Secure Boot Enablement,然后找对应平台的文档xxx-Secure Boot Enablement User Guide, step by step 操作即可 开机校验流…...

从零开始实现自己的串口调试助手(5) -实现HEX显示/发送/接收
实现HEX显示: HEX 显示 -- 其实就是 十六进制显示 --> a - 97(10) --> 61(16) 添加槽函数(bool): 实现槽函数: 注意: 注意QString 没有处理HEX显示的相关API 需要使用 toUtf-8 来 转换位QByteArry 类型, 利用其中的API 来处理HEX格式(toHex fromHex) vo…...

【计算机毕设】基于SpringBoot的民宿在线预定平台设计与实现 - 源码免费(私信领取)
免费领取源码 | 项目完整可运行 | v:chengn7890 诚招源码校园代理! 1. 研究目的 本研究旨在设计并实现一个基于SpringBoot的民宿在线预定平台。通过信息化手段提高民宿预定效率,方便用户查询房源、预定房间、在线支付和…...
大数据—数据分析概论
一、什么是数据分析 数据分析是指使用统计、数学、计算机科学和其他技术手段对数据进行清洗、转换、建模和解释的过程,以提取有用的信息、发现规律、支持决策和解决问题。数据分析可以应用于各种领域,包括商业、医学、工程、社会科学等。 二、数据分析步…...

centos7下卸载MySQL,Oracle数据库
📑打牌 : da pai ge的个人主页 🌤️个人专栏 : da pai ge的博客专栏 ☁️宝剑锋从磨砺出,梅花香自苦寒来 操作系统版本为CentOS 7 使⽤ MySQ…...

Spring解决循环依赖
Spring框架为了解决循环依赖问题,设计了一套三级缓存机制: 一级缓存singletonObjects:这个是最常规的缓存,用于存放完成初始化好的bean,如果某个bean已经在这个缓存了直接返回。二级缓存earlySigletonObjects:这个用于存放早期暴…...
RUST运算符重载
在 Rust 中,可以使用特征(traits)来实现运算符重载。运算符重载是通过实现相应的运算符特征(如 Add、Sub、Mul 等)来完成的。这些特征定义在 std::ops 模块中。下面是一个简单的示例,展示如何为一个自定义结…...
描述一下 Array.forEach() 循环和 Array.map() 方法之间的主要区别
Array.forEach() 和 Array.map() 都是 JavaScript 数组中常用的方法,但它们之间有一些重要的区别: 返回值:forEach():没有返回值,它只是对数组中的每个元素执行提供的函数。map():返回一个新的数组,其元素是通过对原数组的每个元素执行提供的函数后的结…...
在GEE中显示矢量或栅格数据的边界(包含样式设计)
需要保证最后显示的数据是一个 FeatureCollection 对象。 如果数据是一个 Geometry 或 Image,我们也可以使用 style 方法来设置样式并将其添加到地图上。以下是针对不同类型对象的处理方式: 1 Geometry对象 如果 table 是一个 Geometry 对象ÿ…...
Python爬虫实战:研究MechanicalSoup库相关技术
一、MechanicalSoup 库概述 1.1 库简介 MechanicalSoup 是一个 Python 库,专为自动化交互网站而设计。它结合了 requests 的 HTTP 请求能力和 BeautifulSoup 的 HTML 解析能力,提供了直观的 API,让我们可以像人类用户一样浏览网页、填写表单和提交请求。 1.2 主要功能特点…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合
强化学习(Reinforcement Learning, RL)是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程,然后使用强化学习的Actor-Critic机制(中文译作“知行互动”机制),逐步迭代求解…...

大型活动交通拥堵治理的视觉算法应用
大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)
文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

Unity | AmplifyShaderEditor插件基础(第七集:平面波动shader)
目录 一、👋🏻前言 二、😈sinx波动的基本原理 三、😈波动起来 1.sinx节点介绍 2.vertexPosition 3.集成Vector3 a.节点Append b.连起来 4.波动起来 a.波动的原理 b.时间节点 c.sinx的处理 四、🌊波动优化…...

基于Java+VUE+MariaDB实现(Web)仿小米商城
仿小米商城 环境安装 nodejs maven JDK11 运行 mvn clean install -DskipTestscd adminmvn spring-boot:runcd ../webmvn spring-boot:runcd ../xiaomi-store-admin-vuenpm installnpm run servecd ../xiaomi-store-vuenpm installnpm run serve 注意:运行前…...
「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案
在移动互联网营销竞争白热化的当下,推客小程序系统凭借其裂变传播、精准营销等特性,成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径,助力开发者打造具有市场竞争力的营销工具。 一、系统核心功能架构&…...
在golang中如何将已安装的依赖降级处理,比如:将 go-ansible/v2@v2.2.0 更换为 go-ansible/@v1.1.7
在 Go 项目中降级 go-ansible 从 v2.2.0 到 v1.1.7 具体步骤: 第一步: 修改 go.mod 文件 // 原 v2 版本声明 require github.com/apenella/go-ansible/v2 v2.2.0 替换为: // 改为 v…...

向量几何的二元性:叉乘模长与内积投影的深层联系
在数学与物理的空间世界中,向量运算构成了理解几何结构的基石。叉乘(外积)与点积(内积)作为向量代数的两大支柱,表面上呈现出截然不同的几何意义与代数形式,却在深层次上揭示了向量间相互作用的…...