当前位置: 首页 > news >正文

pytorch里常用操作(持续更新)

对不起我脑子不太记事儿每次变换都得想想想所以干脆汇总一下算了,当然也有一些不是torch包里面的但是没有关系hhh 官方文档里有一堆不太常用的,这里整理的都是自己比较常用的

张量操作

torch.tensor:从Python列表或NumPy数组创建张量

torch.zeros/ones:创建全零/一张量

torch.zeros(10,4)就是创建[10,4]的全零张量

torch.rand:创建随机张量

torch.cat:沿指定维度拼接张量

torch.stack:在新的维度上堆叠张量

torch.stack(tensors, dim=0, out=None)

  • tensors:要堆叠的输入张量的列表或元组。
  • dim:指定要堆叠的新维度的索引。默认是0。
  • out:可选参数,用于指定结果张量的输出

 e.g

假设我们有两个张量 tensor1tensor2,它们的形状都是 (3, 4),并且我们想将它们堆叠在一个新的维度上,创建一个新的形状为 (2, 3, 4) 的张量。

# 创建两个张量
tensor1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
tensor2 = torch.tensor([[-1, -2, -3, -4], [-5, -6, -7, -8], [-9, -10, -11, -12]])# 使用torch.stack将它们堆叠在一个新的维度上
stacked_tensor = torch.stack((tensor1, tensor2), dim=0)

torch.reshape:改变张量的形状。

torch.transpose:交换张量的维度。

torch.transpose(input, dim0, dim1)

  • input:要进行维度交换的输入张量。
  • dim0:要交换的第一个维度的索引。
  • dim1:要交换的第二个维度的索引。

假设我们有一个形状为 (3, 4) 的张量,现在想要交换它的维度,创建一个新的张量,使其形状为 (4, 3)

# 创建一个形状为 (3, 4) 的张量
input_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])# 使用 torch.transpose 进行维度交换
transposed_tensor = torch.transpose(input_tensor, 0, 1)

torch.arange: 用于创建一个包含指定范围内的数值的一维张量

# 开始,结尾,间隔;和索引比较像

arr = torch.arange(10,100,10) #[10, 20, 30, 40, 50, 60, 70, 80, 90]

torch.meshgrid: 网格图

根据提供的x,y轴的范围得到一张网格的里面的点的xy坐标

torch.flatten(tensor,dim) : 把tensor压缩,dim表示压缩的维度

torch.unsqueeze:在张量中插入新的维度

new_tensor = torch.unsqueeze(input, dim)

  • input 是要插入新维度的输入张量。
  • dim 是要插入新维度的位置,通常是一个非负整数。

y = torch.tensor([[1, 2], [3, 4],[5, 6]]) # torch.Size[3,2]
y_new_1 = torch.unsqueeze(y, 0) #torch.Size[1, 3, 2]
y_new_2 = torch.unsqueeze(y, 1) # torch.Size[3, 1, 2]
y_new_3 = torch.unsqueeze(y, 2) # torch.Size[3, 2, 1]

一些非torch包的操作

[:, :, :None]最后一维扩一维

A;2d,B:1d B[i*cols+j] = A[i,j] 把二阶张量变成一阶

permute(): 矩阵转置


数学操作

torch.add/sub:张量相加/减

torch.mul/div:张量相乘/除

torch.sum:计算张量的和

torch.mean:计算张量的平均值

torch.max/min:找到张量中的最大/小值

torch.abs:计算张量的绝对值

torch.exp:计算输入张量中元素的指数(exponential)

x = torch.tensor([1.0, 2.0, 3.0])

exp_x = torch.exp(x)     # [e^1.0, e^2.0, e^3.0]


索引和切片

tensor[idx]:根据索引获取张量中的元素。
tensor[start:end]:切片操作。
tensor[:, 1]:选取指定列。
tensor[condition]:使用布尔条件进行索引。


自动求导

torch.autograd.Variable:创建自动求导的变量。
backward():计算梯度。
grad:访问梯度值。
no_grad():上下文管理器,用于禁用梯度计算。


神经网络模块

torch.nn.Module:创建神经网络模块

torch.nn.Linear:定义全连接层

这个层通常用于神经网络中,用来实现从输入到输出的线性变换,其中包括权重矩阵和偏置项。

# 创建一个 Linear 层
input_size = 10
output_size = 5
linear_layer = nn.Linear(input_size, output_size)# 随机生成一个输入张量
input_data = torch.randn(1, input_size)  # 这里创建一个形状为 (1, input_size) 的随机输入张量# 使用线性层进行前向传播
output = linear_layer(input_data)# 查看权重和偏置
weights = linear_layer.weight
bias = linear_layer.biasprint("输入张量:", input_data)
print("输出张量:", output)
print("权重矩阵:", weights)
print("偏置项:", bias)

