面向Ai大模型训练教程的实操指南,系统讲解FP16与BF16差异、PyTorch AMP用法、GradScaler与loss scaling机制,并提供NaN/Inf/溢出与跳步的排查流程、日志指标与稳定性建议,帮助大模型训练提速并保持收敛稳定。
混合精度训练指南:FP16、BF16、AMP、GradScaler与溢出排查
在《Ai大模型训练教程:从入门到实战落地的系统课程》里,混合精度训练几乎是你从“能跑起来”到“跑得快、跑得稳、能扩展”的必经之路。大模型训练的瓶颈常常在显存和吞吐:显存不够就得减 batch、减序列长度、改模型结构;吞吐不够则训练周期拉长。混合精度(Mixed Precision)通过在不显著牺牲精度的前提下,把一部分计算/存储从 FP32 换成 FP16 或 BF16,通常能带来更高吞吐、更低显存占用。
这篇文章聚焦实操:如何在 PyTorch 中正确使用 FP16/BF16、AMP、GradScaler;遇到溢出(overflow)、NaN、loss 变成 0 或不收敛时如何系统排查。
1. 混合精度到底混了什么:权重、激活、梯度、优化器状态
混合精度不是“把模型全改成半精度”这么简单。一个常见且稳定的做法是:
- 前向与大部分算子用 FP16/BF16(激活和中间结果多为半精度),获得吞吐收益
- 参数(权重)通常仍以 FP32 master copy 维护(尤其是 FP16 训练),避免更新时数值分辨率不够
- 梯度在反向可能是半精度产生,但在更新前会做缩放/反缩放,保证不下溢
- 优化器状态(如 Adam 的 m、v)一般保持 FP32,更稳定
具体到 PyTorch AMP:
- 使用
autocast控制前向/反向中算子的 dtype - 使用
GradScaler做 loss scaling(主要针对 FP16)
2. FP16 vs BF16:选型与“何时用哪个”
2.1 FP16(IEEE half)
优点:
- 在很多 GPU 上有成熟的 Tensor Core 加速路径(吞吐提升明显)
- 显存占用低(相对 FP32)
缺点(也是你最常见的坑来源):
- 指数范围较小,容易 overflow(变 Inf)
- 有效精度较低,容易出现梯度下溢(变 0),需要 loss scaling
2.2 BF16(bfloat16)
优点:
- 指数范围接近 FP32,不容易 overflow,训练更“省心”
- 通常不需要 GradScaler(很多场景下可以直接跑)
缺点:
- 尾数精度比 FP16 还少一些,但大模型训练通常更看重指数范围
- 依赖硬件支持(如 A100/H100 等对 BF16 更友好)
2.3 实操建议:如何选择
- 硬件支持 BF16 且你追求稳定优先:优先 BF16(尤其是 Transformer 类模型)
- 只有 FP16 更快或 BF16 支持不佳:用 FP16 + AMP + GradScaler
- 你遇到频繁溢出/NaN,且硬件支持 BF16:从 FP16 切 BF16 往往立竿见影
3. AMP(Automatic Mixed Precision)工作机制:autocast 不是“全半精度”
PyTorch AMP 的核心是:
torch.autocast(device_type='cuda', dtype=...):在上下文内,框架会对算子进行“安全”的 dtype 选择- 对容易数值不稳定的算子可能仍保持 FP32
- 对 GEMM/Conv 等吞吐算子倾向用半精度
3.1 推荐用法:训练循环模板(FP16 + GradScaler)
下面模板适合你在大模型训练脚手架里直接替换:
3.1.1 单卡/单进程示例
- 关键点:
autocast包住 forward + loss 计算;GradScaler包住 backward 与 step - 关键点:梯度裁剪要放在
unscale_之后
示例步骤:
1) 初始化 scaler
2) forward 用 autocast
3) scaler.scale(loss).backward()
4) scaler.unscale_(optimizer) 后再 clip
5) scaler.step + scaler.update
3.1.2 BF16 示例(通常不需要 GradScaler)
- 关键点:
autocast(dtype=torch.bfloat16) - 关键点:多数情况下
GradScaler(enabled=False)或直接不用 scaler
4. GradScaler:它解决的到底是“溢出”还是“下溢”?
很多人误解 GradScaler 是“防溢出”。更准确地说:
- FP16 最大问题之一是梯度下溢(underflow):梯度非常小,半精度表示会变成 0
Loss scaling 的思路:
- 先把 loss 乘一个 scale(比如 2^16)
- 这样反向得到的梯度也被放大,避免在 FP16 表示中变 0
- 更新前再把梯度除回去(unscale)
同时,scale 太大可能导致梯度变 Inf/NaN(这才是 overflow),GradScaler 会检测到并自动缩小 scale,跳过这一步更新。
4.1 你应该关注的 scaler 现象
- scale 频繁下降:说明经常检测到 Inf/NaN,训练不稳定(可能 LR 太大、初始化/归一化问题、某些算子不稳定)
- scale 长期不变或缓慢上升:通常正常
- 大量 step 被跳过(skipped step):等价于有效学习率变小,训练会变慢甚至不收敛
5. 实操:在 PyTorch 中正确开启 FP16/BF16 AMP
5.1 基础训练伪代码(适用于大多数 Transformer)
你可以对照检查你的训练代码是否具备这些关键点:
autocast包 forward- FP16 使用
GradScaler zero_grad(set_to_none=True)- 梯度裁剪在
unscale_之后 - 遇到梯度累计时 scaler 逻辑仍正确
5.2 梯度累计(gradient accumulation)下的注意事项
如果你用 accum_steps > 1:
- backward 每次都要调用
scaler.scale(loss / accum_steps).backward()(或先除后 scale) - 只有在达到累积步时才
scaler.step(optimizer) - 同样只在 step 前做一次
unscale_和梯度裁剪
常见坑:
- 每个 micro-step 都 step,会破坏累计逻辑
- loss 未除以 accum_steps,等价扩大 LR
5.3 DDP/FSDP/DeepSpeed 简述(避免“讲目录”,只给落地要点)
- DDP:AMP 的用法基本不变,主要确认每个进程都创建自己的 scaler
- FSDP:如果启用混合精度策略,通常交给 FSDP 的 mixed_precision 配置;仍需理解某些层保持 FP32(如 LayerNorm)更稳
- DeepSpeed:如果用 ZeRO + fp16/bf16 配置,尽量让“一个地方决定精度”,避免 AMP 与 DeepSpeed fp16 同时混用导致行为混乱
6. 溢出/NaN/Inf 排查手册:从“症状”到“定位”
混合精度最常见的训练事故是:
- loss 变成 NaN
- 梯度出现 Inf
- loss 突然变 0 或不再下降
- scaler 不断降低 scale,step 大量被跳过
下面按“最快排除法”给出排查步骤。
6.1 第一步:确认是数值问题还是数据问题
1) 固定随机种子,取一个很小的数据子集(例如 8~32 条样本)
2) 在同一批数据上跑 50~200 step,观察:
- 是否必现 NaN
- 是否与某个 batch 强相关(可能是脏数据/异常值)
数据问题典型来源:
- 输入 token 全是 padding 或全是同一个 token(损失异常)
- label 越界(分类任务)或存在 NaN(回归任务)
- 序列长度极端(某些样本特别长导致激活爆炸)
6.2 第二步:快速切换精度验证
用“对照实验”快速定位问题层级:
- 改用 BF16(若可用):若问题消失,多半是 FP16 的动态范围问题
- 改用纯 FP32:若 FP32 仍 NaN,通常不是混合精度本身,而是 LR/损失实现/数据
- 关闭某些 fused/flash 算子(如 flash-attn、fused Adam):某些实现对极端值更敏感,先回退到标准实现验证
6.3 第三步:检查学习率、warmup、梯度裁剪
混合精度下,过大的 LR 更容易触发 Inf/NaN。你可以按以下顺序调整:
1) 把学习率降低 2~10 倍试跑 200 step
2) 增加 warmup(如 1%~5% 总步数)
3) 开启或收紧 梯度裁剪:clip_grad_norm_ 1.0 或 0.5
4) 若用 AdamW:检查 betas、eps(过小 eps 有时更不稳)
可操作建议:
- Transformer 预训练常用
clip_grad_norm_=1.0 - 当你看到 scaler 经常跳步:先降 LR,再看 scale 是否稳定
6.4 第四步:定位“第一个出现 NaN/Inf 的张量/层”
6.4.1 打开 anomaly detection(代价高,短跑用)
torch.autograd.set_detect_anomaly(True)- 用小数据跑几步,报错栈能指向可疑算子
6.4.2 在关键节点插入数值检查
对以下位置做 isfinite 检查,找到最早变坏的位置:
- embedding 输出
- attention logits(softmax 前)
- softmax 输出
- LayerNorm 输出
- loss 输入(logits / labels)
典型问题模式:
- attention logits 爆炸:mask 处理不当、缩放因子错误、序列长度极大
- softmax NaN:logits 中有 Inf 或极大值;可尝试在 FP32 下做 softmax(或用稳定实现)
6.5 第五步:GradScaler 相关的“假象”与修复
6.5.1 症状:scale 一直降、step 经常跳过
可能原因:
- LR 太大 / 没 warmup
- 梯度裁剪没在 unscale 后做(导致裁剪无效)
- 某个 batch 极端(数据异常)
修复清单:
- 确保顺序:
scaler.unscale_ -> clip -> scaler.step - 记录每步的
scaler.get_scale(),统计跳步比例 - 过滤异常样本(超长、异常 label)
6.5.2 症状:loss 变 0 或几乎不变
可能原因:
- underflow 导致梯度大量为 0(尤其是没用 scaler 的 FP16)
- 梯度累计时 loss 除法/缩放逻辑错误
修复:
- FP16 必须启用 GradScaler(除非你非常确定无需)
- 检查 loss 是否除以 accum_steps
- 输出梯度统计:非零比例、norm 分布
7. 训练稳定性“工程化”建议:让混合精度少踩坑
7.1 哪些算子/模块建议保持 FP32
在一些场景中,强行半精度会引发不稳定:
- LayerNorm / RMSNorm:很多框架会默认用 FP32 累加
- softmax/log-softmax:对数值稳定性敏感,可在 AMP 策略中让其在 FP32 计算
- 损失函数(特别是交叉熵的内部 logsumexp):保持 FP32 往往更稳
实操原则:
- 不要“手工把模型
.half()然后全跑 FP16”,优先用 AMP 让框架做选择
7.2 初始化与正则:小改动大收益
- 使用成熟初始化(Transformer 常用的默认初始化或 xavier/kaiming 的正确变体)
- 适当 weight decay(AdamW)
- label smoothing(分类)有时能降低极端 logit
7.3 日志监控:别等 NaN 了才看
建议最少记录:
- loss
- grad norm(裁剪前后)
- scaler scale(FP16)
max(abs(logits))或注意力 logits 的统计(可抽样)- 跳步次数(scaler overflow 次数)
这样你能在“变 NaN 前”的几百步看到征兆:grad norm 飙升、scale 连续下降等。
8. 一个可执行的排错流程(建议你贴到项目 Wiki)
当你在混合精度训练中遇到 NaN/Inf,按下面顺序做,基本能在 30~60 分钟内定位大类问题:
1) 小数据子集复现:固定种子 + 取 32 条样本跑 200 step
2) 切 FP32 验证:若 FP32 也炸,先别纠结 AMP,优先查 LR/数据/损失
3) 切 BF16 验证:若 BF16 正常、FP16 炸,优先认为是 FP16 范围/scale 问题
4) 检查训练循环顺序:autocast 范围、scaler 用法、unscale 后再裁剪、累计步逻辑
5) 降 LR + 加 warmup + 开裁剪:看 scale 是否稳定、跳步是否下降
6) 定位首个 NaN 节点:anomaly detection + isfinite 钩子
7) 回退可疑 fused/flash 算子:验证是否实现细节导致
9. 结语:把混合精度当成“性能功能 + 稳定性工程”
在 Ai 大模型训练落地中,混合精度不是锦上添花,而是训练效率的核心杠杆之一。正确姿势是:用 AMP 管理半精度算子选择;FP16 配 GradScaler 防止下溢并自动应对溢出;遇到 NaN 用对照实验(FP32/BF16)、日志与逐层定位把问题收敛到具体环节。
你可以把本文当作项目中的“混合精度运行手册”:先用推荐模板跑通,再用排查流程快速定位数值稳定性问题,最终实现“更快的吞吐 + 更稳定的训练”。
Prev:AdamW与学习率调度实战:Warmup、Cosine、Linear在大模型训练中的用法