这是用户在 2024-11-11 14:57 为 https://arxiv.org/html/2410.01201v2 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?

HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: selectp

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: CC BY 4.0 授权协议: CC BY 4.0
arXiv:2410.01201v2 [cs.LG] 04 Oct 2024
arXiv:2410.01201v2 [cs.LG] 2024 年 10 月 4 日

Were RNNs All We Needed?
我们只需要 RNN 吗?

Leo Feng
Mila – Université de Montréal & Borealis AI leo.feng@mila.quebec
&Frederick Tung Borealis AI frederick.tung@borealisai.com &Mohamed Osama Ahmed
Borealis AI mohamed.o.ahmed@borealisai.com
&Yoshua Bengio Mila – Université de Montréal yoshua.bengio@mila.quebec
&Hossein Hajimirsadeghi Borealis AI hossein.hajimirsadeghi@borealisai.com
Abstract 抽象

The scalability limitations of Transformers regarding sequence length have renewed interest in recurrent sequence models that are parallelizable during training. As a result, many novel recurrent architectures, such as S4, Mamba, and Aaren, have been proposed that achieve comparable performance. In this work, we revisit traditional recurrent neural networks (RNNs) from over a decade ago: LSTMs (1997) and GRUs (2014). While these models were slow due to requiring to backpropagate through time (BPTT), we show that by removing their hidden state dependencies from their input, forget, and update gates, LSTMs and GRUs no longer need to BPTT and can be efficiently trained in parallel. Building on this, we introduce minimal versions (minLSTMs and minGRUs) that (1) use significantly fewer parameters than their traditional counterparts and (2) are fully parallelizable during training (175×175\times175 × faster for a sequence of length 512512512512). Lastly, we show that these stripped-down versions of decade-old RNNs match the empirical performance of recent sequence models.
Transformers 在序列长度方面的可扩展性限制重新引起了人们对在训练期间可并行化的递归序列模型的兴趣。因此,已经提出了许多新颖的递归架构,例如 S4、Mamba 和 Aaren,它们实现了类似的性能。 在这项工作中,我们重新审视了十多年前的传统递归神经网络 (RNN):LSTM (1997) 和 GRU (2014)。 虽然这些模型由于需要随时间反向传播 (BPTT) 而速度较慢,但我们表明,通过从其输入、忘记和更新门中删除其隐藏的状态依赖关系,LSTM 和 GRU 不再需要 BPTT,并且可以有效地并行训练。 在此基础上,我们引入了最小版本(minLSTM 和 minGRU),它们 (1) 使用的参数比传统版本少得多,并且 (2) 在训练期间完全可并行化( 175×175\times175 × 对于长度 512512512512 序列来说更快)。最后,我们表明这些十年前的 RNN 的精简版本与最近序列模型的经验性能相匹配。

1 Introduction 1 介绍

Over the past few years, Transformers (Vaswani et al., 2017) have been the dominant architecture in many areas, leading to advancements in tasks like machine translation (Devlin et al., 2019), text generation (Brown et al., 2020), and more. However, Transformers have a quadratic computational complexity in the sequence length, making them prohibitively expensive for long sequences, especially in low-resource settings. As such, numerous works have investigated the design of more efficient alternatives that achieve competitive performance with that of Transformers. Recently, there has been a renewed interest in recurrent sequence models that can be trained efficiently processing their context in parallel. These models (1) during training require only linear memory in the sequence length and (2) at inference time are rolled out recurrently token-by-token, requiring only constant memory. As a result, these models can scale to significantly longer sequences than Transformers11The title of this paper pays tribute to the original Transformers paper, “Attention is All You Need”.
本文的标题旨在向 Transformers 的原始论文“Attention is All You Need”致敬。

在过去的几年里,变压器(Vaswani et al., 2017)一直是许多领域的主导架构,导致机器翻译(Devlin et al., 2019)、文本生成(Brown et al., 2020)等任务的进步。但是,Transformer 在序列长度上具有二次计算复杂性,这使得它们对于长序列来说非常昂贵,尤其是在资源匮乏的环境中。因此,许多工作研究了更高效的替代方案的设计,以实现与 Transformer 相比具有竞争力的性能。最近,人们对递归序列模型重新产生了兴趣,这些模型可以被训练以并行处理其上下文。这些模型 (1) 在训练期间只需要序列长度的线性内存,并且 (2) 在推理时逐个令牌地反复推出,只需要常量内存。因此,这些模型可以扩展到比 Transformer 长得多的序列1
.

A family of efficiently trainable recurrent sequence models that has recently gained much traction is that of state-space models, specifically the recently proposed Mamba (Gu & Dao, 2024). Mamba (S6) is a state-space model that differentiates itself from prior works by leveraging input-dependent transitions. The recent success of Mamba and the proposals of many new variants of state-space models has led to several survey papers (Wang et al., 2024; Patro & Agneeswaran, 2024; Qu et al., 2024). Another extensively explored group of methods is those based on attention. Peng et al. (2023) proposed a linear attention model that can be written recurrently while being trained in parallel. Feng et al. (2024) showed that softmax attention (and Transformers) can be viewed as a recurrent neural network (RNN). Building on their RNN formulation of attention, they proposed Aaren, a softmax attention model, that can be computed in parallel for efficient training or unrolled sequentially as an RNN for efficient inference. Although many recurrent models have been proposed with vastly different architectures, these recent state-of-the-art methods are all efficiently trainable using the same algorithm – the parallel prefix scan algorithm (Blelloch, 1990).
最近获得广泛关注的一系列高效可训练的循环序列模型是状态空间模型,特别是最近提出的Mamba(Gu&Dao,2024)。Mamba (S6) 是一种状态空间模型,它通过利用依赖于输入的过渡来区别于以前的工作。Mamba 最近的成功和状态空间模型的许多新变体的提议导致了几篇调查论文(Wang et al., 2024;Patro & Agneeswaran, 2024;Qu et al., 2024的另一组被广泛探索的方法是基于注意力的方法。Peng et al. (2023 提出了一个线性注意力模型,该模型可以在并行训练的同时反复编写。Feng et al. (2024 表明,softmax 注意力(和 Transformers)可以被视为递归神经网络 (RNN)。基于他们的 RNN 注意力公式,他们提出了 Aaren,这是一种 softmax 注意力模型,可以并行计算以实现高效训练,也可以作为 RNN 顺序展开以实现高效推理。尽管已经提出了许多架构截然不同的递归模型,但这些最新的最先进的方法都可以使用相同的算法——并行前缀扫描算法(Blelloch,1990)进行有效训练。

Inspired by the striking algorithmic similarities between the numerous recently proposed sequence models, we revisit LSTMs (Hochreiter & Schmidhuber, 1997) and GRUs (Cho et al., 2014) from a modern lens. As traditional RNNs from over a decade ago, LSTMs and GRUs are only computable sequentially and require to backpropagate through time (BPTT) during training. As such, LSTMs and GRUs were far too slow to scale beyond a few hundred tokens, resulting in their deprecation. Revisiting these models, we show that by removing hidden state dependencies from their input, forget, and update gates, LSTMs and GRUs no longer need to BPTT and can be trained efficiently using the parallel scan algorithm. Building on this, we simplify LSTMs and GRUs further by removing their constraints on output range, (i.e., their use of tanhtanh\mathrm{tanh}roman_tanh) and ensuring their output is time-independent in scale. These steps result in minimal versions (minLSTMs and minGRUs) that (1) use significantly fewer parameters than their traditional counterpart and (2) are trainable in parallel (175×175\times175 × faster for a context length of 512512512512). Finally, we show that these stripped-down versions of decade-old RNNs match the empirical performance of recent sequence models.
受到众多最近提出的序列模型之间惊人的算法相似性的启发,我们从现代视角重新审视了LSTMs(Hochreiter & Schmidhuber(1997)和GRUs(Cho等 人,2014。与十多年前的传统 RNN 一样,LSTM 和 GRU 只能按顺序计算,并且在训练期间需要通过时间反向传播 (BPTT)。因此,LSTM 和 GRU 的扩展速度太慢,无法扩展到几百个令牌以上,从而导致它们被弃用。重新审视这些模型,我们表明,通过从其 input、forget 和 update 门中删除隐藏的状态依赖关系,LSTM 和 GRU 不再需要 BPTT,并且可以使用并行扫描算法进行高效训练。在此基础上,我们通过消除 LSTM 和 GRU 对输出范围的限制(即它们的使用 tanh\mathrm{tanh}roman_tanh )并确保它们的输出在规模上与时间无关,从而进一步简化了 LSTM 和 GRU。这些步骤导致最小版本(minLSTM 和 minGRU)满足以下条件:(1) 使用的参数明显少于其传统版本,并且 (2) 可并行训练( 175×175\times175 × 上下文长度 512512512512 为 )。最后,我们表明这些十年前的 RNN 的精简版本与最近序列模型的经验性能相匹配。

2 Background

In this section, we review recurrent neural networks (RNNs). RNNs are recurrent sequence models that maintain a hidden state across time steps, capturing temporal dependencies. As such, RNNs are particularly suitable for sequence modelling settings such as those involving time series, natural language processing, and other sequential tasks where context from previous steps informs the current prediction. Vanilla RNNs (Elman, 1990), however, struggle with issues of vanishing and exploding gradients, limiting their ability to learn long-term dependencies.

2.1 LSTM

Addressing this limitation, Hochreiter & Schmidhuber (1997) introduced Long Short-Term Memory (LSTM) networks. LSTMs are enhanced RNNs designed to mitigate the vanishing gradient problem, allowing the model to learn long-term dependencies. LSTMs are computed as follows:

𝒇tsubscript𝒇𝑡\displaystyle{\bm{f}}_{t}bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =σ(Lineardh([𝒙t,𝒉t1]))absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1\displaystyle=\sigma({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{\bm{h}}_{t-1}]))= italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) )
𝒊tsubscript𝒊𝑡\displaystyle{\bm{i}}_{t}bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =σ(Lineardh([𝒙t,𝒉t1]))absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1\displaystyle=\sigma({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{\bm{h}}_{t-1}]))= italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) )
𝒄~tsubscript~𝒄𝑡\displaystyle\tilde{{\bm{c}}}_{t}over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =tanh(Lineardh([𝒙t,𝒉t1]))absenttanhsubscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1\displaystyle=\mathrm{tanh}({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{\bm{h}}_{% t-1}]))= roman_tanh ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) )
𝒐tsubscript𝒐𝑡\displaystyle{\bm{o}}_{t}bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =σ(Lineardh([𝒙t,𝒉t1]))absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1\displaystyle=\sigma({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{\bm{h}}_{t-1}]))= italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) )
𝒄tsubscript𝒄𝑡\displaystyle{\bm{c}}_{t}bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =𝒇t𝒄t1+𝒊t𝒄~tabsentdirect-productsubscript𝒇𝑡subscript𝒄𝑡1direct-productsubscript𝒊𝑡subscript~𝒄𝑡\displaystyle={\bm{f}}_{t}\odot{\bm{c}}_{t-1}+{\bm{i}}_{t}\odot\tilde{{\bm{c}}% }_{t}= bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_c start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
𝒉tsubscript𝒉𝑡\displaystyle{\bm{h}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =𝒐ttanh(𝒄t)absentdirect-productsubscript𝒐𝑡tanhsubscript𝒄𝑡\displaystyle={\bm{o}}_{t}\odot\mathrm{tanh}({\bm{c}}_{t})= bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ roman_tanh ( bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

where direct-product\odot represents an element-wise multiplication of vectors, t𝑡titalic_t is the current timestep, 𝒉tsubscript𝒉𝑡{\bm{h}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the outputted hidden state, [𝒙t,𝒉t1]subscript𝒙𝑡subscript𝒉𝑡1[{\bm{x}}_{t},{\bm{h}}_{t-1}][ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] represents the concatenation of 𝒙tsubscript𝒙𝑡{\bm{x}}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT with 𝒉t1subscript𝒉𝑡1{\bm{h}}_{t-1}bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT, dhsubscript𝑑d_{h}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT is the size of the hidden state, 𝒄tsubscript𝒄𝑡{\bm{c}}_{t}bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is a cell state that maintains information over the sequence, and 𝒄~tsubscript~𝒄𝑡\tilde{{\bm{c}}}_{t}over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the candidate cell state to be added, 𝒊tsubscript𝒊𝑡{\bm{i}}_{t}bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, 𝒇tsubscript𝒇𝑡{\bm{f}}_{t}bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and 𝒐tsubscript𝒐𝑡{\bm{o}}_{t}bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are gating mechanisms. The input gate 𝒊tsubscript𝒊𝑡{\bm{i}}_{t}bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT controls how much new information from the candidate cell state is added. The forget gate 𝒇tsubscript𝒇𝑡{\bm{f}}_{t}bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT determines the proportion of information in the cell gate to discard. The output gate 𝒐tsubscript𝒐𝑡{\bm{o}}_{t}bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT decides what information from the cell state should be outputted. The σ𝜎\sigmaitalic_σ and tanh\tanhroman_tanh are used for scaling to ensure that the output does not explode/vanish. An LSTM module maintains both a cell and a hidden state and, in total, contains O(4dh(dx+dh))𝑂4subscript𝑑subscript𝑑𝑥subscript𝑑O(4d_{h}(d_{x}+d_{h}))italic_O ( 4 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) ) parameters.

2.2 GRU

Simplifying LSTM, Cho et al. (2014) introduced Gated Recurrent Unit (GRU) which only uses two gates and a single state instead of LSTM’s three gates and two states (hidden and cell state). GRU’s reduced complexity leads to faster training and inference times while achieving competitive performance in many tasks. GRUs are computed as follows:

𝒛tsubscript𝒛𝑡\displaystyle{\bm{z}}_{t}bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =σ(Lineard([𝒙t,𝒉t1]))absent𝜎subscriptLinear𝑑subscript𝒙𝑡subscript𝒉𝑡1\displaystyle=\sigma({\mathrm{Linear}}_{d}([{\bm{x}}_{t},{\bm{h}}_{t-1}]))= italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) )
𝒓tsubscript𝒓𝑡\displaystyle{\bm{r}}_{t}bold_italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =σ(Lineard([𝒙t,𝒉t1]))absent𝜎subscriptLinear𝑑subscript𝒙𝑡subscript𝒉𝑡1\displaystyle=\sigma({\mathrm{Linear}}_{d}([{\bm{x}}_{t},{\bm{h}}_{t-1}]))= italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) )
𝒉~tsubscript~𝒉𝑡\displaystyle\tilde{{\bm{h}}}_{t}over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =tanh(Lineard([𝒙t,𝒓t𝒉t1]))absenttanhsubscriptLinear𝑑subscript𝒙𝑡direct-productsubscript𝒓𝑡subscript𝒉𝑡1\displaystyle=\mathrm{tanh}({\mathrm{Linear}}_{d}([{\bm{x}}_{t},{\bm{r}}_{t}% \odot{\bm{h}}_{t-1}]))= roman_tanh ( roman_Linear start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) )
𝒉tsubscript𝒉𝑡\displaystyle{\bm{h}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =(1𝒛t)𝒉t1+𝒛t𝒉~tabsentdirect-product1subscript𝒛𝑡subscript𝒉𝑡1direct-productsubscript𝒛𝑡subscript~𝒉𝑡\displaystyle=(1-{\bm{z}}_{t})\odot{\bm{h}}_{t-1}+{\bm{z}}_{t}\odot\tilde{{\bm% {h}}}_{t}= ( 1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT

where 𝒉~tsubscript~𝒉𝑡\tilde{{\bm{h}}}_{t}over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the candidate hidden state that represents a potential new value for the hidden state. GRU combines LSTM’s forget and input gates into a single update gate 𝒛t(0,1)subscript𝒛𝑡01{\bm{z}}_{t}\in(0,1)bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ ( 0 , 1 ) which decides how much of the past information to carry forward (i.e., 1𝒛t1subscript𝒛𝑡1-{\bm{z}}_{t}1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT) and how much new information from the candidate hidden state to add (i.e., 𝒛tsubscript𝒛𝑡{\bm{z}}_{t}bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT). Additionally, LSTM’s output gate is removed and instead, a reset gate 𝒓tsubscript𝒓𝑡{\bm{r}}_{t}bold_italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is added that controls how much past information is used in computing the candidate hidden state. GRU reduces the total number of parameters and computations, requiring only O(3dh(dx+dh))𝑂3subscript𝑑subscript𝑑𝑥subscript𝑑O(3d_{h}(d_{x}+d_{h}))italic_O ( 3 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) ) parameters. However, GRUs and LSTMs are only computable sequentially. As a result, during training they require backpropagating their gradients through time (BPTT), requiring linear training time and greatly limiting their ability to scale to long contexts.

2.3 Parallel Scan

Due to this limitation, Transformers replaced LSTMs and GRUs as the defacto sequence modelling method for years by leveraging parallelization during training. However, Transformers have a quadratic complexity in the sequence length, limiting their ability to scale to long contexts. Recently, a resurgence of many new recurrent models have been proposed as replacements for Transformers that achieve comparable performance and are trainable in parallel, while avoiding the BPTT issue that traditional RNNs (e.g., LSTMs and GRUs) faced. Although many different architectures have been proposed, many of these models are efficiently trained using the parallel prefix scan algorithm (Blelloch, 1990).

The parallel scan algorithm is a parallel computation method for computing N𝑁Nitalic_N prefix computations from N𝑁Nitalic_N sequential data points via an associative operator direct-sum\oplus (e.g., +++ and ×\times×). The algorithm efficiently computes {i=1kui}k=1Nsuperscriptsubscriptsuperscriptsubscriptdirect-sum𝑖1𝑘subscript𝑢𝑖𝑘1𝑁\{\bigoplus_{i=1}^{k}u_{i}\}_{k=1}^{N}{ ⨁ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_u start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT from {uk}k=1Nsuperscriptsubscriptsubscript𝑢𝑘𝑘1𝑁\{u_{k}\}_{k=1}^{N}{ italic_u start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT. In particular, we can apply the parallel scan method for efficiently computing a popular family of functions: vt=atvt1+btsubscript𝑣𝑡subscript𝑎𝑡subscript𝑣𝑡1subscript𝑏𝑡v_{t}=a_{t}v_{t-1}+b_{t}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_v start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where vt,at,btsubscript𝑣𝑡subscript𝑎𝑡subscript𝑏𝑡v_{t},a_{t},b_{t}\in\mathbb{R}italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R and v0b0subscript𝑣0subscript𝑏0v_{0}\leftarrow b_{0}italic_v start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ← italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (Heinsen, 2023). The method takes as input a1,,ansubscript𝑎1subscript𝑎𝑛a_{1},\ldots,a_{n}italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_a start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and b0,b1,,bnsubscript𝑏0subscript𝑏1subscript𝑏𝑛b_{0},b_{1},\ldots,b_{n}italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and computes via parallel scans v1,,vnsubscript𝑣1subscript𝑣𝑛v_{1},\ldots,v_{n}italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_v start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

3 Methodology

Naturally, the aforementioned algorithm also extends to vectors: 𝒗t=𝒂t𝒗t1+𝒃tsubscript𝒗𝑡direct-productsubscript𝒂𝑡subscript𝒗𝑡1subscript𝒃𝑡{\bm{v}}_{t}={\bm{a}}_{t}\odot{\bm{v}}_{t-1}+{\bm{b}}_{t}bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_v start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where direct-product\odot is the element-wise multiplication. Interestingly, we can see that the GRU and LSTM state recurrences resemble the vector formulation. In this section, we show that GRUs and LSTMs are trainable via parallel scan by simplifying and removing several hidden state dependencies from their various gates. Building on this, we further simplify these RNNs by removing their constraints on output range, (i.e., tanhtanh\mathrm{tanh}roman_tanh) and ensuring the outputs are time-independent in scale. Combining the steps, we describe minimal versions of GRUs and LSTMs (minGRUs and minLSTMs) that are trainable via parallel scan and perform comparably to Transformers and recently proposed sequence methods.

3.1 A Minimal GRU: minGRU

3.1.1 Step 1: Drop previous hidden state dependencies from gates

Revisiting GRU’s hidden state recurrence which works as follows:

𝒉t=(𝟏𝒛t)𝒉t1+𝒛t𝒉~tsubscript𝒉𝑡direct-product1subscript𝒛𝑡subscript𝒉𝑡1direct-productsubscript𝒛𝑡subscript~𝒉𝑡{\bm{h}}_{t}=(\bm{1}-{\bm{z}}_{t})\odot{\bm{h}}_{t-1}+{\bm{z}}_{t}\odot\tilde{% {\bm{h}}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT

We can observe that the recurrence resembles the aforementioned parallel scan’s formulation where 𝒂t(𝟏𝒛t)subscript𝒂𝑡1subscript𝒛𝑡{\bm{a}}_{t}\leftarrow(\bm{1}-{\bm{z}}_{t})bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), 𝒃t𝒛t𝒉~tsubscript𝒃𝑡direct-productsubscript𝒛𝑡subscript~𝒉𝑡{\bm{b}}_{t}\leftarrow{\bm{z}}_{t}\odot\tilde{{\bm{h}}}_{t}bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and 𝒗t𝒉tsubscript𝒗𝑡subscript𝒉𝑡{\bm{v}}_{t}\leftarrow{\bm{h}}_{t}bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. However, 𝒛tsubscript𝒛𝑡{\bm{z}}_{t}bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝒉~tsubscript~𝒉𝑡\tilde{{\bm{h}}}_{t}over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are dependent on previous hidden states 𝒉t1subscript𝒉𝑡1{\bm{h}}_{t-1}bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT, i.e., 𝒛t=σ(Lineardh([𝒙t,𝒉t1]))subscript𝒛𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1{\bm{z}}_{t}=\sigma({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{\bm{h}}_{t-1}]))bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) ) and 𝒉~t=tanh(Lineardh([𝒙t,rt𝒉t1]))subscript~𝒉𝑡tanhsubscriptLinearsubscript𝑑subscript𝒙𝑡direct-productsubscript𝑟𝑡subscript𝒉𝑡1\tilde{{\bm{h}}}_{t}=\mathrm{tanh}({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},r_{% t}\odot{\bm{h}}_{t-1}]))over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_tanh ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) ). As a result, it is not possible to apply the parallel scan as is since the algorithm’s inputs 𝒂1,,𝒂nsubscript𝒂1subscript𝒂𝑛{\bm{a}}_{1},\ldots,{\bm{a}}_{n}bold_italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_a start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and 𝒃1,,𝒃nsubscript𝒃1subscript𝒃𝑛{\bm{b}}_{1},\ldots,{\bm{b}}_{n}bold_italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_b start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT are conditional on already knowing its outputs 𝒉1,,𝒉n1subscript𝒉1subscript𝒉𝑛1{\bm{h}}_{1},\ldots,{\bm{h}}_{n-1}bold_italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_h start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT.

