当前位置: 首页 > news >正文

时间卷积网络(TCN)原理+代码详解

目录

  • 一、TCN原理
    • 1.1 因果卷积(Causal Convolution)
    • 1.2 扩张卷积(Dilated Convolution)
  • 二、代码实现
    • 2.1 Chomp1d 模块
    • 2.2 TemporalBlock 模块
    • 2.3 TemporalConvNet 模块
    • 2.4 完整代码示例
  • 参考文献

  在理解 TCN 的原理之前,我们可以先对传统的循环神经网络(RNN)进行简要回顾。RNN 是处理序列数据的常用方法,其核心思想是通过将前一个时间步的隐藏状态传递到下一个时间步,实现对序列依赖关系的建模。然而,RNN 在处理长序列时存在以下几个缺点:

  • 无法并行计算:RNN 的计算依赖于时间步的顺序,导致无法高效利用 GPU 并行计算。

  • 梯度消失/爆炸:在长时间依赖中,梯度在反向传播时会逐渐消失或变得不稳定。

  • 短期记忆限制:由于计算依赖于序列的逐步传递,RNN 难以捕获远距离的时间依赖。

  TCN 正是在这样的背景下提出的。它通过因果卷积和扩张卷积,突破了 RNN 的这些瓶颈,特别适用于长时间序列数据。接下来,我们将详细解析 TCN 的原理。

一、TCN原理

1.1 因果卷积(Causal Convolution)

  在卷积操作中,卷积核在输入上滑动时会同时处理前后时间步的数据,导致当前时间步的输出可能依赖于未来的输入。然而,对于时间序列任务,我们通常希望模型只依赖于过去的输入,不“窥探”未来,这样的结构称为“因果性”。

  TCN 使用因果卷积来确保这一点。因果卷积是指每个时间步的输出仅依赖于它之前的时间步,而不依赖于未来。简单来说,当前时间步的输出只会考虑卷积核覆盖的前几个时间步的输入。

  TCN 通过适当的填充(padding)来实现这一点,使得每一层的卷积不会跨越未来时间步。因果卷积的示意图如下:

在这里插入图片描述

1.2 扩张卷积(Dilated Convolution)

  为了捕捉长时间依赖关系,TCN 通过 扩张卷积(Dilated Convolution 来扩展卷积核的感受野。扩张卷积通过在卷积核的元素之间插入“间隔”,从而在保持卷积核大小不变的情况下,扩大卷积的感受野。

  例如,假设卷积核大小为 3,当扩张率 dilation=2 时,卷积核的元素之间插入 1 个间隔,感受野可以从 3 扩展到 5。通过这种扩张卷积,TCN 在每一层可以通过指数扩展的方式增大感受野,使得模型能够捕捉到远距离的依赖关系。例如,TCN 中第 i i i 层的感受野大小为 2 i 2^{i} 2i,这样层数越深,感受野就越大。如下图所示:

在这里插入图片描述

二、代码实现

2.1 Chomp1d 模块

  TCN 使用填充操作来保证卷积后的时间步不丢失,但填充会导致额外的时间步,因此需要 Chomp1d 来修剪掉多余部分,保证输入输出的时间维度一致。

class Chomp1d(nn.Module):def __init__(self, chomp_size):super(Chomp1d, self).__init__()self.chomp_size = chomp_sizedef forward(self, x):return x[:, :, :-self.chomp_size].contiguous()

  Chomp1d 的作用是对卷积结果的最后几个时间步进行修剪,这确保了卷积核在时间序列两端不会额外输出冗余的步长。

2.2 TemporalBlock 模块

  TemporalBlock 是 TCN 的基本构建单元,包含两层扩张卷积,每层后接激活函数和 Chomp1d 操作。

class TemporalBlock(nn.Module):def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout):super(TemporalBlock, self).__init__()# 第一层卷积self.ll_conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)self.chomp1 = Chomp1d(padding)self.relu1 = nn.LeakyReLU()# 第二层卷积self.ll_conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)self.chomp2 = Chomp1d(padding)self.relu2 = nn.LeakyReLU()# Dropout 作为正则化,防止过拟合self.dropout = nn.Dropout(dropout)def forward(self, x):# 第一个卷积、修剪、激活和 Dropoutout = self.ll_conv1(x)out = self.chomp1(out)out = self.relu1(out)out = self.dropout(out)# 第二个卷积、修剪、激活和 Dropoutout = self.ll_conv2(out)out = self.chomp2(out)out = self.relu2(out)out = self.dropout(out)return out
  • ll_conv1 和 ll_conv2 是两层扩张卷积层,dilation 参数决定了每层的感受野大小。

  • Chomp1d 保证卷积结果不会产生额外的时间步。

  • LeakyReLU 是非线性激活函数,为模型引入非线性。

  • Dropout 用于防止过拟合,通过随机丢弃一部分神经元。

