2025.2.16机器学习笔记:TimeGan文献阅读
2025.2.9周报
- 一、文献阅读
- 题目信息
- 摘要
- Abstract
- 创新点
- 网络架构
- 一、嵌入函数
- 二、恢复函数
- 三、序列生成器
- 四、序列判别器
- 损失函数
- 实验
- 结论
- 后续展望
一、文献阅读
题目信息
- 题目: Time-series Generative Adversarial Networks
- 会议: Neural Information Processing Systems (NeurIPS)
- 作者: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar
- 发表时间: 2019/12/01
- 文章链接: https://papers.nips.cc/paper/2019/file/c9efe5f26cd17ba6216bbe2a7d26d490-Paper.pdf
- 代码: https://github.com/jsyoon0823/TimeGAN
摘要
用生成模型生成时间序列数据是一件复杂的事,因为其要求生成模型既要捕捉各时间点特征分布,又要学习变量间的动态关系。在时序数据的生成中,自回归模型虽在序列预测中改进了时间动态性,并非真正的生成模型。此外,将生成对抗网络(GAN)框架直接应用于序列数据,但其未充分利用自回归的先验信息,仅靠标准GAN损失求和不能确保生成模型能有效捕捉训练数据中的多步依赖关系。为了解决以上问题,论文作者提出一种时间序列生成对抗网络(Time - GAN),该网络弥补了上述两种模型的缺陷,作者设计了一个包含嵌入函数、恢复函数、序列生成器和序列鉴别器的生成模型,通过有监督损失和无监督损失的学习嵌入空间对抗性和联合训练,让模型得以同时学习编码特征、生成表示和跨时间迭代。最后,作者通过多种实验证明该模型在生成现实时间序列数据方面相比现有基准模型有明显提升。
Abstract
Generating synthetic time-series data is a complex task, as it requires a generative model to capture both the distribution of features at each time point and the dynamic relationships between variables. Autoregressive models, although improved in terms of temporal dynamics for sequence prediction, are inherently deterministic and do not qualify as true generative models. On the other hand, directly applying the Generative Adversarial Network (GAN) framework to sequential data does not fully leverage the autoregressive prior, and relying solely on standard GAN loss summation is insufficient to ensure that the generative model can effectively capture the multi-step dependencies present in the training data.To address these issues, this paper proposes a Time Series Generative Adversarial Network (Time-GAN) that remedies the shortcomings of the aforementioned models. The proposed network is explicitly trained to preserve temporal dynamics through adversarial and joint training in a learned embedding space with both supervised and unsupervised losses. The authors demonstrate through various experiments that this model significantly outperforms existing benchmark models in generating realistic time-series data.
创新点
1、引入监督损失以更好捕捉时间动态。
2、采用嵌入网络提供低维对抗学习空间。
3、提出联合训练方案,使TimeGAN能同时编码、生成和迭代。
网络架构
作者提出时间序列生成对抗网络(Time-GAN)架构如下图所示:
其中包含嵌入函数和恢复函数、序列生成器和序列判别器四个网络部分

在分析之前,我们首先看看作者提出背景问题

作者认为时间序列数据由两部分组成,如下图所示:

作者还提到在生成时间序列数据的中,通常需要学习一个能够描述整个序列(包括静态特征和时间特征)的概率分布。然而,直接学习整个序列的联合概率分布是很难的,因为它可能遇到长序列、高维特征空间以及复杂的数据分布等情况。
因此作者采用了一种名为自回归分解(autoregressive decomposition)的方法,即将整个序列的联合概率分布分解为一系列条件概率分布的乘积。
公式如下图所示:

