AiSSN.com ©

在线Ai关键词排名GEO优化工具,让你的信息出现在Ai的回答中

混合精度训练指南:FP16、BF16、AMP、GradScaler与溢出排查
原始问题:

面向Ai大模型训练教程的实操指南,系统讲解FP16与BF16差异、PyTorch AMP用法、GradScaler与loss scaling机制,并提供NaN/Inf/溢出与跳步的排查流程、日志指标与稳定性建议,帮助大模型训练提速并保持收敛稳定。

混合精度训练指南:FP16、BF16、AMP、GradScaler与溢出排查

在《Ai大模型训练教程:从入门到实战落地的系统课程》里,混合精度训练几乎是你从“能跑起来”到“跑得快、跑得稳、能扩展”的必经之路。大模型训练的瓶颈常常在显存和吞吐:显存不够就得减 batch、减序列长度、改模型结构;吞吐不够则训练周期拉长。混合精度(Mixed Precision)通过在不显著牺牲精度的前提下,把一部分计算/存储从 FP32 换成 FP16 或 BF16,通常能带来更高吞吐、更低显存占用。

这篇文章聚焦实操:如何在 PyTorch 中正确使用 FP16/BF16、AMP、GradScaler;遇到溢出(overflow)、NaN、loss 变成 0 或不收敛时如何系统排查。


1. 混合精度到底混了什么:权重、激活、梯度、优化器状态

混合精度不是“把模型全改成半精度”这么简单。一个常见且稳定的做法是:

  • 前向与大部分算子用 FP16/BF16(激活和中间结果多为半精度),获得吞吐收益
  • 参数(权重)通常仍以 FP32 master copy 维护(尤其是 FP16 训练),避免更新时数值分辨率不够
  • 梯度在反向可能是半精度产生,但在更新前会做缩放/反缩放,保证不下溢
  • 优化器状态(如 Adam 的 m、v)一般保持 FP32,更稳定

具体到 PyTorch AMP:

  • 使用 autocast 控制前向/反向中算子的 dtype
  • 使用 GradScaler 做 loss scaling(主要针对 FP16)

2. FP16 vs BF16:选型与“何时用哪个”

2.1 FP16(IEEE half)

优点

  • 在很多 GPU 上有成熟的 Tensor Core 加速路径(吞吐提升明显)
  • 显存占用低(相对 FP32)

缺点(也是你最常见的坑来源):

  • 指数范围较小,容易 overflow(变 Inf)
  • 有效精度较低,容易出现梯度下溢(变 0),需要 loss scaling

2.2 BF16(bfloat16)

优点

  • 指数范围接近 FP32,不容易 overflow,训练更“省心”
  • 通常不需要 GradScaler(很多场景下可以直接跑)

缺点

  • 尾数精度比 FP16 还少一些,但大模型训练通常更看重指数范围
  • 依赖硬件支持(如 A100/H100 等对 BF16 更友好)

2.3 实操建议:如何选择

  • 硬件支持 BF16 且你追求稳定优先:优先 BF16(尤其是 Transformer 类模型)
  • 只有 FP16 更快或 BF16 支持不佳:用 FP16 + AMP + GradScaler
  • 你遇到频繁溢出/NaN,且硬件支持 BF16:从 FP16 切 BF16 往往立竿见影

3. AMP(Automatic Mixed Precision)工作机制:autocast 不是“全半精度”

PyTorch AMP 的核心是:

  • torch.autocast(device_type='cuda', dtype=...):在上下文内,框架会对算子进行“安全”的 dtype 选择

    • 对容易数值不稳定的算子可能仍保持 FP32
    • 对 GEMM/Conv 等吞吐算子倾向用半精度

3.1 推荐用法:训练循环模板(FP16 + GradScaler)

下面模板适合你在大模型训练脚手架里直接替换:

3.1.1 单卡/单进程示例

  • 关键点:autocast 包住 forward + loss 计算;GradScaler 包住 backward 与 step
  • 关键点:梯度裁剪要放在 unscale_ 之后

示例步骤:

1) 初始化 scaler
2) forward 用 autocast
3) scaler.scale(loss).backward()
4) scaler.unscale_(optimizer) 后再 clip
5) scaler.step + scaler.update

3.1.2 BF16 示例(通常不需要 GradScaler)

  • 关键点:autocast(dtype=torch.bfloat16)
  • 关键点:多数情况下 GradScaler(enabled=False) 或直接不用 scaler