2.3 TemporalConvNet 模块

  TemporalConvNet 是由多个 TemporalBlock 级联组成的模型,每一层的卷积感受野逐层递增。

class TemporalConvNet(nn.Module):def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.0):super(TemporalConvNet, self).__init__()layers = []self.num_levels = len(num_channels)for i in range(self.num_levels):dilation_size = 2 ** i  # 每层的扩张率递增in_channels = num_inputs if i == 0 else num_channels[i - 1]out_channels = num_channels[i]layers.append(TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,padding=(kernel_size - 1) * dilation_size, dropout=dropout))self.network = nn.Sequential(*layers)def forward(self, x):return self.network(x)
  • TemporalConvNet 通过循环构建多层 TemporalBlock,每层的扩张率 dilation 是前一层的两倍,使得感受野指数级增长。

  • 使用 nn.Sequential 将所有层级联在一起,模型最终输出序列数据经过所有层的处理结果。

2.4 完整代码示例

  在这个例子中,输入数据有 8 个样本,每个样本有 3 个特征,序列长度为 10。经过 TCN 网络的三层处理,输出的特征维度从 3 增加到 64,但时间维度(10)保持不变。

import torch.nn as nn
import torch.nn.functional as F
import torchclass Chomp1d(nn.Module):def __init__(self, chomp_size):super(Chomp1d, self).__init__()self.chomp_size = chomp_sizedef forward(self, x):return x[:, :, : -self.chomp_size].contiguous()class TemporalBlock(nn.Module):def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout):super(TemporalBlock, self).__init__()self.n_inputs = n_inputsself.n_outputs = n_outputsself.kernel_size = kernel_sizeself.stride = strideself.dilation = dilationself.padding = paddingself.dropout = dropoutself.ll_conv1 = nn.Conv1d(n_inputs,n_outputs,kernel_size,stride=stride,padding=padding,dilation=dilation,)self.chomp1 = Chomp1d(padding)self.ll_conv2 = nn.Conv1d(n_outputs,n_outputs,kernel_size,stride=stride,padding=padding,dilation=dilation,)self.chomp2 = Chomp1d(padding)self.sigmoid = nn.Sigmoid()def net(self, x, block_num, params=None):layer_name = "ll_tc.ll_temporal_block" + str(block_num)if params is None:x = self.ll_conv1(x)else:x = F.conv1d(x,weight=params[layer_name + ".ll_conv1.weight"],bias=params[layer_name + ".ll_conv1.bias"],stride=self.stride,padding=self.padding,dilation=self.dilation,)x = self.chomp1(x)x = F.leaky_relu(x)return xdef init_weights(self):self.ll_conv1.weight.data.normal_(0, 0.01)self.ll_conv2.weight.data.normal_(0, 0.01)def forward(self, x, block_num, params=None):out = self.net(x, block_num, params)return outclass TemporalConvNet(nn.Module):def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.0):super(TemporalConvNet, self).__init__()layers = []self.num_levels = len(num_channels)for i in range(self.num_levels):dilation_size = 2 ** iin_channels = num_inputs if i == 0 else num_channels[i - 1]out_channels = num_channels[i]setattr(self,"ll_temporal_block{}".format(i),TemporalBlock(in_channels,out_channels,kernel_size,stride=1,dilation=dilation_size,padding=(kernel_size - 1) * dilation_size,dropout=dropout,),)def forward(self, x, params=None):for i in range(self.num_levels):temporal_block = getattr(self, "ll_temporal_block{}".format(i))x = temporal_block(x, i, params=params)return x# 定义一个 TCN 模型,输入通道数为 3,输出通道分别为 16, 32, 64,核大小为 2
tcn = TemporalConvNet(num_inputs=3, num_channels=[16, 32, 64], kernel_size=2, dropout=0.2)# 假设输入的张量形状为 (batch_size, num_inputs, sequence_length)
x = torch.randn(8, 3, 10)  # 8 个样本,3 个输入特征,序列长度为 10# 通过 TCN 进行前向传播
output = tcn(x)print(output.shape)  # 输出的形状为 (batch_size, 64, sequence_length),即 (8, 64, 10)

