TL;DR
- Third-generation Flash Attention from Tri Dao's group at Princeton / Together AI, released July 2024 (Shah et al., arXiv:2407.08608). BSD-3-Clause licensed; hosted at github.com/Dao-AILab/flash-attention.
- Targets the Hopper architecture specifically — uses Tensor Memory Accelerator (TMA), warp-specialised producer/consumer pipelining, ping-pong scheduling of the two matmuls around softmax, and FP8 quantisation with incoherent (Hadamard) preprocessing to reach ~75 percent of theoretical peak on H100.
- Roughly 1.5-2x faster than Flash Attention 2 on H100 for typical training and inference-prefill attention shapes; FP8 variant doubles that again at near-BF16 quality for standard LLM benchmarks.
- The reference kernel underneath every modern LLM training and inference stack — Megatron-LM, NeMo, DeepSpeed, FSDP-based trainers, vLLM, TensorRT-LLM, SGLang and PyTorch SDPA all dispatch to FA3 on Hopper and to FA2 on Ampere automatically.
- Drop-in replacement for `torch.nn.functional.scaled_dot_product_attention`; same call signature, 3-5x speedup typical for long-context attention on H100.
Overview#
The Flash Attention series — Dao et al., 2022 (arXiv:2205.14135); Dao, 2023 (arXiv:2307.08691); Shah et al., 2024 (arXiv:2407.08608) — replaced the naive O(N^2) attention computation with a tiled, streaming, online-softmax kernel that uses O(N) memory and exposes the full GPU memory hierarchy to the algorithm. The series defined what a production attention kernel looks like for transformers, and the FA3 paper is the current state of the art for Hopper-class hardware.
Flash Attention 3 is the Hopper-specific evolution. FA1 and FA2 were Ampere kernels; FA3 exploits hardware features unique to Hopper: the Tensor Memory Accelerator (TMA) for asynchronous bulk loads from HBM to shared memory, warp-specialised producer-consumer scheduling that maps cleanly to Hopper's 4 warp-schedulers-per-SM design, WGMMA (warp-group matrix-multiply-accumulate) for whole-warp-group tensor-core ops, and native FP8 tensor-core math with E4M3 / E5M2 formats.
Functionally, FA3 is shipped as a Python package (`pip install flash-attn`) with a small surface area: the standalone `flash_attn_func`, `flash_attn_qkvpacked_func`, `flash_attn_varlen_func` and `flash_attn_varlen_qkvpacked_func` entry points, plus a `kvcache` variant for inference. Every major training and inference framework links against this same library; the framework-level flag (`use_flash_attn=true`, `attn_implementation='flash_attention_2'` in HuggingFace) selects it. Yobitel NeoCloud customers training frontier models on H100, H200, and B200 silicon use FA3 by default — it ships pre-installed and pre-dispatched in the standard NGC-derived training and inference containers.
This entry documents the production surface: the API call signatures, the supported dtypes and masks, the kernel-dispatch behaviour by GPU generation, the throughput characteristics at common shapes, and the drop-in migration from PyTorch SDPA or FA2. This entry helps you choose and operate FA3 for training pods on Yobitel NeoCloud or your own multi-GPU cluster.
Quick start#
FA3 is a drop-in replacement for any attention computation in PyTorch. The example below installs the kernel, runs a microbenchmark comparing FA3 against PyTorch SDPA at three common LLM attention shapes on an H100, and shows how to plumb it into a HuggingFace model and a custom training loop.
# 1. Install — pre-built wheels exist for CUDA 12.x + PyTorch 2.3+
# pip install flash-attn --no-build-isolation (Hopper requires the v3 wheel)
# 2. Microbenchmark FA3 vs PyTorch SDPA on H100
import torch
from flash_attn import flash_attn_func # FA3 path on Hopper, FA2 on Ampere
def bench(batch, seq, heads, head_dim, causal=True, dtype=torch.bfloat16):
q = torch.randn(batch, seq, heads, head_dim, device='cuda', dtype=dtype)
k = torch.randn(batch, seq, heads, head_dim, device='cuda', dtype=dtype)
v = torch.randn(batch, seq, heads, head_dim, device='cuda', dtype=dtype)
# Warm-up + time FA3
for _ in range(3): flash_attn_func(q, k, v, causal=causal)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True); end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(50): flash_attn_func(q, k, v, causal=causal)
end.record(); torch.cuda.synchronize()
fa_ms = start.elapsed_time(end) / 50
# Same workload via torch SDPA (auto-dispatches to FA2 or memory-efficient)
q2 = q.transpose(1, 2); k2 = k.transpose(1, 2); v2 = v.transpose(1, 2)
for _ in range(3): torch.nn.functional.scaled_dot_product_attention(q2, k2, v2, is_causal=causal)
torch.cuda.synchronize()
start.record()
for _ in range(50):
torch.nn.functional.scaled_dot_product_attention(q2, k2, v2, is_causal=causal)
end.record(); torch.cuda.synchronize()
sdpa_ms = start.elapsed_time(end) / 50
print(f'seq={seq:>6} FA3={fa_ms:6.2f} ms SDPA={sdpa_ms:6.2f} ms speedup={sdpa_ms/fa_ms:.2f}x')
for seq in (2048, 8192, 32768):
bench(batch=4, seq=seq, heads=32, head_dim=128)
# 3. Plumb into a HuggingFace model — single flag
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'meta-llama/Meta-Llama-3.1-8B-Instruct',
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2', # dispatches to FA3 on Hopper
device_map='cuda')
# 4. Variable-length packed sequences (for batched fine-tuning)
from flash_attn import flash_attn_varlen_func
cu_seqlens = torch.tensor([0, 512, 1280, 2048], device='cuda', dtype=torch.int32)
max_seqlen = 768
out = flash_attn_varlen_func(q.view(-1, heads, head_dim),
k.view(-1, heads, head_dim),
v.view(-1, heads, head_dim),
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
causal=True)On Hopper, prefer the `flash_attn_func` direct call over PyTorch SDPA when you need the absolute fastest path — SDPA's dispatcher chooses a kernel per call and has overhead. For training loops where the shape is fixed, the direct call wins 5-10 percent.
How it works#
Naive attention computes S = QK^T / sqrt(d), P = softmax(S), O = PV, materialising the N x N matrices S and P in HBM. For sequence length N = 8192, that is 256 MB per head in BF16 — IO-bound on every modern GPU and the reason naive attention runs at 5-10 percent of peak FLOPS. The Flash Attention family eliminates this materialisation entirely by tiling the computation and recomputing the softmax statistics online.
FA1 (2022) introduced the streaming-softmax tiled kernel. FA2 (2023) reworked parallelisation to split over the sequence dimension as well as heads, dropping non-matmul FLOPs and lifting A100 utilisation to ~50-72 percent. FA3 (2024) is the Hopper rewrite that pushes utilisation to ~75 percent of peak BF16 and ~85 percent of peak FP8, with three distinct architectural moves.
First, warp specialisation. Hopper has 4 warp schedulers per SM. FA3 partitions warps into producers (responsible for issuing TMA loads of the next K, V tile from HBM into shared memory) and consumers (responsible for the WGMMA tensor-core math). Producer warps spend most of their time waiting on memory transactions; consumer warps spend most of their time issuing WGMMA. This decoupling means the math warps never stall on address arithmetic or memory-load latency.
Second, ping-pong scheduling. Attention has two main matmuls per tile (S = Q x K^T and O = P x V) separated by an online softmax. FA3 schedules them so that the tensor cores fill with one matmul while the softmax of the previous matmul completes — overlapping softmax with matmul, which FA2 could not do because Ampere's warp-scheduler topology did not permit it.
Third, FP8 with incoherent processing. Naive FP8 attention loses accuracy because the softmax compresses outlier values into the limited dynamic range of E4M3 / E5M2. FA3 applies a Hadamard transform (an orthogonal rotation) to Q and K before FP8 quantisation, spreading outliers across the block. This recovers BF16-equivalent quality at FP8 throughput and is the key enabler of FP8 attention in production training and inference.
- Tensor Memory Accelerator (TMA): asynchronous bulk loads with descriptor-based addressing.
- WGMMA: warp-group matrix multiply-accumulate, a Hopper-only tensor-core op covering a 64x64x16 tile per warp group.
- Warp specialisation: separate producer and consumer warps for memory and math.
- Ping-pong scheduling: overlap softmax of tile i-1 with matmul of tile i.
- FP8 with Hadamard preprocessing: BF16-quality at FP8 throughput.
- Online softmax: streaming maximum and partial-sum accumulation; same numerics as exact softmax.
- Causal mask short-circuit: skip empty upper-triangle tiles entirely.
- Sliding-window mask and ALiBi: supported with first-class fast paths.
FA3's reported ~75 percent of theoretical peak is for favourable shapes (long sequences >= 4K, large head counts, no exotic masks). Real training-step utilisation is somewhat lower because attention is one block among many — but the kernel is no longer the dominant bottleneck. On a typical Llama-70B training step, attention drops from ~40 percent of step time under FA2 to ~20 percent under FA3.
Reference and specifications#
The library exposes a small set of entry points, each with a stable call signature. The table below lists the supported functions, their tensor-shape contracts, and the dtype and mask combinations they accept as of flash-attn 2.7 (June 2026). All functions return either the output tensor (training) or `(output, lse)` when `return_attn_probs=True`.
- Head dimension (D): 32, 64, 96, 128, 160, 192, 224, 256 supported; non-power-of-2 dims (e.g. 96, 192) supported but ~10-15 percent slower than powers of 2.
- GQA / MQA: supported via `flash_attn_kvpacked_func` or by passing q with heads_q and k, v with heads_kv = heads_q / group_size.
- Causal mask (`causal=True`): standard left-to-right masking; FA3 skips the empty upper-triangle tiles.
- Sliding window (`window_size=(L, R)`): left-context L and right-context R limits; useful for Mistral-style local attention.
- ALiBi: bias slopes passed via `alibi_slopes`.
- Softcap (`softcap=30.0`): tanh-scaled logit clamping (Gemma 2 style).
- Dropout: training-only dropout fused into the kernel via `dropout_p`.
- FP8 (Hopper only): `flash_attn_func` with FP8 inputs requires the v3 wheel and Transformer Engine for end-to-end FP8 training.
| Function | Shape | Supported dtypes | Supported features |
|---|---|---|---|
| flash_attn_func(q, k, v, ...) | (B, S, H, D) | BF16, FP16, FP8 (Hopper) | Causal, sliding-window, ALiBi, softcap. |
| flash_attn_qkvpacked_func(qkv, ...) | (B, S, 3, H, D) | BF16, FP16, FP8 | As above; saves one transpose. |
| flash_attn_kvpacked_func(q, kv, ...) | (B, S, H, D), (B, S, 2, H, D) | BF16, FP16, FP8 | GQA/MQA-friendly when heads_kv < heads_q. |
| flash_attn_varlen_func(q, k, v, cu_q, cu_k, ...) | (total_tokens, H, D) | BF16, FP16, FP8 | Packed-sequence training/inference. |
| flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, ...) | (total_tokens, 3, H, D) | BF16, FP16, FP8 | Packed + packed-qkv combination. |
| flash_attn_with_kvcache(q, k_cache, v_cache, ...) | (B, 1, H, D), (B, Smax, H, D) | BF16, FP16, FP8 | Inference decode step with paged or contiguous KV cache. |
| torch SDPA (auto-dispatch) | (B, H, S, D) | BF16, FP16 | Dispatcher chooses FA3 on Hopper, FA2 on Ampere, mem-efficient elsewhere. |
FA3's FP8 path requires both Q, K, V to be in FP8 (E4M3) with their scaling tensors. Mixing BF16 inputs with FP8 attention is not supported — the calling framework (Transformer Engine, vLLM, TensorRT-LLM) handles the quantisation step before the kernel call.
Workload patterns#
Three workload shapes cover most FA3 production usage: long-context training where the attention block dominates step time, inference prefill of long prompts, and inference decode with paged KV cache. Each maps to a different FA3 entry point and configuration, and each is exercised daily on Yobitel NeoCloud — Pattern A on multi-node H100 training pods running Megatron-LM with --use-flash-attn and sequence-parallel, Pattern B and C on the H100 / H200 inference fleets behind Yobibyte's vLLM and TensorRT-LLM endpoints.
Pattern A — Long-context training (32K-256K tokens) on a 32-node NeoCloud H100 training pod where FA3 is the difference between OOM and 'fits comfortably'. Pattern B — Inference prefill speedup; a 30K-token RAG prompt that took 8 seconds under SDPA takes 2 seconds under FA3 on a NeoCloud H100 inference node. Pattern C — Inference decode with the `flash_attn_with_kvcache` entry point; integrates with vLLM and TensorRT-LLM paged attention, the default backends on Yobibyte-managed inference endpoints.
# A — Long-context training: 128K context with FA3 + sequence parallel
# Megatron-LM with --use-flash-attn --sequence-parallel automatically uses FA3 on Hopper.
# Direct PyTorch usage:
from flash_attn import flash_attn_func
seq = 131072; batch = 1; heads = 32; head_dim = 128
q = torch.randn(batch, seq, heads, head_dim, device='cuda', dtype=torch.bfloat16)
k = torch.randn(batch, seq, heads, head_dim, device='cuda', dtype=torch.bfloat16)
v = torch.randn(batch, seq, heads, head_dim, device='cuda', dtype=torch.bfloat16)
out = flash_attn_func(q, k, v, causal=True)
# Peak HBM: ~5 GB (vs 67 GB if attention matrix were materialised at this length)
# B — Inference prefill speedup (vLLM does this under the hood)
from flash_attn import flash_attn_varlen_func
# Pack three prompts of length 8K, 16K, 30K into one tensor
cu_seqlens = torch.tensor([0, 8192, 24576, 54272], device='cuda', dtype=torch.int32)
max_seqlen = 30000
q = torch.randn(54272, heads, head_dim, device='cuda', dtype=torch.bfloat16)
k = torch.randn(54272, heads, head_dim, device='cuda', dtype=torch.bfloat16)
v = torch.randn(54272, heads, head_dim, device='cuda', dtype=torch.bfloat16)
out = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=True)
# C — Inference decode with KV cache (FA3 path; vLLM and TRT-LLM use this internally)
from flash_attn import flash_attn_with_kvcache
batch = 8; max_seq = 32768; heads_kv = 8
q_new = torch.randn(batch, 1, heads, head_dim, device='cuda', dtype=torch.bfloat16)
k_cache = torch.randn(batch, max_seq, heads_kv, head_dim, device='cuda', dtype=torch.bfloat16)
v_cache = torch.randn(batch, max_seq, heads_kv, head_dim, device='cuda', dtype=torch.bfloat16)
cache_lens = torch.tensor([12000, 8000, 30000, 5000, 16000, 24000, 18000, 9000],
device='cuda', dtype=torch.int32)
out = flash_attn_with_kvcache(q_new, k_cache, v_cache, cache_seqlens=cache_lens, causal=True)Sizing and performance#
The table below shows representative throughput for FA3 on H100 SXM5 at common LLM attention shapes (BF16, causal, head_dim 128, batch chosen to saturate the GPU). The comparison columns show the same workload under FA2 and naive PyTorch attention (where it fits). All numbers are in TFLOPS measured (not theoretical) at the attention operation only, averaged over 100 iterations.
- H100 SXM5 BF16 tensor-core peak is 989 TFLOPS; FA3 hits ~75 percent at long sequences (740 / 989).
- H100 FP8 tensor-core peak is 1979 TFLOPS; FA3 FP8 hits ~67-68 percent.
- Head_dim 64 is ~10 percent slower than 128 (more iterations per token); head_dim 256 is ~5 percent slower (register pressure).
- Sliding window adds ~3-5 percent overhead vs full causal at the same length.
- Varlen packed sequences add ~5 percent overhead vs uniform-length, but win on batch utilisation.
- On H200, throughput rises ~10-15 percent versus H100 because HBM bandwidth lifts memory-bound regions.
- On B200, FA3 is superseded by a Blackwell-tuned variant in the same upstream repo; expect another 1.5-2x on BF16 and 2-2.5x on FP4/FP8.
| Sequence | Heads | Batch | FA3 BF16 TFLOPS | FA3 FP8 TFLOPS | FA2 BF16 TFLOPS | Naive SDPA |
|---|---|---|---|---|---|---|
| 2,048 | 32 | 16 | 550 | 1050 | 320 | 85 |
| 4,096 | 32 | 8 | 620 | 1180 | 380 | 55 |
| 8,192 | 32 | 4 | 680 | 1260 | 440 | 32 |
| 16,384 | 32 | 2 | 720 | 1320 | 470 | OOM |
| 32,768 | 32 | 1 | 740 | 1340 | 485 | OOM |
| 65,536 | 32 | 1 | 740 | 1340 | 490 | OOM |
| 131,072 | 32 | 1 | 735 | 1330 | 470 | OOM |
| 256,000 | 32 | 1 | 720 | 1300 | 440 | OOM |
| 8,192 | 64 (GQA-8) | 4 | 660 | 1240 | 420 | OOM |
Limits and quotas#
FA3's hard limits come from the kernel's shape contracts and the GPU's resource caps. The table below documents the practical envelope as of flash-attn 2.7.
| Constraint | Value | Notes |
|---|---|---|
| head_dim supported | 32, 64, 96, 128, 160, 192, 224, 256 | Other dims fall back to a slower kernel. |
| Max sequence length | Hardware-bounded (HBM) | 256K demonstrated; only practical with SP/CP. |
| FP8 support | Hopper only (sm_90+) | Ampere falls back to FA2 BF16/FP16. |
| Architectures | sm_80 (Ampere), sm_90 (Hopper) | sm_100 (Blackwell) covered by a separate kernel path. |
| CUDA version | >= 12.3 | Pre-built wheels for 12.3, 12.4, 12.5, 12.6. |
| PyTorch version | >= 2.3 | Earlier versions may work but not officially supported. |
| Causal + sliding-window combo | Supported | Some less-common combinations fall back to a generic path. |
| Custom additive bias | Limited (ALiBi only) | Arbitrary biases must use a non-FA kernel. |
| Backward pass FP8 | Not yet upstream | FP8 forward + BF16 backward is the supported pattern. |
| Dropout | Supported in fwd+bwd | Fused into the kernel; deterministic with seed. |
Where it shows up in production stacks#
Every modern LLM training and inference stack links against flash-attn or vendors equivalent kernels. The integrations below are the ones to know operationally — knowing which stack uses FA3 by default versus which needs an explicit flag is the difference between getting the 1.5-2x uplift and not.
| Stack | FA3 usage | How it is enabled |
|---|---|---|
| Megatron-LM | Native | --use-flash-attn flag; FA3 selected automatically on Hopper. |
| NeMo Framework | Native | model.use_flash_attn=True in Hydra config. |
| DeepSpeed (HF Trainer) | Via HF | attn_implementation='flash_attention_2' on model load. |
| FSDP / torchtune | Via PyTorch SDPA | Automatic on Hopper if flash-attn installed. |
| HuggingFace Transformers | Native flag | AutoModel.from_pretrained(..., attn_implementation='flash_attention_2'). |
| vLLM | Native | Default backend on Hopper; selectable via VLLM_ATTENTION_BACKEND=FLASH_ATTN_V3. |
| TensorRT-LLM | Engine plugin | Compiled into the engine at build time when targeting sm_90. |
| SGLang | Native | Default backend on Hopper; uses the varlen API for batched prefill. |
| JAX / Pallas | Separate kernel | Equivalent kernel under the JAX-Triton kernels for TPU and GPU. |
If you have installed flash-attn and your throughput has not improved, check the framework's actual dispatch. HuggingFace falls back silently to eager attention if the model's attn_implementation flag is unset; vLLM may default to FlashInfer if VLLM_ATTENTION_BACKEND is set; PyTorch SDPA may pick the memory-efficient backend on unsupported shapes. Print the backend after warm-up.
Migration and alternatives#
Three migration paths dominate: from PyTorch SDPA (the most common starting point), from FA2 (Ampere -> Hopper hardware upgrade), and from a custom or naive attention implementation. None is intrusive; FA3 is a library import or a one-line framework flag, not a refactor.
| From | Effort | Throughput delta | Caveats |
|---|---|---|---|
| torch.nn.functional.scaled_dot_product_attention | One-line import | 3-5x at long context on H100 | SDPA dispatcher already chooses FA when available; explicit call removes dispatch overhead. |
| FlashAttention 2 | Import change only | 1.5-2x on H100 | On Ampere, FA2 is fine — FA3 is no-op (falls back). |
| Naive PyTorch attention (matmul + softmax + matmul) | Import flash_attn_func; reshape tensors | 10-50x at long context | Must transpose to (B, S, H, D) layout. |
| xFormers memory_efficient_attention | Direct call swap | 1.2-1.5x typical | xFormers has its own dispatch; FA3 is more aggressive on Hopper. |
| Custom CUDA kernel | Re-evaluate | Likely slower than FA3 unless very specialised | FA3 is the SOTA reference; only beat it for non-standard shapes. |
| Triton fused attention | Direct call swap | Variable | Triton kernels are easier to modify; FA3 is faster but harder to fork. |
Pitfalls#
- FA3 is Hopper-only. On Ampere the same wheel falls back to FA2 silently; check `flash_attn.__version__` and the dispatch log to confirm.
- FP8 attention requires the surrounding model to be FP8-aware (Transformer Engine). Bolting FP8 attention onto a BF16 model without calibration produces accuracy drift.
- The torch SDPA dispatcher only picks FA3 when input shapes, dtypes and masks match the kernel's supported envelope. Non-power-of-2 head_dim, fp32 inputs, or custom biases drop you onto a slow path.
- Building from source needs CUDA 12.3+, a recent g++ (11+), and 30+ minutes of compile time. Always prefer the pre-built wheels for the matching CUDA + PyTorch combination.
- FA3's varlen path requires explicit cu_seqlens prefix-sum arrays — easy to mis-construct (off-by-one on the trailing element). Use the official packing helpers in HuggingFace Transformers or vLLM.
- On Blackwell (sm_100), the FA3 v3 kernel runs but is superseded by a Blackwell-specific kernel published in the same repo; ensure your wheel includes the sm_100 path or you will leave perf on the table.
- Backward pass through FP8 attention is not yet fully upstream — the supported production pattern is FP8 forward + BF16 backward, handled transparently by Transformer Engine.
Lineage#
| Version | Year | Paper / source | Hardware target | Headline win |
|---|---|---|---|---|
| FA1 | 2022 | Dao et al., arXiv:2205.14135 | Ampere (A100) | O(N) memory; ~25 percent of A100 peak. |
| FA2 | 2023 | Dao, arXiv:2307.08691 | Ampere + early Hopper | ~50-72 percent of A100 peak; 2x FA1. |
| FA3 | 2024 | Shah et al., arXiv:2407.08608 | Hopper (H100, H200) | ~75 percent of H100 BF16 peak; warp specialisation + ping-pong + FP8 Hadamard. |
| FA3 Blackwell | 2025 update | Same repo, sm_100 path | Blackwell (B200, B100) | Continues the pattern; FP4 first-class. |
Where this fits in the Yobitel stack#
FA3 is the default attention kernel on every Yobitel sovereign GPU tenancy running Hopper or Blackwell silicon. The NGC-derived training containers (Megatron-LM, NeMo, DeepSpeed, FSDP) and inference containers (vLLM, TensorRT-LLM, SGLang) all ship with flash-attn pre-installed and the appropriate dispatch flags set. Yobitel's reference Slurm and Kubernetes launch templates pin the flash-attn wheel version against the CUDA and PyTorch versions in the container, so the kernel selection is reproducible across cluster generations.
InferenceBench v3, the public benchmark Yobitel maintains, separates engines that use FA3 from those that do not in its throughput tables — the typical FA2 vs FA3 delta on H100 inference is large enough to dominate engine-versus-engine comparison if not controlled for. Customers planning capacity on the Yobibyte platform see FA3-enabled numbers by default.
References
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision · arXiv (Shah et al., 2024)
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning · arXiv (Dao, 2023)
- FlashAttention: Fast and Memory-Efficient Exact Attention · arXiv (Dao et al., 2022)
- FlashAttention on GitHub · GitHub (Dao-AILab)
- FlashAttention-3 — Together AI blog post · Together AI
- NVIDIA Hopper Architecture In-Depth · NVIDIA
- PyTorch Scaled Dot Product Attention · PyTorch