TL;DR
- Multi-head attention runs h independent scaled dot-product attention operations in parallel, each on a learned projection of Q, K and V into d_k = d_model / h dimensions.
- The outputs of the h heads are concatenated and passed through a final linear projection W_O, returning to the original d_model width.
- Different heads learn different attention patterns — syntactic dependencies, coreference, positional structure — which a single head cannot capture in one weighting.
- Modern decoder-only LLMs almost universally use Grouped-Query Attention or Multi-Query Attention instead, which keep h query heads but share K and V projections across groups to shrink the KV cache.
Motivation#
A single softmax(QKᵀ/√d_k) produces one probability distribution per query — one way of mixing values. That is enough to model some relationships but not enough to model many simultaneously. A sentence such as 'the trophy did not fit in the suitcase because it was too large' contains a coreference (it → trophy), a syntactic relation (fit → trophy, fit → suitcase) and a discourse relation (because → too large) all at once. One attention weighting cannot disentangle them.
Multi-head attention solves this by computing h attention operations in parallel, each on its own learned linear projection of the input. Each head produces its own distribution and its own value mix; the concatenation lets the model combine them.
The Mechanics#
Given input X of shape (n, d_model), multi-head attention learns 4h weight matrices: W_Q^i, W_K^i, W_V^i of shape (d_model, d_k) for i = 1..h, plus a single W_O of shape (h · d_v, d_model).
For each head, the projections Q_i = X · W_Q^i, K_i = X · W_K^i, V_i = X · W_V^i are computed in d_k dimensions. Standard scaled dot-product attention then produces head_i = softmax(Q_i · K_i^T / √d_k) · V_i. The h head outputs are concatenated along the feature axis and projected back: MultiHead(X) = Concat(head_1, …, head_h) · W_O.
Choosing h and d_k#
The canonical choice is d_k = d_v = d_model / h, which keeps the total parameter count comparable to a single full-width attention. Almost every modern model uses d_k = 128 because that tile size maps cleanly onto NVIDIA Tensor Core fragments.
Counts of heads scale with model width: GPT-2 small (d_model 768) uses 12 heads; GPT-3 (d_model 12,288) uses 96; Llama 3 70B (d_model 8,192) uses 64; Llama 3 405B (d_model 16,384) uses 128.
What Heads Actually Learn#
Mechanistic interpretability work (Anthropic, Olah et al.) has shown that individual heads in trained Transformers specialise. 'Induction heads' learn to copy patterns ABCAB→C. 'Previous-token heads' attend strictly to position i-1. 'Name-mover heads' shuttle entity identity across a sequence. The phenomenon is robust across model families and is one of the few examples where neural network internals are partially understood.
That said, the majority of heads do not have clean interpretations — and many are redundant. Head-pruning studies routinely remove 30-60 per cent of heads without measurable quality loss, which is part of the motivation for GQA and MQA.
From MHA to GQA and MQA#
Multi-head attention's KV cache scales with h × d_k × n × layers. For long-context decoding, that cache dominates memory bandwidth — every decoded token must read every K and V vector for every prior token. Multi-Query Attention (Shazeer 2019) collapses K and V to a single shared head, cutting the cache by h×. Grouped-Query Attention (Ainslie et al. 2023) takes a middle path, sharing K and V across groups of g query heads.
Llama 3, Mistral, Qwen 2 and DeepSeek-V3 all use GQA with g = 8 — eight query heads share one K/V pair. Quality matches MHA within noise; KV cache shrinks roughly 8×; inference throughput rises proportionally on memory-bound decoders.
If you are reading an LLM paper post-2023 and it says 'multi-head attention', read carefully — it usually means GQA. True MHA is now rare outside encoder models and small research baselines.
Implementation Notes#
In practice the four projections W_Q, W_K, W_V across heads are fused into single (d_model, h · d_k) matrices, computed as one GEMM and then reshaped to (n, h, d_k). The attention itself is computed per-head via batched matmul. Flash Attention 3 fuses the whole pipeline — projection-aware tiling, softmax, V-multiply — into one CUDA kernel with no intermediate HBM writes.
# Conceptual PyTorch — production code uses fused Flash Attention.
def multi_head_attention(x, w_q, w_k, w_v, w_o, num_heads):
n, d_model = x.shape
d_k = d_model // num_heads
q = (x @ w_q).reshape(n, num_heads, d_k).transpose(0, 1) # (h, n, d_k)
k = (x @ w_k).reshape(n, num_heads, d_k).transpose(0, 1)
v = (x @ w_v).reshape(n, num_heads, d_k).transpose(0, 1)
scores = (q @ k.transpose(-2, -1)) / (d_k ** 0.5) # (h, n, n)
weights = scores.softmax(dim=-1)
out = (weights @ v).transpose(0, 1).reshape(n, d_model) # (n, d_model)
return out @ w_o