当前位置: 首页 > news >正文

《动手学深度学习 Pytorch版》 10.2 注意力汇聚:Nadaraya-Watson 核回归

import torch
from torch import nn
from d2l import torch as d2l

1964 年提出的 Nadaraya-Watson 核回归模型是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。

10.2.1 生成数据集

根据下面的非线性函数生成一个人工数据集,其中噪声项 ϵ \epsilon ϵ 服从均值为 0 ,标准差为 0.5 的正态分布:

y i = 2 sin ⁡ x i + x i 0.8 + ϵ \boldsymbol{y}_i=2\sin{\boldsymbol{x}_i}+\boldsymbol{x}_i^{0.8}+\epsilon yi=2sinxi+xi0.8+ϵ

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本def f(x):return 2 * torch.sin(x) + x**0.8y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test
50
def plot_kernel_reg(y_hat):  # 绘制训练样本d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

10.2.2 平均汇聚

先使用最简单的估计器来解决回归问题。基于平均汇聚来计算所有训练样本输出值的平均值:

f ( x ) = 1 n ∑ i = 1 n y i f(x)=\frac{1}{n}\sum^n_{i=1}y_i f(x)=n1i=1nyi

y_hat = torch.repeat_interleave(y_train.mean(), n_test)  # 计算平均并进行扩展
plot_kernel_reg(y_hat)


在这里插入图片描述

10.2.3 非参数注意力汇聚

相对于平均汇聚的忽略输入。Nadaraya 和 Watson 提出了一个更好的想法,根据输入的位置对输出 y i y_i yi 进行加权,即 Nadaraya-Watson 核回归:

f ( x ) = ∑ i = 1 n K ( x − x i ) ∑ j = 1 n K ( x − x j ) y i f(x)=\sum^n_{i=1}\frac{K(x-x_i)}{\sum^n_{j=1}K(x-x_j)}y_i f(x)=i=1nj=1nK(xxj)K(xxi)yi

将其中的核(kernel) K K K 根据上节内容重写为更通用的注意力汇聚公式:

f ( x ) = ∑ i = 1 n α ( x , x i ) y i f(x)=\sum^n_{i=1}\alpha(x,x_i)y_i f(x)=i=1nα(x,xi)yi

参数字典:

  • x x x 为查询

  • ( x i , y i ) (x_i,y_i) (xi,yi) 为键值对

  • α ( x , x i ) \alpha(x,x_i) α(x,xi) 为注意力权重(attention weight),即查询 x x x 和键 x i x_i xi 之间的关系建模,此权重被分配给对应值的 y i y_i yi

对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布: 非负的且和为1。

考虑高斯核(Gaussian kernel)以更好地理解注意力汇聚:

K ( u ) = 1 2 π exp ⁡ ( − u 2 2 ) K(u)=\frac{1}{\sqrt{2\pi}}\exp{(-\frac{u^2}{2})} K(u)=2π 1exp(2u2)

将高斯核代入上式可得:

f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n exp ⁡ ( − 1 2 ( x − x i ) 2 ) ∑ j = 1 n exp ⁡ ( − 1 2 ( x − x j ) 2 ) y i = ∑ i = 1 n s o f t m a x ( − 1 2 ( x − x i ) 2 ) y i \begin{align} f(x)=&\sum^n_{i=1}\alpha(x,x_i)y_i\\ =&\sum^n_{i=1}\frac{\exp{(-\frac{1}{2}(x-x_i)^2)}}{\sum^n_{j=1}\exp{(-\frac{1}{2}(x-x_j)^2)}}y_i\\ =&\sum^n_{i=1}\mathrm{softmax}\left(-\frac{1}{2}(x-x_i)^2\right)y_i \end{align} f(x)===i=1nα(x,xi)yii=1nj=1nexp(21(xxj)2)exp(21(xxi)2)yii=1nsoftmax(21(xxi)2)yi

如果一个键 x i x_i xi 越是接近给定的查询 x x x,那么分配给这个键对应值 y i y_i yi 的注意力权重就会越大,也就“获得了更多的注意力”。

上式是一个非参数的注意力汇聚(nonparametric attention pooling)模型。 接下来基于这个非参数的注意力汇聚模型绘制的预测结果的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)


在这里插入图片描述

观察注意力的权重可以发现,“查询-键”对越接近,注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')


在这里插入图片描述

10.2.4 带参数的注意力汇聚

可以轻松地将可学习的参数集成到注意力汇聚中,例如,在下面的查询 x x x 和键 x i x_i xi 之间的距离乘以可学习参数 w w w

