On Vanishing Variance in Transformer Length Generalization

Ruining Li
University of Oxford
[email protected]
&Gabrijel Boduljak
University of Oxford
[email protected]
&Jensen (Jinghao) Zhou
University of Oxford
[email protected]
Equal contribution; each reserves the right to be listed first.
Abstract

It is a widely known issue that Transformers, when trained on shorter sequences, fail to generalize robustly to longer ones at test time. This raises the question of whether Transformer models are real reasoning engines, despite their impressive abilities in mathematical problem solving and code synthesis. In this paper, we offer a vanishing variance perspective on this issue. To the best of our knowledge, we are the first to demonstrate that even for today’s frontier models, a longer sequence length results in a decrease in variance in the output of the multi-head attention modules. On the argmaxargmax\operatorname{argmax}roman_argmax retrieval and dictionary lookup tasks, our experiments show that applying layer normalization after the attention outputs leads to significantly better length generalization. Our analyses attribute this improvement to a reduction—though not a complete elimination—of the distribution shift caused by vanishing variance. Project page: ruiningli.com/vanishing-variance.

Refer to caption
Figure 1: Standard deviation of a fixed component in attention outputs from the first layer of Llama-3.2-1B (log-log scale) over multiple input sequences of fixed length N𝑁Nitalic_N. Even in the latest LLMs, increasing sequence length N𝑁Nitalic_N reduces the variance of attended outputs, significantly degrading accuracy on long sequences.

1 Background: Vanishing Variance

It is no exaggeration to say that Transformers  (Vaswani et al., 2017) is the most important architecture in modern deep learning. It is widely adopted in almost every domain, ranging from natural language (Vaswani et al., 2017; Devlin et al., 2019; Brown et al., 2020) and vision (Dosovitskiy et al., 2021; Peebles & Xie, 2023) to audio (Radford et al., 2023) and protein design (Jumper et al., 2021). Despite its successes, recent studies (Press et al., 2022; Zhou et al., 2023; 2024; Kazemnejad et al., 2024; Veličković et al., 2024) in large language models (LLMs) have shown that transformer-based models often struggle with length generalization, an ability that requires the model to generalize to longer sequences than seen during training. Several prior works have proposed to either refine position encodings (Ruoss et al., 2023; Zhou et al., 2024; Kazemnejad et al., 2024) or adapt the softmax function (Press et al., 2022; Veličković et al., 2024) to improve length generalization. However, these methods are ad-hoc and lack interpretability, making it more of an art than a science to understand when and why they work.

In this paper, we study the distribution shift that occurs in the intermediate outputs when an attention module trained on shorter sequences is subsequently exposed to longer ones in a zero-shot manner. We hope that our findings will encourage future research on network architectures that are provably robust (e.g., invariant) to varying sequence lengths.

Background and notations.

At the core of Transformers is the attention mechanism  (Vaswani et al., 2017). The attention first projects the input sequence 𝐗=[𝐱1𝐱2𝐱N]N×D𝐗superscriptdelimited-[]conditionalsubscript𝐱1normsubscript𝐱2subscript𝐱𝑁topsuperscript𝑁𝐷\mathbf{X}=\left[\mathbf{x}_{1}\|\mathbf{x}_{2}\|\dots\|\mathbf{x}_{N}\right]^% {\top}\in\mathbb{R}^{N\times D}bold_X = [ bold_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ bold_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ … ∥ bold_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT, where N𝑁Nitalic_N is the sequence length and each item 𝐱nDsubscript𝐱𝑛superscript𝐷\mathbf{x}_{n}\in\mathbb{R}^{D}bold_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT has D𝐷Ditalic_D features, into keys 𝐊=𝐗𝐖KN×D𝐊subscript𝐗𝐖𝐾superscript𝑁𝐷\mathbf{K}=\mathbf{X}\mathbf{W}_{K}\in\mathbb{R}^{N\times D}bold_K = bold_XW start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT and values 𝐕=𝐗𝐖VN×D𝐕subscript𝐗𝐖𝑉superscript𝑁𝐷\mathbf{V}=\mathbf{X}\mathbf{W}_{V}\in\mathbb{R}^{N\times D}bold_V = bold_XW start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT using learnable weight matrices 𝐖K,𝐖VD×Dsubscript𝐖𝐾subscript𝐖𝑉superscript𝐷𝐷\mathbf{W}_{K},\mathbf{W}_{V}\in\mathbb{R}^{D\times D}bold_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D end_POSTSUPERSCRIPT. Similarly, the query sequence 𝐘=[𝐲1𝐲2𝐲M]M×D𝐘superscriptdelimited-[]conditionalsubscript𝐲1normsubscript𝐲2subscript𝐲𝑀topsuperscript𝑀𝐷\mathbf{Y}=\left[\mathbf{y}_{1}\|\mathbf{y}_{2}\|\dots\|\mathbf{y}_{M}\right]^% {\top}\in\mathbb{R}^{M\times D}bold_Y = [ bold_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ bold_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ … ∥ bold_y start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_M × italic_D end_POSTSUPERSCRIPT is projected into queries 𝐐=𝐘𝐖QM×D𝐐subscript𝐘𝐖𝑄superscript𝑀𝐷\mathbf{Q}=\mathbf{Y}\mathbf{W}_{Q}\in\mathbb{R}^{M\times D}bold_Q = bold_YW start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_M × italic_D end_POSTSUPERSCRIPT using 𝐖QD×Dsubscript𝐖𝑄superscript𝐷𝐷\mathbf{W}_{Q}\in\mathbb{R}^{D\times D}bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D end_POSTSUPERSCRIPT. The attention then computes 𝐎=softmax(𝐐𝐊D)𝐕M×D𝐎softmaxsuperscript𝐐𝐊top𝐷𝐕superscript𝑀𝐷\mathbf{O}=\operatorname{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^{\top}}{% \sqrt{D}}\right)\mathbf{V}\in\mathbb{R}^{M\times D}bold_O = roman_softmax ( divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D end_ARG end_ARG ) bold_V ∈ blackboard_R start_POSTSUPERSCRIPT italic_M × italic_D end_POSTSUPERSCRIPT and projects it using another weight matrix 𝐖OD×Dsubscript𝐖𝑂superscript𝐷𝐷\mathbf{W}_{O}\in\mathbb{R}^{D\times D}bold_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D end_POSTSUPERSCRIPT to yield the final result Attn(𝐗,𝐘)=𝐎𝐖OAttn𝐗𝐘subscript𝐎𝐖𝑂\operatorname{Attn}(\mathbf{X},\mathbf{Y})=\mathbf{O}\mathbf{W}_{O}roman_Attn ( bold_X , bold_Y ) = bold_OW start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT. In this paper, we use the term “attention weights” to refer to the softmax score, i.e., softmax(𝐐𝐊D)softmaxsuperscript𝐐𝐊top𝐷\operatorname{softmax}(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{D}})roman_softmax ( divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D end_ARG end_ARG ), “attention outputs” the intermediate 𝐎𝐎\mathbf{O}bold_O, and Attn(𝐗,𝐘)Attn𝐗𝐘\operatorname{Attn}(\mathbf{X},\mathbf{Y})roman_Attn ( bold_X , bold_Y ) the final result.

Our main observation is the vanishing variance problem: as the sequence length N𝑁Nitalic_N increases, the variance of attention outputs (computed over multiple input sequences of length N𝑁Nitalic_N) decreases. We formalize this as Proposition 1.

Proposition 1 (The vanishing variance problem).

Consider a trained attention module with weights 𝐖Q,𝐖K,𝐖V,𝐖Osubscript𝐖𝑄subscript𝐖𝐾subscript𝐖𝑉subscript𝐖𝑂\mathbf{W}_{Q},\mathbf{W}_{K},\mathbf{W}_{V},\mathbf{W}_{O}bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT. Let 𝐗=[𝐱1𝐱2𝐱N]𝐗superscriptdelimited-[]conditionalsubscript𝐱1normsubscript𝐱2subscript𝐱𝑁top\mathbf{X}=\left[\mathbf{x}_{1}\|\mathbf{x}_{2}\|\dots\|\mathbf{x}_{N}\right]^% {\top}bold_X = [ bold_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ bold_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ … ∥ bold_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT denote an input sequence of length N𝑁Nitalic_N. If (1) 𝐱1,𝐱2,,𝐱Ni.i.d𝒳subscript𝐱1subscript𝐱2subscript𝐱𝑁i.i.dsimilar-to𝒳\mathbf{x}_{1},\mathbf{x}_{2},\dots,\mathbf{x}_{N}\overset{\text{i.i.d}}{\sim}% \mathcal{X}bold_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT overi.i.d start_ARG ∼ end_ARG caligraphic_X, a distribution over a finite vocabulary, and (2) 𝔼𝐱𝒳[𝐖V𝐱]=𝟎subscript𝔼similar-to𝐱𝒳delimited-[]subscript𝐖𝑉𝐱0\mathbb{E}_{\mathbf{x}\sim\mathcal{X}}[\mathbf{W}_{V}\mathbf{x}]=\mathbf{0}blackboard_E start_POSTSUBSCRIPT bold_x ∼ caligraphic_X end_POSTSUBSCRIPT [ bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT bold_x ] = bold_0, then for a fixed query 𝐲𝐲\mathbf{y}bold_y and a fixed feature d𝑑ditalic_d,

limNVar(𝐱1,𝐱2,,𝐱N)𝒳N([softmax(𝐐𝐊D)𝐕]d)=0,subscript𝑁subscriptVarsimilar-tosubscript𝐱1subscript𝐱2subscript𝐱𝑁superscript𝒳𝑁subscriptdelimited-[]softmaxsuperscript𝐐𝐊top𝐷𝐕𝑑0\displaystyle\lim_{N\to\infty}\operatorname{Var}_{(\mathbf{x}_{1},\mathbf{x}_{% 2},\ldots,\mathbf{x}_{N})\sim\mathcal{X}^{N}}\left(\left[\operatorname{softmax% }\left(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{D}}\right)\mathbf{V}\right]_{d% }\right)=0,roman_lim start_POSTSUBSCRIPT italic_N → ∞ end_POSTSUBSCRIPT roman_Var start_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ∼ caligraphic_X start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( [ roman_softmax ( divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D end_ARG end_ARG ) bold_V ] start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) = 0 ,

