当前位置: 首页 > news >正文

transformer学习笔记-自注意力机制(2)

经过上一篇transformer学习笔记-自注意力机制(1)原理学习,这一篇对其中的几个关键知识点代码演示:

1、整体qkv注意力计算

先来个最简单未经变换的QKV处理:

import torch  
Q = torch.tensor([[3.0, 3.0,0.0],[0.5, 4.0,0.0]])
K = Q.T
V = Qscores = Q @ K #计算内积
weights = torch.softmax(scores, dim=0)
print(f"概率分布:{weights}")
newQ = weights @ V
print(f"输出:{newQ}")

再来个输入经过Wq/Wk/Wv变换的:

import torch  
Q = torch.tensor([[3.0, 3.0,0.0],[0.5, 4.0,0.0]])
torch.manual_seed(123)  
d_q, d_k, d_v = 4, 4, 5 # W_query, W_key, W_value 的维度  
d = Q.shape[1] #  W_query, W_key, W_value 的行数等于输入token的维度
# 获取W_query, W_key, W_value(随机生成)
W_query = torch.nn.Parameter(torch.rand(d, d_q))  
W_key = torch.nn.Parameter(torch.rand(d, d_k))  
W_value = torch.nn.Parameter(torch.rand(d, d_v))print("W_query:", W_query)
print("W_key:", W_key)
print("W_value:", W_value)#先只计算苹果对整个句子的注意力,看看效果
apple = Q[0]
query_apple = apple @ W_query  
keys = Q @ W_key  
values = Q @ W_value  
print(f"query_apple:{query_apple}")
print(f"keys:{keys}")
print(f"values:{values}")
scores = query_apple @ keys.T
print(f"scores:{scores}")
weights = torch.softmax(scores, dim=0)
print(f"weights:{weights}")
newQ = weights @ values
print(f"newQ:{newQ}")#再看下整体的
querys = Q @ W_query
all_scores = querys @ keys.T
print(f"all_scores:{all_scores}")
all_weights = torch.softmax(all_scores, dim=-1)
print(f"all_weights:{all_weights}")
output = all_weights @ values
print(f"output:{output}")

在这里插入图片描述
最终生成的output的维度与W_value 的维度一致。

2、调换顺序结果不变

import torchdef simple_attention(Q):K = Q.TV = Qscores = Q @ K #计算内积weights = torch.softmax(scores, dim=-1)print(f"概率分布:{weights}")newQ = weights @ Vprint(f"输出:{newQ}")Q = torch.tensor([[3.0, 3.0,0.0],[0.5, 4.0,0.0]])
Q1 = torch.tensor([[0.5, 4.0,0.0],[3.0, 3.0,0.0]])
print("模拟‘苹果梨’:")
simple_attention(Q)
print("模拟‘梨苹果’:")
simple_attention(Q1)

在这里插入图片描述
可以看到“苹果梨”、“梨苹果”即便换了词token的顺序,并不会影响新的梨和新的苹果的向量数值。这里我们用了softmax函数求概率分布,因此跟上一篇文章的示例数值不一样,不要在意这个细节。

3、softmax:

import numpy as npdef softmax(x):e_x = np.exp(x)return e_x / e_x.sum(axis=0)def softmax_with_temperature(x,T):e_x = np.exp(x/T)return e_x / e_x.sum(axis=0)# 示例使用
if __name__ == "__main__":input_vector = np.array([2.0, 1.0, 0.1])output = softmax(input_vector)print("Softmax Output:", output)print("Softmax with Temperature 0.5 Output:", softmax_with_temperature(input_vector,0.5))print("Softmax with Temperature 1 Output:", softmax_with_temperature(input_vector,1))print("Softmax with Temperature 5 Output:", softmax_with_temperature(input_vector,5))

在这里插入图片描述
可以看到随着T的不断加大,概率分布不断趋于均匀分布。

4、softmax除以 d k \sqrt{d_k} dk

还是用上面的softmax函数,演示下除以 d k \sqrt{d_k} dk 的效果:

        # 高维输入向量input_vector_high_dim = np.random.randn(100) * 10  # 生成一个100维的高斯分布随机向量,乘以10增加内积output_high_dim = softmax(input_vector_high_dim)print("High Dimension Softmax Output:", output_high_dim)# 打印高维输出的概率分布print("Max Probability in High Dimension:", np.max(output_high_dim))print("Min Probability in High Dimension:", np.min(output_high_dim))# 高维输入向量除以10input_vector_high_dim_div10 = input_vector_high_dim / 10output_high_dim_div10 = softmax(input_vector_high_dim_div10)print("High Dimension Softmax Output (Divided by 10):", output_high_dim_div10)# 打印高维输出的概率分布print("Max Probability in High Dimension (Divided by 10):", np.max(output_high_dim_div10))print("Min Probability in High Dimension (Divided by 10):", np.min(output_high_dim_div10))# 绘制高维概率分布曲线plt.figure(figsize=(10, 6))# 绘制图形plt.plot(output_high_dim, label='High Dim')plt.plot(output_high_dim_div10, label='High Dim Divided by 10')plt.legend()plt.title('High Dimension Softmax Output Comparison')plt.xlabel('Index')plt.ylabel('Probability')plt.show()