f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n exp ⁡ ( − 1 2 ( ( x − x i ) w ) 2 ) ∑ j = 1 n exp ⁡ ( − 1 2 ( ( x − x j ) w ) 2 ) y i = ∑ i = 1 n s o f t m a x ( − 1 2 ( ( x − x i ) w ) 2 ) y i \begin{align} f(x)=&\sum^n_{i=1}\alpha(x,x_i)y_i\\ =&\sum^n_{i=1}\frac{\exp{(-\frac{1}{2}((x-x_i)w)^2)}}{\sum^n_{j=1}\exp{(-\frac{1}{2}((x-x_j)w)^2)}}y_i\\ =&\sum^n_{i=1}\mathrm{softmax}\left(-\frac{1}{2}((x-x_i)w)^2\right)y_i \end{align} f(x)===i=1nα(x,xi)yii=1nj=1nexp(21((xxj)w)2)exp(21((xxi)w)2)yii=1nsoftmax(21((xxi)w)2)yi

10.2.4.1 批量矩阵乘法

假定两个张量的形状分别是 ( n , a , b ) (n,a,b) (n,a,b) ( n , b , c ) (n,b,c) (n,b,c),它们的批量矩阵乘法输出的形状为 ( n , a , c ) (n,a,c) (n,a,c)

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
torch.Size([2, 1, 6])

可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
weights.shape, values.shape, weights.unsqueeze(1).shape, values.unsqueeze(-1).shape, torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
(torch.Size([2, 10]),torch.Size([2, 10]),torch.Size([2, 1, 10]),torch.Size([2, 10, 1]),tensor([[[ 4.5000]],[[14.5000]]]))

10.2.4.2 定义模型

class NWKernelRegression(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)self.w = nn.Parameter(torch.rand((1,), requires_grad=True))def forward(self, queries, keys, values):# queries和attention_weights的形状为(查询个数,“键-值”对个数)queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w)**2 / 2, dim=1)# values的形状为(查询个数,“键-值”对个数)return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)

10.2.4.3 训练

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')  # 使用平方损失函数
trainer = torch.optim.SGD(net.parameters(), lr=0.5)  # 使用随机梯度下降
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])for epoch in range(5):trainer.zero_grad()l = loss(net(x_train, keys, values), y_train)l.sum().backward()trainer.step()print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')animator.add(epoch + 1, float(l.sum()))


在这里插入图片描述

训练完带参数的注意力汇聚模型后可以发现:在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)


在这里插入图片描述

与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑。

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')


在这里插入图片描述

练习

(1)增加训练数据的样本数量,能否得到更好的非参数的 Nadaraya-Watson 核回归模型?

不能。

n_train_more = 500
x_train_more, _ = torch.sort(torch.rand(n_train_more) * 5)def f(x):return 2 * torch.sin(x) + x**0.8y_train_more = f(x_train_more) + torch.normal(0.0, 0.5, (n_train_more,))
x_test_more = torch.arange(0, 5, 0.01)
y_truth_more = f(x_test_more)def plot_kernel_regv_more(y_hat_more):d2l.plot(x_test_more, [y_truth_more, y_hat_more], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])d2l.plt.plot(x_train_more, y_train_more, 'o', alpha=0.5);X_repeat_more = x_test_more.repeat_interleave(n_train_more).reshape((-1, n_train_more))
attention_weights_more = nn.functional.softmax(-(X_repeat_more - x_train_more)**2 / 2, dim=1)
y_hat_more = torch.matmul(attention_weights_more, y_train_more)
plot_kernel_regv_more(y_hat_more)


在这里插入图片描述

d2l.show_heatmaps(attention_weights_more.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')


在这里插入图片描述


(2)在带参数的注意力汇聚的实验中学习得到的参数 w w w 的价值是什么?为什么在可视化注意力权重时,它会使加权区域更加尖锐?

w w w 的价值在于放大注意力,也就是利用 softmax 函数的特性使键 x i x_i xi 和查询 x x x 距离小的得以保存,学习到的 w w w 就是掌握这个放大的尺度。

距离大的被过滤,当然也就显得更尖锐了。


(3)如何将超参数添加到非参数的Nadaraya-Watson核回归中以实现更好地预测结果?

加进去就能行。

n_train_test = 50
x_train_test, _ = torch.sort(torch.rand(n_train_test) * 5)def f(x):return 2 * torch.sin(x) + x**0.8y_train_test = f(x_train_test) + torch.normal(0.0, 0.5, (n_train_test,))
x_test_test = torch.arange(0, 5, 0.1)
y_truth_test = f(x_test_test)def plot_kernel_regv_more(y_hat_test):d2l.plot(x_test_test, [y_truth_test, y_hat_test], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])d2l.plt.plot(x_train_test, y_train_test, 'o', alpha=0.5);X_repeat_test = x_test_test.repeat_interleave(n_train_test).reshape((-1, n_train_test))
attention_weights_test = nn.functional.softmax(-((X_repeat_test - x_train_test)*net.w.detach().numpy())**2 / 2, dim=1)  # 加入训练好的权重
y_hat_test = torch.matmul(attention_weights_test, y_train_test)
plot_kernel_regv_more(y_hat_test)


