本文为《Ai大模型训练教程》系列实战篇,详细讲解DeepSpeed ZeRO 1/2/3的显存节省原理、通信代价与适用场景,并给出可直接套用的配置示例与10个高频避坑排障建议,帮助你稳定训练更大模型或更大batch。
文章定位与前置假设
在《Ai大模型训练教程:从入门到实战落地的系统课程》系列中,这一篇聚焦 DeepSpeed ZeRO(Zero Redundancy Optimizer)1/2/3:它们各自“省显存”省在哪里、带来什么代价、以及实际配置时最常见的坑。
本文默认你已经能跑通分布式训练(至少知道 data parallel / DDP 的基本概念),并希望把同样的模型在更小显存或更大 batch 下稳定训练。
为什么 DDP 会“显存爆炸”:冗余来自哪三类东西
在最常见的 PyTorch DDP 里,每张卡上都会完整保存:
- 模型参数(Parameters):fp16/bf16 或 fp32(取决于配置)
- 梯度(Gradients):反向传播产生,通常与参数同尺寸
- 优化器状态(Optimizer States):以 Adam 为例,至少有
m、v两份动量,再加上 master weights(很多混合精度实现会保留 fp32 master param)
所以即使只做数据并行,每张卡都在重复保存一份“参数+梯度+优化器状态”。模型越大,这些冗余越致命。
粗略估算(以 Adam + 混合精度为例):
- 参数:fp16 2 bytes/param
- 梯度:fp16 2 bytes/param
- master 参数:fp32 4 bytes/param
- Adam m、v:各 fp32 4 bytes/param,共 8 bytes/param
合计约 2 + 2 + 4 + 8 = 16 bytes/param(还没算激活、临时 buffer)。
ZeRO 的核心就是:把这些可分片的东西在数据并行组内切开,每张卡只存其中 1/N,减少“冗余”。
ZeRO 总览:1/2/3 分别切哪一刀
可以用一句话记忆:
- ZeRO-1:切优化器状态(Optimizer State Partitioning)
- ZeRO-2:切优化器状态 + 切梯度(Grad Partitioning)
- ZeRO-3:优化器状态 + 梯度 + 参数 全切(Param Partitioning),训练时按需 all-gather
ZeRO-1:只分片优化器状态,最稳、最容易落地
省显存原理:
- Adam 的 m/v(以及可能的 fp32 master 权重)在数据并行组内做 partition
- 每张卡只保留 1/N 的 optimizer states
收益:
- 对 Adam 来说,优化器状态通常是最重的部分,ZeRO-1 就能显著降低显存
代价:
- 通信量增加不多,整体最接近 DDP
- 兼容性最好(对很多自定义模型/算子最友好)
适用建议:
- 你想“先稳住”,把同等模型的 batch 提上去,或在同显存下把模型做大一点
- 多数团队会先上 ZeRO-1 再逐步升级
ZeRO-2:再分片梯度,进一步缓解反向显存
省显存原理:
- 梯度在 reduce-scatter 后分片存放,每张卡最终只保留 1/N 的 grad partition
- 反向阶段的梯度聚合由 all-reduce 变为 reduce-scatter + all-gather(实现细节由 DeepSpeed 处理)
收益:
- 梯度也很大,尤其大模型下能再省一截
代价:
- 通信量和通信时序更复杂
- 更依赖良好的网络(IB/NVLink)与 bucket 设置
适用建议:
- ZeRO-1 仍然 OOM,且你确定主要是 optimizer+grad 压力
- 你有较稳定的网络与分布式环境
ZeRO-3:参数也分片,显存最省,但最容易踩坑
省显存原理(关键点):
- 参数不再每卡一份,而是每卡只存 1/N 的参数分片
- 前向/反向需要用到某层参数时,DeepSpeed 临时 all-gather 把该层参数拼齐到参与计算的 GPU,上下文用完再释放或重新分片
收益:
- 参数、梯度、优化器状态几乎都按 1/N 分摊,理论上显存最省
- 对“模型参数本身就装不下”的情况是救命方案
代价:
- 通信显著增多(按层 all-gather),对网络带宽和延迟敏感
- 对算子融合、checkpoint、参数访问方式更挑剔
- 配置项多,排障复杂
适用建议:
- 单卡/多卡都装不下参数,必须上 ZeRO-3
- 能接受调参、排坑,并具备较强的工程控制能力
显存节省“算账”:一个可落地的估算方法
你在选 ZeRO stage 前,建议先做一次估算,至少知道“瓶颈在参数、梯度、优化器状态、还是激活”。
步骤 1:估算每类占用
以参数量 P(单位:个参数)为基准:
- 参数(fp16/bf16):约
2 * Pbytes - 梯度(fp16/bf16):约
2 * Pbytes - Adam 状态(m/v,fp32):约
8 * Pbytes - master param(fp32,可选):约
4 * Pbytes
如果你用 bf16 且不保留 master param(某些实现可做到),会略不同,但“优化器状态最重”通常仍成立。
步骤 2:估算 ZeRO 后的理论变化(数据并行组大小 = N)
- ZeRO-1:optimizer state ~ 1/N;参数和梯度不变
- ZeRO-2:optimizer state ~ 1/N;梯度 ~ 1/N;参数不变
- ZeRO-3:optimizer state ~ 1/N;梯度 ~ 1/N;参数 ~ 1/N(但会有通信/临时 all-gather 峰值)
步骤 3:别忘了“真正的 OOM 常来自激活与临时张量”
即使 ZeRO 把参数相关显存压下去,
- 激活(activation) 仍可能因序列长度、batch、层数、注意力实现而爆
- 这时要结合 activation checkpointing、flash-attn、sequence parallel 等手段
DeepSpeed 配置模板(可直接改):从 ZeRO-1 到 ZeRO-3
下面用常见 JSON 配置字段说明(具体字段会随 DeepSpeed 版本略有差异,但核心一致)。
通用骨架:fp16/bf16 + 基础训练参数
{
"train_batch_size": 64,
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 32,
"steps_per_print": 100,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
}
}建议:
- 先把 micro-batch 调到不 OOM 的最小值,再用 accumulation 把全局 batch 拉起来
- loss_scale=0 让 DeepSpeed 自动动态缩放,初期更稳
ZeRO-1 配置示例
{
"zero_optimization": {
"stage": 1,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 500000000,
"allgather_bucket_size": 500000000
}
}关键点:
overlap_comm: 尝试通信与计算重叠,通常建议开启- bucket 太小通信频繁,太大可能峰值显存上升;可从 5e8(约 0.5GB)起试
ZeRO-2 配置示例
{
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_scatter": true,
"reduce_bucket_size": 500000000,
"allgather_bucket_size": 500000000
}
}关键点:
reduce_scatter开启后通常更符合 ZeRO-2 的通信方式- 若遇到吞吐不升反降,优先排查网络与 bucket、以及是否出现频繁的同步点
ZeRO-3 配置示例(含常见关键项)
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_scatter": true,
"stage3_prefetch_bucket_size": 500000000,
"stage3_param_persistence_threshold": 100000,
"stage3_max_live_parameters": 1000000000,
"stage3_max_reuse_distance": 1000000000,
"gather_16bit_weights_on_model_save": true
}
}字段解释与建议:
stage3_prefetch_bucket_size:预取参数的 bucket,大了可能更快但峰值更高;OOM 先降它stage3_param_persistence_threshold:小参数常驻可减少频繁 all-gather;但太大会增加常驻显存gather_16bit_weights_on_model_save:保存模型时把分片参数聚合为完整权重,便于下游加载(但保存时需要额外显存/时间)
训练脚本对接要点:让“配置生效”并可控
1)确认你是用 deepspeed 启动,而不是只用 torchrun
典型方式:
deepspeed --num_gpus=8 train.py --deepspeed ds_config.json
如果你用 Hugging Face Trainer:
- 在 TrainingArguments 里指定
deepspeed=ds_config.json
2)梯度累积与全局 batch 的一致性检查
确保:
train_batch_size = micro_batch_per_gpu * gradient_accumulation_steps * world_size
不一致会导致:
- 学习率调度与实际 batch 不匹配
- 日志里显示的 batch 概念混乱,影响复现
3)先用 ZeRO-1 跑通,再升 stage
实操策略:
- 单机多卡 + ZeRO-1:验证收敛、loss scale、吞吐
- 升 ZeRO-2:对比显存下降是否符合预期
- 需要时再上 ZeRO-3:一次只改一个关键参数,便于定位问题
配置避坑清单:最常见的 10 个问题与处理思路
坑 1:ZeRO-3 省了参数却依旧 OOM
原因常见于:
- 激活显存占比更大(长序列、attention 开销大)
- 临时 all-gather 峰值 + bucket 过大
处理:
- 开启 activation checkpointing(按层或按 block)
- 降低
stage3_prefetch_bucket_size、allgather_bucket_size - 优先上 FlashAttention / xFormers,降低 attention 激活
坑 2:吞吐骤降,GPU 利用率忽高忽低
原因:
- 通信成为瓶颈(尤其 ZeRO-3 分层 all-gather)
- bucket 设置不合理,导致频繁小通信
处理:
- 逐步增大 bucket(例如 1e8 -> 2e8 -> 5e8),观察吞吐与峰值显存
- 确认网络拓扑(NVLink/IB)与 NCCL 环境变量配置
坑 3:保存 checkpoint 时报错或巨慢
原因:
- ZeRO-3 参数分片,保存“完整权重”需要 gather
- IO 吞吐不足导致保存卡住
处理:
- 训练中间 checkpoint 用分片(更快),最终导出再 gather
- 开启/确认
gather_16bit_weights_on_model_save的策略符合你的加载需求 - 把 checkpoint 写到高吞吐存储(本地 NVMe > 网络盘)
坑 4:模型里手动访问 .param.data 或做奇怪的 in-place 操作
原因:
- ZeRO-3 会对参数生命周期做管理,某些手动操作破坏假设
处理:
- 避免对参数做 in-place 修改
- 自定义层里需要参数时遵循标准 forward 方式,不要在 forward 外缓存参数引用
坑 5:梯度裁剪(grad clipping)行为不符合预期
原因:
- ZeRO-2/3 下梯度是分片的,裁剪需要正确聚合范数
处理:
- 优先使用 DeepSpeed/HF 提供的内置裁剪配置或接口
- 不要自己在外面遍历
model.parameters()做 naive clipping
坑 6:混合精度下 loss 变 NaN 或震荡
原因:
- 动态 loss scale 不稳定、学习率过大、梯度溢出
处理:
- 降低学习率或 warmup 更长
- fp16 下保持
loss_scale=0动态缩放,观察溢出频率 - 能用 bf16 尽量用 bf16(对溢出更不敏感),前提是硬件支持
坑 7:发现显存“省了”,但显存碎片严重,偶发 OOM
原因:
- 频繁分配释放导致碎片(尤其 ZeRO-3 all-gather 临时 buffer)
处理:
- 尽量使用稳定的 batch/seq len,避免动态 shape
- 适当调大
contiguous_gradients、使用更连续的 buffer - 训练前做一次 profile warmup,让内存分配更稳定
坑 8:与梯度检查点(activation checkpointing)一起用时速度过慢
原因:
- checkpointing 会增加重算;ZeRO-3 又加重通信
处理:
- 优先对 attention/MLP block 做分段 checkpoint,而不是全模型无脑 checkpoint
- 在可接受显存下,宁愿用 ZeRO-2 + checkpointing,而不是 ZeRO-3 + 重度 checkpointing
坑 9:多节点训练时偶发 NCCL 超时
原因:
- 通信量大且同步点多,网络抖动/配置不当就容易超时
处理:
- 检查 IB/NCCL 配置、网卡绑定、NCCL 超时设置
- 降低通信压力(bucket 调整、必要时降低 ZeRO stage 或减少并行度)
坑 10:参数量很大但依然“看起来没省多少”
原因:
- 你省的是参数/优化器,但激活占比更大
- 或你观察的是
nvidia-smi的峰值/保留显存,而不是实际分解
处理:
- 用更细的工具(如 torch profiler、deepspeed memory breakdown)看参数/梯度/优化器/激活占比
- 用固定输入做可复现对比,避免动态 shape 影响
选择建议:如何在项目里做决策(可执行)
决策流程(推荐照着走)
- 先测 DDP 是否 OOM:记录 micro-batch=1 时的最大显存
- 如果主要卡在优化器状态(常见于 Adam):上 ZeRO-1
- 如果主要卡在梯度/反向显存:尝试 ZeRO-2
如果参数本体就装不下:直接 ZeRO-3,并配合
- 合理的 bucket
- activation checkpointing
- 更高效 attention 实现
- 吞吐优先:能 ZeRO-1/2 就别上 3;ZeRO-3 是“容量优先”的方案
一条经验线
- 你只是想把 batch 拉上去:ZeRO-1 往往就够
- 你要进一步扩大模型但还没到“参数装不下”:ZeRO-2 性价比很高
- 你要上超大模型(参数单卡完全不可能):ZeRO-3 + 工程化调参不可避免
小结:把 ZeRO 当作“显存预算器”来用
ZeRO 不是“开关一开就完事”,而是你在训练系统里做显存预算与通信预算的核心工具:
- ZeRO-1:最稳的第一步,立竿见影省优化器显存
- ZeRO-2:进一步省梯度,适合想扩大 batch/模型且网络条件不错的场景
- ZeRO-3:容量天花板方案,省得最多但通信与配置最复杂,务必按 bucket、保存策略、激活开销系统调优
在本系列后续文章里,如果你继续往“实战落地”推进,建议把 ZeRO 与 activation checkpointing、FlashAttention、以及不同并行策略(TP/PP)结合起来做端到端的吞吐与成本优化。
Prev:数据并行、张量并行、流水线并行怎么选:通信开销与适用规模