本篇Ai大模型训练教程详解如何看懂训练过程曲线:Loss与PPL的正确比较方式、Grad Norm定位梯度爆炸/消失、吞吐tokens/s拆解数据/计算/通信瓶颈,以及显存峰值与泄漏型OOM的监控与排查步骤,提供可落地的监控面板与实战诊断流程。
写在前面:为什么“看懂曲线”比“跑通训练”更重要
在 Ai大模型训练教程 系列里,很多同学把“训练能跑起来”当成终点,但真正决定模型能否稳定收敛、成本是否可控、上线质量是否可靠的,是你能不能在训练过程中读懂监控曲线,并据此快速定位问题。
这篇文章聚焦训练中最常见、也最容易误判的几类指标:Loss、PPL、Grad Norm、吞吐(Throughput)与显存(Memory)。我会按“指标含义→怎么看→异常模式→怎么处理”的方式给出可执行的检查步骤和调参建议,并给出一套可直接落地的监控配置思路(以 TensorBoard / W&B 这类为代表,工具可替换)。
你需要先统一的“坐标系”:Step、Token、Wall-time
在解释曲线之前,先把横轴的概念统一,否则你会经常出现“看起来变慢/变快,但其实只是横轴换了”的错觉。
H3 Step vs Token vs Wall-time
- Step(迭代步):每次参数更新算一步。对比不同实验时,若 batch size / grad accumulation 不同,step 的可比性会变差。
- Token(已训练 token 数):更接近“训练进度的真实尺度”,尤其是比较不同 batch 配置时。
- Wall-time(实际时间):用来衡量工程效率(吞吐、卡间通信、IO)。
H3 建议的统一口径
- 训练质量(收敛)相关曲线:优先用 Token 做横轴(或同时记录 step 与 token)。
- 工程性能相关曲线:用 Wall-time 做横轴,同时叠加吞吐、显存、GPU 利用率。
Loss:最常见,但也最容易“看错”
Loss 通常指训练目标的平均损失(例如 language modeling 的 cross entropy)。
H3 Loss 应该怎么看
- 先看训练 Loss(train loss)是否稳定下降:大模型预训练时通常是缓慢下降,不会像小模型那样陡。
- 再看验证 Loss(eval/val loss)是否同步下降:验证集更能反映泛化。
- 看波动幅度与趋势:大 batch 往往更平滑,小 batch 更抖;但“抖”并不等于坏,关键看长期趋势。
H3 常见“异常形态”与处理
1)Loss 长时间不降或下降很慢
可能原因:
- 学习率太小
- warmup 太长/调度不当
- 数据质量或 tokenization 问题(大量无效文本、重复、乱码)
- 训练被混合精度溢出/梯度裁剪过度压扁
处理步骤:
- 把横轴切到 Token,确认不是“训练 token 太少”。
- 打印/抽样检查训练 batch 文本:是否大量重复、空行、异常符号。
- 记录并对比 lr 曲线 与 loss:如果 warmup 期间 loss 完全不动,尝试缩短 warmup 或提高 base lr。
- 查看 grad norm 是否长期很小(见后文),若是,可能裁剪阈值太低或 lr 太低。
2)Loss 突然爆炸(瞬间飙升到很大)
可能原因:
- 学习率过大或调度跳变
- 混合精度(fp16/bf16)下发生溢出
- 梯度累积配置/有效 batch 突变
- 数据中出现极端长文本或异常样本
处理步骤(按优先级):
- 查看同一时刻的 grad norm、loss scale(如果使用动态 loss scaling)、是否出现 overflow。
- 将 lr 临时降低 2~10 倍复现一次;若恢复正常,基本锁定 lr 过大或调度问题。
- 检查该 step 的输入长度分布(max seq len 是否突然增大)。
- 对异常 batch 做保存与回放(把该 step 的样本 dump 出来),定位数据问题。
3)训练 Loss 降,但验证 Loss 不降或上升
这通常是过拟合或数据分布不一致:
- 预训练:可能训练数据过脏、重复高,导致“记忆化”
- 微调:可能数据量小、学习率偏大、正则不足
处理建议:
- 增加验证频率,确认趋势真实。
- 微调场景:降低 lr、增加 weight decay、启用/增大 dropout(若结构支持),或使用更强数据增强。
- 预训练场景:提升去重、清洗策略;增加多样性数据。
PPL(Perplexity):把 Loss 变成“更直观的难度指标”
语言模型中常用:
- 若 loss 是平均交叉熵(以自然对数为底),则 PPL = exp(loss)。
H3 PPL 怎么看才不踩坑
- 只在同一种 loss 定义/同一 tokenizer/同一验证集下比较 PPL。不同 tokenization 会导致 PPL 绝对值不可比。
- PPL 对 loss 的变化非常敏感:loss 小幅下降,PPL 可能明显下降。
- 训练初期 PPL 通常下降很快,后期会进入“缓慢爬坡式”优化,这是正常现象。
H3 实操建议:用 PPL 做里程碑检查
- 预训练:固定一个“健康验证集”(高质量、分布稳定),每隔 N tokens 评估一次 PPL。
- 微调:除了通用验证集,也可以加一份“目标域验证集”,对比 PPL 变化是否真的朝目标域提升。
Grad Norm:诊断“训练稳不稳”的第一抓手
Grad Norm(梯度范数)能回答:当前更新步幅是否异常?是否在爆炸/消失?
H3 Grad Norm 正常是什么样
- 通常在一个相对稳定的区间上下波动。
- 训练早期可能更大、波动更明显;稳定后会进入较窄波动区间。
H3 三类典型异常与定位路径
1)Grad Norm 逐步变大,随后 Loss 爆炸
这是一种“渐进式梯度爆炸”。
处理:
- 启用梯度裁剪(clip grad norm),例如 1.0/0.5/2.0 需要结合任务试。
- 降低学习率或使用更平滑的调度(cosine/linear with warmup)。
- 检查是否在某个阶段引入了更长序列或更难数据。
2)Grad Norm 长期极小,Loss 几乎不动
常见于:
- 学习率过小
- 梯度裁剪阈值太低(把梯度“剪没了”)
- 混合精度下数值问题(尤其 fp16)
处理:
- 临时关闭/放宽梯度裁剪阈值,观察 grad norm 是否恢复。
- 增大学习率或减少 warmup。
- 优先使用 bf16(若硬件支持)比 fp16 更稳;或检查 loss scaling 是否频繁触发 overflow。
3)Grad Norm 频繁尖刺(spike),但 Loss 还能下降
可能是正常的“难样本/长序列”导致,也可能是数据脏点。
处理:
- 记录 spike 对应的 batch:长度、是否包含异常字符、是否超长。
- 若 spike 与吞吐下降同时发生,可能是某些 batch 过长造成的。
- 对超长样本做截断策略或长度分桶(bucketing)。
吞吐(Throughput):从“训练快不快”到“钱烧得值不值”
吞吐常用单位:
- tokens/s(每秒 token 数)
- samples/s(每秒样本数)
- 或每 GPU 的 tokens/s
H3 吞吐监控的核心是分解瓶颈
把吞吐拆成三段最有效:
- 数据输入:dataloader/IO 是否卡住
- 计算:GPU 利用率、算子效率、是否频繁同步
- 通信:多卡 all-reduce、ZeRO、pipeline 并行的通信开销
H3 实操:吞吐下降的排查清单
1)吞吐呈“锯齿形”周期性下降
常见原因:
- 周期性 eval/保存 checkpoint
- dataloader 预取不足、周期性 cache miss
建议:
- 把 eval、save 的耗时单独打点记录(如 eval_time、save_time)。
- 增加 dataloader workers、启用 prefetch/persistent_workers。
2)吞吐突然腰斩并维持低位
常见原因:
- GPU 降频(温度/功耗)
- 集群网络异常
- 混合精度退化或 fallback
- 某个节点性能异常(多机训练中很常见)
建议:
- 同步记录 GPU util、SM clock、power、temperature。
- 多机时对比每个 rank 的 step time,定位慢节点。
3)吞吐随序列长度波动很大
这是正常现象,但可以工程优化:
- 用 长度分桶(按长度分 batch)减少 padding 造成的算力浪费。
- 设定最大长度并对极端长样本做截断或单独队列。
显存(Memory):决定你能不能上更大模型、更长上下文
显存监控至少要区分三类:
- 参数显存(parameters)
- 梯度与优化器状态(grads/optimizer states)
- 激活(activations,通常与 seq len、batch size 强相关)
H3 你该监控哪些显存指标
- allocated / reserved(已分配/已保留)
- max memory allocated(峰值)
- OOM 前的“爬升轨迹”(是否逐步泄漏)
H3 三种典型显存曲线问题
1)显存稳定但很高,离 OOM 很近
处理优先级:
- 开启/增大 gradient checkpointing(激活重算)换显存。
- 降低 micro-batch size,用 gradient accumulation 保持有效 batch。
- 使用更省显存的优化器方案(如 ZeRO / offload,或更轻量的 optimizer 配置)。
- 缩短 seq len 或启用 packed sequences(把多个短样本拼到一个序列,减少 padding)。
2)显存逐步爬升(疑似泄漏),最终 OOM
常见原因:
- 日志/评估时保存了计算图(没 detach/没 no_grad)
- 缓存列表不断 append(例如保存每步输出)
- 自定义 callback/metric 里引用了大 tensor
排查步骤:
- 关闭 eval 与可视化输出,观察是否还爬升。
- 检查是否在训练循环里把 tensor 存进 Python list。
- eval 使用
no_grad(),并确保只保存标量。
3)显存周期性尖刺
常见原因:
- eval 使用更大 batch 或更长序列
- checkpoint 保存前后临时拷贝
建议:
- 让 eval 与 train 使用一致的 seq len 上限与 batch 策略。
- checkpoint 采用异步保存或只保存必要权重。
建议的“监控面板”最小集合(可直接照抄)
为了让你在一次训练中快速定位问题,建议把面板分成四组:
H3 1)收敛质量(核心)
- train loss(建议滑动平均)
- eval/val loss
- eval ppl
- lr(学习率曲线)
H3 2)稳定性(防炸)
- grad norm(全局或分层)
- overflow/NaN 计数(如有)
- clip ratio(被裁剪比例,若实现支持)
H3 3)性能(省钱)
- tokens/s(全局与 per GPU)
- step time(p50/p90)
- GPU utilization、data time、compute time(能拆就拆)
H3 4)资源(防 OOM)
- max memory allocated
- allocated/reserved
- CPU 内存与 dataloader 队列长度(如果数据链路复杂)
一套“读曲线→下结论→做实验”的实战流程
把问题诊断流程固定下来,你会越来越快。
H3 Step 1:先判断是“收敛问题”还是“工程问题”
- Loss/PPL 不正常,但吞吐与显存正常:多半是超参/数据问题。
- Loss 正常下降,但吞吐掉、step time 抖:多半是工程瓶颈。
- Loss 直接爆炸且伴随 grad norm 上天:稳定性问题(lr、溢出、异常 batch)。
H3 Step 2:做最小对照实验(不要一次改太多)
建议每次只改一个变量:
- lr ×0.5 或 ×0.2(验证是否过大)
- 开/关 gradient clipping(验证是否梯度爆炸)
- 固定 seq len(排除长度波动)
- 固定数据子集(排除数据脏点与分布漂移)
H3 Step 3:把关键点“打点记录”,让复盘可重复
每次出现异常时,你至少应该能回答:
- 异常发生在多少 token?对应哪个 checkpoint?
- 同步出现变化的指标有哪些(loss、grad norm、throughput、memory、lr)?
- 对应 batch 的长度分布与样本是否异常?
结语:用指标建立训练的“仪表盘直觉”
在大模型训练里,Loss/PPL 决定你有没有在学习,Grad Norm 决定你能不能稳定学下去,吞吐与显存决定你能不能以合理成本学完。把这些曲线当作“仪表盘”,你会从“跑实验”升级为“控训练”。
后续如果你愿意进一步把监控做到可运维级别,可以在团队里建立统一的面板模板、异常报警阈值(如 grad norm 尖刺、吞吐腰斩、显存爬升)与自动保存异常 batch 的机制,这会显著降低大规模训练的试错成本。
Prev:训练中断如何无损恢复:Checkpoint策略、随机种子、数据顺序与可复现性