TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters
TokenFormer introduces a scalable architecture for Transformers, allowing efficient scaling from 124 million to 1.4 billion parameters without complete retraining, while maintaining performance comparable to traditional models.
Read original articleThe paper titled "TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters" introduces a novel architecture aimed at addressing the high computational costs associated with scaling Transformer models. Traditional Transformers rely on a fixed number of parameters in linear projections, which necessitates retraining the entire model when architectural changes are made. This approach becomes increasingly impractical as model sizes grow. The authors propose TokenFormer, which utilizes the attention mechanism for both input tokens and model parameters, allowing for greater architectural flexibility. By treating model parameters as tokens, the architecture replaces linear projections with a token-parameter attention layer, enabling efficient scaling from 124 million to 1.4 billion parameters without the need for complete retraining. This method achieves performance levels comparable to Transformers trained from scratch while significantly reducing training costs. The authors provide access to the code and models for further exploration.
- TokenFormer offers a scalable architecture for Transformers, reducing retraining costs.
- The model treats parameters as tokens, enhancing flexibility in architectural modifications.
- It scales efficiently from 124M to 1.4B parameters without complete retraining.
- Performance is comparable to traditional Transformers trained from scratch.
- Code and models are publicly available for further research and application.
Related
Transformer Explainer: An Interactive Explainer of the Transformer Architecture
The Transformer architecture has transformed AI in text generation, utilizing self-attention and advanced features like layer normalization. The Transformer Explainer tool helps users understand its concepts interactively.
Transformer Explainer
The Transformer architecture has transformed AI in text generation, utilizing self-attention and key components like embedding and Transformer blocks, while advanced features enhance performance and stability.
Symmetric Power Transformers
Symmetric Power Transformers enhance linear transformer performance by using higher-dimensional embeddings and a hyperparameter \(p\) for state size, showing improved capabilities and compatibility with rotary embeddings in experiments.
Were RNNs all we needed?
The paper by Leo Feng et al. revisits RNNs, proposing minimal LSTMs and GRUs that enhance training speed and performance, suggesting a renewed interest in RNNs for machine learning applications.
Differential Transformer
The Differential Transformer improves attention mechanisms in Transformer models by enhancing relevant context and reducing noise, outperforming traditional models in language tasks and improving accuracy in in-context learning.
weight = attention(token_query, weight_keys, weight_values).
In other words, they query weight_keys to fetch the weight_values, and mix them to compute each weight on the spot.Increasing model size becomes a matter of adding more weight_keys and weight_values, and incrementally training them.
Simple, clever, and it seems to work well. Beautiful.
One interesting thing to note: sounds like model scaling happens on the fly by adding key-value pairs as rows in the K and V matrices on the Pattention layer. That suggests that weights represented by tokens in the first rows may be more important than weights in later rows. There may be a lot you could do with that ordering of weights in terms of pruning and such.
Consider a case of two "experts" or two "value parameter tokens."
The mixture of experts has a "router" network that provides a weight to each expert (through a softmax) conditional on an input. The output is a (sparse) weighted sum of the outputs of the experts.
The TokenFormer has an "attention" layer combines the token and a key value to provide a weight to each "value parameter" token. A(B+C) = AB + AC definitionally, so this is like applying a weighted sum of distinct transformations.
I think the differences are: a) where the non-linearity hits (the above description doesn't consider an activation function), b) this attention softmax is not (necessarily) sparse, c) that "mixtral" networks only replace the feed-forward components of the layer, and d) that extending a "mixtral" approach would require re-training the "router" layers.
It seems like (d) is maybe the nicest feature here... my intuition would think (a) doesn't matter much, (b) is debatable (how close a sparse-MoE can approximate a dense-MoE), (c) has probably been tried (guessing the ffwd limitation was just "more-bang-for-buck-given-parameters" not an oversight)...
... I wonder, though, if there might be diminishing returns here (I believe that Mixture-of-Experts tends to struggle with imbalanced "winner-take-all" dynamics, since "early" winners get more gradient signal to improve their weights) and how different this would have been from going from 3x7B to a 8x7B to a 24x7B training approach (with a "retrain routing networks" step).
Their claimed theoretical advancement is as follows. If you want to transform an input vector X to another vector Y of different dimension, "normal" people suggest to use a linear projection: create an appropriately sized matrix W and simply multiply it by your input:
Given X ∈ d_in and W ∈ d_in × d_out, then Y ∈ d_out = X @ W.
In the attention layer, where the input X is converted into queries Q, keys K, and values V, this is the simple strategy employed: Q = X @ W_q, K = X @ W_k, V = X @ W_v, and it has shown itself to be effective.
This is too simple for the authors of this paper. They propose another approach. Instead of converting directly to the desired dimension, we will increase computation by creating an intermediate dimension, and introducing a non-linearity between them.
Given X ∈ d_in, and W_1 ∈ d_in × d_tmp, and W_2 ∈ d_tmp × d_out, then Y ∈ d_out = f(X @ W_1) @ W_2.
Here, f can be any non-linearity. The authors choose softmax; it allows them to claim a superficial resemblance to attention. Later in the paper, they reveal it is not actually softmax, but a modified version to avoid gradient vanishing (softmax is not a very good general-purpose non-linearity).
So, they replace all projections in the attention layer with this new strategy. So Q = f(X @ W_q1) @ W_q2. And K = f(X @ W_k1) @ W_k2. And V = f(X @ W_k3).
The problem with this is not theoretical: this does increase the model's expressiveness and computational power. It is practical: we are adding parameters where we need them the least, in the attention layer. It is generally understood that LLMs do not need extra parameters in the attention layer. Actually, advancements like Grouped-Query Attention hinge on the idea that you can halve or even fourth the number of parameters in the attention layer without harming performance. The experience of the LLM community so far suggests that the authors' idea of adding even more parameters to the self-attention layer should degrade their models' performance while adding no tangible gain.
The authors' numbers say otherwise. But it is hard to trust their numbers. When training a Transformer to compare against they replicate the original GPT-2 proposed in 2019. In doing so they ignore years of architectural improvements, such as rotary positional embeddings, SwiGLU, and RMSNorm that have culminated in Transformer++, the strong recipe which is what Meta's Llama series uses. We've seen this time after time in the various "Transformer killers" that used to be popular about a year ago. A researcher would think up some novel variant of linear attention, furiously test it against a weak GPT-2 baseline, find it blew it out of the water, and declare victory. Somehow, these never caught on, because when tested against a newer baseline these models weren't actually that great. The authors are doing the same thing here.
In their tables they also include comparisons to other models. Actually, they exclusively select the EleutherAI suites: GPT-Neo, OPT, and Pythia. These models were not trained with any modern architectural improvements except rotary embedding (which EleutherAI invented), and so predictably TokenFormer crushes them. On the last page of the appendix the authors have included a full table with some more fair comparisons. Their TokenFormer-150M variant achieves a Pile ppl of 10.45 against Mamba-130M's 10.54. In the intermediate weight class, TokenFormer-450M matches Mamba-370M's 8.28 Pile ppl despite having 21% more parameters. And in the largest size, TokenFormer-1.5B loses to Mamba-1.4B, 6.91 to 6.80 ppl.
Overall, the architectural tweak proposed in this paper is impractical, and the few fair comparisons they include are unimpressive. TokenFormer is another in a long line of Transformer-killers that have nice graphs of cherry-picked data, and will similarly fade into obscurity.
Related
Transformer Explainer: An Interactive Explainer of the Transformer Architecture
The Transformer architecture has transformed AI in text generation, utilizing self-attention and advanced features like layer normalization. The Transformer Explainer tool helps users understand its concepts interactively.
Transformer Explainer
The Transformer architecture has transformed AI in text generation, utilizing self-attention and key components like embedding and Transformer blocks, while advanced features enhance performance and stability.
Symmetric Power Transformers
Symmetric Power Transformers enhance linear transformer performance by using higher-dimensional embeddings and a hyperparameter \(p\) for state size, showing improved capabilities and compatibility with rotary embeddings in experiments.
Were RNNs all we needed?
The paper by Leo Feng et al. revisits RNNs, proposing minimal LSTMs and GRUs that enhance training speed and performance, suggesting a renewed interest in RNNs for machine learning applications.
Differential Transformer
The Differential Transformer improves attention mechanisms in Transformer models by enhancing relevant context and reducing noise, outperforming traditional models in language tasks and improving accuracy in in-context learning.