【深度学习】Pytorch的深入理解和研究
一、Pytorch核心理解
PyTorch 是一个灵活且强大的深度学习框架,广泛应用于研究和工业领域。要深入理解和研究 PyTorch,需要从其核心概念、底层机制以及高级功能入手。以下是对 PyTorch 的深入理解与研究的详细说明。
1. 概念
动态计算图(Dynamic Computation Graph)
定义:PyTorch 使用动态计算图(也称为“定义即运行”模式),允许在运行时动态构建和修改计算图。
特点:
- 更适合调试和实验。
- 支持灵活的控制流(如循环、条件判断)。
实现原理:
- 每次前向传播都会生成一个新的计算图。
- 反向传播时,自动计算梯度并释放计算图以节省内存。
2. 张量(Tensor)
2.1 张量的理解及与NumPy的对比
张量是一个多维数组,可以表示标量(0 维)、向量(1 维)、矩阵(2 维)或更高维度的数据。
特点:
- 支持动态计算图(Dynamic Computation Graph),适合深度学习任务。
- 可以在 CPU 或 GPU 上运行,利用硬件加速。
张量的理解及与NumPy的对比:

2.2 张量的创建
(1) 基本创建方法
import torch
# 创建未初始化的张量
x = torch.empty(3, 3)
# 创建随机张量
y = torch.rand(3, 3)
# 创建全零张量
z = torch.zeros(3, 3)
# 创建从 NumPy 转换的张量
import numpy as np
a = np.array([1, 2, 3])
b = torch.from_numpy(a)
(2)指定数据类型和设备
# 指定数据类型
x = torch.tensor([1, 2, 3], dtype=torch.float32)
# 指定设备(CPU 或 GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = x.to(device)
2.3 张量的操作
(1)基本运算
# 加法
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = x + y
# 矩阵乘法
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
c = torch.matmul(a, b)
(2)广播机制
广播机制允许不同形状的张量进行运算。
广播规则:
- 维度对齐:如果两个张量的维度数不同,则在较小张量的前面添加新的维度(大小为 1),使其维度数相同。
- 逐维比较:对每个维度,检查两个张量的大小是否相等,或者其中一个张量的大小为 1。
- 扩展维度:如果某个维度的大小为 1,则将其扩展为与另一个张量对应维度的大小相同。
x = torch.tensor([1, 2, 3])
y = torch.tensor(2)
z = x + y # 将标量 y 广播到每个元素
print(z) # 输出:tensor([3, 4, 5])
错误示例:
如果张量的形状无法满足广播规则,则会报错:
# 创建张量
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状 (2, 3)
b = torch.tensor([10, 20]) # 形状 (2,)
# 尝试广播
try:c = a + b
except RuntimeError as e:print(e) # 输出:The size of tensor a must match the size of tensor b at non-singleton dimension 1
#解释:
#a 的第二维大小为 3,而 b 的大小为 2,无法对齐。
(3)索引与切片
# 索引
x = torch.tensor([[1, 2], [3, 4]])
print(x[0, 1]) # 输出:2
# 切片
print(x[:, 1]) # 输出:tensor([2, 4])
2.4 张量的属性
(1)形状与维度
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 获取形状
print(x.shape) # 输出:torch.Size([2, 3])
# 获取维度
print(x.ndim) # 输出:2
(2)数据类型
x = torch.tensor([1, 2, 3], dtype=torch.float32)
print(x.dtype) # 输出:torch.float32
(3)设备信息
x = torch.tensor([1, 2, 3])
print(x.device) # 输出:cpu
2.5 张量与自动求导
(1)自动求导基础
PyTorch 的张量支持自动求导,通过 requires_grad=True 启用梯度跟踪。
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2 + 3 * x + 5
# 计算梯度
y.backward()
print(f"Gradient of y w.r.t x: {x.grad}") # 输出:tensor([7.])
(2)禁用梯度计算
禁用梯度计算的理解:禁用梯度计算是指在深度学习模型训练过程中不计算梯度,这通常是通过上下文管理器torch.no_grad()在PyTorch中实现的。禁用梯度计算的主要目的是在某些操作中节省内存和提高计算效率,特别是在进行推理(inference)时。
因此,在推理阶段,可以通过 torch.no_grad() 禁用梯度计算以节省内存。
with torch.no_grad():z = x + 2
print(z.requires_grad) # 输出:False
2.6 张量的底层机制
(1)内存布局
连续存储:张量在内存中是连续存储的,默认按行优先顺序排列。
非连续张量:某些操作(如转置)可能导致张量变得非连续。
x = torch.tensor([[1, 2], [3, 4]])
y = x.t() # 转置
print(y.is_contiguous()) # 输出:False
(2)数据共享
张量之间的操作可能共享底层数据,修改一个张量会影响另一个张量。
x = torch.tensor([1, 2, 3])
y = x.view(3, 1) # 修改视图
x[0] = 10
print(y) # 输出:tensor([[10], [2], [3]])
2.7 高级功能
(1)张量的序列化
张量的序列化是将张量数据保存为一种可以存储或传输的格式的过程,以便在后续需要时重新加载和使用。
它允许模型和数据在不同的运行时环境之间进行共享和持久存储。
# 保存张量
torch.save(x, "tensor.pth")
# 加载张量
x_loaded = torch.load("tensor.pth")
(2)张量的分布式操作
在分布式训练中,张量可以在多个设备之间传递。
import torch.distributed as dist
dist.init_process_group(backend="nccl")
x = torch.tensor([1, 2, 3]).cuda()
dist.all_reduce(x, op=dist.ReduceOp.SUM)
3. 底层机制
3.1 Autograd(自动求导系统)
定义:Autograd 是 PyTorch 的自动求导引擎,用于计算张量的梯度。
工作原理:
- 在前向传播中记录所有操作。
- 在反向传播中根据链式法则计算梯度。
关键组件:
- torch.autograd.Function:自定义前向和反向传播函数。
- torch.no_grad():禁用梯度计算(用于推理阶段)。
class MyFunction(torch.autograd.Function):@staticmethoddef forward(ctx, input):ctx.save_for_backward(input)return input.clamp(min=0)@staticmethoddef backward(ctx, grad_output):input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
# 使用自定义函数
x = torch.tensor([-1.0, 2.0, -3.0], requires_grad=True)
my_relu = MyFunction.apply
y = my_relu(x)
y.sum().backward()
print(f"Gradient: {x.grad}")
3.2 CUDA 和 GPU 加速
定义:PyTorch 支持将张量和模型迁移到 GPU 上,利用 GPU 的并行计算能力加速训练。
实现方式:
- .cuda() 或 .to(device) 将张量或模型迁移到 GPU。
- torch.device 用于指定设备(CPU 或 GPU)。
# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 创建张量并迁移到 GPU
x = torch.randn(3, 3).to(device)
y = torch.randn(3, 3).to(device)
# 在 GPU 上进行计算
z = x + y
print("Result on GPU:", z)
4. 高级功能
4.1 自定义模型
定义:通过继承 torch.nn.Module 类,可以创建自定义神经网络模型。
关键方法:
- forward():定义前向传播逻辑。
- parameters():返回模型的所有可训练参数。
import torch.nn as nn
import torch.optim as optim
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 20)self.relu = nn.ReLU()self.fc2 = nn.Linear(20, 1)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x
# 创建模型和优化器
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 模拟输入
input_data = torch.randn(5, 10)
output = model(input_data)
print("Model output:", output)
4.2 分布式训练
定义:PyTorch 提供了分布式训练工具(如 torch.distributed),支持多 GPU 和多节点训练。
常用方法:
- 数据并行(torch.nn.DataParallel)。
- 分布式数据并行(torch.nn.parallel.DistributedDataParallel)。
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# 初始化分布式环境
dist.init_process_group(backend="nccl")
# 创建模型并迁移到 GPU
model = SimpleNet().cuda()
ddp_model = DDP(model)
# 训练代码省略...
4.3 混合精度训练
定义:混合精度训练使用 FP16 和 FP32 结合的方式,减少显存占用并加速训练。
实现方式:
使用 torch.cuda.amp 提供的工具。
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
for data, target in dataloader:optimizer.zero_grad()# 使用混合精度with autocast():output = model(data)loss = loss_fn(output, target)# 缩放损失并反向传播scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()
4.4 实验与研究
(1)模型可视化
工具:使用 TensorBoard 或 Matplotlib 可视化训练过程。
用途:
- 监控损失和准确率变化。
- 可视化模型结构和特征图。
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
# 记录标量
for epoch in range(10):writer.add_scalar("Loss/train", epoch * 0.1, epoch)
writer.close()
(2)模型解释性
工具:使用 Captum 库分析模型的特征重要性。
用途:
- 解释模型决策过程。
- 发现潜在问题(如偏差或过拟合)。
from captum.attr import IntegratedGradients
ig = IntegratedGradients(model)
attributions = ig.attribute(input_data, target=0)
print("Attributions:", attributions)
二、Pytorch应用场景
1. 计算机视觉(Computer Vision)
1.1 图像分类
任务:将图像分配到预定义的类别。
实现:使用卷积神经网络(CNN),如 ResNet、VGG 或自定义模型。
应用:
- 医疗影像分析(如 X 光片分类)。
- 自动驾驶中的交通标志识别。
import torch
import torchvision.models as models
# 加载预训练模型
model = models.resnet18(pretrained=True)
# 修改输出层以适应新任务
num_classes = 10
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
1.2 目标检测
任务:在图像中定位并分类多个目标。
实现:使用 Faster R-CNN、YOLO 或 SSD 等模型。
应用:
- 安防监控(如行人检测)。
- 工业自动化(如缺陷检测)。
1.3 图像分割
任务:为图像中的每个像素分配一个类别标签。
实现:使用 U-Net、Mask R-CNN 等模型。
应用:
- 医学图像分割(如肿瘤区域标记)。
- 卫星图像分析(如土地覆盖分类)。
2. 自然语言处理(Natural Language Processing, NLP)
2.1 文本分类
任务:将文本分配到预定义的类别。
实现:使用 Transformer 模型(如 BERT、RoBERTa)。
应用:
- 情感分析(如评论情感分类)。
- 垃圾邮件检测。
示例代码:
from transformers import BertTokenizer, BertForSequenceClassification
# 加载预训练模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# 输入文本
text = "I love using PyTorch!"
inputs = tokenizer(text, return_tensors="pt")
# 推理
outputs = model(**inputs)
logits = outputs.logits
print(logits)
2.2 机器翻译
任务:将一种语言的文本翻译成另一种语言。
实现:使用序列到序列(Seq2Seq)模型或 Transformer。
应用:
- 跨语言交流工具。
- 多语言内容生成。
2.3 文本生成
任务:根据输入生成连贯的文本。
实现:使用 GPT 系列模型。
应用:
- 写作助手(如自动完成文章)。
- 聊天机器人。
3. 推荐系统(Recommendation Systems)
3.1 用户行为建模
任务:根据用户的历史行为推荐相关内容。
实现:使用协同过滤、矩阵分解或深度学习模型。
应用:
- 电商平台推荐商品。
- 视频平台推荐视频。
3.2 多模态推荐
任务:结合多种数据源(如文本、图像)进行推荐。
实现:使用多模态融合模型。
应用:
- 社交媒体内容推荐。
- 广告投放优化。
三、总结
深入理解和研究 PyTorch 需要掌握以下内容:
- 核心概念:动态计算图、张量操作、自动求导。
- 底层机制:Autograd、CUDA 加速。
- 高级功能:自定义模型、分布式训练、混合精度训练。
- 实验与研究:模型可视化、解释性分析。
通过不断实践和探索,你可以充分利用 PyTorch 的灵活性和强大功能,解决复杂的深度学习问题!
相关文章:
【深度学习】Pytorch的深入理解和研究
一、Pytorch核心理解 PyTorch 是一个灵活且强大的深度学习框架,广泛应用于研究和工业领域。要深入理解和研究 PyTorch,需要从其核心概念、底层机制以及高级功能入手。以下是对 PyTorch 的深入理解与研究的详细说明。 1. 概念 动态计算图(D…...
什么是 Vue 的自定义事件?如何触发和监听?
Vue 的自定义事件详解 什么是自定义事件? 在 Vue 中,自定义事件是组件之间通信的重要机制。自定义事件允许子组件向父组件发送消息,通常用于处理用户交互或异步操作的结果。这种机制使得组件间的通信更加灵活和解耦。 自定义事件的基本概念…...
windows上vscode cmake工程搭建
安装vscode插件: 1.按装fastc(主要是安装MinGW\mingw64比较方便) 2.安装C,cmake,cmake tools插件 3.准备工作完成之后,按F1,选择cmake:Quick Start就可以创建一个cmake工程。 4.设置Cmake: G…...
DEMF模型赋能多模态图像融合,助力肺癌高效分类
目录 论文创新点 实验设计 1. 可视化的研究设计 2. 样本选取和数据处理 3. 集成分类模型 4. 实验结果 5. 可视化结果 图表总结 可视化知识图谱 在肺癌早期筛查中,计算机断层扫描(CT)和正电子发射断层扫描(PET)作为两种关键的影像学手段,分别提供了丰富的解剖结构…...
Android:权限permission申请示例代码
Android应用项目每次最开始都要进行权限申请,贴一下权限申请的示例代码,方便后续Ctrl CV使用 1.AndroidManifest.xml 配置要申请的权限 <uses-permission android:name"android.permission.READ_CONTACTS" /> <uses-permission and…...
AI Agent Service Toolkit:一站式大模型智能体开发套件
项目简介 该工具包基于LangGraph、FastAPI和Streamlit构建,提供了构建和运行大模型Agent的最小原子能力,包含LangGraph代理、FastAPI服务、用于与服务交互的客户端以及一个使用客户端提供聊天界面的Streamlit应用。用户可以利用该工具包提供的模板快速搭建基于LangGraph框架…...
大数据SQL调优专题——Hive执行原理
引入 Apache Hive 是基于Hadoop的数据仓库工具,它可以使用SQL来读取、写入和管理存在分布式文件系统中的海量数据。在Hive中,HQL默认转换成MapReduce程序运行到Yarn集群中,大大降低了非Java开发者数据分析的门槛,并且Hive提供命令…...
Python程序打包 |《Python基础教程》第18章笔记
《Python基础教程》第1章笔记👉https://blog.csdn.net/holeer/article/details/143052930 第18章 程序打包 程序可以发布后,你可能想先将它打包。如果程序只包含一个.py文件,这可能不是问题。然而,如果用户不是程序员࿰…...
图论 之 迪斯科特拉算法求解最短路径
文章目录 题目743.网络延迟时间3341.到达最后一个房间的最少时间I 求解最短路径的问题,分为使用BFS和使用迪斯科特拉算法,这两种算法求解的范围是有区别的 BFS适合求解,边的权值都是1的图中的最短路径的问题 图论 之 BFS迪斯科特拉算法适合求…...
掌握Spring开发_常用注解详解
1. 前言 1.1 写作目的 本文旨在全面解析Spring框架中常用的注解,帮助开发者更好地理解和使用这些注解,提高开发效率和代码质量。Spring框架提供了丰富的注解,简化了依赖注入、AOP、事务管理、Web开发等多个方面的开发工作。通过本文的学习,读者可以掌握这些注解的使用方法…...
华为昇腾服务器(固件版本查询、驱动版本查询、CANN版本查询)
文章目录 1. **查看固件和驱动版本**2. **查看CANN版本**3. **其他辅助方法**注意事项 在华为昇腾服务器上查看固件、驱动和CANN版本的常用方法如下: 1. 查看固件和驱动版本 通过命令行工具 npu-smi 执行以下命令查看当前设备的固件(Firmware࿰…...
Kubernetes的Ingress和Service有什么区别?
在Kubernetes中,Ingress和Service是两个不同的概念,它们在功能、作用范围、应用场景等方面存在明显区别,具体如下: 功能 Ingress:主要用于管理集群外部到内部服务的HTTP和HTTPS流量路由。它可以根据域名、路径等规则…...
洛谷B3619(B3620)
B3619 10 进制转 x 进制 - 洛谷 B3620 x 进制转 10 进制 - 洛谷 代码区: #include<algorithm> #include<iostream> #include<vector> using namespace std;int main(){int n,x;cin >> n >> x;vector<char> arry;while(n){if(…...
vue组件,父子通信,路由,异步请求后台接口,跨域
1.组件注册 1.1局部注册 局部注册组件---1.导入import 组件对象名 from 组件网页路径 export default{ name:"名称", data(){return {}}, created(){}, …...
详解分布式ID实践
引言 分布式ID,所谓的分布式ID,就是针对整个系统而言,任何时刻获取一个ID,无论系统处于何种情况,该值不会与之前产生的值重复,之后获取分布式ID时,也不会再获取到与其相同的值,它是…...
.NET + Vue3 的前后端项目在IIS的发布
目录 一、发布准备 1、安装 IIS 2、安装 Windows Hosting Bundle(.NET Core 托管捆绑包) 3、安装 IIS URL Rewrite 二、项目发布 1、后端项目发布 2、前端项目发布 3、将项目部署到 IIS中 三、网站配置 1、IP配置 2、防火墙配置 3、跨域配置…...
软件测试之压力测试
🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 压力测试 压力测试是一种软件测试,用于验证软件应用程序的稳定性和可靠性。压力测试的目标是在极其沉重的负载条件下测量软件的健壮性和错误处理能力&…...
矩阵-矩阵置零
矩阵置零 给定一个 m x n 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为0 。请使用 原地 算法。在计算机科学中,一个原地算法(in-place algorithm)是一种使用小的,固定数量的额外之空间来转…...
【JavaScript】《JavaScript高级程序设计 (第4版) 》笔记-Chapter19-表单脚本
十九、表单脚本 表单脚本 JavaScript 较早的一个用途是承担一部分服务器端表单处理的责任。虽然 Web 和 JavaScript 都已经发展了很多年,但 Web 表单的变化不是很大。由于不能直接使用表单解决问题,因此开发者不得不使用JavaScript 既做表单验证…...
【C# 数据结构】队列 FIFO
目录 队列的概念FIFO (First-In, First-Out)Queue<T> 的工作原理:示例:解释: 小结: 环形队列1. **FIFO?**2. **环形缓冲队列如何实现FIFO?**关键概念: 3. **环形缓冲队列的工作过程**假设…...
java_网络服务相关_gateway_nacos_feign区别联系
1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...
oracle与MySQL数据库之间数据同步的技术要点
Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异,它们的数据同步要求既要保持数据的准确性和一致性,又要处理好性能问题。以下是一些主要的技术要点: 数据结构差异 数据类型差异ÿ…...
第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明
AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...
从零实现STL哈希容器:unordered_map/unordered_set封装详解
本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说,直接开始吧! 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...
WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成
厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...
ios苹果系统,js 滑动屏幕、锚定无效
现象:window.addEventListener监听touch无效,划不动屏幕,但是代码逻辑都有执行到。 scrollIntoView也无效。 原因:这是因为 iOS 的触摸事件处理机制和 touch-action: none 的设置有关。ios有太多得交互动作,从而会影响…...
OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 在 GPU 上对图像执行 均值漂移滤波(Mean Shift Filtering),用于图像分割或平滑处理。 该函数将输入图像中的…...
docker 部署发现spring.profiles.active 问题
报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...
安卓基础(aar)
重新设置java21的环境,临时设置 $env:JAVA_HOME "D:\Android Studio\jbr" 查看当前环境变量 JAVA_HOME 的值 echo $env:JAVA_HOME 构建ARR文件 ./gradlew :private-lib:assembleRelease 目录是这样的: MyApp/ ├── app/ …...
HarmonyOS运动开发:如何用mpchart绘制运动配速图表
##鸿蒙核心技术##运动开发##Sensor Service Kit(传感器服务)# 前言 在运动类应用中,运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据,如配速、距离、卡路里消耗等,用户可以更清晰…...
