当前位置: 首页 > news >正文

pytorch张量基础

  1. 引言
  2. 张量的基础知识
    1. 张量的概念
    2. 张量的属性
    3. 张量的创建
  3. 张量的操作
    1. 基本运算
    2. 索引和切片
    3. 形状变换
  4. 自动微分
    1. 基本概念
    2. 停止梯度传播
  5. 张量的设备管理
    1. 检查和移动张量
    2. CUDA 张量
  6. 高级操作
    1. 张量的视图
    2. 广播机制
    3. 分块和拼接
    4. 张量的复制
  7. 内存优化和管理
    1. 稀疏张量
    2. 内存释放
  8. 应用实例
    1. 线性回归
    2. 神经网络基础
  9. 总结

1. 引言

在机器学习和深度学习中,张量(Tensor)是核心的数据结构。了解和掌握张量的操作是学习 PyTorch 和构建神经网络模型的必要基础。张量可以表示从标量到高维数组的数据结构,它在 PyTorch 的计算图中扮演着基础角色。本指南旨在全面介绍 PyTorch 中张量的相关知识,帮助读者从基础打好深度学习的基础。

2. 张量的基础知识

1. 张量的概念

张量是一个数组的通用化,可以表示标量(0维)、向量(1维)、矩阵(2维)及更高维的数组。通俗来说,张量是一种多维数据结构,其本质上是一个多维数组。

2. 张量的属性

张量有多个重要属性,用来描述其数据和结构:

  • 形状(shape):描述张量的维度结构,例如 (2, 3) 表示一个包含 2 行 3 列的矩阵。
  • 数据类型(dtype):指定张量中元素的类型,例如 torch.float32torch.int64 等。
  • 设备(device):指示张量存储的设备,可以是 CPU 或 GPU。
  • 步幅(stride):步幅表示连续两个元素在各个维度上的步进距离。
import torchtensor = torch.tensor([[1., 2., 3.], [4., 5., 6.]])print(tensor.shape)    # torch.Size([2, 3])
print(tensor.dtype)    # torch.float32
print(tensor.device)   # cpu
print(tensor.stride()) # (3, 1)

3. 张量的创建

可以通过多种方式创建张量,包括从已有数据创建、使用随机数生成和从其他张量创建。

# 从数据创建
scalar = torch.tensor(5.0)          # 标量
vector = torch.tensor([1.0, 2.0, 3.0])  # 向量
matrix = torch.tensor([[1.0, 2.0], [3.0, 4.0]])  # 矩阵# 使用随机数创建
rand_tensor = torch.rand(2, 3)     # 均匀分布
randn_tensor = torch.randn(2, 3)   # 标准正态分布# 从其他张量创建
zeros_tensor = torch.zeros_like(matrix)  # 创建与 matrix 形状相同的全零张量

3. 张量的操作

1. 基本运算

张量支持基本的算术运算,包括加、减、乘、除。

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])# 加法
c = a + b# 减法
d = a - b# 乘法
e = a * b# 除法
f = a / b# 点积
dot_prod = torch.dot(a, b)  # 32.0# 矩阵乘法
matrix1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
matrix2 = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
matrix_mul = torch.mm(matrix1, matrix2)  # [[19.0, 22.0], [43.0, 50.0]]

2. 索引和切片

张量支持多种索引和切片操作,类似于 NumPy。

tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])# 索引
element = tensor[1, 2]  # 6.0# 切片
subset = tensor[:, 1]  # tensor([2.0, 5.0])

3. 形状变换

在不复制数据的情况下,PyTorch 支持多种形状变换操作。

# 重塑
reshaped = tensor.view(3, 2)  # tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])# 转置
transposed = tensor.t()       # tensor([[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]])# 增加或减少维度
unsqueezed = tensor.unsqueeze(0)  # 增加第0维
squeezed = tensor.squeeze()       # 去除所有维度为1的维度

4. 自动微分

PyTorch 提供强大的自动微分功能,称为Autograd。它可以自动计算张量的梯度,适用于优化和训练神经网络。

1. 基本概念

张量可以设置 requires_grad=True 以启用自动微分。计算张量的梯度使用 backward() 方法。

