理解深度学习pytorch框架中的线性层
文章目录
在神经网络或机器学习的线性层(Linear Layer / Fully Connected Layer)中,经常会见到两种形式的公式:
- 数学文献或传统线性代数写法: y = W x + b \displaystyle y = W\,x + b y=Wx+b
- 一些深度学习代码中写法: y = x W T + b \displaystyle y = x\,W^T + b y=xWT+b
初次接触时,很多人会觉得两者“方向”不太一样,不知该如何对照理解;再加上矩阵维度 ( in_features , out_features ) (\text{in\_features},\, \text{out\_features}) (in_features,out_features) 和 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features) 的各种写法常常让人疑惑不已。本文将从数学角度和编程实现角度剖析它们的关系,并结合实际示例指出一些常见的坑与需要特别留意的下标对应问题。
1. 数学角度: y = W x + b \displaystyle y = W\,x + b y=Wx+b
在线性代数中,如果我们假设输入 x x x 是一个列向量,通常会写作 x ∈ R ( in_features ) \displaystyle x\in\mathbb{R}^{(\text{in\_features})} x∈R(in_features)(或者在更严格的矩阵形状记法下写作 ( in_features , 1 ) (\text{in\_features},\,1) (in_features,1))。那么一个最常见的全连接层可以表示为:
y = W x + b , y = W\,x + b, y=Wx+b,
其中:
- W W W 是一个大小为 ( out_features , in_features ) \bigl(\text{out\_features},\,\text{in\_features}\bigr) (out_features,in_features) 的矩阵;
- b b b 是一个 out_features \text{out\_features} out_features-维的偏置向量(形状 ( out_features , 1 ) (\text{out\_features},\,1) (out_features,1));
- y y y 则是输出向量,大小为 out_features \text{out\_features} out_features。
示例
假设 in_features = 3 \text{in\_features}=3 in_features=3, out_features = 2 \text{out\_features}=2 out_features=2。那么:
W ∈ R 2 × 3 , x ∈ R 3 × 1 , b ∈ R 2 × 1 . W \in \mathbb{R}^{2\times 3},\quad x \in \mathbb{R}^{3\times 1},\quad b \in \mathbb{R}^{2\times 1}. W∈R2×3,x∈R3×1,b∈R2×1.
矩阵写开来就是:
W = [ w 11 w 12 w 13 w 21 w 22 w 23 ] , x = [ x 1 x 2 x 3 ] , b = [ b 1 b 2 ] . W = \begin{bmatrix} w_{11} & w_{12} & w_{13} \\[5pt] w_{21} & w_{22} & w_{23} \end{bmatrix},\quad x = \begin{bmatrix} x_{1}\\ x_{2}\\ x_{3} \end{bmatrix},\quad b = \begin{bmatrix} b_{1}\\ b_{2} \end{bmatrix}. W=[w11w21w12w22w13w23],x= x1x2x3 ,b=[b1b2].
那么线性变换结果 W x + b Wx + b Wx+b 可以展开为:
W x + b = [ w 11 x 1 + w 12 x 2 + w 13 x 3 w 21 x 1 + w 22 x 2 + w 23 x 3 ] + [ b 1 b 2 ] = [ w 11 x 1 + w 12 x 2 + w 13 x 3 + b 1 w 21 x 1 + w 22 x 2 + w 23 x 3 + b 2 ] . \begin{aligned} Wx + b &= \begin{bmatrix} w_{11}x_1 + w_{12}x_2 + w_{13}x_3 \\ w_{21}x_1 + w_{22}x_2 + w_{23}x_3 \end{bmatrix} + \begin{bmatrix} b_1 \\ b_2 \end{bmatrix} \\ &= \begin{bmatrix} w_{11}x_1 + w_{12}x_2 + w_{13}x_3 + b_1 \\ w_{21}x_1 + w_{22}x_2 + w_{23}x_3 + b_2 \end{bmatrix}. \end{aligned} Wx+b=[w11x1+w12x2+w13x3w21x1+w22x2+w23x3]+[b1b2]=[w11x1+w12x2+w13x3+b1w21x1+w22x2+w23x3+b2].
这就是最传统、在数学文献或线性代数课程中最常见的表示方法。
2. 编程实现角度: y = x W T + b \displaystyle y = x\,W^T + b y=xWT+b
在实际的深度学习代码(例如 PyTorch、TensorFlow)中,经常看到的却是下面这种写法:
y = x @ W.T + b
注意这里 W.shape 通常被定义为 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features),而 x.shape 在批量处理时则是 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features)。于是 (x @ W.T) 的结果是 ( batch_size , out_features ) (\text{batch\_size},\, \text{out\_features}) (batch_size,out_features)。
为什么会出现转置?
因为在数学里我们通常把 x x x 当作“列向量”放在右边,于是公式变成 y = W x + b y = Wx + b y=Wx+b。
但在编程里,尤其是处理批量输入时,x 常写成“行向量”的形式 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features),这就造成了在进行矩阵乘法时,需要将 W(大小 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features))转置成 ( in_features , out_features ) (\text{in\_features},\, \text{out\_features}) (in_features,out_features),才能满足「行×列」的匹配关系。
从结果上来看,
( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) . (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}). (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features).
所以,在代码里就写成 x @ W.T,再加上偏置 b(通常会广播到 batch_size \text{batch\_size} batch_size 那个维度)。
本质上这和数学公式里 y = W x + b y = W\,x + b y=Wx+b 并无冲突,只是一个“列向量”和“行向量”的转置关系。只要搞清楚最终你想让输出 y y y 的 shape 是多少,就能明白在代码里为什么要写 .T。
3. 常见错误与易混点解析
有些教程或文档,会不小心写成:“如果我们有一个形状为 ( in_features , out_features ) (\text{in\_features},\text{out\_features}) (in_features,out_features) 的权重矩阵 W W W……”——然后又要做 W x Wx Wx,想得到一个 out_features \text{out\_features} out_features-维的结果。但按照线性代数的常规写法,行数必须和输出维度匹配、列数必须和输入维度匹配。所以 正确 的说法应该是
W ∈ R ( out_features ) × ( in_features ) . W\in\mathbb{R}^{(\text{out\_features}) \times (\text{in\_features})}. W∈R(out_features)×(in_features).
否则从矩阵乘法次序来看就对不上。
但这又可能让人迷惑:为什么深度学习框架 torch.nn.Linear(in_features, out_features) 却给出 weight.shape == (out_features, in_features)? 其实正是同一个道理,它和上面“数学文献里”用到的 W W W 形状完全一致。
4. 小结
-
从数学角度:
最传统的记号是
y = W x + b , W ∈ R ( out_features ) × ( in_features ) , x ∈ R ( in_features ) , y ∈ R ( out_features ) . y = W\,x + b, \quad W \in \mathbb{R}^{(\text{out\_features})\times(\text{in\_features})},\, x \in \mathbb{R}^{(\text{in\_features})},\, y \in \mathbb{R}^{(\text{out\_features})}. y=Wx+b,W∈R(out_features)×(in_features),x∈R(in_features),y∈R(out_features). -
从深度学习代码角度:
- 由于批量数据常被视为行向量,每一行代表一个样本特征,因此形状通常是 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features)。
- 对应的权重
W定义为 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features)。为了完成行乘以列的矩阵运算,需要对W做转置:y = x @ W.T + b - 得到的
y.shape即 ( batch_size , out_features ) (\text{batch\_size},\, \text{out\_features}) (batch_size,out_features)。
-
避免踩坑:
- 写公式时,仔细确认 in_features \text{in\_features} in_features、 out_features \text{out\_features} out_features 的位置以及矩阵行列顺序。
- 编程实践中理解“为什么要
.T”非常重要:那只是为了匹配「行×列」的矩阵乘法规则,本质上还是和 y = W x + b y = Wx + b y=Wx+b 相同。
通过理解并区分“列向量”与“行向量”的不同惯例,避免因为矩阵维度或转置不当而导致莫名其妙的错误或 bug。
参考链接
- PyTorch 文档:
torch.nn.Linear - 深度学习中的矩阵运算初步 —— batch_size 与矩阵乘法
- 常见线性代数符号:行向量与列向量
相关文章:
理解深度学习pytorch框架中的线性层
文章目录 1. 数学角度: y W x b \displaystyle y W\,x b yWxb示例 2. 编程实现角度: y x W T b \displaystyle y x\,W^T b yxWTb3. 常见错误与易混点解析4. 小结参考链接 在神经网络或机器学习的线性层(Linear Layer / Fully Connect…...
电路研究9.2——合宙Air780EP使用AT指令
这里正式研究AT指令的学习了,之前只是接触的AT指令,这里则是深入分析AT指令了。 软件的开发方式: AT:MCU 做主控,MCU 发 AT 命令给模组的开发方式,模组仅提供标准的 AT 固件, 所有的业务控制逻辑…...
Qt数据库相关操作
目录 一、前言 二、类与接口介绍 1.连接管理类 2.数据操作类 3.数据模型类 4.其它类 三、主要操作流程 1.示例 2.绑定参数 3.事务操作 一、前言 要在Qt中操作数据库,首先要安装对应的数据库,还要确保安装了Qt SQL模块。使用MySQL时࿰…...
2025-01-22 Unity Editor 1 —— MenuItem 入门
文章目录 1 Editor 文件夹2 MenuItem3 使用示例3.1 打开网址3.2 打开文件夹3.3 Menu Toggle3.4 Menu 代码复用3.5 MenuItem 激活与失活4 代码示例 1 Editor 文件夹 Editor 文件夹是 Unity 中的特殊文件夹,Unity 中所有编辑器相关的脚本都需要放置在其中…...
解锁C#编程新姿势:Z.ExtensionMethods入门秘籍
一、引言 在 C# 的开发旅程中,我们常常会遇到各种重复性高、复杂度低的任务,这些任务虽然基础,但却占据了我们大量的开发时间。比如处理字符串时,经常需要进行非空判断、格式转换;操作日期时间时,计算某个…...
不使用 JS 纯 CSS 获取屏幕宽高
前言 在现代前端开发中,获取屏幕的宽度和高度通常依赖于 JavaScript。然而现代 CSS 也可以获取到屏幕的宽高,通过自定义属性(CSS Variables)和一些数学函数来实现这一目标。本文将详细解析如何使用 CSS 的 property 规则和一些数…...
Node.js NativeAddon 构建工具:node-gyp 安装与配置完全指南
Node.js NativeAddon 构建工具:node-gyp 安装与配置完全指南 node-gyp Node.js native addon build tool [这里是图片001] 项目地址: https://gitcode.com/gh_mirrors/no/node-gyp 项目基础介绍及主要编程语言 Node.js NativeAddon 构建工具(node-gyp…...
【ARTS】【LeetCode-704】二分查找算法
目录 前言 什么是ARTS? 算法 力扣704题 二分查找 基本思想: 二分查找算法(递归的方式): 经典写法(找单值): 代码分析: 经典写法(找数组即多个返回值) 代码分析 经典题目 题目描述: 官方题解 深入思考 模版一 (相错终止/左闭右闭) 相等返回情形…...
Vue.js 配置路由:基本的路由匹配
Vue.js 配置路由:基本的路由匹配 在 Vue.js 应用中,Vue Router 是官方提供的路由管理器,用于在单页应用(SPA)中管理不同的视图。通过配置路由,应用可以根据 URL 的变化展示相应的组件。 基本的路由匹配是…...
鸿蒙(HarmonyOS)Json格式转实体对象(2)
下面是一个复杂的json体。 怎么把json转实体类,首先要定义类 import List from ohos.util.List export class InfoModel{msg: stringcars: List<Cars>code: numberpermissions: List<string>roles: List<string>user: User}class Cars{createBy:…...
代码随想录 栈与队列 test 6
239. 滑动窗口最大值 - 力扣(LeetCode) 每次只取窗口中最大值,这个最大值可能在后面的滑动中保持不变,而比最大值小的值且在最大值之前出现的值没必要保留,因此可以通过单调队列利用这个特性。 这个单调队列具有如下…...
动手学深度学习2025.1.23
一、预备知识 1.数据操作 (1)数据访问: 一个元素:[1,2] //行下标为1,列下标为2的元素 一行元素:[1,:] //行下标为1的所有元素 一列元素:[:,1] //列下标为1的所有元素 子区域:[…...
生存网络与mlr3proba
在R语言中,mlr3包是一个用于机器学习的强大工具包。它提供了一种简单且灵活的方式来执行超参数调整。 生存网络是一种用于生存分析的模型,常用在医学和生物学领域。生存分析是一种统计方法,用于研究事件发生的时间和相关因素对事件发生的影响。生存网络可以用来预测个体在给…...
C#与AI的共同发展
C#与人工智能(AI)的共同发展反映了编程语言随着技术进步而演变,以适应新的挑战和需要。自2000年微软推出C#以来,这门语言经历了多次迭代,不仅成为了.NET平台的主要编程语言之一,还逐渐成为构建各种类型应用程序的强大工具。随着时…...
2000-2020年各省第二产业增加值数据
2000-2020年各省第二产业增加值数据 1、时间:2000-2020年 2、来源:国家统计局、统计年鉴、各省年鉴 3、指标:行政区划代码、地区、年份、第二产业增加值 4、范围:31省 5、指标解释:第二产业增加值是指在一个国家或…...
【MySQL】 库的操作
欢迎拜访:雾里看山-CSDN博客 本篇主题:【MySQL】 库的操作 发布时间:2025.1.23 隶属专栏:MySQL 目录 库的创建语法使用 编码规则认识编码集查看数据库默认的编码集和校验集查看数据库支持的编码集和校验集指定编码创建数据库验证不…...
docker 启动镜像命令集合
安装rabbitmq 参考地址: https://blog.csdn.net/xxpxxpoo8/article/details/122935994 docker run -it -d --namerabbit-3.8 -v /d/docker/rabbitmq-stomp/conf:/etc/rabbitmq -p 5617:5617 -p 5672:5672 -p 4369:4369 -p 15671:15671 -p 15672:15672 -p 25672:2…...
微信小程序获取位置服务
wx.getLocation({type: gcj02,success(res) {wx.log(定位成功);},fail(err) {wx.log(定位失败, err);wx.showModal({content: 请打开手机和小程序中的定位服务,success: (modRes) > {if (modRes.confirm) {wx.openSetting({success(setRes) {if (setRes.authSetting[scope.u…...
Docker Load后存储的镜像及更改镜像存储目录的方法
Docker Load后存储的镜像及更改镜像存储目录的方法 Docker Load后存储的镜像更改镜像存储目录的方法脚本说明注意事项Docker作为一种开源的应用容器引擎,已经广泛应用于软件开发、测试和生产环境中。通过Docker,开发者可以将应用打包成镜像,轻松地进行分发和运行。而在某些场…...
Langchain本地知识库部署
本地部署(Docker + LangChain + FAISS) 1. 概述 本地部署 LangChain-Chatchat 可以为企业提供高效、安全、可控的 AI 知识库方案。本方案基于 Docker、LangChain 和 FAISS 进行本地化部署,适用于企业内部知识库问答、私有化 AI 应用等场景。 2. 技术选型 2.1 LangChain …...
百度网盘提取码智能获取工具:提升资源获取效率的技术方案
百度网盘提取码智能获取工具:提升资源获取效率的技术方案 【免费下载链接】baidupankey 项目地址: https://gitcode.com/gh_mirrors/ba/baidupankey 在数字资源爆炸的今天,百度网盘作为主流文件分享平台,已成为学习资料、工作文件和媒…...
javase的第一次博客
1,计算机简介:用于数据计算和处理2,计算机的硬件和软件:计算机硬件:运算器,控制器,存储器,输入设备,输出设备(冯 诺依曼模型)CPU:运算…...
Java微服务Istio配置必须立即更新的4个安全补丁:CVE-2024-23652等高危漏洞绕过配置详解
第一章:Java微服务Istio配置安全补丁的紧急性与背景近年来,Java微服务架构在云原生环境中广泛应用,而Istio作为主流服务网格控制平面,承担着流量管理、可观测性与零信任安全策略实施的关键角色。然而,2024年披露的CVE-…...
PyTorch 2.8多场景实操:科研训练+工程推理+内容创作的统一技术底座
PyTorch 2.8多场景实操:科研训练工程推理内容创作的统一技术底座 1. 为什么选择PyTorch 2.8作为统一技术底座 PyTorch 2.8作为当前最主流的深度学习框架之一,已经成为学术界和工业界的首选工具。这个基于RTX 4090D 24GB显卡深度优化的镜像,…...
终极Windows驱动管理指南:如何用DriverStore Explorer快速释放30GB磁盘空间
终极Windows驱动管理指南:如何用DriverStore Explorer快速释放30GB磁盘空间 【免费下载链接】DriverStoreExplorer Driver Store Explorer 项目地址: https://gitcode.com/gh_mirrors/dr/DriverStoreExplorer DriverStore Explorer(简称RAPR&…...
Matlab GUI计时器:自动更新的数字时钟与恢复/暂停功能的定时器对象实现
Matlab图形用户界面计时器:使用定时器对象自动更新的MatlabGUI,一个数字时钟,作为显示基本组件的快速演示,带有一个按钮,用于恢复/暂停执行更新 实验室配了新酶标仪孵箱但总有人(比如同组摸鱼的小师妹顺便…...
【RT-DETR涨点改进】SCI一区 2025顶刊 |全网独家创新,注意力改进篇 | RT-DETR引入DOAM动态全向注意力模块,模块,显著增强了特征表达能力和结构恢复能力,含7种独家创新改进点
一、本文介绍 🔥本文给大家介绍利用 DOAM 动态全向注意力模块改进RT-DETR网络模型,可在不显著增加计算量的前提下增强全局上下文建模能力,通过空间轴向聚合获得更强的跨区域信息交互,并用通道动态加权突出目标相关特征、抑制背景干扰,从而优化多尺度特征融合效果,提升小…...
Phi-4-mini-reasoning企业知识库接入:PDF解析+向量化+推理问答闭环
Phi-4-mini-reasoning企业知识库接入:PDF解析向量化推理问答闭环 1. 模型简介与部署验证 Phi-4-mini-reasoning 是一个基于合成数据构建的轻量级开源模型,专注于高质量、密集推理的数据处理能力。作为Phi-4模型家族成员,它特别强化了数学推…...
GEE引擎封挂实战:从M2参数到RunGate网关的完整配置指南
GEE引擎封挂实战:从M2参数到RunGate网关的完整配置指南 在游戏运营过程中,外挂问题一直是困扰开发者和运营者的顽疾。对于使用GEE引擎的游戏服务器来说,如何有效防范和打击外挂行为,维护游戏公平性,是每个技术团队必须…...
seo优化一个月大概要花费多少_seo 优化一个月需要多少预算
SEO 优化一个月需要多少预算:详细分析与实用建议 在当今的数字时代,网站的SEO优化是提升网站流量和品牌知名度的关键。SEO 优化一个月大概要花费多少,SEO 优化一个月需要多少预算呢?这个问题困扰着许多企业和个人。本文将从问题分…...
