当前位置: 首页 > news >正文

深入解析全连接层:PyTorch 中的 nn.Linear、nn.Parameter 及矩阵运算

文章目录

这篇文章会从基础的一个数学概念到对应的代码实现,你将了解到:

  • 为什么nn.Parameter()接受 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)作为参数?
  • 为什么不是torch.matmul(self.weight, x) + self.bias
  • 如何使用torch.matmul()@F.linear() 去等价地实现nn.Linear()的输出。

数学概念(全连接层,线性层)

线性变化是数学中一个基础的概念,它描述了如何通过线性变换将输入映射到输出。在线性代数中,线性变化通常表示为矩阵乘法。在神经网络中,线性层的核心就是实现这样的矩阵运算。

数学公式:

给定一个输入向量 x ∈ R n \mathbf{x} \in \mathbb{R}^n xRn 和一个输出向量 y ∈ R m \mathbf{y} \in \mathbb{R}^m yRm,线性变化通过矩阵 W ∈ R m × n \mathbf{W} \in \mathbb{R}^{m \times n} WRm×n 和偏置项 b ∈ R m \mathbf{b} \in \mathbb{R}^m bRm 进行变换,其公式为:
y = W x + b \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} y=Wx+b

  • W \mathbf{W} W:是权重矩阵,维度为 m × n m \times n m×n,它决定了输入向量如何线性变换到输出空间;
  • x \mathbf{x} x:是输入向量,维度为 n n n,表示特征数据;
  • b \mathbf{b} b:是偏置向量,维度为 m m m,用来调整线性变换的输出;
  • y \mathbf{y} y:是输出向量,维度为 m m m,是变换后的结果。

例子:

如果输入向量 x \mathbf{x} x 有 3 个特征,输出向量 y \mathbf{y} y 有 2 个特征,则权重矩阵 W \mathbf{W} W 的形状为 2 × 3 2 \times 3 2×3。假设:
W = [ 1 2 3 4 5 6 ] , x = [ 1 2 3 ] , b = [ 0 1 ] \mathbf{W} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}, \quad \mathbf{x} = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}, \quad \mathbf{b} = \begin{bmatrix} 0 \\ 1 \end{bmatrix} W=[142536],x= 123 ,b=[01]
线性变换计算为:
y = W x + b = [ 1 2 3 4 5 6 ] [ 1 2 3 ] + [ 0 1 ] = [ 14 32 ] + [ 0 1 ] = [ 14 33 ] \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 33 \end{bmatrix} y=Wx+b=[142536] 123 +[01]=[1432]+[01]=[1433]
矩阵运算过程:
[ 1 2 3 4 5 6 ] [ 1 2 3 ] = [ ( 1 × 1 ) + ( 2 × 2 ) + ( 3 × 3 ) ( 4 × 1 ) + ( 5 × 2 ) + ( 6 × 3 ) ] = [ 14 32 ] \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} = \begin{bmatrix} (1 \times 1) + (2 \times 2) + (3 \times 3) \\ (4 \times 1) + (5 \times 2) + (6 \times 3) \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} [142536] 123 =[(1×1)+(2×2)+(3×3)(4×1)+(5×2)+(6×3)]=[1432]

nn.Linear()

nn.Linear() 会自动创建一个权重矩阵(Weight)和偏置项(Bias),并将它们应用到输入上。

代码示例:

import torch
import torch.nn as nn# 定义一个输入为3,输出为2的线性层
linear_layer = nn.Linear(3, 2)# 打印权重矩阵和偏置项
print("权重矩阵 W:")
print(linear_layer.weight)print("偏置项 b:")
print(linear_layer.bias)# 模拟输入向量
input_vector = torch.tensor([1.0, 2.0, 3.0])
output_vector = linear_layer(input_vector)
print("输出向量 y:")
print(output_vector)

image-20240912221728559

在这里,nn.Linear(3, 2) 创建了一个 2×3 的权重矩阵和一个 2 维的偏置向量。通过 linear_layer(input_vector),可以直接获得输入向量经过线性变换后的输出。

nn.Parameter()

