当前位置: 首页 > article >正文

PyTorch 快速入门

我们将通过一个简单的示例,快速了解如何使用 PyTorch 进行机器学习任务。PyTorch 是一个开源的机器学习库,它提供了丰富的工具和库,帮助我们轻松地构建、训练和测试神经网络模型。以下是本教程的主要内容:

一、数据处理

PyTorch 提供了两个基本的数据处理工具:torch.utils.data.DataLoadertorch.utils.data.DatasetDataset 用于存储样本及其对应的标签,而 DataLoader 则为 Dataset 提供了一个可迭代的包装器。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

PyTorch 还提供了许多特定领域的库,如 TorchText、TorchVision 和 TorchAudio,这些库中都包含了各种数据集。在本教程中,我们将使用 TorchVision 中的 FashionMNIST 数据集。每个 TorchVision Dataset 都包含两个参数:transformtarget_transform,分别用于修改样本和标签。

# 下载 FashionMNIST 数据集
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

我们将 Dataset 作为参数传递给 DataLoaderDataLoader 为我们的数据集提供了一个可迭代的包装器,并支持自动批处理、采样、洗牌和多进程数据加载。在这里,我们定义了一个大小为 64 的批次,即数据加载器可迭代对象的每个元素将返回一个包含 64 个特征和标签的批次。

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

二、创建模型

在 PyTorch 中,我们通过创建一个继承自 nn.Module 的类来定义神经网络。我们在 __init__ 函数中定义网络的层,并在 forward 函数中指定数据如何通过网络传递。为了加速神经网络中的操作,我们将网络移动到加速器(如 CUDA、MPS、MTIA 或 XPU)上。如果当前加速器可用,我们将使用它;否则,我们将使用 CPU。

class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits

三、优化模型参数

为了训练模型,我们需要一个损失函数和一个优化器。在一个训练循环中,模型会对训练数据集(以批次形式提供)进行预测,并将预测误差反向传播以调整模型的参数。

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)loss = loss_fn(pred, y)# 反向传播loss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

我们还需要检查模型在测试数据集上的性能,以确保它正在学习。

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程会在多个迭代(epoch)中进行。在每个 epoch 中,模型会学习参数以做出更好的预测。我们在每个 epoch 打印模型的准确率和损失;我们希望看到准确率随着每个 epoch 的增加而增加,损失则随着每个 epoch 的增加而减少。

for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

四、保存模型

保存模型的一种常见方法是序列化内部状态字典(包含模型参数)。

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

五、加载模型

加载模型的过程包括重新创建模型结构,并将状态字典加载到其中。

model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))

现在,我们可以使用这个模型来进行预测。

classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():x = x.to(device)pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')

以上就是使用 PyTorch 进行机器学习任务的基本流程。

六、完整代码

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 下载 FashionMNIST 数据集
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)# 创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)# 定义神经网络模型
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits# 检查是否有可用的加速器
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")# 实例化模型并移动到加速器上
model = NeuralNetwork().to(device)
print(model)# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)loss = loss_fn(pred, y)# 反向传播loss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")# 定义测试函数
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")# 训练和测试模型
epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")# 保存模型
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")# 加载模型
model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))# 使用模型进行预测
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():x = x.to(device)pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')

相关文章:

PyTorch 快速入门

我们将通过一个简单的示例,快速了解如何使用 PyTorch 进行机器学习任务。PyTorch 是一个开源的机器学习库,它提供了丰富的工具和库,帮助我们轻松地构建、训练和测试神经网络模型。以下是本教程的主要内容: 一、数据处理 PyTorch…...

【浏览器 - Mac实时调试iOS手机浏览器页面】

最近开发个项目,需要在 Mac 电脑上调试 iOS 手机设备上的 Chrome 浏览器,并查看Chrome网页上的 console 信息,本来以为要安装一些插件,没想到直接使用Mac上的Safari 直接可以调试,再此记录下,分享给需要的伙…...

PyQt5之QtDesigner的若干配置和使用

1.描述 QtDesigner是一个可视化工具,可以通过该工具设计页面 2.简单使用 1.下载PyQt5-tools pip install pyqt5-tools 2.打开designer.exe文件 我采用的是虚拟环境,该文件位于C:\Users\24715\anaconda3\envs\pyqt\Lib\site-packages\qt5_applicatio…...

Flink (十二) :Table API SQL (一) 概览

Apache Flink 有两种关系型 API 来做流批统一处理:Table API 和 SQL。Table API 是用于 Scala 和 Java 语言的查询API,它可以用一种非常直观的方式来组合使用选取、过滤、join 等关系型算子。Flink SQL 是基于 Apache Calcite 来实现的标准 SQL。无论输入…...

