October 8th, 2024

Differential Transformer

The Differential Transformer improves attention mechanisms in Transformer models by enhancing relevant context and reducing noise, outperforming traditional models in language tasks and improving accuracy in in-context learning.

Read original articleLink Icon
CuriosityConfusionAppreciation
Differential Transformer

The paper titled "Differential Transformer" introduces a novel architecture aimed at improving the attention mechanism in Transformer models. The authors, Tianzhu Ye and colleagues, propose the Differential Transformer (Diff Transformer), which enhances attention to relevant context while minimizing the influence of irrelevant information. This is achieved through a differential attention mechanism that computes attention scores by subtracting two separate softmax attention maps, effectively canceling out noise and fostering sparse attention patterns. Experimental results indicate that Diff Transformer outperforms traditional Transformer models in various scenarios, particularly in language modeling, long-context modeling, key information retrieval, and reducing hallucinations in tasks like question answering and text summarization. Additionally, it shows improved accuracy and robustness in in-context learning, addressing issues related to order permutation. The findings suggest that Diff Transformer is a promising advancement for large language models, offering significant practical benefits.

- Differential Transformer enhances attention to relevant context while reducing noise.

- It outperforms traditional Transformer models in language modeling and other applications.

- The architecture mitigates hallucinations in question answering and text summarization.

- It improves accuracy and robustness in in-context learning tasks.

- The differential attention mechanism promotes sparse attention patterns.

AI: What people are saying
The comments on the Differential Transformer highlight various insights and concerns regarding its architecture and performance improvements.
  • Many commenters express confusion about the mechanism behind the differential attention and how it effectively reduces noise while maintaining performance.
  • There are discussions about the trade-offs involved, particularly regarding parameter efficiency and memory usage compared to traditional transformers.
  • Some users question the implications of negative attention weights and how the model balances attention between relevant and irrelevant contexts.
  • Several commenters note the potential for improved performance in tasks like question answering and text summarization, while also raising concerns about the risk of hallucination.
  • Overall, there is a shared interest in understanding the practical applications and implications of this new architecture in the field of machine learning.
Link Icon 32 comments
By @Imnimo - 6 months
I feel like I'm missing a key insight here. I understand the problem that regular softmax attention struggles to approach assigning zero attention to irrelevant stuff. And I get that having this subtraction formula makes it possible to assign exactly (or near) zero attention weight without having crazy outlier activations. But it seems like it also makes it very easy to have negative attention weight (which is equivalent to having positive attention weight on the negation of your value vectors). Intuitively, it just feels like a difficult balancing act to keep all the stuff you don't care about so close to zero.

But Figure 1 clearly shows that it works, so I don't doubt that it is in fact possible. I'm just struggling to build a picture of how exactly the network accomplishes this.

By @aDyslecticCrow - 6 months
Very clever. I like this kind of nitty-gritty detail work, and the change is small enough to be adapted easily by others. Bravo!

I'm a little concerned about the last sentence of the section introduction of "2 Differential Transformer". It mentions using improvements from previous papers, but in the grammatical context, it's unclear if this improvement is added to both the normal transformer and their diff transformer. This would otherwise sully the comparisons. It's the "main difference" wording in the previous sentence that raised a flag for me.

Of course, a good-faith researcher would know this and may not feel the need to clarify. But you can never be too careful about some published research in this field.

By @msoad - 6 months
Like most things in this new world of Machine Learning, I'm really confused why this works?

The analogy to noise-cancelling headphones is helpful but in that case we clearly know which is signal and which is noise. Here, if we knew why would we even bother to the noise-cancelling work?

By @islewis - 6 months
> Differential attention takes the difference between two softmax attention functions to eliminate attention noise

If I understand correctly, this architecture trades twice as much attention memory in exchange for either a higher quality model, or less parameters at a similar quality.

> According to the fitted curves, 6.8B-size DIFF Transformer achieves a validation loss comparable to 11B-size Transformer, requiring only 62.2% of parameters

This raises a few questions for me:

- Would having only 60% of the parameters negate the double space for attention, leaving a similar memory profile as a traditional transformer?

- Does that tradeoff change noticeably between training and inference?

By @WithinReason - 6 months
We empirically find that the setting λᵢₙᵢₜ = 0.8 − 0.6 × exp(−0.3 · (l − 1)) works well in practice

I wonder about the story behind that formula...

By @iandanforth - 6 months
The key bit I didn't understand at first was what happens if the two groups of attention learn the same thing; because their attention masks are subtracted from one another if they both output similar values the attention across the board will drop to zero and this will lead to high loss. So the only way to reduce loss is if they learn to attend to different things. One of the simplest strategies they could learn (and this paper claims that they do) is for one group to focus on relevant context and the other to focus on irrelevant context. Thus one group learns the noise and the other the signal (it's not this cut and dry but is a useful simplification for understanding IMO).
By @patcon - 6 months
I wonder what is lost here. Surely there's a trade-off...

I'm wondering if there's any effect of "creativity", or ability to interpolate between concepts. Hallucination and creativity feel very related to me. I understand hallucinating as simply being misaligned with the space humans feel appropriate to interpolate between

