Ai大模型训练教程实战篇:提供训练故障定位手册,覆盖Loss发散、NaN/Inf、梯度爆炸与性能骤降的常见根因与排查路径,包含可复现与日志指标、学习率与warmup调整、混合精度处理、梯度裁剪、标签与mask校验、数据异常定位等可落地步骤。
这份手册解决什么问题
在 Ai大模型训练教程 系列里,很多读者会在“模型终于跑起来”之后,迅速撞上第二堵墙:训练不稳定。典型表现包括:
- Loss 突然发散(一路上升或剧烈震荡)
- 出现 NaN/Inf(loss=nan、grad=nan、logits=inf)
- 梯度爆炸(grad norm 突然到 1e3/1e6,随后模型崩)
- 性能骤降(训练 loss 看似正常,但验证集指标断崖式掉,或突然变差后再也回不来)
这篇文章提供一条“从现象到根因”的排查路径:先用最少的检查把问题归类,再沿着对应分支定位。你可以把它当成训练现场的故障处理 SOP。
先做三件事:把故障从“玄学”变成“可复现”
很多训练故障无法定位的原因不是技术,而是无法稳定复现、缺少关键日志。
1)固定可复现实验
建议至少固定:
- 随机种子(Python/NumPy/PyTorch/CUDA)
- 数据顺序(DataLoader shuffle 的 seed;分布式下每个 rank 的 sampler seed)
- 版本(PyTorch / CUDA / xFormers / FlashAttn / Transformers / DeepSpeed)
并且记录:git commit、启动参数、配置文件快照。
2)把“训练核心指标”打齐
最低限度建议每 N step 打印/记录:
- loss(以及各子 loss)
- learning rate
- grad norm(全局)
- weight norm(可选)
- 是否启用 AMP、当前 loss scale(若使用 GradScaler)
- 样本统计:token 数、序列长度分布(尤其是长尾)
- 评估指标(val loss / ppl / accuracy / rouge 等)
3)保存“故障现场”
- 出现异常的 step:保存模型权重(或至少 optimizer state + scaler state)
- 保存出错 batch 的输入(input_ids、attention_mask、labels,必要时原始文本)
做到这些后,下面的排查才会高效。
总体排查路线图(先分类再深入)
遇到训练不稳定,先回答四个问题:
- 是否必现:同配置同 seed 是否在相近 step 必然崩?
- 是否与数据相关:崩溃 step 的 batch 是否含特定样本/超长序列/异常 token?
- 是否与数值精度相关:FP16/BF16 切换、AMP 开关是否改变现象?
- 是否与优化器/学习率相关:调小 LR、增加 warmup 是否立刻缓解?
接下来按症状进入对应章节。
故障一:Loss 发散(上升/剧烈震荡)
Loss 发散大多是“优化不稳定”或“数据/目标不一致”。按优先级排:
1)学习率与 warmup(最常见)
现象:训练前几百/几千 step loss 上升明显,grad norm 同步变大。
排查与修复步骤:
- 把 LR 降低到原来的 1/2 或 1/4(先验证是否能稳定)
- 增大 warmup:例如从 1% 提到 3%~5% 总步数
- 使用更平滑的 schedule(cosine/linear),避免突然跳变
- 如果用了分段 LR 或重启(restart),检查边界 step 是否造成震荡
建议经验值(仅供起步):
- 全参微调大模型:LR 往往在 1e-5 ~ 2e-4 之间,受 batch size 与数据难度影响大
- LoRA/QLoRA:可比全参略大,但依然建议从 1e-4 量级起试
2)有效 batch size 与梯度累积不匹配
现象:换了 GPU 数、改了 gradient_accumulation_steps 后突然不稳。
要点:很多人只改了 micro-batch,但忘了学习率需要按有效 batch size 调整。
做法:
- 计算有效 batch:
global_batch = micro_batch * grad_accum * world_size - 若 global_batch 变大,通常可线性增大 LR(但大模型并不总是严格线性);若变小,LR 需要相应减小
3)Label/Mask/Shift 错误(loss 在“训错目标”)
现象:loss 不但不降,甚至持续升高;或训练 loss 降但生成质量变差。
检查清单:
- causal LM 是否正确 shift:预测 token[t] 用 input_ids[t-1]
- padding 部分 label 是否设为 ignore_index(如 -100)
- attention_mask 是否与 padding 对齐
- 是否把 prompt 部分错误计入 loss(SFT 常见:只对 assistant 回答算 loss)
快速验证:抽 1 条样本,打印:
- 解码后的 input 文本
- labels 中被计算 loss 的位置(非 -100 的 token)
确保“算 loss 的 token”与你想训练的目标一致。
4)优化器超参(betas、eps、weight_decay)
现象:LR 看起来合理,但依旧抖动很大。
建议排查:
- AdamW 的
eps过小在混合精度下更易数值不稳(常用 1e-8 或 1e-6) weight_decay过大可能导致训练初期不稳(尤其全参微调)- betas:常见 (0.9, 0.95) 或 (0.9, 0.999),过激进可能放大震荡
5)数据分布突变/脏数据
现象:训练中途某个区间 loss 抬头,之后持续不稳。
操作:
- 统计每个 batch 的长度、特殊 token 比例(如未知字符、过多控制符)
- 对异常 step 复现:单独取该 batch 前向/反向,看是否出现极端 logits 或异常 loss
- 对长文本启用长度裁剪(max_length)、或按长度分桶(bucketing)减少极端 batch
故障二:出现 NaN/Inf(loss/grad/activation)
NaN 是“数值爆炸”的结果,而不是原因。关键是确定它第一次出现在哪个环节:前向、loss、反向、优化器 step。
1)先定位:NaN 首次出现的位置
建议按顺序插入检查(可只在 debug run 开启):
- 前向输出 logits 是否含 inf/nan
- loss 是否为 nan
- 反向后 grad 是否为 nan
- optimizer.step 后参数是否变 nan
如果你不想到处改代码,至少记录:
- 该 step 的 grad norm
- AMP 的 loss scale(如果 scale 一路下降到极小仍 nan,通常不是单纯 overflow)
2)混合精度相关:FP16 溢出、loss scale 失效
现象:FP16 下 nan,切 BF16/FP32 正常。
处理路径:
- 优先尝试 BF16(硬件支持时)——它比 FP16 更不容易溢出
若必须 FP16:
- 开启 GradScaler
- 降低 LR
- 启用 gradient clipping(见后文)
- 检查是否有不支持 FP16 的算子(某些自定义 op / 旧版 kernel)
3)softmax / log / exp 的数值不稳定
常见触发:
- logits 极大导致 softmax 溢出
- 对 0 取 log(log(0))
- 在 loss 内部做了不稳定的归一化
建议:
- 尽量使用框架提供的稳定实现(如
F.cross_entropy而不是手写 softmax+log) - 如果自定义 loss:加上 clamp,例如
x = x.clamp(min=1e-12)再 log - 检查 attention mask 是否把全部位置 mask 掉(softmax 变成全 -inf)
4)梯度里有 NaN:优先怀疑“非法样本”或“除零”
数据侧检查:
- 标签是否越界(分类任务 label >= num_classes 会导致异常)
- 是否出现空序列、全 padding、全 mask
- 文本中是否有异常字符导致 tokenizer 产生意外结果(极长重复、特殊控制符)
工程侧检查:
- 分布式下是否有 rank 的数据为空(最后一个 batch drop_last 设置不一致)
- 梯度累积时是否错误地重复 backward / 清零时机不对
故障三:梯度爆炸(grad norm 飙升,随后 loss 崩)
梯度爆炸通常与:过大学习率、序列过长、初始化/归一化异常、混合精度溢出、某层输出异常有关。
1)先加“安全带”:Gradient Clipping
目标:让训练先稳定下来,再找根因。
- 推荐使用全局范数裁剪:
clip_grad_norm_(params, max_norm) - 大模型常用 max_norm:0.5、1.0、或 5.0(视任务而定)
注意:裁剪不是万能的,如果每步都在裁剪,说明根因仍在(LR 太大/数据极端)。
2)检查是否存在极端长序列 batch
现象:在某个 step 突然爆炸,而该 batch 的长度分布异常。
建议:
- 训练时记录每 step 的 max seq len
- 对超过阈值的样本:截断、分段、或降低其采样权重
- 用长度分桶:让一个 batch 内长度更接近,减少极端 padding/计算不均
3)归一化层与残差相关问题
大模型里 LayerNorm/RMSNorm 对稳定性关键。
排查:
- 是否错误地冻结/解冻了 norm 参数(某些微调策略会冻结 norm,可能影响收敛)
- 是否在错误位置使用了 dropout 或改变了残差比例
- 是否加载了不匹配的 checkpoint(结构改动但强行 load)导致某些层权重异常
4)优化器状态损坏或不连续
现象:从 checkpoint 恢复后突然爆炸。
检查:
- 是否同时恢复了 optimizer state、lr scheduler state、GradScaler state
- 是否更改了参数组(param groups)导致 optimizer state 对不上
建议:
- 恢复训练时尽量保持参数组一致
- 若必须改结构:宁可只加载模型权重,重新初始化优化器,并降低 LR 重启
故障四:性能骤降(训练正常但验证/生成质量突然变差)
性能骤降往往不是数值崩溃,而是“学偏了”或“评估/数据管道变了”。
1)先确认:评估是否一致
常见坑:
- 训练用的 tokenizer 与评估用的不一致(版本不同、special tokens 不同)
- 评估时 max_length / truncation 策略变化
- 生成评估时 sampling 策略变化(temperature/top_p)导致指标波动
做法:固定评估脚本与配置,把“评估差异”排除掉。
2)数据混入:指令格式漂移或标签污染
现象:某一轮之后模型回答风格变怪、出现大量拒答/胡言乱语。
排查步骤:
- 检查近期合并的数据集:是否混入不同模板(role 标签不一致)
- 抽样检查:assistant 回复是否被错误放进了 user 字段
- 检查是否混入大量低质量重复样本(重复会导致过拟合某些模式)
建议:
- 建立数据 schema 校验(字段、角色、长度、空值)
- 对新数据先小规模训练做 A/B
3)灾难性遗忘(Catastrophic Forgetting)
现象:新任务指标上升,但通用能力/旧任务指标突然下降。
应对:
- 混合训练:加入一定比例通用数据或旧任务数据(例如 70/30 或 80/20)
- 使用较小 LR + 更长训练
- 采用参数高效微调(LoRA)并控制可训练参数范围
4)过拟合或分布偏移
现象:训练 loss 持续下降,val loss 上升;或 val 指标突然掉。
建议动作:
- 增加验证频率:缩短“发现问题”的滞后
- 早停(early stopping)或回滚到最佳 checkpoint
- 增强正则:dropout、weight decay(适度)、数据增强
一套可落地的“10 步排查清单”(推荐照做)
- 记录故障 step:保存该 step 前后的日志、batch、checkpoint
- 关闭分布式复杂度:能否单卡复现?不能则优先怀疑通信/数据切分
- 切换精度:FP16→BF16→FP32,看 NaN 是否消失
- LR 砍半 + warmup 加倍:最快验证是否优化不稳
- 开启 grad clipping:max_norm=1.0 先保命
- 检查 label/mask/ignore_index:打印一条样本的 loss 参与位置
- 检查异常 batch:长度、全 mask、空样本、label 越界
- 恢复训练一致性:optimizer/scheduler/scaler state 是否完整恢复
- 对比最近改动:数据、模板、tokenizer、依赖库、kernel
- 最小复现脚本:用固定 batch 跑 100 step,看是否稳定
做到第 4~6 步,通常能把 80% 的训练崩溃问题定位到“学习率/精度/标签”这三类根因。
训练现场的建议配置(偏通用,便于减少故障)
1)推荐默认:BF16 + 稳定算子
- 优先 BF16(若 GPU 支持)
- 使用框架稳定实现的 loss
- 关注 FlashAttention/xFormers 版本兼容,升级或降级都要做小跑验证
2)把“监控”当功能交付
- 每次训练都输出:grad norm、lr、loss scale、seq len
- 异常报警:出现 NaN/Inf 立刻停止并 dump batch
3)小步快跑:先用小数据/小步数验稳定
在全量训练前:
- 用 1% 数据跑 500~2000 step
- 验证不会 NaN、loss 能下降、评估指标方向正确
- 再扩到全量
结语:把不稳定当作“可工程化”的问题
大模型训练的故障定位,不靠“多试几次”,而靠可复现、可观测、可回滚。当你能回答“NaN 是从前向还是反向开始的”“是哪一批数据触发的”“切换 BF16 是否消失”,问题基本就不再神秘。
后续如果你愿意把具体日志字段(loss/grad norm/lr/loss scale/seq len)和崩溃 step 的 batch 特征贴出来,我也可以基于这份手册帮你把排查路径进一步收敛到最可能的 1~2 个根因。
Prev:持续预训练与增量微调:领域迁移、遗忘缓解与混合数据配比策略