当前位置: 首页 > news >正文

PyTorch 的各个核心模块和它们的功能

1. torch

核心功能
  • 张量操作:PyTorch 的张量是一个多维数组,类似于 NumPy 的 ndarray,但支持 GPU 加速。
  • 数学运算:提供了各种数学运算,包括线性代数操作、随机数生成等。
  • 自动微分torch.autograd 模块用于自动计算梯度。
  • 设备管理:允许在 CPU 和 GPU 之间移动张量。

示例代码

import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([4.0, 5.0, 6.0])# 张量加法
z = x + y
print(f'z: {z}')# 计算梯度
z.sum().backward() # 求和的原因是求梯度需要是一个标量
print(f'Gradients of x: {x.grad}')

2. torch.nn

核心功能
  • 构建神经网络模块nn.Module 是所有神经网络模块的基类。
  • 常用层:如卷积层、池化层、全连接层、激活函数、归一化层等。
  • 损失函数:如交叉熵损失、均方误差损失等。

示例代码

import torch.nn as nn# 定义一个简单的前馈神经网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleNet()
print(model)

3. torch.optim

核心功能
  • 优化算法:包括 SGD、Adam、RMSprop 等。
  • 学习率调度器:用于动态调整学习率,如 StepLRExponentialLR

示例代码

import torch.optim as optim# 定义模型
model = SimpleNet()# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)# 更新模型参数
optimizer.zero_grad()
output = model(torch.randn(1, 10))
loss = torch.mean(output)
loss.backward()
optimizer.step()# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
scheduler.step()

4. torch.utils.data

核心功能
  • 数据集Dataset 类用于自定义数据集。
  • 数据加载器DataLoader 用于批量加载数据,支持多线程加载。
  • 数据变换:通过 torchvision.transforms 可以对数据进行预处理和增强。

示例代码

from torch.utils.data import Dataset, DataLoader# 自定义数据集
class MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]dataset = MyDataset([1, 2, 3, 4])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)for batch in dataloader:print(batch)

5. torchvision

核心功能
  • 数据集:提供了常用的计算机视觉数据集,如 MNIST、CIFAR-10、ImageNet 等。
  • 预训练模型:如 ResNet、VGG、AlexNet 等。
  • 数据变换:如图像调整大小、裁剪、归一化等。

示例代码

import torchvision.transforms as transforms
import torchvision.datasets as datasets# 定义数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 下载 MNIST 数据集
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)for images, labels in dataloader:print(images.shape, labels.shape)break

6. torch.jit

核心功能
  • TorchScript:通过脚本化和追踪将 Python 模型转换为 TorchScript 模型,提高执行效率并支持跨平台部署。
  • 脚本化torch.jit.script 用于将 Python 代码转换为 TorchScript 代码。
  • 追踪torch.jit.trace 用于通过追踪模型的执行流程创建 TorchScript 模型。

示例代码

import torch.jit# 定义简单模型
class SimpleNet(nn.Module):def forward(self, x):return x * 2model = SimpleNet()# 脚本化模型
scripted_model = torch.jit.script(model)
print(scripted_model)# 追踪模型
traced_model = torch.jit.trace(model, torch.randn(1, 10))
print(traced_model)

7. torch.cuda

核心功能
  • 设备管理:提供与 GPU 相关的操作,如设备计数、设备选择等。
  • 张量迁移:将张量从 CPU 移动到 GPU,以利用 GPU 加速计算。

示例代码

if torch.cuda.is_available():device = torch.device("cuda")x = torch.tensor([1.0, 2.0, 3.0]).to(device)print(f'GPU tensor: {x}')
else:print("CUDA is not available.")

8. torch.autograd

核心功能
  • 自动微分:提供自动计算梯度的功能,支持反向传播算法。
  • 计算图:动态构建计算图,并根据图计算梯度。

示例代码

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()# 反向传播计算梯度
out.backward()
print(x.grad)  # 输出 x 的梯度

9. torch.multiprocessing

核心功能
  • 多进程并行:用于在多核 CPU 上实现数据并行和模型并行,提高计算效率。
  • 与 Python 标准库 multiprocessing 的兼容:提供与标准库相似的接口。

示例代码

import torch.multiprocessing as mpdef worker(rank, data):print(f'Worker {rank} processing data: {data}')if __name__ == '__main__':data = [1, 2, 3, 4]mp.spawn(worker, args=(data,), nprocs=4)

