【人工智能】Python常用库-PyTorch常用方法教程
PyTorch 是一个强大的开源深度学习框架,以其灵活性和动态计算图而广受欢迎。以下是 PyTorch 的详细教程,涵盖从基础到实际应用的使用方法。
1. 安装与导入
1.1 安装 PyTorch
访问 PyTorch 官方网站,根据系统、Python 版本和 CUDA 支持选择安装命令。
常用安装命令:
pip install torch torchvision torchaudio
1.2 导入库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
2. PyTorch 基础
2.1 张量(Tensor)
张量是 PyTorch 的核心数据结构,可以看作是一个高维数组。
# 创建张量
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])# 基本运算
c = a + b
print(c) # 输出 tensor([5., 7., 9.])# 随机张量
random_tensor = torch.rand((2, 3)) # 2行3列随机数
print(random_tensor)
输出结果
tensor([5., 7., 9.])
tensor([[0.9980, 0.2970, 0.5257],[0.8807, 0.0471, 0.7896]])
2.2 自动求导
PyTorch 提供动态计算图支持自动求导。
x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3*x + 4y.backward() # 自动求导
print(x.grad) # 输出 dy/dx = 2*x + 3 = 7.0
输出结果
tensor(7.)
3. 数据加载
PyTorch 提供强大的数据加载功能。
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
4. 构建神经网络
4.1 使用 nn.Module
构建模型
import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = x.view(-1, 28 * 28) # 展平输入x = self.relu(self.fc1(x))x = self.softmax(self.fc2(x))return xmodel = SimpleNN()print(model)
输出结果
SimpleNN((fc1): Linear(in_features=784, out_features=128, bias=True)(relu): ReLU()(fc2): Linear(in_features=128, out_features=10, bias=True)(softmax): Softmax(dim=1)
)
5. 模型训练
5.1 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)
5.2 训练循环
for epoch in range(5):for images, labels in train_loader:optimizer.zero_grad() # 梯度清零outputs = model(images)loss = criterion(outputs, labels) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新权重print(f"Epoch {epoch+1}, Loss: {loss.item()}")
完整代码
from torch import nn, optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoaderclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = x.view(-1, 28 * 28) # 展平输入x = self.relu(self.fc1(x))x = self.softmax(self.fc2(x))return xmodel = SimpleNN()# 下载并加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):for images, labels in train_loader:optimizer.zero_grad() # 梯度清零outputs = model(images)loss = criterion(outputs, labels) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新权重print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
输出结果
Epoch 1, Loss: 1.482284665107727
Epoch 2, Loss: 1.4968496561050415
Epoch 3, Loss: 1.5289227962493896
Epoch 4, Loss: 1.4832825660705566
Epoch 5, Loss: 1.5070817470550537
6. 模型评估
6.1 在测试集上评估
test_data = MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {correct / total * 100:.2f}%")
输出结果
Test Accuracy: 10.32%
7. GPU 加速
PyTorch 支持使用 GPU 加速。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)# 将数据也移动到 GPU
for images, labels in train_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)
8. 保存与加载模型
8.1 保存模型
torch.save(model.state_dict(), 'model.pth')
8.2 加载模型
model = SimpleNN()
model.load_state_dict(torch.load('model.pth'))
model.eval() # 切换到评估模式
9. 实际案例
9.1 CIFAR-10 图像分类
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms# CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(16 * 16 * 16, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = x.view(-1, 16 * 16 * 16)x = self.fc1(x)return xmodel = CNN()
# 后续训练步骤类似
10. PyTorch 优势总结
- 动态计算图:支持动态构建与修改模型。
- 灵活性:适合研究和开发,易于调试。
- 强大的社区支持:广泛的教程、示例和扩展工具。
通过实践,PyTorch 能够帮助用户更好地理解和实现深度学习算法!
相关文章:
【人工智能】Python常用库-PyTorch常用方法教程
PyTorch 是一个强大的开源深度学习框架,以其灵活性和动态计算图而广受欢迎。以下是 PyTorch 的详细教程,涵盖从基础到实际应用的使用方法。 1. 安装与导入 1.1 安装 PyTorch 访问 PyTorch 官方网站,根据系统、Python 版本和 CUDA 支持选择安…...

