当前位置: 首页 > news >正文

【人工智能】Python常用库-PyTorch常用方法教程

PyTorch 是一个强大的开源深度学习框架,以其灵活性和动态计算图而广受欢迎。以下是 PyTorch 的详细教程,涵盖从基础到实际应用的使用方法。


1. 安装与导入

1.1 安装 PyTorch

访问 PyTorch 官方网站,根据系统、Python 版本和 CUDA 支持选择安装命令。

常用安装命令:

pip install torch torchvision torchaudio
1.2 导入库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

2. PyTorch 基础

2.1 张量(Tensor)

张量是 PyTorch 的核心数据结构,可以看作是一个高维数组。

# 创建张量
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])# 基本运算
c = a + b
print(c)  # 输出 tensor([5., 7., 9.])# 随机张量
random_tensor = torch.rand((2, 3))  # 2行3列随机数
print(random_tensor)

输出结果

tensor([5., 7., 9.])
tensor([[0.9980, 0.2970, 0.5257],[0.8807, 0.0471, 0.7896]])
2.2 自动求导

PyTorch 提供动态计算图支持自动求导。

x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3*x + 4y.backward()  # 自动求导
print(x.grad)  # 输出 dy/dx = 2*x + 3 = 7.0

输出结果

tensor(7.)

3. 数据加载

PyTorch 提供强大的数据加载功能。

import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

4. 构建神经网络

4.1 使用 nn.Module 构建模型
import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = x.view(-1, 28 * 28)  # 展平输入x = self.relu(self.fc1(x))x = self.softmax(self.fc2(x))return xmodel = SimpleNN()print(model)

输出结果

SimpleNN((fc1): Linear(in_features=784, out_features=128, bias=True)(relu): ReLU()(fc2): Linear(in_features=128, out_features=10, bias=True)(softmax): Softmax(dim=1)
)

5. 模型训练

5.1 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)
5.2 训练循环
for epoch in range(5):for images, labels in train_loader:optimizer.zero_grad()  # 梯度清零outputs = model(images)loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重print(f"Epoch {epoch+1}, Loss: {loss.item()}")

完整代码

from torch import nn, optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoaderclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = x.view(-1, 28 * 28)  # 展平输入x = self.relu(self.fc1(x))x = self.softmax(self.fc2(x))return xmodel = SimpleNN()# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)criterion = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):for images, labels in train_loader:optimizer.zero_grad()  # 梯度清零outputs = model(images)loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

输出结果

Epoch 1, Loss: 1.482284665107727
Epoch 2, Loss: 1.4968496561050415
Epoch 3, Loss: 1.5289227962493896
Epoch 4, Loss: 1.4832825660705566
Epoch 5, Loss: 1.5070817470550537

6. 模型评估

6.1 在测试集上评估
test_data = MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)correct = 0
total = 0
with torch.no_grad():  # 禁用梯度计算for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {correct / total * 100:.2f}%")

输出结果

Test Accuracy: 10.32%

7. GPU 加速

PyTorch 支持使用 GPU 加速。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)# 将数据也移动到 GPU
for images, labels in train_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)

8. 保存与加载模型

8.1 保存模型
torch.save(model.state_dict(), 'model.pth')
8.2 加载模型
model = SimpleNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()  # 切换到评估模式

9. 实际案例

9.1 CIFAR-10 图像分类
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms# CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(16 * 16 * 16, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = x.view(-1, 16 * 16 * 16)x = self.fc1(x)return xmodel = CNN()
# 后续训练步骤类似

10. PyTorch 优势总结

  1. 动态计算图:支持动态构建与修改模型。
  2. 灵活性:适合研究和开发,易于调试。
  3. 强大的社区支持:广泛的教程、示例和扩展工具。

通过实践,PyTorch 能够帮助用户更好地理解和实现深度学习算法!

相关文章:

【人工智能】Python常用库-PyTorch常用方法教程