在这里插入图片描述


(4)为本节的核回归设计一个新的带参数的注意力汇聚模型。训练这个新模型并可视化其注意力权重。

不会,略。

相关文章:

《动手学深度学习 Pytorch版》 10.2 注意力汇聚:Nadaraya-Watson 核回归

import torch from torch import nn from d2l import torch as d2l1964 年提出的 Nadaraya-Watson 核回归模型是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。 10.2.1 生成数据集 根据下面的非线性函数生成一个人工数据集,其中噪声项 …...

测试C#调用Windows Media Player组件

新建基于.net framework的Winform项目,可以通过添加引用的方式选择COM组件中的Windows Media Player组件,如下图所示:   也可以在VS2022的工具箱空白处点右键,选择“选择项…”菜单。   在弹出的选择工具箱项窗口中&#xf…...

面试经典150题——Day20

文章目录 一、题目二、题解 一、题目 14. Longest Common Prefix Write a function to find the longest common prefix string amongst an array of strings. If there is no common prefix, return an empty string “”. Example 1: Input: strs [“flower”,“flow”…...

[SQL开发笔记]AND OR运算符复杂表达式开发实例

结合 AND & OR实例:通过圆括号使用and或or来组成复杂的表达式 目标数据库及表:使用 DRobot数据库,"T_Drobot" 表 假设我们需要查询"T_Drobot" 表,并从"T_Drobot"表中查询选取creator为 "…...

如何将本地 PDF 文件进行翻译

在日常工作和学习中,我们经常会遇到需要翻译 PDF 文件的情况。比如,我们需要将一份英文的技术文档翻译成中文,或者将一份中文的法律文件翻译成英文。 传统上,我们可以使用专业翻译软件或服务来翻译 PDF 文件。但是,这…...

Node.js的readline模块 命令行交互的模块

Node.js是一个非常流行的JavaScript运行时环境,它提供了许多内置模块来帮助我们开发应用程序。其中之一是readline模块,它提供了一种简单的方法来读取用户输入并进行交互。 本文将详细介绍readline模块的API和使用案例,并附有代码注释。 re…...

前沿重器[36] | ACL23-基于检索的大语言模型-报告阅读

前沿重器 栏目主要给大家分享各种大厂、顶会的论文和分享,从中抽取关键精华的部分和大家分享,和大家一起把握前沿技术。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。(算起来,专项启动已经…...

2023秋招笔试算法Python3题解

诸神缄默不语-个人CSDN博文目录 签两方了,感觉秋招已经结束了,所以发布一下之前写的笔试编程题题解。 不全。可能有些题我会继续补。 不保证能过。 后续依然有可能继续刷算法题,但是就另外专门写博文来解析了。 打码是因为原则上其实是不让公…...

uniapp--点击上传图片到oss再保存数据给后端接口

项目采用uniapp与uview2.0组件库 --1.0的也可以参考一下,大差不差 一、项目要求与样式图 点击上传n张图片到oss,然后点击提交给后端 二、思路 1、打开上传按钮,弹出框内出现上传图片和提交按钮 2、点击上传图片区域,打开本地图…...

创建Secret(使用kubectl)

创建Secret(使用kubectl) 假设某个 Pod 需要访问数据库。在您执行 kubectl 命令所在机器的当前目录,创建文件 ./username.txt 文件和 ./password.txt 暂存数据库的用户名和密码,后续我们根据这两个文件配置 kubernetes secrets。…...

Notepad++正则查询替换操作

Notepad编辑器查找功能非常强大,本处记录一些实战中常用到复杂查询替换操作。 注意:如果是重要文件,替换操作前最好备份;当前一个操作后也可以用ctrlz恢复。 查找重复行 用查找(ctrlf)功能,用正则表达式模式匹配。 查…...

Hive特殊函数的使用

Hive特殊函数的使用 with ascastget_json_objectunix_timestampfrom_unixtime with as 在Hive中,WITH AS是一种子查询的用法,用于在查询的开头定义一个临时表达式。它的语法结构如下: WITH [表达式名称] AS (子查询表达式 )在这个结构中,[表…...

Unity Spine 指定导入新Spine动画的默认材质

指定导入新Spine动画的默认材质 找到Spine的Editor导入配置如何修改方法一: 你可以通过脚本 去修改Assets/Editor/SpineSettings.asset文件方法二:通过面板手动设置 找到Spine的Editor导入配置 通常在 Assets/Editor/SpineSettings.asset 配置文件对应着 Edit/Prefe…...

lvs负载均衡集群

目录 一、集群: 1、集群的目的: 2、集群的类型: 3、集群的可靠性指标: 4、设计集群时需要考虑的原则: 二、lvs集群: 1、lvs集群中的术语: 2、lvs访问的大致流程: 三、lvs的…...

MySQL---表的增查改删(CRUD基础)

文章目录 什么是CRUD?新增(Create)单行数据 全列插入多行数据 指定列插入 查询(Retrieve)全列查询指定列查询查询字段为表达式起别名查询去重查询排序查询条件查询分页查询 修改(Update)删除&…...

听GPT 讲Rust源代码--library/std(2)

File: rust/library/std/src/sys_common/wtf8.rs 在Rust源代码中,rust/library/std/src/sys_common/wtf8.rs这个文件的作用是实现了UTF-8编码和宽字符编码之间的转换,以及提供了一些处理和操作UTF-8编码的工具函数。 下面对这几个结构体进行一一介绍&…...

力扣第1005题 K 次取反后最大化的数组和 c++ 贪心 双思维

题目 1005. K 次取反后最大化的数组和 简单 相关标签 贪心 数组 排序 给你一个整数数组 nums 和一个整数 k ,按以下方法修改该数组: 选择某个下标 i 并将 nums[i] 替换为 -nums[i] 。 重复这个过程恰好 k 次。可以多次选择同一个下标 i 。 以…...

Swoole 4.8版本的安装

1、从github拉取安装包 Release v4.8.13 swoole/swoole-src GitHub 2、解压压缩包 tar -zxvf ./v4.8.13.tar.gzcd ./swoole-src-4.8.13 3、执行安装命令 phpize && \ ./configure && \ make && sudo make install 4、检查swoole模块是否安装完成…...

ChatGPT和Copilot协助Vue火速搭建博客网站

AI 对于开发人员的核心价值 网上会看到很多 AI 的应用介绍或者教程 使用 AI 聊天,咨询问题 —— 代替搜索引擎使用 AI 写各种的电商文案(淘宝、小红书)使用 AI 做一个聊天机器人 —— 这最多算猎奇、业余爱好、或者搞个套壳产品来收费 以上…...

javaEE -8(9000字详解网络编程)

一:网络编程基础 1.1 网络资源 所谓的网络资源,其实就是在网络中可以获取的各种数据资源,而所有的网络资源,都是通过网络编程来进行数据传输的。 用户在浏览器中,打开在线视频网站,如优酷看视频&#xff…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

Linux相关概念和易错知识点(42)(TCP的连接管理、可靠性、面临复杂网络的处理)

目录 1.TCP的连接管理机制(1)三次握手①握手过程②对握手过程的理解 (2)四次挥手(3)握手和挥手的触发(4)状态切换①挥手过程中状态的切换②握手过程中状态的切换 2.TCP的可靠性&…...

苍穹外卖--缓存菜品

1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...

linux 下常用变更-8

1、删除普通用户 查询用户初始UID和GIDls -l /home/ ###家目录中查看UID cat /etc/group ###此文件查看GID删除用户1.编辑文件 /etc/passwd 找到对应的行,YW343:x:0:0::/home/YW343:/bin/bash 2.将标红的位置修改为用户对应初始UID和GID: YW3…...

Java入门学习详细版(一)

大家好,Java 学习是一个系统学习的过程,核心原则就是“理论 实践 坚持”,并且需循序渐进,不可过于着急,本篇文章推出的这份详细入门学习资料将带大家从零基础开始,逐步掌握 Java 的核心概念和编程技能。 …...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

中医有效性探讨

文章目录 西医是如何发展到以生物化学为药理基础的现代医学?传统医学奠基期(远古 - 17 世纪)近代医学转型期(17 世纪 - 19 世纪末)​现代医学成熟期(20世纪至今) 中医的源远流长和一脉相承远古至…...

莫兰迪高级灰总结计划简约商务通用PPT模版

莫兰迪高级灰总结计划简约商务通用PPT模版,莫兰迪调色板清新简约工作汇报PPT模版,莫兰迪时尚风极简设计PPT模版,大学生毕业论文答辩PPT模版,莫兰迪配色总结计划简约商务通用PPT模版,莫兰迪商务汇报PPT模版,…...

08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险

C#入门系列【类的基本概念】:开启编程世界的奇妙冒险 嘿,各位编程小白探险家!欢迎来到 C# 的奇幻大陆!今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类!别害怕,跟着我,保准让你轻松搞…...

代码规范和架构【立芯理论一】(2025.06.08)

1、代码规范的目标 代码简洁精炼、美观,可持续性好高效率高复用,可移植性好高内聚,低耦合没有冗余规范性,代码有规可循,可以看出自己当时的思考过程特殊排版,特殊语法,特殊指令,必须…...