当前位置: 首页 > news >正文

lag-llama源码解读(Lag-Llama: Towards Foundation Models for Time Series Forecasting)

Lag-Llama: Towards Foundation Models for Time Series Forecasting
文章内容:
时间序列预测任务,单变量预测单变量,基于Llama大模型,在zero-shot场景下模型表现优异。创新点,引入滞后特征作为协变量来进行预测。

获得不同频率的lag,来自glunoTS库里面的源码

def _make_lags(middle: int, delta: int) -> np.ndarray:"""Create a set of lags around a middle point including +/- delta."""return np.arange(middle - delta, middle + delta + 1).tolist()def get_lags_for_frequency(freq_str: str,lag_ub: int = 1200,num_lags: Optional[int] = None,num_default_lags: int = 7,
) -> List[int]:"""Generates a list of lags that that are appropriate for the given frequencystring.By default all frequencies have the following lags: [1, 2, 3, 4, 5, 6, 7].Remaining lags correspond to the same `season` (+/- `delta`) in previous`k` cycles. Here `delta` and `k` are chosen according to the existing code.Parameters----------freq_strFrequency string of the form [multiple][granularity] such as "12H","5min", "1D" etc.lag_ubThe maximum value for a lag.num_lagsMaximum number of lags; by default all generated lags are returned.num_default_lagsThe number of default lags; by default it is 7."""# Lags are target values at the same `season` (+/- delta) but in the# previous cycle.def _make_lags_for_second(multiple, num_cycles=3):# We use previous ``num_cycles`` hours to generate lagsreturn [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)]def _make_lags_for_minute(multiple, num_cycles=3):# We use previous ``num_cycles`` hours to generate lagsreturn [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)]def _make_lags_for_hour(multiple, num_cycles=7):# We use previous ``num_cycles`` days to generate lagsreturn [_make_lags(k * 24 // multiple, 1) for k in range(1, num_cycles + 1)]def _make_lags_for_day(multiple, num_cycles=4, days_in_week=7, days_in_month=30):# We use previous ``num_cycles`` weeks to generate lags# We use the last month (in addition to 4 weeks) to generate lag.return [_make_lags(k * days_in_week // multiple, 1)for k in range(1, num_cycles + 1)] + [_make_lags(days_in_month // multiple, 1)]def _make_lags_for_week(multiple, num_cycles=3):# We use previous ``num_cycles`` years to generate lags# Additionally, we use previous 4, 8, 12 weeksreturn [_make_lags(k * 52 // multiple, 1) for k in range(1, num_cycles + 1)] + [[4 // multiple, 8 // multiple, 12 // multiple]]def _make_lags_for_month(multiple, num_cycles=3):# We use previous ``num_cycles`` years to generate lagsreturn [_make_lags(k * 12 // multiple, 1) for k in range(1, num_cycles + 1)]# multiple, granularity = get_granularity(freq_str)offset = to_offset(freq_str)# normalize offset name, so that both `W` and `W-SUN` refer to `W`offset_name = norm_freq_str(offset.name)if offset_name == "A":lags = []elif offset_name == "Q":assert (offset.n == 1), "Only multiple 1 is supported for quarterly. Use x month instead."lags = _make_lags_for_month(offset.n * 3.0)elif offset_name == "M":lags = _make_lags_for_month(offset.n)elif offset_name == "W":lags = _make_lags_for_week(offset.n)elif offset_name == "D":lags = _make_lags_for_day(offset.n) + _make_lags_for_week(offset.n / 7.0)elif offset_name == "B":lags = _make_lags_for_day(offset.n, days_in_week=5, days_in_month=22) + _make_lags_for_week(offset.n / 5.0)elif offset_name == "H":lags = (_make_lags_for_hour(offset.n)+ _make_lags_for_day(offset.n / 24)+ _make_lags_for_week(offset.n / (24 * 7)))# minuteselif offset_name == "T":lags = (_make_lags_for_minute(offset.n)+ _make_lags_for_hour(offset.n / 60)+ _make_lags_for_day(offset.n / (60 * 24))+ _make_lags_for_week(offset.n / (60 * 24 * 7)))# secondelif offset_name == "S":lags = (_make_lags_for_second(offset.n)+ _make_lags_for_minute(offset.n / 60)+ _make_lags_for_hour(offset.n / (60 * 60)))else:raise Exception("invalid frequency")# flatten lags list and filterlags = [int(lag) for sub_list in lags for lag in sub_list if 7 < lag <= lag_ub]lags = list(range(1, num_default_lags + 1)) + sorted(list(set(lags)))return lags[:num_lags]

