《Ai大模型训练教程》实战篇:显存不够时如何组合使用激活检查点、Offload、梯度累积与序列并行。文章给出适用场景判断、推荐启用顺序、组合方案与排障方法,帮助在单卡/多卡与长上下文训练中有效降低显存峰值并跑通训练。
场景与目标:显存不够时,你到底在缺什么?
在大模型训练中,“显存不够”往往不是一个单点问题,而是 模型参数、梯度、优化器状态、激活值(activation)、临时张量 共同叠加导致的。
为了把问题拆开,我们先用一个粗略但实用的记忆账本来定位开销:
- 参数(weights):FP16/BF16 通常 2 Bytes/参数。
- 梯度(grads):同样约 2 Bytes/参数(若梯度也 FP16/BF16)。
- 优化器状态(Adam):通常有 m、v 两个动量,FP32 下约 8 Bytes/参数(各 4 Bytes),再加上 master weight(FP32)4 Bytes/参数,合计可能到 12 Bytes/参数。
- 激活值(activations):与 batch size、sequence length、隐藏维度、层数强相关,是最“波动”的部分,也是最常见的爆点。
因此解决显存问题常见方向有两类:
- 压参数/优化器/梯度相关:如 ZeRO、8-bit 优化器、Offload。
- 压激活相关:如激活检查点(activation checkpointing)、序列并行、缩短序列、flash attention。
本文聚焦四个可组合的实战策略:激活检查点、Offload、梯度累积、序列并行,并给出可落地的组合与排障流程。
快速诊断:你应该先用哪一招?
在动手前先回答三个问题,能避免“乱配参数越调越炸”。
1)报错发生在什么时候?
- forward 就 OOM:多半是激活、注意力矩阵或序列太长导致;优先考虑 激活检查点、序列并行、降低 micro-batch、flash-attn。
- backward 或 optimizer.step 才 OOM:更可能是梯度、优化器状态、参数副本导致;优先考虑 Offload、ZeRO、梯度累积。
- 偶发性 OOM:常见于显存碎片或动态 shape;需要 固定 shape、gradient checkpointing、torch.compile/alloc 配置 等。
2)你是否已经在用 DDP / FSDP / DeepSpeed?
- 单卡:梯度累积、激活检查点是最稳的第一步;Offload 也可用但会明显慢。
- 多卡 DDP:序列并行或 ZeRO/FSDP 才能进一步扩展;梯度累积依旧关键。
- DeepSpeed/FSDP:Offload 与 ZeRO stage 的组合通常更顺手。
3)你的瓶颈是“显存”还是“吞吐”?
- 如果是“训练跑不起来”:优先保命策略(能跑最重要)。
- 如果是“能跑但太慢”:Offload 要谨慎,更多用 checkpointing + 梯度累积 + 并行策略平衡。
策略一:激活检查点(Activation Checkpointing)——用算力换显存
激活检查点的核心思想:
forward 时不保存(或少保存)中间激活,backward 时再把必要的 forward 重算一遍。
它能省多少?
对 Transformer 来说,激活常常占大头。开启 checkpointing 后,显存通常能下降 20%~50%+(取决于实现、层切分粒度、注意力实现等),代价是训练耗时上升(常见 10%~40%)。
适用与不适用
- 适用:长序列、层数深、batch 稍大导致的 OOM。
- 不适用:已经极限算力瓶颈(比如 GPU 利用率长期 99% 且训练很慢),再加 checkpoint 会更慢。
实操(Hugging Face Transformers)
1)开启梯度检查点:
model.gradient_checkpointing_enable()- Trainer 中可配
gradient_checkpointing=True
2)注意事项:
部分模型需要
use_cache=False(尤其是带 KV cache 的结构),否则可能与 checkpointing 冲突或无效:model.config.use_cache = False
- 与 FlashAttention/SDPA 同时使用一般没问题,但某些组合可能引入数值差异或不支持,需逐步验证。
粒度建议:按“层”checkpoint
- 最简单:对每一层 Transformer block 做 checkpoint。
- 更激进:把 attention 与 MLP 分开 checkpoint(更省显存但重算更多)。
常见坑
- loss 变 NaN:通常不是 checkpoint 本身,而是混合精度溢出;检查 GradScaler、学习率、bf16/fp16 设置。
- 速度掉太多:尝试更粗粒度 checkpoint;或搭配梯度累积减少 micro-batch 的通信/调度开销。
策略二:Offload(CPU/NVMe)——把“存不下的”挪出去
Offload 指把部分张量从 GPU 显存搬到 CPU 内存甚至 NVMe。
Offload 常见对象:
- 优化器状态 offload(最常见,收益大)
- 参数 offload(更省显存,但通信更频繁,慢得明显)
- 梯度 offload(视框架支持与配置)
Offload 的核心代价
- PCIe/NVLink 传输带宽与延迟会成为瓶颈
- CPU 内存压力增大
- 训练吞吐明显下降(尤其参数 offload)
因此 Offload 的定位更像:
“让训练能跑起来”的兜底方案,而不是最高性价比的加速手段。
推荐路径:优先 Offload 优化器状态
对 Adam 类优化器,状态体积很大。把它 offload 到 CPU,往往能立刻救活训练。
实操(DeepSpeed ZeRO-Offload 典型思路)
你可以在 DeepSpeed 配置里做:
- ZeRO Stage 2/3
offload_optimizer到 CPU(pin_memory=true常有帮助)- 视情况再启
offload_param(更激进)
组合经验:
- 先 Stage 2 + offload_optimizer:通常速度还能接受。
- 还不够再上 Stage 3;最后才考虑 offload_param。
性能优化建议
1)CPU 内存要足:offload 后内存占用可能是显存的数倍。
2)使用 pin_memory:减少数据搬运开销。
3)减少通信频次:配合 梯度累积(后文)减少每 step 的同步。
4)尽量避免 NVMe offload(除非不得已):NVMe 延迟更高,会显著拖慢。
策略三:梯度累积(Gradient Accumulation)——用更多 step 换更大有效 batch
当你降低 batch size 来避免 OOM 后,训练可能不稳定或吞吐下降。梯度累积可以让你:
- micro-batch(单次 forward/backward 的 batch)变小以省显存
- 通过累积多个 micro-step 的梯度,形成更大的 global batch
基本公式
global_batch = micro_batch * grad_accum_steps * data_parallel_size
你要做的是:
- 将
micro_batch调到能跑起来的最小值(常见从 1 或 2 开始)。 - 用
grad_accum_steps把 global batch 补回去。
实操建议
1)学习率如何调?
如果你保持 global batch 不变(只是 micro-batch 变小、累积步数变大),通常学习率不必改。
如果 global batch 变大或变小,可以参考线性缩放经验:
lr_new = lr_old * (global_batch_new / global_batch_old)
但大模型上更稳妥做法是:只做小幅调整,并观察 loss 曲线与梯度范数。
2)什么时候同步梯度?
在 DDP 下,累积期间要避免每个 micro-step 都 AllReduce,否则通信开销大。
实践中常用:
- 只有在最后一个累积步执行同步(框架/Trainer 通常已处理)
3)对显存的影响
梯度累积主要降低的是 激活峰值(因为 micro-batch 小了),但参数/优化器状态不会因为累积而显著下降。因此当你 OOM 出现在 optimizer.step() 附近时,梯度累积不一定救得动,需要 Offload/ZeRO。
策略四:序列并行(Sequence Parallelism)——把长序列的激活/计算沿序列维切开
当 sequence length 很长(比如 8k/16k/32k),注意力和激活会急剧膨胀。序列并行的思路是:
把序列维度(tokens)切分到多张 GPU 上,每张卡只处理一段 token,从而降低单卡激活与注意力相关开销。
它与数据并行不同:
- 数据并行(DP):每卡一份完整序列与模型,处理不同样本。
- 序列并行(SP):每卡处理同一批样本的一部分 token(序列切分)。
适用场景
- 长上下文训练(长文档、代码仓库、多轮对话拼接)
- 模型本身不算太大但 seq 很长导致 OOM
- 你希望在不降低 seq_len 的情况下跑起来
实操落地:优先使用框架能力
序列并行实现较复杂,建议优先使用成熟框架(如 Megatron-LM/NeMo、部分 FSDP/DeepSpeed 的相关能力,或特定长上下文训练栈)。不同框架命名不同,但你可以按以下“落地检查清单”来判断是否真正生效:
- 单卡显存是否随
seq_len增加变得更平缓(而不是线性/超线性暴涨) - 通信是否增加(token 切分后需要跨卡聚合某些操作)
- 吞吐下降是否在可接受范围(SP 一般会增加通信)
与注意力优化的关系
如果你已启用:
- FlashAttention / SDPA
序列并行仍可能带来额外收益,因为它不仅影响注意力计算,还影响许多与序列长度相关的激活保存。
组合策略:从“能跑”到“跑得快”的推荐顺序
四种策略不是互斥的。实战中最有效的是组合拳,但要按“收益/代价比”分阶段推进。
组合 1(最常用、性价比高):Checkpointing + 梯度累积
适合:单卡或小规模多卡,主要是激活爆了。
步骤:
- 打开激活检查点(按层)。
- 把 micro-batch 调小到能稳定 forward/backward。
- 用梯度累积把 global batch 补回去。
- 观察吞吐与 loss:如果慢太多,再调整 checkpoint 粒度。
组合 2(optimizer OOM 的救命组合):Offload(optimizer) + 梯度累积
适合:optimizer.step() 或 backward 末尾爆显存,模型参数/优化器状态顶满。
步骤:
- 启用 ZeRO(若可用)并优先 offload_optimizer。
- micro-batch 降到 1~2,打开梯度累积。
- 若仍 OOM,再加 checkpointing;最后才考虑 offload_param。
组合 3(长上下文主战场):序列并行 + Checkpointing + 梯度累积
适合:seq_len 很长是第一矛盾。
步骤:
- 在框架中启用序列并行(确保真的按 token 切分)。
- 同时启用激活检查点,避免激活峰值。
- micro-batch 降低后用梯度累积补 global batch。
- 若还不够,才考虑 Offload(通常只 offload optimizer)。
一个可直接套用的“调参顺序”
当你拿到一个 OOM 的训练任务,可按以下顺序快速收敛:
- 先降 micro-batch 到能跑(从 1 开始)。
- 开 激活检查点。
- 用 梯度累积 恢复 global batch。
- 仍 OOM:如果是 optimizer 相关,启 Offload optimizer(或 ZeRO/FSDP)。
- 仍 OOM 且 seq_len 很长:上 序列并行(或减少 seq_len/采用分段训练)。
示例:一个“显存卡死”的训练配置如何拆解
假设你要训练一个 7B~13B 级别模型,目标 seq_len=8192,单卡 24GB 显存经常 OOM。
你可以这样做一个可执行的落地方案:
第一步:让它先跑起来
micro_batch=1gradient_accumulation_steps=16(先把 global batch 做到可用)- 启用
gradient_checkpointing use_cache=false
此时大概率 forward/backward 可以过。
第二步:把 seq_len 压力拆开(如果你有多卡)
- 启用序列并行,让每卡处理 1/2 或 1/4 的 tokens
- 继续保留 checkpointing
你会看到单卡激活显存显著下降,但通信开销增加。
第三步:optimizer 仍爆时上 Offload
- ZeRO Stage 2 + offload_optimizer 到 CPU
- 如仍不够:Stage 3
到这一步,通常“训练能跑”问题解决,但速度可能下降。接下来就进入性能优化(比如 flash attention、减少 offload 范围、调大 micro-batch、减少累积步数等)。
验证与排障:如何判断哪招生效了?
1)记录显存峰值位置
用以下方式定位峰值在 forward 还是 backward:
- 在关键位置打印
torch.cuda.max_memory_allocated() - 或使用
torch.cuda.memory_summary() - 配合 profiler(PyTorch Profiler / Nsight)看分配热点
2)逐项开关做 A/B Test
一次只改一项,避免“组合开了但不知道谁起作用”。建议顺序:
- micro-batch → checkpointing → grad accumulation → offload → 序列并行
3)观察吞吐与稳定性
- checkpointing:显存下降、step 时间上升
- offload:显存明显下降、step 时间显著上升,CPU 内存上升
- grad accumulation:显存下降(来自 micro-batch 下降)、但每 optimizer step 时间更长
- 序列并行:显存下降(与 seq_len 相关)、通信与 step 时间上升
最后的实用建议:别忽视“免费显存”的细节
虽然本文聚焦四大策略,但在实操中有几条“几乎零成本”的细节经常决定成败:
- 优先用 BF16(若硬件支持):比 FP16 更稳,溢出少。
- 启用 FlashAttention/SDPA:长序列注意力显存与速度会明显改善。
- 避免动态 padding 导致 shape 抖动:尽量 bucket/pack,减少碎片与峰值。
- 把日志、评估、保存 checkpoint 的 batch 调小:很多 OOM 发生在 eval 或 save 时。
小结:四招各司其职,组合拳才是常态
- 激活检查点:最通用的“用算力换显存”,优先级高。
- Offload:把优化器/参数挪到 CPU/NVMe,能救命但会慢。
- 梯度累积:用更多 micro-step 换大有效 batch,常与前两者搭配。
- 序列并行:长上下文训练的关键能力,适合 seq_len 成为主矛盾时。
在《Ai大模型训练教程》系列的实战落地中,推荐你用“先跑起来→再提速→再扩展上下文/规模”的节奏逐层推进:先 checkpointing+累积稳住,再按瓶颈选择 offload 或序列并行,最后再做性能优化与成本控制。
Prev:多机训练网络排障:NCCL常见错误、拓扑选择与带宽瓶颈定位