Prefix-Tuning源码解析
Prefix-Tuning源码解析
Prefix-Tuning在PEFT包中的源码实现
改写自Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
import torch
from transformers import PretrainedConfigclass PrefixEncoder(torch.nn.Module):r'''The torch.nn model to encode the prefixInput shape: (batch-size, prefix-length)Output shape: (batch-size, prefix-length, 2*layers*hidden)'''def __init__(self, config):super().__init__()self.prefix_projection = config.prefix_projectionif self.prefix_projection:# Use a two-layer MLP to encode the prefixself.embedding = torch.nn.Embedding(config.prefix_length, config.hidden_size)self.trans = torch.nn.Sequential(torch.nn.Linear(config.hidden_size, config.encoder_hidden_size),torch.nn.Tanh(),torch.nn.Linear(config.encoder_hidden_size, config.num_hidden_layers * 2 * config.hidden_size))else:self.embedding = torch.nn.Embedding(config.prefix_length, config.num_hidden_layers * 2 * config.hidden_size)def forward(self, prefix: torch.Tensor):if self.prefix_projection:prefix_tokens = self.embedding(prefix)past_key_values = self.trans(prefix_tokens)else:past_key_values = self.embedding(prefix)return past_key_valuesif __name__ == "__main__":configs = {"prefix_length":20,"hidden_size":768,"encoder_hidden_size":768,"num_hidden_layers":12,"prefix_projection":False}prefix_encoder = PrefixEncoder(config=PretrainedConfig.from_dict(configs))print(prefix_encoder)batch_size = 8prefix = torch.arange(20).long().expand(batch_size, -1)print(prefix.shape)output = prefix_encoder(prefix)print(output.shape)
下面我们以T5-large模型为例子:
不考虑Use a two-layer MLP to encode the prefix的话,prefix tuning主要包括以下代码:
class PrefixEncoder(torch.nn.Module):def __init__(self, config):super().__init__()...self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) #num_virtual_tokens=20,token_dim=1024,num_layers=24def forward(self, prefix: torch.Tensor):past_key_values = self.embedding(prefix)return past_key_values
得到的PrefixEncoder被传入peft->peft_model.py->prompt_encoder:
PrefixEncoder((embedding): Embedding(20, 49152) # 1024*2*24
)
self.prompt_tokens初始化为长度2*20的向量,因为T5有编码器和解码器,需要两次prefix:
self.prompt_tokens[adapter_name] = torch.arange(config.num_virtual_tokens * config.num_transformer_submodules).long() #20*2# tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
# 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
# 36, 37, 38, 39])
prompt_tokens = (self.prompt_tokens[self.active_adapter].unsqueeze(0).expand(batch_size, -1).to(prompt_encoder.embedding.weight.device))
prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
# 此时prompt_tokens.shape = (batch_size=8, num_virtual_tokens=20)past_key_values = prompt_encoder(prompt_tokens)
torch.Size([8, 20, 49152])
但目前的past_key_values还是所有层的集合,我们需要把past_key_values分解为每一层:
past_key_values = past_key_values.view(batch_size, #8peft_config.num_virtual_tokens, #20peft_config.num_layers * 2, #24*2peft_config.num_attention_heads, #16peft_config.token_dim // peft_config.num_attention_heads, #1024/16)
# torch.Size([8, 20, 48, 16, 64])
因为有编码器和解码器,所以再复制一次
past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
# torch.Size([8, 20, 96, 16, 64])# 重排:torch.Size([96, 8, 16, 20, 64])
# 然后split成一个长度为24的tuple,每个tuple的shape:torch.Size([4, 8, 16, 20, 64])
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(peft_config.num_transformer_submodules * 2)
也就是说past_key_values是24个层的Prefix embedding,形状为`(num_transformer_submodules * 2, batch_size, num_attention_heads, num_virtual_tokens, token_dim/num_attention_heads])
注意这里*2是因为key+value.
transformers->models->t5->modeling_t5.py->T5Attention类,这里的关键步骤是project函数中的hidden_states = torch.cat([past_key_value, hidden_states], dim=2),注意project函数仅仅用于key和value。
def forward(self,hidden_states,mask=None,key_value_states=None,position_bias=None,past_key_value=None,layer_head_mask=None,query_length=None,use_cache=False,output_attentions=False,):"""Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states)."""# Input is (batch_size, seq_length, dim)# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)batch_size, seq_length = hidden_states.shape[:2]real_seq_length = seq_lengthif past_key_value is not None:if len(past_key_value) != 2:raise ValueError(f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states")real_seq_length += past_key_value[0].shape[2] if query_length is None else query_lengthkey_length = real_seq_length if key_value_states is None else key_value_states.shape[1]def shape(states):"""projection"""return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)def unshape(states):"""reshape"""return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)def project(hidden_states, proj_layer, key_value_states, past_key_value):"""projects hidden states correctly to key/query states"""if key_value_states is None:# self-attn# (batch_size, n_heads, seq_length, dim_per_head)hidden_states = shape(proj_layer(hidden_states))elif past_key_value is None:# cross-attn# (batch_size, n_heads, seq_length, dim_per_head)hidden_states = shape(proj_layer(key_value_states))if past_key_value is not None:if key_value_states is None:# self-attn# (batch_size, n_heads, key_length, dim_per_head)# 注意这里是重点:用串联方式hidden_states = torch.cat([past_key_value, hidden_states], dim=2)elif past_key_value.shape[2] != key_value_states.shape[1]:# checking that the ` sequence_length` of the `past_key_value` is the same as# the provided `key_value_states` to support prefix tuning# cross-attn# (batch_size, n_heads, seq_length, dim_per_head)hidden_states = shape(proj_layer(key_value_states))else:# cross-attnhidden_states = past_key_valuereturn hidden_statesreal_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
分别计算query_states、key_states、value_states,用query和key计算attention score,得到score形状为torch.Size([8, 16, 2, 22]),所以输入X可以attend to itself以及prefix。
# hidden_states shape: torch.Size([8, 2, 1024]) # get query statesquery_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) # query_states shape: torch.Size([8, 16, 2, 64])# get key/value stateskey_states = project(hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None)# key_states shape: torch.Size([8, 16, 22, 64])value_states = project(hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None)# value_states shape: torch.Size([8, 16, 22, 64])# compute scores# torch.Size([8, 16, 2, 22])scores = torch.matmul(query_states, key_states.transpose(3, 2)) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
接下来就是经典的attention操作了。用attn_weights ([8, 16, 2, 22]) 和value_states ([8, 16, 22, 64])相乘,把22消掉,就是每个输入X的输出了。
# if key and values are already calculated
# we want only the last query position bias
# position_bias.shape: torch.Size([8, 16, 2, 22])scores += position_bias_maskedattn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) torch.Size([8, 2, 1024])attn_output = self.o(attn_output)present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else Noneoutputs = (attn_output,) + (present_key_value_state,) + (position_bias,)if output_attentions:outputs = outputs + (attn_weights,)return outputs
参考
https://huggingface.co/docs/peft/task_guides/seq2seq-prefix-tuning
相关文章:
Prefix-Tuning源码解析
Prefix-Tuning源码解析 Prefix-Tuning在PEFT包中的源码实现 改写自Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py import torch from transformers import PretrainedConfigclass PrefixEncoder(torch.nn.Module):rThe torch.nn model t…...
Java EE-servlet API 三种主要的类
上述的代码如下: import javax.servlet.ServletException; import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.i…...
简单谈谈我参加数据分析省赛的感受与体会
数据分析省赛的感受与体会 概要考试前的感受与体会考试注意事项小结 概要 大数据分析省赛指的是在省级范围内举办的大数据分析竞赛活动。该竞赛旨在鼓励和推动大数据分析领域的技术创新和人才培养,促进大数据技术与应用的深度融合,切实解决实际问题。参…...
rust学习——泛型 (Generics)
文章目录 泛型 Generics泛型详解结构体中使用泛型枚举中使用泛型方法中使用泛型为具体的泛型类型实现方法 const 泛型(Rust 1.51 版本引入的重要特性)const 泛型表达式 泛型的性能 泛型 Generics Go 语言在 2022 年,就要正式引入泛型…...
【USRP】通信之有线通信
有线通信: 有线通信是指使用物理线路或媒体(例如,铜线、同轴电缆、光纤)进行数据、声音和视频传输的通信方式。由于它依赖于实体传输媒介,有线通信通常具有较高的稳定性和可靠性,并能支持长距离的高带宽通…...
【算法】BFS
BFS广度优先搜索 1. 概念理解 广度优先搜索(BFS)是指,以一个起点(原点、结点、根)为基本点,向其所要搜索的方向扩散,并最终到达目标点的搜索方法。 2. 应用方向 有迷宫问题、层序遍历等应用。 3. 迷宫问题 以迷宫问题为例。 当想要从左…...
ZYNQ7020开发(二):zynq linux系统编译
文章目录 一、编译前准备二、SDK编译三、编译步骤总结四、问题汇总 一、编译前准备 1.设置环境变量 source /opt/pkg/petalinux/2020.2/settings.sh/opt/pkg/petalinux/2020.2是上一节petalinux的安装目录 2.创建 petalinux 工程 进入petalinux安装目录(例如:/op…...
Kafka 自动配置部署信息的脚本记录
自动配置 Kafka 整理服务器内容时,发现一个测试 Kafka 的的一个脚本,它可以自动部署 Kafka ,指定三个参数,完成 Kafka 的配置过程。 basePath$1 brokerId$2 zookeeperConnect$3 localIpifconfig |grep inet| awk {print $2}| he…...
数据分析入门
B站:01第一课 数据分析岗位职责和数据分析师_哔哩哔哩_bilibili 一、岗位:数据分析师 Q1 数据分析师在公司做什么工作? 数据来源于公司核心业务,通过监测业务健康度来确定业务的健康状况; 通过对用户精细化分析&am…...
车载网关通信能力解析——SV900-5G车载网关推荐
随着车联网的发展,各类车载设备对车载网关的需求日益增长。车载网关作为车与车、车与路、车与云之间连接的关键设备,其通信能力直接影响整个系统的性能。本文将详细解析车载网关的通信能力,并推荐性价比高的SV900-5G车载网关。 链接直达:https://www.key-iot.com/i…...
服务器中了mkp勒索病毒怎么处理,mkp勒索病毒解密,数据恢复
10月份以来,云天数据恢复中心陆续接到很多企业的求助,企业的服务器遭到了mkp勒索病毒攻击,导致企业的服务器数据库被加密,严重影响了企业工作,通过这一波mkp勒索病毒的攻击,云天数据恢复工程师为大家总结了…...
义乌再次位列第一档!2022年跨境电商综试区评估结果揭晓!
义乌跨境电商综试区捷报频传,在商务部公布的“2022年跨境电子商务综合试验区评估”结果中,中国(义乌)跨境电子商务综合试验区(以下简称:“跨境综试区”)评估结果为成效明显,综合排名…...
07、Python -- 序列相关函数与封包解包
目录 使用函数字符串也能比较大小序列封包序列解包多变量同时赋值 最大值、最小值、长度 序列解包与封包 使用函数 len()、max()、min() 函数可获取元组、列表的长度、最大值和最小值。 字符串也能比较大小 字符串比较大小时,将会依次按字符串中每个字符对应的编…...
# Spring 事务失效场景
Spring 事务失效场景 文章目录 Spring 事务失效场景前言事务不生效未开启事务事务方法未被Spring管理访问权限问题基于接口的代理源码解读 CGLIB代理 方法用final修饰同一类中的方法调用多线程调用不支持事务 事务不回滚设置错误的事务传播机制捕获了异常手动抛了别的异常自定义…...
华为OD 停车场车辆统计(100分)【java】A卷+B卷
华为OD统一考试A卷+B卷 新题库说明 你收到的链接上面会标注A卷还是B卷。目前大部分收到的都是B卷。 B卷对应20022部分考题以及新出的题目,A卷对应的是新出的题目。 我将持续更新最新题目 获取更多免费题目可前往夸克网盘下载,请点击以下链接进入: 我用夸克网盘分享了「华为O…...
出差学小白知识No6:LD_PRELOAD变量路径不对找不到库文件
交叉编译的时候出现以下问题,显示LD_PRELOAD变量找不到路劲 首先先查看一下LD_PRELOAD的路径:echo $LD_PRELOAD 如果输出一大串,那么先进行清空:unset LD_PRELOAD 重新给LD_PRELOAD进行赋值他的路径和库文件: expor…...
利用dns协议发起ddos反射攻击
利用DNS服务器发起反射型DDOS,攻击带宽 基本思路: 1、利用any类型的dns查询,可完成发送少量请求数据,获得大量返回数据。 2、将原请求地址改为受害者地址,则dns会向受害者返回大量数据,占用带宽 警告&…...
Tcl基础知识
一、概述 Tcl 语言的全称 Tool Command Language,即工具命令语言。这种需要在 EDA 工具中使用的相当之多,或者说几乎每个 EDA 工具都支持 Tcl 语言,并将它作为自己的命令shell。 静态时序分析中多用的 Synopsys Tcl 语言,…...
Go中的编程模式:Pipeline
本文章我们重点来介绍一下 Go 编程中的 Pipeline 模式。用过 Linux 命令行的人都不会陌生,它是一种把各种命令拼接起来完成一个更强功能的技术方法,在C语言中也有pipe管道的叫法,具体的有兴趣的同学也可以去了解。 现在的流式处理、函数式编程、应用网关对微服务进行简单的…...
2023最新pytorch安装教程,简单易懂,面向初学者(Anaconda+GPU)
一、前言 目前是2023.1.27,鉴于本人安装过程中踩得坑,安装之前我先给即将安装pytorch的各位提个醒,有以下几点需要注意 1.判断自己电脑是否有GPU 注意这点很重要,本教程面向有NVIDA显卡的电脑,如果你的电脑没有GPU或者使用AMD显…...
微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】
微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来,Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...
SciencePlots——绘制论文中的图片
文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了:一行…...
2021-03-15 iview一些问题
1.iview 在使用tree组件时,发现没有set类的方法,只有get,那么要改变tree值,只能遍历treeData,递归修改treeData的checked,发现无法更改,原因在于check模式下,子元素的勾选状态跟父节…...
BCS 2025|百度副总裁陈洋:智能体在安全领域的应用实践
6月5日,2025全球数字经济大会数字安全主论坛暨北京网络安全大会在国家会议中心隆重开幕。百度副总裁陈洋受邀出席,并作《智能体在安全领域的应用实践》主题演讲,分享了在智能体在安全领域的突破性实践。他指出,百度通过将安全能力…...
【python异步多线程】异步多线程爬虫代码示例
claude生成的python多线程、异步代码示例,模拟20个网页的爬取,每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程:允许程序同时执行多个任务,提高IO密集型任务(如网络请求)的效率…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习) 一、Aspose.PDF 简介二、说明(⚠️仅供学习与研究使用)三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...
20个超级好用的 CSS 动画库
分享 20 个最佳 CSS 动画库。 它们中的大多数将生成纯 CSS 代码,而不需要任何外部库。 1.Animate.css 一个开箱即用型的跨浏览器动画库,可供你在项目中使用。 2.Magic Animations CSS3 一组简单的动画,可以包含在你的网页或应用项目中。 3.An…...
AirSim/Cosys-AirSim 游戏开发(四)外部固定位置监控相机
这个博客介绍了如何通过 settings.json 文件添加一个无人机外的 固定位置监控相机,因为在使用过程中发现 Airsim 对外部监控相机的描述模糊,而 Cosys-Airsim 在官方文档中没有提供外部监控相机设置,最后在源码示例中找到了,所以感…...
MyBatis中关于缓存的理解
MyBatis缓存 MyBatis系统当中默认定义两级缓存:一级缓存、二级缓存 默认情况下,只有一级缓存开启(sqlSession级别的缓存)二级缓存需要手动开启配置,需要局域namespace级别的缓存 一级缓存(本地缓存&#…...
