PyTorch中matmul函数使用详解和示例代码
torch.matmul
是 PyTorch 中用于执行矩阵乘法的函数,它根据输入张量的维度自动选择适当的矩阵乘法方式,包括:
- 向量内积(1D @ 1D)
- 矩阵乘向量(2D @ 1D)
- 向量乘矩阵(1D @ 2D)
- 矩阵乘矩阵(2D @ 2D)
- 批量矩阵乘法(>2D)
函数原型
torch.matmul(input, other, *, out=None) → Tensor
- input:第一个张量
- other:第二个张量
- out(可选):指定输出张量
详细说明
torch.matmul(a, b)
根据 a
和 b
的维度规则如下:
a 维度 | b 维度 | 操作类型 |
---|---|---|
1D | 1D | 向量点积 |
2D | 1D | 矩阵和向量相乘 |
1D | 2D | 向量和矩阵相乘 |
2D | 2D | 标准矩阵乘法 |
≥3D | ≥3D | 批量矩阵乘法(batch) |
示例代码
1. 向量点积(1D @ 1D)
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
result = torch.matmul(a, b)
print(result) # 输出:32.0
2. 矩阵乘向量(2D @ 1D)
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
b = torch.tensor([5.0, 6.0])
result = torch.matmul(a, b)
print(result) # 输出:[17.0, 39.0]
3. 向量乘矩阵(1D @ 2D)
a = torch.tensor([5.0, 6.0])
b = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
result = torch.matmul(a, b)
print(result) # 输出:[23.0, 34.0]
4. 矩阵乘矩阵(2D @ 2D)
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
b = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
result = torch.matmul(a, b)
print(result)
# 输出:
# [[19.0, 22.0],
# [43.0, 50.0]]
5. 批量矩阵乘法(3D @ 3D)
a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 5)
result = torch.matmul(a, b)
print(result.shape) # 输出:torch.Size([10, 3, 5])
综合示例:自定义线性层(类似 nn.Linear
)
下面是一个使用 torch.matmul
构建自定义线性层的完整示例,适合理解如何手动定义一个具有权重、偏置、支持自动求导的神经网络层,适合自定义网络结构或深入理解 PyTorch 的底层机制。
功能描述
- 实现线性变换:
y = x @ W^T + b
- 使用
torch.matmul
执行矩阵乘法 - 权重和偏置作为可训练参数
- 支持 GPU 和自动求导
代码实现
import torch
import torch.nn as nnclass MyLinear(nn.Module):def __init__(self, in_features, out_features):super(MyLinear, self).__init__()self.weight = nn.Parameter(torch.randn(out_features, in_features)) # shape: [out, in]self.bias = nn.Parameter(torch.zeros(out_features)) # shape: [out]def forward(self, x):# x: shape [batch_size, in_features]# weight: shape [out_features, in_features]# transpose weight -> shape [in_features, out_features], then matmulout = torch.matmul(x, self.weight.t()) + self.biasreturn out
使用示例
batch_size = 4
in_dim = 6
out_dim = 3x = torch.randn(batch_size, in_dim)
layer = MyLinear(in_dim, out_dim)output = layer(x)
print(output.shape) # torch.Size([4, 3])
与官方 nn.Linear
等效性验证(可选)
# 官方线性层
torch.manual_seed(0)
official = nn.Linear(in_dim, out_dim)# 自定义线性层,使用相同参数初始化
custom = MyLinear(in_dim, out_dim)
custom.weight.data.copy_(official.weight.data)
custom.bias.data.copy_(official.bias.data)# 比较输出
x = torch.randn(2, in_dim)
out1 = official(x)
out2 = custom(x)
print(torch.allclose(out1, out2)) # True
说明
项 | 内容 |
---|---|
torch.matmul | 用于实现 x @ W.T 矩阵乘法 |
nn.Parameter | 注册为可训练参数,自动加入 .parameters() 中 |
Module.forward() | 用于定义前向传播逻辑 |
注意事项
- 输入张量必须满足矩阵乘法的维度匹配规则。
- 对于 >2D 的张量,PyTorch 会自动按 batch size 广播执行多组矩阵乘法。
torch.matmul
不支持标量乘法(标量乘张量可用*
运算符)。
相关文章:
PyTorch中matmul函数使用详解和示例代码
torch.matmul 是 PyTorch 中用于执行矩阵乘法的函数,它根据输入张量的维度自动选择适当的矩阵乘法方式,包括: 向量内积(1D 1D)矩阵乘向量(2D 1D)向量乘矩阵(1D 2D)矩…...