where 𝐱n,𝐲Dsubscript𝐱𝑛𝐲superscript𝐷\mathbf{x}_{n},\mathbf{y}\in\mathbb{R}^{D}bold_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , bold_y ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT and 𝐐1×D,𝐊N×D,𝐕N×Dformulae-sequence𝐐superscript1𝐷formulae-sequence𝐊superscript𝑁𝐷𝐕superscript𝑁𝐷\mathbf{Q}\in\mathbb{R}^{1\times D},\mathbf{K}\in\mathbb{R}^{N\times D},% \mathbf{V}\in\mathbb{R}^{N\times D}bold_Q ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_D end_POSTSUPERSCRIPT , bold_K ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT , bold_V ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT are intermediate results in Attn(𝐗,[𝐲])Attn𝐗delimited-[]𝐲\operatorname{Attn}(\mathbf{X},[\mathbf{y}])roman_Attn ( bold_X , [ bold_y ] ).

Informally, for a fixed component d𝑑ditalic_d in the attention outputs, its variance over input sequences of length N𝑁Nitalic_N, where each sequence consists of N𝑁Nitalic_N independently and identically distributed (i.i.d.) tokens, vanishes as N𝑁N\to\inftyitalic_N → ∞.

Proof.

Please refer to Appendix A. ∎

Note that assumptions of Proposition 1 are violated in practice. In particular, the independence assumption does not hold in LLMs because of (1) the introduction of positional encoding, and more significantly (2) the nature of language, where preceding words provide important context for those that follow. In addition, 𝔼[𝐖V𝐱i]=𝟎𝔼delimited-[]subscript𝐖𝑉subscript𝐱𝑖0\mathbb{E}[\mathbf{W}_{V}\mathbf{x}_{i}]=\mathbf{0}blackboard_E [ bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] = bold_0 is not strictly enforced. Nevertheless, we find that even for today’s frontier LLMs, the decay in attention output variance, as established in Proposition 1, remains pronounced. In Fig. 1, we plot the standard deviation σ𝜎\sigmaitalic_σ of a fixed component of the attention outputs from the first layer of Llama-3.2-1B (AI@Meta, 2024) as a function of input sequence length N𝑁Nitalic_N. σ𝜎\sigmaitalic_σ is computed over 100100100100 length-N𝑁Nitalic_N sequences sampled randomly with 3333 strategies: ( Random Tokens w/o P.E.) We sample single tokens i.i.d uniformly at random from the tokenizer’s vocabulary, and remove the positional encoding for inference; ( Random Tokens w/ P.E.) We still sample single tokens i.i.d uniformly at random, but keep the positional encoding at inference time; ( Sentences w/ P.E.) We sample consecutive sentences from a long paragraph111Obtained from https://github.com/dscape/spell/blob/master/test/resources/big.txt, and truncate the token sequences to length N𝑁Nitalic_N—such sequences lie within the LLM’s training distribution. As can be seen in the log-log plot, for Random Tokens w/o P.E., where the independence assumption does hold, σ𝜎\sigmaitalic_σ scales with input sequence length N𝑁Nitalic_N roughly as σN0.5proportional-to𝜎superscript𝑁0.5\sigma\propto N^{-0.5}italic_σ ∝ italic_N start_POSTSUPERSCRIPT - 0.5 end_POSTSUPERSCRIPT. For Random Tokens w/ P.E. and Sentences w/ P.E., where such assumption is no longer valid, the downward trend is still obvious.

2 Layer Normalization for Length Generalization

As variance vanishes with longer sequence lengths in attention outputs, we are intrigued to investigate the causes of performance degradation observed in LLMs. To this end, we perform a toy study on the statistical behavior of attention output values.

For simplicity, we consider a one-layer Transformer with single-head attention, omitting residual connections and normalization, following Veličković et al. (2024). We adopt this architecture as our Baseline. The model receives a single query token and an input sequence of varying length to perform simple algorithmic tasks. To eliminate confounds from positional encodings, we focus on order-invariant tasks, where the output depends only on the multiset (not the order) of input tokens, including argmaxargmax\operatorname{argmax}roman_argmax retrieval and dictionary lookup. Our goal is to evaluate models trained on shorter sequences using longer (i.e., out-of-distribution in length) sequences to study length generalization. More details of the model architecture and synthetic data generation are provided in Appendix B.

In Fig. 2, we visualize the distribution of 5555 individual components in attention outputs 𝐎𝐎\mathbf{O}bold_O across multiple input sequences of lengths 24superscript242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT, 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT and 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT, obtained with a model checkpoint trained on sequences of up to length 24superscript242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT. As can be seen in the top row, testing on out-of-distribution sequence lengths leads to vanishing variance, causing a distribution shift where each individual component of 𝐎𝐎\mathbf{O}bold_O becomes more concentrated around its mean.

Refer to caption
Figure 2: Distribution of 5 individual features in attention outputs 𝐎𝐎\mathbf{O}bold_O across batches. Each color represents a different feature. As input sequence length N𝑁Nitalic_N increases from 24superscript242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT to 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT, feature variance decreases, and values concentrate around their mean. Layer normalization (bottom) scales and shifts features to maintain relatively constant global variance, likely explaining its superior length generalization compared to the Baseline (top).

