October 3rd, 2024

Were RNNs all we needed?

The paper by Leo Feng et al. revisits RNNs, proposing minimal LSTMs and GRUs that enhance training speed and performance, suggesting a renewed interest in RNNs for machine learning applications.

Read original articleLink Icon
CuriositySkepticismExcitement
Were RNNs all we needed?

The paper titled "Were RNNs All We Needed?" by Leo Feng and colleagues revisits the capabilities of recurrent neural networks (RNNs) in light of the recent scalability issues associated with Transformer models, particularly concerning sequence length. The authors propose novel recurrent architectures, including minimal versions of traditional RNNs, specifically LSTMs and GRUs, which have been modified to eliminate hidden state dependencies in their input, forget, and update gates. This modification allows these models to be trained in parallel without the need for backpropagation through time (BPTT), significantly enhancing their training speed—up to 175 times faster for sequences of length 512. The study demonstrates that these streamlined models, referred to as minLSTMs and minGRUs, not only require fewer parameters than their traditional counterparts but also achieve performance levels comparable to more recent sequence models. This work highlights the potential of revisiting and optimizing older neural network architectures to address current challenges in machine learning.

- The paper explores the limitations of Transformers and the renewed interest in RNNs.

- It introduces minimal versions of LSTMs and GRUs that are fully parallelizable.

- The modified RNNs can be trained significantly faster than traditional models.

- The study shows that these minimal RNNs match the performance of contemporary sequence models.

- The research suggests a potential shift back to RNNs for certain applications in machine learning.

AI: What people are saying
The discussion surrounding the paper on minimal LSTMs and GRUs reveals several key themes and insights about RNNs and their comparison to Transformers.
  • Many commenters express skepticism about RNNs' ability to match Transformers in handling long-context tasks due to inherent limitations in memory retention.
  • There is a recognition of the historical significance of RNNs and a renewed interest in their potential, especially with new architectures like minGRU.
  • Several users draw parallels between RNNs and digital signal processing, suggesting that RNNs may be more efficient for certain tasks.
  • Concerns about energy efficiency in AI models are raised, comparing the energy consumption of modern AI to that of the human brain.
  • Some commenters emphasize the need for better citation practices in the field to acknowledge previous work on RNNs and related architectures.
Link Icon 31 comments
By @bob1029 - 7 months
> Transformers required ~2.5x more training steps to achieve comparable performance, overfitting eventually.

> RNNs are particularly suitable for sequence modelling settings such as those involving time series, natural language processing, and other sequential tasks where context from previous steps informs the current prediction.

I would like to draw an analogy to digital signal processing. If you think of the recurrent-style architectures as IIR filters and feedforward-only architectures as FIR filters, you will likely find many parallels.

The most obvious to me being that IIR filters typically require far fewer elements to produce the same response as an equivalent FIR filter. Granted, the FIR filter is often easier to implement/control/measure in practical terms (fixed-point arithmetic hardware == ML architectures that can run on GPUs).

I don't think we get to the exponential scary part of AI without some fundamentally recurrent architecture. I think things like LSTM are kind of an in-between hack in this DSP analogy - You could look at it as FIR with dynamic coefficients. Neuromorphic approaches seem like the best long term bet to me in terms of efficiency.

By @charlescurt123 - 7 months
I find the entire field lacking when it comes to long-horizon problems. Our current, widely used solution is to scale, but we're nowhere near achieving the horizon scales even small mammal brains can handle. Our models can have trillions of parameters, yet a mouse brain would still outperform them on long-horizon tasks and efficiency. It's something small, simple, and elegant—an incredible search algorithm that not only finds near-optimal routes but also continuously learns on a fixed computational budget.

I'm honestly a bit envious of future engineers who will be tackling these kinds of problems with a 100-line Jupyter notebook on a laptop years from now. If we discovered the right method or algorithm for these long-horizon problems, a 2B-parameter model might even outperform current models on everything except short, extreme reasoning problems.

The only solution I've ever considered for this is expanding a model's dimensionality over time, rather than focusing on perfect weights. The higher dimensionality you can provide to a model, the greater its theoretical storage capacity. This could resemble a two-layer model—one layer acting as a superposition of multiple ideal points, and the other layer knowing how to use them.