We can remedy this by simplifying GRUs, removing their previous hidden state (i.e., 𝒉t1subscript𝒉𝑡1{\bm{h}}_{t-1}bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT) dependencies. Specifically, the changes are as follows:

𝒛t=σ(Lineardh([𝒙t,𝒉t1]))𝒓t=σ(Lineardh([𝒙t,𝒉t1]))𝒉~t=tanh(Lineardh([𝒙t,𝒓t𝒉t1]))𝒛t=σ(Linear𝒅h(𝒙t))𝒉~t=tanh(Linear𝒅h(𝒙t))subscript𝒛𝑡absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1subscript𝒓𝑡absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1subscript~𝒉𝑡absenttanhsubscriptLinearsubscript𝑑subscript𝒙𝑡direct-productsubscript𝒓𝑡subscript𝒉𝑡1subscript𝒛𝑡absent𝜎subscriptLinearsubscript𝒅subscript𝒙𝑡subscript~𝒉𝑡absenttanhsubscriptLinearsubscript𝒅subscript𝒙𝑡\begin{aligned} {\bm{z}}_{t}&=\sigma({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{% \bm{h}}_{t-1}]))\\ {\bm{r}}_{t}&=\sigma({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{\bm{h}}_{t-1}]))% \\ \tilde{{\bm{h}}}_{t}&=\mathrm{tanh}({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{% \bm{r}}_{t}\odot{\bm{h}}_{t-1}]))\\ \end{aligned}\quad\Rightarrow\quad\begin{aligned} {\bm{z}}_{t}&=\sigma({% \mathrm{Linear}}_{{\bm{d}}_{h}}({\bm{x}}_{t}))\\ \tilde{{\bm{h}}}_{t}&=\mathrm{tanh}({\mathrm{Linear}}_{{\bm{d}}_{h}}({\bm{x}}_% {t}))\\ \end{aligned}start_ROW start_CELL bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) ) end_CELL end_ROW start_ROW start_CELL bold_italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) ) end_CELL end_ROW start_ROW start_CELL over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_tanh ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) ) end_CELL end_ROW ⇒ start_ROW start_CELL bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( roman_Linear start_POSTSUBSCRIPT bold_italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_tanh ( roman_Linear start_POSTSUBSCRIPT bold_italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) end_CELL end_ROW

By removing the dependence on 𝒉t1subscript𝒉𝑡1{\bm{h}}_{t-1}bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT from the candidate hidden state 𝒉~tsubscript~𝒉𝑡\tilde{{\bm{h}}}_{t}over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the reset gate 𝒓tsubscript𝒓𝑡{\bm{r}}_{t}bold_italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT that would control 𝒉t1subscript𝒉𝑡1{\bm{h}}_{t-1}bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT weight is also no longer needed and is removed. Without the dependencies on previous hidden states, the inputs to the algorithm 𝒂1,,𝒂nsubscript𝒂1subscript𝒂𝑛{\bm{a}}_{1},\ldots,{\bm{a}}_{n}bold_italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_a start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and 𝒃1,,𝒃nsubscript𝒃1subscript𝒃𝑛{\bm{b}}_{1},\ldots,{\bm{b}}_{n}bold_italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_b start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT are all easily computed in parallel and can thus be used to compute 𝒉1,,𝒉nsubscript𝒉1subscript𝒉𝑛{\bm{h}}_{1},\ldots,{\bm{h}}_{n}bold_italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_h start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT efficiently via the parallel scan.

3.1.2 Step 2: Drop range restriction of candidate states

In GRU’s hidden state recurrence, the proportion carried over from the previous hidden state (𝟏𝒛t1subscript𝒛𝑡\mathbf{1}-{\bm{z}}_{t}bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT) and the amount added for the new candidate hidden state (𝒛tsubscript𝒛𝑡{\bm{z}}_{t}bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT) sum to 1111. As a result, the scale of GRU’s hidden state value is time-independent. Instead, the scale of its hidden state depends on that of its candidate hidden states 𝒉~tsubscript~𝒉𝑡\tilde{{\bm{h}}}_{t}over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. The hyperbolic tangent function (tanhtanh\mathrm{tanh}roman_tanh) plays a crucial role in LSTMs and GRUs, restricting the range of (candidate) hidden states, i.e., 𝒉~t,𝒉t(1,1)dhsubscript~𝒉𝑡subscript𝒉𝑡superscript11subscript𝑑\tilde{{\bm{h}}}_{t},{\bm{h}}_{t}\in(-1,1)^{d_{h}}over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ ( - 1 , 1 ) start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. The tanhtanh\mathrm{tanh}roman_tanh helps stabilize the training and mitigates vanishing gradients that result from applying sigmoid (σ𝜎\sigmaitalic_σ) activations to linear transformations of the hidden state (e.g., 𝒛t=σ(Lineardh([𝒙t,𝒉t1]))subscript𝒛𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1{\bm{z}}_{t}=\sigma({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{\bm{h}}_{t-1}]))bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) )). In the previous step, these hidden state dependencies were removed. As such, we can simplify GRU further by removing the range restriction (tanhtanh\mathrm{tanh}roman_tanh) on the (candidate) hidden states as follows:

𝒉~t=tanh(Lineardh(𝒙t))𝒉~t=Lineardh(𝒙t)subscript~𝒉𝑡absenttanhsubscriptLinearsubscript𝑑subscript𝒙𝑡subscript~𝒉𝑡absentsubscriptLinearsubscript𝑑subscript𝒙𝑡\begin{aligned} \tilde{{\bm{h}}}_{t}&=\mathrm{tanh}({\mathrm{Linear}}_{d_{h}}(% {\bm{x}}_{t}))\\ \end{aligned}\quad\Rightarrow\quad\begin{aligned} \tilde{{\bm{h}}}_{t}&={% \mathrm{Linear}}_{d_{h}}({\bm{x}}_{t})\\ \end{aligned}start_ROW start_CELL over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_tanh ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) end_CELL end_ROW ⇒ start_ROW start_CELL over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_CELL end_ROW

3.1.3 minGRU

Combining the two simplification steps results in a minimal version of GRU (minGRU):

\Rightarrow