While this distribution shift of individual features is expected according to Proposition 1, we are more interested in the distribution shift of the entire feature vector in Dsuperscript𝐷\mathbb{R}^{D}blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, as the whole vector is subsequently input to an MLP to predict the final result. As input sequence length N𝑁Nitalic_N increases, each feature is less likely to have extreme values (as its distribution is more centered). Consequently, the global feature variance, defined as σglobal2=1DΣd=1D(𝐨dμglobal)2superscriptsubscript𝜎global21𝐷superscriptsubscriptΣ𝑑1𝐷superscriptsubscript𝐨𝑑subscript𝜇global2\sigma_{\text{global}}^{2}=\frac{1}{D}\Sigma_{d=1}^{D}(\mathbf{o}_{d}-\mu_{% \text{global}})^{2}italic_σ start_POSTSUBSCRIPT global end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_D end_ARG roman_Σ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ( bold_o start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT global end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT where μglobal=1DΣd=1D𝐨dsubscript𝜇global1𝐷superscriptsubscriptΣ𝑑1𝐷subscript𝐨𝑑\mu_{\text{global}}=\frac{1}{D}\Sigma_{d=1}^{D}\mathbf{o}_{d}italic_μ start_POSTSUBSCRIPT global end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_D end_ARG roman_Σ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT bold_o start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is the global mean, also decreases. We illustrate this observation in Fig. 3 (right), where the global variance decays as N𝑁Nitalic_N increases. In Fig. 3 (left), we show that in addition to the global variance, the global mean μglobalsubscript𝜇global\mu_{\text{global}}italic_μ start_POSTSUBSCRIPT global end_POSTSUBSCRIPT also exhibits drift. Such a distribution shift in attention outputs (and thus MLP inputs) hinders generalization, since the MLP is only trained on features with larger global variance and a different global mean (Zhou et al., 2022).

Refer to caption
Figure 3: Layer normalization helps mitigate distribution shift in attention outputs. (Left) shows the drift in global mean as input sequence length deviates from the training distribution. The mean is normalized by the training global variance to eliminate scale differences. (Right) shows the decay in global variance. All results are averaged across 32323232k random input sequences of the fixed length.

To mitigate this distribution shift, we explore applying layer normalization (Ba et al., 2016) immediately after the attention outputs, i.e., LayerNorm(𝐎)b,t,d=γd𝐨b,t,dμb,tσb,t+ϵ+βd\operatorname{LayerNorm}(\mathbf{O})_{b,t,d}=\gamma_{d}\cdot\frac{\mathbf{o}_{% b,t,d}-\mathbf{\mu}_{b,t}}{\mathbf{\sigma}_{b,t}+\epsilon}+\beta_{d}roman_LayerNorm ( bold_O ) start_POSTSUBSCRIPT italic_b , italic_t , italic_d end_POSTSUBSCRIPT = italic_γ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ⋅ divide start_ARG bold_o start_POSTSUBSCRIPT italic_b , italic_t , italic_d end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT italic_b , italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_b , italic_t end_POSTSUBSCRIPT + italic_ϵ end_ARG + italic_β start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT, where 𝐎𝐎\mathbf{O}bold_O is the batched attention outputs, μb,t=1DΣd=1D𝐎b,t,dsubscript𝜇𝑏𝑡1𝐷superscriptsubscriptΣ𝑑1𝐷subscript𝐎𝑏𝑡𝑑\mu_{b,t}=\frac{1}{D}\Sigma_{d=1}^{D}\mathbf{O}_{b,t,d}italic_μ start_POSTSUBSCRIPT italic_b , italic_t end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_D end_ARG roman_Σ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT bold_O start_POSTSUBSCRIPT italic_b , italic_t , italic_d end_POSTSUBSCRIPT, σb,t=1DΣd=1D(𝐎b,t,dμb,t)2subscript𝜎𝑏𝑡1𝐷superscriptsubscriptΣ𝑑1𝐷superscriptsubscript𝐎𝑏𝑡𝑑subscript𝜇𝑏𝑡2\sigma_{b,t}=\sqrt{\frac{1}{D}\Sigma_{d=1}^{D}(\mathbf{O}_{b,t,d}-\mu_{b,t})^{% 2}}italic_σ start_POSTSUBSCRIPT italic_b , italic_t end_POSTSUBSCRIPT = square-root start_ARG divide start_ARG 1 end_ARG start_ARG italic_D end_ARG roman_Σ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ( bold_O start_POSTSUBSCRIPT italic_b , italic_t , italic_d end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT italic_b , italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG, and γ,βD𝛾𝛽superscript𝐷\gamma,\beta\in\mathbb{R}^{D}italic_γ , italic_β ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT are learnable scale and shift parameters. While variance decay in individual features is inevitable (bottom row of Fig. 2), standardization and learnable scale and shift parameters help stabilize the feature distribution. This adjustment preserves the global mean and variance more effectively as sequence length increases 1000×1000\times1000 × (Fig. 3). This enhances length generalization, as discussed next.

3 Experiments

We consider two tasks: argmaxargmax\operatorname{argmax}roman_argmax retrieval and dictionary lookup. The former has been considered by Veličković et al. (2024). The latter closely resembles the core function of the attention mechanism (i.e., to retrieve the most relevant information based on the similarity between queries and keys). As detailed in Appendix B, the order of input tokens does not affect the target output in either task. By deliberately selecting such tasks, we isolate and examine the length generalization capabilities of the attention mechanism itself, independent of any effects introduced by positional encodings (Zhou et al., 2024). We generate synthetic data (of input sequence length up to 16161616) to train the models, and evaluate them on sequences of length up to 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT.

3.1 Results and Analysis

The results presented in Table 1 and Table 2 indicate that applying layer normalization to attention outputs leads to consistently better accuracy on out-of-distribution sequence lengths, with statistical significance confirmed by a paired t𝑡titalic_t-test over 100100100100 training runs from different random seeds.

Test-time adaptation and fine-tuning are common techniques for improving length generalization in transformers (Anil et al., 2022; Veličković et al., 2024). To show that the benefits of layer normalization are orthogonal to these techniques, we implement the adaptive temperature method from Veličković et al. (2024) in both architectures, with and without layer normalization. Combined with test-time adaptation, layer normalization still yields a significant improvement. In Fig. 4, we demonstrate that layer normalization also mitigates dispersion (Veličković et al., 2024).

Refer to caption
Figure 4: Heatmap of the largest 16161616 attention weights, computed over 32323232 examples. Layer normalization mitigates dispersion, which is inevitable as sequence length increases (Veličković et al., 2024).
Table 1: Results (%) on the argmaxargmax\operatorname{argmax}roman_argmax retrieval task. Results are averaged over 100100100100 runs with different random seeds. p𝑝pitalic_p-values are computed using a paired t𝑡titalic_t-test. Entries highlighted in green indicate those with in-distribution sequence lengths.
Model 24superscript242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 25superscript252^{5}2 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 211superscript2112^{11}2 start_POSTSUPERSCRIPT 11 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 213superscript2132^{13}2 start_POSTSUPERSCRIPT 13 end_POSTSUPERSCRIPT 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT
w.o. test-time adaptation
Baseline 99.699.6\mathbf{99.6}bold_99.6 99.299.299.299.2 98.498.4{98.4}98.4 96.896.896.896.8 93.793.793.793.7 88.088.088.088.0 78.078.078.078.0 62.562.562.562.5 44.244.244.244.2 29.729.729.729.7 20.820.820.820.8
Baseline (+ LN) 99.699.6\mathbf{99.6}bold_99.6 99.399.3\mathbf{99.3}bold_99.3 98.698.6\mathbf{98.6}bold_98.6 97.497.4\mathbf{97.4}bold_97.4 94.894.8\mathbf{94.8}bold_94.8 89.889.8\mathbf{89.8}bold_89.8 81.081.0\mathbf{81.0}bold_81.0 66.966.9\mathbf{66.9}bold_66.9 49.249.2\mathbf{49.2}bold_49.2 33.033.0\mathbf{33.0}bold_33.0 22.622.6\mathbf{22.6}bold_22.6
p𝑝pitalic_p-value 8/1018superscript1018/10^{1}8 / 10 start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT 4/1024superscript1024/10^{2}4 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 4/1024superscript1024/10^{2}4 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 2/1022superscript1022/10^{2}2 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 4/1054superscript1054/10^{5}4 / 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 2/1042superscript1042/10^{4}2 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 1/1041superscript1041/10^{4}1 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 2/1052superscript1052/10^{5}2 / 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 2/1062superscript1062/10^{6}2 / 10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 4/1054superscript1054/10^{5}4 / 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 3/1033superscript1033/10^{3}3 / 10 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT
w. test-time adaptation
Adaptive θ𝜃\thetaitalic_θ 99.699.699.699.6 99.299.299.299.2 98.598.5{98.5}98.5 96.996.996.996.9 94.194.194.194.1 89.189.189.189.1 81.281.281.281.2 69.169.169.169.1 54.254.254.254.2 39.039.039.039.0 27.127.127.127.1
Adaptive θ𝜃\thetaitalic_θ (+ LN) 99.799.7\mathbf{99.7}bold_99.7 99.499.4\mathbf{99.4}bold_99.4 98.798.7\mathbf{98.7}bold_98.7 97.597.5\mathbf{97.5}bold_97.5 95.195.1\mathbf{95.1}bold_95.1 91.091.0\mathbf{91.0}bold_91.0 84.084.0\mathbf{84.0}bold_84.0 73.673.6\mathbf{73.6}bold_73.6 58.958.9\mathbf{58.9}bold_58.9 43.143.1\mathbf{43.1}bold_43.1 30.430.4\mathbf{30.4}bold_30.4
p𝑝pitalic_p-value 7/1017superscript1017/10^{1}7 / 10 start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT 5/1045superscript1045/10^{4}5 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 1/1021superscript1021/10^{2}1 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 5/1055superscript1055/10^{5}5 / 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 4/1044superscript1044/10^{4}4 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 5/1055superscript1055/10^{5}5 / 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 1/1041superscript1041/10^{4}1 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 2/1052superscript1052/10^{5}2 / 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 8/1058superscript1058/10^{5}8 / 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 4/1044superscript1044/10^{4}4 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 7/1047superscript1047/10^{4}7 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT
Table 2: Results (%) on the dictionary lookup task. Results are averaged over 100100100100 runs with different random seeds. p𝑝pitalic_p-values are computed using a paired t𝑡titalic_t-test. Entries highlighted in green indicate those with in-distribution sequence lengths.
Model 24superscript242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 25superscript252^{5}2 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 211superscript2112^{11}2 start_POSTSUPERSCRIPT 11 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 213superscript2132^{13}2 start_POSTSUPERSCRIPT 13 end_POSTSUPERSCRIPT 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT
w.o. test-time adaptation
Baseline 99.399.399.399.3 98.698.698.698.6 97.397.3{97.3}97.3 94.794.794.794.7 89.589.589.589.5 80.480.480.480.4 67.667.667.667.6 52.952.952.952.9 38.738.738.738.7 26.526.526.526.5 17.817.817.817.8
Baseline (+ LN) 99.499.4\mathbf{99.4}bold_99.4 98.898.8\mathbf{98.8}bold_98.8 97.697.6\mathbf{97.6}bold_97.6 95.395.3\mathbf{95.3}bold_95.3 90.790.7\mathbf{90.7}bold_90.7 82.982.9\mathbf{82.9}bold_82.9 71.771.7\mathbf{71.7}bold_71.7 57.757.7\mathbf{57.7}bold_57.7 44.144.1\mathbf{44.1}bold_44.1 32.332.3\mathbf{32.3}bold_32.3 22.422.4\mathbf{22.4}bold_22.4
p𝑝pitalic_p-value 6/1026superscript1026/10^{2}6 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 1/1021superscript1021/10^{2}1 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 5/1025superscript1025/10^{2}5 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 2/1032superscript1032/10^{3}2 / 10 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT 2/1042superscript1042/10^{4}2 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 3/1083superscript1083/10^{8}3 / 10 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 1/10111superscript10111/10^{11}1 / 10 start_POSTSUPERSCRIPT 11 end_POSTSUPERSCRIPT 2/10122superscript10122/10^{12}2 / 10 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 7/10157superscript10157/10^{15}7 / 10 start_POSTSUPERSCRIPT 15 end_POSTSUPERSCRIPT 3/10213superscript10213/10^{21}3 / 10 start_POSTSUPERSCRIPT 21 end_POSTSUPERSCRIPT 2/10192superscript10192/10^{19}2 / 10 start_POSTSUPERSCRIPT 19 end_POSTSUPERSCRIPT
w. test-time adaptation
Adaptive θ𝜃\thetaitalic_θ 99.399.399.399.3 98.698.698.698.6 97.297.2{97.2}97.2 94.594.594.594.5 89.389.389.389.3 80.480.480.480.4 67.867.867.867.8 52.652.652.652.6 38.638.638.638.6 27.327.327.327.3 20.820.820.820.8
Adaptive θ𝜃\thetaitalic_θ (+ LN) 99.499.4\mathbf{99.4}bold_99.4 98.898.8\mathbf{98.8}bold_98.8 97.697.6\mathbf{97.6}bold_97.6 95.495.4\mathbf{95.4}bold_95.4 90.690.6\mathbf{90.6}bold_90.6 82.982.9\mathbf{82.9}bold_82.9 71.771.7\mathbf{71.7}bold_71.7 57.857.8\mathbf{57.8}bold_57.8 44.544.5\mathbf{44.5}bold_44.5 33.433.4\mathbf{33.4}bold_33.4 27.727.7\mathbf{27.7}bold_27.7
p𝑝pitalic_p-value 6/1016superscript1016/10^{1}6 / 10 start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT 5/1025superscript1025/10^{2}5 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 3/1023superscript1023/10^{2}3 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 1/1041superscript1041/10^{4}1 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 3/1043superscript1043/10^{4}3 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 1/1051superscript1051/10^{5}1 / 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 8/1098superscript1098/10^{9}8 / 10 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 6/10126superscript10126/10^{12}6 / 10 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 2/10162superscript10162/10^{16}2 / 10 start_POSTSUPERSCRIPT 16 end_POSTSUPERSCRIPT 9/10209superscript10209/10^{20}9 / 10 start_POSTSUPERSCRIPT 20 end_POSTSUPERSCRIPT 1/10211superscript10211/10^{21}1 / 10 start_POSTSUPERSCRIPT 21 end_POSTSUPERSCRIPT

Does layer normalization alleviate distribution shift?

Layer normalization does alleviate—but not eliminate—distribution shift. With layer normalization, the global mean and global variance remain more stable on out-of-distribution sequence lengths (Fig. 3). However, the variance of fixed components in attention outputs still decays, regardless of layer normalization (Fig. 2).

3.2 Ablations

In addition to layer normalization, we explore an alternative normalization strategy in which we standardize (i.e., std. in Table 3) the attention outputs across the D𝐷Ditalic_D features without the learnable scale and shift parameters present in LN, i.e., Standardize(𝐎)b,t,d=𝐎b,t,dμb,tσb,t+ϵ,\operatorname{Standardize}(\mathbf{O})_{b,t,d}=\frac{\mathbf{O}_{b,t,d}-% \mathbf{\mu}_{b,t}}{\mathbf{\sigma}_{b,t}+\epsilon},roman_Standardize ( bold_O ) start_POSTSUBSCRIPT italic_b , italic_t , italic_d end_POSTSUBSCRIPT = divide start_ARG bold_O start_POSTSUBSCRIPT italic_b , italic_t , italic_d end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT italic_b , italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_b , italic_t end_POSTSUBSCRIPT + italic_ϵ end_ARG , where μb,tsubscript𝜇𝑏𝑡\mu_{b,t}italic_μ start_POSTSUBSCRIPT italic_b , italic_t end_POSTSUBSCRIPT and σb,tsubscript𝜎𝑏𝑡\sigma_{b,t}italic_σ start_POSTSUBSCRIPT italic_b , italic_t end_POSTSUBSCRIPT are computed in the same manner as in layer normalization.

As shown in Table 3, where the relative accuracy gain over Baseline on the argmax\arg\maxroman_arg roman_max retrieval task is reported, standardization improves length generalization, even though it strictly constrains model capacity. This underscores the importance (and potential benefits) of addressing the observed distribution shift. LN outperforms standardization, as confirmed by the paired t𝑡titalic_t-test. Similar ablation results on the dictionary lookup task can be found in Section B.3.

Table 3: Ablations on different normalization strategies on the argmaxargmax\operatorname{argmax}roman_argmax retrieval task. Relative results (%) compared to the Baseline (\triangle) are reported.
Model 24superscript242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 25superscript252^{5}2 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 211superscript2112^{11}2 start_POSTSUPERSCRIPT 11 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 213superscript2132^{13}2 start_POSTSUPERSCRIPT 13 end_POSTSUPERSCRIPT 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT
\triangle (+ std.) 0.050.05-0.05- 0.05 +0.000.00+0.00+ 0.00 +0.090.09+0.09+ 0.09 +0.260.26+0.26+ 0.26 +0.600.60+0.60+ 0.60 +0.840.84+0.84+ 0.84 +1.621.62+1.62+ 1.62 +2.262.26+2.26+ 2.26 +3.053.05+3.05+ 3.05 +1.801.80+1.80+ 1.80 +0.700.70+0.70+ 0.70
\triangle (+ LN) +0.010.01\mathbf{+0.01}+ bold_0.01 +0.110.11\mathbf{+0.11}+ bold_0.11 +0.210.21\mathbf{+0.21}+ bold_0.21 +0.570.57\mathbf{+0.57}+ bold_0.57 +1.151.15\mathbf{+1.15}+ bold_1.15 +1.811.81\mathbf{+1.81}+ bold_1.81 +2.982.98\mathbf{+2.98}+ bold_2.98 +4.324.32\mathbf{+4.32}+ bold_4.32 +4.994.99\mathbf{+4.99}+ bold_4.99 +3.303.30\mathbf{+3.30}+ bold_3.30 +1.761.76\mathbf{+1.76}+ bold_1.76
p𝑝pitalic_p-value 2/1032superscript1032/10^{3}2 / 10 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT 7/1047superscript1047/10^{4}7 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 1/1021superscript1021/10^{2}1 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 5/1045superscript1045/10^{4}5 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 3/1043superscript1043/10^{4}3 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 3/1043superscript1043/10^{4}3 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 7/1047superscript1047/10^{4}7 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 2/1042superscript1042/10^{4}2 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 3/1043superscript1043/10^{4}3 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 1/1031superscript1031/10^{3}1 / 10 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT 5/1035superscript1035/10^{3}5 / 10 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT

4 Related Work

Positional encoding for length generalization.

Many works have attributed the inability of Transformers to extrapolate to longer sequences to positional encoding. Several alternatives to the sinusoidal positional encoding originally introduced by Vaswani et al. (2017) have been proposed to enhance the performance of Transformer-based models in natural language processing (NLP) tasks, including relative positional encoding (Shaw et al., 2018; Dai et al., 2019), rotary positional encoding (Su et al., 2024), no positional encoding (Haviv et al., 2022) and randomized positional encoding (Ruoss et al., 2023). Authors have examined the impact of different variants of positional encoding on length generalization (Chi et al., 2022; Ruoss et al., 2023; Kazemnejad et al., 2024; Li et al., 2024; Peng et al., 2024; Zhou et al., 2024). Unlike prior work that explores positional encoding for length generalization, we focus on algorithmic tasks that are order-invariant. We present a vanishing variance perspective on length generalization which is orthogonal to the extrapolability of positional encoding.

Alternatives to softmax attention.

The softmaxsoftmax\operatorname{softmax}roman_softmax output has been utilized to interpret the inner workings of Transformers (Xu et al., 2015; Choi et al., 2016; Martins & Astudillo, 2016). More recently, Veličković et al. (2024) demonstrated that the attention weights output by softmaxsoftmax\operatorname{softmax}roman_softmax will disperse as sequence length increases, attributing this phenomenon to the Transformer’s limited capability in length generalization. In this paper, we show that this dispersion leads to the vanishing variance problem in the intermediate attention outputs. While many variants of softmax attention have been introduced (Correia et al., 2019; Press et al., 2022; Tan et al., 2024; Ye et al., 2024), they are motivated mostly by interpretability, rather than the distribution of attention outputs for length generalization. To the best of our knowledge, none of the existing works have fundamentally eliminated the vanishing variance problem we presented in this paper. We hope our study can motivate designs of network architectures that are provably invariant to sequence length variations.

5 Conclusion

In this paper, we have introduced the vanishing variance problem and provided both theoretical analysis and empirical evidence demonstrating its role in inducing distribution shift in attention outputs. This shift hinders the ability of Transformers to generalize effectively to out-of-distribution sequence lengths. We demonstrated that mitigating this distribution shift through techniques like layer normalization and standardization—despite potential trade-offs in model expressiveness—significantly improves length generalization in attention models.

Future work.

We conduct our experiments using a single-layer, single-head attention architecture for simplicity, while real-world models typically use multi-layer, multi-head attention. Our conclusions may not fully generalize to these more complex architectures. Future work may validate the normalization strategies on larger benchmarks like CLRS (Veličković et al., 2022) and real-world LLMs. Moreover, layer normalization only partially mitigates distribution shift presented in this paper, and is already widely adopted in Transformers (though not immediately after attention outputs). Future work may design architectures that are provably invariant to sequence length variations.

References

  • AI@Meta (2024) AI@Meta. Llama 3 model card. 2024. URL https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md.
  • Anil et al. (2022) Cem Anil, Yuhuai Wu, Anders Andreassen, Aitor Lewkowycz, Vedant Misra, Vinay Ramasesh, Ambrose Slone, Guy Gur-Ari, Ethan Dyer, and Behnam Neyshabur. Exploring length generalization in large language models. Advances in Neural Information Processing Systems, 35:38546–38556, 2022.
  • Ba et al. (2016) Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
  • Brown et al. (2020) Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. In NeurIPS, 2020.
  • Chi et al. (2022) Ta-Chung Chi, Ting-Han Fan, Peter J Ramadge, and Alexander Rudnicky. Kerple: Kernelized relative positional embedding for length extrapolation. NeurIPS, 2022.
  • Choi et al. (2016) Edward Choi, Mohammad Taha Bahadori, Jimeng Sun, Joshua Kulas, Andy Schuetz, and Walter Stewart. Retain: An interpretable predictive model for healthcare using reverse time attention mechanism. NeurIPS, 2016.
  • Correia et al. (2019) Gonçalo M Correia, Vlad Niculae, and André FT Martins. Adaptively sparse transformers. arXiv preprint arXiv:1909.00015, 2019.
  • Dai et al. (2019) Zihang Dai, Zhilin Yang, Yiming Yang, Jaime G Carbonell, Quoc Le, and Ruslan Salakhutdinov. Transformer-xl: Attentive language models beyond a fixed-length context. In ACL, 2019.
  • Devlin et al. (2019) Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-training of deep bidirectional transformers for language understanding. In NAACL, 2019.
  • Dosovitskiy et al. (2021) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2021.
  • Haviv et al. (2022) Adi Haviv, Ori Ram, Ofir Press, Peter Izsak, and Omer Levy. Transformer language models without positional encodings still learn positional information. In EMNLP, 2022.
  • Jumper et al. (2021) John M. Jumper, Richard Evans, Alexander Pritzel, Tim Green, Michael Figurnov, Olaf Ronneberger, Kathryn Tunyasuvunakool, Russ Bates, Augustin Žídek, Anna Potapenko, Alex Bridgland, Clemens Meyer, Simon A A Kohl, Andy Ballard, Andrew Cowie, Bernardino Romera-Paredes, Stanislav Nikolov, Rishub Jain, Jonas Adler, Trevor Back, Stig Petersen, David Reiman, Ellen Clancy, Michal Zielinski, Martin Steinegger, Michalina Pacholska, Tamas Berghammer, Sebastian Bodenstein, David Silver, Oriol Vinyals, Andrew W. Senior, Koray Kavukcuoglu, Pushmeet Kohli, and Demis Hassabis. Highly accurate protein structure prediction with alphafold. Nature, 2021.
  • Kazemnejad et al. (2024) Amirhossein Kazemnejad, Inkit Padhi, Karthikeyan Natesan Ramamurthy, Payel Das, and Siva Reddy. The impact of positional encoding on length generalization in transformers. NeurIPS, 2024.
  • Li et al. (2024) Shanda Li, Chong You, Guru Guruganesh, Joshua Ainslie, Santiago Ontanon, Manzil Zaheer, Sumit Sanghai, Yiming Yang, Sanjiv Kumar, and Srinadh Bhojanapalli. Functional interpolation for relative positions improves long context transformers. In ICLR, 2024.
  • Martins & Astudillo (2016) Andre Martins and Ramon Astudillo. From softmax to sparsemax: A sparse model of attention and multi-label classification. In ICML, 2016.
  • Peebles & Xie (2023) William Peebles and Saining Xie. Scalable diffusion models with transformers. In ICCV, 2023.
  • Peng et al. (2024) Bowen Peng, Jeffrey Quesnelle, Honglu Fan, and Enrico Shippole. Yarn: Efficient context window extension of large language models. In ICLR, 2024.
  • Press et al. (2022) Ofir Press, Noah A. Smith, and Mike Lewis. Train short, test long: Attention with linear biases enables input length extrapolation. In ICLR, 2022.
  • Radford et al. (2023) Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, and Ilya Sutskever. Robust speech recognition via large-scale weak supervision. In ICML, 2023.
  • Ruoss et al. (2023) Anian Ruoss, Grégoire Delétang, Tim Genewein, Jordi Grau-Moya, Róbert Csordás, Mehdi Bennani, Shane Legg, and Joel Veness. Randomized positional encodings boost length generalization of transformers. In ACL, 2023.
  • Shaw et al. (2018) Peter Shaw, Jakob Uszkoreit, and Ashish Vaswani. Self-attention with relative position representations. arXiv preprint arXiv:1803.02155, 2018.
  • Su et al. (2024) Jianlin Su, Murtadha Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding. Neurocomputing, 2024.
  • Tan et al. (2024) Shawn Tan, Yikang Shen, Songlin Yang, Aaron Courville, and Rameswar Panda. Stick-breaking attention. arXiv preprint arXiv:2410.17980, 2024.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NeurIPS, 2017.
  • Veličković et al. (2022) Petar Veličković, Adrià Puigdomènech Badia, David Budden, Razvan Pascanu, Andrea Banino, Misha Dashevskiy, Raia Hadsell, and Charles Blundell. The clrs algorithmic reasoning benchmark. In ICML, 2022.
  • Veličković et al. (2024) Petar Veličković, Christos Perivolaropoulos, Federico Barbero, and Razvan Pascanu. softmax is not enough (for sharp out-of-distribution). arXiv preprint arXiv:2410.01104, 2024.
  • Xu et al. (2015) Kelvin Xu, Jimmy Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan Salakhudinov, Rich Zemel, and Yoshua Bengio. Show, attend and tell: Neural image caption generation with visual attention. In ICML, 2015.
  • Ye et al. (2024) Tianzhu Ye, Li Dong, Yuqing Xia, Yutao Sun, Yi Zhu, Gao Huang, and Furu Wei. Differential transformer. arXiv preprint arXiv:2410.05258, 2024.
  • Zhou et al. (2023) Hattie Zhou, Arwen Bradley, Etai Littwin, Noam Razin, Omid Saremi, Josh Susskind, Samy Bengio, and Preetum Nakkiran. What algorithms can transformers learn? a study in length generalization. arXiv preprint arXiv:2310.16028, 2023.
  • Zhou et al. (2022) Kaiyang Zhou, Ziwei Liu, Yu Qiao, Tao Xiang, and Chen Change Loy. Domain generalization: A survey. PAMI, 2022.
  • Zhou et al. (2024) Yongchao Zhou, Uri Alon, Xinyun Chen, Xuezhi Wang, Rishabh Agarwal, and Denny Zhou. Transformers can achieve length generalization but not robustly. arXiv preprint arXiv:2402.09371, 2024.

Appendix A Proof of Proposition 1

Proposition 1 (The vanishing variance problem).

Consider a trained attention module with weights 𝐖Q,𝐖K,𝐖V,𝐖Osubscript𝐖𝑄subscript𝐖𝐾subscript𝐖𝑉subscript𝐖𝑂\mathbf{W}_{Q},\mathbf{W}_{K},\mathbf{W}_{V},\mathbf{W}_{O}bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT. Let 𝐗=[𝐱1𝐱2𝐱N]𝐗superscriptdelimited-[]conditionalsubscript𝐱1normsubscript𝐱2subscript𝐱𝑁top\mathbf{X}=\left[\mathbf{x}_{1}\|\mathbf{x}_{2}\|\dots\|\mathbf{x}_{N}\right]^% {\top}bold_X = [ bold_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ bold_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ … ∥ bold_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT denote an input sequence of length N𝑁Nitalic_N. If (1) 𝐱1,𝐱2,,𝐱Ni.i.d𝒳subscript𝐱1subscript𝐱2subscript𝐱𝑁i.i.dsimilar-to𝒳\mathbf{x}_{1},\mathbf{x}_{2},\dots,\mathbf{x}_{N}\overset{\text{i.i.d}}{\sim}% \mathcal{X}bold_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT overi.i.d start_ARG ∼ end_ARG caligraphic_X, a distribution over a finite vocabulary, and (2) 𝔼𝐱𝒳[𝐖V𝐱]=𝟎subscript𝔼similar-to𝐱𝒳delimited-[]subscript𝐖𝑉𝐱0\mathbb{E}_{\mathbf{x}\sim\mathcal{X}}[\mathbf{W}_{V}\mathbf{x}]=\mathbf{0}blackboard_E start_POSTSUBSCRIPT bold_x ∼ caligraphic_X end_POSTSUBSCRIPT [ bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT bold_x ] = bold_0, then for a fixed query 𝐲𝐲\mathbf{y}bold_y and a fixed feature d𝑑ditalic_d,

limNVar(𝐱1,𝐱2,,𝐱N)𝒳N([softmax(𝐐𝐊D)𝐕]d)=0,subscript𝑁subscriptVarsimilar-tosubscript𝐱1subscript𝐱2subscript𝐱𝑁superscript𝒳𝑁subscriptdelimited-[]softmaxsuperscript𝐐𝐊top𝐷𝐕𝑑0\displaystyle\lim_{N\to\infty}\operatorname{Var}_{(\mathbf{x}_{1},\mathbf{x}_{% 2},\ldots,\mathbf{x}_{N})\sim\mathcal{X}^{N}}\left(\left[\operatorname{softmax% }\left(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{D}}\right)\mathbf{V}\right]_{d% }\right)=0,roman_lim start_POSTSUBSCRIPT italic_N → ∞ end_POSTSUBSCRIPT roman_Var start_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ∼ caligraphic_X start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( [ roman_softmax ( divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D end_ARG end_ARG ) bold_V ] start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) = 0 ,

where 𝐱n,𝐲Dsubscript𝐱𝑛𝐲superscript𝐷\mathbf{x}_{n},\mathbf{y}\in\mathbb{R}^{D}bold_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , bold_y ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT and 𝐐1×D,𝐊N×D,𝐕N×Dformulae-sequence𝐐superscript1𝐷formulae-sequence𝐊superscript𝑁𝐷𝐕superscript𝑁𝐷\mathbf{Q}\in\mathbb{R}^{1\times D},\mathbf{K}\in\mathbb{R}^{N\times D},% \mathbf{V}\in\mathbb{R}^{N\times D}bold_Q ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_D end_POSTSUPERSCRIPT , bold_K ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT , bold_V ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT are intermediate results in Attn(𝐗,[𝐲])Attn𝐗delimited-[]𝐲\operatorname{Attn}(\mathbf{X},[\mathbf{y}])roman_Attn ( bold_X , [ bold_y ] ).

Informally, for a fixed component d𝑑ditalic_d in the attention outputs, its variance over input sequences of length N𝑁Nitalic_N, where each sequence consists of N𝑁Nitalic_N independently and identically distributed (i.i.d.) tokens, vanishes as N𝑁N\to\inftyitalic_N → ∞.

Proof.

Let 𝐕n,d=[𝐖V𝐱n]dsubscript𝐕𝑛𝑑subscriptdelimited-[]subscript𝐖𝑉subscript𝐱𝑛𝑑\mathbf{V}_{n,d}=[\mathbf{W}_{V}\mathbf{x}_{n}]_{d}bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT = [ bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT bold_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT. Firstly, we argue that 𝐕k,dsubscript𝐕𝑘𝑑\mathbf{V}_{k,d}bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT and 𝐕l,dsubscript𝐕𝑙𝑑\mathbf{V}_{l,d}bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT are independent for kl𝑘𝑙k\neq litalic_k ≠ italic_l. Let πd:D:subscript𝜋𝑑superscript𝐷\pi_{d}:\mathbb{R}^{D}\to\mathbb{R}italic_π start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → blackboard_R be the projection onto the d𝑑ditalic_d-th coordinate, namely πd(𝐯)=𝐯dsubscript𝜋𝑑𝐯subscript𝐯𝑑\pi_{d}(\mathbf{v})=\mathbf{v}_{d}italic_π start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( bold_v ) = bold_v start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT for 𝐯D𝐯superscript𝐷\mathbf{v}\in\mathbb{R}^{D}bold_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT. Observe that 𝐕k,d=πd(𝐖V𝐱k)=(πd𝐖V)(𝐱k)subscript𝐕𝑘𝑑subscript𝜋𝑑subscript𝐖𝑉subscript𝐱𝑘subscript𝜋𝑑subscript𝐖𝑉subscript𝐱𝑘\mathbf{V}_{k,d}=\pi_{d}(\mathbf{W}_{V}\mathbf{x}_{k})=(\pi_{d}\circ\mathbf{W}% _{V})(\mathbf{x}_{k})bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT = italic_π start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT bold_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = ( italic_π start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ∘ bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ) ( bold_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) and 𝐕l,d=πd(𝐖V𝐱l)=(πd𝐖V)(𝐱l)subscript𝐕𝑙𝑑subscript𝜋𝑑subscript𝐖𝑉subscript𝐱𝑙subscript𝜋𝑑subscript𝐖𝑉subscript𝐱𝑙\mathbf{V}_{l,d}=\pi_{d}(\mathbf{W}_{V}\mathbf{x}_{l})=(\pi_{d}\circ\mathbf{W}% _{V})(\mathbf{x}_{l})bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT = italic_π start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) = ( italic_π start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ∘ bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ) ( bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ). By assumption (1), 𝐱ksubscript𝐱𝑘\mathbf{x}_{k}bold_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and 𝐱lsubscript𝐱𝑙\mathbf{x}_{l}bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT are independent for kl𝑘𝑙k\neq litalic_k ≠ italic_l. Since 𝐕k,dsubscript𝐕𝑘𝑑\mathbf{V}_{k,d}bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT, 𝐕l,dsubscript𝐕𝑙𝑑\mathbf{V}_{l,d}bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT are measurable functions of independent random variables, they are independent for kl𝑘𝑙k\neq litalic_k ≠ italic_l. By assumption, 𝐱ksubscript𝐱𝑘\mathbf{x}_{k}bold_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and 𝐱lsubscript𝐱𝑙\mathbf{x}_{l}bold_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT are identically distributed for every k,l𝑘𝑙k,litalic_k , italic_l, so 𝐕k,dsubscript𝐕𝑘𝑑\mathbf{V}_{k,d}bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT and 𝐕l,dsubscript𝐕𝑙𝑑\mathbf{V}_{l,d}bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT are identically distributed. Thus, Var[𝐕k,d]Varsubscript𝐕𝑘𝑑\operatorname{Var}[\mathbf{V}_{k,d}]roman_Var [ bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT ] depends only on d𝑑ditalic_d, not on k𝑘kitalic_k. Set σd2=Var[𝐕k,d]superscriptsubscript𝜎𝑑2Varsubscript𝐕𝑘𝑑\sigma_{d}^{2}=\operatorname{Var}[\mathbf{V}_{k,d}]italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = roman_Var [ bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT ]. σd2superscriptsubscript𝜎𝑑2\sigma_{d}^{2}italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT is finite since the vocabulary is finite and thus compact and bounded.

Let 𝐀𝐀\mathbf{A}bold_A denote the attention weights softmax(𝐐𝐊D)1×Nsoftmaxsuperscript𝐐𝐊top𝐷superscript1𝑁\operatorname{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{D}}\right% )\in\mathbb{R}^{1\times N}roman_softmax ( divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D end_ARG end_ARG ) ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_N end_POSTSUPERSCRIPT as the query sequence [𝐲]delimited-[]𝐲[\mathbf{y}][ bold_y ] consists of only a single item. Let 𝐀nsubscript𝐀𝑛\mathbf{A}_{n}bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT denote the n𝑛nitalic_n-th element of 𝐀𝐀\mathbf{A}bold_A. We have

Var(n=1N𝐀n𝐕n,d)Varsuperscriptsubscript𝑛1𝑁subscript𝐀𝑛subscript𝐕𝑛𝑑\displaystyle\operatorname{Var}\left(\sum_{n=1}^{N}\mathbf{A}_{n}\mathbf{V}_{n% ,d}\right)roman_Var ( ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT ) =𝔼[(n=1N𝐀n𝐕n,d)2](𝔼[n=1N𝐀n𝐕n,d])2absent𝔼delimited-[]superscriptsuperscriptsubscript𝑛1𝑁subscript𝐀𝑛subscript𝐕𝑛𝑑2superscript𝔼delimited-[]superscriptsubscript𝑛1𝑁subscript𝐀𝑛subscript𝐕𝑛𝑑2\displaystyle=\mathbb{E}\left[\left(\sum_{n=1}^{N}\mathbf{A}_{n}\mathbf{V}_{n,% d}\right)^{2}\right]-\left(\mathbb{E}\left[\sum_{n=1}^{N}\mathbf{A}_{n}\mathbf% {V}_{n,d}\right]\right)^{2}= blackboard_E [ ( ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] - ( blackboard_E [ ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT ] ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
𝔼[(n=1N𝐀n𝐕n,d)2]absent𝔼delimited-[]superscriptsuperscriptsubscript𝑛1𝑁subscript𝐀𝑛subscript𝐕𝑛𝑑2\displaystyle\leq\mathbb{E}\left[\left(\sum_{n=1}^{N}\mathbf{A}_{n}\mathbf{V}_% {n,d}\right)^{2}\right]≤ blackboard_E [ ( ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]
=𝔼[n=1N𝐀n2𝐕n,d2]+𝔼[1k,lN,kl𝐀k𝐀l𝐕k,d𝐕l,d]absent𝔼delimited-[]superscriptsubscript𝑛1𝑁superscriptsubscript𝐀𝑛2superscriptsubscript𝐕𝑛𝑑2𝔼delimited-[]subscriptformulae-sequence1𝑘formulae-sequence𝑙𝑁𝑘𝑙subscript𝐀𝑘subscript𝐀𝑙subscript𝐕𝑘𝑑subscript𝐕𝑙𝑑\displaystyle=\mathbb{E}\left[\sum_{n=1}^{N}\mathbf{A}_{n}^{2}\mathbf{V}_{n,d}% ^{2}\right]+\mathbb{E}\left[\sum_{1\leq k,l\leq N,k\neq l}\mathbf{A}_{k}% \mathbf{A}_{l}\mathbf{V}_{k,d}\mathbf{V}_{l,d}\right]= blackboard_E [ ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] + blackboard_E [ ∑ start_POSTSUBSCRIPT 1 ≤ italic_k , italic_l ≤ italic_N , italic_k ≠ italic_l end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ]
𝔼[n=1N(max1nN𝐀n)2𝐕n,d2]+𝔼[1k,lN,kl,𝐕k,d𝐕l,d0𝐀k𝐀l𝐕k,d𝐕l,d]absent𝔼delimited-[]superscriptsubscript𝑛1𝑁superscriptsubscriptmax1𝑛𝑁subscript𝐀𝑛2superscriptsubscript𝐕𝑛𝑑2𝔼delimited-[]subscriptformulae-sequence1𝑘formulae-sequence𝑙𝑁formulae-sequence𝑘𝑙subscript𝐕𝑘𝑑subscript𝐕𝑙𝑑0subscript𝐀𝑘subscript𝐀𝑙subscript𝐕𝑘𝑑subscript𝐕𝑙𝑑\displaystyle\leq\mathbb{E}\left[\sum_{n=1}^{N}(\operatorname{max}_{1\leq n% \leq N}\mathbf{A}_{n})^{2}\mathbf{V}_{n,d}^{2}\right]+\mathbb{E}\left[\sum_{1% \leq k,l\leq N,k\neq l,\mathbf{V}_{k,d}\mathbf{V}_{l,d}\geq 0}\mathbf{A}_{k}% \mathbf{A}_{l}\mathbf{V}_{k,d}\mathbf{V}_{l,d}\right]≤ blackboard_E [ ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( roman_max start_POSTSUBSCRIPT 1 ≤ italic_n ≤ italic_N end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] + blackboard_E [ ∑ start_POSTSUBSCRIPT 1 ≤ italic_k , italic_l ≤ italic_N , italic_k ≠ italic_l , bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ≥ 0 end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ]
(max1nN𝐀n)2n=1N𝔼[𝐕n,d2]+𝔼[1k,lN,kl,𝐕k,d𝐕l,d0𝐕k,d𝐕l,d]absentsuperscriptsubscriptmax1𝑛𝑁subscript𝐀𝑛2superscriptsubscript𝑛1𝑁𝔼delimited-[]superscriptsubscript𝐕𝑛𝑑2𝔼delimited-[]subscriptformulae-sequence1𝑘formulae-sequence𝑙𝑁formulae-sequence𝑘𝑙subscript𝐕𝑘𝑑subscript𝐕𝑙𝑑0subscript𝐕𝑘𝑑subscript𝐕𝑙𝑑\displaystyle\leq(\operatorname{max}_{1\leq n\leq N}\mathbf{A}_{n})^{2}\sum_{n% =1}^{N}\mathbb{E}\left[\mathbf{V}_{n,d}^{2}\right]+\mathbb{E}\left[\sum_{1\leq k% ,l\leq N,k\neq l,\mathbf{V}_{k,d}\mathbf{V}_{l,d}\geq 0}\mathbf{V}_{k,d}% \mathbf{V}_{l,d}\right]≤ ( roman_max start_POSTSUBSCRIPT 1 ≤ italic_n ≤ italic_N end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E [ bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] + blackboard_E [ ∑ start_POSTSUBSCRIPT 1 ≤ italic_k , italic_l ≤ italic_N , italic_k ≠ italic_l , bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ≥ 0 end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ]
N(max1nN𝐀n)2σd2+1k,lN,kl,𝐕k,d𝐕l,d0𝔼[𝐕k,d𝐕l,d]absent𝑁superscriptsubscriptmax1𝑛𝑁subscript𝐀𝑛2superscriptsubscript𝜎𝑑2subscriptformulae-sequence1𝑘formulae-sequence𝑙𝑁formulae-sequence𝑘𝑙subscript𝐕𝑘𝑑subscript𝐕𝑙𝑑0𝔼delimited-[]subscript𝐕𝑘𝑑subscript𝐕𝑙𝑑\displaystyle\leq N(\operatorname{max}_{1\leq n\leq N}\mathbf{A}_{n})^{2}% \sigma_{d}^{2}+\sum_{1\leq k,l\leq N,k\neq l,\mathbf{V}_{k,d}\mathbf{V}_{l,d}% \geq 0}\mathbb{E}\left[\mathbf{V}_{k,d}\mathbf{V}_{l,d}\right]≤ italic_N ( roman_max start_POSTSUBSCRIPT 1 ≤ italic_n ≤ italic_N end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT 1 ≤ italic_k , italic_l ≤ italic_N , italic_k ≠ italic_l , bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ≥ 0 end_POSTSUBSCRIPT blackboard_E [ bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ]
N(max1nN𝐀n)2σd2+1k,lN,kl,𝐕k,d,𝐕l,d0𝔼[𝐕k,d]𝔼[𝐕l,d]absent𝑁superscriptsubscriptmax1𝑛𝑁subscript𝐀𝑛2superscriptsubscript𝜎𝑑2subscriptformulae-sequence1𝑘formulae-sequence𝑙𝑁formulae-sequence𝑘𝑙subscript𝐕𝑘𝑑subscript𝐕𝑙𝑑0𝔼delimited-[]subscript𝐕𝑘𝑑𝔼delimited-[]subscript𝐕𝑙𝑑\displaystyle\leq N(\operatorname{max}_{1\leq n\leq N}\mathbf{A}_{n})^{2}% \sigma_{d}^{2}+\sum_{1\leq k,l\leq N,k\neq l,\mathbf{V}_{k,d},\mathbf{V}_{l,d}% \geq 0}\mathbb{E}\left[\mathbf{V}_{k,d}\right]\mathbb{E}\left[\mathbf{V}_{l,d}\right]≤ italic_N ( roman_max start_POSTSUBSCRIPT 1 ≤ italic_n ≤ italic_N end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT 1 ≤ italic_k , italic_l ≤ italic_N , italic_k ≠ italic_l , bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT , bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ≥ 0 end_POSTSUBSCRIPT blackboard_E [ bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT ] blackboard_E [ bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT ]
N(max1nN𝐀n)2σd2absent𝑁superscriptsubscriptmax1𝑛𝑁subscript𝐀𝑛2superscriptsubscript𝜎𝑑2\displaystyle\leq N(\operatorname{max}_{1\leq n\leq N}\mathbf{A}_{n})^{2}% \sigma_{d}^{2}≤ italic_N ( roman_max start_POSTSUBSCRIPT 1 ≤ italic_n ≤ italic_N end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

In the derivation above, we used the fact that 𝐀n[0,1]subscript𝐀𝑛01\mathbf{A}_{n}\in[0,1]bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ [ 0 , 1 ] and that for kl𝑘𝑙k\neq litalic_k ≠ italic_l, 𝐕k,dsubscript𝐕𝑘𝑑\mathbf{V}_{k,d}bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT and 𝐕l,dsubscript𝐕𝑙𝑑\mathbf{V}_{l,d}bold_V start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT are independent. We also used the assumption that for every 1kN1𝑘𝑁1\leq k\leq N1 ≤ italic_k ≤ italic_N, 𝔼[𝐕k,d]=0𝔼delimited-[]subscript𝐕𝑘𝑑0\mathbb{E}\left[\mathbf{V}_{k,d}\right]=0blackboard_E [ bold_V start_POSTSUBSCRIPT italic_k , italic_d end_POSTSUBSCRIPT ] = 0. Since the tokens come from a finite dictionary, and since 𝐱𝐖Q𝐱𝐱subscript𝐖𝑄𝐱\mathbf{x}\to\mathbf{W}_{Q}\mathbf{x}bold_x → bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT bold_x and 𝐱𝐖K𝐱𝐱subscript𝐖𝐾𝐱\mathbf{x}\to\mathbf{W}_{K}\mathbf{x}bold_x → bold_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_x are continuous functions on compact domain (dictionary is finite), the logits 𝐖Q𝐱i,𝐖K𝐱jsubscript𝐖𝑄subscript𝐱𝑖subscript𝐖𝐾subscript𝐱𝑗\langle\mathbf{W}_{Q}\mathbf{x}_{i},\mathbf{W}_{K}\mathbf{x}_{j}\rangle⟨ bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ⟩ are bounded, because they are continuous image of a compact set and every compact set on the real line is closed and bounded. By Lemma 2.1 of Veličković et al. (2024), there exist a constant C>0𝐶0C>0italic_C > 0 and N0subscript𝑁0N_{0}\in\mathbb{N}italic_N start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_N, such that for every NN0𝑁subscript𝑁0N\geq N_{0}italic_N ≥ italic_N start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, (max1nN𝐀n)2<CN2superscriptsubscriptmax1𝑛𝑁subscript𝐀𝑛2𝐶superscript𝑁2(\operatorname{max}_{1\leq n\leq N}\mathbf{A}_{n})^{2}<\frac{C}{N^{2}}( roman_max start_POSTSUBSCRIPT 1 ≤ italic_n ≤ italic_N end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT < divide start_ARG italic_C end_ARG start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG. Then for every NN0𝑁subscript𝑁0N\geq N_{0}italic_N ≥ italic_N start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT,

Var(n=1N𝐀n𝐕n,d)Varsuperscriptsubscript𝑛1𝑁subscript𝐀𝑛subscript𝐕𝑛𝑑\displaystyle\operatorname{Var}\left(\sum_{n=1}^{N}\mathbf{A}_{n}\mathbf{V}_{n% ,d}\right)roman_Var ( ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT ) Nσ2CN2=σ2CN.absent𝑁superscript𝜎2𝐶superscript𝑁2superscript𝜎2𝐶𝑁\displaystyle\leq N\sigma^{2}\frac{C}{N^{2}}=\sigma^{2}\frac{C}{N}.≤ italic_N italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG italic_C end_ARG start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG italic_C end_ARG start_ARG italic_N end_ARG .

Let ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0. There exists N1subscript𝑁1N_{1}\in\mathbb{N}italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_N such that for every NN1𝑁subscript𝑁1N\geq N_{1}italic_N ≥ italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, σ2CN<ϵsuperscript𝜎2𝐶𝑁italic-ϵ\sigma^{2}\frac{C}{N}<\epsilonitalic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG italic_C end_ARG start_ARG italic_N end_ARG < italic_ϵ. Then for every Nmax(N0,N1)𝑁maxsubscript𝑁0subscript𝑁1N\geq\operatorname{max}(N_{0},N_{1})italic_N ≥ roman_max ( italic_N start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ),

Var(n=1N𝐀n𝐕n,d)Varsuperscriptsubscript𝑛1𝑁subscript𝐀𝑛subscript𝐕𝑛𝑑\displaystyle\operatorname{Var}\left(\sum_{n=1}^{N}\mathbf{A}_{n}\mathbf{V}_{n% ,d}\right)roman_Var ( ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_n , italic_d end_POSTSUBSCRIPT ) <ϵ.absentitalic-ϵ\displaystyle<\epsilon.< italic_ϵ .

Table 4: Results (%) on the argmaxargmax\operatorname{argmax}roman_argmax retrieval task. Results are averaged over 100100100100 runs with different random seeds. p𝑝pitalic_p-values are computed using a paired t𝑡titalic_t-test. Entries highlighted in green indicate those with in-distribution sequence lengths.
Model 24superscript242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 25superscript252^{5}2 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 211superscript2112^{11}2 start_POSTSUPERSCRIPT 11 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 213superscript2132^{13}2 start_POSTSUPERSCRIPT 13 end_POSTSUPERSCRIPT 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT
Baseline 99.899.899.899.8 99.899.899.899.8 99.699.6{99.6}99.6 99.399.399.399.3 98.598.598.598.5 97.197.197.197.1 94.494.494.494.4 89.189.189.189.1 79.979.979.979.9 65.865.865.865.8 47.847.847.847.8
Baseline (+ LN) 99.999.9\mathbf{99.9}bold_99.9 99.999.9\mathbf{99.9}bold_99.9 99.899.8\mathbf{99.8}bold_99.8 99.599.5\mathbf{99.5}bold_99.5 99.199.1\mathbf{99.1}bold_99.1 98.198.1\mathbf{98.1}bold_98.1 96.296.2\mathbf{96.2}bold_96.2 92.892.8\mathbf{92.8}bold_92.8 86.386.3\mathbf{86.3}bold_86.3 75.175.1\mathbf{75.1}bold_75.1 58.758.7\mathbf{58.7}bold_58.7
p𝑝pitalic_p-value 1/1081superscript108{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}1/10^{8}}1 / 10 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 6/1066superscript106{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}6/10^{6}}6 / 10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 1/1041superscript104{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}1/10^{4}}1 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 2/1042superscript104{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}2/10^{4}}2 / 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 9/1089superscript108{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}9/10^{8}}9 / 10 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 5/1085superscript108{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}5/10^{8}}5 / 10 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 3/1083superscript108{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}3/10^{8}}3 / 10 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 5/10105superscript1010{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}5/10^{10}}5 / 10 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 1/10111superscript1011{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}1/10^{11}}1 / 10 start_POSTSUPERSCRIPT 11 end_POSTSUPERSCRIPT 3/10123superscript1012{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}3/10^{12}}3 / 10 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 1/10131superscript1013{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}1/10^{13}}1 / 10 start_POSTSUPERSCRIPT 13 end_POSTSUPERSCRIPT
Table 5: Results (%) on the dictionary lookup task. Results are averaged over 100100100100 runs with different random seeds. p𝑝pitalic_p-values are computed using a paired t𝑡titalic_t-test. Entries highlighted in green indicate those with in-distribution sequence lengths.
Model 24superscript242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 25superscript252^{5}2 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 211superscript2112^{11}2 start_POSTSUPERSCRIPT 11 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 213superscript2132^{13}2 start_POSTSUPERSCRIPT 13 end_POSTSUPERSCRIPT 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT
Baseline 99.999.9\mathbf{99.9}bold_99.9 99.999.9\mathbf{99.9}bold_99.9 99.899.8{99.8}99.8 99.799.799.799.7 99.599.599.599.5 99.199.199.199.1 98.398.398.398.3 96.596.596.596.5 93.593.593.593.5 87.787.787.787.7 77.877.877.877.8
Baseline (+ LN) 99.999.9\mathbf{99.9}bold_99.9 99.999.9\mathbf{99.9}bold_99.9 99.999.9\mathbf{99.9}bold_99.9 99.899.8\mathbf{99.8}bold_99.8 99.699.6\mathbf{99.6}bold_99.6 99.399.3\mathbf{99.3}bold_99.3 98.798.7\mathbf{98.7}bold_98.7 97.597.5\mathbf{97.5}bold_97.5 95.095.0\mathbf{95.0}bold_95.0 90.490.4\mathbf{90.4}bold_90.4 82.682.6\mathbf{82.6}bold_82.6
p𝑝pitalic_p-value 2/1012superscript101{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}2/10^{1}}2 / 10 start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT 1/1021superscript102{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}1/10^{2}}1 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 5/1025superscript102{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}5/10^{2}}5 / 10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 2/1032superscript103{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}2/10^{3}}2 / 10 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT 3/1033superscript103{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}3/10^{3}}3 / 10 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT 6/1056superscript105{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}6/10^{5}}6 / 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 4/1074superscript107{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}4/10^{7}}4 / 10 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 8/1098superscript109{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}8/10^{9}}8 / 10 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 2/10102superscript1010{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}2/10^{10}}2 / 10 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 2/10102superscript1010{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}2/10^{10}}2 / 10 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 2/10132superscript1013{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}2/10^{13}}2 / 10 start_POSTSUPERSCRIPT 13 end_POSTSUPERSCRIPT
Table 6: Ablations on different normalization strategies on the dictionary lookup task. Relative results (%) compared to the Baseline (\triangle) are reported.
Model 24superscript242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT 25superscript252^{5}2 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT 26superscript262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT 27superscript272^{7}2 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT 28superscript282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT 29superscript292^{9}2 start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT 210superscript2102^{10}2 start_POSTSUPERSCRIPT 10 end_POSTSUPERSCRIPT 211superscript2112^{11}2 start_POSTSUPERSCRIPT 11 end_POSTSUPERSCRIPT 212superscript2122^{12}2 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT 213superscript2132^{13}2 start_POSTSUPERSCRIPT 13 end_POSTSUPERSCRIPT 214superscript2142^{14}2 start_POSTSUPERSCRIPT 14 end_POSTSUPERSCRIPT
\triangle (+ std.) +0.090.09\mathbf{+0.09}+ bold_0.09 +0.140.14+0.14+ 0.14 +0.230.23+0.23+ 0.23 +0.540.54+0.54+ 0.54 +1.081.08+1.08+ 1.08 +2.272.27+2.27+ 2.27 +3.493.49+3.49+ 3.49 +4.784.78+4.78+ 4.78 +5.145.14+5.14+ 5.14 +5.655.65+5.65+ 5.65 +4.574.57\mathbf{+4.57}+ bold_4.57
\triangle (+ LN) +0.090.09\mathbf{+0.09}+ bold_0.09 +0.200.20\mathbf{+0.20}+ bold_0.20 +0.300.30\mathbf{+0.30}+ bold_0.30 +0.640.64\mathbf{+0.64}+ bold_0.64 +1.221.22\mathbf{+1.22}+ bold_1.22 +2.552.55\mathbf{+2.55}+ bold_2.55 +4.064.06\mathbf{+4.06}+ bold_4.06 +4.864.86\mathbf{+4.86}+ bold_4.86 +5.385.38\mathbf{+5.38}+ bold_5.38 +5.725.72\mathbf{+5.72}+ bold_5.72 +4.514.51+4.51+ 4.51
p𝑝pitalic_p-value 1.01.0{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}1.0}1.0 0.30.3{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.3}0.3 0.60.6{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.6}0.6 0.60.6{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.6}0.6 0.60.6{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.6}0.6 0.50.5{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.5}0.5 0.30.3{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.3}0.3 0.90.9{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.9}0.9 0.60.6{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.6}0.6 0.90.9{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.9}0.9 0.90.9{\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}0.9}0.9