第一部分,生成以middle为中心,以delta为半径的区间[middle-delta,middle+delta] ,这很好理解,比如一周的周期是7天,周期大小在7天附近波动很正常。
在这里插入图片描述

第二部分,对于年月日时分秒这些不同的采样频率,采用不同的具体的函数来确定lags,其中有一个参数num_cycle,进一步利用了周期性,我们考虑间隔1、2、3、…num个周期的时间点之间的联系
在这里插入图片描述
原理类似于这张图,这种周期性的重复性体现在邻近的多个周期上

在这里插入图片描述

lag的用途

计算各类窗口大小

计算采样窗口大小

window_size = estimator.context_length + max(estimator.lags_seq) + estimator.prediction_length# Here we make a window slightly bigger so that instance sampler can sample from each window# An alternative is to have exact size and use different instance sampler (e.g. ValidationSplitSampler)
window_size = 10 * window_size
# We change ValidationSplitSampler to add min_pastestimator.validation_sampler = ValidationSplitSampler(min_past=estimator.context_length + max(estimator.lags_seq),min_future=estimator.prediction_length,)
  1. 构建静态特征
lags = lagged_sequence_values(self.lags_seq, prior_input, input, dim=-1)#构建一个包含给定序列的滞后值的数组static_feat = torch.cat((loc.abs().log1p(), scale.log()), dim=-1)
expanded_static_feat = unsqueeze_expand(static_feat, dim=-2, size=lags.shape[-2]
)return torch.cat((lags, expanded_static_feat, time_feat), dim=-1), loc, scale

数据集准备过程

对每个数据集采样,window_size=13500,也挺离谱的

 train_data, val_data = [], []for name in TRAIN_DATASET_NAMES:new_data = create_sliding_window_dataset(name, window_size)train_data.append(new_data)new_data = create_sliding_window_dataset(name, window_size, is_train=False)val_data.append(new_data)

采样的具体过程,这里有个问题,样本数量很小的数据集,实际采样窗口大小小于设定的window_size,后续会如何对齐呢?

文章设置单变量预测单变量,所以样本进行了通道分离,同一样本的不同特征被采样为不同的样本

def create_sliding_window_dataset(name, window_size, is_train=True):#划分非重叠的滑动窗口数据集,window_size是对数据集采样的数量,对每个数据集只取前windowsize个样本# Splits each time series into non-overlapping sliding windowsglobal_id = 0freq = get_dataset(name, path=dataset_path).metadata.freq#从数据集中获取时间频率data = ListDataset([], freq=freq)#创建空数据集dataset = get_dataset(name, path=dataset_path).train if is_train else get_dataset(name, path=dataset_path).test#获取原始数据集for x in dataset:windows = []#划分滑动窗口#target:滑动窗口的目标值#start:滑动窗口的起始位置#item_id,唯一标识符#feat_static_cat:静态特征数组for i in range(0, len(x['target']), window_size):windows.append({'target': x['target'][i:i+window_size],'start': x['start'] + i,'item_id': str(global_id),'feat_static_cat': np.array([0]),})global_id += 1data += ListDataset(windows, freq=freq)return data

合并数据集

# Here weights are proportional to the number of time series (=sliding windows)weights = [len(x) for x in train_data]# Here weights are proportinal to the number of individual points in all time series# weights = [sum([len(x["target"]) for x in d]) for d in train_data]train_data = CombinedDataset(train_data, weights=weights)val_data = CombinedDataset(val_data, weights=weights)
class CombinedDataset:def __init__(self, datasets, seed=None, weights=None):self._seed = seedself._datasets = datasetsself._weights = weightsn_datasets = len(datasets)if weights is None:#如果未提供权重,默认平均分配权重self._weights = [1 / n_datasets] * n_datasetsdef __iter__(self):return CombinedDatasetIterator(self._datasets, self._seed, self._weights)def __len__(self):return sum([len(ds) for ds in self._datasets])

