当前位置: 首页 > news >正文

pytorch张量基础

  1. 引言
  2. 张量的基础知识
    1. 张量的概念
    2. 张量的属性
    3. 张量的创建
  3. 张量的操作
    1. 基本运算
    2. 索引和切片
    3. 形状变换
  4. 自动微分
    1. 基本概念
    2. 停止梯度传播
  5. 张量的设备管理
    1. 检查和移动张量
    2. CUDA 张量
  6. 高级操作
    1. 张量的视图
    2. 广播机制
    3. 分块和拼接
    4. 张量的复制
  7. 内存优化和管理
    1. 稀疏张量
    2. 内存释放
  8. 应用实例
    1. 线性回归
    2. 神经网络基础
  9. 总结

1. 引言

在机器学习和深度学习中,张量(Tensor)是核心的数据结构。了解和掌握张量的操作是学习 PyTorch 和构建神经网络模型的必要基础。张量可以表示从标量到高维数组的数据结构,它在 PyTorch 的计算图中扮演着基础角色。本指南旨在全面介绍 PyTorch 中张量的相关知识,帮助读者从基础打好深度学习的基础。

2. 张量的基础知识

1. 张量的概念

张量是一个数组的通用化,可以表示标量(0维)、向量(1维)、矩阵(2维)及更高维的数组。通俗来说,张量是一种多维数据结构,其本质上是一个多维数组。

2. 张量的属性

张量有多个重要属性,用来描述其数据和结构:

  • 形状(shape):描述张量的维度结构,例如 (2, 3) 表示一个包含 2 行 3 列的矩阵。
  • 数据类型(dtype):指定张量中元素的类型,例如 torch.float32torch.int64 等。
  • 设备(device):指示张量存储的设备,可以是 CPU 或 GPU。
  • 步幅(stride):步幅表示连续两个元素在各个维度上的步进距离。
import torchtensor = torch.tensor([[1., 2., 3.], [4., 5., 6.]])print(tensor.shape)    # torch.Size([2, 3])
print(tensor.dtype)    # torch.float32
print(tensor.device)   # cpu
print(tensor.stride()) # (3, 1)

3. 张量的创建

可以通过多种方式创建张量,包括从已有数据创建、使用随机数生成和从其他张量创建。

# 从数据创建
scalar = torch.tensor(5.0)          # 标量
vector = torch.tensor([1.0, 2.0, 3.0])  # 向量
matrix = torch.tensor([[1.0, 2.0], [3.0, 4.0]])  # 矩阵# 使用随机数创建
rand_tensor = torch.rand(2, 3)     # 均匀分布
randn_tensor = torch.randn(2, 3)   # 标准正态分布# 从其他张量创建
zeros_tensor = torch.zeros_like(matrix)  # 创建与 matrix 形状相同的全零张量

3. 张量的操作

1. 基本运算

张量支持基本的算术运算,包括加、减、乘、除。

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])# 加法
c = a + b# 减法
d = a - b# 乘法
e = a * b# 除法
f = a / b# 点积
dot_prod = torch.dot(a, b)  # 32.0# 矩阵乘法
matrix1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
matrix2 = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
matrix_mul = torch.mm(matrix1, matrix2)  # [[19.0, 22.0], [43.0, 50.0]]

2. 索引和切片

张量支持多种索引和切片操作,类似于 NumPy。

tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])# 索引
element = tensor[1, 2]  # 6.0# 切片
subset = tensor[:, 1]  # tensor([2.0, 5.0])

3. 形状变换

在不复制数据的情况下,PyTorch 支持多种形状变换操作。

# 重塑
reshaped = tensor.view(3, 2)  # tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])# 转置
transposed = tensor.t()       # tensor([[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]])# 增加或减少维度
unsqueezed = tensor.unsqueeze(0)  # 增加第0维
squeezed = tensor.squeeze()       # 去除所有维度为1的维度

4. 自动微分

PyTorch 提供强大的自动微分功能,称为Autograd。它可以自动计算张量的梯度,适用于优化和训练神经网络。

1. 基本概念

张量可以设置 requires_grad=True 以启用自动微分。计算张量的梯度使用 backward() 方法。

x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x[0] ** 2 + x[1] ** 3
y.backward()
print(x.grad)  # tensor([ 4.0, 27.0])

2. 停止梯度传播

在某些情况下,比如模型评估或推理时,需要停止梯度传播以提高性能并节省内存。

with torch.no_grad():y = x[0] ** 2 + x[1] ** 3# 使用 detach() 方法创建一个新的张量,该张量与原始张量共享数据,但不进行梯度追踪
detached_tensor = x.detach()

5. 张量的设备管理

1. 检查和移动张量

张量可以在 CPU 或 GPU 上进行计算。PyTorch 提供了简单的方法来检查和移动张量到不同的设备。