此外, X 1 : t − 1 X_{1:t-1} X1:t−1表示表示从时间步 1 到时间步 t−1的所有时间序列数据。
这样就可以将学习整个序列的联合概率分布转化为学习每个时间点的条件概率分布,将复杂问题简单化。转化为条件概率分布可以让训练中模型更容易学习和优化。此外,时间序列数据通常具有自回归性(即当前时间点的值往往依赖于之前时间点的值),通过转化,模型可以更好地捕捉这种时间依赖性。
下面我们对网络架构进行分析:
一、嵌入函数
嵌入函数的目的是将括静态特征和时间序列特征映射到一个低维的潜在空间中。
-
因为在平时的数据集中数据都是高维的,高维空间中可能包含很多噪声或不重要的信息,这些信息可能会干扰模型的学习。因此,嵌入函数能够将数据映射到低维空间中,这个潜在空间的表示能够捕捉数据的关键信息的同时降低数据的维度,使得模型更容易学习和处理。
e : S × ∏ t X → H S × ∏ t H X e:\mathcal{S} \times \prod_{t} \mathcal{X} \rightarrow \mathcal{H}_{\mathcal{S}} \times \prod_{t} \mathcal{H}_{\mathcal{X}} e:S×t∏X→HS×t∏HX
其中,S表示静态特征的空间; ∏ t X \prod_{t} \mathcal{X} ∏tX表示时间序列特征在时间步 t 上的笛卡尔积 (使用笛卡尔积操作来表示时间序列数据的联合空间,能方便地映射到一个统一的潜在空间中) ; H S \mathcal{H} _S HS和 H X \mathcal{H} _X HX分别表示静态特征和时间序列特征的潜在空间。 -
嵌入函数 e e e 将 s s s (静态特征)和 x 1 : T x _{1:T} x1:T(时间序列特征)映射到它们的潜在代码 h S \mathcal{h}_S hS 和 h 1 : T \mathcal{h}_{1:T} h1:T 。公式表示如下:
h S = e S ( s ) \mathbf{h}_{\mathcal{S}}=e_{\mathcal{S}}(\mathbf{s}) hS=eS(s)嵌入空间是特征空间到潜在空间的映射。在这个空间中,数据被表示为低维的潜在代码。例如,如果潜在空间是一个 2D 平面,那么潜在代码就是这个平面上的一个点 (x,y)。 -
时间序列特征的嵌入网络 e X \mathcal{e}_X eX 是一个递归网络,它不仅依赖于当前的时间序列特征 x t x_t xt,还依赖于前一个时间步的潜在代码 h t − 1 h_{t−1} ht−1 和静态特征 h S \mathcal{h}_S hS。公式表示如下:
h t = e X ( h S , h t − 1 , x t ) \mathbf{h}_{t}=e_{\mathcal{X}}\left(\mathbf{h}_{\mathcal{S}}, \mathbf{h}_{t-1}, \mathbf{x}_{t}\right) ht=eX(hS,ht−1,xt)
二、恢复函数
恢复函数的作用是将潜在空间中的代码再映射回原始数据的特征空间。
- r r r(恢复函数)将 h S \mathcal{h}_S hS(静态潜在代码)和 h 1 : T \mathcal{h}_{1:T} h1:T (时间序列潜在代码 )映射回 s s s (静态特征)和 x 1 : T x _{1:T} x1:T(时间序列特征)
公式表示如下:
r : H S × ∏ t H X → S × ∏ t X r: \mathcal{H}_{\mathcal{S}} \times \prod_{t} \mathcal{H}_{\mathcal{X}} \rightarrow \mathcal{S} \times \prod_{t} \mathcal{X} r:HS×t∏HX→S×t∏X
s ~ = r S ( h s ) , x ~ t = r X ( h t ) \tilde{\mathbf{s}}=r_{\mathcal{S}}\left(\mathbf{h}_{s}\right), \quad \tilde{\mathbf{x}}_{t}=r_{\mathcal{X}}\left(\mathbf{h}_{t}\right) s~=rS(hs),x~t=rX(ht)
其中,静态特征的恢复网络 r S r_S rS 将静态潜在代码 h S h_S hS 映射回静态特征 s ~ \tilde{s} s~。
时间序列特征的恢复网络 r X {r}_{X} rX 将每个时间步的潜在代码 ht 映射回对应的时间序列特征 x ~ t \tilde{x}_t x~t。
嵌入和恢复函数需要每一步的输出只能依赖于前面的信息,不能“看到”未来的信息。 这样确保模型能够正确地模拟时间序列数据的生成过程。
嵌入和恢复函数可以采用任何网络架构(如时间卷积网络、注意力解码器、循环神经网络等等)
三、序列生成器
序列生成器的用于生成合成的数据。它不是直接在特征空间生成数据,而是首先在嵌入空间生成数据。
生成器 g g g通过静态和时间序列的随机向量 z z z,生成合成的潜在代码 h h h

