Skip to content

Were RNNs All We Needed?

Updated: at 15:06

摘要: Transformer在序列长度方面的可扩展性限制重新引发了对在训练时可并行化的递归序列模型的兴趣。因此,许多新颖的递归架构被提出,例如 S4、Mamba 和 Aaren,这些模型在性能上达到了相当的水平。在本研究中,我们重新探讨了十多年前提出的传统递归神经网络(RNNs):LSTM(1997年)和GRU(2014年)。虽然这些模型因需要通过时间反向传播(BPTT)而训练速度较慢,但我们发现,通过移除其隐藏状态在输入、遗忘和更新门中的依赖,LSTM 和 GRU 不再需要进行 BPTT,从而可以高效地并行训练。基于此,我们提出了精简版本(minLSTMs 和 minGRUs),它们具有以下特点:1. ** 显著减少参数数量**,相比传统版本更加轻量化;2. ** 在训练时完全可并行化**,在长度为512的序列上训练速度提高了175倍。最后,我们展示了这些简化版的旧RNN模型在实验性能上与近期的序列模型表现相当。

1. Intro

基于Transformer的模型计算复杂度在token length上都是平方增长的,限制了scale up,所以最近人们开始重新关注具有以下特点的模型:

  1. 训练时消耗的内存和token length成线性关系
  2. 推理时消耗常量内存

最近流行的Mamba、RWKV以及Aaren,都采用了Parallel Prefix Scan Algorithm。这篇文章重新审视之前的LSTM和GRU,他们被时代抛弃的主要原因就是训练时只能通过BPTT,速度太慢。

2. Background

2.1. LSTM

ft=σ(Lineardh([xt,ht1]))it=σ(Lineardh([xt,ht1]))c~t=tanh(Lineardh([xt,ht1]))ot=σ(Lineardh([xt,ht1]))ct=ftct1+itc~tht=ottanh(ct)\begin{align} &f_t = \sigma(\text{Linear}_{d_h}([x_t, h{t-1}]))\\ &i_t = \sigma(\text{Linear}_{d_h}([x_t, h{t-1}]))\\ &\tilde{c}_t = \tanh(\text{Linear}_{d_h}([x_t, h_{t-1}]))\\ &o_t = \sigma(\text{Linear}_{d_h}([x_t, h{t-1}]))\\ &c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t\\ &h_t = o_t \odot \tanh(c_t)\end{align}

其中,\odot 表示向量的按元素乘法tt 是当前时间步,hth_t 是输出的隐藏状态,[xt,ht1][x_t, h_{t-1}]xt x_tht1h_{t-1} 的拼接,dhd_h 是隐藏状态的维度,ctc_t 是在序列中维护信息的细胞状态c~t\tilde{c}_t 是候选细胞状态。LSTM 的输入门 iti_t 决定了从候选细胞状态中添加多少新信息,** 遗忘门** ftf_t 确定要丢弃多少现有细胞状态中的信息,而** 输出门** oto_t 控制了从细胞状态输出哪些信息。激活函数 σ\sigmatanh\tanh 用于缩放,确保输出不会发生爆炸或消失。

总参数量O(4dh(dx+dh))O(4d_h(d_x + d_h))

2.2. GRU

zt=σ(Lineard([xt,ht1]))rt=σ(Lineard([xt,ht1]))h~t=tanh(Lineard([xt,rtht1]))ht=(1zt)ht1+zth~t\begin{align} &z_t = \sigma(\text{Linear}_{d}([x_t, h{t-1}])) \\ &r_t = \sigma(\text{Linear}_{d}([x_t, h{t-1}])) \\ &\tilde{h}_t = \tanh(\text{Linear}_{d}([x_t, r_t \odot h_{t-1}])) \\ &h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\end{align}

其中,h~t\tilde{h}_t 是候选隐藏状态,代表隐藏状态的潜在新值。GRU 将 LSTM 的输入门和遗忘门合并为一个更新门 zt(0,1)z_t \in (0, 1),该门控制了多少过去的信息得以保留(即 1zt1 - z_t),以及从候选隐藏状态中添加多少新信息(即zt z_t)。此外,GRU 省略了 LSTM 的输出门,并引入了** 重置门**rt r_t,该门决定了计算候选隐藏状态时应使用多少过去的信息。

减少了参数量,为O(3dh(dx+dh))O(3d_h(d_x + d_h))

