当前位置: 首页 > article >正文

Pytorch系列教程:可视化Pytorch模型训练过程

深度学习和理解训练过程中的学习和进步机制对于优化性能、诊断欠拟合或过拟合等问题至关重要。将训练过程可视化的过程为学习的动态提供了有价值的见解,使我们能够做出合理的决策。训练进度必须可视化的两种方法是:使用Matplotlib和Tensor Board。在本文中,我们将学习如何在Pytorch中可视化模型训练进度。

使用Matplotlib在PyTorch中可视化训练进度

Matplotlib是Python中广泛使用的绘图库,它为在Python中创建静态,动画和交互式可视化提供了灵活而强大的工具。它特别适合于创建出版质量的图表。
在这里插入图片描述

**步骤1:**导入必要的库并生成样本数据集

在这一步中,我们将导入必要的库并生成样本数据集。

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
# Sample data
X = torch.randn(100, 1)  # Sample features
y = 3 * X + 2 + torch.randn(100, 1)  # Sample labels with noise

**步骤2:**定义模型

  1. PyTorch中的LinearRegression类定义了一个简单的线性回归模型。它继承自nn。模块的类,使其成为一个神经网络模型。
  2. 构造函数(__init__方法)初始化模型的结构,创建具有一个输入特征和一个输出特征的单一线性层(‘nn.Linear’)。
  3. 这个线性层被存储为名为 self.linear的属性。“forward”方法定义了如何通过这个线性层处理输入数据“x”以产生模型的输出。
  4. 具体来说,输入x是通过 self.linear,并返回结果输出。该方法封装了神经网络的前向传递计算,决定了模型如何将输入转换为输出。
# Define a simple linear regression model
class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # One input feature, one outputdef forward(self, x):return self.linear(x)model = LinearRegression()

**步骤3:**定义损失函数、优化器和训练循环

在下面的代码中,我们将均方误差定义为损失函数,将随机梯度下降(SGD)优化器定义为优化器,该优化器通过使用学习率为0.01的计算梯度来修改模型的参数。

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

这段代码运行了一个神经网络模型在多个时代的训练循环,使用梯度下降计算和优化损失。损失值被存储以进行绘图,进度每10次打印一次。

# Training loop
num_epochs = 100
losses = []
for epoch in range(num_epochs):# Forward passoutputs = model(X)loss = criterion(outputs, y)# Backward pass and optimizationoptimizer.zero_grad()loss.backward()optimizer.step()# Print progressif (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# Store loss for plottinglosses.append(loss.item())

**步骤4:**使用Matplotlib在PyTorch中可视化训练进度

使用下面的代码,我们可以使用matplotlib可视化训练损失曲线。

  • plot(损失)线根据epoch号绘制存储在损失列表中的损失值。
  • x轴表示历元数,y轴表示相应的损失值。
  • plt.xlabel(‘Epoch’), plt.ylabel(‘Loss‘)和plt.xlabel(’Epoch’).title()‘Training Loss’)行设置情节的标签和标题。
  • 最后,plot .show()显示该图,允许您可视化地分析损失如何在训练期间减少(或收敛)。
# Plot the loss curve
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

通常,您会期望在损失曲线中看到下降的趋势,这表明模型正在随着时间的推移而学习和改进。

完整的代码:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# Sample data
X = torch.randn(100, 1)  # Sample features
y = 3 * X + 2 + torch.randn(100, 1)  # Sample labels with noise# Define a simple linear regression model
class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # One input feature, one outputdef forward(self, x):return self.linear(x)model = LinearRegression()# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# Training loop
num_epochs = 100
losses = []
for epoch in range(num_epochs):# Forward passoutputs = model(X)loss = criterion(outputs, y)# Backward pass and optimizationoptimizer.zero_grad()loss.backward()optimizer.step()# Print progressif (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# Store loss for plottinglosses.append(loss.item())# Plot the loss curve
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

在这里插入图片描述

输出图显示了训练损失如何随时间变化,并根据迭代次数绘制。这种可视化使人们能够看到模型在训练时是如何减少损失的。此外,Matplotlib图还有其他东西,如轴标签、标题,可能还有标记或线条,表示特定事件,如最小实现损失或损失急剧下降。