为什么生成的时间序列潜在表示 h ^ t \hat{\mathbf{h}}_{t} h^t的生成需要静态特征的潜在表示 h ^ S \hat{\mathbf{h}}_{S} h^S?
因为 h ^ t \hat{\mathbf{h}}_{t} h^t编码了静态特征的全局信息,这些信息对时间序列的生成有重要影响。 比如在生成一个人的身高序列时,个人的静态特征(如:年龄和性别)可能影响身高趋势的整体情况。
四、序列判别器
序列判别器的目的是区分真实数据和生成数据。
它也工作在嵌入空间中。判别器 d : H S × ∏ t H X → [ 0 , 1 ] × ∏ t [ 0 , 1 ] d: \mathcal{H}_{S} \times \prod_{t} \mathcal{H}_{\mathcal{X}} \rightarrow[0,1] \times \prod_{t}[0,1] d:HS×∏tHX→[0,1]×∏t[0,1] 接受静态和时间序列的潜在代码,返回分类结果。这些分类结果表示数据是真实的还是合成的。
-
静态特征的判别网络 d S d_S dS 直接对静态潜在代码 h ~ S \tilde{h}_S h~S 进行分类
y ~ S = d S ( h ~ S ) \tilde{y}_{\mathcal{S}}=d_{\mathcal{S}}\left(\tilde{\mathbf{h}}_{\mathcal{S}}\right) y~S=dS(h~S)
其中 y ~ S \tilde{y}_{\mathcal{S}} y~S是一个值概率值范围是[0,1]。 -
时间序列特征的判别网络 d X d_X dX 使用一个双向递归网络来处理时间序列数据。
它考虑了前向和后向的隐藏状态 u → t = c ⃗ X ( h ~ S , h ~ t , u → t − 1 ) \overrightarrow{\mathbf{u}}_{t}=\vec{c}_{\mathcal{X}}\left(\tilde{\mathbf{h}}_{\mathcal{S}}, \tilde{\mathbf{h}}_{t}, \overrightarrow{\mathbf{u}}_{t-1}\right) ut=cX(h~S,h~t,ut−1) 和 u ← t = c ← X ( h ~ S , h ~ t , u ← t + 1 ) \overleftarrow{\mathbf{u}}_{t}=\overleftarrow{c}_{\mathcal{X}}\left(\tilde{\mathbf{h}}_{\mathcal{S}}, \tilde{\mathbf{h}}_{t}, \overleftarrow{\mathbf{u}}_{t+1}\right) ut=cX(h~S,h~t,ut+1)(其中, c ← X \overleftarrow{c}_{\mathcal{X}} cX与 c → X \overrightarrow{c}_{\mathcal{X}} cX前向和后向的循环函数)
y ~ t = d X ( u ← t , u → t ) \quad \tilde{y}_{t}=d_{\mathcal{X}}\left(\overleftarrow{\mathbf{u}}_{t}, \overrightarrow{\mathbf{u}}_{t}\right) y~t=dX(ut,ut)
生成器通过递归网络生成潜在表示,这些代码表示合成的时间序列数据。而判别器通过双向递归网络来区分真实数据和生成数据的潜在表示。
损失函数
1. 重建损失
嵌入和恢复函数能够准确地从潜在代码 h S , h 1 : T h_S,h_{1:T} hS,h1:T 重建原始数据 s , x 1 : T s,x_{1:T} s,x1:T。
重建损失 L R \mathcal{L}_{R} LR 用于衡量重建数据与原始数据之间的差异。
公式如下:
L R = E s , x 1 : T ∼ p [ ∥ s − s ~ ∥ 2 + ∑ t ∥ x t − x ~ t ∥ 2 ] \mathcal{L}_{R}=\mathbb{E}_{s, x_{1: T} \sim p}\left[\|s-\tilde{s}\|_{2}+\sum_{t}\left\|x_{t}-\tilde{x}_{t}\right\|_{2}\right] LR=Es,x1:T∼p[∥s−s~∥2+t∑∥xt−x~t∥2]
其中, E s , x 1 : T ∼ p \mathbb{E}_{s, x_{1: T}\sim p} Es,x1:T∼p表示对所有可能的静态特征 s s s和时间序列特征 x 1 : T x_{1:T} x1:T的期望值; ∥ s − s ~ ∥ 2 \|s-\tilde{s}\|_{2} ∥s−s~∥2表示原始静态特征 s 和重建的静态特征 s ~ \tilde{s} s~之间的欧几里得距离;同样的, ∑ t ∥ x t − x ~ t ∥ 2 \sum_{t}\left\|x_{t}-\tilde{x}_{t}\right\|_{2} ∑t∥xt−x~t∥2表示原始时间序列特征 x t x_t xt 和重建的时间序列特征 x t ~ \tilde{x_{t}} xt~之间的欧几里得距离。
2. 无监督损失
生成器(自回归)通过合成嵌入 h S , h 1 : t − 1 h_S,h_{1:t−1} hS,h1:t−1生成下一个合成向量 h t ^ \hat{h_{t}} ht^。然后计算无监督损失的梯度,公式如下:
L U = E s , x 1 : T ∼ p [ log y S + ∑ t log y t ] + E s , x 1 : T ∼ p ^ [ log ( 1 − y ^ S ) + ∑ t log ( 1 − y ^ t ) ] \mathcal{L}_{U}=\mathbb{E}_{s, x_{1: T} \sim p}\left[\log y_{\mathcal{S}}+\sum_{t} \log y_{t}\right]+\mathbb{E}_{s, x_{1: T} \sim \hat{p}}\left[\log \left(1-\hat{y}_{\mathcal{S}}\right)+\sum_{t} \log \left(1-\hat{y}_{t}\right)\right] LU=Es,x1:T∼p[logyS+t∑logyt]+Es,x1:T∼p^[log(1−y^S)+t∑log(1−y^t)]
其实就是GAN的公式具体可以参考我之间关于GAN的博客,里面有公式的具体说明:https://blog.csdn.net/Zcymatics/article/details/145011685?spm=1001.2014.3001.5501
其中, y S y_S yS 和 y t y_t yt 是真实数据的分类结果, y S ^ \hat{y_S} yS^ 和 y t ^ \hat{y_t} yt^是生成数据的分类结果。
3. 监督损失
作者认为仅依赖判别器的二元对抗不足以让生成器捕捉数据中的逐步条件分布。所以作者引入了额外的损失来进一步加强模型的捕捉特征能力。在交替的方式中,生成器还以循环迭代的模式训练即通过实际数据的嵌入序列 h 1 : t − 1 h_{1:t−1} h1:t−1以生成下一个潜在向量 h t h_t ht。
作者在损失上计算梯度,该损失得到分布 p ( H t ∣ H S , H 1 : t − 1 ) p\left(H_{t} \mid H_{\mathcal{S}}, H_{1: t-1}\right) p(Ht∣HS,H1:t−1)和 p ^ ( H t ∣ H S , H 1 : t − 1 ) \hat{p}\left(H_{t} \mid H_{\mathcal{S}}, H_{1: t-1}\right) p^(Ht∣HS,H1:t−1)之间的差异。其应用最大似然得到监督损失,公式如下:
L S = E s , x 1 : T ∼ p [ ∑ t ∥ h t − g X ( h S , h t − 1 , z t ) ∥ 2 ] \mathcal{L}_{S}=\mathbb{E}_{s, x_{1: T} \sim p}\left[\sum_{t}\left\|h_{t}-g_{\mathcal{X}}\left(h_{\mathcal{S}}, h_{t-1}, z_{t}\right)\right\|_{2}\right] LS=Es,x1:T∼p[t∑∥ht−gX(hS,ht−1,zt)∥2]
优化过程:
下图展示了训练过程中的方法:
- 令 θ e , θ r , θ g , θ d θ_e,θ_r,θ_g,θ_d θe,θr,θg,θd 分别表示嵌入、恢复、生成器和判别器网络的参数。
- 嵌入函数和恢复函数在重建损失和监督损失上进行训练,过程可以表达为: min θ e , θ r ( λ L S + L R ) \min _{\theta_e, \theta_r}\left(\lambda \mathcal{L}_S+\mathcal{L}_R\right) minθe,θr(λLS+LR).其中,λ≥0 是一个超参数,用于平衡两个损失。
- 然后生成器和判别器网络以对抗方式进行训练(即在有监督和无监督损下失混合训练),过程可以表达为: min θ g ( η L S + max θ d L U ) \min _{\theta_g}\left(\eta \mathcal{L}_S+\max _{\theta_d} \mathcal{L}_U\right) minθg(ηLS+maxθdLU)。其中,η≥0 是另一个超参数,用于平衡两个损失。