2.3. Parallel Scan

为了替换Transformer提出的大部分RNN or RNN like算法都采用了prefix scan algorithm。

prefix scan algorithm:

并行扫描算法是一种并行计算方法,利用结合性操作符 \oplus(如加法和乘法),从 N 个顺序数据点计算出 N 个前缀计算结果。具体而言,该算法从输入数据 {uk}k=1N\{u_k\}{k=1}^{N} 高效地计算出{i=1kui}k=1N \left\{\sum{i=1}^{k} u_i\right\}_{k=1}^{N}

特别地,我们可以使用并行扫描算法高效计算如下形式的函数:

vt=atvt1+bt其中vt,at,btR,v0b0(Heinsen, 2023)v_t = a_t v_{t-1} + b_t \quad \text{其中} \quad v_t, a_t, b_t \in \mathbb{R}, \quad v_0 \leftarrow b_0 \quad (\text{Heinsen, 2023})

该方法的输入为(a1,,an) (a_1, \dots, a_n)(b0,b1,,bn)(b_0, b_1, \dots, b_n),并通过并行扫描计算出(v1,,vn)(v_1, \dots, v_n)

3. Methodology

上面的公式可以拓展到向量形式:

vt=atvt1+btv_t = a_t\odot v_{t-1} + b_t

可以看到LSTM和GRU里面有长得比较类似的模式。本文简化了RNN,移除输出范围的限制,并且确保在时间维度上是无关的,描述了miniGRU和miniLSTM。

3.1. A minimal GRU: miniGRU

3.1.1. Step 1: Drop previous hidden state dependencies from gates

GRU:

ht=(1zt)ht1+zth~t h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t

依赖于上一时刻,并且无法去除这种依赖,所以修改为:

zt=σ(Lineardh([xt,ht1]))zt=σ(Lineardh(xt))h~t=tanh(Lineardh([xt,rtht1]))h~t=tanh(Lineardh(xt))\begin{align*} z_t = \sigma(\text{Linear}_{d_h}([x_t, h{t-1}])) \quad &\rightarrow \quad z_t = \sigma(\text{Linear}_{d_h}(x_t)) \\ \tilde{h}t = \tanh(\text{Linear}_{d_h}([x_t, r_t \odot h_{t-1}])) \quad &\rightarrow \quad \tilde{h}_t = \tanh(\text{Linear}_{d_h}(x_t))\end{align*}

3.1.2. Step 2: Drop range restriction of candidate states

h~t=tanh(Lineardh(xt))h~t=Lineardh(xt) \tilde{h}_t = \tanh(\text{Linear}_{d_h}(x_t)) \quad \rightarrow \quad \tilde{h}_t = \text{Linear}_{d_h}(x_t)

原来的函数是帮助裁剪隐藏状态,但是现在只和自己相关之后不需要做这样的裁剪。

3.1.3. miniGRU

image.png

参数从O(3dh(dx+dh))O(3d_h(d_x+d_h))下降到O(2dhdx)O(2d_hd_x)

3.2. A Minimal LSTM: minLSTM

3.2.1. Step 1: Drop previous hidden state dependencies from gates

3.2.2. Step 2: Drop range restriction of candidate states

3.2.3. Step 3: Ensure output is time-independent in scale

对遗忘gate和input gate做归一化,并且需要保证它们time independent。

3.2.4. minLSTM

image.png

参数从O(4dh(dx+dh))O(4d_h(d_x+d_h))下降到O(3dhdx)O(3d_hd_x),并且支持并行扫描训练速度很快。

4. Were RNNs All We Needed?

4.1. Minimal LSTMs and GRUs are very efficient

传统的RNN用BPTT的训练效率相当低。min系列的新RNN可以通过并行扫描方法进行并行计算,效率更高。文章重点对比的目标是Mamba,batch size=64.

Runtime & Memory:

image.png

和sequence length基本是线性的了。

Other Tasks:

image.png

image.png

image.png

剩下的没有什么好讲的了,感觉实验太toy了,正大光明说就用了一张P100一张T4,感觉有点抽象。而且去掉了Hidden Dim之后还管自己叫RNN也有点奇怪。


Previous Post
A ConvNet for the 2020s
Next Post
SparseRT: Accelerating Unstructured Sparsity on GPUs for Deep Learning Inference