Appendix B More Experimental Details

B.1 Implementation Details

argmaxargmax\operatorname{argmax}roman_argmax retrieval.

We follow Veličković et al. (2024) and train the same neural network architecture in PyTorch for 100,000100000100,000100 , 000 gradient steps with the same hyper-parameter setup. We also follow Veličković et al. (2024) to generate data of varying number of items to train and test the model.

dictionarydictionary\operatorname{dictionary}roman_dictionary lookup.

The network architecture is the same as the argmax\arg\maxroman_arg roman_max retrieval task. We generate data for training and evaluation in the following way: for each item of the length-N𝑁Nitalic_N sequence, we sample a value class cV𝒰{1,,CV}similar-tosubscript𝑐𝑉𝒰1subscript𝐶𝑉c_{V}\sim\mathcal{U}\{1,\dots,C_{V}\}italic_c start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∼ caligraphic_U { 1 , … , italic_C start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT } i.i.d at random; each item also has a key class 1cKCK1subscript𝑐𝐾subscript𝐶𝐾1\leq c_{K}\leq C_{K}1 ≤ italic_c start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT. The key classes of all N𝑁Nitalic_N items in the sequence are sampled without replacement. In our experiments, CK=16384subscript𝐶𝐾16384C_{K}=16384italic_C start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = 16384 and CV=64subscript𝐶𝑉64C_{V}=64italic_C start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT = 64.