这个示例中,首先创建了一个 nn.Linear 层,指定输入特征的数量和输出特征的数量。然后,随机生成一个输入张量 input_data,并通过将其传递给 linear_layer 来进行前向传播。线性层会应用权重矩阵和偏置项,生成输出张量。

nn.Linear 层在神经网络中通常用于连接不同层之间的神经元,执行线性变换的作用,帮助网络学习数据的特征表示。

torch.nn.Conv2d:定义卷积层

torch.nn.ReLU:ReLU激活函数

torch.nn.CrossEntropyLoss:交叉熵损失函数

torch.nn.optim:包含各种优化器,如SGD、Adam等

torch.nn.LayerNorm:用于层归一化

层归一化是一种用于神经网络的正则化技术,有助于加速训练和提高模型的鲁棒性。输入输出的形状并不会改变

import torch
import torch.nn as nn# 创建一个 LayerNorm 层
input_size = 10
layer_norm = nn.LayerNorm(input_size)# 随机生成一个输入张量
input_data = torch.randn(1, input_size)  # 创建一个形状为 (1, input_size) 的随机输入张量# 使用 LayerNorm 层进行前向传播
output = layer_norm(input_data)print("输入张量:", input_data)
print("LayerNorm 后的输出张量:", output) #shape[1,input_size]

torch.nn.Parameter:将张量标记为模型参数(可训练的参数)

将张量封装为 torch.nn.Parameter 对象后,它会被自动注册为模型的可训练参数,并在反向传播(backpropagation)期间更新它的值。这对于构建神经网络模型非常有用,因为神经网络的权重和偏置通常需要在训练期间进行优化。

# 创建一个普通的张量
tensor_data = torch.tensor([1.0, 2.0, 3.0])

# 将张量包装为一个模型参数·1                                                                                               
parameter = nn.Parameter(tensor_data)

# 打印参数
print(parameter) #Parameter containing: tensor([1., 2., 3.], requires_grad=True)

一起使用的是register_buffer

用于将张量注册为模型的缓冲区(buffer)

 注册为缓冲区的张量不会被视为模型的可训练参数,也不会在反向传播期间更新。它们用于保存模型的固定状态信息,例如统计信息(均值、方差等)、预训练的权重或任何其他不需要进行梯度更新的张量。

register_buffer 的主要作用是将这些张量添加到模型的状态字典中,以便在保存和加载模型时一并保存和加载。这对于确保模型的一致性和可重现性非常有用。

class CustomModel(nn.Module):def __init__(self):super(CustomModel, self).__init__()# 创建一个常量张量作为缓冲区self.register_buffer('constant_tensor', torch.tensor([1.0, 2.0, 3.0]))# 创建模型实例
model = CustomModel()# 打印模型缓冲区
for name, buffer in model.named_buffers():print(name, buffer)

torch.nn.MultiheadAttention:实现多头注意力机制

多头注意力机制允许模型同时关注输入中的不同部分,以提高模型性能

定义多头注意力模块
- embed_dim: 输入的维度
- num_heads: 头的数量,用于并行处理不同部分的注意力
- dropout: 可选的丢弃率

attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)# 输入数据 (query, key, value),通常是三个相同形状的张量
query = torch.randn(seq_length, batch_size, embed_dim)
key = torch.randn(seq_length, batch_size, embed_dim)
value = torch.randn(seq_length, batch_size, embed_dim)# 调用多头注意力模块
output, attention_weights = attention(query, key, value)# output 是注意力机制的输出,attention_weights 是注意力权重

e.g 

import torch
import torch.nn as nn# 定义多头注意力模块
embed_dim = 128
num_heads = 4
attention = nn.MultiheadAttention(embed_dim, num_heads)# 输入数据 (query, key, value)
seq_length = 10
batch_size = 32
query = torch.randn(seq_length, batch_size, embed_dim)
key = torch.randn(seq_length, batch_size, embed_dim)
value = torch.randn(seq_length, batch_size, embed_dim)# 调用多头注意力模块
output, attention_weights = attention(query, key, value)print("Output shape:", output.shape)
print("Attention weights shape:", attention_weights.shape)

数据加载和处理

torch.utils.data.Dataset:创建自定义数据集。


torch.utils.data.DataLoader:数据加载器。


transforms模块:用于数据预处理和转换。

相关文章:

pytorch里常用操作(持续更新)

对不起我脑子不太记事儿每次变换都得想想想所以干脆汇总一下算了,当然也有一些不是torch包里面的但是没有关系hhh 官方文档里有一堆不太常用的,这里整理的都是自己比较常用的 张量操作 torch.tensor:从Python列表或NumPy数组创建张量 torc…...

地铁大数据客流分析系统 设计与实现 计算机竞赛