使用TensorBoard可视化训练进度

为了在深度学习模型中可视化训练过程,我们可以使用torch.utils.tensorboard模块中的SummaryWriter类,该模块与TensorFlow开发的可视化工具TensorBoard无缝集成。
在这里插入图片描述

  • 集成:PyTorch在torch.utils.tensorboard模块中提供了一个SummaryWriter类,它与TensorBoard无缝集成以实现可视化。
  • 日志记录:在训练循环中,您可以使用SummaryWriter记录各种指标,如损失,准确性等,以实现可视化。
  • 可视化:TensorBoard提供了记录指标的交互式和实时可视化,允许您动态监控训练进度。
  • 监控:TensorBoard使您能够监控训练的多个方面,例如学习曲线,模型图和权重直方图,为优化您的模型提供见解。

使用以下命令安装TensorBoard库:

pip install tensorboard

步骤1:导入库

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

步骤2:定义简单的神经网络

让我们定义SimpleNN一个简单神经网络的类声明,它包含两个完全连接的层,以及定义网络前向传递的forward函数。

# Define a simple neural network
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.flatten(x, 1)x = torch.relu(self.fc1(x))x = self.fc2(x)return x

步骤3:加载MNIST数据集

让我们加载用于训练的MINST数据,将其分成批次并使用一些预处理技术进行转换。

# Load a smaller subset of MNIST dataset
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
small_train_dataset = torch.utils.data.Subset(train_dataset, range(1000))  # Subset of first 1000 samples
train_loader = DataLoader(small_train_dataset, batch_size=64, shuffle=True)

步骤4:初始化模型、损失函数和优化器

现在,初始化模型。与此同时,我们将使用交叉熵损失函数和adam优化器来更新模型参数。

# Initialize model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

步骤5:初始化用于日志记录的SummaryWriter

SummaryWriter是导入模块的对象,用于编写要在TensorBoard中可视化的日志。

# Initialize SummaryWriter for logging
writer = SummaryWriter('logs_small')

第六步:循环训练

  • 训练循环:通过时代和批次,执行向前传递,计算损失,向后传递和更新模型参数。
  • 日志损失和准确性:记录划时代的训练损失和准确性。