实验
作者通过在多个真实和合成数据集上进行实验。
- 采用定性方法,如t-SNE和PCA分析来可视化生成分布与原始分布的相似性;
- 采用定量方法,如训练后分类器区分真实和生成序列,以及应用在合成数据上训练,在真实数据上测试框架评估生成数据对原始预测特征的保留能力。对不同类型的时间序列数据进行实验,包括具有周期性、离散性、不同噪声水平、时间步长规律性以及时间和特征相关性的数据。
与其他基准模型对比实验结果
1.1 Discriminative Score(判别分数):
在自回归多元高斯数据实验中如表1所示:
TimeGAN在不同的时间相关性(φ)和特征相关性(σ)设置下,判别分数均优于RCGAN、C-RNN-GAN、T-Forcing、P-Forcing、WaveNet和WaveGAN等基准模型。例如,当φ = 0.8且σ = 0.8时,TimeGAN的判别分数为0.105±0.005,而其他模型的分数相对较高。

在正弦、股票、能源、事件数据集实验结果如表2所示:
TimeGAN的判别分数也始终优于其他基准模型。如在股票数据集上,TimeGAN生成样本的判别分数为0.102±0.021,比次优的RCGAN(0.196±0.027)低48%。

1.2 预测分数(Predictive Score):
在自回归多元高斯数据实验中,如表1所示:
TimeGAN在不同的时间相关性(φ)和特征相关性(σ)设置下,预测分数均优于其他基准模型。例如,当φ = 0.8且σ = 0.8时,TimeGAN的预测分数为0.251±0.002,低于其他模型。

在不同类型时间序列数据实验中,结果如表2所示:
TimeGAN的预测分数同样始终优于其他基准模型,并且TimeGAN的预测分数几乎与原始数据集一致。

4. t - SNE和PCA可视化结果
在正弦和股票数据集上进行t - SNE可视化如图3所示:
TimeGAN生成的合成数据集与原始数据的重叠度明显优于其他基准模型。

5. 增益来源分析结果
在对TimeGAN进行修改后的实验中,如表3所示:
作者分析了监督损失、嵌入网络和联合训练方案这三个元素对生成时间序列数据质量的贡献。结果表明这三个元素都对提高生成时间序列数据的质量有重要作用。例如,在股票数据集这种具有高时间相关性的数据中,监督损失的作用尤为重要;嵌入网络和与对抗网络的联合训练也能全面且持续地提高生成性能。

