Were RNNs all we needed?
The paper by Leo Feng et al. revisits RNNs, proposing minimal LSTMs and GRUs that enhance training speed and performance, suggesting a renewed interest in RNNs for machine learning applications.
Read original articleThe paper titled "Were RNNs All We Needed?" by Leo Feng and colleagues revisits the capabilities of recurrent neural networks (RNNs) in light of the recent scalability issues associated with Transformer models, particularly concerning sequence length. The authors propose novel recurrent architectures, including minimal versions of traditional RNNs, specifically LSTMs and GRUs, which have been modified to eliminate hidden state dependencies in their input, forget, and update gates. This modification allows these models to be trained in parallel without the need for backpropagation through time (BPTT), significantly enhancing their training speed—up to 175 times faster for sequences of length 512. The study demonstrates that these streamlined models, referred to as minLSTMs and minGRUs, not only require fewer parameters than their traditional counterparts but also achieve performance levels comparable to more recent sequence models. This work highlights the potential of revisiting and optimizing older neural network architectures to address current challenges in machine learning.
- The paper explores the limitations of Transformers and the renewed interest in RNNs.
- It introduces minimal versions of LSTMs and GRUs that are fully parallelizable.
- The modified RNNs can be trained significantly faster than traditional models.
- The study shows that these minimal RNNs match the performance of contemporary sequence models.
- The research suggests a potential shift back to RNNs for certain applications in machine learning.
Related
xLSTM Explained in Detail
Maximillan Beck's YouTube video delves into XLSTM as a Transformer alternative in language modeling. XLSTM combines LSTM and modern techniques to tackle storage and decision-making issues, aiming to rival Transformers in predictive tasks.
Math Behind Transformers and LLMs
This post introduces transformers and large language models, focusing on OpenGPT-X and transformer architecture. It explains language models, training processes, computational demands, GPU usage, and the superiority of transformers in NLP.
Learning to (Learn at Test Time): RNNs with Expressive Hidden States
The paper introduces Test-Time Training (TTT) layers for sequence modeling, featuring linear complexity and self-supervised learning for training on test sequences. TTT-Linear outperforms Transformer, while TTT-MLP shows potential for long contexts.
Transformer Layers as Painters
The study "Transformer Layers as Painters" by Qi Sun et al. delves into transformer models, showcasing layer impact variations and potential for model optimization through strategic layer adjustments.
Architectural Effects on Maximum Dependency Lengths of Recurrent Neural Networks
The study by Kent and Murray presents a methodology for assessing maximum dependency lengths in RNNs, analyzing how architectural factors like layers and neuron counts affect performance in sequential data.
- Many commenters express skepticism about RNNs' ability to match Transformers in handling long-context tasks due to inherent limitations in memory retention.
- There is a recognition of the historical significance of RNNs and a renewed interest in their potential, especially with new architectures like minGRU.
- Several users draw parallels between RNNs and digital signal processing, suggesting that RNNs may be more efficient for certain tasks.
- Concerns about energy efficiency in AI models are raised, comparing the energy consumption of modern AI to that of the human brain.
- Some commenters emphasize the need for better citation practices in the field to acknowledge previous work on RNNs and related architectures.
> RNNs are particularly suitable for sequence modelling settings such as those involving time series, natural language processing, and other sequential tasks where context from previous steps informs the current prediction.
I would like to draw an analogy to digital signal processing. If you think of the recurrent-style architectures as IIR filters and feedforward-only architectures as FIR filters, you will likely find many parallels.
The most obvious to me being that IIR filters typically require far fewer elements to produce the same response as an equivalent FIR filter. Granted, the FIR filter is often easier to implement/control/measure in practical terms (fixed-point arithmetic hardware == ML architectures that can run on GPUs).
I don't think we get to the exponential scary part of AI without some fundamentally recurrent architecture. I think things like LSTM are kind of an in-between hack in this DSP analogy - You could look at it as FIR with dynamic coefficients. Neuromorphic approaches seem like the best long term bet to me in terms of efficiency.
I'm honestly a bit envious of future engineers who will be tackling these kinds of problems with a 100-line Jupyter notebook on a laptop years from now. If we discovered the right method or algorithm for these long-horizon problems, a 2B-parameter model might even outperform current models on everything except short, extreme reasoning problems.
The only solution I've ever considered for this is expanding a model's dimensionality over time, rather than focusing on perfect weights. The higher dimensionality you can provide to a model, the greater its theoretical storage capacity. This could resemble a two-layer model—one layer acting as a superposition of multiple ideal points, and the other layer knowing how to use them.
When you think about the loss landscape, imagine it with many minima for a given task. If we could create a method that navigates these minima by reconfiguring the model when needed, we could theoretically develop a single model with near-infinite local minima—and therefore, higher-dimensional memory. This may sound wild, but consider the fact that the human brain potentially creates and disconnects thousands of new connections in a single day. Could it be that these connections steer our internal loss landscape between different minima we need throughout the day?
"Interesting work on reviving RNNs. https://arxiv.org/abs/2410.01201 -- in general the fact that there are many recent architectures coming from different directions that roughly match Transformers is proof that architectures aren't fundamentally important in the curve-fitting paradigm (aka deep learning)
Curve-fitting is about embedding a dataset on a curve. The critical factor is the dataset, not the specific hard-coded bells and whistles that constrain the curve's shape. As long as your curve is sufficiently expressive all architectures will converge to the same performance in the large-data regime."
Here's why.
A user of an LLM might give the model some long text and then say "Translate this into German please". A Transformer can look back at its whole history. But what is an RNN to do? While the length of its context is unlimited, the amount of information the model retains about it is bounded by whatever is in its hidden state at any given time.
Relevant: https://arxiv.org/abs/2402.01032
Do we have solutions for these two problems now?
class MinGRU(nn.Module):
def __init__(self, token_size, hidden_state_size):
self.token_to_proposal = nn.Linear(token_size, hidden_size)
self.token_to_mix_factors = nn.Linear(token_size, hidden_size)
def forward(self, previous_hidden_state, current_token):
proposed_hidden_state = self.token_to_proposal(current_token)
mix_factors = torch.sigmoid(self.token_to_mix_factors(current_token))
return torch.lerp(proposed_hidden_state, previous_hidden_state, mix_factors)
And since the proposed hidden states and mix factors for each layer are both only dependent on the current token, you can compute all of them in parallel if you know the whole sequence ahead of time (like during training), and then combine them in linear time using parallel scan.The fact that this is competitive with transformers and state-space models in their small-scale experiments is gratifying to the "best PRs are the ones that delete code" side of me. That said, we won't know for sure if this is a capital-B Breakthrough until someone tries scaling it up to parameter and data counts comparable to SOTA models.
One detail I found really interesting is that they seem to do all their calculations in log-space, according to the Appendix. They say it's for numerical stability, which is curious to me—I'm not sure I have a good intuition for why running everything in log-space makes the model more stable. Is it because they removed the tanh from the output, making it possible for values to explode if calculations are done in linear space?
EDIT: Another thought—it's kind of fascinating that this sort of sequence modeling works at all. It's like if I gave you all the pages of a book individually torn out and in a random order, and asked you to try to make a vector representation for each page as well as instructions for how to mix that vector with the vector representing all previous pages — except you have zero knowledge of those previous pages. Then, I take all your page vectors, sequentially mix them together in-order, and grade you based on how good of a whole-book summary the final vector represents. Wild stuff.
FURTHER EDIT: Yet another thought—right now, they're just using two dense linear layers to transform the token into the proposed hidden state and the lerp mix factors. I'm curious what would happen if you made those transforms MLPs instead of singular linear layers.
Mine worked, but it was very simple and dog slow, running on my old laptop. Nothing was ever going to run fast on that thing, but I remember my RNN being substantially slower than a feed-forward network would have been.
I was so confident that this was dead technology -- an academic curiosity from the 1980s and 1990s. It was bizarre to see how quickly that changed.
We were able to build generators that could replicate any dataset they were trained on, and would produce unique deviations, but match the statistical underpinnings of the original datasets.
https://medium.com/capital-one-tech/why-you-dont-necessarily...
We built several text generators for bots that similarly had very good results. The introduction of the transformer improved the speed and reduced the training / data requirements, but honestly the accuracy changed minimal.
Compare with one human brain. Far more sophisticated, even beyond our knowledge. What does it take to power it for a day? Some vegetables and rice. Still fine for a while if you supply pure junk food -- it'll still perform.
Clearly we have a long, long way to go in terms of the energy efficiency of AI approaches. Our so-called neural nets clearly don't resemble the energy efficiency of actual biological neurons.
It's obvious why the newest toy from openai can solve problems better mostly by just being allowed to "talk to itself" for a moment before starting the answer that human sees.
Given that, modern incarnation of RNN can be vastly cheaper than transformers provided that they can be trained.
Convolutional neural networks get more visual understanding by "reusing" their capacity across the area of the image. RNN's and transformers can have better understanding of a given problem by "reusing" their capacity to learn and infer across time (across steps of iterative process really).
When it comes to transformer architecture the attention is a red herring. It's just more or less arbitrary way to partition the network so it can be parallelized. The only bit of potential magic is with "shortcut" links between non adjacent layers that help propagate learning back through many layers.
Basically the optimal network is deep, dense (all neurons connect with all belonging to all preceding layers) that is ran in some form of recurrence.
But we don't have enough compute to train that. So we need to arbitrarily sever some connections so the whole thing is easier to parallelized. It really doesn't matter which unless we do in some obviously stupid way.
Actual inventive magic part of LLMs possibly happens in token and positional encoders.
From theory the answer to the question should be "yes", they are Turing complete.
The real question is about how to train them, and the paper is about that.
Can RNNs be as good as Transformers at recalling information from previous tokens in a sequence?
Transformers excel at recalling info, likely because they keep all previous context around in an ever-growing KV cache.
Unless proponents of RNNs conclusively demonstrate that RNNs can recall info from previous context at least as well as Transformers, I'll stick with the latter.
This is obvious when one considers the connections between Transformers, RNNs, Hopfield networks and the Ising model, a model from statistical mechanics which is solved by calculating the partition function.
This interpretation provides us with some very powerful tools that are commonplace in math and physics but which are not talked about in CS & ML.
I'm working on a startup http://traceoid.ai which takes this exact view. Our approach enables faster training and inference, interpretability and also scalable energy-based models, the Holy Grail of machine learning.
Join the discord https://discord.com/invite/mr9TAhpyBW or follow me on twitter https://twitter.com/adamnemecek1
BPTT was their problem
In 2016 my team from Salesforce Research published our work on the Quasi-Recurrent Neural Network[1] (QRNN). The QRNN variants we describe are near identical (minGRU) or highly similar (minLSTM) to the work here.
The QRNN was used, many years ago now, in the first version of Baidu's speech recognition system (Deep Voice [6]) and as part of Google's handwriting recognition system in Gboard[5] (2019).
Even if there are expressivity trade-offs when using parallelizable RNNs they've shown historically they can work well and are low resource and incredibly fast. Very few of the possibilities regarding distillation, hardware optimization, etc, have been explored.
Even if you need "exact" recall, various works have shown that even a single layer of attention with a parallelizable RNN can yield strong results. Distillation down to such a model is quite promising.
Other recent fast RNN variants such as the RWKV, S4, Mamba et al. include citations to QRNN (2016) and SRU (2017) for a richer history + better context.
The SRU work has also had additions in recent years (SRU++), doing well in speech recognition and LM tasks where they found similar speed benefits over Transformers.
I note this primarily as the more data points, especially when strongly relevant, the better positioned the research is. A number of the "new" findings from this paper have been previously explored - and do certainly show promise! This makes sure we're asking new questions with new insights (with all the benefit of additional research from ~8 years ago) versus missing the work from those earlier.
[1] QRNN paper: https://arxiv.org/abs/1611.01576
[2] SRU paper: https://arxiv.org/abs/1709.02755
[3]: SRU++ for speech recognition: https://arxiv.org/abs/2110.05571
[4]: SRU++ for language modeling: https://arxiv.org/abs/2102.12459
[5]: https://research.google/blog/rnn-based-handwriting-recogniti...
Everything else is just details.
Related
xLSTM Explained in Detail
Maximillan Beck's YouTube video delves into XLSTM as a Transformer alternative in language modeling. XLSTM combines LSTM and modern techniques to tackle storage and decision-making issues, aiming to rival Transformers in predictive tasks.
Math Behind Transformers and LLMs
This post introduces transformers and large language models, focusing on OpenGPT-X and transformer architecture. It explains language models, training processes, computational demands, GPU usage, and the superiority of transformers in NLP.
Learning to (Learn at Test Time): RNNs with Expressive Hidden States
The paper introduces Test-Time Training (TTT) layers for sequence modeling, featuring linear complexity and self-supervised learning for training on test sequences. TTT-Linear outperforms Transformer, while TTT-MLP shows potential for long contexts.
Transformer Layers as Painters
The study "Transformer Layers as Painters" by Qi Sun et al. delves into transformer models, showcasing layer impact variations and potential for model optimization through strategic layer adjustments.
Architectural Effects on Maximum Dependency Lengths of Recurrent Neural Networks
The study by Kent and Murray presents a methodology for assessing maximum dependency lengths in RNNs, analyzing how architectural factors like layers and neuron counts affect performance in sequential data.