10. torch.distributed

核心功能
  • 分布式训练:支持在多个 GPU 和多台机器上进行分布式训练。
  • 通信接口:提供多种通信后端,如 Gloo、NCCL 等。

示例代码

import torch
import torch.distributed as distdef init_process(rank, size, fn, backend='gloo'):dist.init_process_group(backend, rank=rank, world_size=size)fn(rank, size)def example(rank, size):tensor = torch.zeros(1)if rank == 0:tensor += 1dist.send(tensor, dst=1)else:dist.recv(tensor, src=0)print(f'Rank {rank} has data {tensor[0]}')if __name__ == "__main__":size = 2processes = []for rank in range(size):p = mp.Process(target=init_process, args=(rank, size, example))p.start()processes.append(p)for p in processes:p.join()

通过这些模块,PyTorch 提供了构建、训练、优化和部署深度学习模型所需的全面支持。

相关文章:

PyTorch 的各个核心模块和它们的功能

1. torch 核心功能 张量操作:PyTorch 的张量是一个多维数组,类似于 NumPy 的 ndarray,但支持 GPU 加速。数学运算:提供了各种数学运算,包括线性代数操作、随机数生成等。自动微分:torch.autograd 模块用于…...

Java开发之LinkedList源码分析

#来自ゾフィー(佐菲) 1 简介 LinkedList 的底层数据结构是双向链表。可以当作链表、栈、队列、双端队列来使用。有以下特点: 在插入或删除数据时,性能好;允许有 null 值;查询效率不高;线程不安…...

外卖霸王餐系统架构怎么选?

在当今日益繁荣的外卖市场中,外卖霸王餐作为一种独特的营销策略,受到了众多商家的青睐。然而,要想成功实施外卖霸王餐活动,一个安全、稳定且高效的架构选择至关重要。本文将深入探讨外卖霸王餐架构的选择,以期为商家提…...

AV1技术学习:Transform Coding