Android Studio安装TalkX AI编程助手
文章目录 TalkX简介编程场景 TalkX安装TalkX编程使用ai编程助手相关文章 TalkX简介 TalkX是一款将OpenAI的GPT 3.5/4模型集成到IDE的AI编程插件。它免费提供特定场景的AI编程指导,帮助开发人员提高工作效率约38%,甚至在解决编程问题的效率上提升超过2倍…...

#渗透测试#红蓝攻防#HW#漏洞挖掘#漏洞复现02-永恒之蓝漏洞
免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停…...

gitlab自动打包python项目
现在新版的gitlab可以不用自己配置runner什么的了 直接写.gitlab-ci.yml文件就行,这里给出一个简单的依靠setup把python项目打包成whl文件的方法 首先写.gitlab-ci.yml文件,放到项目根目录里 stages: # List of stages for jobs, and their or…...

残差神经网络
目录 1. 梯度消失问题 2. 残差学习的引入 3. 跳跃连接(Shortcut Connections) 4. 恒等映射与维度匹配 5. 反向传播与梯度流 6. 网络深度与性能 总结 残差神经网络的原理是基于“残差学习”的概念,它旨在解决深度神经网络训练中的梯度消…...

mini-spring源码分析
IOC模块 关键解释 beanFactory:beanFactory是一个hashMap, key为beanName, Value为 beanDefination beanDefination: BeanDefinitionRegistry,BeanDefinition注册表接口,定义注册BeanDefinition的方法 beanReference:增加Bean…...

黑马程序员Java项目实战《苍穹外卖》Day01
苍穹外卖-day01 课程内容 软件开发整体介绍苍穹外卖项目介绍开发环境搭建导入接口文档Swagger 项目整体效果展示: 管理端-外卖商家使用 用户端-点餐用户使用 当我们完成该项目的学习,可以培养以下能力: 1. 软件开发整体介绍 作为一…...

uniapp开发支付宝小程序自定义tabbar样式异常
解决方案: 这个问题应该是支付宝基础库的问题,除了依赖于官方更新之外,开发者可以利用《自定义 tabBar》曲线救国 也就是创建一个空内容的自定义tabBar,这样即使 tabBar 被渲染出来,但从视觉上也不会有问题 1.官方文…...

python+django5.1+docker实现CICD自动化部署springboot 项目前后端分离vue-element
一、开发环境搭建和配置 # channels是一个用于在Django中实现WebSocket、HTTP/2和其他异步协议的库。 pip install channels#channels-redis是一个用于在Django Channels中使用Redis作为后台存储的库。它可以用于处理#WebSocket连接的持久化和消息传递。 pip install channels…...

python代码示例(读取excel文件,自动播放音频)
目录 python 操作excel 表结构 安装第三方库 代码 自动播放音频 介绍 安装第三方库 代码 python 操作excel 表结构 求出100班同学的平均分 安装第三方库 因为这里的表结构是.xlsx文件,需要使用openpyxl库 如果是.xls格式文件,需要使用xlrd库 pip install openpyxl /…...
【第十课】Rust并发编程(一)
目录 前言 Fork和Join 前言 本节会介绍Rust中的并发编程,并发编程在编程中是提升cpu使用率的一大利器,通过多线程技术提升效率,Rust的并发和其他编程语言的并发不同的地方在于,Rust号称无畏并发。更重要的一点是安全。Rust中所有…...
图形渲染性能优化
variable rate shading conditional render 设置可见性等, 不需要重新build command buffer indirect draw glMultiDraw* - 直接支持多次绘制glMultiDrawIndirect - 间接多次绘制multithreading 多线程录制 实例化渲染 lod texture array 小对象剔除 投影到…...

elasticsearch的索引模版使用方法
5 索引模版⭐️⭐️⭐️⭐️⭐️ 索引模板就是创建索引时要遵循的模板规则索引模板仅对新创建的索引有效,已经创建的索引并不受索引模板的影响 5.1 索引模版的基本使用 1.查看所有的索引模板 GET 10.0.0.91:9200/_index_template2.创建自定义索引模板 xixi &…...