在 PyTorch 中,nn.Linear() 自动处理了权重和偏置项的初始化和更新,但有时你可能希望对这些参数自定义一些操作,比如 LoRA。这时,我们可以使用 nn.Parameter() 来自定义权重和偏置,其实 nn.Linear() 本身就是使用的nn.Parameter(),感兴趣的话可以看官方源码。

以自定义线性层为例:

class CustomLinearLayer(nn.Module):def __init__(self, input_dim, output_dim):super(CustomLinearLayer, self).__init__()# 使用 nn.Parameter 手动定义权重和偏置self.weight = nn.Parameter(torch.randn(output_dim, input_dim))self.bias = nn.Parameter(torch.randn(output_dim))def forward(self, x):# 手动实现线性变换 y = Wx + breturn torch.matmul(x, self.weight.T) + self.bias# 使用自定义的线性层
custom_layer = CustomLinearLayer(3, 2)
output = custom_layer(input_vector)
print(output)

image-20240912222625609

在看完代码后,你可能会产生两个疑惑:

Q

1. 为什么 self.weight 的权重矩阵 shape 使用 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)而不是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)?

实际这正是我写这篇博客分享的原因,现在进入正题,因为这个疑惑完全不应该产生。

让我们重新使用 in_features \text{in\_features} in_features out_features \text{out\_features} out_features来重现之前的数学定义:

对于输入向量 x ∈ R in_features \mathbf{x} \in \mathbb{R}^{\text{in\_features}} xRin_features,全连接层的输出为:

y = W x + b \mathbf{y} = W\mathbf{x} + \mathbf{b} y=Wx+b

其中:

  • W ∈ R out_features × in_features W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} WRout_features×in_features 是权重矩阵,
  • b ∈ R out_features \mathbf{b} \in \mathbb{R}^{\text{out\_features}} bRout_features 是偏置项。

在线性变换中,输入向量 x \mathbf{x} x 的维度是 in_features \text{in\_features} in_features,而输出向量 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features。根据矩阵乘法的规则,要将输入 x \mathbf{x} x 映射到输出 y \mathbf{y} y,权重矩阵 W W W 的形状应该是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features),因为矩阵乘法中 W x W\mathbf{x} Wx的维度要求是:

( out_features × in_features ) × ( in_features × 1 ) = ( out_features × 1 ) (\text{out\_features} \times \text{in\_features}) \times (\text{in\_features} \times 1) = (\text{out\_features} \times 1) (out_features×in_features)×(in_features×1)=(out_features×1)

这保证了输出 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features

如果权重矩阵的形状是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features),矩阵乘法的维度将不匹配,无法实现线性变换。

现在是不是感觉清晰了?不要 nn.Linear(in_feature, out_feature) 用多了就将权重矩阵当作是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)遗忘了线性代数的概念,数学才是这一切的基石。

2. 为什么是torch.matmul(x, self.weight.T) + self.bias 而不是torch.matmul(self.weight, x) + self.bias?

主要原因还是在于 输入张量 x 的形状矩阵乘法规则

一般来说,模型的输入 x 实际上并不是 ( in_features , 1 ) (\text{in\_features}, 1) (in_features,1),而是 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features),而权重矩阵 self.weight 的形状是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)​,我们需要实现的线性变换是:
y = W x + b y = W x + b y=Wx+b
根据矩阵乘法规则,第一个矩阵的列数必须等于第二个矩阵的行数,这意味着我们不能直接计算 torch.matmul(self.weight, x),因为这样会导致维度不匹配:

  • self.weight 形状为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)x 形状为 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features)
  • torch.matmul(self.weight, x) 的维度计算规则将要求 x 的形状为 ( in_features , batch_size ) (\text{in\_features}, \text{batch\_size}) (in_features,batch_size),但这与模型的输入不匹配。

因此,正确的矩阵乘法应该是 torch.matmul(x, self.weight.T),其中 self.weight.T 表示 self.weight 的转置矩阵,此时的形状为 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)

这样,torch.matmul(x, self.weight.T) 的维度计算为:

( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}) (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features)

这就得到了正确的输出形状 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features)

3. 为什么不直接设置self.weight = nn.Parameter(torch.randn(input_dim, output_dim))

