November 1st, 2024

TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters

TokenFormer introduces a scalable architecture for Transformers, allowing efficient scaling from 124 million to 1.4 billion parameters without complete retraining, while maintaining performance comparable to traditional models.

Read original articleLink Icon
TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters

The paper titled "TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters" introduces a novel architecture aimed at addressing the high computational costs associated with scaling Transformer models. Traditional Transformers rely on a fixed number of parameters in linear projections, which necessitates retraining the entire model when architectural changes are made. This approach becomes increasingly impractical as model sizes grow. The authors propose TokenFormer, which utilizes the attention mechanism for both input tokens and model parameters, allowing for greater architectural flexibility. By treating model parameters as tokens, the architecture replaces linear projections with a token-parameter attention layer, enabling efficient scaling from 124 million to 1.4 billion parameters without the need for complete retraining. This method achieves performance levels comparable to Transformers trained from scratch while significantly reducing training costs. The authors provide access to the code and models for further exploration.

- TokenFormer offers a scalable architecture for Transformers, reducing retraining costs.

- The model treats parameters as tokens, enhancing flexibility in architectural modifications.

- It scales efficiently from 124M to 1.4B parameters without complete retraining.

- Performance is comparable to traditional Transformers trained from scratch.

- Code and models are publicly available for further research and application.

Link Icon 12 comments
By @cs702 - 6 months
The authors factorize every weight matrix with an attention mechanism:

  weight = attention(token_query, weight_keys, weight_values).
In other words, they query weight_keys to fetch the weight_values, and mix them to compute each weight on the spot.

Increasing model size becomes a matter of adding more weight_keys and weight_values, and incrementally training them.

Simple, clever, and it seems to work well. Beautiful.

By @valine - 6 months
I would like to see a comparison for the inference time compute between a regular transformer and this. I’m assuming token/s is lower since you need to compute the weights of the model for each token prior to the actual attention calculations for the sequence position.
By @davesque - 6 months
Seems like a big deal. I feel like this could enable a new level of modularity and compatibility between publicly available weight sets, assuming they use similar channel dimensions. Maybe it also provides a nice formalism for thinking about fine tuning, where you could adopt certain heuristics for adding/removing key-value pairs from the Pattention layers.

One interesting thing to note: sounds like model scaling happens on the fly by adding key-value pairs as rows in the K and V matrices on the Pattention layer. That suggests that weights represented by tokens in the first rows may be more important than weights in later rows. There may be a lot you could do with that ordering of weights in terms of pruning and such.

By @goldenshale - 6 months
This is a great idea. Being able to dynamically scale up model sizes as datasets and use cases expand without needing to retrain from scratch could enable a Cambrian explosion of interesting stuff building on top of a Llama type model trained in this way.
By @eric15342335 - 6 months
I am a university year 2 student learning about basic mathematics and statistics related to neural networks. One thing that shocks me is that there isn't an "incremental" solution for building larger (more parameters) AI models (like GPT-4) despite having one in a smaller size e.g. GPT-3.5 (I saw the term "incremental (compiling)" nearly everywhere in the software engineering industry). I am curious how is this not possible theortically?
By @c0g - 6 months
By @logicchains - 6 months
Seems this would naturally translate into a mixture of experts by using a "hard" attention function so that only a fixed amount of weight tokens get included in the calculation.
By @ml_thoughts - 6 months
This seems closely related to the "Mixtral" approach of a mixture-of-experts transformer [1]... I'm not claiming the approach is not original, it just helped me understand what was going on.

Consider a case of two "experts" or two "value parameter tokens."

The mixture of experts has a "router" network that provides a weight to each expert (through a softmax) conditional on an input. The output is a (sparse) weighted sum of the outputs of the experts.

The TokenFormer has an "attention" layer combines the token and a key value to provide a weight to each "value parameter" token. A(B+C) = AB + AC definitionally, so this is like applying a weighted sum of distinct transformations.