tensor = torch.tensor([1.0, 2.0, 3.0])# 检查是否有可用的 GPU
if torch.cuda.is_available():tensor = tensor.to('cuda')print(tensor.device)  # cuda:0# 将张量移动回 CPU
tensor = tensor.to('cpu')
print(tensor.device)  # cpu

2. CUDA 张量

使用 CUDA 张量可以显著提高计算速度,特别是在深度学习中。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tensor = torch.tensor([1.0, 2.0, 3.0], device=device)

6. 高级操作

1. 张量的视图

视图允许我们在不复制数据的情况下,改变张量的形状。

original_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
view_tensor = original_tensor.view(6)  # tensor([1, 2, 3, 4, 5, 6])# 修改视图
view_tensor[0] = 10
print(original_tensor)  # tensor([[10,  2,  3], [ 4,  5,  6]])

2. 广播机制

广播机制使得不同形状的张量能够进行相同大小的运算。

a = torch.tensor([1, 2, 3])
b = torch.tensor([[1], [2], [3]])
result = a + b
# result: tensor([[2, 3, 4],
#                 [3, 4, 5],
#                 [4, 5, 6]])

3. 分块和拼接

可以使用 split() 和 cat() 等函数进行分块和拼接。

tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])# 分割张量
split_tensors = torch.split(tensor, split_size_or_sections=2, dim=1)# 拼接张量
tensor_a = torch.tensor([[1, 2], [3, 4]])
tensor_b = torch.tensor([[5, 6], [7, 8]])
concat_tensor = torch.cat((tensor_a, tensor_b), dim=1)

4. 张量的复制

用于创建独立副本,clone() 和 detach() 是常用方法。

tensor = torch.tensor([1, 2, 3], requires_grad=True)
cloned_tensor = tensor.clone()
detached_tensor = tensor.detach()

7. 内存优化和管理

1. 稀疏张量

对于稀疏矩阵和张量,PyTorch 提供了稀疏张量表示,以便节省内存和计算资源。

indices = torch.tensor([[0, 1, 1], [2, 0, 2]])
values = torch.tensor([3, 4, 5], dtype=torch.float32)
sparse_tensor = torch.sparse_coo_tensor(indices, values, [2, 3])
print(sparse_tensor)

2. 内存释放

为了在训练和评估期间节省内存,可以释放不再需要的张量。

# 使用 del 语句手动删除对象
del tensor# 清空 GPU 切实可行的张量以释放内存
torch.cuda.empty_cache()

8. 应用实例

通过实际应用实例,可以更好地理解和掌握 PyTorch 张量的使用方式。

1. 线性回归

利用 PyTorch 张量实现简单的线性回归模型。

# 数据集
x_train = torch.tensor([[1.0], [2.0], [3.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0]])# 初始化参数
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)def model(x):return w * x + b# 损失函数
def loss_fn(y_pred, y):return ((y_pred - y) ** 2).mean()# 训练模型
learning_rate = 0.01
for epoch in range(1000):y_pred = model(x_train)loss = loss_fn(y_pred, y_train)loss.backward()with torch.no_grad():w -= learning_rate * w.gradb -= learning_rate * b.gradw.grad.zero_()b.grad.zero_()print(f'w: {w}, b: {b}')

2. 神经网络基础

张量在神经网络中的应用,是构建复杂模型的基础。

import torch.nn as nn# 简单的神经网络
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(1, 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 1)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return outmodel = SimpleNN()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 训练模型
for epoch in range(1000):y_pred = model(x_train)loss = criterion(y_pred, y_train)optimizer.zero_grad()loss.backward()optimizer.step()print(list(model.parameters()))

相关文章:

pytorch张量基础

引言张量的基础知识 张量的概念张量的属性张量的创建张量的操作 基本运算索引和切片形状变换自动微分 基本概念停止梯度传播张量的设备管理 检查和移动张量CUDA 张量高级操作 张量的视图广播机制分块和拼接张量的复制内存优化和管理 稀疏张量内存释放应用实例 线性回归神经网络…...

深入解析LlamaIndex Workflows【下篇】:实现ReAct模式AI智能体的新方法