这样不就可以不转置直接使用torch.matmul(x, self.weight)了吗?的确如此,或许是因为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features) 对于矩阵运算 W x W\mathbf{x} Wx 来讲更符合直觉吧。

计算过程的细分:torch.matmul() vs @ 运算符

在 PyTorch 中,torch.matmul() 用于实现矩阵乘法,而 @ 是其简洁的符号形式,是 Python 的语法糖,二者在功能上是等价的。

示例代码:

import torch# 定义权重矩阵 W 和输入向量 input_vector
W = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
input_vector = torch.tensor([1.0, 2.0, 3.0])# 使用 torch.matmul 实现矩阵乘法
result1 = torch.matmul(W, input_vector)# 使用 @ 运算符
result2 = W @ input_vectorprint("使用 torch.matmul 计算的结果:")
print(result1)print("使用 @ 运算符计算的结果:")
print(result2)

结果:

image-20240912233355773

使用 F.linear()

PyTorch 提供了 F.linear() 作为函数式接口,它与 nn.Linear() 类似,但不需要创建一个线性层对象。F.linear() 可以接受线性层的权重和偏置作为输入,在下一篇有关 LoRA 的文章中,你将看到使用范例。

示例代码:

import torch.nn.functional as F# 使用 F.linear 进行线性变换
output = F.linear(input_vector, linear_layer.weight, linear_layer.bias)
print(output)

image-20240912233501651

相关文章:

深入解析全连接层:PyTorch 中的 nn.Linear、nn.Parameter 及矩阵运算

文章目录 数学概念(全连接层,线性层)nn.Linear()nn.Parameter()Q1. 为什么 self.weight 的权重矩阵 shape 使用 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)而不是 ( in_featur…...

缓存对象反序列化失败

未定义serialVersionUID,会自动生成序列化号 新增了属性,序列号就变了,导致缓存对象反序列化失败。 所有缓存对象必须指定序列化id! 那我如何找到未添加字段前 对象的序列化号呢?默认的序列化号是如何生成的呢&#…...

F28335的存储器与寄存器

1 存储器及CMD文件的编写 1 F28335的存储器 1.1 F28335存储器的结构 1.2 F28335存储器的映像 存储器本身不具有地址信息,它的地址是由芯片厂商或用户分配,给存储器分配地址的过程称为存储器映射,如果再分配一个地址就叫重映射。 我们将《tms320f28335 数据手册》中“3.1…...

Python在AOIP(Audio Over IP)方面的应用探讨

Python在AOIP(Audio Over IP)方面的应用探讨 引言 随着网络技术的发展,音频传输逐渐向基于IP的解决方案迁移。音频通过互联网进行传输被称为音频过IP(Audio Over IP,简称AOIP)。这种技术在广播、现场活动…...

C++20标准对线程库的改进:更安全、更高效的并发编程

引言 C20 是 C 语言的一个重要里程碑,它引入了许多新特性,其中就包括对线程库(thread)的重大改进。这些改进不仅增强了语言的并发编程能力,还解决了先前版本中的一些痛点问题。本文将详细介绍 C20 在线程方面的改进&a…...

外包干了三年,快要废了。。。

先简单说一下自己的情况,普通本科,在外包干了3年多的功能测试,这几年因为大环境不好,我整个人心惊胆战的,怕自己卷铺盖走人了,我感觉自己不能够在这样蹉跎下去了,长时间呆在一个舒适的环境真的会…...

微服务网关终极进化:设计模式驱动的性能与可用性优化(四)

时间:2024年09月12日 作者:小蒋聊技术 邮箱:wei_wei10163.com 微信:wei_wei10 希望大家帮个忙!如果大家有工作机会,希望帮小蒋推荐一下,小蒋希望遇到一个认真做事的团队,一起努力…...

Java中的服务端点日志记录:AOP与SLF4J

Java中的服务端点日志记录:AOP与SLF4J 大家好,我是微赚淘客返利系统3.0的小编,是个冬天不穿秋裤,天冷也要风度的程序猿! 在Java后端服务开发中,日志记录是监控和调试应用的关键手段。通过合理使用AOP&…...

黑马头条第八天实战(上)

D8 1)登录功能需求说明 用户根据用户名和密码登录密码需要手动加盐验证需要返回用户的token和用户信息 2)模块搭建思路步骤 2.1)模块作用 先捋一下之前搭模块干了啥 feign-api 远程调用 自媒体保存时调用远程客户端进行增加文章&#x…...