The resulting model is significantly more efficient than the original GRU (1) requiring only O(2dhdx)𝑂2subscript𝑑subscript𝑑𝑥O(2d_{h}d_{x})italic_O ( 2 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) parameters instead of GRU’s O(3dh(dx+dh))𝑂3subscript𝑑subscript𝑑𝑥subscript𝑑O(3d_{h}(d_{x}+d_{h}))italic_O ( 3 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) ) parameters where dx,dhsubscript𝑑𝑥subscript𝑑d_{x},d_{h}italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT corresponds to the sizes of xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and htsubscript𝑡h_{t}italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT respectively. In terms of training, minGRU (2) can be trained in parallel using the parallel scan algorithm, speeding up training significantly. In Section 4.1, we show that this corresponded to a 175×175\times175 × speedup in training steps for a sequence length of 512512512512 on a T4 GPU. The parameter efficiency gains are also significant. Typically, in RNNs, state expansion is performed (i.e., dh=αdxsubscript𝑑𝛼subscript𝑑𝑥d_{h}=\alpha d_{x}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = italic_α italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT where α1𝛼1\alpha\geq 1italic_α ≥ 1) allowing the models to more readily learn features from their inputs. minGRU uses approximately 33%,22%,17%,percent33percent22percent1733\%,22\%,17\%,33 % , 22 % , 17 % , or 13%percent1313\%13 % of parameters compared to GRU when α=1,2,3,𝛼123\alpha=1,2,3,italic_α = 1 , 2 , 3 , or 4444 respectively.

3.2 A Minimal LSTM: minLSTM

3.2.1 Step 1: Drop previous hidden state dependencies from gates

Revisiting LSTMs, we focus on their cell state recurrence which works as follows:

𝒄t=𝒇t𝒄t1+𝒊t𝒄~tsubscript𝒄𝑡direct-productsubscript𝒇𝑡subscript𝒄𝑡1direct-productsubscript𝒊𝑡subscript~𝒄𝑡{\bm{c}}_{t}={\bm{f}}_{t}\odot{\bm{c}}_{t-1}+{\bm{i}}_{t}\odot\tilde{{\bm{c}}}% _{t}\\ bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_c start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT

Similar to GRU’s hidden state, we can see that LSTM’s cell state recurrence resembles the aforementioned parallel scan’s formulation 𝒗t=𝒂t𝒗t1+𝒃tsubscript𝒗𝑡direct-productsubscript𝒂𝑡subscript𝒗𝑡1subscript𝒃𝑡{\bm{v}}_{t}={\bm{a}}_{t}\odot{\bm{v}}_{t-1}+{\bm{b}}_{t}bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_v start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where 𝒂t𝒇tsubscript𝒂𝑡subscript𝒇𝑡{\bm{a}}_{t}\leftarrow{\bm{f}}_{t}bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, 𝒃t𝒊t𝒄~tsubscript𝒃𝑡direct-productsubscript𝒊𝑡subscript~𝒄𝑡{\bm{b}}_{t}\leftarrow{\bm{i}}_{t}\odot\tilde{{\bm{c}}}_{t}bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and 𝒗t𝒄tsubscript𝒗𝑡subscript𝒄𝑡{\bm{v}}_{t}\leftarrow{\bm{c}}_{t}bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. However, 𝒇tsubscript𝒇𝑡{\bm{f}}_{t}bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, 𝒊tsubscript𝒊𝑡{\bm{i}}_{t}bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝒄~tsubscript~𝒄𝑡\tilde{{\bm{c}}}_{t}over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are dependent on the previous hidden state 𝒉tsubscript𝒉𝑡{\bm{h}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. As such, LSTM’s cell state recurrence is unable to apply the parallel scan algorithm as is. We can address this in a similar fashion to GRU by removing their hidden state dependencies as follows:

𝒇t=σ(Lineardh([𝒙t,𝒉t1]))𝒊t=σ(Lineardh([𝒙t,𝒉t1]))𝒄~t=tanh(Lineardh([𝒙t,𝒉t1]))𝒇t=σ(Lineardh(𝒙t))𝒊t=σ(Lineardh(𝒙t))𝒄~t=tanh(Lineardh(𝒙t))subscript𝒇𝑡absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1subscript𝒊𝑡absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1subscript~𝒄𝑡absenttanhsubscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡1subscript𝒇𝑡absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒊𝑡absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript~𝒄𝑡absenttanhsubscriptLinearsubscript𝑑subscript𝒙𝑡\begin{aligned} {\bm{f}}_{t}&=\sigma({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{% \bm{h}}_{t-1}]))\\ {\bm{i}}_{t}&=\sigma({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{\bm{h}}_{t-1}]))% \\ \tilde{{\bm{c}}}_{t}&=\mathrm{tanh}({\mathrm{Linear}}_{d_{h}}([{\bm{x}}_{t},{% \bm{h}}_{t-1}]))\\ \end{aligned}\quad\Rightarrow\quad\begin{aligned} {\bm{f}}_{t}&=\sigma({% \mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))\\ {\bm{i}}_{t}&=\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))\\ \tilde{{\bm{c}}}_{t}&=\mathrm{tanh}({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))\\ \end{aligned}start_ROW start_CELL bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) ) end_CELL end_ROW start_ROW start_CELL bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) ) end_CELL end_ROW start_ROW start_CELL over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_tanh ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( [ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ] ) ) end_CELL end_ROW ⇒ start_ROW start_CELL bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_tanh ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) end_CELL end_ROW

3.2.2 Step 2: Drop range restriction of candidate states

Similar to GRUs, LSTMs leverage the hyperbolic tangent function (tanhtanh\mathrm{tanh}roman_tanh) to restrict the range of its states between (1,1)11(-1,1)( - 1 , 1 ). LSTMs apply the range restriction twice: once when computing the candidate cell state and once computing its hidden state. In this step, we drop both as follows:

𝒄~t=tanh(Lineardh(𝒙t))𝒉t=𝒐ttanh(𝒄t)𝒄~t=Lineardh(𝒙t)𝒉t=𝒐t𝒄tsubscript~𝒄𝑡absenttanhsubscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡absentdirect-productsubscript𝒐𝑡tanhsubscript𝒄𝑡subscript~𝒄𝑡absentsubscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡absentdirect-productsubscript𝒐𝑡subscript𝒄𝑡\begin{aligned} \tilde{{\bm{c}}}_{t}&=\mathrm{tanh}({\mathrm{Linear}}_{d_{h}}(% {\bm{x}}_{t}))\\ {\bm{h}}_{t}&={\bm{o}}_{t}\odot\mathrm{tanh}({\bm{c}}_{t})\\ \end{aligned}\quad\Rightarrow\quad\begin{aligned} \tilde{{\bm{c}}}_{t}&={% \mathrm{Linear}}_{d_{h}}({\bm{x}}_{t})\\ {\bm{h}}_{t}&={\bm{o}}_{t}\odot{\bm{c}}_{t}\\ \end{aligned}start_ROW start_CELL over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_tanh ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ roman_tanh ( bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_CELL end_ROW ⇒ start_ROW start_CELL over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW

3.2.3 Step 3: Ensure output is time-independent in scale

In many sequence modelling settings (e.g., text generation), the optimization objective/target is time-independent in scale. Recall LSTM’s cell state recurrence 𝒄t=𝒇t𝒄t1+𝒊t𝒄~tsubscript𝒄𝑡direct-productsubscript𝒇𝑡subscript𝒄𝑡1direct-productsubscript𝒊𝑡subscript~𝒄𝑡{\bm{c}}_{t}={\bm{f}}_{t}\odot{\bm{c}}_{t-1}+{\bm{i}}_{t}\odot\tilde{{\bm{c}}}% _{t}bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_c start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where 𝒊t,𝒇t(0,1)dhsubscript𝒊𝑡subscript𝒇𝑡superscript01subscript𝑑{\bm{i}}_{t},{\bm{f}}_{t}\in(0,1)^{d_{h}}bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ ( 0 , 1 ) start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and GRU’s hidden state recurrence22A superscript is added to differentiate GRU’s hidden state from LSTM’s., 𝒉tGRU=(𝟏𝒛t)𝒉t1GRU+𝒛t𝒉~tGRUsubscriptsuperscript𝒉𝐺𝑅𝑈𝑡direct-product1subscript𝒛𝑡subscriptsuperscript𝒉𝐺𝑅𝑈𝑡1direct-productsubscript𝒛𝑡subscriptsuperscript~𝒉𝐺𝑅𝑈𝑡{\bm{h}}^{GRU}_{t}=(\bm{1}-{\bm{z}}_{t})\odot{\bm{h}}^{GRU}_{t-1}+{\bm{z}}_{t}% \odot\tilde{{\bm{h}}}^{GRU}_{t}bold_italic_h start_POSTSUPERSCRIPT italic_G italic_R italic_U end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ bold_italic_h start_POSTSUPERSCRIPT italic_G italic_R italic_U end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUPERSCRIPT italic_G italic_R italic_U end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where 𝒛t(0,1)dhsubscript𝒛𝑡superscript01subscript𝑑{\bm{z}}_{t}\in(0,1)^{d_{h}}bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ ( 0 , 1 ) start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. GRUs retain (𝟏𝒛t)(0,1)1subscript𝒛𝑡01(\bm{1}-{\bm{z}}_{t})\in(0,1)( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∈ ( 0 , 1 ) of the previous hidden state and add 𝒛tsubscript𝒛𝑡{\bm{z}}_{t}bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT of the new candidate state. Since these proportions sum to 𝟏1\bm{1}bold_1, the model ensures its outputs (i.e., hidden states) are time-independent in scale. In contrast, LSTM’s forget and input gates are computed independently (e.g., 𝒇t,𝒊t𝟏subscript𝒇𝑡subscript𝒊𝑡1{\bm{f}}_{t},{\bm{i}}_{t}\rightarrow\bm{1}bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → bold_1 or 𝒇t,𝒊t𝟎subscript𝒇𝑡subscript𝒊𝑡0{\bm{f}}_{t},{\bm{i}}_{t}\rightarrow\bm{0}bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → bold_0), making its cell states time-dependent in scale33For example, 𝒄t𝒄0+i=1t𝒄~tsubscript𝒄𝑡subscript𝒄0superscriptsubscript𝑖1𝑡subscript~𝒄𝑡{\bm{c}}_{t}\rightarrow{\bm{c}}_{0}+\sum_{i=1}^{t}\tilde{{\bm{c}}}_{t}bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → bold_italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT when 𝒇1:t,𝒊1:t1subscript𝒇:1𝑡subscript𝒊:1𝑡1{\bm{f}}_{1:t},{\bm{i}}_{1:t}\rightarrow 1bold_italic_f start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT , bold_italic_i start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT → 1, growing in scale as the sequence length increases. and optimization more difficult. As such, we ensure LSTM’s output is time-independent in scale.

To do so, we can simply normalize the two gates, i.e., 𝒇t,𝒊tft𝒇t+𝒊t,𝒊t𝒇t+𝒊tformulae-sequencesubscriptsuperscript𝒇𝑡subscriptsuperscript𝒊𝑡subscript𝑓𝑡subscript𝒇𝑡subscript𝒊𝑡subscript𝒊𝑡subscript𝒇𝑡subscript𝒊𝑡{\bm{f}}^{\prime}_{t},{\bm{i}}^{\prime}_{t}\leftarrow\frac{f_{t}}{{\bm{f}}_{t}% +{\bm{i}}_{t}},\frac{{\bm{i}}_{t}}{{\bm{f}}_{t}+{\bm{i}}_{t}}bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG , divide start_ARG bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG, ensuring that 𝒇t+𝒊t=𝟏subscriptsuperscript𝒇𝑡subscriptsuperscript𝒊𝑡1{\bm{f}}^{\prime}_{t}+{\bm{i}}^{\prime}_{t}=\bm{1}bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_1 and the scale of LSTM’s cell state is time-independent. Ensuring that the hidden state is time-independent in scale, we also drop the output gate 𝒐tsubscript𝒐𝑡{\bm{o}}_{t}bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT which scales the hidden state. Without the output gate, the normalized hidden state is equal to the cell state, i.e., 𝒉t=𝒐t𝒄t𝒉t=𝒄tsubscript𝒉𝑡direct-productsubscript𝒐𝑡subscript𝒄𝑡subscript𝒉𝑡subscript𝒄𝑡{\bm{h}}_{t}={\bm{o}}_{t}\odot{\bm{c}}_{t}\,\Rightarrow\,{\bm{h}}_{t}={\bm{c}}% _{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⇒ bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, making having both a hidden and cell state unnecessary. As such, we drop the cell state as well. In summary, the modifications are as follows:

𝒉t=𝒐t𝒄t𝒐t=σ(Lineardh(𝒙t))𝒄t=𝒇t𝒄t1+𝒊t𝒄~t𝒄~t=Lineardh(𝒙t)𝒉t=𝒇t𝒉t1+𝒊t𝒉~t𝒉~t=Lineardh(𝒙t)𝒇t,𝒊t𝒇t𝒇t+𝒊t,𝒊t𝒇t+𝒊tsubscript𝒉𝑡absentdirect-productsubscript𝒐𝑡subscript𝒄𝑡subscript𝒐𝑡absent𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒄𝑡absentdirect-productsubscript𝒇𝑡subscript𝒄𝑡1direct-productsubscript𝒊𝑡subscript~𝒄𝑡subscript~𝒄𝑡absentsubscriptLinearsubscript𝑑subscript𝒙𝑡subscript𝒉𝑡absentdirect-productsubscriptsuperscript𝒇𝑡subscript𝒉𝑡1direct-productsubscriptsuperscript𝒊𝑡subscript~𝒉𝑡subscript~𝒉𝑡absentsubscriptLinearsubscript𝑑subscript𝒙𝑡subscriptsuperscript𝒇𝑡subscriptsuperscript𝒊𝑡absentsubscript𝒇𝑡subscript𝒇𝑡subscript𝒊𝑡subscript𝒊𝑡subscript𝒇𝑡subscript𝒊𝑡\begin{aligned} {\bm{h}}_{t}&={\bm{o}}_{t}\odot{\bm{c}}_{t}\\ {\bm{o}}_{t}&=\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))\\ {\bm{c}}_{t}&={\bm{f}}_{t}\odot{\bm{c}}_{t-1}+{\bm{i}}_{t}\odot\tilde{{\bm{c}}% }_{t}\\ \tilde{{\bm{c}}}_{t}&={\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t})\\ \end{aligned}\quad\Rightarrow\quad\begin{aligned} {\bm{h}}_{t}&={\bm{f}}^{% \prime}_{t}\odot{\bm{h}}_{t-1}+{\bm{i}}^{\prime}_{t}\odot\tilde{{\bm{h}}}_{t}% \\ \tilde{{\bm{h}}}_{t}&={\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t})\\ {\bm{f}}^{\prime}_{t},{\bm{i}}^{\prime}_{t}&\leftarrow\frac{{\bm{f}}_{t}}{{\bm% {f}}_{t}+{\bm{i}}_{t}},\frac{{\bm{i}}_{t}}{{\bm{f}}_{t}+{\bm{i}}_{t}}\\ \end{aligned}start_ROW start_CELL bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_o start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL bold_italic_c start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_c start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL over~ start_ARG bold_italic_c end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_CELL end_ROW ⇒ start_ROW start_CELL bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL = roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL start_CELL ← divide start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG , divide start_ARG bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG end_CELL end_ROW