x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x[0] ** 2 + x[1] ** 3
y.backward()
print(x.grad)  # tensor([ 4.0, 27.0])

2. 停止梯度传播

在某些情况下,比如模型评估或推理时,需要停止梯度传播以提高性能并节省内存。

with torch.no_grad():y = x[0] ** 2 + x[1] ** 3# 使用 detach() 方法创建一个新的张量,该张量与原始张量共享数据,但不进行梯度追踪
detached_tensor = x.detach()

5. 张量的设备管理

1. 检查和移动张量

张量可以在 CPU 或 GPU 上进行计算。PyTorch 提供了简单的方法来检查和移动张量到不同的设备。

tensor = torch.tensor([1.0, 2.0, 3.0])# 检查是否有可用的 GPU
if torch.cuda.is_available():tensor = tensor.to('cuda')print(tensor.device)  # cuda:0# 将张量移动回 CPU
tensor = tensor.to('cpu')
print(tensor.device)  # cpu

2. CUDA 张量

使用 CUDA 张量可以显著提高计算速度,特别是在深度学习中。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tensor = torch.tensor([1.0, 2.0, 3.0], device=device)

6. 高级操作

1. 张量的视图

视图允许我们在不复制数据的情况下,改变张量的形状。

original_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
view_tensor = original_tensor.view(6)  # tensor([1, 2, 3, 4, 5, 6])# 修改视图
view_tensor[0] = 10
print(original_tensor)  # tensor([[10,  2,  3], [ 4,  5,  6]])

2. 广播机制

广播机制使得不同形状的张量能够进行相同大小的运算。

a = torch.tensor([1, 2, 3])
b = torch.tensor([[1], [2], [3]])
result = a + b
# result: tensor([[2, 3, 4],
#                 [3, 4, 5],
#                 [4, 5, 6]])

3. 分块和拼接

可以使用 split() 和 cat() 等函数进行分块和拼接。

tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])# 分割张量
split_tensors = torch.split(tensor, split_size_or_sections=2, dim=1)# 拼接张量
tensor_a = torch.tensor([[1, 2], [3, 4]])
tensor_b = torch.tensor([[5, 6], [7, 8]])
concat_tensor = torch.cat((tensor_a, tensor_b), dim=1)

4. 张量的复制

用于创建独立副本,clone() 和 detach() 是常用方法。

tensor = torch.tensor([1, 2, 3], requires_grad=True)
cloned_tensor = tensor.clone()
detached_tensor = tensor.detach()

7. 内存优化和管理

1. 稀疏张量

对于稀疏矩阵和张量,PyTorch 提供了稀疏张量表示,以便节省内存和计算资源。

indices = torch.tensor([[0, 1, 1], [2, 0, 2]])
values = torch.tensor([3, 4, 5], dtype=torch.float32)
sparse_tensor = torch.sparse_coo_tensor(indices, values, [2, 3])
print(sparse_tensor)

2. 内存释放

为了在训练和评估期间节省内存,可以释放不再需要的张量。

# 使用 del 语句手动删除对象
del tensor# 清空 GPU 切实可行的张量以释放内存
torch.cuda.empty_cache()

8. 应用实例

通过实际应用实例,可以更好地理解和掌握 PyTorch 张量的使用方式。

1. 线性回归

利用 PyTorch 张量实现简单的线性回归模型。

# 数据集
x_train = torch.tensor([[1.0], [2.0], [3.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0]])# 初始化参数
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)def model(x):return w * x + b# 损失函数
def loss_fn(y_pred, y):return ((y_pred - y) ** 2).mean()# 训练模型
learning_rate = 0.01
for epoch in range(1000):y_pred = model(x_train)loss = loss_fn(y_pred, y_train)loss.backward()with torch.no_grad():w -= learning_rate * w.gradb -= learning_rate * b.gradw.grad.zero_()b.grad.zero_()print(f'w: {w}, b: {b}')

2. 神经网络基础

张量在神经网络中的应用,是构建复杂模型的基础。

import torch.nn as nn# 简单的神经网络
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(1, 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 1)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return outmodel = SimpleNN()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 训练模型
for epoch in range(1000):y_pred = model(x_train)loss = criterion(y_pred, y_train)optimizer.zero_grad()loss.backward()optimizer.step()print(list(model.parameters()))