When you think about the loss landscape, imagine it with many minima for a given task. If we could create a method that navigates these minima by reconfiguring the model when needed, we could theoretically develop a single model with near-infinite local minima—and therefore, higher-dimensional memory. This may sound wild, but consider the fact that the human brain potentially creates and disconnects thousands of new connections in a single day. Could it be that these connections steer our internal loss landscape between different minima we need throughout the day?

By @xnx - 7 months
It's curse and a blessing that discussion of topics happens in so many different places. I found this comment on Twitter/X interesting: https://x.com/fchollet/status/1841902521717293273

"Interesting work on reviving RNNs. https://arxiv.org/abs/2410.01201 -- in general the fact that there are many recent architectures coming from different directions that roughly match Transformers is proof that architectures aren't fundamentally important in the curve-fitting paradigm (aka deep learning)

Curve-fitting is about embedding a dataset on a curve. The critical factor is the dataset, not the specific hard-coded bells and whistles that constrain the curve's shape. As long as your curve is sufficiently expressive all architectures will converge to the same performance in the large-data regime."

By @trott - 7 months
My feeling is that the answer is "no", in the sense that these RNNs wouldn't be able to universally replace Transformers in LLMs, even though they might be good enough in some cases and beat them in others.

Here's why.

A user of an LLM might give the model some long text and then say "Translate this into German please". A Transformer can look back at its whole history. But what is an RNN to do? While the length of its context is unlimited, the amount of information the model retains about it is bounded by whatever is in its hidden state at any given time.

Relevant: https://arxiv.org/abs/2402.01032

By @theanonymousone - 7 months
I remember that, the way I understood it, Transformers solved two major "issues" of RNNs that enabled the later boom: Vanishing gradients limiting the context (and model?) size and difficulty in parallelisation limiting the size of the training data.

Do we have solutions for these two problems now?

By @mkaic - 7 months
I strongly enjoy the simplicity of their "minGRU" architecture. It's basically just:

  class MinGRU(nn.Module):
    def __init__(self, token_size, hidden_state_size):
      self.token_to_proposal = nn.Linear(token_size, hidden_size)
      self.token_to_mix_factors = nn.Linear(token_size, hidden_size)

    def forward(self, previous_hidden_state, current_token):
      proposed_hidden_state = self.token_to_proposal(current_token)
      mix_factors = torch.sigmoid(self.token_to_mix_factors(current_token))
      return torch.lerp(proposed_hidden_state, previous_hidden_state, mix_factors)
And since the proposed hidden states and mix factors for each layer are both only dependent on the current token, you can compute all of them in parallel if you know the whole sequence ahead of time (like during training), and then combine them in linear time using parallel scan.

The fact that this is competitive with transformers and state-space models in their small-scale experiments is gratifying to the "best PRs are the ones that delete code" side of me. That said, we won't know for sure if this is a capital-B Breakthrough until someone tries scaling it up to parameter and data counts comparable to SOTA models.

One detail I found really interesting is that they seem to do all their calculations in log-space, according to the Appendix. They say it's for numerical stability, which is curious to me—I'm not sure I have a good intuition for why running everything in log-space makes the model more stable. Is it because they removed the tanh from the output, making it possible for values to explode if calculations are done in linear space?

EDIT: Another thought—it's kind of fascinating that this sort of sequence modeling works at all. It's like if I gave you all the pages of a book individually torn out and in a random order, and asked you to try to make a vector representation for each page as well as instructions for how to mix that vector with the vector representing all previous pages — except you have zero knowledge of those previous pages. Then, I take all your page vectors, sequentially mix them together in-order, and grade you based on how good of a whole-book summary the final vector represents. Wild stuff.

FURTHER EDIT: Yet another thought—right now, they're just using two dense linear layers to transform the token into the proposed hidden state and the lerp mix factors. I'm curious what would happen if you made those transforms MLPs instead of singular linear layers.

By @vandahm - 7 months
I made a RNN for a college project because I was interested in obsolete historical technology and I thought I needed to seize the opportunity while it lasted, because once I was out of school, I'd never hear about neural networks ever again.