Notably, GRUs do not need this step as their outputs are already time-independent in scale.

3.2.4 minLSTM

Combining the three steps results in a minimal version of LSTM (minLSTM):

\Rightarrow

The minimal version (minLSTM) is significantly more efficient (1) requiring only O(3dhdx)𝑂3subscript𝑑subscript𝑑𝑥O(3d_{h}d_{x})italic_O ( 3 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) parameters compared to LSTM’s O(4dh(dx+dh))𝑂4subscript𝑑subscript𝑑𝑥subscript𝑑O(4d_{h}(d_{x}+d_{h}))italic_O ( 4 italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) ). Furthermore, minLSTM (2) can be trained in parallel using the parallel scan algorithm, speeding up training significantly. For example, in Section 4.1, we found that minLSTM corresponded to a 235×235\times235 × speedup for a sequence of length 512512512512 compared to LSTM on a T4 GPU. In terms of parameter efficiency, minLSTM uses only 38%,25%,19%,percent38percent25percent1938\%,25\%,19\%,38 % , 25 % , 19 % , or 15%percent1515\%15 % of parameters compared to LSTM when α=1,2,3,𝛼123\alpha=1,2,3,italic_α = 1 , 2 , 3 , or 4444 respectively where dh=αdxsubscript𝑑𝛼subscript𝑑𝑥d_{h}=\alpha d_{x}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = italic_α italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT.

4 Were RNNs All We Needed?

In this section, we compare the minimal versions (minLSTMs and minGRUs) with their traditional counterparts (LSTMs and GRUs) and modern sequence models. Pseudocode, PyTorch implementation, and detailed information regarding the experiment setup are available in the Appendix.

4.1 Minimal LSTMs and GRUs are very efficient

Refer to caption
Refer to caption
Refer to caption
Figure 1: Training runtime (left), speedup (middle), and memory footprint (right) on a T4 GPU for a batch size of 64646464. In the training runtime plot (left), minGRU, minLSTM, and Mamba lines overlap. These methods are approximately the same in training runtime.

At test time, recurrent sequence models are rolled out sequentially, making their inferences efficient. Instead, the bottleneck of traditional RNNs is their training which requires linear training time (backpropagating through time) which resulted in their eventual deprecation. The renewed interest in recurrent sequence models is due to many new architectures being efficiently trained in parallel (Gu et al., 2021). In this section, we compare the resources required to train the traditional RNNs (LSTM and GRU), their minimal versions (minLSTM and minGRU), and a recent state-of-the-art sequence model. In particular, we focus on the comparison with Mamba (Gu & Dao, 2024) which has seen significant popularity recently. For these experiments, we consider a batch size of 64646464 and vary the sequence length. We measure the total runtime and memory complexity of performing a forward pass through the models, computing a loss, and computing gradients via a backward pass.

Runtime. In terms of runtime (see Figure 1 (left)), the simplified versions of LSTM and GRU (minLSTM and minGRU) Mamba achieve similar runtimes. Averaging over 100100100100 runs, the runtime for sequence lengths of 512512512512 for minLSTM, minGRU, and Mamba were 2.972.972.972.97, 2.722.722.722.72, and 2.712.712.712.71 milliseconds respectively. For a sequence with length 4096409640964096, the runtime were 3.413.413.413.41, 3.253.253.253.25, and 3.153.153.153.15 respectively. In contrast, the traditional RNN counterparts (LSTMs and GRUs) required a runtime that scaled linearly with respect to sequence length. For a sequence length of 512512512512, minGRUs and minLSTMs were 175×175\times175 × and 235×235\times235 × faster per training step (see Figure 1 (middle)) than GRUs and LSTMs on a T4 GPU. The improvement is even more significant as sequences grow in length with minGRUs and minLSTMs being 1324×1324\times1324 × and 1361×1361\times1361 × faster for a sequence length of 4096409640964096. As such, in a setting where minGRU would take a day to finish training for a fixed number of epochs, its traditional counterpart GRU could take over 3333 years.

Memory. By leveraging a parallel scan algorithm to compute the outputs in parallel efficiently, minGRU, minLSTM, and Mamba create a larger computational graph, thus needing more memory compared to traditional RNNs (see Figure 1 (right)). The minimal variants (minGRU and minLSTM) use 88%similar-toabsentpercent88\sim 88\%∼ 88 % more memory compared to their traditional counterpart. Mamba uses 56%percent5656\%56 % more memory compared to minGRU. In practice, however, runtime is the bottleneck when training RNNs.

Effect of removing ht1subscript𝑡1{\bm{h}}_{t-1}bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT. The original LSTM and GRU compute their various gates using their inputs 𝒙tsubscript𝒙𝑡{\bm{x}}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and previous hidden states 𝒉t1subscript𝒉𝑡1{\bm{h}}_{t-1}bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT. These models leverage their time-dependent gates to learn complex functions. However, minLSTM and minGRU’s training efficiencies are achieved by dropping their gates’ dependencies on the previous hidden states 𝒉t1subscript𝒉𝑡1{\bm{h}}_{t-1}bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT. As a result, minLSTM and minGRU’s gates are dependent only on their inputs 𝒙tsubscript𝒙𝑡{\bm{x}}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, resulting in a simpler recurrent module. As such, the gates of a model consisting of a single layer of minLSTM or minGRU are time-independent due to being conditioned on time-independent inputs 𝒙1:n(1)subscriptsuperscript𝒙1:1𝑛{\bm{x}}^{(1)}_{1:n}bold_italic_x start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_n end_POSTSUBSCRIPT.

Model # Layers Accuracy
MinLSTM 1 37.6 ± 2.0
2 85.7 ± 5.8
3 96.0 ± 2.8
MinGRU 1 37.0 ± 2.3
2 96.8 ± 3.2
3 99.5 ± 0.2
Table 1: Comparison of the number of layers on the Selective Copying Task (Gu & Dao, 2024).