4. GradScaler:它解决的到底是“溢出”还是“下溢”?

很多人误解 GradScaler 是“防溢出”。更准确地说:

  • FP16 最大问题之一是梯度下溢(underflow):梯度非常小,半精度表示会变成 0
  • Loss scaling 的思路:

    • 先把 loss 乘一个 scale(比如 2^16)
    • 这样反向得到的梯度也被放大,避免在 FP16 表示中变 0
    • 更新前再把梯度除回去(unscale)

同时,scale 太大可能导致梯度变 Inf/NaN(这才是 overflow),GradScaler 会检测到并自动缩小 scale,跳过这一步更新。

4.1 你应该关注的 scaler 现象

  • scale 频繁下降:说明经常检测到 Inf/NaN,训练不稳定(可能 LR 太大、初始化/归一化问题、某些算子不稳定)
  • scale 长期不变或缓慢上升:通常正常
  • 大量 step 被跳过(skipped step):等价于有效学习率变小,训练会变慢甚至不收敛

5. 实操:在 PyTorch 中正确开启 FP16/BF16 AMP

5.1 基础训练伪代码(适用于大多数 Transformer)

你可以对照检查你的训练代码是否具备这些关键点:

  • autocast 包 forward
  • FP16 使用 GradScaler
  • zero_grad(set_to_none=True)
  • 梯度裁剪在 unscale_ 之后
  • 遇到梯度累计时 scaler 逻辑仍正确

5.2 梯度累计(gradient accumulation)下的注意事项

如果你用 accum_steps > 1

  • backward 每次都要调用 scaler.scale(loss / accum_steps).backward()(或先除后 scale)
  • 只有在达到累积步时才 scaler.step(optimizer)
  • 同样只在 step 前做一次 unscale_ 和梯度裁剪

常见坑:

  • 每个 micro-step 都 step,会破坏累计逻辑
  • loss 未除以 accum_steps,等价扩大 LR

5.3 DDP/FSDP/DeepSpeed 简述(避免“讲目录”,只给落地要点)

  • DDP:AMP 的用法基本不变,主要确认每个进程都创建自己的 scaler
  • FSDP:如果启用混合精度策略,通常交给 FSDP 的 mixed_precision 配置;仍需理解某些层保持 FP32(如 LayerNorm)更稳
  • DeepSpeed:如果用 ZeRO + fp16/bf16 配置,尽量让“一个地方决定精度”,避免 AMP 与 DeepSpeed fp16 同时混用导致行为混乱

6. 溢出/NaN/Inf 排查手册:从“症状”到“定位”

混合精度最常见的训练事故是:

  • loss 变成 NaN
  • 梯度出现 Inf
  • loss 突然变 0 或不再下降
  • scaler 不断降低 scale,step 大量被跳过

下面按“最快排除法”给出排查步骤。

6.1 第一步:确认是数值问题还是数据问题

1) 固定随机种子,取一个很小的数据子集(例如 8~32 条样本)
2) 在同一批数据上跑 50~200 step,观察:

  • 是否必现 NaN
  • 是否与某个 batch 强相关(可能是脏数据/异常值)

数据问题典型来源:

  • 输入 token 全是 padding 或全是同一个 token(损失异常)
  • label 越界(分类任务)或存在 NaN(回归任务)
  • 序列长度极端(某些样本特别长导致激活爆炸)

6.2 第二步:快速切换精度验证

用“对照实验”快速定位问题层级:

  • 改用 BF16(若可用):若问题消失,多半是 FP16 的动态范围问题
  • 改用纯 FP32:若 FP32 仍 NaN,通常不是混合精度本身,而是 LR/损失实现/数据
  • 关闭某些 fused/flash 算子(如 flash-attn、fused Adam):某些实现对极端值更敏感,先回退到标准实现验证

6.3 第三步:检查学习率、warmup、梯度裁剪

混合精度下,过大的 LR 更容易触发 Inf/NaN。你可以按以下顺序调整:

1) 把学习率降低 2~10 倍试跑 200 step
2) 增加 warmup(如 1%~5% 总步数)
3) 开启或收紧 梯度裁剪clip_grad_norm_ 1.0 或 0.5
4) 若用 AdamW:检查 betaseps(过小 eps 有时更不稳)

可操作建议:

  • Transformer 预训练常用 clip_grad_norm_=1.0
  • 当你看到 scaler 经常跳步:先降 LR,再看 scale 是否稳定

6.4 第四步:定位“第一个出现 NaN/Inf 的张量/层”