PyTorch 是一个强大的开源深度学习框架,以其灵活性和动态计算图而广受欢迎。以下是 PyTorch 的详细教程,涵盖从基础到实际应用的使用方法。 1. 安装与导入 1.1 安装 PyTorch 访问 PyTorch 官方网站,根据系统、Python 版本和 CUDA 支持选择安…...

Android Studio安装TalkX AI编程助手

文章目录 TalkX简介编程场景 TalkX安装TalkX编程使用ai编程助手相关文章 TalkX简介 TalkX是一款将OpenAI的GPT 3.5/4模型集成到IDE的AI编程插件。它免费提供特定场景的AI编程指导,帮助开发人员提高工作效率约38%,甚至在解决编程问题的效率上提升超过2倍…...

#渗透测试#红蓝攻防#HW#漏洞挖掘#漏洞复现02-永恒之蓝漏洞

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停…...

gitlab自动打包python项目

现在新版的gitlab可以不用自己配置runner什么的了 直接写.gitlab-ci.yml文件就行,这里给出一个简单的依靠setup把python项目打包成whl文件的方法 首先写.gitlab-ci.yml文件,放到项目根目录里 stages: # List of stages for jobs, and their or…...

残差神经网络

目录 1. 梯度消失问题 2. 残差学习的引入 3. 跳跃连接(Shortcut Connections) 4. 恒等映射与维度匹配 5. 反向传播与梯度流 6. 网络深度与性能 总结 残差神经网络的原理是基于“残差学习”的概念,它旨在解决深度神经网络训练中的梯度消…...

mini-spring源码分析

IOC模块 关键解释 beanFactory:beanFactory是一个hashMap, key为beanName, Value为 beanDefination beanDefination: BeanDefinitionRegistry,BeanDefinition注册表接口,定义注册BeanDefinition的方法 beanReference:增加Bean…...

黑马程序员Java项目实战《苍穹外卖》Day01

苍穹外卖-day01 课程内容 软件开发整体介绍苍穹外卖项目介绍开发环境搭建导入接口文档Swagger 项目整体效果展示: ​ 管理端-外卖商家使用 ​ 用户端-点餐用户使用 当我们完成该项目的学习,可以培养以下能力: 1. 软件开发整体介绍 作为一…...

uniapp开发支付宝小程序自定义tabbar样式异常

解决方案: 这个问题应该是支付宝基础库的问题,除了依赖于官方更新之外,开发者可以利用《自定义 tabBar》曲线救国 也就是创建一个空内容的自定义tabBar,这样即使 tabBar 被渲染出来,但从视觉上也不会有问题 1.官方文…...

python+django5.1+docker实现CICD自动化部署springboot 项目前后端分离vue-element

一、开发环境搭建和配置 # channels是一个用于在Django中实现WebSocket、HTTP/2和其他异步协议的库。 pip install channels#channels-redis是一个用于在Django Channels中使用Redis作为后台存储的库。它可以用于处理#WebSocket连接的持久化和消息传递。 pip install channels…...

python代码示例(读取excel文件,自动播放音频)

目录 python 操作excel 表结构 安装第三方库 代码 自动播放音频 介绍 安装第三方库 代码 python 操作excel 表结构 求出100班同学的平均分 安装第三方库 因为这里的表结构是.xlsx文件,需要使用openpyxl库 如果是.xls格式文件,需要使用xlrd库 pip install openpyxl /…...

【第十课】Rust并发编程(一)

目录 前言 Fork和Join 前言 本节会介绍Rust中的并发编程,并发编程在编程中是提升cpu使用率的一大利器,通过多线程技术提升效率,Rust的并发和其他编程语言的并发不同的地方在于,Rust号称无畏并发。更重要的一点是安全。Rust中所有…...

图形渲染性能优化

variable rate shading conditional render 设置可见性等, 不需要重新build command buffer indirect draw glMultiDraw* - 直接支持多次绘制glMultiDrawIndirect - 间接多次绘制multithreading 多线程录制 实例化渲染 lod texture array 小对象剔除 投影到…...

elasticsearch的索引模版使用方法

