时间卷积网络(TCN):序列建模的强大工具(附Pytorch网络模型代码)

1. 引言

引用自:Bai S, Kolter J Z, Koltun V. An empirical evaluation of generic convolutional and recurrent networks for sequence modeling. arXiv[J]. arXiv preprint arXiv:1803.01271, 2018, 10.

在这里插入图片描述

时间卷积网络(Temporal Convolutional Network,简称TCN)是一种专门用于处理序列数据的深度学习模型。它结合了卷积神经网络(CNN)的并行处理能力和循环神经网络(RNN)的长期依赖建模能力,成为序列建模任务中的强大工具。实验证明,对于某些任务下的长序LSTM和GRU等RNN架构,因此如果大家有多输入单输出(MISO)或多输入多输出(MIMO)序列建模任务,可以尝试使用TCN来作为创新点。
在这里插入图片描述

2. TCN的核心特性

在这里插入图片描述
图1所示。TCN中的架构元素。(a)一个扩张的因果卷积,其扩张因子d = 1,2,4,滤波器大小k = 3。接收野能够覆盖输入序列中的所有值。(b) TCN残余块。当剩余输入和输出具有不同的维数时,添加1x1卷积。© TCN中剩余连接的示例。蓝线是残差函数中的过滤器,绿线是恒等映射。

2.1 序列建模任务描述

在定义网络结构之前,我们先强调序列建模任务的核心特性。假设我们有输入序列 x 0 , … , x T x_0, \ldots, x_T x0,,xT,并希望在每个时间点预测对应的输出 y 0 , … , y T y_0, \ldots, y_T y0,,yT。关键约束在于,预测某个时间点 t t t 的输出 y t y_t yt 时,我们只能利用此前观察到的输入 x 0 , … , x t x_0, \ldots, x_t x0,,xt。形式上讲,序列建模网络是任何函数 f : X T + 1 → Y T + 1 f : X^{T+1} \rightarrow Y^{T+1} f:XT+1YT+1,它生成如下映射:

y ^ 0 , … , y ^ T = f ( x 0 , … , x T ) \hat{y}_0, \ldots, \hat{y}_T = f(x_0, \ldots, x_T) y^0,,y^T=f(x0,,xT)

若要满足因果性约束,即 y t y_t yt 只依赖于 x 0 , … , x t x_0, \ldots, x_t x0,,xt,而不依赖于任何“未来”的输入 x t + 1 , … , x T x_{t+1}, \ldots, x_T xt+1,,xT。在序列建模的学习目标中,是找到网络 f f f,使其最小化实际输出与预测值间的预期损失, L ( y 0 , … , y T , f ( x 0 , … , x T ) ) L(y_0, \ldots, y_T, f(x_0, \ldots, x_T)) L(y0,,yT,f(x0,,xT)),其中序列和输出根据某一概率分布抽取。

2.2 因果卷积

TCN使用因果卷积(Causal Convolution)来确保模型不会违反时间顺序。因果卷积即输出只依赖于当前时刻及其之前的输入,而不依赖于未来的输入(因为当前的你看不到未来的数据)。在标准的卷积操作中,每个输出值都基于其周围的输入值,包括未来的时间点。但在因果卷积中,权重仅应用于当前和过去的输入值,确保了信息流的方向性,避免了未来信息泄露到当前输出中。为了实现这一点,通常会在卷积核的右侧填充零(称为因果填充),这样只有当前和过去的信息被用于计算输出。

数学表示:

y ( t ) = ∑ i = 0 k − 1 f ( i ) ⋅ x ( t − i ) y(t) = \sum_{i=0}^{k-1} f(i) \cdot x(t-i) y(t)=i=0k1f(i)x(ti)

其中, f f f是卷积核, k k k是卷积核大小, x x x是输入序列。

2.3 扩张卷积

为了增加感受野而不增加参数数量,TCN采用扩张卷积(Dilated Convolution)。扩张卷积,也被称为空洞卷积,是一种在卷积核之间插入空隙(即跳过某些输入单元)的卷积形式。这种技术允许模型在不增加参数数量的情况下捕获更大的感受野,从而更好地理解输入数据中的上下文信息。扩张因子(dilation factor)决定了卷积核中元素之间的间距,例如,如果扩张因子为2,则卷积核中的元素会间隔一个输入单元。