之前我们介绍了来自LLM开发框架LlamaIndex的新特性:Workflows,一种事件驱动、用于构建复杂AI工作流应用的新方法(参考:[深入解析LlamaIndex Workflows:构建复杂RAG与智能体工作流的新利器【上篇】]。在本篇中&#xff…...

要在 Git Bash 中使用 `tree` 命令,下载并手动安装 `tree`。

0、git bash 安装 git(安装,常用命令,分支操作,gitee,IDEA集成git,IDEA集成gitee,IDEA集成github,远程仓库操作) 1、下载并手动安装 tree 下载 tree.exe 从 tree for Windows 官方站点 下载 tree 的 Windows 可执行文件。tree for Window:https://gnuwin32.source…...

Linux的基本指令(1)

前提: a:博主是在云服务器上进行操作的 b:windows上普通文件在Linux中也叫作普通文件,但是windows上的文件夹,在Linux中叫作目录 c:文件 文件内容 文件属性(创建时间,修改时间,…...

JavaEE之多线程进阶-面试问题

一.常见的锁策略 锁策略不是指某一个具体的锁,所有的锁都可以往这些锁策略中套 1.悲观锁与乐观锁 预测所冲突的概率是否高,悲观锁为预测锁冲突的概率较高,乐观锁为预测锁冲突的概率更低。 2.重量级锁和轻量级锁 从加锁的开销角度判断&am…...

费曼学习法没有输出对象怎么办?

‌费曼学习法并不需要输出对象。‌费曼学习法的核心在于通过将所学知识以简明易懂的方式解释给自己听,从而加深对知识的理解和记忆。这种方法强调的是理解和反思的过程,而不是简单地通过输出(如向他人解释)来检验学习效果。费曼学…...

Hive优化操作(二)

Hive 数据倾斜优化 在使用 Hive 进行大数据处理时,数据倾斜是一个常见的问题。本文将详细介绍数据倾斜的概念、表现、常见场景及其解决方案。 1. 什么是数据倾斜? 数据倾斜是指由于数据分布不均匀,导致大量数据集中到某个节点或任务中&…...

销冠的至高艺术:让自己不像销售

若想在销售领域脱颖而出,首先是让自己超越传统销售的框架,成为客户心中不可多得的行业顾问与信赖源泉。这不仅是身份的蜕变,更是影响力与信任度质的飞跃。 销冠对客户只吸引不骚扰,不讲自己卖什么,只讲自己能解决什么…...

Hive数仓操作(十一)

一、Hive 日期函数 在日常的数据处理工作中,日期和时间的处理是非常常见的操作。Hive 提供了丰富的日期函数,能够帮助我们方便地进行日期和时间的计算。本文将详细介绍 Hive 中常用的日期函数,并通过具体的示例展示其用法和结果。 1. 获取当…...

C语言初步介绍(初学者,大学生)【上】

1.C语⾔是什么? ⼈和⼈交流使⽤的是⾃然语⾔,如:汉语、英语、⽇语 那⼈和计算机是怎么交流的呢?使⽤ 计算机语⾔ 。 ⽬前已知已经有上千种计算机语⾔,⼈们是通过计算机语⾔写的程序,给计算机下达指令&am…...

陈文自媒体:现在的房价,已经跌到7年前!

今年的国庆北上广深都放开了政策,很多人都放弃旅游去看房了,现在的全民都有一个基本意识,现在的房子已经到了谷底,从各大政策就可以看出来,稍微有点钱的可以出手买房了。 昨天我哥跟我说,现在xx地方的房子…...

基于STM32的智能水族箱控制系统设计

引言 本项目基于STM32微控制器设计一个智能水族箱控制系统。该系统能够通过传感器监测水温、照明和水位,并自动控制加热器、LED灯和水泵,确保水族箱内的环境适宜鱼类生长。该项目展示了STM32在环境监测、设备控制和智能反馈系统中的应用。 环境准备 1…...

java语言基础案例-cnblog

java语言基础案例 象棋口诀 输出 package nb;public class XiangQi {public static void main(String[] args) {char a 马;char b 象;char c 卒;System.out.println(a"走日"b"走田""小"c"一去不复还");} }输出汇款单 package nb…...

MyBatis-Plus 之 typeHandler 的使用

一、typeHandler 的使用 1、存储json格式字段 如果字段需要存储为json格式,可以使用JacksonTypeHandler处理器。使用方式非常简单,如下所示: 在domain实体类里面要加上,两个注解 TableName(autoResultMap true) 表示自动…...

HDLBits中文版,标准参考答案 |2.5 More Verilog Features | 更多Verilog 要点

关注 望森FPGA 查看更多FPGA资讯 这是望森的第 7 期分享 作者 | 望森 来源 | 望森FPGA 目录 1 Conditional ternary operator | 条件三目运算符 2 Reduction operators | 归约运算器 3 Reduction: Even wider gates | 归约:更宽的门电路 4 Combinational fo…...

提升开机速度:有效管理Windows电脑自启动项,打开、关闭自启动项教程分享

日常使用Windows电脑时,总会需要下载各种各样的办公软件。部分软件会默认开机自启功能,开机启动项是指那些在电脑启动时自动运行的程序和服务。电脑开机自启太多的情况下会导致电脑卡顿,开机慢,运行不流畅的情况出现,而…...

数据库简单介绍

数据库是现代信息技术中用于存储、管理和检索数据的重要工具。数据库技术的发展经历了多个阶段,从早期的层次模型和网状模型,到关系型数据库的兴起,再到NoSQL和NewSQL的多样化发展。数据库系统已经成为现代信息系统的核心和基础设施。 数据库…...

运用MinIO技术服务器实现文件上传——利用程序上传图片(二 )

在上一篇文章中,我们已经在云服务器中安装并开启了minio服务,本章我们将为大家讲解如何利用程序将文件上传到minio桶中 下面介绍MinIO中的几个核心概念,这些概念在所有的对象存储服务中也都是通用的。 - **对象(Object&#xff0…...

C语言 | Leetcode C语言题解之第461题汉明距离

题目: 题解: int hammingDistance(int x, int y) {int s x ^ y, ret 0;while (s) {s & s - 1;ret;}return ret; }...

Qt 3D、QtQuick、QtQuick 3D 和 QML 的关系

理清 Qt 3D、QtQuick、QtQuick 3D 和 QML 的关系 在开发图形界面应用时,特别是在使用 Qt 框架时,开发者可能会接触到多个概念,如 Qt 3D、QtQuick、QtQuick 3D 和 QML。这些术语分别代表了 Qt 中不同的模块或技术,但由于它们的功能…...

软件设计师(软考学习)

数据库技术 数据库基础知识 1. 数据库中的简单属性、多值属性、复合属性、派生属性简单属性:指不能够再分解成更小部分的属性,通常是数据表中的一个列。例如学生表中的“学号”、“姓名”等均为简单属性。 多值属性:指一个属性可以有多个值…...

第一讲:Go语言开发入门:环境搭建与基础语法

文章目录 环境搭建windows环境搭建Mac环境搭建安装GO使用 Homebrew 安装 Go手动下载安装 Go 配置环境变量配置环境变量检查 Go 是否正确安装 验证安装:编写第一个 Go 程序创建 Go 工作区编写 Hello World 程序运行程序编译程序 常用的 Go 命令 Go语言基础语法1. 变量…...

Linux CentOS stream9配置本地yum源

在Linux系统中,yum源配置是一个重要的环节。把系统安装时配置的国外yum源转换为国内yum源,能够帮助系统快速安装软件包。对于网络环境不稳定或无法联网的系统,配置本地yum源,可以让用户在离线状态下也能进行软件包的安装,十分重要。 一、国内源 在使用Linux的日常工作中…...

std::string

std::string是C标准库中的一个基本类模板,专门用于处理字符串。它提供了一个可变长度的字符序列,以及一系列用于字符串操作的方法。std::string是值类型,这意味着当它作为函数参数传递或赋值时,整个字符串数据会被复制。 std::st…...

【Docker】03-自制镜像

1. 自制镜像 2. Dockerfile # 基础镜像 FROM openjdk:11.0-jre-buster # 设定时区 ENV TZAsia/Shanghai RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone # 拷贝jar包 COPY docker-demo.jar /app.jar # 入口 ENTRYPOINT ["ja…...

Java GC 分类,8和9使用的哪种?

Java的垃圾收集器(Garbage Collector, GC)负责自动管理内存,回收不再使用的对象所占用的空间。随着JVM的发展,出现了多种不同特性的垃圾收集器来适应不同的应用场景和性能需求。在Java 8中,主要有以下几种垃圾收集器&a…...

【Docker从入门到进阶】01.介绍 02.基础使用

1. 介绍 1.1. 什么是 Docker Docker 是一个开源的平台,用于开发、发布和运行应用程序。它使开发者能够以更精简的方式封装应用及其依赖,做到“打包一次,到处运行”。通过 Docker,您可以创建轻量级、可移植的容器,每个…...

GraphRAG-Local-UI - 基于 GraphRAG 支持本地的聊天UI

文章目录 一、关于 GraphRAG-Local-UI 🕸️特点🌟🗺️路线图最近更新即将推出的功能 二、📦安装和设置三、使用入门🚀1、创建索引目录2、添加示例数据(可选)3、初始化索引文件夹4、配置设置5、定…...

Java 根据字符生成背景透明的图片

上代码 package com.example.demotest.controller;/*** Author shaolin* Date 2024-10-08 10:11**/import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; import java.awt.image.ColorModel; import java.awt.image.WritableRaster; impor…...

树莓派3b安装ubuntu18.04服务器系统server配置网线连接

下载ubuntu镜像网址 img镜像,即树莓派官方烧录器使用的镜像网址 ubuntu18.04-server:ARM/RaspberryPi - Ubuntu Wiki 其他版本:Index of /ubuntu/releases 下载后解压即可。 发现使用官方烧录器烧录配置时配置wifi无论如何都不能使用&am…...