参考文献

[1] https://github.com/locuslab/TCN

[2] 如何理解扩张卷积(dilated convolution)

[3] 【机器学习】详解 扩张/膨胀/空洞卷积 (Dilated / Atrous Convolution)

相关文章:

时间卷积网络(TCN)原理+代码详解

目录 一、TCN原理1.1 因果卷积(Causal Convolution)1.2 扩张卷积(Dilated Convolution) 二、代码实现2.1 Chomp1d 模块2.2 TemporalBlock 模块2.3 TemporalConvNet 模块2.4 完整代码示例 参考文献 在理解 TCN 的原理之前&#xff…...

零散的知识

1.物化 在SQL中,物化(Materialization)是指将查询结果保存为物理数据结构以供后续使用的过程。这与普通的视图或查询不同,物化视图会存储查询的结果,而不是每次查询时都动态地重新计算数据。 ①物化视图 物化视图是一…...

Python读取pdf中的文字与表格

一、PyPDF2包安装 在Python中安装PyPDF2库,您可以使用pip包管理器。打开您的命令行工具(例如CMD、Terminal或Anaconda Prompt),然后输入以下命令: pip install PyPDF2 如果您使用的是Python 3,并且系统中…...

【MySQL 08】复合查询

目录 1.准备工作 2.多表查询 笛卡尔积 多表查询案例 3. 自连接 4.子查询 1.单行子查询 2.多行子查询 3.多列子查询 4.在from子句中使用子查询 5.合并查询 1.union 2.union all 1.准备工作 如下三个表,将作为示例,理解复合查询 EMP员工表…...

求1000以内的完数

题目:一个数如果恰好等于他的因子之和(包括1,但不包括这个数),这个数就是完数。编写算法找出1000之内的所有完数,并按下面格式输出其因子:28 its factors are 1,2,4,7,14 代码如下:…...

sqli-labs less-16 post提交dnslog注入

post提交DNSlog注入 第十六关和和十五关大差不大,可以使用布尔注入,时间盲注等,只不过闭合方式不一样,但是用布尔和时间盲太过于消耗时间,本次测试我将使用dnslog注入。 使用在线平台http://www.dnslog.cn/ 闭合方式…...

nginx报错|xquic|xqc_engine_create: fail|

一.问题描述 nginx使用xquic协议一切安装正常,nginx -s reload也正常,但就是访问不了网页 [emerg] 12342#0: |xquic|xqc_engine_create: fail| [emerg] 12342#0: |xquic|ngx_xquic_process_init|engine_init fail| [emerg] 12341#0: |xquic|xqc_engine_create: fai…...

Java虚拟机(JVM)

目录 内存区域划分堆(Heap)方法区(Method Area)程序计数器(Program Counter Register)虚拟机栈(VM Stack)本地方法栈(Native Method Stack) 类加载的过程类加…...

MQ 架构设计原理与消息中间件详解(三)

RabbitMQ实战解决方案 RabbitMQ死信队列 死信队列产生的背景 RabbitMQ死信队列俗称,备胎队列;消息中间件因为某种原因拒收该消息后,可以转移到死信队列中存放,死信队列也可以有交换机和路由key等。 产生死信队列的原因 消息投…...

大数据新视界 --大数据大厂之 Alluxio 数据缓存系统在大数据中的应用与配置

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…...

PHP基本语法总结

目录 输出语句 注释 数据类型(变量) 局部和全局作用域 类型比较(松散比较与严格比较) 常量 运算符 并置运算符 不等于 逻辑运算符 条件语句 数组 关联数组 数组排序 一般数组 关联数组 循环 函数 变量函数 魔…...

尚硅谷rabbitmq 2024第30-33节 死信队列 答疑

Virtual host: Type: Name: Durabiity: Arguments: Default for virtual host w ququt.normal.video Durable x-dead-letter-exchange x-dead-1etter-routing-xey x-mAx-1ength X-m在88点0也-6E1 exchange.dead.letter.vide zouting.key.dead.ietter.v 10 String String Number…...

解锁空间距离计算的多种方式-含前端、空间数据库、后端

目录 前言 一、空间数据库求解 1、PostGIS实现 二、GIS前端组件求解 1、Leaflet.js距离测算 2、Turf.js前端计算 三、后台距离计算生成 1、欧式距离 2、Haversice球面距离 3、GeoTools距离计算 4、Gdal距离生成 5、geodesy距离计算 四、成果与生成对比 1、Java不…...