扩张卷积的数学表示:

y ( t ) = ∑ i = 0 k − 1 f ( i ) ⋅ x ( t − d ⋅ i ) y(t) = \sum_{i=0}^{k-1} f(i) \cdot x(t-d \cdot i) y(t)=i=0k1f(i)x(tdi)

其中, d d d是扩张率。

一个扩张的因果卷积如下图所示:
在这里插入图片描述

2.4 残差连接

TCN使用残差连接来缓解梯度消失问题并促进更深层网络的训练。残差连接是残差网络(ResNets)的关键组成部分,由何凯明等人提出。它的主要目的是解决深层神经网络训练中的梯度消失/爆炸问题,以及提高网络的训练效率和性能。在残差连接中,网络的某一层的输出直接加到几层之后的另一层上,形成所谓的“跳跃连接”。具体来说,假设有一个输入 x x x,经过几层后得到 F ( x ) F(x) F(x),那么最终的输出不是 F ( x ) F(x) F(x)而是 x + F ( x ) x+F(x) x+F(x),也就是输入+输出。这种结构允许梯度在反向传播时可以直接流回更早的层,减少了梯度消失的问题,并且使得网络能够有效地训练更深的架构。残差块的输出可以表示为:

o u t p u t = a c t i v a t i o n ( i n p u t + F ( i n p u t ) ) output = activation(input + F(input)) output=activation(input+F(input))

其中, F F F是卷积层和激活函数的组合,残差连接如下图所示:
在这里插入图片描述

3. TCN的网络结构

TCN的基本结构包括多个残差块,每个残差块包含:

  1. 一维因果卷积层
  2. 层归一化
  3. ReLU激活函数
  4. Dropout层

TCN的整体结构可以表示为:
在这里插入图片描述

4. TCN vs RNN

相比于RNN,TCN有以下优势:

  1. 并行计算:卷积操作可以并行执行,提高计算效率。
  2. 固定感受野:可以精确控制输出对过去输入的依赖范围。
  3. 灵活的感受野大小:通过调整网络深度和扩张率,可以轻松处理不同长度的序列。
  4. 稳定梯度:避免了RNN中的梯度消失/爆炸问题。

5. TCN的应用

TCN在多个领域表现出色,包括:

  • 时间序列预测
  • 语音合成
  • 机器翻译
  • 动作识别
  • 音频生成

本篇文章不靠卖代码赚取收益,麻烦给个点赞和关注,后续还会有开源的免费优化算法及其代码,栓Q!同时如果大家有想要的算法可以在评论区打出,如果有空的话我可以帮忙复现

TCN的实现

以下是使用PyTorch实现TCN核心组件的示例代码(可以直接调用):

import torch
import torch.nn as nn
from torch.nn.utils import weight_norm


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)
        

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-20 10:44:02       58 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-20 10:44:02       60 阅读
  3. 在Django里面运行非项目文件

    2024-07-20 10:44:02       48 阅读
  4. Python语言-面向对象

    2024-07-20 10:44:02       60 阅读

热门阅读

  1. linux学习笔记整理: 关于linux:nginx服务器 2024/7/20;

    2024-07-20 10:44:02       18 阅读
  2. 初等数论精解【1】

    2024-07-20 10:44:02       18 阅读
  3. Base64编码与解码

    2024-07-20 10:44:02       24 阅读
  4. Android Studio关于Gradle及JDK问题解决

    2024-07-20 10:44:02       17 阅读
  5. Oracle(12)什么是主键(Primary Key)?

    2024-07-20 10:44:02       17 阅读
  6. 目标检测算法

    2024-07-20 10:44:02       15 阅读
  7. 使用python调用dll库

    2024-07-20 10:44:02       19 阅读
  8. 数据结构之栈、队列和数组的基本概念

    2024-07-20 10:44:02       17 阅读