相关文章:

pytorch张量基础

引言张量的基础知识 张量的概念张量的属性张量的创建张量的操作 基本运算索引和切片形状变换自动微分 基本概念停止梯度传播张量的设备管理 检查和移动张量CUDA 张量高级操作 张量的视图广播机制分块和拼接张量的复制内存优化和管理 稀疏张量内存释放应用实例 线性回归神经网络…...

深入解析LlamaIndex Workflows【下篇】:实现ReAct模式AI智能体的新方法

之前我们介绍了来自LLM开发框架LlamaIndex的新特性:Workflows,一种事件驱动、用于构建复杂AI工作流应用的新方法(参考:[深入解析LlamaIndex Workflows:构建复杂RAG与智能体工作流的新利器【上篇】]。在本篇中&#xff…...

要在 Git Bash 中使用 `tree` 命令,下载并手动安装 `tree`。

0、git bash 安装 git(安装,常用命令,分支操作,gitee,IDEA集成git,IDEA集成gitee,IDEA集成github,远程仓库操作) 1、下载并手动安装 tree 下载 tree.exe 从 tree for Windows 官方站点 下载 tree 的 Windows 可执行文件。tree for Window:https://gnuwin32.source…...

Linux的基本指令(1)

前提: a:博主是在云服务器上进行操作的 b:windows上普通文件在Linux中也叫作普通文件,但是windows上的文件夹,在Linux中叫作目录 c:文件 文件内容 文件属性(创建时间,修改时间,…...

JavaEE之多线程进阶-面试问题

一.常见的锁策略 锁策略不是指某一个具体的锁,所有的锁都可以往这些锁策略中套 1.悲观锁与乐观锁 预测所冲突的概率是否高,悲观锁为预测锁冲突的概率较高,乐观锁为预测锁冲突的概率更低。 2.重量级锁和轻量级锁 从加锁的开销角度判断&am…...

费曼学习法没有输出对象怎么办?

‌费曼学习法并不需要输出对象。‌费曼学习法的核心在于通过将所学知识以简明易懂的方式解释给自己听,从而加深对知识的理解和记忆。这种方法强调的是理解和反思的过程,而不是简单地通过输出(如向他人解释)来检验学习效果。费曼学…...

Hive优化操作(二)

Hive 数据倾斜优化 在使用 Hive 进行大数据处理时,数据倾斜是一个常见的问题。本文将详细介绍数据倾斜的概念、表现、常见场景及其解决方案。 1. 什么是数据倾斜? 数据倾斜是指由于数据分布不均匀,导致大量数据集中到某个节点或任务中&…...

销冠的至高艺术:让自己不像销售

若想在销售领域脱颖而出,首先是让自己超越传统销售的框架,成为客户心中不可多得的行业顾问与信赖源泉。这不仅是身份的蜕变,更是影响力与信任度质的飞跃。 销冠对客户只吸引不骚扰,不讲自己卖什么,只讲自己能解决什么…...

Hive数仓操作(十一)

一、Hive 日期函数 在日常的数据处理工作中,日期和时间的处理是非常常见的操作。Hive 提供了丰富的日期函数,能够帮助我们方便地进行日期和时间的计算。本文将详细介绍 Hive 中常用的日期函数,并通过具体的示例展示其用法和结果。 1. 获取当…...

C语言初步介绍(初学者,大学生)【上】

1.C语⾔是什么? ⼈和⼈交流使⽤的是⾃然语⾔,如:汉语、英语、⽇语 那⼈和计算机是怎么交流的呢?使⽤ 计算机语⾔ 。 ⽬前已知已经有上千种计算机语⾔,⼈们是通过计算机语⾔写的程序,给计算机下达指令&am…...

陈文自媒体:现在的房价,已经跌到7年前!

今年的国庆北上广深都放开了政策,很多人都放弃旅游去看房了,现在的全民都有一个基本意识,现在的房子已经到了谷底,从各大政策就可以看出来,稍微有点钱的可以出手买房了。 昨天我哥跟我说,现在xx地方的房子…...

基于STM32的智能水族箱控制系统设计

引言 本项目基于STM32微控制器设计一个智能水族箱控制系统。该系统能够通过传感器监测水温、照明和水位,并自动控制加热器、LED灯和水泵,确保水族箱内的环境适宜鱼类生长。该项目展示了STM32在环境监测、设备控制和智能反馈系统中的应用。 环境准备 1…...

java语言基础案例-cnblog

java语言基础案例 象棋口诀 输出 package nb;public class XiangQi {public static void main(String[] args) {char a 马;char b 象;char c 卒;System.out.println(a"走日"b"走田""小"c"一去不复还");} }输出汇款单 package nb…...

MyBatis-Plus 之 typeHandler 的使用

一、typeHandler 的使用 1、存储json格式字段 如果字段需要存储为json格式,可以使用JacksonTypeHandler处理器。使用方式非常简单,如下所示: 在domain实体类里面要加上,两个注解 TableName(autoResultMap true) 表示自动…...

HDLBits中文版,标准参考答案 |2.5 More Verilog Features | 更多Verilog 要点

关注 望森FPGA 查看更多FPGA资讯 这是望森的第 7 期分享 作者 | 望森 来源 | 望森FPGA 目录 1 Conditional ternary operator | 条件三目运算符 2 Reduction operators | 归约运算器 3 Reduction: Even wider gates | 归约:更宽的门电路 4 Combinational fo…...

提升开机速度:有效管理Windows电脑自启动项,打开、关闭自启动项教程分享

日常使用Windows电脑时,总会需要下载各种各样的办公软件。部分软件会默认开机自启功能,开机启动项是指那些在电脑启动时自动运行的程序和服务。电脑开机自启太多的情况下会导致电脑卡顿,开机慢,运行不流畅的情况出现,而…...

数据库简单介绍

数据库是现代信息技术中用于存储、管理和检索数据的重要工具。数据库技术的发展经历了多个阶段,从早期的层次模型和网状模型,到关系型数据库的兴起,再到NoSQL和NewSQL的多样化发展。数据库系统已经成为现代信息系统的核心和基础设施。 数据库…...

运用MinIO技术服务器实现文件上传——利用程序上传图片(二 )

在上一篇文章中,我们已经在云服务器中安装并开启了minio服务,本章我们将为大家讲解如何利用程序将文件上传到minio桶中 下面介绍MinIO中的几个核心概念,这些概念在所有的对象存储服务中也都是通用的。 - **对象(Object&#xff0…...

C语言 | Leetcode C语言题解之第461题汉明距离

题目: 题解: int hammingDistance(int x, int y) {int s x ^ y, ret 0;while (s) {s & s - 1;ret;}return ret; }...

Qt 3D、QtQuick、QtQuick 3D 和 QML 的关系

理清 Qt 3D、QtQuick、QtQuick 3D 和 QML 的关系 在开发图形界面应用时,特别是在使用 Qt 框架时,开发者可能会接触到多个概念,如 Qt 3D、QtQuick、QtQuick 3D 和 QML。这些术语分别代表了 Qt 中不同的模块或技术,但由于它们的功能…...

wordpress后台更新后 前端没变化的解决方法

使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…...

接口测试中缓存处理策略

在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...

51c自动驾驶~合集58

我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留,CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制(CCA-Attention),…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装(Encapsulation) 定义:将数据(属性)和操作数据的方法绑定在一起,通过访问控制符(private、protected、public)隐藏内部实现细节。示例: public …...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建(全平台详解) 在开始使用 React Native 开发移动应用之前,正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南,涵盖 macOS 和 Windows 平台的配置步骤,如何在 Android 和 iOS…...

FFmpeg 低延迟同屏方案

引言 在实时互动需求激增的当下,无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作,还是游戏直播的画面实时传输,低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架,凭借其灵活的编解码、数据…...

深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法

深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...

AtCoder 第409​场初级竞赛 A~E题解

A Conflict 【题目链接】 原题链接:A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串,只有在同时为 o 时输出 Yes 并结束程序,否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个?3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制(过半机制&#xff0…...

Qt Http Server模块功能及架构

Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...