TL;DR
- Grouped-Query Attention (Ainslie et al., 2023, arXiv:2305.13245) keeps the full set of N_q query heads but lets them share a smaller set of N_kv key/value heads, with query heads grouped into G = N_q / N_kv buckets — interpolating between Multi-Head Attention (G = 1) and Multi-Query Attention (G = N_q).
- The KV cache — the dominant memory pressure on autoregressive decoders at long context — shrinks by exactly the group size G, typically 4 or 8, with quality within evaluation noise of full MHA.
- Every credible modern decoder LLM uses GQA: Llama 2 70B onwards, Llama 3 / 3.1 / 4, Mistral, Mixtral, Qwen 2 / 2.5 / 3, Gemma 2 / 3, Phi-3 / Phi-4, DeepSeek-V3's dense layers — the field has converged on N_kv = 8 for models with N_q in the 32-128 range.
- Inference engines (vLLM, TensorRT-LLM, SGLang) implement GQA as a cheap repeat-interleave of K and V across the group dimension, so existing Flash Attention 3 kernels run unmodified — meaning every model in the Yobibyte marketplace catalogue inherits GQA-aware serving without per-model engineering.
- Practical consequence for sizing on Yobitel NeoCloud H100 / H200: a 70B GQA model needs roughly an eighth of the per-request KV cache of an equivalent MHA model, which translates directly into higher concurrency, lower $/million-tokens, and feasible 128k-context serving on a single 8x H100 node.
Overview#
Grouped-Query Attention is the attention shape that made long-context decoder-only LLMs economically serveable. The original 'Attention Is All You Need' design used Multi-Head Attention (MHA): given N_q query heads, the model also held N_q key and N_q value heads, each its own learned projection of the residual stream. That symmetry is beautiful and was right for the encoder-decoder workloads of 2017 — but it leaves a structural problem under autoregressive decoding. Every previously generated token's keys and values have to be cached so they can be re-attended to at every subsequent decoding step. With N_kv = N_q, that KV cache scales linearly in both depth and N_q, and on a 70B-parameter, 80-layer, 64-head model at 32k tokens it eats more HBM than the activations themselves. At 128k tokens it eats more HBM than the weights.
Noam Shazeer's 2019 'Fast Transformer Decoding' paper proposed the extreme fix — Multi-Query Attention (MQA) — with a single shared K/V head across all N_q query heads, shrinking the KV cache by a factor of N_q. The serving wins were real, but MQA showed measurable quality regressions at scale and was harder to train stably. Ainslie et al.'s 2023 paper (arXiv:2305.13245) found the sweet spot by interpolating: keep N_q query heads, but use N_kv key/value heads where 1 < N_kv < N_q, with query heads grouped into G = N_q / N_kv consecutive groups that share each K/V projection. The KV cache shrinks by exactly G, training stability returns, and ablations show quality within the noise floor of full MHA.
By mid-2026 the result is total convergence. Every modern open-weights decoder-only LLM uses GQA — Llama 2 70B onwards, Llama 3 and 3.1 (all sizes), Llama 4 Scout, Mistral 7B / Large 2, Mixtral 8x7B / 8x22B, Qwen 2 / 2.5 / 3, Gemma 2 / 3, Phi-3 / Phi-4, DeepSeek-V3's dense attention layers. The configuration converged too: G = 8 is the dominant choice for models with N_q in the 32-128 range. Closed-frontier models (GPT-4o, Claude 4, Gemini 2) are widely believed to use GQA or its successor Multi-Head Latent Attention (MLA, DeepSeek-V3) based on serving-cost economics.
This entry helps you understand GQA well enough to read a model card and immediately reason about its KV cache footprint, decode bandwidth, and per-request HBM cost on Hopper- or Blackwell-class GPUs — and to choose between a GQA model and an MLA model when sizing inference on Yobitel NeoCloud. The Yobibyte marketplace catalogue is GQA-by-default; the per-GPU sizing the picker reasons about reflects exactly this maths, which is why a 70B GQA model stretches further on a single H100 SXM5 than the raw parameter count would suggest.
How it works: the KV cache, the group mechanism, and the maths#
Autoregressive decoding generates one token at a time. To compute attention for token n, the model needs the keys and values of every previous token from positions 1 to n-1. Recomputing those at every decoding step would scale quadratically with sequence length and is unaffordable past about a hundred tokens. Instead, the KV cache materialises each layer's keys and values once when they are first computed and reuses them on every subsequent step. The cache size per request is exactly 2 * num_layers * N_kv * d_k * seq_len * bytes_per_element, where the factor of 2 covers both K and V.
With full MHA at N_q = N_kv = 64, Llama 3 70B at 80 layers, d_k = 128, BF16 (2 bytes), the per-token cache cost is 2 * 80 * 64 * 128 * 2 = 2.6 MB. A 32k-token request needs 83 GB of KV cache — more than a single H100 80 GB SXM5 has, before you load the weights. At 128k context the same recipe needs 332 GB. That is the constraint GQA was built to break.
Mathematically, GQA changes nothing about the attention operation per query head — it changes which K and V each head reads from. Define G as the group size (G divides N_q). Query head i is assigned to group g(i) = floor(i / (N_q / N_kv)). The query projection W_Q stays at shape (d_model, N_q * d_k); the key and value projections W_K and W_V shrink to (d_model, N_kv * d_k). At attention time, query head q_i reads from shared K_{g(i)} and V_{g(i)} for its group. The softmax-and-weighted-sum maths is identical to MHA; only the indexing changes.
The KV cache shrink factor is exactly G. At Llama 3 70B's actual configuration (N_q = 64, N_kv = 8, G = 8), the cache cost drops from 2.6 MB per token to 320 KB per token, and the 32k-context request from 83 GB to 10.4 GB. A whole 8x H100 SXM5 node can then host hundreds of concurrent 32k requests instead of single-digit. The 128k case drops from 332 GB to 42 GB, which is the difference between needing a 4x H200 (4 x 141 GB HBM3e = 564 GB) and not.
Decode bandwidth follows the same shrink. On Hopper, decode is HBM-bandwidth-bound: at every step you stream the KV cache slice for the new token through the attention kernel. Cutting the cache by 8x cuts the per-token bandwidth bill by 8x, which translates directly into tokens-per-second-per-GPU at long context. Empirically, modern GQA models reach 30-60 % higher decode throughput than would have been possible at the same parameter count with MHA — the gap is exactly what Yobitel InferenceBench measures when it compares H100 SXM5 GQA serving against the legacy MHA baselines.
- Total query heads N_q: unchanged from the dense Transformer baseline (typically 32, 40, 64 or 128).
- Total K/V heads N_kv: shrunk to N_q / G. Standard choice is N_kv = 8 across the modern open-weights family.
- Group size G: how many query heads share each K/V head. G = 1 is MHA; G = N_q is MQA; modern GQA uses G = 4 or G = 8.
- Per-token KV cache: 2 * num_layers * N_kv * d_k * bytes_per_element. Shrinks by exactly G versus MHA.
- FLOPs per token: unchanged. Attention compute is dominated by Q @ K^T and softmax @ V, both of which retain the full per-query-head outputs.
- Quality: within evaluation noise of MHA at the same parameter count, per Ainslie et al. (2023) ablations and every subsequent open-weights model release.
# gqa_minimal.py — runs with: pip install torch && python gqa_minimal.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class GroupedQueryAttention(nn.Module):
"""Faithful GQA per Ainslie et al. 2023 (arXiv:2305.13245).
N_q must be divisible by N_kv."""
def __init__(self, d_model: int, n_q: int, n_kv: int, d_k: int):
super().__init__()
assert n_q % n_kv == 0, "N_q must be divisible by N_kv"
self.n_q, self.n_kv, self.d_k = n_q, n_kv, d_k
self.group = n_q // n_kv
self.w_q = nn.Linear(d_model, n_q * d_k, bias=False)
self.w_k = nn.Linear(d_model, n_kv * d_k, bias=False) # shrunk
self.w_v = nn.Linear(d_model, n_kv * d_k, bias=False) # shrunk
self.w_o = nn.Linear(n_q * d_k, d_model, bias=False)
def forward(self, x: torch.Tensor, causal: bool = True):
b, n, _ = x.shape
q = self.w_q(x).view(b, n, self.n_q, self.d_k).transpose(1, 2) # (b, N_q, n, d_k)
k = self.w_k(x).view(b, n, self.n_kv, self.d_k).transpose(1, 2) # (b, N_kv, n, d_k)
v = self.w_v(x).view(b, n, self.n_kv, self.d_k).transpose(1, 2) # (b, N_kv, n, d_k)
# Broadcast each K/V head across its group of query heads.
k = k.repeat_interleave(self.group, dim=1) # (b, N_q, n, d_k)
v = v.repeat_interleave(self.group, dim=1) # (b, N_q, n, d_k)
# F.scaled_dot_product_attention dispatches to Flash Attention 3 on Hopper/Blackwell.
y = F.scaled_dot_product_attention(q, k, v, is_causal=causal) # (b, N_q, n, d_k)
y = y.transpose(1, 2).contiguous().view(b, n, self.n_q * self.d_k)
return self.w_o(y)
# Llama-3-70B-shaped block: N_q=64, N_kv=8, G=8, d_k=128, d_model=8192.
attn = GroupedQueryAttention(d_model=8192, n_q=64, n_kv=8, d_k=128)
x = torch.randn(2, 16, 8192)
print("output shape:", attn(x).shape) # (2, 16, 8192)
print("KV cache per token (BF16):",
2 * 8 * 128 * 2, "bytes (= ", 2 * 8 * 128 * 2 / 1024, "KB per layer)")The `repeat_interleave` above materialises broadcast K/V tensors. Production kernels (Flash Attention 3, vLLM's PagedAttention) skip this step entirely and read directly from the smaller N_kv-shaped cache, which is both faster and the only way to realise the full HBM saving in practice.
Variants and architectural choices: MHA, MQA, GQA, and MLA#
GQA sits on a spectrum of KV-shrinkage strategies. The choice of N_kv (and equivalently G) is the single architectural lever; the rest of the variants are about what to do beyond GQA when you want the cache smaller still.
Multi-Query Attention (Shazeer 2019, arXiv:1911.02150) is GQA's predecessor and extreme limit: N_kv = 1 (G = N_q). All query heads share a single K/V head. The cache shrinks to its theoretical floor — about 64x smaller than MHA for N_q = 64 — and decode bandwidth follows. PaLM (2022), Falcon-40B and StarCoder used MQA in production. The quality cost showed at scale, especially on tasks needing fine-grained head specialisation (long-form reasoning, multi-document QA), and training stability suffered without careful learning-rate schedules. Modern teams avoid pure MQA in favour of GQA at G = 8.
Multi-Head Latent Attention (MLA, introduced with DeepSeek-V2, 2024) is the post-GQA frontier. Instead of N_kv K/V heads, MLA projects the residual stream into a low-rank latent of dimension d_latent (typically 512), caches only that, and reconstructs K and V on-the-fly via a learned up-projection at attention time. The per-token cache is 2 * num_layers * d_latent * bytes (no head dimension at all), which is roughly 4x smaller than GQA at G = 8. DeepSeek-V3 uses MLA on every attention layer; it costs more compute at attention time (the reconstruction matmuls) but the HBM and bandwidth wins are decisive at long context. MLA is the natural successor to GQA — and it is what to evaluate when GQA stops being enough.
The configurations that have actually shipped in production decoder LLMs are tabulated below. The pattern is unambiguous: N_kv = 8 is the modern default across the open-weights family, with smaller models (8B and below) sometimes dropping to N_kv = 4 or N_kv = 2 to push memory pressure further.
| Variant | N_kv | Cache shrink vs MHA | Quality vs MHA | Used by |
|---|---|---|---|---|
| MHA (Vaswani 2017) | = N_q | 1x | Baseline | GPT-2/3, BERT, T5, Llama 1 |
| MQA (Shazeer 2019) | 1 | N_q x (~64x) | Slight regression at scale | PaLM, Falcon-40B, StarCoder |
| GQA G=4 | N_q / 4 | 4x | Within noise | Some smaller variants (Phi-3-mini-4B) |
| GQA G=8 | N_q / 8 | 8x | Within noise | Llama 2/3/3.1, Mistral, Mixtral, Qwen 2/3, Gemma 2, Phi-4 |
| GQA G=16 | N_q / 16 | 16x | Slight regression observed | Rare; experimental |
| MLA (DeepSeek-V2/V3) | Latent dim 512 | ~32x at long context | Equal or better than GQA G=8 | DeepSeek-V2, DeepSeek-V3 |
For the question 'should I pick a GQA model or an MLA model?': GQA wins on simplicity, kernel maturity and ecosystem support; MLA wins on extreme-long-context (>128k) decode economics. Through mid-2026 the answer for typical 8k-32k production chat workloads is GQA; for 128k+ document workloads it is increasingly MLA.
Where it is used today: every modern open-weights decoder LLM#
Adoption of GQA across the open-weights family is total. Below is the configuration used by the major frontier models — published in their technical reports, verifiable from the `config.json` of each checkpoint on HuggingFace. The N_kv = 8 convergence is striking.
Llama 2 70B (Touvron et al., 2023) was the first frontier-scale open-weights model to ship GQA: N_q = 64, N_kv = 8, G = 8. The smaller Llama 2 7B / 13B kept MHA — the cache pressure at 7B is not yet decisive — but Llama 3 (April 2024) made GQA universal across all sizes (N_kv = 8 at 70B and 405B; N_kv = 8 at 8B with N_q = 32, giving G = 4). Llama 3.1 retains the same shape with extended 128k context. Llama 4 Scout (2025) uses GQA on its dense attention layers and MoE on its FFN.
Mistral followed the same trajectory. Mistral 7B (Sept 2023) shipped with GQA at N_kv = 8 from launch; Mixtral 8x7B and 8x22B inherit it on the attention layers; Mistral Large 2 (123B, July 2024) uses N_q = 96, N_kv = 8, G = 12. Qwen 2 (June 2024), Qwen 2.5 (Sept 2024) and Qwen 3 (2025) all use GQA at N_kv = 8 across the entire family from 0.5B to 235B. Gemma 2 (June 2024) uses GQA at N_kv = 8 on the 27B model. Phi-3 (April 2024) and Phi-4 (Dec 2024) use GQA — Phi-4 14B with N_q = 40, N_kv = 10, G = 4.
On the closed-frontier side, GPT-4o, Claude 4 and Gemini 2 have not published their attention configurations, but serving-cost economics (the ability to support 200k+ context at competitive prices) strongly imply either GQA or MLA. The architectural lever this entry describes is what makes any of these long-context APIs financially viable.
On the Yobitel side, every model in the Yobibyte marketplace catalogue inherits GQA-aware serving without per-model engineering, because the runtimes Yobibyte routes through (industry-standard inference engines selected per workload) all support GQA natively in their Flash Attention 3 attention kernels. When a Yobitel NeoCloud customer pins a Llama 3.1 70B endpoint to a UK region, the per-GPU concurrency they see on the InferenceBench leaderboard reflects exactly the GQA cache shrink: roughly 8x more concurrent requests per H100 SXM5 than the equivalent MHA workload would have allowed.
| Model | Params | N_q | N_kv | G | Context |
|---|---|---|---|---|---|
| Llama 3.1 8B | 8 B | 32 | 8 | 4 | 128k |
| Llama 3.1 70B | 70 B | 64 | 8 | 8 | 128k |
| Llama 3.1 405B | 405 B | 128 | 8 | 16 | 128k |
| Mistral 7B v0.3 | 7 B | 32 | 8 | 4 | 32k |
| Mistral Large 2 | 123 B | 96 | 8 | 12 | 128k |
| Mixtral 8x22B (attention) | 141 B total | 48 | 8 | 6 | 64k |
| Qwen 2.5 72B | 72 B | 64 | 8 | 8 | 128k |
| Qwen 3 235B (dense attn) | 235 B total | 64 | 8 | 8 | 128k |
| Gemma 2 27B | 27 B | 32 | 16 | 2 | 8k |
| Phi-4 14B | 14 B | 40 | 10 | 4 | 16k |
| DeepSeek-V3 (attention) | 671 B total | MLA | MLA | n/a | 128k |
Trade-offs and known limitations#
The standard summary is right: GQA gives you most of MQA's serving win with none of MQA's quality cost. The detail is more nuanced.
The wins are HBM and bandwidth, not FLOPs. Attention compute per token is essentially unchanged — every query head still produces its own attention output, every softmax still runs across full sequence length. The Q @ K^T matmul is the same shape, the softmax @ V matmul is the same shape, the only change is which K/V is read for which Q. So GQA helps decode (which is HBM-bandwidth-bound) much more than it helps prefill (which is compute-bound on long context). Empirically, decode throughput per GPU rises 30-60 % under GQA versus MHA at the same parameter count; prefill throughput is roughly flat. The Yobitel InferenceBench long-context benchmarks show exactly this asymmetry.
Quality is within noise of MHA at G = 4 or G = 8, but degrades visibly at G = 16 and above. Ainslie et al.'s 2023 paper ran T5-XXL ablations and showed that 8 K/V heads recovered essentially all of MHA's quality on SuperGLUE; 1 K/V head (MQA) lost 0.5-2 points. The intuition: the K/V projections are where attention encodes 'what kind of relationships am I looking for?' — collapsing them too aggressively erases head specialisation. G = 8 is the empirical sweet spot; G = 4 is safer for smaller models with fewer query heads; G > 16 is risky.
Training stability is GQA's quiet advantage over MQA. MQA is well-documented to require careful learning-rate warm-up and longer training schedules to converge stably at frontier scale; GQA does not. Most teams that adopted GQA report that no other training change was needed beyond replacing the K/V projection shapes.
The KV cache is still linear in num_layers and seq_len. GQA shrinks the multiplicative N_kv term but does not change the scaling. At very long context (>1M tokens) even GQA caches grow unwieldy — Llama 3.1 405B at 1M tokens needs roughly 1.3 TB of KV cache per request at FP16, which is what motivates Ring Attention (distributing the cache across multiple GPUs) or MLA (compressing K/V into a smaller latent representation). For most production workloads through 2026 — chat at 8k-32k context, document QA at 64k-128k — GQA at N_kv = 8 is more than sufficient.
GQA is compatible with every modern attention optimisation. Flash Attention 3, PagedAttention, prefix caching, chunked prefill, speculative decoding, FP8 inference — all of these compose with GQA without modification. The architecture lever is orthogonal to the kernel optimisations.
Practical implementation notes#
Libraries that implement GQA correctly in 2026: PyTorch's `torch.nn.functional.scaled_dot_product_attention` accepts non-matching N_q and N_kv inputs and broadcasts internally, dispatching to Flash Attention 3 on Hopper / Blackwell; HuggingFace `transformers` exposes `num_attention_heads` and `num_key_value_heads` as separate `config.json` fields on every GQA model; vLLM detects N_kv from the config and sizes its PagedAttention block table accordingly; TensorRT-LLM and SGLang follow the same pattern; the underlying Flash Attention 3 kernel has GQA-aware variants that skip the repeat-interleave entirely.
Common foot-guns when implementing GQA from scratch. Forgetting to broadcast K/V across the group dimension before the matmul produces silent shape errors that are hard to debug — `F.scaled_dot_product_attention` handles this transparently in PyTorch 2.5+, but custom CUDA kernels need explicit broadcast indexing. Using `repeat` instead of `repeat_interleave` produces a stride that is twice as large and reads keys/values in the wrong order — `repeat_interleave` is the correct primitive. Choosing N_kv that does not divide N_q raises a shape error; the constraint is hard. Setting N_kv too small (G > 16) degrades quality measurably and is hard to recover from once weights are trained.
Reading a model card: look for `num_key_value_heads` in `config.json` and compare to `num_attention_heads`. If they match, the model is MHA. If `num_key_value_heads = 1`, it is MQA. Otherwise it is GQA with G = num_attention_heads / num_key_value_heads. Llama 3.1 8B has `num_attention_heads: 32, num_key_value_heads: 8`, so G = 4; Llama 3.1 70B has `num_attention_heads: 64, num_key_value_heads: 8`, so G = 8.
Sizing arithmetic for a planning conversation on Yobitel NeoCloud. The per-token KV cache cost in BF16 is 2 * num_layers * num_key_value_heads * head_dim * 2 bytes. For Llama 3.1 70B (80 layers, 8 KV heads, 128 head_dim, BF16) that is 327 KB per token. At 32k context per request, 10.4 GB. A single 8x H100 SXM5 node has 640 GB HBM total, roughly 540 GB available after weights — enough for about 50 concurrent 32k requests. With FP8 weights and FP8 KV cache the same node serves roughly 100 concurrent 32k requests. The InferenceBench public numbers for this configuration on Yobitel NeoCloud should land in the same range; if a serving stack is wildly different, the limiting factor is usually attention kernel choice or PagedAttention block fragmentation, not GQA itself.
Fine-tuning a GQA model preserves N_kv by default — both LoRA and full-finetune leave the K/V projection shapes untouched and only update the existing weights. Converting an MHA checkpoint to GQA after training is non-trivial: it requires averaging K and V projections within each group, which produces a measurable quality drop, and is generally a 'continue pretraining for a few hundred billion tokens' exercise rather than a one-shot reshape. The Yobibyte FineTune resource exposes GQA models as-is and does not attempt MHA-to-GQA conversion as a managed method.
If you are debugging unexpectedly low decode throughput on a GQA model, first confirm the runtime is actually exploiting the smaller K/V cache. A common bug: a serving stack that broadcasts K/V to full N_q early and stores the broadcast version in HBM. The cache shrinks only if the runtime keeps K/V at N_kv shape and broadcasts at kernel time. Flash Attention 3, vLLM 0.6+ and TensorRT-LLM all do this correctly; older custom kernels often do not.
Where GQA fits in the Yobitel stack#
Every decoder LLM in the Yobibyte managed-platform catalogue uses GQA. The marketplace exposes Llama 3.1, Llama 4, Qwen 3, Mistral Large, Phi-4 and the rest as named models; Yobibyte routes inference through industry-standard runtimes that exploit the GQA shape transparently. Customers do not configure N_kv or capacity factors — they pick a model and a region, and the long-context economics in this entry are what they see in their bill.
Yobitel NeoCloud — the underlying H100 SXM5 / H200 / B200 fleet — sizes per-GPU concurrency for inference workloads using the per-token cache arithmetic above. The GQA shrink factor is what makes a 70B model viable on a single 8x H100 node at 128k context, and what gives the published NeoCloud price-per-million-tokens its competitive shape against US hyperscaler API alternatives.
Yobitel InferenceBench publishes per-runtime, per-GPU GQA serving numbers for the major open-weights models — tokens-per-second, time-to-first-token, p99 latency, and cost-per-million-tokens. For a team choosing between GQA models (Llama 3.1 70B, Qwen 2.5 72B, Mistral Large 2) or between GQA and MLA models (the Qwen / Llama family vs DeepSeek-V3), InferenceBench is the empirical complement to the architectural reasoning here.
References
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023) · arXiv
- Fast Transformer Decoding: One Write-Head Is All You Need (Shazeer, 2019) · arXiv
- Llama 2: Open Foundation and Fine-Tuned Chat Models (Touvron et al., 2023) · arXiv
- Llama 3 Technical Report (Meta, 2024) · arXiv
- DeepSeek-V2 Technical Report (introducing MLA, 2024) · arXiv
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision · arXiv