侯捷C++day01

一个类该准备什么样的数据、函数。才能满足使用这个类人的需求。 inline关键字是建议编译器做inline处理。 private只有本类可以看到。 C创建对象会自动调用构造函数。不可能在程序中显示调用构造函数。不带指针的类多半不用写析构函数。 以下两个重载构造函数会发生错误 不允许…...

CTF-web: phar反序列化+数据库伪造 [DASCTF2024最后一战 strange_php]

step 1 如何触发反序列化? 漏洞入口在 welcome.php case delete: // 获取删除留言的路径,优先使用 POST 请求中的路径,否则使用会话中的路径 $message $_POST[message_path] ? $_POST[message_path] : $_SESSION[message_path]; $msg $userMes…...

Win11下帝国时代2无法启动解决方法

鼠标右键点图标,选择属性 点开始,输入启用和关闭...

GSI快速收录服务:让你的网站内容“上架”谷歌

辛苦制作的内容无法被谷歌抓取和展示,导致访客无法找到你的网站,这是会让人丧失信心的事情。GSI快速收录服务就是为了解决这种问题而存在的。无论是新上线的页面,还是长期未被收录的内容,通过我们的技术支持,都能迅速被…...

如何用函数去计算x年x月x日是(C#)

如何用函数去计算x年x月x日是? 由于现在人工智能的普及,我们往往会用计算机去算,或者去记录事情 1.计算某一年某一个月有多少天 2.计算某年某月某日是周几 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threadin…...

mysql_init和mysql_real_connect的形象化认识

解析总结 1. mysql_init 的作用 mysql_init 用于初始化一个 MYSQL 结构体,为后续数据库连接和操作做准备。该结构体存储连接配置及状态信息,是 MySQL C API 的核心句柄。 示例: MYSQL *conn mysql_init(NULL); // 初始化连接句柄2. mysql_…...

python学opencv|读取图像(四十九)原理探究:使用cv2.bitwise()系列函数实现图像按位运算

【0】基础定义 按位与运算:两个等长度二进制数上下对齐,全1取1,其余取0。 按位或运算:两个等长度二进制数上下对齐,有1取1,其余取0。 按位异或运算: 两个等长度二进制数上下对齐,相…...

基础项目实战——学生管理系统(c++)

目录 前言一、功能菜单界面二、类与结构体的实现三、录入学生信息四、删除学生信息五、更改学生信息六、查找学生信息七、统计学生人数八、保存学生信息九、读取学生信息十、打印所有学生信息十一、退出系统十二、文件拆分结语 前言 这一期我们来一起学习我们在大学做过的课程…...

春节期间,景区和酒店如何合理用工?

春节期间,景区和酒店如何合理用工? 春节期间,旅游市场将迎来高峰期。景区与酒店,作为旅游产业链中的两大核心环节,承载着无数游客的欢乐与期待。然而,也隐藏着用工管理的巨大挑战。如何合理安排人力资源&a…...

信息学奥赛一本通 1606:【 例 1】任务安排 1 | 洛谷 P2365 任务安排

【题目链接】 ybt 1606:【 例 1】任务安排 1 洛谷 P2365 任务安排 【题目考点】 1. 动态规划:线性动规 【解题思路】 可以先了解法1,虽然不是正解,但该解法只使用了动规的基本思路,易于理解,有助于理解…...

Linux Samba 低版本漏洞(远程控制)复现与剖析

目录 前言 漏洞介绍 漏洞原理 产生条件 漏洞影响 防御措施 复现过程 结语 前言 在网络安全的复杂生态中,系统漏洞的探索与防范始终是保障数字世界安全稳定运行的关键所在。Linux Samba 作为一款在网络共享服务领域应用极为广泛的软件,其低版本中…...

讯飞智作 AI 配音技术浅析(一)

一、核心技术 讯飞智作 AI 配音技术作为科大讯飞在人工智能领域的重要成果,融合了多项前沿技术,为用户提供了高质量的语音合成服务。其核心技术主要涵盖以下几个方面: 1. 深度学习与神经网络 讯飞智作 AI 配音技术以深度学习为核心驱动力&…...

Autogen_core 测试代码:test_types.py

目录 第一段代码:test_get_types第二段代码:test_handler第三段代码:test_nested_data_model总结 代码段是针对 autogen_core 的库的单元测试,主要关注类型检查和消息处理。让我们逐个解释每个代码段的功能: 第一段代…...

PySide(PyQT)进行SQLite数据库编辑和前端展示的基本操作

以SQLite数据库为例,学习数据库的基本操作,使用QSql模块查询、编辑数据并在前端展示。 SQLite数据库的基础知识: https://blog.csdn.net/xulibo5828/category_12785993.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId1278…...

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.27 线性代数王国:矩阵分解实战指南

1.27 线性代数王国:矩阵分解实战指南 #mermaid-svg-JWrp2JAP9qkdS2A7 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-JWrp2JAP9qkdS2A7 .error-icon{fill:#552222;}#mermaid-svg-JWrp2JAP9qkdS2A7 .erro…...

初二回娘家

昨天下午在相亲相爱一家人群里聊天,今天来娘家拜年。 聊天结束后,开始准备今天的菜肴,梳理了一下,凉菜,热菜,碗菜。 上次做菜,粉丝感觉泡的不透,有的硬,这次使用开水浸泡…...

【Block总结】PKI 模块,无膨胀多尺度卷积,增强特征提取的能力|即插即用

论文信息 标题: Poly Kernel Inception Network for Remote Sensing Detection 作者: Xinhao Cai, Qiuxia Lai, Yuwei Wang, Wenguan Wang, Zeren Sun, Yazhou Yao 论文链接:https://arxiv.org/pdf/2403.06258 代码链接:https://github.com/NUST-Mac…...

日志2025.1.30

日志2025.1.30 1.简略地做了一下交互系统 public class Interactable : MonoBehaviour { private MeshRenderer renderer; private Material defaultMaterial; public Material highlightMaterial; private void Awake() { renderer GetComponentInChildren<Me…...

PHP中的获取器和修改器:探索数据访问的新维度

在PHP开发中&#xff0c;操作数据是开发人员最常见的任务之一。为了使数据的访问和修改更加便捷和安全&#xff0c;PHP提供了获取器和修改器这两个强大的特性。本文将探索获取器和修改器的作用和用法&#xff0c;并且通过具体的代码示例来帮助读者更好地理解和应用这两个特性。…...

Blazor-@bind

数据绑定 带有 value属性的标记都可以使用bind 绑定&#xff0c;<div>、<span>等非输入标记&#xff0c;无法使用bind 指令的&#xff0c;默认绑定了 onchange 事件&#xff0c;onchange 事件是指在输入框中输入内容之后&#xff0c;当失去焦点时执行。 page &qu…...

Github 2025-01-29 C开源项目日报 Top10

根据Github Trendings的统计,今日(2025-01-29统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量C项目10C++项目1Assembly项目1Go项目1我的电视 - 安卓电视直播软件 创建周期:40 天开发语言:CStar数量:649 个Fork数量:124 次关注人数:64…...

01-时间与管理

时间与效率 一丶番茄时钟步骤好处 二丶86400s的财富利用时间的方法每天坚持写下一天计划 自我管理体系计划-行动-评价-回顾 一丶番茄时钟 一个计时器 一份任务清单,任务 步骤 每一个25分钟是一个番茄时钟 将工作时间划分为若干个25分钟的工作单元期间只专注于当前任务,遇到…...

架构技能(六):软件设计(下)

我们知道&#xff0c;软件设计包括软件的整体架构设计和模块的详细设计。 在上一篇文章&#xff08;见 《架构技能&#xff08;五&#xff09;&#xff1a;软件设计&#xff08;上&#xff09;》&#xff09;谈了软件的整体架构设计&#xff0c;今天聊一下模块的详细设计。 模…...

C++并发编程指南07

文章目录 [TOC]5.1 内存模型5.1.1 对象和内存位置图5.1 分解一个 struct&#xff0c;展示不同对象的内存位置 5.1.2 对象、内存位置和并发5.1.3 修改顺序示例代码 5.2 原子操作和原子类型5.2.1 标准原子类型标准库中的原子类型特殊的原子类型备选名称内存顺序参数 5.2.2 std::a…...

MySQL 容器已经停止(但仍然存在),但希望重新启动它,并使它的 3306 端口映射到宿主机的 3306 端口是不可行的

重新启动容器并映射端口是不行的 由于你已经有一个名为 mysql-container 的 MySQL 容器&#xff0c;你可以使用 docker start 启动它。想要让3306 端口映射到宿主机是不行的&#xff0c;实际上&#xff0c;端口映射是在容器启动时指定的。你无法在容器已经创建的情况下直接修改…...

AI大模型开发原理篇-6:Seq2Seq编码器-解码器架构

基本概念 Seq2Seq架构的全名是“Sequence-to-Sequence”&#xff0c;简称S2S&#xff0c;意为将一个序列映射到另一个序列。q2Seq编码器-解码器架构&#xff0c;这也是Transformer的基础架构。Seq2Seq架构是一个用于处理输入序列和生成输出序列的神经网络模型&#xff0c;由一…...