However, in deep learning, models are constructed by stacking modules. Although the inputs to the first layer 𝒙1:n(1)subscriptsuperscript𝒙1:1𝑛{\bm{x}}^{(1)}_{1:n}bold_italic_x start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_n end_POSTSUBSCRIPT is time-independent, its outputs 𝒉1:n(1)subscriptsuperscript𝒉1:1𝑛{\bm{h}}^{(1)}_{1:n}bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_n end_POSTSUBSCRIPT are time-dependent and are used as the inputs to the second layer, i.e., 𝒙1:n(2)𝒉1:n(1)subscriptsuperscript𝒙2:1𝑛subscriptsuperscript𝒉1:1𝑛{\bm{x}}^{(2)}_{1:n}\leftarrow{\bm{h}}^{(1)}_{1:n}bold_italic_x start_POSTSUPERSCRIPT ( 2 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_n end_POSTSUBSCRIPT ← bold_italic_h start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_n end_POSTSUBSCRIPT. As such, beginning from the second layer onwards, minLSTM and minGRU’s gates will also be time-dependent, resulting in the modelling of more complex functions. In Table 1, we compare the performance of the models with varying numbers of layers on the Selective Copying Task from the Mamba paper (Gu & Dao, 2024). We can immediately see the impact of the time dependencies: increasing the number of layers to 2222 or more drastically increases the model’s performance.

Training Stability. Another effect of the number of layers is increased stability with decreased variance in the accuracy as the number of layers increases (see Table 1). Furthermore, although minLSTM and minGRU both solve the Selective Copying task, we can see that minGRU is an empirically more stable method than minLSTM, solving the task with more consistency and lower variance. minLSTM discards old information and adds new information, controlling the ratio with two sets of parameters (forget and input gate). During training, the two sets of parameters are tuned in different directions, making the ratio harder to control and optimize. In contrast, minGRU’s discarding and adding of information is controlled by a single set of parameters (update gate), making it easier to optimize.

Model Layer Accuracy
H3 Hyena 30.1
Mamba Hyena 28.4
S4 S4 18.3
H3 S4 57.0
Mamba S4 56.4
S4 S6 97.0
H3 S6 99.7
Mamba S6 99.8
minGRU minGRU 99.5 ± 0.2
minLSTM minLSTM 96.0 ± 2.8
Table 2: Selective Copy Task. minLSTM, minGRU, and Mamba’s S6 (Gu & Dao, 2024) are capable of solving this task. Other methods such as S4, H3, and Hyena at best only partially solve the task.

4.2 Minimal LSTMs and GRUs perform well

In the previous section, we showed the significant efficiency gains achieved by simplifying traditional RNNs. Here, we explore the empirical performance aspect of these minimal versions of LSTMs and GRUs compared to several popular sequence models.

Selective Copy. We consider the long-range Selective Copying task from the Mamba paper (Gu & Dao, 2024). Unlike the original Copying task (Arjovsky et al., 2016), the Selective Copying task’s input elements are randomly spaced relative to their output, making the task harder. To solve the task, models are required to perform content-aware reasoning, memorizing relevant and filtering out irrelevant tokens.

In Table 2, we compare the simplified versions of LSTMs and GRUs (minLSTM and minGRU) against well-known recurrent sequence models that can trained in parallel: S4 (Gu et al., 2021), H3 (Fu et al., 2023), Hyena (Poli et al., 2023), and Mamba (S6) (Gu & Dao, 2024). The results for these baselines are quoted from the Mamba paper. Out of all of these baselines, only S6 from Mamba’s paper is capable of solving this task. minGRU and minLSTM are also capable of solving the Selective Copying task, achieving comparable performance to S6 and outperforming all other baselines. LSTMs and GRUs leverage content-aware gating mechanisms, making these minimal versions sufficient for solving this task that many popular sequence models fail to solve.

Dataset DT DS4 DAaren DMamba minLSTM minGRU
HalfCheetah-M 42.6 42.5 42.2 42.8 42.7 ± 0.7 43.0 ± 0.4
Hopper-M 68.4 54.2 80.9 83.5 85.0 ± 4.4 79.4 ± 8.2
Walker-M 75.5 78.0 74.4 78.2 72.0 ± 7.5 73.3 ± 3.3
HalfCheetah-M-R 37.0 15.2 37.9 39.6 38.6 ± 1.1 38.5 ± 1.1
Hopper-M-R 85.6 49.6 77.9 82.6 88.5 ± 4.7 90.5 ± 0.9
Walker-M-R 71.2 69.0 71.4 70.9 69.7 ± 10.7 72.8 ± 8.9
HalfCheetah-M-E 88.8 92.7 75.7 91.9 85.4 ± 1.7 86.3 ± 0.5
Hopper-M-E 109.6 110.8 103.9 111.1 110.3 ± 1.6 109.7 ± 2.7
Walker-M-E 109.3 105.7 110.5 108.3 110.3 ± 0.5 110.3 ± 0.4
Average 76.4 68.6 75.0 78.8 78.1 78.2
Table 3: Reinforcement Learning results on the D4RL (Fu et al., 2020) datasets. We report the expert normalized returns (higher is better), following (Fu et al., 2020), averaged across five random seeds. The minimal versions of LSTM and GRU, minLSTM and minGRU outperform Decision S4 (David et al., 2023) and perform comparably with Decision Mamba (Ota, 2024), (Decision) Aaren (Feng et al., 2024) and Decision Transformer (Chen et al., 2021).

Reinforcement Learning. Next, we consider the MuJoCo locomotion tasks from the D4RL benchmark (Fu et al., 2020). Specifically, we consider the three environments: HalfCheetah, Hopper, and Walker. For each environment, the models are trained on three datasets of varying data quality: Medium (M), Medium-Replay (M-R), and Medium-Expert (M-E).

In Table 3, we compare minLSTM and minGRU with various Decision Transformer variants, including the original Decision Transformer (DT) (Chen et al., 2021), Decision S4 (DS4) (David et al., 2023), Decision Mamba (Ota, 2024), and (Decision) Aaren (Feng et al., 2024). The baseline results are retrieved from the Decision Mamba and Aaren papers. minLSTM and minGRU outperform Decision S4 and achieve performance competitive with Decision Transformer, Aaren, and Mamba. Unlike other recurrent methods, Decision S4 is a model whose recurrence transitions are not input-aware, affecting their performance. In terms of average score across the 3×3=93393\times 3=93 × 3 = 9 datasets, minLSTM and minGRU outperform all the baselines except for Decision Mamba where the difference is marginal.

Refer to caption
Figure 2: Language Modelling results on the Shakespeare dataset. Minimal versions of decade-old RNNs (LSTMs and GRUs) performed comparably to Mamba and Transformers. Transformers required 2.5×\sim 2.5\times∼ 2.5 × more training steps to achieve comparable performance, overfitting eventually.

Language Modelling. Finally, we consider a language modelling task. In this setting, we train a character-level GPT on the works of Shakespeare using the nanoGPT (Karpathy, 2022) framework. In Figure 2, we plot the learning curves with a cross-entropy loss comparing the proposed minimal LSTM and GRU (minLSTM and minGRU) with Mamba and Transformers. We found that minGRU, minLSTM, Mamba, and Transformers achieved comparable test losses of 1.5481.5481.5481.548, 1.5551.5551.5551.555, 1.5751.5751.5751.575, and 1.5471.5471.5471.547 respectively. Mamba performed slightly worse than the other models but trained faster, particularly in the early stages, achieving its best performance at 400400400400 steps while minGRU and minLSTM continued training until 575575575575 and 625625625625 steps respectively. In contrast, Transformers trained significantly slower, requiring 2000200020002000 steps (2.5×\sim 2.5\times∼ 2.5 ×) more training steps than minGRU to achieve comparable performance, making it significantly slower and more resource-intensive to train (quadratic complexity compared to minGRU, minLSTM, and Mamba’s linear complexity).

5 Related Work

In this section, we provide a discussion of the similarities and differences between existing recurrent sequence models and the simplified versions of LSTMs and GRUs (minLSTM and minGRU).

State-Space Models (SSMs). Although Mamba (Gu & Dao, 2024) and state-space models have gained significant popularity recently, the steps towards the recent success of Mamba began years ago. Gu et al. (2020) first proposed a discretized structured state-space model. Gu et al. (2021) scaled the idea up, introducing S4. The success of S4 became the basis for many future works (Gu et al., 2022; Gupta et al., 2022; Hasani et al., 2023; Smith et al., 2023) and state-space model applications in language (Mehta et al., 2023), audio (Goel et al., 2022), and more. Recently, Mamba was a significant breakthrough in SSM, outperforming previous methods and garnering substantial attention. A major novelty in Mamba was the proposal of S6, a state-space model whose transition matrices are input-dependent (i.e., Atsubscript𝐴𝑡A_{t}italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and Btsubscript𝐵𝑡B_{t}italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are functions of xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT). In contrast, earlier state-space model transition matrices were input-independent, limiting their expressivity. The success of Mamba and state-space models led to the writing of several survey papers (Wang et al., 2024; Patro & Agneeswaran, 2024; Qu et al., 2024).

Recurrent Versions of Attention. Another direction that proposed efficient recurrent sequence models is that of attention. Building on variations of linear attention (Katharopoulos et al., 2020), several papers have introduced recurrent versions that can be computed in parallel. Notably, Sun et al. (2023) and Qin et al. (2023) introduced variants that use an input-independent gating mechanism (decay factor). More recently, Katsch (2023) and Yang et al. (2024) proposed linear attention variants that use input-dependent gating. Feng et al. (2024) showed softmax attention can be viewed as an RNN and proposed a recurrent model based on their RNN formulation.

Parallelizable RNNs. Alternatively, several papers have proposed RNNs that can be trained efficiently in parallel. Orvieto et al. (2023) proposed an RNN that leverages complex diagonal recurrences and an exponential parameterization. Beck et al. (2024) proposed various enhancements to LSTM such as exponential gating, covariance update rule, and a normalizer state.

Although these three directions of designing efficient recurrent sequence models have proposed vastly different architectures, the core recurrent component of these models is remarkably similar. For example, although state-space models are typically written as 𝒉t=𝑨t𝒉t1+𝑩t𝒙tsubscript𝒉𝑡subscript𝑨𝑡subscript𝒉𝑡1subscript𝑩𝑡subscript𝒙𝑡{\bm{h}}_{t}={\bm{A}}_{t}{\bm{h}}_{t-1}+{\bm{B}}_{t}{\bm{x}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, in practice, the transition matrices are typically diagonal for efficiency reasons. As such, Mamba’s S6 (Gu & Dao, 2024) can be viewed as 𝒉t=diag(𝑨t)𝒉t1+diag(𝑩t)𝒙tsubscript𝒉𝑡direct-productdiagsubscript𝑨𝑡subscript𝒉𝑡1direct-productdiagsubscript𝑩𝑡subscript𝒙𝑡{\bm{h}}_{t}=\mathrm{diag}({\bm{A}}_{t})\odot{\bm{h}}_{t-1}+\mathrm{diag}({\bm% {B}}_{t})\odot{\bm{x}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_diag ( bold_italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + roman_diag ( bold_italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where 𝑨tsubscript𝑨𝑡{\bm{A}}_{t}bold_italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝑩tsubscript𝑩𝑡{\bm{B}}_{t}bold_italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are functions of 𝒙tsubscript𝒙𝑡{\bm{x}}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. In contrast, consider the minimal version of GRU 𝒉t=(𝟏𝒛t)𝒉t1+𝒛tLineardn(𝒙t)subscript𝒉𝑡direct-product1subscript𝒛𝑡subscript𝒉𝑡1direct-productsubscript𝒛𝑡subscriptLinearsubscript𝑑𝑛subscript𝒙𝑡{\bm{h}}_{t}=(\bm{1}-{\bm{z}}_{t})\odot{\bm{h}}_{t-1}+{\bm{z}}_{t}\odot\mathrm% {Linear}_{d_{n}}({\bm{x}}_{t})bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and minimal version of LSTM 𝒉t=𝒇t𝒉t1+𝒊tLineardn(𝒙t)subscript𝒉𝑡direct-productsubscriptsuperscript𝒇𝑡subscript𝒉𝑡1direct-productsubscriptsuperscript𝒊𝑡subscriptLinearsubscript𝑑𝑛subscript𝒙𝑡{\bm{h}}_{t}={\bm{f}}^{\prime}_{t}\odot{\bm{h}}_{t-1}+{\bm{i}}^{\prime}_{t}% \odot\mathrm{Linear}_{d_{n}}({\bm{x}}_{t})bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). The recurrences of these models are similar. The major difference between these minimal RNNs, Mamba’s S6, and other models is how their transitions (e.g., 𝒛t,𝒊t,𝒇t,𝑨t,subscript𝒛𝑡subscriptsuperscript𝒊𝑡subscriptsuperscript𝒇𝑡subscript𝑨𝑡{\bm{z}}_{t},{\bm{i}}^{\prime}_{t},{\bm{f}}^{\prime}_{t},{\bm{A}}_{t},bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , and 𝑩tsubscript𝑩𝑡{\bm{B}}_{t}bold_italic_B start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT) are computed from the input token 𝒙tsubscript𝒙𝑡{\bm{x}}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Parallel Scan. Generalizing across the families of methods (including minLSTM and minGRU), these recent sequence models can be viewed as members of the same family of functions trainable via a parallel scan: 𝒗t=𝒂t𝒗t1+𝒃tsubscript𝒗𝑡direct-productsubscript𝒂𝑡subscript𝒗𝑡1subscript𝒃𝑡{\bm{v}}_{t}={\bm{a}}_{t}\odot{\bm{v}}_{t-1}+{\bm{b}}_{t}bold_italic_v start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_v start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (see Section 2.3) where 𝒂tsubscript𝒂𝑡{\bm{a}}_{t}bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝒃tsubscript𝒃𝑡{\bm{b}}_{t}bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are functions of the input token 𝒙tsubscript𝒙𝑡{\bm{x}}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Improving upon the parallel scan algorithm, several models (Yang et al., 2024; Gu & Dao, 2024) such as Mamba have proposed specialized hardware-efficient methods that leverage GPU’s memory hierarchy to reduce high I/O costs and speed up training. In our work, we implemented minLSTM and minGRU in plain PyTorch. However, due to the structural similarities in recurrences amongst the numerous methods that leverage parallel scan, many techniques such as chunking that apply to one work for speeding up training can also apply to others such as minGRU and minLSTM.

Parameter Initializations. Unrolling the recurrences of these new recurrent sequence models over time often results in their outputs and gradients vanishing/exploding (Wang et al., 2024) due to time dependency in their output’s scale. To ensure model stability, the parameters of many models such as state-space models are initialized according to special distributions (Gu et al., 2020, 2022; Orvieto et al., 2023). In contrast, we found that minLSTM and minGRU are already stable using the default PyTorch initialization. Unlike SSMs, minLSTM and minGRU’s outputs are time-independent in scale, avoiding potential instabilities.

6 Conclusion

In this work, we revisited RNNs from over a decade ago: LSTMs and GRUs. We show that these models are trainable via the parallel scan algorithm by removing their hidden state dependencies from their gates. Simplifying these models further, we removed their constraints on output range and ensured their output was time-independent in scale. These steps result in their minimal versions (minLSTM and minGRU). Empirically, we showed that minLSTM and minGRU (1) address the computational limitations of their traditional counterparts and (2) are as computationally efficient as Mamba, a popular recent state-of-the-art recurrent sequence model, and (3) are competitive in performance with recent sequence models. Considering the strong empirical performance of these simplified RNNs and their fundamental similarities with many recently proposed recurrent sequence methods, we question ”Were RNNs all we needed?”

Limitations

Our experiments were run on P100 (16 GBs) and T4 (16 GBs) GPUs. Due to computation limitations, our experiments are smaller in scale compared to works such as Mamba (Gu & Dao, 2024) which leveraged A100 80GB GPUs. To fit the selective copy task on the GPU, we leveraged gradient accumulation for training, splitting the standard batch size in half and slowing training significantly. Nonetheless, we hypothesize that these conclusions generalize to larger-scale settings due to the fundamental similarities between the minimal RNNs (minLSTM and minGRU) and many recent sequence methods.

References

  • Arjovsky et al. (2016) Martin Arjovsky, Amar Shah, and Yoshua Bengio. Unitary evolution recurrent neural networks. In International conference on machine learning, pp.  1120–1128. PMLR, 2016.
  • Beck et al. (2024) Maximilian Beck, Korbinian Pöppel, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter. xlstm: Extended long short-term memory. arXiv preprint arXiv:2405.04517, 2024.
  • Blelloch (1990) Guy E Blelloch. Prefix sums and their applications. Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University, 1990.
  • Brown et al. (2020) Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel M. Ziegler, Jeff Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. ArXiv, abs/2005.14165, 2020.
  • Chen et al. (2021) Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Misha Laskin, Pieter Abbeel, Aravind Srinivas, and Igor Mordatch. Decision transformer: Reinforcement learning via sequence modeling. Advances in neural information processing systems, 34:15084–15097, 2021.
  • Cho et al. (2014) Kyunghyun Cho, Bart Van Merrienboer, Caglar Gulcehre, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. Learning phrase representations using rnn encoder-decoder for statistical machine translation. In EMNLP, 2014.
  • David et al. (2023) Shmuel Bar David, Itamar Zimerman, Eliya Nachmani, and Lior Wolf. Decision s4: Efficient sequence-based rl via state spaces layers. In The Eleventh International Conference on Learning Representations, 2023.
  • Devlin et al. (2019) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. In North American Chapter of the Association for Computational Linguistics, 2019.
  • Elman (1990) Jeffrey L. Elman. Finding structure in time. Cognitive Science, 14(2):179–211, 1990. ISSN 0364-0213.
  • Feng et al. (2024) Leo Feng, Frederick Tung, Hossein Hajimirsadeghi, Mohamed Osama Ahmed, Yoshua Bengio, and Greg Mori. Attention as an rnn. arXiv preprint arXiv:2405.13956, 2024.
  • Fu et al. (2023) Daniel Y Fu, Tri Dao, Khaled Kamal Saab, Armin W Thomas, Atri Rudra, and Christopher Re. Hungry hungry hippos: Towards language modeling with state space models. In The Eleventh International Conference on Learning Representations, 2023.
  • Fu et al. (2020) Justin Fu, Aviral Kumar, Ofir Nachum, George Tucker, and Sergey Levine. D4rl: Datasets for deep data-driven reinforcement learning. arXiv preprint arXiv:2004.07219, 2020.
  • Goel et al. (2022) Karan Goel, Albert Gu, Chris Donahue, and Christopher Ré. It’s raw! audio generation with state-space models. In International Conference on Machine Learning, pp.  7616–7633. PMLR, 2022.
  • Gu & Dao (2024) Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752, 2024.
  • Gu et al. (2020) Albert Gu, Tri Dao, Stefano Ermon, Atri Rudra, and Christopher Ré. Hippo: Recurrent memory with optimal polynomial projections. Advances in neural information processing systems, 33:1474–1487, 2020.
  • Gu et al. (2021) Albert Gu, Karan Goel, and Christopher Re. Efficiently modeling long sequences with structured state spaces. In International Conference on Learning Representations, 2021.
  • Gu et al. (2022) Albert Gu, Karan Goel, Ankit Gupta, and Christopher Ré. On the parameterization and initialization of diagonal state space models. Advances in Neural Information Processing Systems, 35:35971–35983, 2022.
  • Gupta et al. (2022) Ankit Gupta, Albert Gu, and Jonathan Berant. Diagonal state spaces are as effective as structured state spaces. Advances in Neural Information Processing Systems, 35:22982–22994, 2022.
  • Hasani et al. (2023) Ramin Hasani, Mathias Lechner, Tsun-Hsuan Wang, Makram Chahine, Alexander Amini, and Daniela Rus. Liquid structural state-space models. In The Eleventh International Conference on Learning Representations, 2023.
  • Heinsen (2023) Franz A Heinsen. Parallelization of an ubiquitous sequential computation. arXiv preprint arXiv:2311.06281, 2023.
  • Hochreiter & Schmidhuber (1997) S Hochreiter and J Schmidhuber. Long short-term memory. Neural Computation, 9(8):1735–1780, 1997.
  • Karpathy (2022) Andrej Karpathy. NanoGPT. https://github.com/karpathy/nanoGPT, 2022.
  • Katharopoulos et al. (2020) Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pp.  5156–5165. PMLR, 2020.
  • Katsch (2023) Tobias Katsch. Gateloop: Fully data-controlled linear recurrence for sequence modeling. arXiv preprint arXiv:2311.01927, 2023.
  • Mehta et al. (2023) Harsh Mehta, Ankit Gupta, Ashok Cutkosky, and Behnam Neyshabur. Long range language modeling via gated state spaces. In The Eleventh International Conference on Learning Representations, 2023.
  • Orvieto et al. (2023) Antonio Orvieto, Samuel L Smith, Albert Gu, Anushan Fernando, Caglar Gulcehre, Razvan Pascanu, and Soham De. Resurrecting recurrent neural networks for long sequences. In International Conference on Machine Learning, pp.  26670–26698. PMLR, 2023.
  • Ota (2024) Toshihiro Ota. Decision mamba: Reinforcement learning via sequence modeling with selective state spaces. arXiv preprint arXiv:2403.19925, 2024.
  • Patro & Agneeswaran (2024) Badri Narayana Patro and Vijay Srinivas Agneeswaran. Mamba-360: Survey of state space models as transformer alternative for long sequence modelling: Methods, applications, and challenges. arXiv preprint arXiv:2404.16112, 2024.
  • Peng et al. (2023) Bo Peng, Eric Alcaide, Quentin Anthony, Alon Albalak, Samuel Arcadinho, Stella Biderman, Huanqi Cao, Xin Cheng, Michael Chung, Matteo Grella, et al. Rwkv: Reinventing rnns for the transformer era. arXiv preprint arXiv:2305.13048, 2023.
  • Poli et al. (2023) Michael Poli, Stefano Massaroli, Eric Nguyen, Daniel Y Fu, Tri Dao, Stephen Baccus, Yoshua Bengio, Stefano Ermon, and Christopher Ré. Hyena hierarchy: Towards larger convolutional language models. In International Conference on Machine Learning, pp.  28043–28078. PMLR, 2023.
  • Qin et al. (2023) Zhen Qin, Dong Li, Weigao Sun, Weixuan Sun, Xuyang Shen, Xiaodong Han, Yunshen Wei, Baohong Lv, Fei Yuan, Xiao Luo, et al. Scaling transnormer to 175 billion parameters. arXiv preprint arXiv:2307.14995, 2023.
  • Qu et al. (2024) Haohao Qu, Liangbo Ning, Rui An, Wenqi Fan, Tyler Derr, Xin Xu, and Qing Li. A survey of mamba. arXiv preprint arXiv:2408.01129, 2024.
  • Smith et al. (2023) Jimmy TH Smith, Andrew Warrington, and Scott Linderman. Simplified state space layers for sequence modeling. In The Eleventh International Conference on Learning Representations, 2023.
  • Sun et al. (2023) Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and Furu Wei. Retentive network: A successor to transformer for large language models. arXiv preprint arXiv:2307.08621, 2023.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30(2017), 2017.
  • Wang et al. (2024) Xiao Wang, Shiao Wang, Yuhe Ding, Yuehang Li, Wentao Wu, Yao Rong, Weizhe Kong, Ju Huang, Shihao Li, Haoxiang Yang, et al. State space model for new-generation network alternative to transformers: A survey. arXiv preprint arXiv:2404.09516, 2024.
  • Yang et al. (2024) Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, and Yoon Kim. Gated linear attention transformers with hardware-efficient training. In International Conference on Machine Learning, 2024.

Appendix A Implementation Details: Vanilla Version

In this section, we provide the pseudocode and equivalent PyTorch code for minGRU and minLSTM. When performing repeated multiplications such as in many recurrent sequence models, numerical instabilities are common, especially during training. As such, we trained using a log-space implementation (see Section B) for improved numerical stability.

A.1 Pseudocode: Vanilla Version

A.1.1 minGRU: A Minimal GRU

Algorithm 1 Sequential Mode: Minimal Version of GRU (minGRU)
𝒙t,𝒉t1subscript𝒙𝑡subscript𝒉𝑡1{\bm{x}}_{t},{\bm{h}}_{t-1}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT
𝒉tsubscript𝒉𝑡{\bm{h}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
𝒛tσ(Lineardh(𝒙t))subscript𝒛𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡{\bm{z}}_{t}\leftarrow\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
𝒉~tLineardh(𝒙t)subscript~𝒉𝑡subscriptLinearsubscript𝑑subscript𝒙𝑡\tilde{{\bm{h}}}_{t}\leftarrow{\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t})over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
𝒉t(𝟏𝒛t)𝒉t1+𝒛t𝒉~tsubscript𝒉𝑡direct-product1subscript𝒛𝑡subscript𝒉𝑡1direct-productsubscript𝒛𝑡subscript~𝒉𝑡{\bm{h}}_{t}\leftarrow(\bm{1}-{\bm{z}}_{t})\odot{\bm{h}}_{t-1}+{\bm{z}}_{t}% \odot\tilde{{\bm{h}}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
Algorithm 2 Parallel Mode: Minimal Version of GRU (minGRU)
𝒙1:t,𝒉0subscript𝒙:1𝑡subscript𝒉0{\bm{x}}_{1:t},{\bm{h}}_{0}bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
𝒉1:tsubscript𝒉:1𝑡{\bm{h}}_{1:t}bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT
𝒛1:tσ(Lineardh(𝒙1:t))subscript𝒛:1𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙:1𝑡{\bm{z}}_{1:t}\leftarrow\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{1:t}))bold_italic_z start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) )
𝒉~1:tLineardh(𝒙1:t)subscript~𝒉:1𝑡subscriptLinearsubscript𝑑subscript𝒙:1𝑡\tilde{{\bm{h}}}_{1:t}\leftarrow{\mathrm{Linear}}_{d_{h}}({\bm{x}}_{1:t})over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT )
𝒉1:tParallelScan((𝟏𝒛1:t),[𝒉0,𝒛1:t𝒉~1:t])subscript𝒉:1𝑡ParallelScan1subscript𝒛:1𝑡subscript𝒉0direct-productsubscript𝒛:1𝑡subscript~𝒉:1𝑡{\bm{h}}_{1:t}\leftarrow\mathrm{ParallelScan}((\bm{1}-{\bm{z}}_{1:t}),[{\bm{h}% }_{0},{\bm{z}}_{1:t}\odot\tilde{{\bm{h}}}_{1:t}])bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← roman_ParallelScan ( ( bold_1 - bold_italic_z start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) , [ bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_z start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ] )