swift qwen2-vl推理及加载lora使用案例

参考: https://swift.readthedocs.io/zh-cn/latest/Instruction/LLM%E5%BE%AE%E8%B0%83%E6%96%87%E6%A1%A3.html#%E5%BE%AE%E8%B0%83%E5%90%8E%E6%A8%A1%E5%9E%8B https://blog.csdn.net/weixin_42357472/article/details/142150209 SWIFT支持300+ LLM和50+ MLLM(多模态大模型…...

如何使用 Choreographer 进行帧率优化

Choreographer 是 Android 提供的一个工具类,专门用来协调 UI 帧的渲染。你可以通过 Choreographer 来精确控制帧的绘制时机,以优化帧率,确保应用的流畅度。以下是如何使用 Choreographer 进行帧率优化的详细步骤: 1. 理解 Chore…...

稳定驱动之选SiLM5350系列SiLM5350MDBCA-DG单通道隔离栅极驱动器(带内部钳位):工业自动化的可靠伙伴

SiLM5350系列SiLM5350MDBCA-DG是具体有10A峰值输出电流能力,单通道隔离式栅极驱动器。SiLM5350MDBCA-DG可提供内部钳位功能。驱动电源电压为4V至30V。3V至18V的宽输入VDDI范围使驱动器适合与模拟和数字控制器接口。所有电源电压引脚都有欠压锁定 (UVLO) 保护。 SiLM…...

鸿蒙OpenHarmony【轻量系统芯片移植】内核移植

移植芯片架构 芯片架构的移植是内核移植的基础,在OpenHarmony中芯片架构移植是可选过程,如果当前OpenHarmony已经支持对应芯片架构则不需要移植操作,在“liteos_m/arch”目录下可看到当前已经支持的架构,如表1: 表1 …...

多字节字符和宽字符

小时候,买东西的单位是一角、二角和五角,现在的单位是一元、五元和十元。人类社会的发展和计算机发展本质没啥两样,形态不同而已。 编码格式的历史 尽管早期只用ASCII码就可以表达所有字符,但计算机日益推广让其他国家不同语言的…...

C++缺省参数

个人主页:Jason_from_China-CSDN博客 所属栏目:C系统性学习_Jason_from_China的博客-CSDN博客 缺省参数的概念 缺省参数是声明或定义函数时为函数的参数指定一个缺省值。在调用该函数时,如果没有指定实参则采用该形参的缺省值,否则…...

深度学习中的常用线性代数知识汇总——第一篇:基础概念、秩、奇异值

文章目录 0. 前言1. 基础概念2. 矩阵的秩2.1 秩的定义2.2 秩的计算方法2.3 秩在深度学习中的应用 3. 矩阵的奇异值3.1 奇异值分解(SVD)3.2 奇异值的定义3.3 奇异值的性质3.4 奇异值的意义3.5 实例说明3.6 奇异值在深度学习中的应用 0. 前言 按照国际惯例…...

MATLAB | R2024b更新了哪些好玩的东西?

Hey, 又到了一年两度的MATLAB更新时刻,MATLAB R2024b正式版发布啦!,直接来看看有哪些我认为比较有意思的更新吧! 1 小提琴图 天塌了,我这两天才写了个半小提琴图咋画,MATLAB 官方就出了小提琴图绘制方法。 小提琴图…...

嵌入式硬件基础知识

嵌入式硬件基础知识涵盖了嵌入式系统中的硬件组成及其工作原理,涉及处理器、存储器、外设接口、电源管理等多个方面。这些硬件共同构成了一个完整的嵌入式系统,用于执行特定任务。下面我们来详细介绍嵌入式硬件的基础知识。 1. 嵌入式系统的组成 嵌入式…...

keepalived和lvs高可用集群

