TL;DR
- Frontier training runs span weeks on thousands of GPUs — hardware failure is statistical certainty, not exception. Checkpointing is the recovery mechanism.
- Three axes: cadence (how often), distribution (how to save sharded state efficiently), retention (which to keep).
- Modern stacks: PyTorch Distributed Checkpoint (DCP), DeepSpeed Universal Checkpoint, Megatron-LM distributed save — all support sharded async writes to local NVMe + tiered upload to object storage.
Overview#
A 70B-parameter training run on 1,000 H100s for two weeks experiences, on average, several hardware-induced interruptions — a failed NIC, a thermal event, a stuck NCCL collective. Without checkpointing, every interruption restarts the run from step zero. With checkpointing, the loss is bounded by the inter-checkpoint interval.
Designing a checkpoint strategy is a three-axis optimisation: cadence (steps between saves), distribution (how saves are written across the cluster), and retention (which historical checkpoints to keep).
Cadence#
Too frequent: I/O overhead dominates training time. Too infrequent: a failure costs a long re-do. The right interval is roughly the geometric mean of the per-checkpoint cost and the expected interval between failures. For 1,000-H100 runs, that typically lands at every 1,000-5,000 steps (15-60 minutes).
Asynchronous checkpointing — where the save happens in a background thread while training continues — relaxes the constraint. Modern frameworks (PyTorch DCP, DeepSpeed) support async checkpointing with stale-write tolerance, letting saves happen at lower cadence without blocking the training step.
Distribution#
With sharded state (FSDP, ZeRO-3, Megatron 3D parallelism), each worker holds only 1/N of the state. The naive approach — gather everything to rank 0 and write one file — does not scale; gather cost is O(model size) and rank 0 becomes the bottleneck.
The modern approach is sharded distributed checkpointing: each worker writes its shard directly to a shared filesystem (NVMe-backed Lustre, GPFS, or an S3-compatible object store), with a small metadata file describing the layout. PyTorch DCP, Megatron-LM, and DeepSpeed all implement this.
- Local-NVMe-first writes — fastest, then tiered to object storage in the background.
- Sharded format with metadata — N small files instead of one giant tar.
- Save the full DP group, not just one replica — if state has drifted (it shouldn't, but it can), at least one replica's copy is preserved.
- Atomic rename — never partially overwrite a 'good' checkpoint with an in-progress one.
Retention#
Keep one 'latest' checkpoint always. Keep N rotating checkpoints behind it (typical N = 3-5) so you can roll back past a known-bad step. Keep periodic 'milestone' checkpoints (every 10 % of training, or at the end of each epoch) for ablation studies and reproducibility. Garbage-collect everything else.
Storage matters: a 405B BF16 checkpoint is ~810 GB; keeping 100 of them is most of a petabyte. Tiered storage (NVMe → SSD → HDD → object) and aggressive cold-tier compression keep this manageable.
Saving optimiser state is non-negotiable for resumable training. A weights-only checkpoint cannot reproduce the post-resume trajectory — momentum, second-moment, and master-weight state matter. Plan for ~12 bytes/parameter for Adam at mixed precision.
Performance Characteristics#
- Sharded distributed write to local NVMe + Lustre: typically 5-20 GB/s aggregate across the cluster.
- Async checkpointing overhead: <2 % of step time at production scale.
- Resume time (load + redistribute) is usually 1-3 minutes for sub-100B models, longer for frontier scales.
Pitfalls#
- RNG state must be saved per worker; otherwise data ordering changes on resume.
- Learning-rate scheduler and gradient-accumulation counters must be saved; otherwise the LR jumps on resume.
- Checkpoints incompatible across frameworks — Megatron checkpoints need conversion to load in HuggingFace.
- Object-store-only checkpoints have minutes-to-tens-of-minutes restore time; pair with local NVMe staging.
- Disk pressure failures during checkpoint write are catastrophic — monitor free space, never write to '%full' filesystems.
Software#
- PyTorch Distributed Checkpoint (`torch.distributed.checkpoint` / DCP) — DTensor-aware sharded checkpointing.
- DeepSpeed Universal Checkpoint — convertible across parallelism configurations.
- Megatron-LM `--save` / `--load` — sharded with metadata, integrated with NeMo.
- FSDP `StateDictType.SHARDED_STATE_DICT` — required for non-OOM saves on large models.
- S5cmd / aws-cli for high-throughput parallel S3 uploads in the tiering layer.
References
- PyTorch Distributed Checkpoint documentation · PyTorch
- DeepSpeed Universal Checkpointing · GitHub (Microsoft)
- Meta on training Llama 3 — failure rates and checkpointing · Meta AI