A.1.2 minLSTM: A Minimal LSTM

Algorithm 3 Sequential Mode: Minimal Version of LSTM (minLSTM)
𝒙t,𝒉t1subscript𝒙𝑡subscript𝒉𝑡1{\bm{x}}_{t},{\bm{h}}_{t-1}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT
𝒉tsubscript𝒉𝑡{\bm{h}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
𝒇tσ(Lineardh(𝒙t))subscript𝒇𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡{\bm{f}}_{t}\leftarrow\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
𝒊tσ(Lineardh(𝒙t))subscript𝒊𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡{\bm{i}}_{t}\leftarrow\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
𝒇t,𝒊t𝒇t𝒇t+𝒊t,𝒊t𝒇t+𝒊tformulae-sequencesubscriptsuperscript𝒇𝑡subscriptsuperscript𝒊𝑡subscript𝒇𝑡subscript𝒇𝑡subscript𝒊𝑡subscript𝒊𝑡subscript𝒇𝑡subscript𝒊𝑡{\bm{f}}^{\prime}_{t},{\bm{i}}^{\prime}_{t}\leftarrow\frac{{\bm{f}}_{t}}{{\bm{% f}}_{t}+{\bm{i}}_{t}},\frac{{\bm{i}}_{t}}{{\bm{f}}_{t}+{\bm{i}}_{t}}bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG , divide start_ARG bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG
𝒉~tLineardh(𝒙t)subscript~𝒉𝑡subscriptLinearsubscript𝑑subscript𝒙𝑡\tilde{{\bm{h}}}_{t}\leftarrow{\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t})over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
𝒉t𝒇t𝒉t1+𝒊t𝒉~tsubscript𝒉𝑡direct-productsubscriptsuperscript𝒇𝑡subscript𝒉𝑡1direct-productsubscriptsuperscript𝒊𝑡subscript~𝒉𝑡{\bm{h}}_{t}\leftarrow{\bm{f}}^{\prime}_{t}\odot{\bm{h}}_{t-1}+{\bm{i}}^{% \prime}_{t}\odot\tilde{{\bm{h}}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
Algorithm 4 Parallel Mode: Minimal Version of LSTM (minLSTM)
𝒙1:t,𝒉0subscript𝒙:1𝑡subscript𝒉0{\bm{x}}_{1:t},{\bm{h}}_{0}bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
𝒉1:tsubscript𝒉:1𝑡{\bm{h}}_{1:t}bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT
𝒇1:tσ(Lineardh(𝒙1:t))subscript𝒇:1𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙:1𝑡{\bm{f}}_{1:t}\leftarrow\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{1:t}))bold_italic_f start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) )
𝒊1:tσ(Lineardh(𝒙1:t))subscript𝒊:1𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙:1𝑡{\bm{i}}_{1:t}\leftarrow\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{1:t}))bold_italic_i start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) )
𝒇1:t,𝒊1:t𝒇1:t𝒇1:t+𝒊1:t,𝒊1:t𝒇1:t+𝒊1:tformulae-sequencesubscriptsuperscript𝒇:1𝑡subscriptsuperscript𝒊:1𝑡subscript𝒇:1𝑡subscript𝒇:1𝑡subscript𝒊:1𝑡subscript𝒊:1𝑡subscript𝒇:1𝑡subscript𝒊:1𝑡{\bm{f}}^{\prime}_{1:t},{\bm{i}}^{\prime}_{1:t}\leftarrow\frac{{\bm{f}}_{1:t}}% {{\bm{f}}_{1:t}+{\bm{i}}_{1:t}},\frac{{\bm{i}}_{1:t}}{{\bm{f}}_{1:t}+{\bm{i}}_% {1:t}}bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT , bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← divide start_ARG bold_italic_f start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT end_ARG , divide start_ARG bold_italic_i start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT end_ARG
𝒉~1:tLineardh(𝒙1:t)subscript~𝒉:1𝑡subscriptLinearsubscript𝑑subscript𝒙:1𝑡\tilde{{\bm{h}}}_{1:t}\leftarrow{\mathrm{Linear}}_{d_{h}}({\bm{x}}_{1:t})over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT )
𝒉1:tParallelScan(𝒇1:t,[𝒉0,𝒊1:t𝒉~1:t])subscript𝒉:1𝑡ParallelScansubscriptsuperscript𝒇:1𝑡subscript𝒉0direct-productsubscriptsuperscript𝒊:1𝑡subscript~𝒉:1𝑡{\bm{h}}_{1:t}\leftarrow\mathrm{ParallelScan}({\bm{f}}^{\prime}_{1:t},[{\bm{h}% }_{0},{\bm{i}}^{\prime}_{1:t}\odot\tilde{{\bm{h}}}_{1:t}])bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← roman_ParallelScan ( bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT , [ bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ] )

A.2 PyTorch Code: Vanilla Version

A.2.1 minGRU: A Minimal GRU

1 def forward(self, x_t, h_prev):
2 # x_t: (batch_size, input_size)
3 # h_prev: (batch_size, hidden_size)
4
5 z_t = torch.sigmoid(self.linear_z(x_t))
6 h_tilde = self.linear_h(x_t)
7 h_t = (1 - z_t) * h_prev + z_t * h_tilde
8 return h_t
Listing 1: Sequential Mode: Minimal Version of GRU (minGRU)
1 def forward(self, x, h_0):
2 # x: (batch_size, seq_len, input_size)
3 # h_0: (batch_size, 1, hidden_size)
4
5 z = torch.sigmoid(self.linear_z(x))
6 h_tilde = self.linear_h(x)
7 h = parallel_scan((1 - z),
8 torch.cat([h_0, z * tilde_h], dim=1))
9 return h
Listing 2: Parallel Mode: Minimal Version of GRU (minGRU)

A.2.2 minLSTM: A Minimal LSTM

1 def forward(self, x_t, h_prev):
2 # x_t: (batch_size, input_size)
3 # h_prev: (batch_size, hidden_size)
4
5 f_t = torch.sigmoid(self.linear_f(x_t))
6 i_t = torch.sigmoid(self.linear_i(x_t))
7 tilde_h_t = self.linear_h(x_t)
8 f_prime_t = f_t / (f_t + i_t)
9 i_prime_t = i_t / (f_t + i_t)
10 h_t = f_prime_t * h_prev + i_prime_t * tilde_h_t
11 return h_t
Listing 3: Sequential Mode: Minimal Version of LSTM (minLSTM)
1 def forward(self, x, h_0):
2 # x: (batch_size, seq_len, input_size)
3 # h_0: (batch_size, 1, hidden_size)
4
5 f = torch.sigmoid(self.linear_f(x))
6 i = torch.sigmoid(self.linear_i(x))
7 tilde_h = self.linear_h(x)
8 f_prime = f / (f + i)
9 i_prime = i / (f + i)
10 h = parallel_scan(f_prime,
11 torch.cat([h_0, i_prime * tilde_h], dim=1))
12 return h
Listing 4: Parallel Mode: Minimal Version of LSTM (minLSTM)

Appendix B Implementation Details: Log-Space Version (Additional Numerical Stability)

In this section, we detail the log-space version of minLSTM and minGRU for improved numerical stability. During training, the parallel modes are used to avoid backpropagation through time (BPTT), speeding up the training time significantly. At inference time, the sequential modes are used.

B.1 Parallel Scan: Log-Space Implementation

Recall that, the parallel scan’s objective is to compute 𝒉1:tsubscript𝒉:1𝑡{\bm{h}}_{1:t}bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT where 𝒉k=𝒂k𝒉k1+𝒃ksubscript𝒉𝑘direct-productsubscript𝒂𝑘subscript𝒉𝑘1subscript𝒃𝑘{\bm{h}}_{k}={\bm{a}}_{k}\odot{\bm{h}}_{k-1}+{\bm{b}}_{k}bold_italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = bold_italic_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT + bold_italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. In code, the vanilla parallel scan function would take as input: coefficients 𝒂1:tsubscript𝒂:1𝑡{\bm{a}}_{1:t}bold_italic_a start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT and values 𝒃0:tsubscript𝒃:0𝑡{\bm{b}}_{0:t}bold_italic_b start_POSTSUBSCRIPT 0 : italic_t end_POSTSUBSCRIPT. The function then outputs 𝒉1:tsubscript𝒉:1𝑡{\bm{h}}_{1:t}bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT. For numerical stability, we consider a log-space implementation which takes as input log(𝒂1:t)subscript𝒂:1𝑡\log({\bm{a}}_{1:t})roman_log ( bold_italic_a start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) and log(𝒃0:t)subscript𝒃:0𝑡\log({\bm{b}}_{0:t})roman_log ( bold_italic_b start_POSTSUBSCRIPT 0 : italic_t end_POSTSUBSCRIPT ) instead and outputs 𝒉1:tsubscript𝒉:1𝑡{\bm{h}}_{1:t}bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT. The code for the parallel scan in log-space is included below and is based on the code by Heinsen (2023).

1def parallel_scan_log(log_coeffs, log_values):
2 # log_coeffs: (batch_size, seq_len, input_size)
3 # log_values: (batch_size, seq_len + 1, input_size)
4 a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 1, 0))
5 log_h0_plus_b_star = torch.logcumsumexp(
6 log_values - a_star, dim=1)
7 log_h = a_star + log_h0_plus_b_star
8 return torch.exp(log_h)[:, 1:]
Listing 5: Parallel scan based on Heinsen (2023). This function computes 𝒉1:tsubscript𝒉:1𝑡{\bm{h}}_{1:t}bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT given log coefficients log(𝒂1:t)subscript𝒂:1𝑡\log({\bm{a}}_{1:t})roman_log ( bold_italic_a start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) and log values log(𝒃0:t)subscript𝒃:0𝑡\log({\bm{b}}_{0:t})roman_log ( bold_italic_b start_POSTSUBSCRIPT 0 : italic_t end_POSTSUBSCRIPT ).

B.2 Pseudocode: Log-Space Version

For maximal numerical stability, we rewrite the log-space versions of minGRU and minLSTM.

B.2.1 minGRU: A Minimal GRU