5 索引模版⭐️⭐️⭐️⭐️⭐️ 索引模板就是创建索引时要遵循的模板规则索引模板仅对新创建的索引有效,已经创建的索引并不受索引模板的影响 5.1 索引模版的基本使用 1.查看所有的索引模板 GET 10.0.0.91:9200/_index_template2.创建自定义索引模板 xixi &…...

论文学习——进化动态约束多目标优化:测试集和算法

论文题目:Evolutionary Dynamic Constrained Multiobjective Optimization: Test Suite and Algorithm 进化动态约束多目标优化:测试集和算法(Guoyu Chen ,YinanGuo , Member, IEEE, Yong Wang , Senior Member, IEEE, Jing Liang , Senior …...

C++中的volatile关键字

作用: 1.它用于修饰变量,告知编译器该变量的值可能会在程序的外部被改变,编译器不能对这个变量的访问进行优化。这是因为编译器通常会对代码进行优化,例如把变量的值缓存到寄存器中,但对于 volatile 变量,…...

linux桌面qt应用程序UI自动化实现之dogtail

1. 前言 Dogtail适用于Linux 系统上进行 GUI 自动化测试,利用 Accessibility 技术与桌面程序通信;Dogtail 包含一个名为 sniff 的组件,这是一个嗅探器,用于 GUI 程序追踪; 源码下载:​​dogtail PyPI 可通过sudo python setup.py install安装或sudo pip install dogt…...

Hello World C#

using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using System; 引入了System命名空间,基本输入输出。一般只用这个,后面的不用 using System.Collections.Generic; 包含了定…...

SAP开发语言ABAP开发入门

1. 了解ABAP开发环境和基础知识 - ABAP简介 - ABAP(Advanced Business Application Programming)是SAP系统中的编程语言,主要用于开发企业级的业务应用程序,如财务、物流、人力资源等模块的定制开发。 - 开发环境搭建 - 首先需…...

应急响应靶机——easy溯源

载入虚拟机,开启虚拟机: (账户密码:zgsfsys/zgsfsys) 解题程序.exe是额外下载解压得到的: 1. 攻击者内网跳板机IP地址 2. 攻击者服务器地址 3. 存在漏洞的服务(提示:7个字符) 4. 攻击者留下的flag(格式…...

【前端】vscode报错: 无法加载文件 D:\nodejs\node_global\yarn.ps1,因为在此系统上禁止运行脚本。

vscode运行前端代码时候,执行yarn install时候报错 问题: 无法加载文件 D:\nodejs\node_global\yarn.ps1,因为在此系统上禁止运行脚本。 解决方式: 首先用管理员身份运行vscode 查看 get-ExecutionPolicy,Restrict…...

华为云AI开发平台ModelArts

华为云ModelArts:重塑AI开发流程的“智能引擎”与“创新加速器”! 在人工智能浪潮席卷全球的2025年,企业拥抱AI的意愿空前高涨,但技术门槛高、流程复杂、资源投入巨大的现实,却让许多创新构想止步于实验室。数据科学家…...

web vue 项目 Docker化部署

Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage)&#xff1a…...

基于ASP.NET+ SQL Server实现(Web)医院信息管理系统

医院信息管理系统 1. 课程设计内容 在 visual studio 2017 平台上,开发一个“医院信息管理系统”Web 程序。 2. 课程设计目的 综合运用 c#.net 知识,在 vs 2017 平台上,进行 ASP.NET 应用程序和简易网站的开发;初步熟悉开发一…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

基础测试工具使用经验

背景 vtune,perf, nsight system等基础测试工具,都是用过的,但是没有记录,都逐渐忘了。所以写这篇博客总结记录一下,只要以后发现新的用法,就记得来编辑补充一下 perf 比较基础的用法: 先改这…...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说,传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度,通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)

本文把滑坡位移序列拆开、筛优质因子,再用 CNN-BiLSTM-Attention 来动态预测每个子序列,最后重构出总位移,预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵(S…...

什么是Ansible Jinja2

理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...

均衡后的SNRSINR

本文主要摘自参考文献中的前两篇,相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程,其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt​ 根发送天线, n r n_r nr​ 根接收天线的 MIMO 系…...