Windows 开发工具使用技巧 QT使用安装和使用技巧 QT快捷键

一、QT配置 1. 安装 Qt 开发框架 1、下载 1、进入下载地址 下载地址1 (官方, 需注册账号): https://www.qt.io/download下载地址2(推荐): http://download.qt.io/http://download.qt.io/archive/qt/ (或更直接的…...

【实战教程】SpringBoot全面指南:快速上手到项目实战(SpringBoot)

文章目录 【实战教程】SpringBoot全面指南:快速上手到项目实战(SpringBoot)1. SpringBoot介绍1.1 SpringBoot简介1.2系统要求1.3 SpringBoot和SpringMVC区别1.4 SpringBoot和SpringCloud区别 2.快速入门3. Web开发3.1 静态资源访问3.2 渲染Web页面3.3 YML与Properti…...

LeetCode讲解篇之1043. 分隔数组以得到最大和

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 对于这题我们这么考虑,我们选择以数字的第i个元素做为分隔子数组的右边界,我们需要计算当前分隔子数组的长度为多少时能让数组[0, i]进行分隔数组的和最大 我们用数组f表示[0, i)区间内的…...

Python知识点:结合Python工具,如何使用TfidfVectorizer进行文本特征提取

开篇,先说一个好消息,截止到2025年1月1日前,翻到文末找到我,赠送定制版的开题报告和任务书,先到先得!过期不候! 如何使用Python的TfidfVectorizer进行文本特征提取 在自然语言处理(…...

Diffusion models(扩散模型) 是怎么工作的

前言 给一个提示词, Midjourney, Stable Diffusion 和 DALL-E 可以生成很好看的图片,那么它们是怎么工作的呢?它们都用了 Diffusion models(扩散模型) 这项技术。 Diffusion models 正在成为生命科学等领域的一项尖端技术&…...

查找回收站里隐藏的文件

在Windows里,每个磁盘分区都有一个隐藏的回收站Recycle, 回收站里保存着用户删除的文件、图片、视频等数据,比如,C盘的回收站为C:\RECYCLE.BIN\,D盘的的回收站为D:\RECYCLE.BIN\,E盘的的回收站为E:\RECYCLE…...

[运维]2.elasticsearch-svc连接问题

Serverless 与容器决战在即?有了弹性伸缩就不一样了 - 阿里云云原生 - 博客园 当我部署好elasticsearch的服务后,由于个人习惯,一般服务会在name里带上svc,所以我elasticsearch服务的名字是elasticsearch-svc: [root…...

Python爬虫实战:研究MechanicalSoup库相关技术

一、MechanicalSoup 库概述 1.1 库简介 MechanicalSoup 是一个 Python 库,专为自动化交互网站而设计。它结合了 requests 的 HTTP 请求能力和 BeautifulSoup 的 HTML 解析能力,提供了直观的 API,让我们可以像人类用户一样浏览网页、填写表单和提交请求。 1.2 主要功能特点…...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向: 逆向设计 通过神经网络快速预测微纳结构的光学响应,替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

ubuntu搭建nfs服务centos挂载访问

在Ubuntu上设置NFS服务器 在Ubuntu上,你可以使用apt包管理器来安装NFS服务器。打开终端并运行: sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享,例如/shared: sudo mkdir /shared sud…...

【力扣数据库知识手册笔记】索引

索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?

Otsu 是一种自动阈值化方法,用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理,能够自动确定一个阈值,将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

自然语言处理——Transformer

自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效,它能挖掘数据中的时序信息以及语义信息,但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN,但是…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,社区养老保险系统小程序被用户普遍使用,为方…...

WebRTC调研

WebRTC是什么,为什么,如何使用 WebRTC有什么优势 WebRTC Architecture Amazon KVS WebRTC 其它厂商WebRTC 海康门禁WebRTC 海康门禁其他界面整理 威视通WebRTC 局域网 Google浏览器 Microsoft Edge 公网 RTSP RTMP NVR ONVIF SIP SRT WebRTC协…...

数据库正常,但后端收不到数据原因及解决

从代码和日志来看,后端SQL查询确实返回了数据,但最终user对象却为null。这表明查询结果没有正确映射到User对象上。 在前后端分离,并且ai辅助开发的时候,很容易出现前后端变量名不一致情况,还不报错,只是单…...