Mine worked, but it was very simple and dog slow, running on my old laptop. Nothing was ever going to run fast on that thing, but I remember my RNN being substantially slower than a feed-forward network would have been.

I was so confident that this was dead technology -- an academic curiosity from the 1980s and 1990s. It was bizarre to see how quickly that changed.

By @imjonse - 7 months
To their credit, the authors (Y. Bengio among them) end the paper with the question, not suggesting they know the answer. These models are very small even by academic standards so any finding would not necessarily extend to current LLM scales. The main conclusion is that RNN class networks can be trained as efficiently as modern alternatives but the resulting performance is only competitive at small scale.
By @logicchains - 7 months
The model in the paper isn't a "real" RNN due making it parallelizable, for same the reasons described in https://arxiv.org/abs/2404.08819 , and hence is theoretically less powerful than a "real" RNN (struggles at some classes of problems that RNNs traditionally excel at). On the other hand, https://arxiv.org/abs/2405.04517 contains a "real" RNN component, which demonstrates a significant improvement on the kind of state-tracking problems that transformers struggle with.
By @tehsauce - 7 months
I haven’t gone through the paper in detail yet but maybe someone can answer. If you remove the hidden state from an rnn as they say they’ve done, what’s left? An mlp predicting from a single token?
By @tadala - 7 months
Everyone wants to use less compute to fit more in, but (obviously?) the solution will be to use more compute and fit less. Attention isn't (topologically) attentive enough. All these RNN-lite approaches are doomed, beyond saving costs, they're going to get cooked by some other arch—even more expensive than transformers.
By @lettergram - 7 months
In 2016 & 2017 my team at Capital One built several >1B parameter models combining LSTMs with a few other tricks.

We were able to build generators that could replicate any dataset they were trained on, and would produce unique deviations, but match the statistical underpinnings of the original datasets.

https://medium.com/capital-one-tech/why-you-dont-necessarily...

We built several text generators for bots that similarly had very good results. The introduction of the transformer improved the speed and reduced the training / data requirements, but honestly the accuracy changed minimal.

By @hdivider - 7 months
I still find it remarkable how we need such an extreme amount of electrical energy to power large modern AI models.

Compare with one human brain. Far more sophisticated, even beyond our knowledge. What does it take to power it for a day? Some vegetables and rice. Still fine for a while if you supply pure junk food -- it'll still perform.

Clearly we have a long, long way to go in terms of the energy efficiency of AI approaches. Our so-called neural nets clearly don't resemble the energy efficiency of actual biological neurons.

By @m11a - 7 months
It’d be nice to see more of how this compares to Mamba. Looks like, in performance, they’re not leagues apart and it’s just a different architecture, not necessarily better or worse?
By @scotty79 - 7 months
The only strength of transformers is that they can run once for each token and they can pass to themselves intermediate state as they solve your problems. They have to conceal it in tokens that look to humans like a part of the response.

It's obvious why the newest toy from openai can solve problems better mostly by just being allowed to "talk to itself" for a moment before starting the answer that human sees.

Given that, modern incarnation of RNN can be vastly cheaper than transformers provided that they can be trained.

Convolutional neural networks get more visual understanding by "reusing" their capacity across the area of the image. RNN's and transformers can have better understanding of a given problem by "reusing" their capacity to learn and infer across time (across steps of iterative process really).

When it comes to transformer architecture the attention is a red herring. It's just more or less arbitrary way to partition the network so it can be parallelized. The only bit of potential magic is with "shortcut" links between non adjacent layers that help propagate learning back through many layers.

Basically the optimal network is deep, dense (all neurons connect with all belonging to all preceding layers) that is ran in some form of recurrence.

But we don't have enough compute to train that. So we need to arbitrarily sever some connections so the whole thing is easier to parallelized. It really doesn't matter which unless we do in some obviously stupid way.

Actual inventive magic part of LLMs possibly happens in token and positional encoders.

By @marcosdumay - 7 months
R == Recurrent

From theory the answer to the question should be "yes", they are Turing complete.

The real question is about how to train them, and the paper is about that.

By @cs702 - 6 months
I finally got around to reading this. Nice paper, but it fails to address a key question about RNNs:

Can RNNs be as good as Transformers at recalling information from previous tokens in a sequence?