I think the differences are: a) where the non-linearity hits (the above description doesn't consider an activation function), b) this attention softmax is not (necessarily) sparse, c) that "mixtral" networks only replace the feed-forward components of the layer, and d) that extending a "mixtral" approach would require re-training the "router" layers.

It seems like (d) is maybe the nicest feature here... my intuition would think (a) doesn't matter much, (b) is debatable (how close a sparse-MoE can approximate a dense-MoE), (c) has probably been tried (guessing the ffwd limitation was just "more-bang-for-buck-given-parameters" not an oversight)...

... I wonder, though, if there might be diminishing returns here (I believe that Mixture-of-Experts tends to struggle with imbalanced "winner-take-all" dynamics, since "early" winners get more gradient signal to improve their weights) and how different this would have been from going from 3x7B to a 8x7B to a 24x7B training approach (with a "retrain routing networks" step).

[1] https://arxiv.org/abs/2401.04088

By @mentalically - 6 months
Eventually people will figure out how to nest neural networks in the nodes and edges of an arbitrary graph.
By @davesque - 6 months
Seems like a lot of existing models could be converted to this token parameter representation.
By @sapphire42 - 6 months
As someone who has worked in this space, this paper is unfortunately total BS.

Their claimed theoretical advancement is as follows. If you want to transform an input vector X to another vector Y of different dimension, "normal" people suggest to use a linear projection: create an appropriately sized matrix W and simply multiply it by your input:

Given X ∈ d_in and W ∈ d_in × d_out, then Y ∈ d_out = X @ W.

In the attention layer, where the input X is converted into queries Q, keys K, and values V, this is the simple strategy employed: Q = X @ W_q, K = X @ W_k, V = X @ W_v, and it has shown itself to be effective.

This is too simple for the authors of this paper. They propose another approach. Instead of converting directly to the desired dimension, we will increase computation by creating an intermediate dimension, and introducing a non-linearity between them.

Given X ∈ d_in, and W_1 ∈ d_in × d_tmp, and W_2 ∈ d_tmp × d_out, then Y ∈ d_out = f(X @ W_1) @ W_2.

Here, f can be any non-linearity. The authors choose softmax; it allows them to claim a superficial resemblance to attention. Later in the paper, they reveal it is not actually softmax, but a modified version to avoid gradient vanishing (softmax is not a very good general-purpose non-linearity).

So, they replace all projections in the attention layer with this new strategy. So Q = f(X @ W_q1) @ W_q2. And K = f(X @ W_k1) @ W_k2. And V = f(X @ W_k3).

The problem with this is not theoretical: this does increase the model's expressiveness and computational power. It is practical: we are adding parameters where we need them the least, in the attention layer. It is generally understood that LLMs do not need extra parameters in the attention layer. Actually, advancements like Grouped-Query Attention hinge on the idea that you can halve or even fourth the number of parameters in the attention layer without harming performance. The experience of the LLM community so far suggests that the authors' idea of adding even more parameters to the self-attention layer should degrade their models' performance while adding no tangible gain.

The authors' numbers say otherwise. But it is hard to trust their numbers. When training a Transformer to compare against they replicate the original GPT-2 proposed in 2019. In doing so they ignore years of architectural improvements, such as rotary positional embeddings, SwiGLU, and RMSNorm that have culminated in Transformer++, the strong recipe which is what Meta's Llama series uses. We've seen this time after time in the various "Transformer killers" that used to be popular about a year ago. A researcher would think up some novel variant of linear attention, furiously test it against a weak GPT-2 baseline, find it blew it out of the water, and declare victory. Somehow, these never caught on, because when tested against a newer baseline these models weren't actually that great. The authors are doing the same thing here.

In their tables they also include comparisons to other models. Actually, they exclusively select the EleutherAI suites: GPT-Neo, OPT, and Pythia. These models were not trained with any modern architectural improvements except rotary embedding (which EleutherAI invented), and so predictably TokenFormer crushes them. On the last page of the appendix the authors have included a full table with some more fair comparisons. Their TokenFormer-150M variant achieves a Pile ppl of 10.45 against Mamba-130M's 10.54. In the intermediate weight class, TokenFormer-450M matches Mamba-370M's 8.28 Pile ppl despite having 21% more parameters. And in the largest size, TokenFormer-1.5B loses to Mamba-1.4B, 6.91 to 6.80 ppl.

Overall, the architectural tweak proposed in this paper is impractical, and the few fair comparisons they include are unimpressive. TokenFormer is another in a long line of Transformer-killers that have nice graphs of cherry-picked data, and will similarly fade into obscurity.

By @a_wild_dandan - 6 months
This could be revolutionary. The PPL/compute graphs are damning. If the Transformer is a function, then the TokenFormer feels like a higher-order function. Perhaps this approach is a natural direction for producing System Two reasoning? There's so much to digest here...