Python Day37
Task:
1.过拟合的判断:测试集和训练集同步打印指标
2.模型的保存和加载
a.仅保存权重
b.保存权重和模型
c.保存全部信息checkpoint,还包含训练状态
3.早停策略
1. 过拟合的判断:测试集和训练集同步打印指标
过拟合是指模型在训练数据上表现良好,但在未见过的新数据(测试数据)上表现较差。 判断过拟合的关键在于比较训练集和测试集的性能。
- 训练集指标: 准确率、损失等。 训练集指标通常会随着训练的进行而持续提高(准确率)或降低(损失)。
- 测试集指标: 准确率、损失等。 测试集指标一开始可能会随着训练的进行而提高,但当模型开始过拟合时,测试集指标会停止提高甚至开始下降。
判断标准:
- 训练集性能远好于测试集性能: 这是过拟合的典型迹象。
- 训练集损失持续下降,但测试集损失开始上升: 这表明模型正在记住训练数据,而不是学习泛化能力。
- 训练集准确率接近 100%,但测试集准确率较低: 也是过拟合的信号。
代码示例 (PyTorch):
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np# 1. 创建一些模拟数据
np.random.seed(42)
X = np.random.rand(100, 10) # 100个样本,每个样本10个特征
y = np.random.randint(0, 2, 100) # 二分类问题# 将数据分成训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]# 转换为 PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)# 创建 DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)# 2. 定义一个简单的模型
class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return outinput_size = 10
hidden_size = 20
num_classes = 2
model = SimpleNN(input_size, hidden_size, num_classes)# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 4. 训练循环
num_epochs = 100
for epoch in range(num_epochs):# 训练model.train()train_loss = 0.0correct_train = 0total_train = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total_train += labels.size(0)correct_train += (predicted == labels).sum().item()train_accuracy = 100 * correct_train / total_traintrain_loss /= len(train_loader)# 评估model.eval()test_loss = 0.0correct_test = 0total_test = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total_test += labels.size(0)correct_test += (predicted == labels).sum().item()test_accuracy = 100 * correct_test / total_testtest_loss /= len(test_loader)# 打印指标print(f'Epoch [{epoch+1}/{num_epochs}], 'f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, 'f'Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%')# 观察训练集和测试集的损失和准确率,判断是否过拟合。
关键点:
- 在每个 epoch 结束后,同时打印训练集和测试集的损失和准确率。
- 观察这些指标的变化趋势。 如果训练集损失持续下降,但测试集损失开始上升,或者训练集准确率远高于测试集准确率,则表明模型正在过拟合。
2. 模型的保存和加载
在 PyTorch 中,有几种保存和加载模型的方法:
a. 仅保存权重 (推荐):
这是最常见和推荐的方法。 它只保存模型的参数(权重和偏置),而不保存模型的结构。 这种方法更轻量级,更灵活,因为你可以将权重加载到任何具有相同结构的模型的实例中。
- 保存:
torch.save(model.state_dict(), 'model_weights.pth')
- 加载:
- 首先,创建一个与保存时具有相同结构的模型的实例。
- 然后,使用
model.load_state_dict(torch.load('model_weights.pth'))
加载权重。
# 保存权重
torch.save(model.state_dict(), 'model_weights.pth')# 加载权重
model = SimpleNN(input_size, hidden_size, num_classes) # 创建模型实例
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 设置为评估模式
b. 保存权重和模型结构:
这种方法保存整个模型对象,包括模型的结构和权重。 它更简单,但不如仅保存权重灵活。
- 保存:
torch.save(model, 'model.pth')
- 加载:
model = torch.load('model.pth')
# 保存整个模型
torch.save(model, 'model.pth')# 加载整个模型
model = torch.load('model.pth')
model.eval() # 设置为评估模式
c. 保存全部信息 (Checkpoint):
这种方法保存模型的全部信息,包括模型的结构、权重、优化器的状态、epoch 数等。 这对于恢复训练非常有用。
- 保存:
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': train_loss,
}
torch.save(checkpoint, 'checkpoint.pth')
- 加载:
checkpoint = torch.load('checkpoint.pth')
model = SimpleNN(input_size, hidden_size, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.eval() # or model.train() depending on your needs
# - or -
model.train()
3. 早停策略 (Early Stopping)
早停是一种正则化技术,用于防止过拟合。 它的基本思想是:在训练过程中,如果测试集上的性能在一定时间内没有提高,就提前停止训练。
实现步骤:
- 监控指标: 选择一个在测试集上监控的指标(例如,测试集损失或准确率)。
- 耐心 (Patience): 定义一个“耐心”值,表示在测试集指标没有提高的情况下,允许训练继续进行的 epoch 数。
- 最佳模型: 在训练过程中,保存测试集上性能最佳的模型。
- 停止训练: 如果测试集指标在“耐心”个 epoch 内没有提高,则停止训练,并加载保存的最佳模型。
代码示例 (PyTorch):
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np# 1. 创建一些模拟数据
np.random.seed(42)
X = np.random.rand(100, 10) # 100个样本,每个样本10个特征
y = np.random.randint(0, 2, 100) # 二分类问题# 将数据分成训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]# 转换为 PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)# 创建 DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)# 2. 定义一个简单的模型
class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return outinput_size = 10
hidden_size = 20
num_classes = 2
model = SimpleNN(input_size, hidden_size, num_classes)# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 4. 早停策略
patience = 10 # 如果10个epoch测试集损失没有下降,就停止训练
best_loss = np.inf # 初始最佳损失设置为无穷大
counter = 0 # 计数器,记录测试集损失没有下降的epoch数
best_model_state = None # 保存最佳模型的状态# 5. 训练循环
num_epochs = 100
for epoch in range(num_epochs):# 训练model.train()train_loss = 0.0correct_train = 0total_train = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total_train += labels.size(0)correct_train += (predicted == labels).sum().item()train_accuracy = 100 * correct_train / total_traintrain_loss /= len(train_loader)# 评估model.eval()test_loss = 0.0correct_test = 0total_test = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total_test += labels.size(0)correct_test += (predicted == labels).sum().item()test_accuracy = 100 * correct_test / total_testtest_loss /= len(test_loader)# 打印指标print(f'Epoch [{epoch+1}/{num_epochs}], 'f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, 'f'Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%')# 早停逻辑if test_loss < best_loss:best_loss = test_losscounter = 0best_model_state = model.state_dict() # 保存最佳模型的状态else:counter += 1if counter >= patience:print("Early stopping!")model.load_state_dict(best_model_state) # 加载最佳模型breakprint(f"Best Test Loss: {best_loss:.4f}")
关键点:
patience
:控制早停的敏感度。 较小的patience
值会导致更早的停止,可能导致欠拟合。 较大的patience
值允许训练继续进行更长时间,可能导致过拟合。best_loss
:跟踪测试集上的最佳损失。counter
:跟踪测试集损失没有下降的 epoch 数。- 在早停后,加载保存的最佳模型。
总结:
- 通过同步打印训练集和测试集的指标,可以有效地判断过拟合。
- 仅保存权重是保存模型的推荐方法,因为它更轻量级和灵活。
- 早停是一种有效的正则化技术,可以防止过拟合。
相关文章:

Python Day37
Task: 1.过拟合的判断:测试集和训练集同步打印指标 2.模型的保存和加载 a.仅保存权重 b.保存权重和模型 c.保存全部信息checkpoint,还包含训练状态 3.早停策略 1. 过拟合的判断:测试集和训练集同步打印指标 过拟合是指模型在训…...

RabbitMQ集群与负载均衡实战指南
文章目录 集群架构概述仲裁队列的使用1. 使用Spring框架代码创建2. 使用amqp-client创建3. 使用管理平台创建 负载均衡引入HAProxy 负载均衡:使用方法1. 修改配置文件2. 声明队列 test_cluster3. 发送消息 集群架构 概述 RabbitMQ支持部署多个结点,每个…...
怎么开机自动启动vscode项目
每次开机都得用 vscode 打开多个工程,然后用 vscode 里的终端启动,怎么设置成开机自动启动,省事点。 创建 bat 文件,用 cmd 启动,然后将 bat 文件放到 windows 启动文件夹中 yqp1.bat echo on cls d: cd D:\yqp\add…...
Unity 中 Update、FixedUpdate 和 LateUpdate 的区别及使用场景
在Unity开发中,Update、FixedUpdate 和 LateUpdate 是生命周期函数中最常见也最容易混淆的一组。 一、调用时机 方法名调用频率调用时机说明Update()每帧调用一次跟随帧率(帧率高则调用频率高)FixedUpdate()固定时间间隔调用默认每 0.02 秒执行一次LateUpdate()每帧调用一次…...

linux安装ffmpeg7.0.2全过程
编辑 白眉大叔 发布于 2025年4月16日 评论关闭 阅读(341) centos 编译安装 ffmpeg 7.0.2 :连接https://www.baimeidashu.com/19668.html 下载 FFmpeg 源代码 在文章最后 一、在CentOS上编译安装FFmpeg 以常见的CentOS为例,FFmpeg的编译说明页面为h…...

Java中的设计模式实战:单例、工厂、策略模式的最佳实践
Java中的设计模式实战:单例、工厂、策略模式的最佳实践 在Java开发中,设计模式是构建高效、可维护、可扩展应用程序的关键。本文将深入探讨三种常见且实用的设计模式:单例模式、工厂模式和策略模式,并通过详细代码实例࿰…...
DexGarmentLab 论文翻译
单个 专家 演示 装扮 15 任务 场景 2500+ 服装 手套 棒球帽 裤子 围巾 碗 帽子 上衣 外套 服装-手部交互 捕捉 摇篮 夹紧 平滑 任务 ...... 投掷 悬挂 折叠 ... 多样化位置 ... 多样化 变形 ... 多样化服装形状 类别级 一般化 类别级(有或没有变形) 服装具有相同结构 变形 生…...
Elasticsearch性能优化全解析
Elasticsearch作为一款分布式搜索和分析引擎,其性能优化是实际生产环境中必须深入研究的课题。本文基于Elastic官方文档,系统性地总结了从硬件配置、索引设计到查询优化的全链路优化策略,帮助用户构建高性能、高稳定性的集群。 Elasticsearch的优化需结合业务场景综合决策:…...

2025.05.28【Parallel】Parallel绘图:拟时序分析专用图
Improve general appearance Add title, use a theme, change color palette, control variable orders and more Highlight a group Highlight a group of interest to help people understand your story 文章目录 Improve general appearanceHighlight a group探索Paralle…...
tc3975开发板上有ft2232这块的电路,我想知道这个开发板有哪些升级方式,重点关注是怎样通过ft2232实现的烧录升级的
关于TC3975开发板上FT2232芯片支持的升级方式,特别是如何通过FT2232实现烧录升级的问题。首先,我得回忆一下FT2232的基本功能和常见应用场景。 FT2232是FTDI公司的一款双通道USB转UART/FIFO芯片,常用于嵌入式系统的调试和编程。它支持多种协议…...

自动驾驶与智能交通:构建未来出行的智能引擎
随着人工智能、物联网、5G和大数据等前沿技术的发展,自动驾驶汽车和智能交通系统正以前所未有的速度改变人类的出行方式。这一变革不仅是技术的融合创新,更是推动城市可持续发展的关键支撑。 一、自动驾驶与智能交通的定义 1. 自动驾驶(Auto…...
Kotlin Multiplatform与Flutter深度对比:跨平台开发方案的实战选择
简介 在当今多平台应用开发的浪潮中,Kotlin Multiplatform与Flutter代表了两种截然不同的技术路线。KMP以"共享代码、保留原生"为核心理念,允许开发者在业务逻辑层实现高达80%的跨平台代码共享,而Flutter则采用统一渲染引擎,在UI层提供100%的代码共享率。这两种…...

ELectron 中 BrowserView 如何进行实时定位和尺寸调整
背景 BrowserView 是继 Webview 后推出来的高性能多视图管理工具,与 Webview 最大的区别是,Webview 是一个 DOM 节点,依附于主渲染进程的附属进程,Webview 节点的崩溃会导致主渲染进程的连锁反应,会引起软件的崩溃。 …...

深兰科技董事长陈海波率队考察南京,加速AI大模型区域落地应用
近日,深兰科技创始人、董事长陈海波受邀率队赴南京市,先后考察了南京高新技术产业开发区与鼓楼区,就推进深兰AI医诊大模型在南京的落地应用,与当地政府及相关部门进行了深入交流与合作探讨。 此次考察聚焦于深兰科技自主研发的AI医…...

《深度关系-从建立关系到彼此信任》
陈海贤老师推荐的书,花了几个小时,感觉现在的人与人之间特别缺乏这种深度的关系,但是与一个人建立深度的关系并没有那么简单,反正至今为止,自己好像没有与任何一个人建立了这种深度的关系,那种双方高度同频…...

IT选型指南:电信行业需要怎样的服务器?
从第一条电报发出的 那一刻起 电信技术便踏上了飞速发展的征程 百余年间 将世界编织成一个紧密相连的整体 而在今年 我们迎来了第25届世界电信日 同时也是国际电联成立的第160周年 本届世界电信日的主题为:“弥合性别数字鸿沟,为所有人创造机遇”,但在新兴技术浪潮汹涌…...

【ConvLSTM第二期】模拟视频帧的时序建模(Python代码实现)
目录 1 准备工作:python库包安装1.1 安装必要库 案例说明:模拟视频帧的时序建模ConvLSTM概述损失函数说明(python全代码) 参考 ConvLSTM的原理说明可参见另一博客-【ConvLSTM第一期】ConvLSTM原理。 1 准备工作:pytho…...
[VMM]分享一个用SystemC编写的页表管理程序
分享一个用SystemC编写的页表管理程序 摘要:分享一个用SystemC编写的页表管理的程序,这个程序将模拟页表(PDE和PTE)的创建、虚拟地址(VA)到物理地址(PA)的转换,以及对内存的读写操作。 为了简化实现,我们做出以下假设: 页表是两级结构:PDE (Page Directory…...
将docker数据目录迁移到 home目录下
将 Docker 数据目录从默认位置(通常是 /var/lib/docker)迁移到 /home 目录下,可以通过几个步骤来完成。以下是详细的迁移步骤: 步骤 1:停止 Docker 服务 在进行任何操作之前,确保先停止 Docker 服务以避免…...

【论文解读】DETR: 用Transformer实现真正的End2End目标检测
1st authors: About me - Nicolas CarionFrancisco Massa - Google Scholar paper: [2005.12872] End-to-End Object Detection with Transformers ECCV 2020 code: facebookresearch/detr: End-to-End Object Detection with Transformers 1. 背景 目标检测&#…...
Pytest 是什么
Pytest 是 Python 生态中最流行的 测试框架,用于编写、运行和组织单元测试、功能测试甚至复杂的集成测试。它以简洁的语法、强大的插件系统和高度可扩展性著称,广泛应用于 Python 项目的自动化测试中。以下是其核心特性和使用详解: Pytest 的…...

ElasticSearch简介及常用操作指南
一. ElasticSearch简介 ElasticSearch 是一个基于 Lucene 构建的开源、分布式、RESTful 风格的搜索和分析引擎。 1. 核心功能 强大的搜索能力 它能够提供全文检索功能。例如,在海量的文档数据中,可以快速准确地查找到包含特定关键词的文档。这在处理诸如…...
缓存常见问题:缓存穿透、缓存雪崩以及缓存击穿
缓存常见问题 一、缓存穿透 (Cache Penetration) 是什么 缓存穿透是指客户端持续请求一个缓存和数据库中都根本不存在的数据。这导致每次请求都会先查缓存(未命中),然后穿透到数据库查询(也未命中)。如果这类请求量…...

纤维组织效应偏斜如何影响您的高速设计
随着比特率继续飙升,光纤编织效应 (FWE) 偏移,也称为玻璃编织偏移 (GWS),正变得越来越成为一个问题。今天的 56GB/s 是高速路由器中最先进的,而 112 GB/s 指日可待。而用于个人计算机…...
【深度学习】sglang 的部署参数详解
SGLang 的部署参数详解 SGLang(Structured Generation Language)是一个高性能的大语言模型推理框架,专为结构化生成和多模态应用设计。本文将全面介绍SGLang的部署参数,帮助你充分发挥其性能潜力。 🚀 SGLang 项目概览 SGLang是由UC Berkeley开发的新一代LLM推理引擎,…...
SDL2常用函数:SDL_RendererSDL_CreateRendererSDL_RenderCopySDL_RenderPresent
SDL 渲染器系统详解 SDL_Renderer 概述 SDL_Renderer 是 SDL 2.0 引入的核心渲染抽象,它提供了一种高效的、硬件加速的 2D 渲染方式,比传统的表面(Surface)操作更加高效和灵活。 主要函数 1. SDL_CreateRenderer - 创建渲染器 SDL_Renderer* SDL_Cr…...
[git]忽略.gitignore文件
git rm --cached .gitignore 是一个 Git 命令,主要用于 从版本控制中移除已追踪的 .gitignore 文件,但保留该文件在本地工作目录中。以下是详细解析: 一、命令拆解与核心作用 语法解析 git rm:Git 的删除命令,用于从版本库(Repository)中移除文件。--cached:关键参数…...
FEMFAT许可的有效期限
在工程仿真领域,FEMFAT作为一款领先的疲劳分析软件,为用户提供了强大的功能和卓越的性能。然而,为了确保软件的合法使用和持续合规,了解FEMFAT许可的有效期限至关重要。本文将为您详细解读FEMFAT许可的有效期限,帮助您…...

Rust使用Cargo构建项目
文章目录 你好,Cargo!验证Cargo安装使用Cargo创建项目新建项目配置文件解析默认代码结构 Cargo工作流常用命令速查表详细使用说明1. 编译项目2. 运行程序3.快速检查4. 发布版本构建 Cargo的设计哲学约定优于配置工程化优势 开发建议1. 新项目初始化2. …...

Python训练营打卡Day39
DAY 39 图像数据与显存 知识点回顾 1.图像数据的格式:灰度和彩色数据 2.模型的定义 3.显存占用的4种地方 a.模型参数梯度参数 b.优化器参数 c.数据批量所占显存 d.神经元输出中间状态 4.batchisize和训练的关系 作业:今日代码较少,理解内容…...