keepavlied和lvs高可用集群搭建 主备模式: 关闭防火墙和selinux systemctl stop firewalld setenforce 0部署master负载调度服务器 zyj86 安装ipvsadm keepalived yum install -y keepalived ipvsadm修改主节点配置 vim /etc/keepalived/keepalived.conf! Conf…...

在VMware部署银河麒麟系统

虚拟机镜像安装文件从下面下载: 银河麒麟桌面操作系统V10SP1 2403 下载地址_银河麒麟v10镜像iso下载-CSDN博客 虚拟机安装要求硬盘大小至少40G,我悬着60G 选择桥接网络安装后上不了网并且和本机也互相ping不通,因此选择Nat方式,然后重启,就可以上网 下面开始安装,第一个…...

git删除本地分支报错:error: the branch ‘xxx‘ is not fully merged

git删除本地分支报错:error: the branch xxx is not fully merged error: the branch xxx is not fully merged 直接: git branch -D xxx 就可以。 如果删除远程分支: git push origin --delete origin/xxx git强制删除本地分支 git branc…...

Tensorflow 兼容性测试-opencloudos

介绍 Tensorflow 兼容性测试: 测试 Tensorflow 各个版本在 OpenCloudOS Stream 的安装支持 操作系统 [rootlab101 ~]# cat /etc/os-release NAME"OpenCloudOS Stream" VERSION"23" ID"opencloudos" ID_LIKE"opencloudos" VERSION_I…...

Windows主机上安装CUPS服务端共享USB打印机实践心得

背景 平时主力机器是Windows,不想额外开一个Linux服务器来共享打印机。由于主力机平时也不关机,尝试在Windows上安装CUPS服务。 结论 先说结论,结论是可行,但是麻烦且不稳定,虚拟机方案少折腾,但是资源消耗…...

socket通讯原理及例程(详解)

里面有疑问或者不正确的地方可以给我留言。 对TCP/IP、UDP、Socket编程这些词你不会很陌生吧?随着网络技术的发展,这些词充斥着我们的耳朵。那么我想问: 什么是TCP/IP、UDP?Socket在哪里呢?Socket是什么呢&#xff1…...

vue3使用provide和inject传递异步请求数据子组件接收不到

前言 一般接口返回的格式是数组或对象,使用reactive定义共享变量 父组件传递 const data reactive([])// 使用settimout模拟接口返回 setTimeout(() > {// 将接口返回的数据赋值给变量Object.assign(data, [{ id: 10000 }]) }, 3000);provide(shareData, dat…...

对称矩阵的压缩存储

1.给自己出题:自己动手创造,画一个5行5列的对称矩阵 2.画图:按“行优先”压缩存储上述矩阵,画出一维数组的样子 3.简答:写出元素 i,j 与 数组下标之间的对应关系 4.画图:按“列优先”压缩存储上述矩阵&a…...

高阶数据结构之哈希表基础讲解与模拟实现

程序猿的读书历程:x语言入门—>x语言应用实践—>x语言高阶编程—>x语言的科学与艺术—>编程之美—>编程之道—>编程之禅—>颈椎病康复指南。 前言: 哈希表(Hash Table)是一种高效的键值对存储数据结构&…...

基于STM32设计的智能货架(华为云IOT)(225)

文章目录 一、前言1.1 项目介绍【1】项目背景【2】项目支持的功能【3】项目硬件模块组成【4】ESP8266工作模式配置【5】Android手机APP开发思路【6】项目模块划分1.2 项目开发背景【1】选题来源与背景【2】国内外研究现状【3】课题研究的目的和内容【4】参考文献【5】研究内容【…...

JDBC API详解一

DriverManager 驱动管理类,作用:1,注册驱动;2,获取数据库连接 1,注册驱动 Class.forName("com.mysql.cj.jdbc.Driver"); 查看Driver类源码 static{try{DriverManager.registerDriver(newDrive…...

工厂安灯系统在设备管理中的重要性

在现代制造业中,设备管理是确保生产效率和产品质量的关键环节。随着工业4.0的推进,越来越多的企业开始采用智能化的设备管理系统,其中安灯系统作为一种有效的管理工具,逐渐受到重视。安灯系统最初源于日本的丰田生产方式&#xff…...