当前位置: 首页 > news >正文

Pyraformer复现心得

Pyraformer复现心得

引用

Liu, Shizhan, et al. “Pyraformer: Low-complexity pyramidal attention for long-range time series modeling and forecasting.” International conference on learning representations. 2021.

代码部分

def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]#B,dmodel*3dec_out = self.projection(enc_out).view(enc_out.size(0), self.pred_len, -1)#B,pre,Nreturn dec_out

预测部分就这么长

x_dec, x_mark_dec, mask=None都没用到

enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]
#B,dmodel*3
  • 直接进入encoder
def forward(self, x_enc, x_mark_enc):seq_enc = self.enc_embedding(x_enc, x_mark_enc)
  • 重构了encoder和decoder,跟transformer的很不一样
x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
return self.dropout(x)
  • embedding方法跟former一样
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)

用pyra的方式获取pam掩码

def get_mask(input_size, window_size, inner_size):"""Get the attention mask of PAM-Naive"""# Get the size of all layersall_size = []all_size.append(input_size)for i in range(len(window_size)):layer_size = math.floor(all_size[i] / window_size[i])all_size.append(layer_size)seq_length = sum(all_size)mask = torch.zeros(seq_length, seq_length)# get intra-scale maskinner_window = inner_size // 2for layer_idx in range(len(all_size)):start = sum(all_size[:layer_idx])for i in range(start, start + all_size[layer_idx]):left_side = max(i - inner_window, start)right_side = min(i + inner_window + 1, start + all_size[layer_idx])mask[i, left_side:right_side] = 1# get inter-scale maskfor layer_idx in range(1, len(all_size)):start = sum(all_size[:layer_idx])for i in range(start, start + all_size[layer_idx]):left_side = (start - all_size[layer_idx - 1]) + \(i - start) * window_size[layer_idx - 1]if i == (start + all_size[layer_idx] - 1):right_side = startelse:right_side = (start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1]mask[i, left_side:right_side] = 1mask[left_side:right_side, i] = 1mask = (1 - mask).bool()return mask, all_size

接着进入卷积层

seq_enc = self.conv_layers(seq_enc)

先构建CSCM卷积

class Bottleneck_Construct(nn.Module):"""Bottleneck convolution CSCM"""
temp_input = self.down(enc_input).permute(0, 2, 1)
all_inputs = []
self.down = Linear(d_model, d_inner)

下采样

for i in range(len(self.conv_layers)):temp_input = self.conv_layers[i](temp_input)all_inputs.append(temp_input)

堆叠很多次卷积,这个跟former是一样的

class ConvLayer(nn.Module):def __init__(self, c_in, window_size):super(ConvLayer, self).__init__()self.downConv = nn.Conv1d(in_channels=c_in,out_channels=c_in,kernel_size=window_size,stride=window_size)self.norm = nn.BatchNorm1d(c_in)self.activation = nn.ELU()def forward(self, x):x = self.downConv(x)x = self.norm(x)x = self.activation(x)return x

将N次卷积的结果拼接起来

all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2)#
all_inputs = self.up(all_inputs)
all_inputs = torch.cat([enc_input, all_inputs], dim=1)
self.up = Linear(d_inner, d_model)
all_inputs = self.norm(all_inputs)
return all_inputs
self.norm = nn.LayerNorm(d_model)

之后在跟原始输入拼接起来

  • 卷积layer完了之后是encoderlayer
def forward(self, enc_input, slf_attn_mask=None):attn_mask = RegularMask(slf_attn_mask)
enc_output, _ = self.slf_attn(enc_input, enc_input, enc_input, attn_mask=attn_mask)

进到encoder里面,到了熟悉的former框架

def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):#后面俩参数应该是作者指定的B, L, _ = queries.shape#B,seq,dmodel_, S, _ = keys.shapeH = self.n_heads
#其实L和S是一个数queries = self.query_projection(queries).view(B, L, H, -1)#B, L, H, dmodel/hkeys = self.key_projection(keys).view(B, S, H, -1)#一样的计算方法values = self.value_projection(values).view(B, S, H, -1)#H 表示头的数量-1 表示自动计算该维度
  • encoder的注意力用的fullattention。并且用到了掩码

回到pyra的encoder