文章目录 1 前言1.1 实现目的 2 数据集2.2 数据集概况2.3 数据字段 3 实现效果3.1 地铁数据整体概况3.2 平均指标3.3 地铁2018年9月开通运营的线路3.4 客流量相关统计3.4.1 线路客流量排行3.4.2 站点客流量排行3.4.3 入站客流排行3.4.4 整体客流随时间变化趋势3.4.5 不同线路客…...

00后都到适婚年龄啦!90后的还在低调什么?

当你的想法还停留在00后读书时代,其实大部分00后早已步入工作社会,还有不少人已经步入婚姻。广东金媒人婚恋,无论是广州、深圳、东莞、佛山举办活动的参与者中,00后的男生女生都占了不少。 广州深圳这样一二线城市的单身年轻人群&…...

reactnative使用七牛云上传图片

安装react-native-qiniu npm install react-native-qiniu --save 上传文件 import Qiniu,{Auth,ImgOps,Conf,Rs,Rpc} from react-native-qiniu;// 初始化七牛云配置 // Qiniu.region.z0:华东地区(默认值)。 // Qiniu.region.z1&#xff1a…...

在JavaScript中,如何创建一个数组或对象?

在JavaScript中,可以使用以下方式创建数组和对象: 一:创建数组(Array): 1:使用数组字面量(Array Literal)语法,使用方括号 [] 包裹元素,并用逗号分隔: let array1 = []; // 空数组 let array2 = [1, 2, 3]; // 包含三个数字的数组 let array3 = [apple, banana,…...

001.第一个C语言项目

Visual studio2022的使用 创建第一个C语言项目和源文件 https://blog.csdn.net/qq_45037165/article/details/124520286 第一个C语言项目 #include<stdio.h> int main() {printf("Hello World");return 0; }运行结果&#xff1a; 第一行为库函数&#xff0…...

luffy项目后端轮播图接口

后台主页功能 需求 根据原型图&#xff0c;分析出首页需要配合俩接口 轮播图接口&#xff08;要写&#xff09; 查询所有轮播图 推荐课程接口(暂时先不写) 设计表 轮播图表&#xff1a;Banner 写轮播图接口 查询所有轮播图 轮播图表 写一个公共表模型且只用于继承 fr…...

如何通过Photoshop将视频转换成GIF图片

一、应用场景 1、将视频转有趣动图发朋友圈 2、写CSDN无法上传视频&#xff0c;而可以用GIF动图替代 3、其他 二、实现步骤 1、打开Photoshop APP 2、点击文件——导入——视频帧到图层 3、选择视频文件 4、配置视频信息&#xff0c;按照图片提示配置完毕之后点击确定&…...

书单|1024程序员狂欢节充能书单!

点击链接进入图书专题 1024程序员节 “IT有得聊”是机械工业出版社旗下IT专业资讯和服务平台&#xff0c;致力于帮助读者在广义的IT领域里&#xff0c;掌握更专业、更实用的知识与技能&#xff0c;快速提升职场竞争力。 点击蓝色微信名可快速关注我们。 一年一度的1024程序员…...

GRS认证与TC交易证明的区别

TC&#xff08;Transaction Certificate&#xff09;交易证书是由认证单位向其客户出具再生含量证明&#xff0c;证明本次 销售产品符合GRS标准。TC交易证书上列明 卖方&#xff08;seller&#xff09;&#xff0c;买方&#xff08;buyer&#xff09;,收货方 &#xff08;consi…...

高精度时间测量(TDC)电路MS1022

MS1022 是一款高精度时间测量电路&#xff0c;内部集成了模拟比 较器、模拟开关、施密特触发器等器件&#xff0c;从而大大简化了外 围电路。同时内部增加了第一波检测功能&#xff0c;使抗干扰能力大 大提高。通过读取第一个回波脉冲的相对宽度&#xff0c;用户可以获 得接…...

js关键字

JavaScript 的关键字是指有特殊含义的单词&#xff0c;它们不能用作标识符&#xff0c;比如变量名、函数名等。 以下是 JavaScript 的关键字列表及其解释&#xff1a; true&#xff1a;布尔值 truefalse&#xff1a;布尔值 falsenull&#xff1a;表示一个空值或空对象引用und…...

《算法通关村第二关——指定区间反转问题解析》

《算法通关村第二关——指定区间反转问题解析》 题目描述 给你单链表的头指针head和两个整数left和right&#xff0c;其中left < right 。 请你反转从位置left到位置right的链表节点&#xff0c;返回反转后的链表。 示例1&#xff1a; 输入&#xff1a; head [1,2,3,4,5…...

掌控安全Update.jsp SQL注入