Recall minGRU’s recurrence is as follows 𝒉t(𝟏𝒛t)𝒉t1+𝒛t𝒉~tsubscript𝒉𝑡direct-product1subscript𝒛𝑡subscript𝒉𝑡1direct-productsubscript𝒛𝑡subscript~𝒉𝑡{\bm{h}}_{t}\leftarrow(\bm{1}-{\bm{z}}_{t})\odot{\bm{h}}_{t-1}+{\bm{z}}_{t}% \odot\tilde{{\bm{h}}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. As such, 𝒂t(𝟏𝒛t)subscript𝒂𝑡1subscript𝒛𝑡{\bm{a}}_{t}\leftarrow(\bm{1}-{\bm{z}}_{t})bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and 𝒃t𝒛t𝒉~tsubscript𝒃𝑡direct-productsubscript𝒛𝑡subscript~𝒉𝑡{\bm{b}}_{t}\leftarrow{\bm{z}}_{t}\odot\tilde{{\bm{h}}}_{t}bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT where 𝒛t=σ(𝒌t)subscript𝒛𝑡𝜎subscript𝒌𝑡{\bm{z}}_{t}=\sigma({\bm{k}}_{t})bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_σ ( bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and 𝒌t=Lineardh(xt)subscript𝒌𝑡subscriptLinearsubscript𝑑subscript𝑥𝑡{\bm{k}}_{t}=\mathrm{Linear}_{d_{h}}(x_{t})bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). As a result, log(𝒂t)log(𝟏𝒛t)subscript𝒂𝑡1subscript𝒛𝑡\log({\bm{a}}_{t})\leftarrow\log(\bm{1}-{\bm{z}}_{t})roman_log ( bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ← roman_log ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and log(𝒃t)log(𝒛t)+log(𝒉~t)subscript𝒃𝑡subscript𝒛𝑡subscript~𝒉𝑡\log({\bm{b}}_{t})\leftarrow\log({\bm{z}}_{t})+\log(\tilde{{\bm{h}}}_{t})roman_log ( bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ← roman_log ( bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + roman_log ( over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). We can break down these down as follows:

log(𝒛t)subscript𝒛𝑡\displaystyle\log({\bm{z}}_{t})roman_log ( bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) =log(σ(𝒌t))absent𝜎subscript𝒌𝑡\displaystyle=\log(\sigma({\bm{k}}_{t}))= roman_log ( italic_σ ( bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
=log(11+exp(𝒌t))absent11subscript𝒌𝑡\displaystyle=\log\left(\frac{1}{1+\exp(-{\bm{k}}_{t})}\right)= roman_log ( divide start_ARG 1 end_ARG start_ARG 1 + roman_exp ( - bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG )
=Softplus(𝒌t)absentSoftplussubscript𝒌𝑡\displaystyle=-\mathrm{Softplus}(-{\bm{k}}_{t})= - roman_Softplus ( - bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
log(𝟏𝒛t)1subscript𝒛𝑡\displaystyle\log(\bm{1}-{\bm{z}}_{t})roman_log ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) =log(exp(𝒌t)1+exp(𝒌t))absentsubscript𝒌𝑡1subscript𝒌𝑡\displaystyle=\log\left(\frac{\exp(-{\bm{k}}_{t})}{1+\exp(-{\bm{k}}_{t})}\right)= roman_log ( divide start_ARG roman_exp ( - bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG 1 + roman_exp ( - bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG )
=log(11+exp(𝒌t))absent11subscript𝒌𝑡\displaystyle=\log\left(\frac{1}{1+\exp({\bm{k}}_{t})}\right)= roman_log ( divide start_ARG 1 end_ARG start_ARG 1 + roman_exp ( bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG )
=Softplus(𝒌t)absentSoftplussubscript𝒌𝑡\displaystyle=-\mathrm{Softplus}({\bm{k}}_{t})= - roman_Softplus ( bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )

where 𝒌t=Lineardh(𝒙t)subscript𝒌𝑡subscriptLinearsubscript𝑑subscript𝒙𝑡{\bm{k}}_{t}=\mathrm{Linear}_{d_{h}}({\bm{x}}_{t})bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). However, we need to compute log(𝒉~)t\log(\tilde{{\bm{h}}})_{t}roman_log ( over~ start_ARG bold_italic_h end_ARG ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT which is inconvenient if 𝒉~tsubscript~𝒉𝑡\tilde{{\bm{h}}}_{t}over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT has some negative values. We could use complex numbers and a complex number version of the parallel scan, but this would result in the parallel scan increasing in complexity. Instead, we propose to ensure that 𝒉~t>0subscript~𝒉𝑡0\tilde{{\bm{h}}}_{t}>0over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT > 0. This be can done in a variety of ways. In our experiments, we added a continuous activation function g𝑔gitalic_g replacing 𝒉~tLineardh(xt)subscript~𝒉𝑡subscriptLinearsubscript𝑑subscript𝑥𝑡\tilde{{\bm{h}}}_{t}\leftarrow\mathrm{Linear}_{d_{h}}(x_{t})over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) with 𝒉~tg(Lineardh(xt))subscript~𝒉𝑡𝑔subscriptLinearsubscript𝑑subscript𝑥𝑡\tilde{{\bm{h}}}_{t}\leftarrow g(\mathrm{Linear}_{d_{h}}(x_{t}))over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_g ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) where g(x)={x+0.5,if x0σ(x),otherwise𝑔𝑥cases𝑥0.5if 𝑥0𝜎𝑥otherwiseg(x)=\begin{cases}x+0.5,&\text{if }x\geq 0\\ \sigma(x),&\text{otherwise}\end{cases}italic_g ( italic_x ) = { start_ROW start_CELL italic_x + 0.5 , end_CELL start_CELL if italic_x ≥ 0 end_CELL end_ROW start_ROW start_CELL italic_σ ( italic_x ) , end_CELL start_CELL otherwise end_CELL end_ROW and its log: log(g(x))={log(x+0.5),if x0Softplus(x),otherwise𝑔𝑥cases𝑥0.5if 𝑥0Softplus𝑥otherwise\log(g(x))=\begin{cases}\log(x+0.5),&\text{if }x\geq 0\\ -\mathrm{Softplus}(-x),&\text{otherwise}\end{cases}roman_log ( italic_g ( italic_x ) ) = { start_ROW start_CELL roman_log ( italic_x + 0.5 ) , end_CELL start_CELL if italic_x ≥ 0 end_CELL end_ROW start_ROW start_CELL - roman_Softplus ( - italic_x ) , end_CELL start_CELL otherwise end_CELL end_ROW.

At inference time, the sequential mode (Algorithm 5) is used. During training, the parallel mode (Algorithm 6) is used.

Algorithm 5 Sequential Mode: Minimal Version of GRU (minGRU) trained in log-space
𝒙t,𝒉t1subscript𝒙𝑡subscript𝒉𝑡1{\bm{x}}_{t},{\bm{h}}_{t-1}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT
𝒉tsubscript𝒉𝑡{\bm{h}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
𝒛tσ(Lineardh(𝒙t))subscript𝒛𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡{\bm{z}}_{t}\leftarrow\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
𝒉~tg(Lineardh(𝒙t))subscript~𝒉𝑡gsubscriptLinearsubscript𝑑subscript𝒙𝑡\tilde{{\bm{h}}}_{t}\leftarrow\mathrm{g}({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t% }))over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← roman_g ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
𝒉t(𝟏𝒛t)𝒉t1+𝒛t𝒉~tsubscript𝒉𝑡direct-product1subscript𝒛𝑡subscript𝒉𝑡1direct-productsubscript𝒛𝑡subscript~𝒉𝑡{\bm{h}}_{t}\leftarrow(\bm{1}-{\bm{z}}_{t})\odot{\bm{h}}_{t-1}+{\bm{z}}_{t}% \odot\tilde{{\bm{h}}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← ( bold_1 - bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
Algorithm 6 Parallel Mode: Minimal Version of GRU (minGRU) for training in log-space
𝒙1:t,𝒉0subscript𝒙:1𝑡subscript𝒉0{\bm{x}}_{1:t},{\bm{h}}_{0}bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
𝒉1:tsubscript𝒉:1𝑡{\bm{h}}_{1:t}bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT
linear_zLineardhlinear_zsubscriptLinearsubscript𝑑\mathrm{linear\_z}\leftarrow{\mathrm{Linear}}_{d_{h}}roman_linear _ roman_z ← roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT
log_𝒛1:tSoftplus(linear_z(𝒙1:t))log_subscript𝒛:1𝑡Softpluslinear_zsubscript𝒙:1𝑡\mathrm{log}\_{\bm{z}}_{1:t}\leftarrow-\mathrm{Softplus}(\mathrm{linear\_z}(-{% \bm{x}}_{1:t}))roman_log _ bold_italic_z start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← - roman_Softplus ( roman_linear _ roman_z ( - bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) )
log_coeffsSoftplus(linear_z(𝒙1:t))log_coeffsSoftpluslinear_zsubscript𝒙:1𝑡\mathrm{log\_coeffs}\leftarrow-\mathrm{Softplus}(\mathrm{linear\_z}({\bm{x}}_{% 1:t}))roman_log _ roman_coeffs ← - roman_Softplus ( roman_linear _ roman_z ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) )
log_𝒉0log_g(h0)log_subscript𝒉0log_gsubscript0\mathrm{log}\_{\bm{h}}_{0}\leftarrow\mathrm{log\_g}(h_{0})roman_log _ bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ← roman_log _ roman_g ( italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
log_𝒉~1:tlog_g(Lineardh(𝒙1:t))log_subscript~𝒉:1𝑡log_gsubscriptLinearsubscript𝑑subscript𝒙:1𝑡\mathrm{log}\_\tilde{{\bm{h}}}_{1:t}\leftarrow\mathrm{log\_g}({\mathrm{Linear}% }_{d_{h}}({\bm{x}}_{1:t}))roman_log _ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← roman_log _ roman_g ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) )
𝒉1:tParallelScanLog(log_coeffs,[log_𝒉0,log_𝒛1:t+log_𝒉~1:t){\bm{h}}_{1:t}\leftarrow\mathrm{ParallelScanLog}(\mathrm{log\_coeffs},[\mathrm% {log}\_{\bm{h}}_{0},\mathrm{log}\_{\bm{z}}_{1:t}+\mathrm{log}\_\tilde{{\bm{h}}% }_{1:t})bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← roman_ParallelScanLog ( roman_log _ roman_coeffs , [ roman_log _ bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , roman_log _ bold_italic_z start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT + roman_log _ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT )

B.2.2 minLSTM: A Minimal LSTM

We also derive minLSTM’s log-space formulation as well. Recall minLSTM’s recurrence is as follows 𝒉t𝒇t𝒉t1+𝒊t𝒉~tsubscript𝒉𝑡direct-productsubscriptsuperscript𝒇𝑡subscript𝒉𝑡1direct-productsubscriptsuperscript𝒊𝑡subscript~𝒉𝑡{\bm{h}}_{t}\leftarrow{\bm{f}}^{\prime}_{t}\odot{\bm{h}}_{t-1}+{\bm{i}}^{% \prime}_{t}\odot\tilde{{\bm{h}}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. As such, 𝒂t𝒇tsubscript𝒂𝑡subscriptsuperscript𝒇𝑡{\bm{a}}_{t}\leftarrow{\bm{f}}^{\prime}_{t}bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝒃t𝒊t𝒉~tsubscript𝒃𝑡direct-productsubscriptsuperscript𝒊𝑡subscript~𝒉𝑡{\bm{b}}_{t}\leftarrow{\bm{i}}^{\prime}_{t}\odot\tilde{{\bm{h}}}_{t}bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. As a result, log(𝒂t)log(𝒇t)subscript𝒂𝑡subscriptsuperscript𝒇𝑡\log({\bm{a}}_{t})\leftarrow\log({\bm{f}}^{\prime}_{t})roman_log ( bold_italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ← roman_log ( bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and log(𝒃t)log(𝒊t)+log(𝒉~t)subscript𝒃𝑡subscriptsuperscript𝒊𝑡subscript~𝒉𝑡\log({\bm{b}}_{t})\leftarrow\log({\bm{i}}^{\prime}_{t})+\log(\tilde{{\bm{h}}}_% {t})roman_log ( bold_italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ← roman_log ( bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + roman_log ( over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ).

log(𝒇t)subscriptsuperscript𝒇𝑡\displaystyle\log({\bm{f}}^{\prime}_{t})roman_log ( bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) =log(𝒇t𝒇t+𝒊t)absentsubscript𝒇𝑡subscript𝒇𝑡subscript𝒊𝑡\displaystyle=\log\left(\frac{{\bm{f}}_{t}}{{\bm{f}}_{t}+{\bm{i}}_{t}}\right)= roman_log ( divide start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG )
=log(11+𝒊t𝒇t)absent11subscript𝒊𝑡subscript𝒇𝑡\displaystyle=\log\left(\frac{1}{1+\frac{{\bm{i}}_{t}}{{\bm{f}}_{t}}}\right)= roman_log ( divide start_ARG 1 end_ARG start_ARG 1 + divide start_ARG bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG end_ARG )
=log(1+𝒊t𝒇t)absent1subscript𝒊𝑡subscript𝒇𝑡\displaystyle=-\log\left(1+\frac{{\bm{i}}_{t}}{{\bm{f}}_{t}}\right)= - roman_log ( 1 + divide start_ARG bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG )
=log(1+exp(log(𝒊t𝒇t)))absent1subscript𝒊𝑡subscript𝒇𝑡\displaystyle=-\log\left(1+\exp\left(\log\left(\frac{{\bm{i}}_{t}}{{\bm{f}}_{t% }}\right)\right)\right)= - roman_log ( 1 + roman_exp ( roman_log ( divide start_ARG bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ) ) )
=Softplus(log(𝒊t𝒇t))absentSoftplussubscript𝒊𝑡subscript𝒇𝑡\displaystyle=-\mathrm{Softplus}\left(\log\left(\frac{{\bm{i}}_{t}}{{\bm{f}}_{% t}}\right)\right)= - roman_Softplus ( roman_log ( divide start_ARG bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ) )
=Softplus(log(𝒊t)log(𝒇t))absentSoftplussubscript𝒊𝑡subscript𝒇𝑡\displaystyle=-\mathrm{Softplus}\left(\log({\bm{i}}_{t})-\log({\bm{f}}_{t})\right)= - roman_Softplus ( roman_log ( bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - roman_log ( bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )

Recall that 𝒊tsubscript𝒊𝑡{\bm{i}}_{t}bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝒇tsubscript𝒇𝑡{\bm{f}}_{t}bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are computed via sigmoid. In other words, 𝒊t=σ(𝒌t)subscript𝒊𝑡𝜎subscript𝒌𝑡{\bm{i}}_{t}=\sigma({\bm{k}}_{t})bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_σ ( bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and 𝒇t=σ(𝒑t)subscript𝒇𝑡𝜎subscript𝒑𝑡{\bm{f}}_{t}=\sigma({\bm{p}}_{t})bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_σ ( bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) where 𝒌t=Lineardh(𝒙t)subscript𝒌𝑡subscriptLinearsubscript𝑑subscript𝒙𝑡{\bm{k}}_{t}=\mathrm{Linear}_{d_{h}}({\bm{x}}_{t})bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and 𝒑t=Lineardh(𝒙t)subscript𝒑𝑡subscriptLinearsubscript𝑑subscript𝒙𝑡{\bm{p}}_{t}=\mathrm{Linear}_{d_{h}}({\bm{x}}_{t})bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). Furthermore, recall in minGRU’s derivation we showed that log(σ(𝒌t))=Softplus(𝒌t)𝜎subscript𝒌𝑡Softplussubscript𝒌𝑡\log(\sigma({\bm{k}}_{t}))=-\mathrm{Softplus}(-{\bm{k}}_{t})roman_log ( italic_σ ( bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) = - roman_Softplus ( - bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) Using this, we can simplify the computation as follows:

log(𝒇t)subscriptsuperscript𝒇𝑡\displaystyle\log({\bm{f}}^{\prime}_{t})roman_log ( bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) =Softplus(log(σ(𝒌t))log(σ(𝒑t)))absentSoftplus𝜎subscript𝒌𝑡𝜎subscript𝒑𝑡\displaystyle=-\mathrm{Softplus}\left(\log(\sigma({\bm{k}}_{t}))-\log(\sigma({% \bm{p}}_{t}))\right)= - roman_Softplus ( roman_log ( italic_σ ( bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) - roman_log ( italic_σ ( bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) )
=Softplus(Softplus(𝒑t)Softplus(𝒌t)))\displaystyle=-\mathrm{Softplus}\left(\mathrm{Softplus}(-{\bm{p}}_{t})-\mathrm% {Softplus}(-{\bm{k}}_{t}))\right)= - roman_Softplus ( roman_Softplus ( - bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - roman_Softplus ( - bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) )

Similarly, we also get that:

log(𝒊t)=Softplus(Softplus(𝒌t)Softplus(𝒑t)))\log({\bm{i}}^{\prime}_{t})=-\mathrm{Softplus}\left(\mathrm{Softplus}(-{\bm{k}% }_{t})-\mathrm{Softplus}(-{\bm{p}}_{t}))\right)\\ roman_log ( bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = - roman_Softplus ( roman_Softplus ( - bold_italic_k start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - roman_Softplus ( - bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) )

Combining these derivations, we get the parallel mode (Algorithm 8) for efficient training.

Algorithm 7 Sequential Mode: Minimal Version of LSTM (minLSTM) trained in log-space
𝒙t,𝒉t1subscript𝒙𝑡subscript𝒉𝑡1{\bm{x}}_{t},{\bm{h}}_{t-1}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT
𝒉tsubscript𝒉𝑡{\bm{h}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
𝒇tσ(Lineardh(𝒙t))subscript𝒇𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡{\bm{f}}_{t}\leftarrow\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
𝒊tσ(Lineardh(𝒙t))subscript𝒊𝑡𝜎subscriptLinearsubscript𝑑subscript𝒙𝑡{\bm{i}}_{t}\leftarrow\sigma({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_σ ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
𝒇t,𝒊t𝒇t𝒇t+𝒊t,𝒊t𝒇t+𝒊tformulae-sequencesubscriptsuperscript𝒇𝑡subscriptsuperscript𝒊𝑡subscript𝒇𝑡subscript𝒇𝑡subscript𝒊𝑡subscript𝒊𝑡subscript𝒇𝑡subscript𝒊𝑡{\bm{f}}^{\prime}_{t},{\bm{i}}^{\prime}_{t}\leftarrow\frac{{\bm{f}}_{t}}{{\bm{% f}}_{t}+{\bm{i}}_{t}},\frac{{\bm{i}}_{t}}{{\bm{f}}_{t}+{\bm{i}}_{t}}bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG , divide start_ARG bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG bold_italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + bold_italic_i start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG
𝒉~tg(Lineardh(𝒙t))subscript~𝒉𝑡𝑔subscriptLinearsubscript𝑑subscript𝒙𝑡\tilde{{\bm{h}}}_{t}\leftarrow g({\mathrm{Linear}}_{d_{h}}({\bm{x}}_{t}))over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_g ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) )
𝒉t𝒇t𝒉t1+𝒊t𝒉~tsubscript𝒉𝑡direct-productsubscriptsuperscript𝒇𝑡subscript𝒉𝑡1direct-productsubscriptsuperscript𝒊𝑡subscript~𝒉𝑡{\bm{h}}_{t}\leftarrow{\bm{f}}^{\prime}_{t}\odot{\bm{h}}_{t-1}+{\bm{i}}^{% \prime}_{t}\odot\tilde{{\bm{h}}}_{t}bold_italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
Algorithm 8 Parallel Mode: Minimal Version of LSTM (minLSTM) for training in log-space
𝒙1:t,𝒉0subscript𝒙:1𝑡subscript𝒉0{\bm{x}}_{1:t},{\bm{h}}_{0}bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT , bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
𝒉1:tsubscript𝒉:1𝑡{\bm{h}}_{1:t}bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT
diffSoftplus(Lineardh(𝒙1:t))Softplus(Lineardh(𝒙1:t))diffSoftplussubscriptLinearsubscript𝑑subscript𝒙:1𝑡SoftplussubscriptLinearsubscript𝑑subscript𝒙:1𝑡\mathrm{diff}\leftarrow\mathrm{Softplus}(-{\mathrm{Linear}}_{d_{h}}({\bm{x}}_{% 1:t}))-\mathrm{Softplus}(-{\mathrm{Linear}}_{d_{h}}({\bm{x}}_{1:t}))roman_diff ← roman_Softplus ( - roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) ) - roman_Softplus ( - roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) )
log_𝒇1:tSoftplus(diff)log_subscriptsuperscript𝒇:1𝑡Softplusdiff\mathrm{log}\_{\bm{f}}^{\prime}_{1:t}\leftarrow-\mathrm{Softplus}(\mathrm{diff})roman_log _ bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← - roman_Softplus ( roman_diff )
log_𝒊1:tSoftplus(diff)log_subscriptsuperscript𝒊:1𝑡Softplusdiff\mathrm{log}\_{\bm{i}}^{\prime}_{1:t}\leftarrow-\mathrm{Softplus}(-\mathrm{% diff})roman_log _ bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← - roman_Softplus ( - roman_diff )
log_𝒉0log_g(h0)log_subscript𝒉0log_gsubscript0\mathrm{log}\_{\bm{h}}_{0}\leftarrow\mathrm{log\_g}(h_{0})roman_log _ bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ← roman_log _ roman_g ( italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
log_𝒉~1:tlog_g(Lineardh(𝒙1:t))log_subscript~𝒉:1𝑡log_gsubscriptLinearsubscript𝑑subscript𝒙:1𝑡\mathrm{log}\_\tilde{{\bm{h}}}_{1:t}\leftarrow\mathrm{log\_g}({\mathrm{Linear}% }_{d_{h}}({\bm{x}}_{1:t}))roman_log _ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← roman_log _ roman_g ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) )
𝒉1:tParallelScanLog(log_𝒇1:t,[log_𝒉0,log_𝒊1:t+log_𝒉~1:t){\bm{h}}_{1:t}\leftarrow\mathrm{ParallelScanLog}(\mathrm{log}\_{\bm{f}}^{% \prime}_{1:t},[\mathrm{log}\_{\bm{h}}_{0},\mathrm{log}\_{\bm{i}}^{\prime}_{1:t% }+\mathrm{log}\_\tilde{{\bm{h}}}_{1:t})bold_italic_h start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ← roman_ParallelScanLog ( roman_log _ bold_italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT , [ roman_log _ bold_italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , roman_log _ bold_italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT + roman_log _ over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT )

B.3 PyTorch Code: Log-Space Version

1def g(x):
2 return torch.where(x >= 0, x+0.5, torch.sigmoid(x))
3def log_g(x):
4 return torch.where(x >= 0, (F.relu(x)+0.5).log(),
5 -F.softplus(-x))
Listing 6: The continuous function g𝑔gitalic_g ensures that 𝒉~tg(Lineardh(xt))subscript~𝒉𝑡𝑔subscriptLinearsubscript𝑑subscript𝑥𝑡\tilde{{\bm{h}}}_{t}\leftarrow g(\mathrm{Linear}_{d_{h}}(x_{t}))over~ start_ARG bold_italic_h end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← italic_g ( roman_Linear start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) is positive.

B.3.1 minGRU: A Minimal GRU

1 def forward(self, x_t, h_prev):
2 # x_t: (batch_size, input_size)
3 # h_prev: (batch_size, hidden_size)
4
5 z = torch.sigmoid(self.linear_z(x))
6 h_tilde = g(self.linear_h(x))
7 h_t = (1 - z) * h_prev + z * h_tilde
8 return h_t
Listing 7: Sequential Mode: Minimal Version of GRU (minGRU) trained in log-space
1 def forward(self, x, h_0):
2 # x: (batch_size, seq_len, input_size)
3 # h_0: (batch_size, 1, hidden_size)
4
5 k = self.linear_z(x)
6 log_z = -F.softplus(-k)
7 log_coeffs = -F.softplus(k)
8 log_h_0 = log_g(h_0)
9 log_tilde_h = log_g(self.linear_h(x))
10 h = parallel_scan_log(log_coeffs,
11 torch.cat([log_h_0, log_z + log_tilde_h], dim=1))
12 return h
Listing 8: Parallel Mode: Minimal Version of GRU (minGRU) for training in log-space

B.3.2 minLSTM: A Minimal LSTM

1 def forward(self, x_t, h_prev):
2 # x_t: (batch_size, input_size)
3 # h_prev: (batch_size, hidden_size)
4
5 f_t = torch.sigmoid(self.linear_f(x_t))
6 i_t = torch.sigmoid(self.linear_i(x_t))
7 tilde_h_t = g(self.linear_h(x_t))
8 f_prime_t = f_t / (f_t + i_t)
9 i_prime_t = i_t / (f_t + i_t)
10 h_t = f_prime_t * h_prev + i_prime_t * tilde_h_t
11 return h_t
Listing 9: Sequential Mode: Minimal Version of LSTM (minLSTM) trained in log-space
1 def forward(self, x, h_0):
2 # x: (batch_size, seq_len, input_size)
3 # h_0: (batch_size, 1, hidden_size)
4
5 diff = F.softplus(-self.linear_f(x)) \
6 - F.softplus(-self.linear_i(x))
7 log_f = -F.softplus(diff)
8 log_i = -F.softplus(-diff)
9 log_h_0 = log_g(h_0)
10 log_tilde_h = log_g(self.linear_h(x))
11 h = parallel_scan_log(log_f,
12 torch.cat([log_h_0, log_i + log_tilde_h], dim=1))
13 return h
Listing 10: Parallel Mode: Minimal Version of LSTM (minLSTM) for training in log-space

Appendix C Detailed Experiment Setup

In this section, we describe the experiment setup in detail.

C.1 Architecture

In all models, residual connections are added between layers and layer norms are applied before each layer.

Selective Copying. Each layer in the model consisted of (1) either a minLSTM or minGRU layer and (2) a linear layer.

Reinforcement Learning. In this work, we consider the Decision Transformer framework for (Offline) RL. Following prior works (Feng et al., 2024; Ota, 2024), we replace the Self-Attention module with our recurrent sequence modules: minLSTM and minGRU respectively.

Language Modelling. Prior works (Gu & Dao, 2024; Beck et al., 2024) apply a convolutional layer in addition to their recurrent sequence module. Following them, a layer of the model consists of an RNN (minLSTM and minGRU), a convolutional layer applied temporally with a kernel size of 4444, and a two-layer MLP.

C.2 Hyperparameters and general experimental details

Selective Copying. Models are trained for 400,000400000400,000400 , 000 steps with a batch size of 64646464, and using early stopping. The optimizer used is Adam with a learning rate of 3×1043superscript1043\times 10^{-4}3 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT. Due to GPU memory limitations, gradient accumulation is performed during training. Gradients for two batches of size 32323232 are accumulated for each gradient update and clipped to 1.01.01.01.0. Each model consists of 3333 layers, an input dimension of 64646464, and a dropout ratio of 0.10.10.10.1. minLSTM and minGRU have an expansion factor of 6666. Results for the baselines are referenced from the Mamba paper.

Reinforcement Learning. We follow the hyperparameter settings outlined by Ota (2024). For Hopper (Medium) and Hopper (Medium-Replay), an embedding dimension of 256256256256 is used, while all other environments utilize an embedding dimension of 128128128128. The learning rate is set to 1×1041superscript1041\times 10^{-4}1 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT for Hopper (Medium), Hopper (Medium-Replay), and Walker (Medium). For all other environments and datasets, the learning rate is 1×1031superscript1031\times 10^{-3}1 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT. The models are optimized using AdamW with a weight decay of 1×1041superscript1041\times 10^{-4}1 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT and a linear warmup for 10,0001000010,00010 , 000 steps. Each model consists of 3333 layers and has a dropout ratio of 0.10.10.10.1. The models are trained for 100,000100000100,000100 , 000 steps with a batch size of 64646464. Gradients are clipped to 0.250.250.250.25. Results for the baselines are referenced from the Mamba and Aaren papers.

Language Modelling. The models are optimized using AdamW with a learning rate of 1×1031superscript1031\times 10^{-3}1 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT. Each model consists of three layers, a dropout ratio of 0.20.20.20.2, and an embedding dimension of 384384384384. Training is done with 5000500050005000 steps using a batch size of 64 and evaluated every 25252525 steps. Gradients are clipped to 1.01.01.01.0. The Transformer is configured with 6666 heads. Mamba uses an SSM state expansion factor of 16161616 and a block expansion factor of 2222. Following Mamba, both minLSTM and minGRU utilize an expansion factor of 2222 as well.

C.3 Datasets

Selective Copying. In this task, the model learns to extract data tokens from a sequence while disregarding noise tokens. Following Gu & Dao (2024), we consider a vocabulary of 16161616 and sequences of length 4096409640964096. Each sequence includes 16161616 randomly placed data tokens. The remainder of the tokens are noise.

Reinforcement Learning. In this setting, we consider continuous control tasks from the D4RL benchmark (Fu et al., 2020). These tasks based on MuJoCo comprise of three environments with dense rewards: HalfCheetah, Hopper, and Walker. For each environment, three different datasets are considered that have varying level represent varying levels of data quality:

  • Medium (M): One million timesteps generated by a policy scoring about one-third of an expert policy’s score.

  • Medium-Replay (M-R): A replay buffer from an agent trained to perform like the Medium policy.

  • Medium-Expert (M-E): One million timesteps from the Medium policy combined with one million from an expert policy.

Following Fu et al. (2020), reported scores are normalized such that 100100100100 represents an expert policy performance.

Language Modelling. In this setting, we consider the Shakespeare dataset, comprising a collection of text data derived from the works of William Shakespeare. The training and testing data consists of 1,003,85410038541,003,8541 , 003 , 854 and 111,540111540111,540111 , 540 tokens respectively.