网络结构

lagllama

class LagLlamaModel(nn.Module):def __init__(self,max_context_length: int,scaling: str,input_size: int,n_layer: int,n_embd: int,n_head: int,lags_seq: List[int],rope_scaling=None,distr_output=StudentTOutput(),num_parallel_samples: int = 100,) -> None:super().__init__()self.lags_seq = lags_seqconfig = LTSMConfig(n_layer=n_layer,n_embd=n_embd,n_head=n_head,block_size=max_context_length,feature_size=input_size * (len(self.lags_seq)) + 2 * input_size + 6,rope_scaling=rope_scaling,)self.num_parallel_samples = num_parallel_samplesif scaling == "mean":self.scaler = MeanScaler(keepdim=True, dim=1)elif scaling == "std":self.scaler = StdScaler(keepdim=True, dim=1)else:self.scaler = NOPScaler(keepdim=True, dim=1)self.distr_output = distr_outputself.param_proj = self.distr_output.get_args_proj(config.n_embd)self.transformer = nn.ModuleDict(dict(wte=nn.Linear(config.feature_size, config.n_embd),h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),ln_f=RMSNorm(config.n_embd),))

主要是transformer里面首先是一个线性层,然后加了n_layer个Block,最后是RMSNorm,接下来解析Block的代码

在这里插入图片描述

Block

class Block(nn.Module):def __init__(self, config: LTSMConfig) -> None:super().__init__()self.rms_1 = RMSNorm(config.n_embd)self.attn = CausalSelfAttention(config)self.rms_2 = RMSNorm(config.n_embd)self.mlp = MLP(config)self.y_cache = Nonedef forward(self, x: torch.Tensor, is_test: bool) -> torch.Tensor:if is_test and self.y_cache is not None:# Only use the most recent one, rest is in cachex = x[:, -1:]x = x + self.attn(self.rms_1(x), is_test)y = x + self.mlp(self.rms_2(x))if is_test:if self.y_cache is None:self.y_cache = y  # Build cacheelse:self.y_cache = torch.cat([self.y_cache, y], dim=1)[:, 1:]  # Update cachereturn y

代码看到这里不太想继续看了,太多glunoTS库里面的函数了,我完全不熟悉这个库,看起来太痛苦了,还有很多的困惑,最大的困惑就是数据是怎么对齐的,怎么输入到Llama里面的,慢慢看吧

其他

来源
在这里插入图片描述

相关文章:

lag-llama源码解读(Lag-Llama: Towards Foundation Models for Time Series Forecasting)

Lag-Llama: Towards Foundation Models for Time Series Forecasting 文章内容&#xff1a; 时间序列预测任务&#xff0c;单变量预测单变量&#xff0c;基于Llama大模型&#xff0c;在zero-shot场景下模型表现优异。创新点&#xff0c;引入滞后特征作为协变量来进行预测。 获得…...

Three.js基础入门介绍——Three.js学习三【借助控制器操作相机】

在Three.js基础入门介绍——Three.js学习二【极简入门】中介绍了如何搭建Three.js开发环境并实现一个包含旋转立方体的场景示例&#xff0c;以此为前提&#xff0c;本篇将引进一个控制器的概念并使用”轨道控制器”&#xff08;OrbitControls&#xff09;来达到从不同方向展示场…...

【日志系列】什么是分布式日志系统?

✔️什么是分布式日志系统&#xff1f; 现在&#xff0c;很多应用都是集群部署的&#xff0c;一次请求会因为负载均衡而被路由到不同的服务器上面&#xff0c;这就导致一个应用的日志会分散在不同的服务器上面。 当我们要向通过日志做数据分析&#xff0c;问题排查的时候&#…...

[卷积神经网络]FCOS--仅使用卷积的Anchor Free目标检测

项目源码&#xff1a; FCOShttps://github.com/tianzhi0549/FCOS/ 一、概述 作为一种Anchor Free的目标检测网络&#xff0c;FCOS并不依赖锚框&#xff0c;这点类似于YOLOx和CenterNet&#xff0c;但CenterNet的思路是寻找目标的中心点&#xff0c;而FCOS则是寻找每个像素点&…...