6.4.1 打开 anomaly detection(代价高,短跑用)

  • torch.autograd.set_detect_anomaly(True)
  • 用小数据跑几步,报错栈能指向可疑算子

6.4.2 在关键节点插入数值检查

对以下位置做 isfinite 检查,找到最早变坏的位置:

  • embedding 输出
  • attention logits(softmax 前)
  • softmax 输出
  • LayerNorm 输出
  • loss 输入(logits / labels)

典型问题模式:

  • attention logits 爆炸:mask 处理不当、缩放因子错误、序列长度极大
  • softmax NaN:logits 中有 Inf 或极大值;可尝试在 FP32 下做 softmax(或用稳定实现)

6.5 第五步:GradScaler 相关的“假象”与修复

6.5.1 症状:scale 一直降、step 经常跳过

可能原因:

  • LR 太大 / 没 warmup
  • 梯度裁剪没在 unscale 后做(导致裁剪无效)
  • 某个 batch 极端(数据异常)

修复清单:

  • 确保顺序:scaler.unscale_ -> clip -> scaler.step
  • 记录每步的 scaler.get_scale(),统计跳步比例
  • 过滤异常样本(超长、异常 label)

6.5.2 症状:loss 变 0 或几乎不变

可能原因:

  • underflow 导致梯度大量为 0(尤其是没用 scaler 的 FP16)
  • 梯度累计时 loss 除法/缩放逻辑错误

修复:

  • FP16 必须启用 GradScaler(除非你非常确定无需)
  • 检查 loss 是否除以 accum_steps
  • 输出梯度统计:非零比例、norm 分布

7. 训练稳定性“工程化”建议:让混合精度少踩坑

7.1 哪些算子/模块建议保持 FP32

在一些场景中,强行半精度会引发不稳定:

  • LayerNorm / RMSNorm:很多框架会默认用 FP32 累加
  • softmax/log-softmax:对数值稳定性敏感,可在 AMP 策略中让其在 FP32 计算
  • 损失函数(特别是交叉熵的内部 logsumexp):保持 FP32 往往更稳

实操原则:

  • 不要“手工把模型 .half() 然后全跑 FP16”,优先用 AMP 让框架做选择

7.2 初始化与正则:小改动大收益

  • 使用成熟初始化(Transformer 常用的默认初始化或 xavier/kaiming 的正确变体)
  • 适当 weight decay(AdamW)
  • label smoothing(分类)有时能降低极端 logit

7.3 日志监控:别等 NaN 了才看

建议最少记录:

  • loss
  • grad norm(裁剪前后)
  • scaler scale(FP16)
  • max(abs(logits)) 或注意力 logits 的统计(可抽样)
  • 跳步次数(scaler overflow 次数)

这样你能在“变 NaN 前”的几百步看到征兆:grad norm 飙升、scale 连续下降等。


8. 一个可执行的排错流程(建议你贴到项目 Wiki)

当你在混合精度训练中遇到 NaN/Inf,按下面顺序做,基本能在 30~60 分钟内定位大类问题:

1) 小数据子集复现:固定种子 + 取 32 条样本跑 200 step
2) 切 FP32 验证:若 FP32 也炸,先别纠结 AMP,优先查 LR/数据/损失
3) 切 BF16 验证:若 BF16 正常、FP16 炸,优先认为是 FP16 范围/scale 问题
4) 检查训练循环顺序:autocast 范围、scaler 用法、unscale 后再裁剪、累计步逻辑
5) 降 LR + 加 warmup + 开裁剪:看 scale 是否稳定、跳步是否下降
6) 定位首个 NaN 节点:anomaly detection + isfinite 钩子
7) 回退可疑 fused/flash 算子:验证是否实现细节导致


9. 结语:把混合精度当成“性能功能 + 稳定性工程”

在 Ai 大模型训练落地中,混合精度不是锦上添花,而是训练效率的核心杠杆之一。正确姿势是:用 AMP 管理半精度算子选择;FP16 配 GradScaler 防止下溢并自动应对溢出;遇到 NaN 用对照实验(FP32/BF16)、日志与逐层定位把问题收敛到具体环节。

你可以把本文当作项目中的“混合精度运行手册”:先用推荐模板跑通,再用排查流程快速定位数值稳定性问题,最终实现“更快的吞吐 + 更稳定的训练”。

混合精度训练指南:FP16、BF16、AMP、GradScaler与溢出排查
https://aissn.com/95.html