论文解读:Locating and Editing Factual Associations in GPT(ROME)
论文发表于人工智能顶会NeurIPS(原文链接),研究了GPT(Generative Pre-trained Transformer)中事实关联的存储和回忆,发现这些关联与局部化、可直接编辑的计算相对应。因此: 1、开发了一种因果干预方法,用于识别对模型的事实预测起…...
NoSQl之Redis部署
一、Redis 核心概念与技术定位 1. 数据库分类与 Redis 的诞生背景 关系型数据库的局限性 数据模型:基于二维表结构,通过 SQL 操作,强一致性(ACID 特性),适合结构化事务场景(如银行转账、订单管…...

学习设计模式《十二》——命令模式
一、基础概念 命令模式的本质是【封装请求】命令模式的关键是把请求封装成为命令对象,然后就可以对这个命令对象进行一系列的处理(如:参数化配置、可撤销操作、宏命令、队列请求、日志请求等)。 命令模式的定义:将一个…...

十三、【核心功能篇】测试计划管理:组织和编排测试用例
【核心功能篇】测试计划管理:组织和编排测试用例 前言准备工作第一部分:后端实现 (Django)1. 定义 TestPlan 模型2. 生成并应用数据库迁移3. 创建 TestPlanSerializer4. 创建 TestPlanViewSet5. 注册路由6. 注册到 Django Admin 第二部分:前端…...

手撕 K-Means
1. K-means 的原理 K-means 是一种经典的无监督学习算法,用于将数据集划分为 kk 个簇(cluster)。其核心思想是通过迭代优化,将数据点分配到最近的簇中心,并更新簇中心,直到簇中心不再变化或达到最大迭代次…...

SmolVLA: 让机器人更懂 “看听说做” 的轻量化解决方案
🧭 TL;DR 今天,我们希望向大家介绍一个新的模型: SmolVLA,这是一个轻量级 (450M 参数) 的开源视觉 - 语言 - 动作 (VLA) 模型,专为机器人领域设计,并且可以在消费级硬件上运行。 SmolVLAhttps://hf.co/lerobot/smolvla…...

day45python打卡
知识点回顾: tensorboard的发展历史和原理tensorboard的常见操作tensorboard在cifar上的实战:MLP和CNN模型 效果展示如下,很适合拿去组会汇报撑页数: 作业:对resnet18在cifar10上采用微调策略下,用tensorbo…...

AIGC赋能前端开发
一、引言:AIGC对前端开发的影响 1. AIGC与前端开发的关系 从“写代码”到“生成代码”传统开发痛点:重复性编码工作、UI 设计稿还原、问题定位与调试...核心场景的AI化:需求转代码(P2C)、设计稿转代码(D2…...

Web 3D协作平台开发案例:构建制造业远程设计与可视化协作
HOOPS Communicator为开发者提供了丰富的定制化能力,助力他们在实现强大 Web 3D 可视化功能的同时,灵活构建符合特定业务需求的工程应用。对于希望构建在线协同设计工具的企业而言,如何在保障性能与用户体验的前提下实现高效开发,…...

AI Agent开发第78课-大模型结合Flink构建政务类长公文、长文件、OA应用Agent
开篇 AI Agent2025确定是进入了爆发期,到处都在冒出各种各样的实用AI Agent。很多人、组织都投身于开发AI Agent。 但是从3月份开始业界开始出现了一种这样的声音: AI开发入门并不难,一旦开发完后没法用! 经历过至少一个AI Agent从开发到上线的小伙伴们其实都听到过这种…...
极空间z4pro配置gitea mysql,内网穿透
极空间z4pro配置gitea mysql等记录,内网穿透 1、mysql、gitea镜像下载,极空间不成功,先用自己电脑科学后下载镜像,拉取代码: docker pull --platform linux/amd64 gitea/gitea:1.23 docker pull --platform linux/amd64 mysql:5.…...

第三方测试机构进行科技成果鉴定测试有什么价值
在当今科技创新的浪潮中,科技成果的鉴定测试至关重要,而第三方测试机构凭借其独特优势,在这一领域发挥着不可替代的作用。那么,第三方测试机构进行科技成果鉴定测试究竟有什么价值呢? 一、第三方测试机构能提供独立、公…...

华为云Flexus+DeepSeek征文|基于华为云Flexus X和DeepSeek-R1打造个人知识库问答系统
目录 前言 1 快速部署:一键搭建Dify平台 1.1 部署流程详解 1.2 初始配置与登录 2 构建专属知识库 2.1 进入知识库模块并创建新库 2.2 选择数据源导入内容 2.3 上传并识别多种文档格式 2.4 文本处理与索引构建 2.5 保存并完成知识库创建 3接入ModelArts S…...

【数据结构】_排序
【本节目标】 排序的概念及其运用常见排序算法的实现排序算法复杂度及稳定性分析 1.排序的概念及其运用 1.1排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 1.2特性…...
《前端面试题:JS数据类型》
JavaScript 数据类型指南:从基础到高级全解析 一、JavaScript 数据类型概述 JavaScript 作为一门动态类型语言,其数据类型系统是理解这门语言的核心基础。在 ECMAScript 标准中,数据类型分为两大类: 1. 原始类型(Pr…...

PPT转图片拼贴工具 v4.3
软件介绍 这个软件就是将PPT文件转换为图片并且拼接起来。 效果展示 支持导入文件和支持导入文件夹,也支持手动输入文件/文件夹路径 软件界面 这一次提供了源码和开箱即用版本,exe就是直接用就可以了。 软件源码 import os import re import sys …...

Chrome安装代理插件ZeroOmega(保姆级别)
目录 本文直接讲解一下怎么本地安装ZeroOmega一、下载文件在GitHub直接下ZeroOmega 的文件(下最新版即可) 二、安装插件打开 Chrome 浏览器,访问 chrome://extensions/ 页面(扩展程序管理页面),并打开开发者…...

Transformer-BiGRU多变量时序预测(Matlab完整源码和数据)
Transformer-BiGRU多变量时序预测(Matlab完整源码和数据) 目录 Transformer-BiGRU多变量时序预测(Matlab完整源码和数据)效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现Transformer-BiGRU多变量时间序列预测&…...

新华三H3CNE网络工程师认证—Easy IP
Easy IP 就是“用路由器自己的公网IP,给全家所有设备当共享门牌号”的技术!(省掉额外公网IP,省钱又省配置!) 生活场景对比,想象你住在一个小区:普通动态NAT:物业申请了 …...
《视觉SLAM十四讲》自用笔记 第二讲:SLAM系统概述
在rm队伍里作为算法组梯队队员度过了一个赛季,为了促进和负责其他工作的算法组成员的交流,我决定在接下来的半个学期里(可能更快)读完这本书,并将其中的部分理论应用于我自制的雷达导航小车上。 以下为第二讲的部分笔记…...
vscode 插件 eslint, 检查 js 语法
1. 起因, 目的: 我的需求 vscode 写js代码, 有什么插件能进行语法检查。 比如某个函数没有定义,getName(), 但是却调用了。 那么这个插件会给出警告,在 getName() 给出红色波浪线。类似这种效果的插件, 有吗…...

Excel 模拟分析之单变量求解简单应用
正向求解 利用公式根据贷款总额、还款期限、贷款利率,求每月还款金额 反向求解 根据每月还款能力,求最大能承受贷款金额 参数: 目标单元格:求的值所在的单元格 目标值:想要达到的预期值 可变单元格:变…...

装备制造项目管理具备什么特征?如何选择适配的项目管理软件系统进行项目管控?
国内某大型半导体装备制造企业与奥博思软件达成战略合作,全面引入奥博思 PowerProject 打造企业专属项目管理平台,进一步提升智能制造领域的项目管理效率与协同能力。 该项目管理平台聚焦半导体装备研发与制造的业务特性,实现了从项目立项、…...

FPGA 动态重构配置流程
触发FPGA 进行配置的方式有两种,一种是断电后上电,另一种是在FPGA运行过程中,将PROGRAM 管脚拉低。将PROGRAM 管脚拉低500ns 以上就可以触发FPGA 进行重构。 FPGA 的配置过程大致可以分为:配置的触发和建立阶段、加载配置文件和建…...
Elasticsearch的审计日志(Audit Logging)介绍
Elasticsearch 的审计日志(Audit Logging)是一种记录与安全相关事件的功能,用于监控和追踪对集群的访问行为。通过审计日志,管理员可以了解谁在何时对哪些资源执行了什么操作,从而满足合规性要求、进行安全分析和排查异常行为。 一、审计日志的核心功能 记录安全事件捕获…...
软件测试:质量保障的基石与未来趋势
软件测试作为软件开发生命周期中的关键环节,不仅是发现和修复缺陷的手段,更是确保产品质量、提升用户体验和降低开发成本的重要保障。在当今快速迭代的互联网时代,测试已从单纯的验证活动演变为贯穿整个开发过程的质量管理体系。本文将系统阐…...

网络安全逆向分析之rust逆向技巧
rust逆向技巧 rust逆向三板斧: 快速定位关键函数 (真正的main函数):观察输出、输入,字符串搜索,断点等方法。定位关键 加密区 :根据输入的flag,打硬件断点,快速捕获程序中对flag访问的位置&am…...
Docker容器化技术概述与实践
哈喽,大家好,我是左手python! Docker 容器化的基本概念 Docker 容器化是一种轻量级的虚拟化技术,通过将应用程序及其依赖项打包到一个可移植的容器中,使其在任何兼容 Docker 的环境中都能运行。与传统的虚拟机技术不同…...
win中将pdf转为图片
0 资料 博客 1 正文 直接使用这个软件即可https://sourceforge.net/projects/pkpdfconverter/...