self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
def forward(self, x):residual = xif self.normalize_before:x = self.layer_norm(x)x = F.gelu(self.w_1(x))x = self.dropout(x)x = self.w_2(x)x = self.dropout(x)x = x + residualif not self.normalize_before:x = self.layer_norm(x)return x
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
#B,seq,3,dmodel
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
#B,seq+pred,dmodel
all_enc = torch.gather(seq_enc, 1, indexes)
##B,seq+pred,dmodel
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
#B,seq,dmodel*3
return seq_enc

总结

x_dec, x_mark_dec, mask=None都没用到

  • 直接进入encoder

重构了encoder和decoder,跟transformer的很不一样

embedding方法跟former一样

encoder的注意力用的fullattention,并且用到了掩码

相关文章:

Pyraformer复现心得

Pyraformer复现心得 引用 Liu, Shizhan, et al. “Pyraformer: Low-complexity pyramidal attention for long-range time series modeling and forecasting.” International conference on learning representations. 2021. 代码部分 def long_forecast(self, x_enc, x_m…...

成绩排序c++

说明 给出了班里某门课程的成绩单&#xff0c;请你按成绩从高到低对成绩单排序输出&#xff0c;如果有相同分数则名字字典序小的在前。 输入格式 第一行为nn(0<n<200<n<20)&#xff0c;表示班里的学生数目; 接下来的nn行&#xff0c;每行为每个学生的名字和他的…...

人脸检测之MTCNN算法网络结构

MTCNN&#xff08;Multi-task Cascaded Convolutional Networks&#xff09;是一种用于人脸检测和关键点检测的深度学习模型&#xff0c;特别适合在复杂背景下识别出多尺度的人脸。它通过多任务学习来实现人脸检测和人脸关键点定位&#xff08;如眼睛、鼻子、嘴巴的位置&#x…...

蓝桥杯顺子日期(填空题)

题目&#xff1a;小明特别喜欢顺子。顺子指的就是连续的三个数字&#xff1a;123、456 等。顺子日期指的就是在日期的 yyyymmdd 表示法中&#xff0c;存在任意连续的三位数是一个顺子的日期。例如 20220123 就是一个顺子日期&#xff0c;因为它出现了一个顺子&#xff1a;123&a…...

Java云HIS医院管理系统源码 病案管理、医保业务、门诊、住院、电子病历编辑

云HIS系统优势 &#xff08;1&#xff09;客户/用户角度 无需安装&#xff0c;登录即用 多终端同步&#xff0c;轻松应对工作环境转换 系统使用简单、易上手&#xff0c;信息展示主次分明、重点突出 极致降低用户操作负担&#xff1a;关联功能集中、减少跳转&#xff0c;键盘快…...

【C++的vector、list、stack、queue用法简单介绍】

【知识预告】 vector的介绍及使用list的介绍及使用list与vector的对比stack的介绍和使用queue的介绍和使用priority_queue的介绍和使用 1 vector的介绍及使用 1.1 vector的介绍 vector是表示可变大小数组的序列容器和数组类似&#xff0c;vector也采用连续存储空间来存储元…...

git中使用tag(标签)的方法及重要性

在Git中打标签&#xff08;tag&#xff09;通常用于标记发布版本或其他重要提交。 Git中打标签的步骤&#xff1a; 列出当前所有的标签 git tag创建一个指向特定提交的标签 git tag <tagname> <commit-hash>创建一个带注释的标签&#xff0c;通常用于发布版本 git…...

【专题】2024年文旅微短剧专题研究报告汇总PDF洞察(附原数据表)

原文链接&#xff1a; https://tecdat.cn/?p38187 当今时代&#xff0c;各类文化与消费领域呈现出蓬勃发展且不断变革的态势。 微短剧作为新兴内容形式&#xff0c;凭借网络发展与用户需求&#xff0c;从低成本都市题材为主逐步走向多元化&#xff0c;其内容供给类型正历经深…...

celery加速爬虫 使用flower 可视化地查看celery的实时监控情况

重点: celery ==5.4.0 python 3.11 flower ==2.0.1 请对齐celery与flower的版本信息,如果过低会导致报错 报错1: (venv) PS D:\apploadpath\pythonPath\Lib\site-packages> celery -A tasks flower Traceback (most recent call last):File …...

Angular进阶之十:toPromise废弃原因及解决方案

背景 Rxjs从V7开始废弃了toPromise, V8中会删除它。 原因 1&#xff1a;toPromise()只返回一个值 toPromise()将 Observable 序列转换为符合 ES2015 标准的 Promise 。它使用 Observable 序列的最后一个值。 例&#xff1a; import { Observable } from "rxjs"; ………...

python实现RSA算法

目录 一、算法简介二、算法描述2.1 密钥产生2.2 加密过程2.3 解密过程2.4 证明解密正确性 三、相关算法3.1 欧几里得算法3.2 扩展欧几里得算法3.3 模重复平方算法3.4 Miller-Rabin 素性检测算法 四、算法实现五、演示效果 一、算法简介 RSA算法是一种非对称加密算法&#xff0c…...

可灵开源视频生成数据集 学习笔记

目录 介绍 可灵团队提出了四个模块的改进&#xff1a; video caption 新指标 vtss 动态质量 静态质量 视频自然性 介绍 在视频数据处理中&#xff0c;建立准确且细致的条件是关键&#xff0c;可灵团队认为&#xff0c;解决这一问题需要关注三个主要方面&#xff1a; 文本…...

告别软文营销瓶颈!5招助你突破限制,实现宣传效果最大化

在当今信息爆炸的时代&#xff0c;软文营销作为品牌推广的重要手段之一&#xff0c;面临着日益激烈的竞争和受众日益提高的辨别力。传统的软文营销方式往往难以穿透消费者的心理防线&#xff0c;实现有效的信息传递和品牌塑造。为了突破这一瓶颈&#xff0c;实现宣传效果的最大…...

秋冬进补防肥胖:辨证施补,健康过冬不增脂

中医理论中的秋冬“封藏” 在中医理论中&#xff0c;认为秋冬季节是人体“封藏”的时期&#xff0c;而“封藏”指的是秋冬季节人体应当减少消耗&#xff0c;蓄积能源&#xff0c;此时进补可以使营养物质易于吸收并蓄积于体内&#xff0c;从而增强体质和抵抗力&#xff0c;为来…...

uniapp radio单选

<uni-data-checkbox v-model"selectedValue" :localdata"quTypeList" change"radioChange"/> //产品类型列表 const quTypeList [{ text: 漆面膜, value: 100, }, { text: 改色…...

通熟易懂地讲解GCC和Makefile

1. 嵌入式开发工具链&#xff1a;GCC GCC&#xff08;GNU Compiler Collection&#xff09;是一个强大且常用的编译器套件&#xff0c;支持多种编程语言&#xff0c;比如 C 和 C。在嵌入式开发中&#xff0c;GCC 可以帮助我们把人类可读的 C/C 代码编译成机器可以理解的二进制…...

Java Agent使用

文章目录 基本使用premain使用场景 agentmain 关于tools.jar https://docs.oracle.com/en/java/javase/20/docs/specs/jvmti.html com.sun的API&#xff0c;如果使用其他厂商的JVM&#xff0c;可能没有这个API了&#xff0c;比如Eclipse的J9 https://www.ibm.com/docs/en/sdk…...

selenium 点击元素报错element not interactable

描述说明&#xff1a; 我这里是获取一个span标签后并点击&#xff0c;用的元素自带的element.click()&#xff0c;报错示例代码如下&#xff1a; driver.find_element(By.XPATH,//span[id"my_span"]).click() # 或者 elementdriver.find_element(By.XPATH,//span[i…...

【大数据技术基础 | 实验七】HBase实验:部署HBase

文章目录 一、实验目的二、实验要求三、实验原理四、实验环境五、实验内容和步骤&#xff08;一&#xff09;验证Hadoop和ZooKeeper已启动&#xff08;二&#xff09;修改HBase配置文件&#xff08;三&#xff09;启动并验证HBase 六、实验结果七、实验心得 一、实验目的 掌握…...

Android进程保活,lmkd杀进程相关

lmk原理 Android进程回收之LowMemoryKiller原理 lmkd 更新进程oomAdj; 设备端进程被杀可能原因...

新基建淘汰战:UWB高功耗基站 vs 镜像视界边缘AI无感定位

新基建淘汰战&#xff1a;UWB高功耗基站 vs 镜像视界边缘AI无感定位新基建浪潮下&#xff0c;低能耗、强兼容、可扩展成为空间感知技术的核心准入门槛。UWB厘米级定位深陷高功耗基站强硬件绑定的沉重模式&#xff0c;而镜像视界浙江科技有限公司以边缘AI无感定位为核心&#xf…...

Orbit:革命性记忆增强平台的完整指南

Orbit&#xff1a;革命性记忆增强平台的完整指南 【免费下载链接】orbit Experimental spaced repetition platform for exploring ideas in memory augmentation and programmable attention 项目地址: https://gitcode.com/gh_mirrors/orbit1/orbit Orbit是一个革命性…...

视频高清直播点播/音视频点播/云点播/云直播EasyDSS交互升级解锁大型活动直播新体验

在数字化时代&#xff0c;大型活动直播已从“可选”变为“必需”&#xff0c;无论是政企发布会、行业峰会&#xff0c;还是跨区域学术论坛&#xff0c;都需要一套兼顾稳定、安全与高效的直播解决方案。EasyDSS私有化视频会议系统凭借高并发、低延迟的核心优势站稳市场&#xff…...

被遗忘的女程序员沙拉:用模拟程序为互联网奠基,却因家庭放弃编程

为互联网奠基的女程序员沙拉 数学教师沙拉博姆利用暑假编写代码&#xff0c;她之后开发的东西最终演变成了互联网。作者包括凯蒂哈夫纳、萨米亚布齐德、劳拉伊森西以及科学领域被遗忘的女性倡议组织。 沙拉的编程之路 沙拉博姆从加州大学洛杉矶分校获得教学学位后&#xff0c;投…...

如何为Hermes Agent配置Taotoken作为自定义模型供应商并写入环境变量

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 如何为Hermes Agent配置Taotoken作为自定义模型供应商并写入环境变量 基础教程类&#xff0c;详细说明在Hermes Agent中配置Taotok…...

深入RKMedia:拆解Rockchip RV1126多媒体框架,看它如何封装RGA/MPP/RKNN

深入解析RKMedia&#xff1a;Rockchip RV1126多媒体框架的设计哲学与实现细节 在嵌入式多媒体处理领域&#xff0c;Rockchip的RV1126平台凭借其出色的能效比和丰富的硬件加速单元&#xff0c;成为智能视觉终端设备的首选方案之一。而RKMedia作为连接应用层与底层硬件的关键中间…...

Wireshark实战:从流量包里‘捞出’图片和压缩包的两种方法(附CTF解题步骤)

Wireshark实战&#xff1a;从流量包里‘捞出’图片和压缩包的两种方法&#xff08;附CTF解题步骤&#xff09; 在网络安全和数字取证领域&#xff0c;网络流量分析是一项基础但至关重要的技能。想象一下这样的场景&#xff1a;你正在调查一起数据泄露事件&#xff0c;或者参加…...

解锁SD-PPP:将AI绘画能力无缝融入Photoshop工作流

解锁SD-PPP&#xff1a;将AI绘画能力无缝融入Photoshop工作流 【免费下载链接】sd-ppp A Photoshop AI plugin 项目地址: https://gitcode.com/gh_mirrors/sd/sd-ppp 你是否曾经在Photoshop中创作时&#xff0c;突然需要一个AI生成的元素来完善设计&#xff0c;却不得不…...

SpinalHDL Pipeline库核心要素解析:从Stageable到流水线构建实战

1. Pipeline核心要素深度解析&#xff1a;从概念到实战在数字电路设计&#xff0c;尤其是处理器流水线这类复杂逻辑的构建中&#xff0c;我们常常需要一种更抽象、更灵活的方式来组织数据流和控制流。传统的RTL描述方式在面对多级流水、动态数据传递和复杂交互时&#xff0c;代…...

Beyond Compare 5密钥生成终极指南:5分钟免费激活完整教程

Beyond Compare 5密钥生成终极指南&#xff1a;5分钟免费激活完整教程 【免费下载链接】BCompare_Keygen Keygen for BCompare 5 项目地址: https://gitcode.com/gh_mirrors/bc/BCompare_Keygen Beyond Compare 5作为专业的文件对比工具&#xff0c;在30天试用期结束后会…...