代码如下:
import tensorflow as tf
import numpy as np
from utils import extract_time, rnn_cell, random_generator, batch_generator
def timegan(ori_data, parameters):"""TimeGAN 函数。使用原始数据作为训练集生成合成数据(时间序列)。参数:- ori_data: 原始时间序列数据- parameters: TimeGAN 网络参数返回:- generated_data: 生成的时间序列数据"""# 初始化 TensorFlow 计算图tf.reset_default_graph()# 获取原始数据的基本参数no, seq_len, dim = np.asarray(ori_data).shape # no: 样本数, seq_len: 序列长度, dim: 特征维度# 提取时间信息并计算最大序列长度ori_time, max_seq_len = extract_time(ori_data) # ori_time: 每个样本的时间长度, max_seq_len: 最大序列长度def MinMaxScaler(data):"""Min-Max 归一化器。参数:- data: 原始数据返回:- norm_data: 归一化后的数据- min_val: 最小值(用于反归一化)- max_val: 最大值(用于反归一化)"""min_val = np.min(np.min(data, axis=0), axis=0) # 计算每个特征的最小值data = data - min_val # 数据减去最小值max_val = np.max(np.max(data, axis=0), axis=0) # 计算每个特征的最大值norm_data = data / (max_val + 1e-7) # 归一化数据return norm_data, min_val, max_val# 对原始数据进行归一化ori_data, min_val, max_val = MinMaxScaler(ori_data)## 构建 RNN 网络# 网络参数hidden_dim = parameters['hidden_dim'] # 隐藏层维度num_layers = parameters['num_layer'] # RNN 层数iterations = parameters['iterations'] # 训练迭代次数batch_size = parameters['batch_size'] # 批量大小module_name = parameters['module'] # RNN 模块名称(如 LSTM 或 GRU)z_dim = dim # 随机噪声的维度gamma = 1 # 超参数,用于调整损失函数# 输入占位符X = tf.placeholder(tf.float32, [None, max_seq_len, dim], name="myinput_x") # 输入时间序列数据Z = tf.placeholder(tf.float32, [None, max_seq_len, z_dim], name="myinput_z") # 输入随机噪声T = tf.placeholder(tf.int32, [None], name="myinput_t") # 输入时间信息def embedder(X, T):"""嵌入网络:将原始特征空间映射到潜在空间。参数:- X: 输入时间序列特征- T: 输入时间信息返回:- H: 嵌入表示"""with tf.variable_scope("embedder", reuse=tf.AUTO_REUSE):e_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers)]) # 创建多层 RNN 单元e_outputs, e_last_states = tf.nn.dynamic_rnn(e_cell, X, dtype=tf.float32, sequence_length=T) # 动态 RNNH = tf.contrib.layers.fully_connected(e_outputs, hidden_dim, activation_fn=tf.nn.sigmoid) # 全连接层return Hdef recovery(H, T):"""恢复网络:从潜在空间映射回原始空间。参数:- H: 潜在表示- T: 输入时间信息返回:- X_tilde: 恢复的数据"""with tf.variable_scope("recovery", reuse=tf.AUTO_REUSE):r_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers)]) # 创建多层 RNN 单元r_outputs, r_last_states = tf.nn.dynamic_rnn(r_cell, H, dtype=tf.float32, sequence_length=T) # 动态 RNNX_tilde = tf.contrib.layers.fully_connected(r_outputs, dim, activation_fn=tf.nn.sigmoid) # 全连接层return X_tildedef generator(Z, T):"""生成器函数:在潜在空间中生成时间序列数据。参数:- Z: 随机噪声- T: 输入时间信息返回:- E: 生成的嵌入表示"""with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):e_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers)]) # 创建多层 RNN 单元e_outputs, e_last_states = tf.nn.dynamic_rnn(e_cell, Z, dtype=tf.float32, sequence_length=T) # 动态 RNNE = tf.contrib.layers.fully_connected(e_outputs, hidden_dim, activation_fn=tf.nn.sigmoid) # 全连接层return Edef supervisor(H, T):"""监督器函数:使用前一序列生成下一序列。参数:- H: 潜在表示- T: 输入时间信息- 返回:- S: 基于生成器生成的潜在表示生成的序列"""with tf.variable_scope("supervisor", reuse=tf.AUTO_REUSE):e_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers - 1)]) # 创建多层 RNN 单元e_outputs, e_last_states = tf.nn.dynamic_rnn(e_cell, H, dtype=tf.float32, sequence_length=T) # 动态 RNNS = tf.contrib.layers.fully_connected(e_outputs, hidden_dim, activation_fn=tf.nn.sigmoid) # 全连接层return Sdef discriminator(H, T):"""判别器函数:区分原始和合成的时间序列数据。参数:- H: 潜在表示- T: 输入时间信息返回:- Y_hat: 原始和合成时间序列的分类结果"""with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):d_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(module_name, hidden_dim) for _ in range(num_layers)]) # 创建多层 RNN 单元d_outputs, d_last_states = tf.nn.dynamic_rnn(d_cell, H, dtype=tf.float32, sequence_length=T) # 动态 RNNY_hat = tf.contrib.layers.fully_connected(d_outputs, 1, activation_fn=None) # 全连接层return Y_hat# 嵌入网络和恢复网络H = embedder(X, T) # 嵌入表示X_tilde = recovery(H, T) # 恢复的数据# 生成器E_hat = generator(Z, T) # 生成的嵌入表示H_hat = supervisor(E_hat, T) # 监督器生成的潜在表示H_hat_supervise = supervisor(H, T) # 监督器生成的潜在表示(基于原始数据)# 合成数据X_hat = recovery(H_hat, T) # 生成的合成数据# 判别器Y_fake = discriminator(H_hat, T) # 判别器对合成数据的输出Y_real = discriminator(H, T) # 判别器对原始数据的输出Y_fake_e = discriminator(E_hat, T) # 判别器对生成器输出的嵌入表示的输出# 变量e_vars = [v for v in tf.trainable_variables() if v.name.startswith('embedder')] # 嵌入网络变量r_vars = [v for v in tf.trainable_variables() if v.name.startswith('recovery')] # 恢复网络变量g_vars = [v for v in tf.trainable_variables() if v.name.startswith('generator')] # 生成器变量s_vars = [v for v in tf.trainable_variables() if v.name.startswith('supervisor')] # 监督器变量d_vars = [v for v in tf.trainable_variables() if v.name.startswith('discriminator')] # 判别器变量# 判别器损失D_loss_real = tf.losses.sigmoid_cross_entropy(tf.ones_like(Y_real), Y_real) # 真实数据的损失D_loss_fake = tf.losses.sigmoid_cross_entropy(tf.zeros_like(Y_fake), Y_fake) # 合成数据的损失D_loss_fake_e = tf.losses.sigmoid_cross_entropy(tf.zeros_like(Y_fake_e), Y_fake_e) # 生成器嵌入表示的损失D_loss = D_loss_real + D_loss_fake + gamma * D_loss_fake_e # 判别器总损失# 生成器损失# 1. 对抗损失G_loss_U = tf.losses.sigmoid_cross_entropy(tf.ones_like(Y_fake), Y_fake) # 合成数据的对抗损失G_loss_U_e = tf.losses.sigmoid_cross_entropy(tf.ones_like(Y_fake_e), Y_fake_e) # 生成器嵌入表示的对抗损失# 2. 监督损失G_loss_S = tf.losses.mean_squared_error(H[:, 1:, :], H_hat_supervise[:, :-1, :]) # 监督损失# 3. 二阶矩损失G_loss_V1 = tf.reduce_mean(tf.abs(tf.sqrt(tf.nn.moments(X_hat, [0])[1] + 1e-6) - tf.sqrt(tf.nn.moments(X, [0])[1] + 1e-6))) # 方差损失G_loss_V2 = tf.reduce_mean(tf.abs((tf.nn.moments(X_hat, [0])[0]) - (tf.nn.moments(X, [0])[0]))) # 均值损失G_loss_V = G_loss_V1 + G_loss_V2 # 二阶矩总损失# 4. 总生成器损失G_loss = G_loss_U + gamma * G_loss_U_e + 100 * tf.sqrt(G_loss_S) + 100 * G_loss_V# 嵌入网络损失E_loss_T0 = tf.losses.mean_squared_error(X, X_tilde) # 嵌入网络的恢复损失E_loss0 = 10 * tf.sqrt(E_loss_T0) # 嵌入网络的总损失E_loss = E_loss0 + 0.1 * G_loss_S # 嵌入网络的最终损失# 优化器E0_solver = tf.train.AdamOptimizer().minimize(E_loss0, var_list=e_vars + r_vars) # 嵌入网络优化器E_solver = tf.train.AdamOptimizer().minimize(E_loss, var_list=e_vars + r_vars) # 嵌入网络优化器D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=d_vars) # 判别器优化器G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=g_vars + s_vars) # 生成器优化器GS_solver = tf.train.AdamOptimizer().minimize(G_loss_S, var_list=g_vars + s_vars) # 监督器优化器## TimeGAN 训练sess = tf.Session() # 创建 TensorFlow 会话sess.run(tf.global_variables_initializer()) # 初始化所有变量# 1. 嵌入网络训练print('Start Embedding Network Training')for itt in range(iterations):# 生成小批量数据X_mb, T_mb = batch_generator(ori_data, ori_time, batch_size)# 训练嵌入网络_, step_e_loss = sess.run([E0_solver, E_loss_T0], feed_dict={X: X_mb, T: T_mb})# 打印训练进度if itt % 1000 == 0:print('step: ' + str(itt) + '/' + str(iterations) + ', e_loss: ' + str(np.round(np.sqrt(step_e_loss), 4)))print('Finish Embedding Network Training')# 2. 仅使用监督损失训练print('Start Training with Supervised Loss Only')for itt in range(iterations):# 生成小批量数据X_mb, T_mb = batch_generator(ori_data, ori_time, batch_size)# 生成随机噪声Z_mb = random_generator(batch_size, z_dim, T_mb, max_seq_len)# 训练生成器_, step_g_loss_s = sess.run([GS_solver, G_loss_S], feed_dict={Z: Z_mb, X: X_mb, T: T_mb})# 打印训练进度if itt % 1000 == 0:print('step: ' + str(itt) + '/' + str(iterations) + ', s_loss: ' + str(np.round(np.sqrt(step_g_loss_s), 4))print('Finish Training with Supervised Loss Only')# 3. 联合训练print('Start Joint Training')for itt in range(iterations):# 生成器训练(比判别器训练多一次)for kk in range(2):# 生成小批量数据X_mb, T_mb = batch_generator(ori_data, ori_time, batch_size)# 生成随机噪声Z_mb = random_generator(batch_size, z_dim, T_mb, max_seq_len)# 训练生成器_, step_g_loss_u, step_g_loss_s, step_g_loss_v = sess.run([G_solver, G_loss_U, G_loss_S, G_loss_V], feed_dict={Z: Z_mb, X: X_mb, T: T_mb})# 训练嵌入网络_, step_e_loss_t0 = sess.run([E_solver, E_loss_T0], feed_dict={Z: Z_mb, X: X_mb, T: T_mb})# 判别器训练# 生成小批量数据X_mb, T_mb = batch_generator(ori_data, ori_time, batch_size)# 生成随机噪声Z_mb = random_generator(batch_size, z_dim, T_mb, max_seq_len)# 检查判别器损失check_d_loss = sess.run(D_loss, feed_dict={X: X_mb, T: T_mb, Z: Z_mb})# 如果判别器损失较大,则训练判别器if check_d_loss > 0.15:_, step_d_loss = sess.run([D_solver, D_loss], feed_dict={X: X_mb, T: T_mb, Z: Z_mb})# 打印训练进度if itt % 1000 == 0:print('step: ' + str(itt) + '/' + str(iterations) +', d_loss: ' + str(np.round(step_d_loss, 4)) +', g_loss_u: ' + str(np.round(step_g_loss_u, 4)) +', g_loss_s: ' + str(np.round(np.sqrt(step_g_loss_s), 4)) +', g_loss_v: ' + str(np.round(step_g_loss_v, 4)) +', e_loss_t0: ' + str(np.round(np.sqrt(step_e_loss_t0), 4)))print('Finish Joint Training')## 合成数据生成Z_mb = random_generator(no, z_dim, ori_time, max_seq_len) # 生成随机噪声generated_data_curr = sess.run(X_hat, feed_dict={Z: Z_mb, X: ori_data, T: ori_time}) # 生成合成数据generated_data = list()# 将生成的合成数据裁剪为原始时间长度for i in range(no):temp = generated_data_curr[i, :ori_time[i], :]generated_data.append(temp)# 反归一化generated_data = generated_data * max_valgenerated_data = generated_data + min_valreturn generated_data
结论
本篇论文提出TimeGAN这一新型时间序列生成框架,它结合无监督GAN方法的通用性和监督自回归模型对条件时间动态的控制。通过监督损失和联合训练嵌入网络,TimeGAN在生成现实时间序列数据方面较现有基准有显著改进。TimeGAN不仅依赖于二元对抗反馈进行学习,还通过采样从学习到的分布中生成数据,这对于合成数据生成非常重要。此外,TimeGAN能够处理不规则采样,并且通过嵌入网络识别数据的低维空间,从而学习数据的逐步分布和潜在动态。实验结果表明TimeGAN在多个数据集上的实验表明其能生成高质量数据,有助于提升数据可用性,对相关领域的时间序列数据的研究方面具有实际应用价值。
后续展望
未来可将差分隐私框架融入TimeGAN,以生成有差分隐私保证的高质量时间序列数据。此外,还可探索在更多类型数据或复杂场景下的应用,进一步优化模型结构或参数以提高性能等。
相关文章:
2025.2.16机器学习笔记:TimeGan文献阅读
2025.2.9周报 一、文献阅读题目信息摘要Abstract创新点网络架构一、嵌入函数二、恢复函数三、序列生成器四、序列判别器损失函数 实验结论后续展望 一、文献阅读 题目信息 题目: Time-series Generative Adversarial Networks会议: Neural Information…...
最新智能优化算法: 中华穿山甲优化( Chinese Pangolin Optimizer ,CPO)算法求解23个经典函数测试集,MATLAB代码
中华穿山甲优化( Chinese Pangolin Optimizer ,CPO)算法由GUO Zhiqing 等人提出,该算法的灵感来自中华穿山甲独特的狩猎行为,包括引诱和捕食行为。 算法流程如下: 1. 开始 设置算法参数和最大迭代次数&a…...
使用 DeepSeek + 语音转文字工具 实现会议整理
目录 简述 1. DeepSeek与常用的语音转文字工具 1.1 DeepSeek 1.2 讯飞听见 1.3 飞书妙记 1.4 剪映电脑版 1.5 Buzz 2. 安装Buzz 3. 使用DeepSeek Buzz提取并整理语音文字 3.1 使用 Buzz 完成语音转文字工作 3.2 使用 DeepSeek 进行文本处理工作 3.3 整理成思维导图…...
【OS安装与使用】part4-ubuntu22.04安装anaconda
文章目录 一、待解决问题1.1 问题描述1.2 解决方法 二、方法详述2.1 必要说明2.2 应用步骤2.2.1 官网下载Anaconda(1)确认自己的系统型号与硬件架构(2)官网下载对应版本 2.2.2 安装Anaconda(1)基于shell脚本…...
把程序加入开机自启动
一、Windows 系统 方法 1:通过启动文件夹 1. 按下 Win R,输入 shell:startup,回车打开 **启动文件夹**。 2. 将应用程序的快捷方式复制到此文件夹中。 右键应用程序主程序(.exe)→ 创建快捷方式 → 拖动到启动文件夹。…...
介绍cherrypick
git cherry-pick 是 Git 中的一个强大命令,用于将一个或多个提交(commit)从一个分支应用到另一个分支。它允许你选择性地将特定的变更引入到当前分支,而无需合并整个分支。以下是对 git cherry-pick 操作的详细介绍: 1…...
Spring IoC DI:控制反转与依赖注入
目录 前言 - Spring MVC 与 Spring IoC 之间的关系 1. IoC 1.1 Spring Framework, Spring MVC, Spring boot 之间的联系[面试题] 1.2 什么是容器 1.3 什么是 IoC 2. DI 2.1 什么是 DI 3. Spring IoC & DI 3.1 Component 3.2 Autowired 4. IoC 详解 4.1 Applica…...
JavaAPI常用类型(包装类、BigDecimal类)
包装类 java语言是面向对象的语言,但是其中的八大基本数据类型不符合面向对象的特征。 因此java为了弥补这样的缺点,为这八种基本数据类型专门设计了八种符合面向对象特征的的类型,这八种具有面向对象特征的类型,统称为包装类&a…...
项目中一些不理解的问题
1.Mybatis是干啥的 他是用来帮我们操作数据库的,相当于是我们的一个助手: 我们想要得到数据库中的什么数据,就可以告诉mybatis,他会给我们想要的结果,同时,我们想要对数据库做出什么操作,也可…...
数字化转型4化:标准化奠基-信息化加速-数字化赋能-智能化引领
随着经济增速的放缓,大国体系所催生的生产力逐渐释放,后续业务的发展愈发需要精耕细作,精益理念也必须深入企业的骨髓。与此同时,在全球经济一体化的大背景下,企业面临着来自国内外同行,甚至是跨行业的激…...
Lineageos 22.1(Android 15) 开机向导制作
一、前言 开机向导原理其实就是将特定的category的Activity加入ComponentResolver,如下 <category android:name"android.intent.category.SETUP_WIZARD"/>然后我们开机启动的时候,FallbackHome结束,然后启动Launcher的时候…...
“让App玩捉迷藏:Android教育平板的‘隐身术’开发实录”
1. 前言:一场App的“消失魔术” 在定制教育平板时,客户要求:“朕要某些App在桌面上消失,只能在系统设置里当个‘幽灵’,而朕一声令下,它们又得原地复活!”于是,程序员们翻开了Androi…...
简单易懂,解析Go语言中的Channel管道
Channel 管道 1 初始化 可用var声明nil管道;用make初始化管道; len(): 缓冲区中元素个数, cap(): 缓冲区大小 //变量声明 var a chan int //使用make初始化 b : make(chan int) //不带缓冲区 c : make(chan stri…...
C++基础知识学习记录—模版和泛型编程
1、模板 概念: 模板可以让类或者函数支持一种通用类型,在编写时不指定固定的类型,在运行时才决定是什么类型,理论上讲可以支持任何类型,提高了代码的重用性。 模板可以让程序员专注于内部算法而忽略具体类型&#x…...
已解决IDEA无法输入中文问题(亲测有效)
前言 在使用IDEA的时候,比如我们想写个注释,可能不经意间,输入法就无法输入中文了,但是在其他地方打字,输入法仍然能够正常工作。这是什么原因呢,这篇文章带你解决这个问题! 快捷键 如果你的I…...
人工智能之目标追踪DeepSort源码解读(yolov5目标检测,代价矩阵,余弦相似度,马氏距离,匹配与预测更新)
要想做好目标追踪,须做好目标检测,所以这里就是基于yolov5检测基础上进行DeepSort,叫它为Yolov5_DeepSort。整体思路是先检测再追踪,基于检测结果进行预测与匹配。 一.参数与演示 这里用到的是coco预训练人的数据集: 二.针对检测结果初始化track 对每一帧数据都输出…...
Copilot基于企业PPT模板生成演示文稿
关于copilot创建PPT,咱们写过较多文章了: Copilot for PowerPoint通过文件创建PPT Copilot如何将word文稿一键转为PPT Copilot一键将PDF转为PPT,治好了我的精神内耗 测评Copilot和ChatGPT-4o从PDF创建PPT功能 Copilot for PPT全新功能&a…...
使用GDI+、文件和目录和打印API,批量将图片按文件名分组打包成PDF
代码写了两个小时,速度太慢(包括学习文档的时间) #include <stdio.h> #include <Windows.h> #include <gdiplus.h> #include <string.h> using namespace Gdiplus; #pragma comment(lib, "Gdiplus.lib") …...
【Linux】【网络】Libevent基础
【Linux】【网络】Libevent基础 libevent 是轻量级 c语言实现的 网络io库 能够跨平台 且线程安全 是单线程的 libevent 的使用过程通常包括几个主要步骤: 1.创建Libevent实例2.注册事件、添加事件、设置处理事件回调函数3.启动事件循环4.清理资源 1. 创建Libeven…...
MySQL 主从复制原理及其工作过程
一、MySQL主从复制原理 MySQL 主从复制是一种将数据从一个 MySQL 数据库服务器(主服务器,Master)复制到一个或多个 MySQL 数据库服务器(从服务器,Slave)的技术。以下简述其原理,主要包含三个核…...
nginx负载均衡, 解决iphash不均衡的问题之consistent
原因分析 客户端IP分布不均:部分IP段请求集中,导致哈希到同一后端。 服务器数量变动:增删节点时,传统ip_hash未使用一致性哈希,导致分布重置。 哈希键范围过小:例如仅使用IPv4前24位,不同IP可…...
MySQL远程连接配置
一、配置TCP服务地址绑定 配置文件路径 /etc/mysql/mysql.cnf /etc/mysql/mysql.conf.d/mysqld.cnf具体文件可以通过 mysql --help查看 配置项 # 只接受本地连接 bind-address 127.0.0.1 mysqlx-bind-address 127.0.0.1改为 # 接受任意IP地址连接 bind-address …...
Langchain vs. LlamaIndex:哪个在集成MongoDB并分析资产负债表时效果更好?
Langchain vs. LlamaIndex:哪个在集成MongoDB并分析资产负债表时效果更好? 随着大语言模型(LLM)在实际应用中的普及,许多开发者开始寻求能够帮助他们更高效地开发基于语言模型的应用框架。在众多框架中,La…...
iOS开发书籍推荐 - 《高性能 iOS应用开发》(附带链接)
引言 在 iOS 开发的过程中,随着应用功能的增加和用户需求的提升,性能优化成为了不可忽视的一环。尤其是面对复杂的界面、庞大的数据处理以及不断增加的后台操作,如何确保应用的流畅性和响应速度,成为开发者的一大挑战。《高性能 …...
Excel核心函数VLOOKUP全解析:从入门到精通
一、函数概述 VLOOKUP是Excel中最重要且使用频率最高的查找函数之一,全称为Vertical Lookup(垂直查找)。该函数主要用于在数据表的首列查找特定值,并返回该行中指定列的对应值。根据微软官方统计,超过80%的Excel用户在…...
leetcode1047-删除字符串中的所有相邻重复项
leetcode 1047 思路 因为要删除字符串中的所有相邻重复项,那么在删除完成后,最后返回的元素中是不应该存在任何相邻重复项的,如果是普通的遍历,假设str ‘abbaca’,遍历出来只发现中间的bb是相邻重复的删除了以后a…...
解决DeepSeek服务器繁忙问题的实用指南
目录 简述 1. 关于服务器繁忙 1.1 服务器负载与资源限制 1.2 会话管理与连接机制 1.3 客户端配置与网络问题 2. 关于DeepSeek服务的备用选项 2.1 纳米AI搜索 2.2 硅基流动 2.3 秘塔AI搜索 2.4 字节跳动火山引擎 2.5 百度云千帆 2.6 英伟达NIM 2.7 Groq 2.8 Firew…...
软件工程之软件需求SWE.1
物有本末,事有终始。知所先后,则近道矣。对软件开发而言,软件需求乃重中之重。必先之事重千钧,不可或缺如日辰。 汽车行业由于有方法论和各种标准约束,对软件开发有严苛的要求。ASPICE指导如何审核软件开发࿰…...
【面试题】redis大key问题怎么解决?(key访问的次数比较多,key存的数据比较大)
针对 Redis 中大 Key(数据量大且访问频繁)的问题,需从 数据拆分、访问优化、架构设计 等多维度综合解决。以下是具体方案及实施步骤: 一、大 Key 的定义与危害 定义: Value 过大:如 String 类型 Value >…...
web入侵实战分析-常见web攻击类应急处置实验1
场景说明: 某天运维人员发现在/opt/tomcat8/webapps/test/目录下,多出了一个index_bak.jsp这个文件, 并告诉你如下信息 操作系统:ubuntu-16.04业务:测试站点中间件:tomcat开放端口:22&#x…...
