Attention Mechanisms

From BloomWiki
Revision as of 01:47, 25 April 2026 by Wordpad (talk | contribs) (BloomWiki: Attention Mechanisms)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search

How to read this page: This article maps the topic from beginner to expert across six levels � Remembering, Understanding, Applying, Analyzing, Evaluating, and Creating. Scan the headings to see the full scope, then read from wherever your knowledge starts to feel uncertain. Learn more about how BloomWiki works ?

Attention mechanisms are the computational building blocks that allow neural networks to dynamically focus on the most relevant parts of their input when producing each output. Introduced in the context of machine translation to help encoder-decoder RNNs align source and target words, attention was then generalized into the self-attention mechanism at the core of the Transformer architecture — the technology underlying GPT, BERT, DALL-E, AlphaFold, and virtually every frontier AI system. Understanding attention is understanding the engine of modern AI.

Remembering

  • Attention — A mechanism that computes a weighted combination of input elements, where weights represent how relevant each element is to the current computation.
  • Self-attention — Attention applied to a single sequence where each element attends to all other elements of the same sequence.
  • Cross-attention — Attention where queries come from one sequence and keys/values from another; used in encoder-decoder models.
  • Query (Q) — A vector representing "what I am looking for" at the current position.
  • Key (K) — A vector representing "what I offer" for each position in the sequence.
  • Value (V) — A vector representing "what I give if selected" for each position.
  • Attention weight — The scalar importance assigned to each key-value pair given the query; computed via softmax of scaled dot products.
  • Attention head — One parallel attention operation; multi-head attention runs H heads simultaneously.
  • Multi-head attention — Running H attention operations in parallel with different projections, then concatenating outputs.
  • Scaled dot-product attention — The standard attention formula: Attention(Q,K,V) = softmax(QKᵀ/√d_k)V.
  • Causal (masked) attention — Self-attention where each position can only attend to positions before it; used in autoregressive decoders.
  • Positional encoding — Information added to embeddings indicating each token's position, since attention is permutation-invariant.
  • Attention sink — The empirical phenomenon where early tokens attract disproportionate attention mass in LLMs.
  • Flash Attention — A memory-efficient, hardware-optimized implementation of exact attention using tiling and recomputation.
  • Sparse attention — Attention variants that restrict which positions can attend to which, reducing O(n²) complexity.

Understanding

Standard neural networks apply fixed weights regardless of input context. Attention is dynamic — the weights change with every input. This is its power: the network can decide, for each output token, which parts of the input are most relevant.

The intuition: imagine reading a sentence and being asked "Who did John see?" Your brain attends to "John" and the verb "see" and its object — not to articles and prepositions. Attention gives neural networks this selective focus ability.

The math: For a query q and a set of key-value pairs (ki, vi): - Compute compatibility score: si = q · ki / √d_k - Normalize: αi = softmax(si) - Output: o = Σ αi vi

The output is a weighted sum of values, where the weights reflect how compatible each key is with the query. Scaling by √d_k prevents dot products from growing too large, which would push softmax into saturation.

Multi-head attention allows different heads to capture different types of relationships simultaneously — one head might track syntactic dependencies, another semantic similarity, another coreference. The outputs are concatenated and linearly projected.

The O(n²) problem: Full self-attention computes all pairwise attention scores, requiring O(n²) memory and compute in sequence length n. For a 100k-token context, this is 10^10 scores — impractical. Solutions: Flash Attention (IO-aware exact attention), sliding window attention (Longformer), linear attention approximations, and GQA (Grouped Query Attention) for inference efficiency.

Applying

Implementing scaled dot-product attention from scratch: <syntaxhighlight lang="python"> import torch import torch.nn as nn import torch.nn.functional as F import math

class MultiHeadAttention(nn.Module):

   def __init__(self, d_model=512, n_heads=8, dropout=0.1):
       super().__init__()
       assert d_model % n_heads == 0
       self.d_k = d_model // n_heads
       self.n_heads = n_heads
       self.W_q = nn.Linear(d_model, d_model)
       self.W_k = nn.Linear(d_model, d_model)
       self.W_v = nn.Linear(d_model, d_model)
       self.W_o = nn.Linear(d_model, d_model)
       self.dropout = nn.Dropout(dropout)
   def split_heads(self, x):
       B, T, D = x.shape
       return x.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
       # (B, H, T, d_k)
   def forward(self, query, key, value, mask=None):
       B, T, _ = query.shape
       Q = self.split_heads(self.W_q(query))
       K = self.split_heads(self.W_k(key))
       V = self.split_heads(self.W_v(value))
       # Scaled dot-product attention
       scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
       if mask is not None:
           scores = scores.masked_fill(mask == 0, float('-inf'))
       attn = self.dropout(F.softmax(scores, dim=-1))
       out = torch.matmul(attn, V)           # (B, H, T, d_k)
       out = out.transpose(1, 2).contiguous().view(B, T, -1)
       return self.W_o(out), attn

</syntaxhighlight>

Attention variant selection guide
Standard self-attention → Transformers, BERT, GPT (seq len ≤ 4096)
Flash Attention 2 → Any modern transformer; same output, 2-4× faster, O(n) memory
Grouped Query Attention (GQA) → LLaMA 2/3, Mistral — reduces KV cache in inference
Sliding window attention → Longformer, Mistral — O(n·w) complexity for long docs
Cross-attention → Encoder-decoder models (T5, Whisper, LLaVA projection)
Linear attention → Mamba alternative, sub-quadratic, trades quality for speed

Analyzing

Attention Complexity Comparison
Variant Time Complexity Memory Quality vs. Full Attention
Full attention O(n²) O(n²) Reference
Flash Attention O(n²) compute O(n) memory Identical (exact)
GQA (G groups) O(n²) O(n²/G) KV Near-identical
Sliding window (w) O(n·w) O(n·w) Degrades for long-range deps
Linear attention O(n) O(n) Significant quality loss

Failure modes: Attention sinks — first token receives excess attention mass, distorting representations. Position generalization failure — models trained with absolute PE fail on sequences longer than training. Quadratic blowup — naive attention on 100k tokens requires TB of memory. Uniform attention distribution in early training destabilizes gradients.

Evaluating

Attention visualization (BertViz, TransformerLens) reveals what each head attends to — essential for understanding model behavior. Expert practitioners perform attention head ablation: zero out individual heads and measure performance impact, identifying which heads are critical. Context utilization benchmarks (RULER, NIAH — Needle in a Haystack) test whether long-context models actually use information at all positions. Many models with 128k context fail to retrieve information from the middle.

Creating

Designing attention for a specific long-context use case:

  1. Use Flash Attention 2 as the default — same quality, O(n) memory.
  2. For inference serving, switch to GQA (group K/V heads) to reduce KV cache size by 4-8×.
  3. For documents >32k tokens, implement sliding window attention for local context + full attention every N layers for global context (like Mistral).
  4. Use RoPE (Rotary Position Embeddings) for better length generalization.
  5. Monitor KV cache size as the dominant memory cost at inference scale.