The features of each item 𝐱isubscript𝐱𝑖\mathbf{x}_{i}bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is defined as EmbK(cK)EmbV(cV)conditionalsubscriptEmb𝐾subscript𝑐𝐾subscriptEmb𝑉subscript𝑐𝑉\operatorname{Emb}_{K}(c_{K})\parallel\operatorname{Emb}_{V}(c_{V})roman_Emb start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) ∥ roman_Emb start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ), i.e., the concatenation of the embeddings of the key class and the value class. The embedding vectors for each (key and value) class are optimized jointly with the attention network.

The query sequence in our case is guaranteed to be of length 1111. We sample a key class present in the input sequence and use its embedding vector as the query.

For this task, we found that the optimization usually converges within 10,0001000010,00010 , 000 gradient steps. We train the attention network, together with the embedding vectors, in PyTorch for 10,0001000010,00010 , 000 steps with the same hyper-parameter setup as the argmax\arg\maxroman_arg roman_max task.

B.2 Results When Training on More Diverse Sequence Lengths

To validate the utility of normalization when the length gap between the training sequences and the test sequences is smaller, we follow the same experimental setup as in Section 3, but sample sequences of up to 256256256256 items during training. We found it beneficial to gradually increase the length of the sequences sampled throughout training, as is commonly done during pre-training of frontier LLMs (AI@Meta, 2024). The results are reported in Table 4 and Table 5. With layer normalization, the accuracies on out-of-distribution sequence lengths are significantly higher than without on both tasks, demonstrating the importance of normalization for length generalization over various training settings.

B.3 Ablations on the Dictionary Lookup Task

Ablation results on the dictionary lookup task are shown in Table 6, which are consistent with the results on the argmax\arg\maxroman_arg roman_max retrieval task presented in Section 3.2. However, on this task, the performance of standardization and layer normalization is more similar, as indicated by the larger p𝑝pitalic_p-values, suggesting weaker statistical evidence for a significant difference.