# Training loop
epochs = 5
for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, (inputs, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# Calculate accuracy_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# Log losswriter.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + i)# Log accuracyaccuracy = 100 * correct / totalwriter.add_scalar('Accuracy/train', accuracy, epoch)print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {accuracy}%')print('Finished Training')
writer.close()

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# Define a simple neural network
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.flatten(x, 1)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# Load a smaller subset of MNIST dataset
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
small_train_dataset = torch.utils.data.Subset(train_dataset, range(1000))  # Subset of first 1000 samples
train_loader = DataLoader(small_train_dataset, batch_size=64, shuffle=True)# Initialize model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# Initialize SummaryWriter for logging
writer = SummaryWriter('logs_small')# Training loop
epochs = 5
for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, (inputs, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# Calculate accuracy_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# Log losswriter.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + i)# Log accuracyaccuracy = 100 * correct / totalwriter.add_scalar('Accuracy/train', accuracy, epoch)print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {accuracy}%')print('Finished Training')
writer.close()

运行示例,输出如下:

Epoch [1/5], Loss: 1.8145772516727448, Accuracy: 47.1%
Epoch [2/5], Loss: 1.0121613591909409, Accuracy: 78.8%
Epoch [3/5], Loss: 0.6829517856240273, Accuracy: 84.1%
Epoch [4/5], Loss: 0.5442189555615187, Accuracy: 85.4%
Epoch [5/5], Loss: 0.46599634923040867, Accuracy: 87.0%
Finished Training

TensorBoard提供了一个基于web的仪表板,其中包含代表各种培训方面的选项卡和可视化。标量度量将损失或准确度等值可视化,为训练动态提供了不同的视角。此外,TensorBoard可以显示直方图、嵌入和基于日志信息的专门可视化。

在PyTorch中可视化训练进度

为了运行TensorBoard,你应该打开终端,然后运行tensorboard use命令:

tensorboard --logdir=./logs_small

注意,这里logdir指定上节示例的路径,采用相对路径表示。访问TensorBoard需要:打开浏览器,输入TensorBoard提供的网址(通常为http://localhost:6006/)。

a
b

TensorBoard提供了一个基于web的仪表板,其中包含代表各种培训方面的选项卡和可视化。标量度量将损失或准确度等值可视化,为训练动态提供了不同的视角。此外,TensorBoard可以显示直方图、嵌入和基于日志信息的专门可视化。

在这篇博客中,我们介绍了如何使用matplotlib和tensorboard来可视化深度学习框架的训练过程。

相关文章:

Pytorch系列教程:可视化Pytorch模型训练过程

深度学习和理解训练过程中的学习和进步机制对于优化性能、诊断欠拟合或过拟合等问题至关重要。将训练过程可视化的过程为学习的动态提供了有价值的见解,使我们能够做出合理的决策。训练进度必须可视化的两种方法是:使用Matplotlib和Tensor Board。在本文…...

electron+vue+webview内嵌网页并注入js

vue内嵌网页可以使用iframe实现内嵌网页,但是只能通过postMessage间接通信,在electron环境下,vue可以直接使用webview来内嵌网页,支持 executeJavaScript、postMessage、send 等丰富的通信机制。 使用 webview的优势 性能更佳&…...

利用OpenResty拦截SQL注入

需求 客户的一个老项目被相关部门检测不安全,报告为sql注入。不想改代码,改项目,所以想到利用nginx去做一些数据校验拦截。也就是前端传一些用于sql注入的非法字符或者数据库的关键字这些,都给拦截掉,从而实现拦截sql…...

CAD文件转换为STL

AutoCAD与STL格式简介 AutoCAD软件是由美国欧特克有限公司(Autodesk)出品的一款自动计算机辅助设计软件,可以用于绘制二维制图和基本三维设计,通过它无需懂得编程,即可自动制图,因此它在全球广泛使用&…...

78_Pandasagg()和aggregate()的用法

78_Pandasagg()和aggregate()的用法 通过使用pandas.DataFrame和Series的agg()或aggregate()方法,可以对行或列同时应用多个操作进行聚合。agg()是aggregate()的别名,二者用法相同。 pandas.DataFrame.agg — pandas 2.1.3 文档 pandas.Series.agg —…...

QT:串口上位机

创建工程 布局UI界面 设置名称 设置数据 设置波特率 波特率默认9600 设置数据位 数据位默认8 设置停止位 设置校验位 调整串口设置、接收设置、发送设置为Group Box 修改配置 QT core gui serialport 代码详解 mianwindow.h 首先在mianwindow.h当中定义一个串口指…...

C++跨平台开发环境搭建全指南:工具链选型与性能优化实战

C跨平台开发环境搭建全指南:工具链选型与性能优化实战 目录 开发环境搭建工具链选型性能优化实战常见问题排查 开发环境搭建 操作系统环境准备 Windows# 安装Visual Studio Build Tools choco install visualstudio2022buildtools choco install cmake --instal…...

数据批处理(队列方式)

数据批处理&#xff08;队列方式&#xff09; public class DataProcessor {private static final int THREAD_COUNT 4;private static final int QUEUE_SIZE 10;private LinkedBlockingQueue<Data> queue new LinkedBlockingQueue<>(QUEUE_SIZE);public DataP…...

win32汇编环境,网络编程入门之二

;运行效果 ;win32汇编环境,网络编程入门之二 ;本教程在前一教程的基础上&#xff0c;研究一下如何得到服务器的返回的信息 ;正常的逻辑是连接上了&#xff0c;然后我发送什么&#xff0c;它返回什么&#xff0c;但是这有一个很尴尬的问题。 ;就是如何表现出来。因为网络可能有延…...

MATLAB—从入门到精通的第二天

在第一天的学习中&#xff0c;我们掌握了 MATLAB 的安装配置、基础语法、变量管理和运算符的使用。本文将深入讲解 控制结构&#xff08;嵌套 if、switch&#xff09;、循环类型 和 向量操作&#xff0c;帮助读者进一步掌握 MATLAB 的核心编程技能。 1. 条件语句进阶 1.1 嵌套…...

【认识OpenThread协议】

OpenThread 是一种基于 IPv6 、IEEE 802.15.4 标准的低功耗无线 Mesh 网络协议&#xff0c;主要用于智能家居、物联网设备等场景。它的设计目标是实现设备之间的高效通信、低功耗运行和高可靠性。 OpenThread官方文档 ① 特性 低功耗: 适合电池供电的设备。 Mesh 网络: 支持多…...

驱动开发系列46 - Linux 显卡KMD驱动代码分析(七)- 显存管理

目录 一:概述 二:应用程序和UMD调用栈 三:KMD 显存分配和和映射过程 一:概述 显存管理是图形驱动程序中至关重要的一部分,涉及到从用户空间(UMD,User Mode Driver)到内核空间(KMD,Kernel Mode Driver)的显存分配和管理。本文将首先梳理从一个 OpenGL 应…...

MATLAB代码开发实战:从入门到高效应用

一、MATLAB生态系统的核心优势 &#xff08;扩展原有内容&#xff0c;增加行业数据&#xff09; MATLAB在全球工程领域的市场占有率已达67%&#xff08;2024年IEEE统计&#xff09;&#xff0c;其核心优势体现在&#xff1a; 矩阵运算速度比传统编程快3-5倍包含22个专业工具箱…...

为什么 NFS 不适合作为 TDengine 的数据存储

NFS NFS 是一种分布式文件系统&#xff0c;允许多台计算机通过网络共享文件。其具有以下优点&#xff1a; 共享存储: 多个数据库实例可以共享同一个 NFS 目录&#xff0c;适合分布式数据库或集群环境。灵活性: 数据存储可以集中管理&#xff0c;便于备份和迁移。成本低: 利用…...

办公常用自动化工具

自动化办公工具说明文档 代码全部在底部。 文件批量重命名工具 (file_renamer.py) 功能概述 file_renamer.py 是一个用于批量重命名文件的工具&#xff0c;可以根据自定义规则为文件重命名&#xff0c;支持按日期、序号、原文件名等格式进行命名。 主要功能 支持按文件类…...

字节跳动 —— 建筑物组合(滑动窗口+溢出问题)

原题描述&#xff1a; 题目精炼&#xff1a; 给定N个建筑物的位置和一个距离D&#xff0c;选取3个建筑物作为埋伏点&#xff0c;找出所有可能的建筑物组合&#xff0c;使得每组中的建筑物之间的最大距离不超过D。最后&#xff0c;输出不同埋伏方案的数量并对99997867取模。 识…...

开源数字人模型Heygem

一、Heygem是什么 Heygem 是硅基智能推出的开源数字人模型&#xff0c;专为 Windows 系统设计。基于先进的AI技术&#xff0c;仅需1秒视频或1张照片&#xff0c;能在30秒内完成数字人形象和声音克隆&#xff0c;在60秒内合成4K超高清视频。Heygem支持多语言输出、多表情动作&a…...

Linux远程工具SecureCRT下载安装和使用

SecureCRT下载安装和使用 SecureCRT是一款功能强大的终端仿真软件&#xff0c;它支持SSH、Telnet等多种协议&#xff0c;可以连接和管理基于Unix和Windows的远程主机和网络设备。SecureCRT提供了语法高亮、多标签页管理、会话管理、脚本编辑等便捷功能&#xff0c;安全性高、操…...

从前端视角理解消息队列:核心问题与实战指南

消息队列&#xff08;Message Queue&#xff09;是现代分布式系统的核心组件之一&#xff0c;它在前后端协作、系统解耦、流量削峰等场景中发挥着重要作用。本文从前端开发者视角出发&#xff0c;解析消息队列的关键问题&#xff0c;并结合实际场景给出解决方案。 一、为什么要…...

Android 线程池实战指南:高效管理多线程任务

在 Android 开发中&#xff0c;线程池的使用非常重要&#xff0c;尤其是在需要处理大量异步任务时。线程池可以有效地管理线程资源&#xff0c;避免频繁创建和销毁线程带来的性能开销。以下是线程池的使用方法和最佳实践。 1. 线程池的基本使用 &#xff08;1&#xff09;创建线…...

CentOS7下安装MongoDB

步骤 1&#xff1a;创建 MongoDB Yum 仓库文件 你需要创建一个 MongoDB 的 Yum 仓库配置文件&#xff0c;以便从官方源下载 MongoDB。打开终端并使用以下命令创建并编辑该文件&#xff1a; sudo vi /etc/yum.repos.d/mongodb-org-7.0.repo 在打开的文件中&#xff0c;输入以下…...

江科大51单片机笔记【15】直流电机驱动(PWM)

写在前言 此为博主自学江科大51单片机&#xff08;B站&#xff09;的笔记&#xff0c;方便后续重温知识 在后面的章节中&#xff0c;为了防止篇幅过长和易于查找&#xff0c;我把一个小节分成两部分来发&#xff0c;上章节主要是关于本节课的硬件介绍、电路图、原理图等理论…...

【网络协议详解】——QOS技术(学习笔记)

目录 QoS简介 QoS产生的背景 QoS服务模型 基于DiffServ模型的QoS组成 MQC简介 MQC三要素 MQC配置流程 优先级映射配置(DiffServ域模式) 优先级映射概述 优先级映射原理描述 优先级映射 PHB行为 流量监管、流量整形和接口限速简介 流量监管 流量整形 接口限速…...

【工具使用】IDEA 社区版如何创建 Spring Boot 项目(详细教程)

IDEA 社区版如何创建 Spring Boot 项目&#xff08;详细教程&#xff09; Spring Boot 以其简洁、高效的特性&#xff0c;成为 Java 开发的主流框架之一。虽然 IntelliJ IDEA 专业版提供了Spring Boot 项目向导&#xff0c;但 社区版&#xff08;Community Edition&#xff09…...

基于Prometheus+Grafana的Deepseek性能监控实战

文章目录 1. 为什么需要专门的大模型监控?2. 技术栈组成2.1 vLLM(推理引擎层)2.2 Prometheus(监控采集层)2.3 Grafana(数据可视化平台)3. 监控系统架构4. 实施步骤4.1 启动DeepSeek-R1模型4.2 部署 Prometheus4.2.1 拉取镜像4.2.2 编写配置文件4.2.3 启动容器4.3 部署 G…...

Spring学习笔记:工厂模式与反射机制实现解耦

1.什么是Spring? spring是一个开源轻量级的java开发应用框架&#xff0c;可以简化企业级应用开发 轻量级 1.轻量级(对于运行环境没有额外要求) 2.代码移植性高(不需要实现额外接口) JavaEE的解决方案 Spring更像是一种解决方案&#xff0c;对于控制层&#xff0c;它有Spring…...

pytest数据库测试文章推荐

参考链接&#xff1a; 第一部分&#xff1a;http://alextechrants.blogspot.fi/2013/08/unit-testing-sqlalchemy-apps.html第二部分&#xff1a;http://alextechrants.blogspot.fi/2014/01/unit-testing-sqlalchemy-apps-part-2.html...

vue3 二次封装uni-ui中的组件,并且组件中有 v-model 的解决方法

在使用uniappvue3开发中&#xff0c; 使用了uni-ui的组件&#xff0c;但是我们也需要自定义组件&#xff0c;比如我要自定一个picker 的组件&#xff0c; 是在 uni-data-picker 组件的基础上进行封装的 父组件中的代码 <classesselect :selectclass"selectclass"…...

探索高性能AI识别和边缘计算 | NVIDIA Jetson Orin Nano 8GB 开发套件的全面测评

随着边缘计算和人工智能技术的迅速发展&#xff0c;性能强大的嵌入式AI开发板成为开发者和企业关注的焦点。NVIDIA近期推出的Jetson Orin Nano 8GB开发套件&#xff0c;凭借其40 TOPS算力、高效的Ampere架构GPU以及出色的边缘AI能力&#xff0c;引起了广泛关注。本文将从配置性…...

Prompt 工程

一、提示原則 import openai import os import openai from dotenv import load_dotenv, find_dotenv from openai import OpenAI def get_openai_key():_ load_dotenv(find_dotenv())return os.environ[OPENAI_API_KEY]client OpenAI(api_keyget_openai_key(), # This is …...