【深度学习】Pytorch的深入理解和研究
一、Pytorch核心理解
PyTorch 是一个灵活且强大的深度学习框架,广泛应用于研究和工业领域。要深入理解和研究 PyTorch,需要从其核心概念、底层机制以及高级功能入手。以下是对 PyTorch 的深入理解与研究的详细说明。
1. 概念
动态计算图(Dynamic Computation Graph)
定义:PyTorch 使用动态计算图(也称为“定义即运行”模式),允许在运行时动态构建和修改计算图。
特点:
- 更适合调试和实验。
- 支持灵活的控制流(如循环、条件判断)。
实现原理:
- 每次前向传播都会生成一个新的计算图。
- 反向传播时,自动计算梯度并释放计算图以节省内存。
2. 张量(Tensor)
2.1 张量的理解及与NumPy的对比
张量是一个多维数组,可以表示标量(0 维)、向量(1 维)、矩阵(2 维)或更高维度的数据。
特点:
- 支持动态计算图(Dynamic Computation Graph),适合深度学习任务。
- 可以在 CPU 或 GPU 上运行,利用硬件加速。
张量的理解及与NumPy的对比:

2.2 张量的创建
(1) 基本创建方法
import torch
# 创建未初始化的张量
x = torch.empty(3, 3)
# 创建随机张量
y = torch.rand(3, 3)
# 创建全零张量
z = torch.zeros(3, 3)
# 创建从 NumPy 转换的张量
import numpy as np
a = np.array([1, 2, 3])
b = torch.from_numpy(a)
(2)指定数据类型和设备
# 指定数据类型
x = torch.tensor([1, 2, 3], dtype=torch.float32)
# 指定设备(CPU 或 GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = x.to(device)
2.3 张量的操作
(1)基本运算
# 加法
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = x + y
# 矩阵乘法
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
c = torch.matmul(a, b)
(2)广播机制
广播机制允许不同形状的张量进行运算。
广播规则:
- 维度对齐:如果两个张量的维度数不同,则在较小张量的前面添加新的维度(大小为 1),使其维度数相同。
- 逐维比较:对每个维度,检查两个张量的大小是否相等,或者其中一个张量的大小为 1。
- 扩展维度:如果某个维度的大小为 1,则将其扩展为与另一个张量对应维度的大小相同。
x = torch.tensor([1, 2, 3])
y = torch.tensor(2)
z = x + y # 将标量 y 广播到每个元素
print(z) # 输出:tensor([3, 4, 5])
错误示例:
如果张量的形状无法满足广播规则,则会报错:
# 创建张量
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状 (2, 3)
b = torch.tensor([10, 20]) # 形状 (2,)
# 尝试广播
try:c = a + b
except RuntimeError as e:print(e) # 输出:The size of tensor a must match the size of tensor b at non-singleton dimension 1
#解释:
#a 的第二维大小为 3,而 b 的大小为 2,无法对齐。
(3)索引与切片
# 索引
x = torch.tensor([[1, 2], [3, 4]])
print(x[0, 1]) # 输出:2
# 切片
print(x[:, 1]) # 输出:tensor([2, 4])
2.4 张量的属性
(1)形状与维度
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 获取形状
print(x.shape) # 输出:torch.Size([2, 3])
# 获取维度
print(x.ndim) # 输出:2
(2)数据类型
x = torch.tensor([1, 2, 3], dtype=torch.float32)
print(x.dtype) # 输出:torch.float32
(3)设备信息
x = torch.tensor([1, 2, 3])
print(x.device) # 输出:cpu
2.5 张量与自动求导
(1)自动求导基础
PyTorch 的张量支持自动求导,通过 requires_grad=True 启用梯度跟踪。
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2 + 3 * x + 5
# 计算梯度
y.backward()
print(f"Gradient of y w.r.t x: {x.grad}") # 输出:tensor([7.])
(2)禁用梯度计算
禁用梯度计算的理解:禁用梯度计算是指在深度学习模型训练过程中不计算梯度,这通常是通过上下文管理器torch.no_grad()在PyTorch中实现的。禁用梯度计算的主要目的是在某些操作中节省内存和提高计算效率,特别是在进行推理(inference)时。
因此,在推理阶段,可以通过 torch.no_grad() 禁用梯度计算以节省内存。
with torch.no_grad():z = x + 2
print(z.requires_grad) # 输出:False
2.6 张量的底层机制
(1)内存布局
连续存储:张量在内存中是连续存储的,默认按行优先顺序排列。
非连续张量:某些操作(如转置)可能导致张量变得非连续。
x = torch.tensor([[1, 2], [3, 4]])
y = x.t() # 转置
print(y.is_contiguous()) # 输出:False
(2)数据共享
张量之间的操作可能共享底层数据,修改一个张量会影响另一个张量。
x = torch.tensor([1, 2, 3])
y = x.view(3, 1) # 修改视图
x[0] = 10
print(y) # 输出:tensor([[10], [2], [3]])
2.7 高级功能
(1)张量的序列化
张量的序列化是将张量数据保存为一种可以存储或传输的格式的过程,以便在后续需要时重新加载和使用。
它允许模型和数据在不同的运行时环境之间进行共享和持久存储。
# 保存张量
torch.save(x, "tensor.pth")
# 加载张量
x_loaded = torch.load("tensor.pth")
(2)张量的分布式操作
在分布式训练中,张量可以在多个设备之间传递。
import torch.distributed as dist
dist.init_process_group(backend="nccl")
x = torch.tensor([1, 2, 3]).cuda()
dist.all_reduce(x, op=dist.ReduceOp.SUM)
3. 底层机制
3.1 Autograd(自动求导系统)
定义:Autograd 是 PyTorch 的自动求导引擎,用于计算张量的梯度。
工作原理:
- 在前向传播中记录所有操作。
- 在反向传播中根据链式法则计算梯度。
关键组件:
- torch.autograd.Function:自定义前向和反向传播函数。
- torch.no_grad():禁用梯度计算(用于推理阶段)。
class MyFunction(torch.autograd.Function):@staticmethoddef forward(ctx, input):ctx.save_for_backward(input)return input.clamp(min=0)@staticmethoddef backward(ctx, grad_output):input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
# 使用自定义函数
x = torch.tensor([-1.0, 2.0, -3.0], requires_grad=True)
my_relu = MyFunction.apply
y = my_relu(x)
y.sum().backward()
print(f"Gradient: {x.grad}")
3.2 CUDA 和 GPU 加速
定义:PyTorch 支持将张量和模型迁移到 GPU 上,利用 GPU 的并行计算能力加速训练。
实现方式:
- .cuda() 或 .to(device) 将张量或模型迁移到 GPU。
- torch.device 用于指定设备(CPU 或 GPU)。
# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 创建张量并迁移到 GPU
x = torch.randn(3, 3).to(device)
y = torch.randn(3, 3).to(device)
# 在 GPU 上进行计算
z = x + y
print("Result on GPU:", z)
4. 高级功能
4.1 自定义模型
定义:通过继承 torch.nn.Module 类,可以创建自定义神经网络模型。
关键方法:
- forward():定义前向传播逻辑。
- parameters():返回模型的所有可训练参数。
import torch.nn as nn
import torch.optim as optim
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 20)self.relu = nn.ReLU()self.fc2 = nn.Linear(20, 1)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x
# 创建模型和优化器
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 模拟输入
input_data = torch.randn(5, 10)
output = model(input_data)
print("Model output:", output)
4.2 分布式训练
定义:PyTorch 提供了分布式训练工具(如 torch.distributed),支持多 GPU 和多节点训练。
常用方法:
- 数据并行(torch.nn.DataParallel)。
- 分布式数据并行(torch.nn.parallel.DistributedDataParallel)。
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# 初始化分布式环境
dist.init_process_group(backend="nccl")
# 创建模型并迁移到 GPU
model = SimpleNet().cuda()
ddp_model = DDP(model)
# 训练代码省略...
4.3 混合精度训练
定义:混合精度训练使用 FP16 和 FP32 结合的方式,减少显存占用并加速训练。
实现方式:
使用 torch.cuda.amp 提供的工具。
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
for data, target in dataloader:optimizer.zero_grad()# 使用混合精度with autocast():output = model(data)loss = loss_fn(output, target)# 缩放损失并反向传播scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()
4.4 实验与研究
(1)模型可视化
工具:使用 TensorBoard 或 Matplotlib 可视化训练过程。
用途:
- 监控损失和准确率变化。
- 可视化模型结构和特征图。
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
# 记录标量
for epoch in range(10):writer.add_scalar("Loss/train", epoch * 0.1, epoch)
writer.close()
(2)模型解释性
工具:使用 Captum 库分析模型的特征重要性。
用途:
- 解释模型决策过程。
- 发现潜在问题(如偏差或过拟合)。
from captum.attr import IntegratedGradients
ig = IntegratedGradients(model)
attributions = ig.attribute(input_data, target=0)
print("Attributions:", attributions)
二、Pytorch应用场景
1. 计算机视觉(Computer Vision)
1.1 图像分类
任务:将图像分配到预定义的类别。
实现:使用卷积神经网络(CNN),如 ResNet、VGG 或自定义模型。
应用:
- 医疗影像分析(如 X 光片分类)。
- 自动驾驶中的交通标志识别。
import torch
import torchvision.models as models
# 加载预训练模型
model = models.resnet18(pretrained=True)
# 修改输出层以适应新任务
num_classes = 10
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
1.2 目标检测
任务:在图像中定位并分类多个目标。
实现:使用 Faster R-CNN、YOLO 或 SSD 等模型。
应用:
- 安防监控(如行人检测)。
- 工业自动化(如缺陷检测)。
1.3 图像分割
任务:为图像中的每个像素分配一个类别标签。
实现:使用 U-Net、Mask R-CNN 等模型。
应用:
- 医学图像分割(如肿瘤区域标记)。
- 卫星图像分析(如土地覆盖分类)。
2. 自然语言处理(Natural Language Processing, NLP)
2.1 文本分类
任务:将文本分配到预定义的类别。
实现:使用 Transformer 模型(如 BERT、RoBERTa)。
应用:
- 情感分析(如评论情感分类)。
- 垃圾邮件检测。
示例代码:
from transformers import BertTokenizer, BertForSequenceClassification
# 加载预训练模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# 输入文本
text = "I love using PyTorch!"
inputs = tokenizer(text, return_tensors="pt")
# 推理
outputs = model(**inputs)
logits = outputs.logits
print(logits)
2.2 机器翻译
任务:将一种语言的文本翻译成另一种语言。
实现:使用序列到序列(Seq2Seq)模型或 Transformer。
应用:
- 跨语言交流工具。
- 多语言内容生成。
2.3 文本生成
任务:根据输入生成连贯的文本。
实现:使用 GPT 系列模型。
应用:
- 写作助手(如自动完成文章)。
- 聊天机器人。
3. 推荐系统(Recommendation Systems)
3.1 用户行为建模
任务:根据用户的历史行为推荐相关内容。
实现:使用协同过滤、矩阵分解或深度学习模型。
应用:
- 电商平台推荐商品。
- 视频平台推荐视频。
3.2 多模态推荐
任务:结合多种数据源(如文本、图像)进行推荐。
实现:使用多模态融合模型。
应用:
- 社交媒体内容推荐。
- 广告投放优化。
三、总结
深入理解和研究 PyTorch 需要掌握以下内容:
- 核心概念:动态计算图、张量操作、自动求导。
- 底层机制:Autograd、CUDA 加速。
- 高级功能:自定义模型、分布式训练、混合精度训练。
- 实验与研究:模型可视化、解释性分析。
通过不断实践和探索,你可以充分利用 PyTorch 的灵活性和强大功能,解决复杂的深度学习问题!
相关文章:
【深度学习】Pytorch的深入理解和研究
一、Pytorch核心理解 PyTorch 是一个灵活且强大的深度学习框架,广泛应用于研究和工业领域。要深入理解和研究 PyTorch,需要从其核心概念、底层机制以及高级功能入手。以下是对 PyTorch 的深入理解与研究的详细说明。 1. 概念 动态计算图(D…...
IDEA + 通义灵码AI程序员:快速构建DDD后端工程模板
作者:陈荣健 IDEA 通义灵码AI程序员:快速构建DDD后端工程模板 在软件开发过程中,一个清晰、可维护、可扩展的架构至关重要。领域驱动设计 (DDD) 是一种软件开发方法,它强调将软件模型与业务领域紧密结合,从而构建更…...
内容中台重构企业内容管理的价值维度与实施路径
内容概要 在数字化转型进程中,企业内容管理(ECM)与内容中台的差异性体现在价值维度的重构与能力边界的突破。传统ECM系统通常聚焦于文档存储、权限控制等基础功能,而内容中台通过标准化流程引擎与智能工具链,构建起覆…...
CPU封装形式解析:从传统到先进封装的技术演进
中央处理器(CPU)的封装技术是半导体制造的关键环节,直接影响芯片的电气性能、散热效率和物理可靠性。随着半导体工艺的不断进步,封装形式从早期的简单结构演变为复杂的多维集成方案。本文将系统解析CPU的主流封装形式及其技术特点…...
Spring Boot 应用(官网文档解读)
Spring Boot 启动方式 SpringApplication.run(MyApplication.class, args); Spring Boot 故障分析器 在Spring Boot 项目启动发生错误的时候,我们通常可以看到上面的内容,即 APPLICATION FAILED TO START,以及后面的错误描述。这个功能是通过…...
【智能客服】ChatGPT大模型话术优化落地方案
本文原创作者:姚瑞南 AI-agent 大模型运营专家,先后任职于美团、猎聘等中大厂AI训练专家和智能运营专家岗;多年人工智能行业智能产品运营及大模型落地经验,拥有AI外呼方向国家专利与PMP项目管理证书。(转载需经授权) 目录 一、项目背景 1.1 行业背景 1.2 业务现…...
1.22作业
1 Web-php-unserialize __construct()与$file、__destruct() __wakeup()检查 先绕过wakeup函数: O:4:"Demo":2:{s:10:"Demofile";s:8:"fl4g.php";}1.PHP序列化的时候对public protected private变量的处理方式是不同的 public无标…...
蓝桥杯 Day6 贪心
贪心 1.要点 2.例题 2022 砍竹子 学习: 1.模拟砍竹子砍到高度1,不过要记录过程高度,以便后续判断是否存在(想到集合哈希),然后外面嵌套数组(活用数据结构)resize给大小 vector<unordered_set<ll>> hs;//记录第i根竹子下降到1过程中的每…...
学习aigc
DALLE2 论文 Hierarchical Text-Conditional Image Generation with CLIP Latents [2204.06125] Hierarchical Text-Conditional Image Generation with CLIP LatentsAbstract page for arXiv paper 2204.06125: Hierarchical Text-Conditional Image Generation with CLIP L…...
overflow-x: auto 使用鼠标实现横向滚动,区分触摸板和鼠标滚动事件的方法
假设一个 div 的滚动只设置了 overflow-x: auto 我们发现使用鼠标的滚轮是无法左右滚动的,但是使用笔记本电脑的触摸板,或者在移动设备上是可以滚动的。所以我们需要兼容一下鼠标的横向滚动功能。 我们可以监控 wheel 事件,然后根据位置来计…...
模拟实现Java中的计时器
定时器是什么 定时器也是软件开发中的⼀个重要组件. 类似于⼀个 "闹钟". 达到⼀个设定的时间之后, 就执⾏某个指定好的代码. 前端/后端中都会用到计时器. 定时器是⼀种实际开发中⾮常常⽤的组件. ⽐如⽹络通信中, 如果对⽅ 500ms 内没有返回数据, 则断开连接尝试重…...
Ubuntu 的RabbitMQ安装
目录 1.安装Erlang 查看erlang版本 退出命令 2. 安装 RabbitMQ 3.确认安装结果 4.安装RabbitMQ管理界面 5.启动服务并访问 1.启动服务 2.查看服务状态 3.通过IP:port 访问界面 4.添加管理员用户 a)添加用户名:admin,密码࿱…...
设计模式教程:命令模式(Command Pattern)
1. 什么是命令模式? 命令模式(Command Pattern)是一种行为型设计模式。它将请求封装成一个对象,从而使你能够用不同的请求、队列和日志请求以及支持可撤销操作。 简单来说,命令模式通过把请求封装成对象的方式解耦了…...
JavaScript数组常用的方法有哪些?map、filter、reduce 的区别和使用场景是什么?
JavaScript数组常用的方法有哪些?map、filter、reduce 的区别和使用场景是什么? JavaScript 数组常用方法 JavaScript 数组有很多实用的方法,以下先简单介绍一些常见的基础方法,再重点讲解 map、filter、reduce 这三个高阶函数。…...
vim修改只读文件
现象 解决方案 对于有root权限的用户,在命令行输入 :wq! 即可强制保存退出...
【DeepSeek】本地部署,保姆级教程
deepseek网站链接传送门:DeepSeek 在这里主要介绍DeepSeek的两种部署方法,一种是调用API,一种是本地部署。 一、API调用 1.进入网址Cherry Studio - 全能的AI助手选择立即下载 2.安装时位置建议放在其他盘,不要放c盘 3.进入软件后…...
为AI聊天工具添加一个知识系统 之114 详细设计之55 知识表征
本文要点 要点 项目名称:为使用AI聊天工具的聊天者添加一个知识系统 项目背景: 在现在各种AI聊天工具层出不穷的今天,我觉得特别需要一个通用的AI聊天工具的图形界面能够为每个聊天者(或一个利益相关者组织)建立自…...
centos 9 时间同步服务
在 CentOS 9 中,默认的时间同步服务是 chrony,而不是传统的 ntpd。 因此,建议使用 chrony 来配置和管理时间同步。 以下是使用 chrony 配置 NTP 服务的步骤: 1. 安装 chrony 首先,确保系统已安装 chrony。 在 CentOS…...
NCRE证书构成:全国计算机等级考试证书体系详解
全国计算机等级考试(NCRE)证书体系为中学生提供了一个系统学习和提升计算机能力的平台。本文将详细介绍 NCRE 证书的构成,帮助中学生了解 NCRE 证书的级别和内容,规划未来职业发展。 一、NCRE 证书体系概述 NCRE 证书共分为四个级…...
嵌入式之总线
嵌入式系统中的总线(Bus)是指用于连接各种组件(如处理器、存储器、外设等)的通信通道。总线的设计和实现对嵌入式系统的性能、功耗和扩展性有着重要影响。下面详细介绍嵌入式系统中的总线的概念、类型和特点。 一、总线的基本概念 总线是一种共享的通信路径,允许多个设备…...
如何在WPS打开的word、excel文件中,使用AI?
1、百度搜索:Office AI官方下载 或者直接打开网址:https://www.office-ai.cn/static/introductions/officeai/smartdownload.html 打开后会直接提示开始下载中,下载完成后会让其选择下载存放位置: 选择位置,然后命名文…...
Java 使用websocket
添加依赖 <!-- WebSocket 支持 --> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId> </dependency>添加配置类 Configuration public class WebSocketConfig {B…...
MySQL 视图入门
一、什么是 MySQL 视图 1.1 视图的基本概念 在 MySQL 中,视图是一种虚拟表,它本身并不存储实际的数据,而是基于一个或多个真实表(基表)的查询结果集。可以把视图想象成是一个预定义好的查询语句的快捷方式。当你查询…...
【设计模式】 代理模式(静态代理、动态代理{JDK动态代理、JDK动态代理与CGLIB动态代理的区别})
代理模式 代理模式是一种结构型设计模式,它提供了一种替代访问的方法,即通过代理对象来间接访问目标对象。代理模式可以在不改变原始类代码的情况下,增加额外的功能,如权限控制、日志记录等。 静态代理 静态代理是指创建的或特…...
高考或者单招考试需要考物理这科目
问题:帮忙搜索一下以上学校哪些高考或者单招考试需要考物理这科目的 回答: 根据目前获取的资料,明确提及高考或单招考试需考物理的学校为湖南工业职业技术学院,在部分专业单招时要求选考物理;其他学校暂未发现明确提…...
《A++ 敏捷开发》- 16 评审与结对编程
客户:我们的客户以银行为主,他们很注重质量,所以一直很注重评审。他们对需求评审、代码走查等也很赞同,也能找到缺陷,对提升质量有作用。但他们最困惑的是通过设计评审很难发现缺陷。 我:你听说过敏捷的结对…...
NutUI内网离线部署
文章目录 官网拉取源代码到本地仓库修改源代码打包构建nginx反向代理部署访问内网离线地址 在网上找了一圈没有写NutUI内网离线部署的文档,花了1天时间研究下,终于解决了。 对于有在内网离线使用的小伙伴就可以参考使用了 如果还是不会联系UP主:QQ:10927…...
【实战篇】【深度介绍 DeepSeek R1 本地/私有化部署大模型常见问题及解决方案】
引言 大家好!今天我们来聊聊 DeepSeek R1 的本地/私有化部署大模型。如果你正在考虑或者已经开始了这个项目,那么这篇文章就是为你准备的。我们会详细探讨常见问题及其解决方案,帮助你更好地理解和解决在部署过程中可能遇到的挑战。准备好了…...
数据结构--双向链表,双向循环链表
双向链表的头插,尾插,头删,尾删 头文件:(head.h) #include <string.h> #include <stdlib.h> typedef…...
Qt学习(六) 软件启动界面 ,注册表使用 ,QT绘图, 视图和窗口绘图,Graphics View绘图框架:简易CAD
一 软件启动界面 注册表使用 知识点1:这样创建的界面是不可以拖动的,需要手动创建函数来进行拖动,以下的3个函数是从父类继承过来的函数 virtual void mousePressEvent(QMouseEvent *event);virtual void mouseReleaseEvent(QMouseEvent *eve…...