论文学习——进化动态约束多目标优化:测试集和算法
论文题目:Evolutionary Dynamic Constrained Multiobjective Optimization: Test Suite and Algorithm 进化动态约束多目标优化:测试集和算法(Guoyu Chen ,YinanGuo , Member, IEEE, Yong Wang , Senior Member, IEEE, Jing Liang , Senior …...
C++中的volatile关键字
作用: 1.它用于修饰变量,告知编译器该变量的值可能会在程序的外部被改变,编译器不能对这个变量的访问进行优化。这是因为编译器通常会对代码进行优化,例如把变量的值缓存到寄存器中,但对于 volatile 变量,…...
linux桌面qt应用程序UI自动化实现之dogtail
1. 前言 Dogtail适用于Linux 系统上进行 GUI 自动化测试,利用 Accessibility 技术与桌面程序通信;Dogtail 包含一个名为 sniff 的组件,这是一个嗅探器,用于 GUI 程序追踪; 源码下载:dogtail PyPI 可通过sudo python setup.py install安装或sudo pip install dogt…...
Hello World C#
using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using System; 引入了System命名空间,基本输入输出。一般只用这个,后面的不用 using System.Collections.Generic; 包含了定…...

SAP开发语言ABAP开发入门
1. 了解ABAP开发环境和基础知识 - ABAP简介 - ABAP(Advanced Business Application Programming)是SAP系统中的编程语言,主要用于开发企业级的业务应用程序,如财务、物流、人力资源等模块的定制开发。 - 开发环境搭建 - 首先需…...

应急响应靶机——easy溯源
载入虚拟机,开启虚拟机: (账户密码:zgsfsys/zgsfsys) 解题程序.exe是额外下载解压得到的: 1. 攻击者内网跳板机IP地址 2. 攻击者服务器地址 3. 存在漏洞的服务(提示:7个字符) 4. 攻击者留下的flag(格式…...

【前端】vscode报错: 无法加载文件 D:\nodejs\node_global\yarn.ps1,因为在此系统上禁止运行脚本。
vscode运行前端代码时候,执行yarn install时候报错 问题: 无法加载文件 D:\nodejs\node_global\yarn.ps1,因为在此系统上禁止运行脚本。 解决方式: 首先用管理员身份运行vscode 查看 get-ExecutionPolicy,Restrict…...

调用支付宝接口响应40004 SYSTEM_ERROR问题排查
在对接支付宝API的时候,遇到了一些问题,记录一下排查过程。 Body:{"datadigital_fincloud_generalsaas_face_certify_initialize_response":{"msg":"Business Failed","code":"40004","sub_msg…...
【android bluetooth 框架分析 04】【bt-framework 层详解 1】【BluetoothProperties介绍】
1. BluetoothProperties介绍 libsysprop/srcs/android/sysprop/BluetoothProperties.sysprop BluetoothProperties.sysprop 是 Android AOSP 中的一种 系统属性定义文件(System Property Definition File),用于声明和管理 Bluetooth 模块相…...
C++中string流知识详解和示例
一、概览与类体系 C 提供三种基于内存字符串的流,定义在 <sstream> 中: std::istringstream:输入流,从已有字符串中读取并解析。std::ostringstream:输出流,向内部缓冲区写入内容,最终取…...
docker 部署发现spring.profiles.active 问题
报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

基于TurtleBot3在Gazebo地图实现机器人远程控制
1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...
站群服务器的应用场景都有哪些?
站群服务器主要是为了多个网站的托管和管理所设计的,可以通过集中管理和高效资源的分配,来支持多个独立的网站同时运行,让每一个网站都可以分配到独立的IP地址,避免出现IP关联的风险,用户还可以通过控制面板进行管理功…...

宇树科技,改名了!
提到国内具身智能和机器人领域的代表企业,那宇树科技(Unitree)必须名列其榜。 最近,宇树科技的一项新变动消息在业界引发了不少关注和讨论,即: 宇树向其合作伙伴发布了一封公司名称变更函称,因…...

(一)单例模式
一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...
LangFlow技术架构分析
🔧 LangFlow 的可视化技术栈 前端节点编辑器 底层框架:基于 (一个现代化的 React 节点绘图库) 功能: 拖拽式构建 LangGraph 状态机 实时连线定义节点依赖关系 可视化调试循环和分支逻辑 与 LangGraph 的深…...

elementUI点击浏览table所选行数据查看文档
项目场景: table按照要求特定的数据变成按钮可以点击 解决方案: <el-table-columnprop"mlname"label"名称"align"center"width"180"><template slot-scope"scope"><el-buttonv-if&qu…...