0x01 漏洞介绍 亿赛通电子文档安全管理系统是国内最早基于文件过滤驱动技术的文档加解密产品之一&#xff0c;保护范围涵盖终端电脑&#xff08;Windows、Mac、Linux系统平台&#xff09;、智能终端&#xff08;Android、IOS&#xff09;及各类应用系统&#xff08;OA、知识管理…...

C#将图片转换为ICON格式(程序运行图标)

介绍&#xff1a; C#创建窗体项目后左上角有显示图标&#xff0c;这个图标会在运行的时候显示在下面进程这里&#xff0c;但是必须是ico格式的图片才可以导入使用。以下是将图片打开后保存为ico格式代码。 代码如下&#xff1a; main函数测试 new 将图片转换成icon格式(&qu…...

ELK架构Logstash的相关插件:grok、multiline、mutate、date的详细介绍

文章目录 1. grok (正则捕获插件)1.1 作用1.2 正则表达式的类型1.2.1 内置正则表达式1.2.2 自定义正则表达式 2. mutate (数据修改插件&#xff09;2.1 作用2.2 常见配置选项2.3 应用实例 3. multiline &#xff08;多行合并插件&#xff09;3.1 作用3.2 常用配置项及示例3.2.1…...

linux 防火墙介绍以及iptables的使用

背景介绍 在前几天&#xff0c;于工发现我们内部的150服务器7554端口被外网访问了。该应用提供着内部的摄像头资源。为了避免被入侵&#xff0c;于是我添加了一些iptables规则&#xff0c;防止外网的访问。 解决方式 解决方式有两种&#xff1a; 关闭公司公网路由器对150服务…...

原码、反码、补码在汇编中的应用

原文章&#xff1a;知乎 原码和二进制类似&#xff0c;不过它有符号位。正数符号位为0&#xff0c;负数为1 。 例&#xff1a;40000 0100 &#xff0c;-41000 0100 原码是人脑最容易理解和计算的表示方式。 但是这在计算机中计算就出了问题&#xff0c;这两个&#xff08;4…...

【红日靶场】vulnstack5-完整渗透过程

系列文章目录 【红日靶场】vulnstack1-完整渗透过程 【红日靶场】vulnstack2-完整渗透过程 【红日靶场】vulnstack3-完整渗透过程 【红日靶场】vulnstack4-完整渗透过程 文章目录 系列文章目录描述虚拟机密码红队思路 一、环境初始化二、开始渗透外网打点上线cs权限提升域信息…...

嵌入式平台的电源总结

本文引注: https://mp.weixin.qq.com/s/PuSxHDFbJjjHEReukLSvyg 1.AC的定义 Alternating Current&#xff08;交流&#xff09;的首字母缩写。AC是大小和极性&#xff08;方向&#xff09;随时间呈周期性变化的电流。电流极性在1秒内的变化次数被称为频率&#xff0c;以Hz为单位…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应&#xff0c;这是一种非线性光学现象&#xff0c;主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场&#xff0c;对材料产生非线性响应&#xff0c;可能…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

PPT|230页| 制造集团企业供应链端到端的数字化解决方案:从需求到结算的全链路业务闭环构建

制造业采购供应链管理是企业运营的核心环节&#xff0c;供应链协同管理在供应链上下游企业之间建立紧密的合作关系&#xff0c;通过信息共享、资源整合、业务协同等方式&#xff0c;实现供应链的全面管理和优化&#xff0c;提高供应链的效率和透明度&#xff0c;降低供应链的成…...

抖音增长新引擎:品融电商,一站式全案代运营领跑者

抖音增长新引擎&#xff1a;品融电商&#xff0c;一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中&#xff0c;品牌如何破浪前行&#xff1f;自建团队成本高、效果难控&#xff1b;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...

全球首个30米分辨率湿地数据集(2000—2022)

数据简介 今天我们分享的数据是全球30米分辨率湿地数据集&#xff0c;包含8种湿地亚类&#xff0c;该数据以0.5X0.5的瓦片存储&#xff0c;我们整理了所有属于中国的瓦片名称与其对应省份&#xff0c;方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

Go 语言并发编程基础:无缓冲与有缓冲通道

在上一章节中&#xff0c;我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道&#xff0c;它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好&#xff0…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...

基于Springboot+Vue的办公管理系统

角色&#xff1a; 管理员、员工 技术&#xff1a; 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能&#xff1a; 该办公管理系统是一个综合性的企业内部管理平台&#xff0c;旨在提升企业运营效率和员工管理水…...

Spring AI Chat Memory 实战指南:Local 与 JDBC 存储集成

一个面向 Java 开发者的 Sring-Ai 示例工程项目&#xff0c;该项目是一个 Spring AI 快速入门的样例工程项目&#xff0c;旨在通过一些小的案例展示 Spring AI 框架的核心功能和使用方法。 项目采用模块化设计&#xff0c;每个模块都专注于特定的功能领域&#xff0c;便于学习和…...