当前位置: 首页 > article >正文

rust-candle学习笔记11-实现一个简单的自注意力

参考:about-pytorch

定义ScaledDotProductAttention结构体:

use candle_core::{Result, Device, Tensor};
use candle_nn::{Linear, Module, linear_no_bias, VarMap, VarBuilder, ops};struct ScaledDotProductAttention {wq: Linear,wk: Linear,wv: Linear,d_model: Tensor,device: Device,
}

为ScaledDotProductAttention结构体实现new方法:

impl ScaledDotProductAttention {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { wq: linear_no_bias(embedding_dim, out_dim, vb.pp("wq"))?, wk: linear_no_bias(embedding_dim, out_dim, vb.pp("wk"))?, wv: linear_no_bias(embedding_dim, out_dim, vb.pp("wv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

为结构体实现Module的forward trait:

impl Module for ScaledDotProductAttention {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let q = self.wq.forward(xs)?;let k = self.wk.forward(xs)?;let v = self.wv.forward(xs)?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

融合qkv实现:

定义ScaledDotProductAttentionFusedQKV结构体:

struct ScaledDotProductAttentionFusedQKV {w_qkv: Linear,d_model: Tensor,device: Device,
}

为结构体实现new方法:

impl ScaledDotProductAttentionFusedQKV {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

为结构体实现forward trait:

impl Module for ScaledDotProductAttentionFusedQKV {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let qkv = self.w_qkv.forward(xs)?;let (batch_size, seq_len, _) = qkv.dims3()?;let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;let q = qkv.get_on_dim(2, 0)?;let q = q.reshape((batch_size, seq_len, ()))?;let k = qkv.get_on_dim(2, 1)?;let k = k.reshape((batch_size, seq_len, ()))?;let v = qkv.get_on_dim(2, 2)?;let v = v.reshape((batch_size, seq_len, ()))?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

测试:

fn main() -> Result<()> {let device = Device::cuda_if_available(0)?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55, 0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55], (2, 6, 3), &device)?;// let model = ScaledDotProductAttention::new(vb.clone(), 3, 2, device.clone())?;let model = ScaledDotProductAttentionFusedQKV::new(vb.clone(), 3, 2, device.clone())?;let output = model.forward(&input)?;println!("output: {:?}\n", output);println!("output: {:?}\n", output.to_vec3::<f32>()?);Ok(())
}

相关文章:

rust-candle学习笔记11-实现一个简单的自注意力

参考&#xff1a;about-pytorch 定义ScaledDotProductAttention结构体&#xff1a; use candle_core::{Result, Device, Tensor}; use candle_nn::{Linear, Module, linear_no_bias, VarMap, VarBuilder, ops};struct ScaledDotProductAttention {wq: Linear,wk: Linear,wv: …...

读入csv文件写入MySQL

### 使用 Spark RDD 读取 CSV 文件并写入 MySQL 的实现方法 #### 1. 环境准备 在使用 Spark 读取 CSV 文件并写入 MySQL 数据库之前&#xff0c;需要确保以下环境已配置完成&#xff1a; - 添加 Maven 依赖项以支持 JDBC 连接。 - 配置 MySQL 数据库连接参数&#xff0c;包括 …...

Andorid之TabLayout+ViewPager

文章目录 前言一、效果图二、使用步骤1.主xml布局2.activity代码3.MyTaskFragment代码4.MyTaskFragment的xml布局5.Adapter代码6.item布局 总结 前言 TabLayoutViewPager功能需求已经是常见功能了&#xff0c;我就不多解释了&#xff0c;需要的自取。 一、效果图 二、使用步骤…...

C++GO语言微服务之用户信息处理②

目录 01 03-获取用户信息-上 02 04-获取用户信息-下 03 05-更新用户名实现 01 06-中间件简介和中间件类型 02 07-中间件测试和模型分析 03 08-中间件测试案例和小结 04 09-项目使用中间件 01 03-获取用户信息-上 ## Cookie操作 ### 设置Cookie go func (c *Context) …...

26考研——中央处理器_指令流水线_流水线的冒险与处理 流水线的性能指标 高级流水线技术(5)

408答疑 文章目录 六、指令流水线流水线的冒险与处理结构冒险数据冒险延迟执行相关指令采用转发&#xff08;旁路&#xff09;技术load-use 数据冒险的处理 控制冒险 流水线的性能指标流水线的吞吐率流水线的加速比 高级流水线技术超标量流水线技术超长指令字技术超流水线技术 …...

Java 与 Go 语言对比

Java 和 Go (Golang) 是两种流行的编程语言&#xff0c;各有其设计哲学和应用场景。以下是它们的详细对比&#xff1a; 1. 基本特性 特性JavaGo诞生时间1995 (Sun Microsystems)2009 (Google)设计目标“Write Once, Run Anywhere”简洁、高效的系统编程语言语言类型面向对象多…...

OpenUCX 库介绍与使用指南

OpenUCX 库介绍与使用指南 OpenUCX 简介 OpenUCX (Unified Communication X) 是一个高性能、开源通信框架&#xff0c;专为大规模分布式计算和加速计算设计。它提供了统一的API&#xff0c;支持多种网络硬件和协议&#xff0c;包括InfiniBand、RoCE、TCP等。 主要特点 高性…...

酒店旅游类数据采集API接口之携程数据获取地方美食品列表 获取地方美餐馆列表 景点评论

携程 API 接入指南 API 地址&#xff1a; 调用示例&#xff1a; 美食列表 景点列表 景点详情 酒店详情 参数说明 通用参数说明 请谨慎传递参数&#xff0c;避免不必要的费用扣除。 URL 说明&#xff1a;https://api-gw.cn/平台/API类型/ 平台&#xff1a;淘宝&#xff0c;京…...

Lora原理及实现浅析

Lora 什么是Lora Lora的原始论文为《LoRA: Low-Rank Adaptation of Large Language Models》&#xff0c;翻译为中文为“大语言模型的低秩自适应”。最初是为了解决大型语言模在进行任务特定微调时消耗大量资源的问题&#xff1b;随后也用在了Diffusion等领域&#xff0c;用于…...

GitHub 趋势日报 (2025年05月13日)

本日报由 TrendForge 系统生成 https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日整体趋势 Top 10 排名项目名称项目描述今日获星总星数语言1harry0703/MoneyPrinterTurbo利用ai大模型&#xff0c;一键生成高清短视频使用…...

【设计模式】- 创建者模式

单例模型 饿汉式 静态方法创建对象 public class Singleton {// 私有构造方法private Singleton(){}private static Singleton instance new Singleton();// 提供一个外界获取的方法public static Singleton getInstance(){return instance;} }静态代码块创建对象 public …...

服务器时间发生跳变导致hghac中对应主机状态频繁切换为crash或stop

文章目录 环境症状问题原因解决方案相关文档报错编码 环境 系统平台&#xff1a;N/A 版本&#xff1a;N/A 症状 集群状态&#xff1a; [rootbthbj-hgywsjkjq-ip28-cen76 ~]# hghactl list Cluster: highgo-ee-cluster —---------------------—---------- | Member | Ho…...

南审计院考研分享会 经验总结

汪学长 – 中科大 计科专硕 初试准备 数学先做真题&#xff0c;模拟题刷的越多分越高&#xff1b;408真题最重要&#xff0c;模拟题辅助&#xff1b;英语只做真题&#xff1b;政治9月份开始背 代码能力在低年级培养的重要性和路径 考研不选择机构原因 因为机构里面学习的框…...

牛客练习赛138(首篇万字题解???)

赛时成绩如下&#xff1a; 1. 小s的签到题 小s拿到了一个比赛榜单&#xff0c;他要用最快的速度找到签到题&#xff0c;但是小s脑子还是有点晕&#xff0c;请你帮帮小s&#xff0c;助力他找到签到题。 比赛榜单是一个 2 行 n 列的表格&#xff1a; 第一行是 n 个大写字母&#…...

Rust 中的 `String`、`str` 和 `str`:深入解析与使用指南

在 Rust 编程中&#xff0c;字符串是不可或缺的数据类型&#xff0c;但 Rust 的字符串系统与其他语言有所不同。Rust 提供了 String、str 和 &str 三种主要的字符串类型&#xff0c;每种类型都有其独特的用途和特点。本文将详细介绍这三种字符串类型&#xff0c;帮助你更好…...

深入理解高性能网络通信:从内核源码到云原生实践

深入理解高性能网络通信&#xff1a;从内核源码到云原生实践 &#xff08;示意图&#xff1a;Linux网络协议栈与通信架构分层模型&#xff09; 随着互联网业务规模的不断扩大&#xff0c;系统对网络通信性能的要求也在迅速提升。从内核事件机制的演进到云原生架构下的极致优化&…...

10 web 自动化之 yaml 数据/日志/截图

文章目录 一、yaml 数据获取二、日志获取三、截图 一、yaml 数据获取 需要安装 PyYAML 库 import yaml import os from TestPOM.common import dir_config as Dir import jsonpathclass Data:def __init__(self,keyNone,file_name"test_datas.yaml"):file_path os…...

ubuntu清除缓存

pip pip cache purgeconda conda clean -a -yapt apt cleanapt-get apt-get cleanmodelscope modelscope clear-cachehuggingface rm -rf ~/.cache/huggingface/*...

OSI 7层模型

OSI 7层模型&#xff1a; 1、物理层&#xff08;光纤等把电脑连接起来的物理手段&#xff09; 2、数据链路层&#xff08;以太网&#xff0c;确认0和1电信号的分组方式&#xff0c;负责MAC地址&#xff0c;MAC地址用于在网络中唯一标示一个网卡&#xff0c;相当于网卡的身份证…...

用git下载vcpkg时出现Connection was reset时的处理

用git安装vcpkg时出现Connect was rest&#xff08;如上图&#xff09;。多谢这位网友的博文解决了问题&#xff1a; 通过:http.sslVerify false全局来设置&#xff0c;执行以下命令&#xff1a; git config --global http.sslVerify "false" 原文链接&#xff1a…...

deepseek梳理java高级开发工程师算法面试题

Java高级工程师算法面试题与答案 一、数据结构与算法基础 1. 红黑树与AVL树比较 题目&#xff1a;详细说明红黑树和AVL树的区别及各自的适用场景&#xff0c;并用Java实现红黑树的插入操作。 答案&#xff1a; 区别对比&#xff1a; ┌─────────────────…...

leetcode - 滑动窗口问题集

目录 前言 题1 长度最小的子数组&#xff1a; 思考&#xff1a; 参考代码1&#xff1a; 参考代码2&#xff1a; 题2 无重复字符的最长子串&#xff1a; 思考&#xff1a; 参考代码1&#xff1a; 参考代码2&#xff1a; 题3 最大连续1的个数 III&#xff1a; 思考&am…...

一分钟在Cherry Studio和VSCode集成火山引擎veimagex-mcp

MCP的出现打通了AI模型和外部数据库、网页API等资源&#xff0c;成倍提升工作效率。近期火山引擎团队推出了 MCP Server SDK&#xff1a; veimagex-mcp。本文介绍如何在Cherry Studio 和VSCode平台集成 veimagex-mcp。 什么是MCP MCP&#xff08;Model Context Protocol&…...

Tomcat与纯 Java Socket 实现远程通信的区别

Servlet 容器​​&#xff08;如 Tomcat&#xff09; 是一个管理 Servlet 生命周期的运行环境&#xff0c;主要功能包括&#xff1a; ​​协议解析​​&#xff1a;自动处理 HTTP 请求/响应的底层协议&#xff08;如报文头解析、状态码生成&#xff09;&#xff1b; ​​线程…...

为什么企业建站或独立站选用WordPress

与大多数组织相比&#xff0c;企业业务更需要保持可扩展和可靠的网络存在&#xff0c;以保持竞争力。为此&#xff0c;许多大型企业的 IT 领导者历来寻求昂贵的网络解决方案&#xff0c;这些方案需要签订专有支持合同来保证质量。不过&#xff0c;还有另一种方法。WordPress问世…...

镜头内常见的马达类型(私人笔记)

① 螺杆式马达 驱动来源&#xff1a;机身内马达。镜头尾部有一个接收“螺杆”的接口&#xff0c;通过机械传动带动镜头对焦组。缺点&#xff1a;慢、吵、不能用于无机身马达的相机。✅ 典型镜头&#xff1a;尼康 AF、AF-D 系列&#xff1b;美能达老镜头。尼康传统的AF镜头通过…...

docker-compose——安装mysql8

一、编写Dockerfile FROM mysql:8.0.39 ENV TZAsia/Shanghai RUN ln -sf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone二、编写docker-compose.yml version : 3.8services:zaomeng-mysql:build:context: ./mysqlimage: mysql:8.0.39conta…...

从代码学习深度学习 - 语义分割和数据集 PyTorch版

文章目录 前言什么是语义分割?图像分割和实例分割Pascal VOC2012 语义分割数据集Pascal VOC2012 语义分割数据集介绍基本信息语义分割部分特点数据格式评价指标应用价值数据集获取使用提示辅助工具代码 (`utils_for_huitu.py`)读取数据预处理数据自定义语义分割数据集类读取数…...

4G物联网模块实现废气处理全流程数据可视化监控配置

一、项目背景 随着工业化进程的加速&#xff0c;工业废气的排放对环境造成了严重影响&#xff0c;废气处理厂应运而生。然而&#xff0c;废气处理厂中的设备众多且分散&#xff0c;传统的人工巡检和数据记录方式效率低下&#xff0c;难以及时发现问题。为了实现对废气处理设备…...

深圳SMT贴片加工厂制造流程解析

内容概要 作为大湾区电子制造产业链的重要节点&#xff0c;深圳SMT贴片加工厂凭借精密的生产体系与技术创新&#xff0c;构建了涵盖12道核心工序的标准化流程。从PCB基板的来料检验开始&#xff0c;通过全自动贴片机的高精度元件定位、SPI三维锡膏检测、智能温控回流焊接等关键…...