Transformers excel at recalling info, likely because they keep all previous context around in an ever-growing KV cache.

Unless proponents of RNNs conclusively demonstrate that RNNs can recall info from previous context at least as well as Transformers, I'll stick with the latter.

By @adamnemecek - 7 months
Yes, all machine learning can be interpreted in terms of approximating the partition function.

This is obvious when one considers the connections between Transformers, RNNs, Hopfield networks and the Ising model, a model from statistical mechanics which is solved by calculating the partition function.

This interpretation provides us with some very powerful tools that are commonplace in math and physics but which are not talked about in CS & ML.

I'm working on a startup http://traceoid.ai which takes this exact view. Our approach enables faster training and inference, interpretability and also scalable energy-based models, the Holy Grail of machine learning.

Join the discord https://discord.com/invite/mr9TAhpyBW or follow me on twitter https://twitter.com/adamnemecek1

By @dsamarin - 7 months
The name of the paper contrasts with the paper that spawned Transformer architecture, which itself is a reference to the song "All You Need Is Love" by the Beatles. https://en.wikipedia.org/wiki/Attention_Is_All_You_Need
By @limapedro - 7 months
This is such a interesting paper, sadly they don't have big models, I'd like to see a model trained on TinyStories or even C4 since it should be faster than the transformer variant and see how it compares.
By @gdiamos - 7 months
RNNs always had better scaling law curves than transformers.

BPTT was their problem

By @kgbcia - 7 months
Decision trees is all we needed
By @hiddencost - 7 months
Note Yoshua Bengio in the author list. This shouldn't be taken lightly.
By @Smerity - 7 months
Excited to see more people working on RNNs but wish their citations were better.

In 2016 my team from Salesforce Research published our work on the Quasi-Recurrent Neural Network[1] (QRNN). The QRNN variants we describe are near identical (minGRU) or highly similar (minLSTM) to the work here.

The QRNN was used, many years ago now, in the first version of Baidu's speech recognition system (Deep Voice [6]) and as part of Google's handwriting recognition system in Gboard[5] (2019).

Even if there are expressivity trade-offs when using parallelizable RNNs they've shown historically they can work well and are low resource and incredibly fast. Very few of the possibilities regarding distillation, hardware optimization, etc, have been explored.

Even if you need "exact" recall, various works have shown that even a single layer of attention with a parallelizable RNN can yield strong results. Distillation down to such a model is quite promising.

Other recent fast RNN variants such as the RWKV, S4, Mamba et al. include citations to QRNN (2016) and SRU (2017) for a richer history + better context.

The SRU work has also had additions in recent years (SRU++), doing well in speech recognition and LM tasks where they found similar speed benefits over Transformers.

I note this primarily as the more data points, especially when strongly relevant, the better positioned the research is. A number of the "new" findings from this paper have been previously explored - and do certainly show promise! This makes sure we're asking new questions with new insights (with all the benefit of additional research from ~8 years ago) versus missing the work from those earlier.

[1] QRNN paper: https://arxiv.org/abs/1611.01576

[2] SRU paper: https://arxiv.org/abs/1709.02755

[3]: SRU++ for speech recognition: https://arxiv.org/abs/2110.05571

[4]: SRU++ for language modeling: https://arxiv.org/abs/2102.12459

[5]: https://research.google/blog/rnn-based-handwriting-recogniti...

[6]: https://arxiv.org/abs/1702.07825

By @moi2388 - 7 months
Yes, and it’s hardly surprising, since the Chinese room thought experiment is completely wrong; that is in fact exactly how you learn something.
By @fhdsgbbcaA - 7 months
We really need a [preprint] flag for unreviewed papers.
By @lccerina - 7 months
"Was all along a scheme by Google to sell more tensor processing units that didn't run RNNs well?"
By @hydrolox - 7 months
Betteridge's law of headlines?
By @Sysreq2 - 7 months
Guys, I’m gonna stop this before it gets out of hand: All we need is love and a shit ton of compute.

Everything else is just details.

By @PunchTornado - 7 months
To me this is further evidence that these LLMs learn only to speak English, but there is no reasoning at all in them. If you simplify a lot and obtain the same results and we know how complex the brain is.