By @chessgecko - 6 months
I wonder how much of the value here is from canceling out the positional noise rope produces. I would love to see a table comparing an alibi version of this to an alibi baseline in addition to the rope models here.

Crazy gains though congrats to the researchers

By @vsroy - 6 months
Is the thing that's going on here that softmax can't push a value to 0, but by subtracting 2 softmax maps we can output 0s?
By @machinelearning - 6 months
This is a good problem to solve but the approach is wrong imo.

It has to be done in a hierarchical way to know what you attended to + full context.

If the differential vector is being computed with the same input as the attention vector how do you know how to modify the attention vector correctly

By @pxdm - 6 months
What's the comparison with conventional attention using a more aggressive (lower temperature) softmax? I can imagine that for the multi-needle retrieval test this may also give a performance boost, although at some cost other more creative tasks.
By @nmacias - 6 months
AdderaLLM was right there
By @miven - 6 months
Is there an intuitive reason why this ends up working this well compared to, say, applying some kind of thresholding to attention activations that are below average for a given head to filter that same attention noise out?
By @pizza - 6 months
Was just going to mention that it seems that it should be possible to make a Flash Attention version of this algorithm and was pleasantly surprised to see they already included an implementation of one :)
By @watsonmusic - 6 months
The modification is simple and beautiful. And the improvements are quite significant.
By @singularity2001 - 6 months
Anyone remember siamese networks?
By @slashdave - 6 months
I don't get it. Arbitrary linear combinations are already accommodated via feed forward. What am I missing?
By @WithinReason - 6 months
Hmmm, this could be expressed as 2 consecutive attentions in a residual branch:

Simplified differential T. looks like: (softmax(Q₁K₁) − λ softmax(Q₂K₂)) V

You can factor this into:

    x = softmax(Q₁K₁)V
    x += -λ softmax(Q₂K₂)V
which is like 2 subsequent regular attentions added that are sharing V
By @h_tbob - 6 months
I wish they didn’t use swiGLU and preRMSnorm so we could have a better comparison.

Then we would know how much this transformer innovation helps by itself.

By @digdugdirk - 6 months
Is there any way to replicate this with existing models, or are we going to need to wait for models to be trained in this style?

I'm imagining a smaller model examining the output tokens of a larger model and metaphorically slapping it on the wrist with a ruler if the output tokens start drifting off topic. Not quite the same, but an entertaining thought nonetheless.

By @dartos - 6 months
> By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization

I’m very interested in this claim. I was under the impression that hallucination is unavoidable in these kinds of models. IIRC proof for that was trending on HN a couple weeks ago.

By @mik09 - 6 months
r/machine learning comment thread has some interesting ideas, one of them linking this one with similar work in CV: https://www.reddit.com/r/MachineLearning/comments/1g0lnij/r_...
By @lucidrains - 6 months
does this not mean we should explore usage of talking heads (Shazeer et al) a bit more? https://arxiv.org/abs/2003.02436
By @x49asvk - 6 months
This concept is really interesting to me, I am very very new to transformers but would love to learn more about normal transformers and differential too. Can anyone suggest any resources?
By @pikseladam - 6 months
Did this mean they solved the hallucination problem of transformers?

edit: not fully but it gives promising results. quiet an improvement actually.

By @nowayno583 - 6 months
Does anyone understand why they are taking the difference between transformers instead of the sum? It seems to me that in a noise reducing solution we would be more interested in the sum, as random noise would cancel out and signal would be constructive.

Of course, even if I'm right proper training would account to that by inverting signs where appropriate. Still, it seems weird to present it as the difference, especially seeing as they compare this directly to noise cancelling headphones, where we sum both microphones inputs.

By @badsandwitch - 6 months
What is purpose of the lambda parameter? Why isn't it a constant of 1?
By @esafak - 6 months
How is this different than using a sparsity-inducing prior?
By @magicalhippo - 6 months
The visualization reveals that Transformer tends to allocate only a small proportion of attention scores to the correct answer, while disproportionately focusing on irrelevant context.

[...] Specifically, we partition the query and key vectors into two groups and compute two separate softmax attention maps. Then the result of subtracting these two maps is regarded as attention scores.

[...] The approach is analogous to noise-canceling headphones and differential amplifiers in electrical engineering, where the difference between two signals cancels out common-mode noise.

Simple change, with seemingly decent improvements across the board.

By @campers - 6 months
The tl;dr on high level performance improvements

"The scaling curves indicate that Diff Transformer requires only about 65% of model size or training tokens needed by Transformer to achieve comparable language modeling performance."

"Diff Transformer retains high performance even at reduced bit-widths, ranging from 16 bits to 6 bits. In comparison, Transformer’s accuracy significantly drops with 6-bit quantization. The 4-bit Diff Transformer achieves comparable accuracy as the 6-bit Transformer, and outperforms the 4-bit Transformer by about 25% in accuracy."

By @ExxKA - 6 months
Very interesting. Currently working on timeseries with Transformers. Let me know if anyone else out there is also reading it from that context.