在这里插入图片描述
在除以 d k \sqrt{d_k} dk 之前,由于内积变大,导致概率分布变得尖锐,趋近0的位置梯度基本消失,softmax 函数的损失函数的导数在输出接近 0 时接近零,在反向传播过程中,无法有效地更新权重。有兴趣的话可以试试对softmax 函数的损失函数求导。

继续上面的代码,来看下softmax的输出的损失函数求梯度:

        def test_grad( dim_vertor):import numpy as npimport torchimport torch.nn.functional as F# 假设的输入z = torch.tensor(dim_vertor, requires_grad=True)print(z)# 计算 softmax 输出p = F.softmax(z, dim=0)true_label = np.zeros(100)true_label[3] = 1# 模拟损失函数(例如交叉熵)y = torch.tensor(true_label)  # one-hot 编码的真实标签loss = -torch.sum(y * torch.log(p))# 反向传播并获取梯度loss.backward()# print(z.grad)  # 输出梯度return z.gradgrad_div10 = test_grad(input_vector_high_dim_div10)grad = test_grad(input_vector_high_dim)print(f"grad_div10:{grad_div10}")print(f"grad:{grad}")

在这里插入图片描述
明显看出,没有除以 d k \sqrt{d_k} dk 求出的梯度,基本为0;上面的代码是torch已经实现的。当然也可以根据损失函数自己求导,这里我们只为演示效果,点到即止:

5、多头注意力:

import torch
import torch.nn as nntorch.manual_seed(123)# 输入矩阵 Q
Q = torch.tensor([[3.0, 3.0, 0.0],[0.5, 4.0, 0.0]])# 维度设置
d_q, d_k, d_v = 4, 4, 5  # 每个头的 query, key, value 的维度
d_model = Q.shape[1]     # 输入 token 的维度
num_heads = 2            # 头的数量# 初始化每个头的权重矩阵
W_query = nn.ParameterList([nn.Parameter(torch.rand(d_model, d_q)) for _ in range(num_heads)])
W_key = nn.ParameterList([nn.Parameter(torch.rand(d_model, d_k)) for _ in range(num_heads)])
W_value = nn.ParameterList([nn.Parameter(torch.rand(d_model, d_v)) for _ in range(num_heads)])# 输出权重矩阵
W_output = nn.Parameter(torch.rand(num_heads * d_v, d_model))# 打印权重矩阵
for i in range(num_heads):print(f"W_query_{i+1}:\n{W_query[i]}")print(f"W_key_{i+1}:\n{W_key[i]}")print(f"W_value_{i+1}:\n{W_value[i]}")# 计算每个头的 Q, K, V
queries = [Q @ W_query[i] for i in range(num_heads)]
keys = [Q @ W_key[i] for i in range(num_heads)]
values = [Q @ W_value[i] for i in range(num_heads)]# 计算每个头的注意力分数和权重
outputs = []
for i in range(num_heads):scores = queries[i] @ keys[i].T / (d_k ** 0.5)weights = torch.softmax(scores, dim=-1)output = weights @ values[i]outputs.append(output)# 拼接所有头的输出
concat_output = torch.cat(outputs, dim=-1)
print(f"concat_output:\n{concat_output}")
# 最终线性变换
final_output = concat_output @ W_output# 打印结果
print(f"Final Output:\n{final_output}")

6、掩码注意力:

import torch# 原始 Q 矩阵
Q = torch.tensor([[3.0, 3.0, 0.0],[0.5, 4.0, 0.0],[1.0, 2.0, 0.0],[2.0, 1.0, 0.0]])torch.manual_seed(123)
d_q, d_k, d_v = 4, 4, 5  # query, key, value 的维度
d = Q.shape[1]           # query, key, value 的行数等于输入 token 的维度# 初始化权重矩阵
W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))print("W_query:", W_query)
print("W_key:", W_key)
print("W_value:", W_value)# 计算 Q, K, V
querys = Q @ W_query
keys = Q @ W_key
values = Q @ W_valueprint(f"querys:\n{querys}")
print(f"keys:\n{keys}")
print(f"values:\n{values}")# 计算注意力分数
all_scores = querys @ keys.T / (d_k ** 0.5)
print(f"all_scores:\n{all_scores}")# 生成掩码
seq_len = Q.shape[0]
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
masked_scores = all_scores.masked_fill(mask, float('-inf'))print(f"Mask:\n{mask}")
print(f"Masked Scores:\n{masked_scores}")# 计算权重
all_weights = torch.softmax(masked_scores, dim=-1)
print(f"all_weights:\n{all_weights}")# 计算输出
output = all_weights @ values
print(f"output:\n{output}")

