本文为《Ai大模型训练教程》实战篇,系统讲解训练中断如何实现无损或近无损恢复:Checkpoint应保存的关键状态(模型/优化器/调度器/AMP/RNG)、原子写入与回退策略、随机种子与RNG状态管理、数据顺序与采样器offset恢复,以及分布式训练下的rank一致性与梯度累积边界保存,帮助你实现可验证的可复现续训流程。
为什么“无损恢复”在大模型训练中是刚需
在大模型训练里,“训练中断”几乎是常态:显卡驱动重启、集群抢占、作业超时、网络抖动、进程 OOM、手滑 kill、节点硬件故障……如果恢复策略不完善,会出现以下典型损失:
- 训练进度损失:只能从旧 checkpoint 继续,浪费数小时到数天计算。
- 指标不一致:同样的代码和数据,恢复后曲线突然跳动,难以定位问题。
- 不可复现:多卡/多机场景下,即使从 checkpoint 继续,结果也可能漂移。
本文属于《Ai大模型训练教程:从入门到实战落地的系统课程》系列,聚焦一个目标:训练被中断后,尽可能做到“无损恢复(或可控损失)”,并让实验可复现。核心围绕四件事:Checkpoint 策略、随机种子、数据顺序与可复现性。
定义:什么叫“无损恢复”?
在工程上,“无损”有两个层级:
1)训练态无损(理想目标)
恢复后从第 N step 继续训练,等价于从未中断:
- 全部模型参数一致
- 优化器状态一致(momentum、Adam 的一阶/二阶矩等)
- 学习率调度器状态一致(warmup、cosine 等)
- AMP/GradScaler 状态一致(混合精度)
- 数据迭代器位置一致(恢复到同一个样本序列与同一个 micro-batch)
- RNG(随机数发生器)状态一致(dropout、数据增强、采样等)
2)业务可接受的“近无损”(更常见)
允许丢失少量 step,例如只恢复到最近一次 checkpoint(每 500/1000 step 保存一次),但要求:
- 恢复后训练稳定,不发散
- 不出现“重复数据过多”或“跳过数据太多”
- 指标波动可解释、可控制
下面的策略会帮助你尽量逼近第一种,并在无法完全无损时,让损失可控且可解释。
Checkpoint 要保存什么:最小集合与完整集合
必须保存(建议作为“最小可恢复集合”)
- model_state_dict:模型权重参数
- optimizer_state_dict:优化器内部状态(如 Adam 的 exp_avg/exp_avg_sq)
- lr_scheduler_state_dict:学习率调度器状态
- global_step / consumed_samples / epoch / micro_step:训练进度计数
- scaler_state_dict(如使用 AMP):torch.cuda.amp.GradScaler 状态
- rng_state(强烈建议):包括 Python、NumPy、PyTorch CPU/GPU、(可选)各 rank 的 RNG 状态
推荐额外保存(提升可诊断性与一致性)
- 数据加载器/采样器状态:如 sampler 的 epoch、shuffle seed、当前 index
- 配置快照:超参、数据路径、模型结构哈希、git commit、依赖版本
- 分布式训练状态:world_size、rank 映射、ZeRO/FSDP 分片信息
- 训练日志游标:方便断点后继续写同一条曲线(非必须但实用)
Checkpoint 保存频率与策略:别只会“每 N step 保存一次”
1)两层 checkpoint:latest + periodic
实践里最稳的是“两层策略”:
- latest(滚动覆盖):每 X 分钟或每 Y step 保存一次,文件名固定
ckpt_latest,用于最大化减少损失。 - periodic(带版本号保留):每 K step 保存一次,如
ckpt_step_100000,用于回滚、对比、排障。
建议:
latest:5~15 分钟一次(取决于 ckpt 写盘耗时)periodic:1000~10000 step 一次(取决于 step 时长与预算)
2)保存触发条件优先用“时间 + step 双阈值”
仅按 step 可能在长 step(大 batch/慢 IO)时损失过大;仅按时间可能导致保存过于频繁。
可选规则:
- 距上次保存时间 > T 且 global_step % S == 0 时保存
3)原子写入与防止半个 checkpoint
中断最尴尬的是“写到一半作业挂了”。要做到:
- 写入到临时目录(如
ckpt_tmp) - 写完后再 rename 成正式目录(rename 在多数文件系统是原子的)
- 保存一个
manifest.json标记完整性
恢复时检查 manifest.json 或文件列表是否齐全,不齐全则回退到上一个 periodic。
随机种子:不仅要 set_seed,还要保存 RNG 状态
很多人以为“固定随机种子就可复现”,但训练中断恢复时,仅靠 seed 远远不够。原因是:训练进行到第 N step 时,各种随机源(dropout、数据增强、采样、mask 生成等)已经消耗了大量随机数。从 checkpoint 恢复想做到位级一致,必须恢复 RNG 的“状态”,而不只是初始 seed。
需要管理的随机源清单
- Python:
random - NumPy:
numpy.random - PyTorch CPU:
torch.random - PyTorch CUDA:
torch.cuda.random(每张卡可能不同) - 分布式场景:不同 rank 的 RNG 消耗节奏可能不同
实操建议
- 启动时固定 seed:保证同一配置下的“初始条件一致”。
- checkpoint 时保存 RNG 状态:恢复时 set 回去。
- 注意 DataLoader worker 的 seed:多进程 worker 会各自持有 RNG 状态,需要通过
worker_init_fn或生成器控制。
如果你用 PyTorch,至少要把以下状态放进 checkpoint:
random.getstate()np.random.get_state()torch.get_rng_state()torch.cuda.get_rng_state_all()(多 GPU)
数据顺序:决定了“恢复后是否重复/跳过数据”
训练能否无损恢复,最容易被忽视的其实是数据顺序(Data Order)。
1)为什么数据顺序会变
常见原因:
- DataLoader 使用
shuffle=True,每个 epoch 都会重新打乱 - 分布式 sampler 每个 epoch 需要
set_epoch(epoch)才能让不同 rank 的打乱一致 - worker 数变化(
num_workers)导致预取与取样节奏改变 - 动态 padding / bucketing(按长度分桶)在恢复后桶边界变化
- 数据集本身在训练期间发生变化(增量写入、文件更新、list 顺序变化)
2)两种恢复目标对应两种方案
方案 A:追求“严格无损”(推荐用于科研/对齐、回归对比)
核心思想:训练到哪里,数据迭代器就恢复到哪里。
可执行做法:
- 在 checkpoint 中保存
consumed_samples(全局已消费样本数)或global_step + micro_step。 - 使用可寻址的采样器:给定
seed + epoch + offset能生成确定的样本序列。 - 恢复时让 sampler/loader 从
offset处继续。
难点:标准 PyTorch DataLoader 不容易“从中间恢复 worker 队列状态”,因此工程上常用两种技巧:
- 以样本为单位的可重复索引序列:提前生成 index 列表(或可计算生成),再按 offset 切片。
- 以 step 为单位的稳定 batch 切分:确保
global_step -> batch_indices的映射确定。
方案 B:接受“近无损”(推荐用于大规模生产训练)
核心思想:恢复到最近 checkpoint 后,从该 step 开始重新走一遍后续数据,允许少量重复。
建议控制重复范围:
- checkpoint 间隔不要太大
- 记录
last_ckpt_consumed_samples,评估重复比例 - 若使用 streaming 数据,确保数据源在时间窗口内可重放
分布式训练的额外坑:每个 rank 的一致性
在 DDP/FSDP/DeepSpeed 场景,“无损恢复”要额外关注:
1)不同 rank 的 RNG 与数据切分
- 每个 rank 的 dropout RNG 消耗可能不同(尤其存在条件分支、动态 shape)。
- DistributedSampler 会把数据切给不同 rank,必须保证恢复时 rank 映射一致(hostfile 改了也可能影响)。
建议:
- 固定
world_size、rank拓扑,或在 checkpoint 里记录并在恢复时校验。 - 使用
DistributedSampler(..., seed=base_seed)并在每个 epoch 调用set_epoch(epoch)。
2)梯度累积(gradient accumulation)与 micro-step
如果你每 grad_accum_steps 才进行一次 optimizer.step,那么中断发生在 accumulation 中间时:
- 若不保存“当前累积到第几个 micro-step”,恢复后会导致等效 batch 改变。
工程建议:
- 在 checkpoint 保存
micro_step_in_accum。 - 最稳做法:只在“完成一次 optimizer.step 之后”保存 checkpoint(边界保存)。
- 如果必须任意时刻保存,则要把梯度缓存也纳入(成本较高)。
3)ZeRO/FSDP 分片状态
使用 DeepSpeed ZeRO 或 FSDP 时,optimizer state 和参数是分片的:
- 需要使用框架提供的官方 checkpoint API(如 deepspeed 的
save_checkpoint、FSDP 的 state_dict 类型选择)。 - 恢复时需要相同或兼容的并行配置,否则可能无法加载。
务必在 checkpoint 中写入并校验:
- 并行策略(dp/tp/pp)
- ZeRO stage、offload 配置
- 混合精度配置
可复现性:你能做到什么程度?
可复现性不是非黑即白,它有层级:
1)完全一致(bitwise identical)
极难,尤其是多 GPU、使用 fused kernel、flash-attn、不同驱动/库版本时。需要:
- 固定软件栈版本
- 关闭非确定性算子,启用 deterministic
- 控制所有 RNG 状态
- 保证计算顺序一致(多线程/通信顺序也可能影响)
2)统计一致(曲线高度接近、最终指标差异很小)
生产中更常见、成本更合理。策略:
- 固定 seed
- 固定数据版本与顺序逻辑
- 恢复 optimizer/scheduler/scaler
- 允许浮点非确定性带来的微小差异
3)业务一致(结论一致,指标在容忍区间)
用于快速迭代、容错性训练,重点在“稳定恢复、别崩”。
实操建议:可复现性清单
- 固定并记录:训练代码版本(git commit)、配置文件、依赖版本、CUDA/cuDNN、驱动
- 数据版本化:数据集 hash / 文件列表快照 / 数据时间窗口
- 记录关键超参:batch size、grad_accum、lr、warmup、weight_decay
- 记录分布式拓扑:world_size、并行策略
如果你使用 PyTorch 并追求确定性,可考虑(代价是速度下降):
torch.use_deterministic_algorithms(True)torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = True
一套可落地的“中断无损恢复”流程(建议照着做)
步骤 1:设计 checkpoint 内容结构
建议按如下字段组织(概念结构,不限定实现):
modeloptimizerschedulerscalerprogress: global_step, epoch, consumed_samples, micro_steprng: python/numpy/torch_cpu/torch_cuda_alldata: sampler_epoch, shuffle_seed, dataset_version, maybe offsetmeta: time, host, world_size, parallel_config, git_commit
步骤 2:统一保存时机(边界保存)
- 优先在完成一次
optimizer.step()后保存 - 若使用 gradient accumulation,只在 accumulation 完成点保存
- 保存前执行一次 barrier(分布式)以确保状态一致(按框架推荐方式)
步骤 3:原子写入与完整性校验
- 写临时目录 -> 写 manifest -> rename
- 恢复时:优先尝试 latest,若不完整则回退到最近 periodic
步骤 4:恢复逻辑要“先校验再加载”
恢复时做以下校验,避免“加载成功但训练悄悄变味”:
- world_size / 并行配置是否匹配
- 模型结构 hash 是否一致
- 数据版本是否一致
- 混合精度与优化器类型是否一致
步骤 5:恢复后做一次一致性冒烟测试
恢复后建议自动跑一个小检查:
- 打印恢复的 global_step、lr、loss scale
- 用同一 batch(固定样本)前向一次,确认 loss 合理
- 若你有中断前最后一次日志记录,可对比恢复后的第一个 step loss 是否在合理范围
常见故障与定位建议
1)恢复后 loss 突然飙升
可能原因:
- optimizer state 没加载(等于重新热身)
- scheduler 状态不对(lr 突变)
- scaler 状态丢失(AMP 下数值不稳定)
排查顺序:
- 打印 lr 是否与中断前一致
- 检查 optimizer 的 state 是否非空、key 数量是否匹配参数
- AMP 下检查 loss scale 是否恢复
2)恢复后训练变慢或显存变大
可能原因:
- cudnn benchmark/确定性设置变化
- batch 形状分布变化(数据顺序变了,动态 padding 让平均长度上升)
建议:
- 记录并对比恢复前后的 token/seq_len 分布
- 检查 DataLoader 的 bucketing 是否可复现
3)同一 checkpoint 多次恢复结果不一致
可能原因:
- 未恢复 RNG 状态
- DataLoader worker seed 不固定
- 使用了非确定性算子
建议:
- 先在单卡上验证 bitwise 或统计一致
- 再扩展到多卡,并固定软件栈与 NCCL 环境变量(必要时)
结语:把“能恢复”升级成“可验证地恢复”
在《Ai大模型训练教程》系列的实战里,真正拉开工程成熟度差距的,往往不是能不能训起来,而是:
- 中断后能否可靠续训
- 续训后曲线是否可解释、可对比
- 实验是否可复现、可审计
落地时建议你至少做到三点:
- checkpoint 保存 model/optimizer/scheduler/scaler + progress + RNG
- 数据顺序可控:记录 consumed_samples,并保证 sampler 逻辑确定
- 原子写入与恢复校验:坏 checkpoint 自动回退
做到这些,你的训练中断将从“事故”变成“可预期事件”,训练效率与实验可信度都会显著提升。
Prev:用单卡复现一个小型GPT预训练:数据准备、训练脚本与关键超参