Ubuntu fcitx Install

ubuntu经常出现键盘失灵的问题 查询资料得知应该是Ibus框架的问题 于是需要安装fcitx框架和搜狗拼音 sudo apt update sudo apt install fcitx 设置fcitx开机自启动&#xff08;建议&#xff09; sudo cp /usr/share/applications/fcitx.desktop /etc/xdg/autostart/ 然后…...

【Makefile/GNU Make】知识总结

文章目录 1. 总体认识2. 编写Makefile2.1. Makefile的组成2.2. Makefile文件名2.3. 包含其他Makefile 3. 编写规则4. 编写规则中的构建命令5. 如何使用变量6. 条件判断7. 转换文本的函数8. 如何运行make9. 使用模糊规则10. 使用make来更新存档文件11. 扩展GNU make12. 集成GNU …...

腾讯云轻量服务器和云服务器CVM该怎么选?区别一览

腾讯云轻量服务器和云服务器CVM该怎么选&#xff1f;不差钱选云服务器CVM&#xff0c;追求性价比选择轻量应用服务器&#xff0c;轻量真优惠呀&#xff0c;活动 https://curl.qcloud.com/oRMoSucP 轻量应用服务器2核2G3M价格62元一年、2核2G4M价格118元一年&#xff0c;540元三…...

MySQL定时备份实现

一、备份数据库 –all-databases 备份所有数据库 /opt/mysqlcopy/all_$(date “%Y-%m-%d %H:%M:%S”).sql 备份地址 docker exec -it 容器名称 sh -c "mysqldump -u root -ppassword --all-databases > /opt/mysqlcopy/all_$(date "%Y-%m-%d %H:%M:%S").sq…...

Nginx 不同源Https请求Http 报strict-origin-when-cross-origin

原因&#xff1a; nginx代理配置url指向只开放了/* 而我/*/*多了一层路径 成功&#xff1a;...

openGauss学习笔记-175 openGauss 数据库运维-备份与恢复-导入数据-管理并发写入操作示例

文章目录 openGauss学习笔记-175 openGauss 数据库运维-备份与恢复-导入数据-管理并发写入操作示例175.1 相同表的INSERT和DELETE并发175.2 相同表的并发INSERT175.3 相同表的并发UPDATE175.4 数据导入和查询的并发 openGauss学习笔记-175 openGauss 数据库运维-备份与恢复-导入…...

pnpm、npm、yarn是什么?怎么选择?

pnpm、npm、yarn三者是前端常用的包管理器&#xff0c;那么他们有什么区别呢&#xff1f; 1. npm (Node Package Manager) npm是Node.js的默认包管理器。自Node.js发布以来&#xff0c;npm就一直作为它的一个组成部分存在&#xff0c;因此&#xff0c;安装Node.js时也会自动安…...

MySQL8 一键部署

#!/bin/bash ### 定义变量 mysql_download_urlhttps://cdn.mysql.com//Downloads/MySQL-8.0/mysql-8.0.33-linux-glibc2.12-x86_64.tar.xz mysql_package_namemysql-8.0.33-linux-glibc2.12-x86_64.tar.xz mysql_dec_namemysql-8.0.33-linux-glibc2.12-x86_64 mysql_download_…...

12 UVM Driver

目录 12.1 uvm_driver class hierarchy 12.2 How to write driver code? 12.3 UVM Driver example 12.4 How to get sequence items from the sequencer? 12.5 UVM driver methods 12.5.1 Using get_next_item/ try_next_item and item_done methods 12.5.2 Using get…...

“暂存”校验逻辑探讨

1、背景 在业务中可能会遇到这种场景&#xff0c;前端页面元素多且复杂&#xff0c;一次性填完提交耗时很长&#xff0c;中间中断面临着丢失数据的风险。针对这个问题&#xff0c;“暂存”应运而生。 那“暂存”的时候&#xff0c;是否需要对数据校验&#xff0c;如何进行校验…...

探究element-ui 2.15.8中<el-input>的keydown事件无效问题