主要看下生成的掩码矩阵,和通过掩码矩阵处理的权重分布:在这里插入图片描述

相关文章:

transformer学习笔记-自注意力机制(2)

经过上一篇transformer学习笔记-自注意力机制(1)原理学习,这一篇对其中的几个关键知识点代码演示: 1、整体qkv注意力计算 先来个最简单未经变换的QKV处理: import torch Q torch.tensor([[3.0, 3.0,0.0],[0.5, 4…...

呼叫中心呼入大模型如何对接传统呼叫中心系统?

呼叫中心呼入大模型如何对接传统呼叫中心系统? 原作者:开源呼叫中心FreeIPCC,其Github:https://github.com/lihaiya/freeipcc 呼叫中心呼入大模型与传统呼叫中心系统的对接是一个复杂而细致的过程,涉及技术实现、流程…...

[Unity] Text文本首行缩进两个字符

Text文本首行缩进两个字符的方法比较简单。通过代码把"\u3000\u3000"加到文本字符串前面即可。 比如: 效果: 代码: TMPtext1.text "\u3000\u3000" "选择动作类型:";...

最新版Chrome浏览器加载ActiveX控件之Adobe PDF阅读器控件

背景 Adobe PDF阅读器控件是一个ActiveX控件,用于在Windows平台上显示和操作PDF文件。它提供了一系列方法和属性,可以实现对PDF文件的加载、显示、搜索、打印、保存等操作。 allWebPlugin中间件是一款为用户提供安全、可靠、便捷的浏览器插件服务的中间件…...

springboot 对接 ollama

spring ai 对接 ollama 引入依赖 <dependency><groupId>io.springboot.ai</groupId><artifactId>spring-ai-ollama-spring-boot-starter</artifactId><version>1.0.0</version> </dependency>这里因为使用的是快照版本所以需…...

【数据库】选择题+填空+简答

1.关于冗余数据的叙述中&#xff0c;不正确的是&#xff08;&#xff09; A.冗余的存在容易破坏数据库的完整新 B.冗余的存在给数据库的维护增加困难 C.不应该在数据库中存储任何冗余数据 D.冗余数据是指由基本数据导出的数据 C 2.最终用户使用的数据视图称为&#xff08;&…...

从0开始写android 之xwindow

模拟实现android的窗口系统本质上还是在ubuntu 上实现自己的窗口系统&#xff0c; xwindow是一套成熟的解决方案。在ubuntu上使用xwindow的好处之一 是ubuntu自带xwindow的库&#xff0c; 直接引用头文件和库文件。下面来了解下 xwindow的基本函数接口。 参考 https://tronche…...

The Past, Present and Future of Apache Flink

摘要&#xff1a;本文整理自阿里云开源大数据负责人王峰&#xff08;莫问&#xff09;在 Flink Forward Asia 2024上海站主论坛开场的分享&#xff0c;今年正值Flink开源项目诞生的第10周年&#xff0c;借此时机&#xff0c;王峰回顾了Flink在过去10年的发展历程以及 Flink社区…...

多模块应用、发布使用第三方库(持续更新中)

目录: 1、多模块概述&#xff08;HAP、HSP、HAR&#xff09; HAR与HSP两种共享包的主要区别体现在&#xff1a; 2、三类模块&#xff1a; 3、创建项目&#xff1a;项目名&#xff1a;meituan &#xff08;1&#xff09;创建Ability类型的Module&#xff0c;编译后为HAP文件…...

An error happened while trying to locate the file on the Hub and we cannot f

An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on. 关于上述comfy ui使用control net预处理器的报错问…...

UE5安装Fab插件

今天才知道原来Fab也有类似Quixel Bridge的插件&#xff0c;于是立马就安装上了&#xff0c;这里分享一下安装方法 在Epic客户端 - 库 - Fab Library 搜索 Fab 即可安装Fab插件 然后重启引擎&#xff0c;在插件面板勾选即可 然后在窗口这就有了 引擎左下角也会多出一个Fab图标…...

Linux C语言操作sqlite3数据库

一、环境配置 1、下载源码&#xff1a;sqlite-autoconf-3470200.tar.gz 2、解压&#xff0c;cd到源码主目录 3、配置参数 ./configure --prefix/usr/local/ 如果是交叉编译环境 ./configure CC/opt/rk3288/gcc-linaro/bin/arm-linux-gnueabihf-gcc --hostarm-linux --pre…...

【人工智能】因果推断与数据分析:用Python探索数据间的因果关系

解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 因果推断是数据科学领域的一个重要方向,旨在发现变量间的因果关系,而不仅仅是相关性。本篇文章将从因果推断的理论基础出发,介绍因果关系的定义与建模方法,涵盖因果图(Causal Graph)、d-分离、反事实估计等…...

freeswitch(30秒自动挂断)

亲测版本centos 7.9系统–》 freeswitch1.10.9 本人freeswitch安装路径(根据自己的路径进入) /usr/local/freeswitch/etc/freeswitch场景说明: A和B接通通话时候,时间开始计算到达30秒后自动挂断使用方法 进入/usr/local/freeswitch/etc...

大模型呼入机器人有哪些功能特点?(转)

大模型呼入机器人有哪些功能特点&#xff1f;(转) 原作者&#xff1a;开源呼叫中心FreeIPCC&#xff0c;其Github&#xff1a;https://github.com/lihaiya/freeipcc 大模型呼入机器人&#xff0c;作为现代通信技术与人工智能深度融合的产物&#xff0c;正逐渐成为企业提升服务…...

网络工程师常用软件之配置对比软件

「24-配置比对软件-汉化WinMerge」 链接&#xff1a;https://pan.quark.cn/s/cef7541d62d1 ################################################################################ 我们经常在项目或者运维中对设备的config进行变更&#xff0c;那么我们如何快速的知道变更了什么…...

Linux之远程登录

一、使用ssh命令登录 winR打开cmd输入命令 # root是命令&#xff0c;192.168.101.200是地址 ssh root192.168.101.200是否要保存密码&#xff0c;就是yes以后可以免密登录&#xff0c;这里就yes了 输入密码&#xff0c;就登录成功了 操作完成之后&#xff0c;输入命令退出 e…...

#渗透测试#漏洞挖掘#红蓝攻防#js分析(上)

免责声明 本教程仅为合法的教学目的而准备&#xff0c;严禁用于任何形式的违法犯罪活动及其他商业行为&#xff0c;在使用本教程前&#xff0c;您应确保该行为符合当地的法律法规&#xff0c;继续阅读即表示您需自行承担所有操作的后果&#xff0c;如有异议&#xff0c;请立即停…...

数智读书笔记系列006 协同进化:人类与机器融合的未来

书名:协同进化&#xff1a;人类与机器融合的未来 作者:[美]爱德华阿什福德李 译者:李杨 出版时间:2022-06-01 ISBN:9787521741476 中信出版集团制作发行 爱德华・阿什福德・李&#xff08;Edward Ashford Lee&#xff09;是一位在计算机科学与工程领域颇具影响力的学者&am…...

操作系统(7)处理机调度

前言 操作系统中的处理机调度是一个核心概念&#xff0c;它涉及如何从就绪队列中选择进程并将处理机分配给它以运行&#xff0c;从而实现进程的并发执行。 一、调度的层次 高级调度&#xff08;作业调度&#xff09;&#xff1a; 调度对象&#xff1a;作业&#xff08;包含程序…...

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具&#xff0c;该工具基于TUN接口实现其功能&#xff0c;利用反向TCP/TLS连接建立一条隐蔽的通信信道&#xff0c;支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式&#xff0c;适应复杂网…...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框&#xff0c;很难让人不联想到SQL注入&#xff0c;但提示都说了不是SQL注入&#xff0c;所以就不往这方面想了 ​ 先查看一下网页源码&#xff0c;发现一段JavaScript代码&#xff0c;有一个关键类ctfs…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销&#xff0c;平衡网络负载&#xff0c;延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

汽车生产虚拟实训中的技能提升与生产优化​

在制造业蓬勃发展的大背景下&#xff0c;虚拟教学实训宛如一颗璀璨的新星&#xff0c;正发挥着不可或缺且日益凸显的关键作用&#xff0c;源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例&#xff0c;汽车生产线上各类…...

五年级数学知识边界总结思考-下册

目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解&#xff1a;由来、作用与意义**一、知识点核心内容****二、知识点的由来&#xff1a;从生活实践到数学抽象****三、知识的作用&#xff1a;解决实际问题的工具****四、学习的意义&#xff1a;培养核心素养…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现&#xff08;两者等价&#xff09;&#xff0c;用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例&#xff1a; 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

在WSL2的Ubuntu镜像中安装Docker

Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包&#xff1a; for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量&#xff0c;这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

论文笔记——相干体技术在裂缝预测中的应用研究

目录 相关地震知识补充地震数据的认识地震几何属性 相干体算法定义基本原理第一代相干体技术&#xff1a;基于互相关的相干体技术&#xff08;Correlation&#xff09;第二代相干体技术&#xff1a;基于相似的相干体技术&#xff08;Semblance&#xff09;基于多道相似的相干体…...