对预测残差进行变换编码,去除潜在的空间相关性。VP9 采用统一的变换块大小设计,编码块中的所有的块共享相同的变换大小。VP9 支持 4 4、8 8、16 16、32 32 四种正方形变换大小。根据预测模式选择由一维离散余弦变换 (DCT) 和非对称离散正弦变换 (ADS…...

Git操作指令

Git操作指令 一、安装git 1、设置配置信息: # global全局配置 git config --global user.name "Your username" git config --global user.email "Your email"2、查看git版本号 git -v # or git --version3、查看配置信息: git…...

CSS 创建:从入门到精通

CSS 创建:从入门到精通 CSS(层叠样式表)是网页设计中不可或缺的一部分,它用于控制网页的布局和样式。本文将详细介绍CSS的创建过程,包括基本概念、语法结构、选择器、样式属性以及如何将CSS应用到HTML中。无论您是初学者还是有经验的开发者,本文都将为您提供宝贵的信息。…...

Windows 11 系统对磁盘进行分区保姆级教程

Windows 11磁盘分区 磁盘分区是将硬盘驱动器划分为多个逻辑部分的过程,每个逻辑部分都可以独立使用和管理。在Windows 11操作系统中进行磁盘分区主要有以下几个作用和意义: 组织和管理数据:分区可以帮助用户更好地组织他们的数据&#xff0c…...

探索WebKit的CSS盒模型:深入理解Web布局的基石

探索WebKit的CSS盒模型:深入理解Web布局的基石 在Web开发的世界中,CSS盒模型(Box Model)是构建网页布局的核心原理。WebKit,作为Safari浏览器的渲染引擎,对CSS盒模型有着深入而精确的支持。本文将带你深入…...

c++初阶知识——string类详解

目录 前言: 1.标准库中的string类 1.1 auto和范围for auto 范围for 1.2 string类常用接口说明 1.string类对象的常见构造 1.3 string类对象的访问及遍历操作 1.4. string类对象的修改操作 1.5 string类非成员函数 2.string类的模拟实现 2.1 经典的string…...

php接口返回的json字符串,json_decode()失败,原来是多了红点

问题: 调用某个接口返回的json,json_decode()失败,返回数据为null, echo json_last_error();返回错误码 4 经过多次调试发现:多出来一个红点,预览是看不到的。 解决:要去除BOM头部 $resul…...

Python3网络爬虫开发实战(2)爬虫基础库

文章目录 一、urllib1. urlparse 实现 URL 的识别和分段2. urlunparse 用于构造 URL3. urljoin 用于两个链接的拼接4. urlencode 将 params 字典序列化为 params 字符串5. parse_qs 和 parse_qsl 用于将 params 字符串反序列化为 params 字典或列表6. quote 和 unquote 对 URL的…...

el-image预览图片点击遮盖处关闭预览

预览关闭按钮不明显 解决方式: 1.修改按钮样式明显点: //el-image 添加自定义类名,下文【test-image】代指 .test-image .el-icon-circle-close{ color:#fff; font-size:20px; ...改成很明显的样式 }2.使用事件监听,监听当前遮…...

基于Neo4j将知识图谱用于检索增强生成:Knowledge Graphs for RAG

Knowledge Graphs for RAG 本文是学习https://www.deeplearning.ai/short-courses/knowledge-graphs-rag/这门课的学习笔记。 What you’ll learn in this course Knowledge graphs are used in development to structure complex data relationships, drive intelligent sea…...

康康近期的慢SQL(oracle vs 达梦)

近期执行的sql,哪些比较慢? 或者健康检查时搂一眼状态 oracle: --最近3天内的慢sql set lines 200 pages 100 col txt for a65 col sql_id for a13 select a.sql_id,a.cnt,a.pctload,b.sql_text txt from (select * from (select sql_id,co…...

探索 GPT-4o mini:成本效益与创新的双重驱动

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...

2.6基本算法之动态规划2989:糖果

描述 由于在维护世界和平的事务中做出巨大贡献,Dzx被赠予糖果公司2010年5月23日当天无限量糖果免费优惠券。在这一天,Dzx可以从糖果公司的N件产品中任意选择若干件带回家享用。糖果公司的N件产品每件都包含数量不同的糖果。Dzx希望他选择的产品包含的糖…...

12.顶部带三角形的边框 CSS 关键字 currentColor

顶部带三角形的边框 创建一个在顶部带有三角形的内容容器。 使用 ::before 和 ::after 伪元素创建两个三角形。两个三角形的颜色应分别与容器的 border-color 和容器的 background-color 相同。一个三角形(::before)的 border-width 应比另一个(::after)宽 1px,以起到边框的作…...

Llama中模块参数大小

LLama2中,流程中数据大小的变换如下 Transformer模块 第一次输入,进行prefill,输入x维度为[1, 8, 4096] 1. 构建wq,wk,wv,wo,尺寸均为[4096,4096], 与x点乘,得到xq, xk, xv 2. 构建KV cache, 尺寸为 [b…...

Modbus转EtherCAT网关将Modbus协议的数据格式转换为EtherCAT协议

随着工业自动化技术的快速发展,不同通信协议之间的互操作性变得越来越重要。Modbus作为一种广泛使用的串行通信协议,与以太网为基础的EtherCAT协议之间的转换需求日益增长。本文将从网关功能、硬件设计、性能以及应用案例来介绍这款Modbus转EtherCAT网关…...

【开发实战】QT5 + OpenCV4 开发环境配置应用演示

前言 作为深度学习算法工程师,必须要掌握应用开发技能吗?搞工程肯定是必须要会界面开发,QT就是一个很不错的选择。本文以QT5.15 OpenCV4.8 OpenVINO2023为例,搭建应用开发环境,演示深度学习模型的QT应用案例。 开发…...

基于算法竞赛的c++编程(28)结构体的进阶应用

结构体的嵌套与复杂数据组织 在C中,结构体可以嵌套使用,形成更复杂的数据结构。例如,可以通过嵌套结构体描述多层级数据关系: struct Address {string city;string street;int zipCode; };struct Employee {string name;int id;…...

【网络】每天掌握一个Linux命令 - iftop

在Linux系统中,iftop是网络管理的得力助手,能实时监控网络流量、连接情况等,帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...

React hook之useRef

React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法

深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...

Python爬虫实战:研究feedparser库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例

文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...

macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用

文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

Qwen3-Embedding-0.6B深度解析:多语言语义检索的轻量级利器

第一章 引言:语义表示的新时代挑战与Qwen3的破局之路 1.1 文本嵌入的核心价值与技术演进 在人工智能领域,文本嵌入技术如同连接自然语言与机器理解的“神经突触”——它将人类语言转化为计算机可计算的语义向量,支撑着搜索引擎、推荐系统、…...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)

文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...