一、问题描述 今天看到一个问题&#xff0c;在用Vue2element-ui 2.15.8开发时&#xff0c;使用input组件绑定keydown事件没有任何效果。 <template><div id"app"><el-input v-model"content" placeholder"请输入" keydown&quo…...

Unity 代码控制Text自适应文本高度

在使用代码给Text赋值时&#xff0c;且文本有多段&#xff0c;并需要根据实际文本高度适配Text组件的高度时&#xff0c;可以使用以下方法&#xff1a; //Text文本 public TextMeshProUGUI text;void Start() {//代码赋值文本text.text "好!\n很好!\n非常好!";//获…...

TiDB 7.1 多租户在中泰证券中的应用

本文详细介绍了中泰证券在系统国产化改造项目中采用 TiDB 多租户技术的实施过程。文章分析了中泰证券数据库系统现状以及引入 TiDB 资源管控技术的必要性&#xff0c;探讨了 TiDB 多租户的关键特性&#xff0c;并阐述了在实际应用中的具体操作步骤。通过该技术的应用&#xff0…...

嵌入式-stm32-SR04超声波测距介绍及实战

一&#xff1a;超声波传感器介绍 1.1、SR04超声波测距硬件模块 1.2、SR04的四个IO口 vcc:提供电源5V gnd:接地 Trig:是**发送**声波信号的触发器 Echo:是**接收**回波信号的引脚 当TRIG信号被触发时&#xff0c;传感器会发送一定频率的声波信号&#xff0c;该信号被反射后&am…...

智能优化算法应用:基于白鲸算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于白鲸算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于白鲸算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.白鲸算法4.实验参数设定5.算法结果6.参考文献7.MA…...

mac m1芯片 pytorch安装及gpu性能测试

pytorch 使用mac的m1芯片进行模型训练。 #小结&#xff1a;在数据量小和模型参数少&#xff0c;batch_size小时&#xff0c;cpu训练更快&#xff08;原因&#xff1a;每次训练时数据需要放入GPU中&#xff0c;由于batch_size小。数据放入gpu比模型计算时间还长&#xff09; 在…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 &#xff08;部分有免费额度&#x…...

中医有效性探讨

文章目录 西医是如何发展到以生物化学为药理基础的现代医学&#xff1f;传统医学奠基期&#xff08;远古 - 17 世纪&#xff09;近代医学转型期&#xff08;17 世纪 - 19 世纪末&#xff09;​现代医学成熟期&#xff08;20世纪至今&#xff09; 中医的源远流长和一脉相承远古至…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

MySQL 8.0 事务全面讲解

以下是一个结合两次回答的 MySQL 8.0 事务全面讲解&#xff0c;涵盖了事务的核心概念、操作示例、失败回滚、隔离级别、事务性 DDL 和 XA 事务等内容&#xff0c;并修正了查看隔离级别的命令。 MySQL 8.0 事务全面讲解 一、事务的核心概念&#xff08;ACID&#xff09; 事务是…...

免费数学几何作图web平台

光锐软件免费数学工具&#xff0c;maths,数学制图&#xff0c;数学作图&#xff0c;几何作图&#xff0c;几何&#xff0c;AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...

django blank 与 null的区别

1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是&#xff0c;要注意以下几点&#xff1a; Django的表单验证与null无关&#xff1a;null参数控制的是数据库层面字段是否可以为NULL&#xff0c;而blank参数控制的是Django表单验证时字…...

二维FDTD算法仿真

二维FDTD算法仿真&#xff0c;并带完全匹配层&#xff0c;输入波形为高斯波、平面波 FDTD_二维/FDTD.zip , 6075 FDTD_二维/FDTD_31.m , 1029 FDTD_二维/FDTD_32.m , 2806 FDTD_二维/FDTD_33.m , 3782 FDTD_二维/FDTD_34.m , 4182 FDTD_二维/FDTD_35.m , 4793...

Spring AOP代理对象生成原理

代理对象生成的关键类是【AnnotationAwareAspectJAutoProxyCreator】&#xff0c;这个类继承了【BeanPostProcessor】是一个后置处理器 在bean对象生命周期中初始化时执行【org.springframework.beans.factory.config.BeanPostProcessor#postProcessAfterInitialization】方法时…...