AiSSN.com ©

在线Ai关键词排名GEO优化工具,让你的信息出现在Ai的回答中

DeepSpeed ZeRO 1/2/3详解:显存节省原理与配置避坑
原始问题:

本文为《Ai大模型训练教程》系列实战篇,详细讲解DeepSpeed ZeRO 1/2/3的显存节省原理、通信代价与适用场景,并给出可直接套用的配置示例与10个高频避坑排障建议,帮助你稳定训练更大模型或更大batch。

文章定位与前置假设

在《Ai大模型训练教程:从入门到实战落地的系统课程》系列中,这一篇聚焦 DeepSpeed ZeRO(Zero Redundancy Optimizer)1/2/3:它们各自“省显存”省在哪里、带来什么代价、以及实际配置时最常见的坑。

本文默认你已经能跑通分布式训练(至少知道 data parallel / DDP 的基本概念),并希望把同样的模型在更小显存或更大 batch 下稳定训练。


为什么 DDP 会“显存爆炸”:冗余来自哪三类东西

在最常见的 PyTorch DDP 里,每张卡上都会完整保存:

  1. 模型参数(Parameters):fp16/bf16 或 fp32(取决于配置)
  2. 梯度(Gradients):反向传播产生,通常与参数同尺寸
  3. 优化器状态(Optimizer States):以 Adam 为例,至少有 mv 两份动量,再加上 master weights(很多混合精度实现会保留 fp32 master param)

所以即使只做数据并行,每张卡都在重复保存一份“参数+梯度+优化器状态”。模型越大,这些冗余越致命。

粗略估算(以 Adam + 混合精度为例):

  • 参数:fp16 2 bytes/param
  • 梯度:fp16 2 bytes/param
  • master 参数:fp32 4 bytes/param
  • Adam m、v:各 fp32 4 bytes/param,共 8 bytes/param

合计约 2 + 2 + 4 + 8 = 16 bytes/param(还没算激活、临时 buffer)。

ZeRO 的核心就是:把这些可分片的东西在数据并行组内切开,每张卡只存其中 1/N,减少“冗余”。


ZeRO 总览:1/2/3 分别切哪一刀

可以用一句话记忆:

  • ZeRO-1:切优化器状态(Optimizer State Partitioning)
  • ZeRO-2:切优化器状态 + 切梯度(Grad Partitioning)
  • ZeRO-3:优化器状态 + 梯度 + 参数 全切(Param Partitioning),训练时按需 all-gather

ZeRO-1:只分片优化器状态,最稳、最容易落地

省显存原理

  • Adam 的 m/v(以及可能的 fp32 master 权重)在数据并行组内做 partition
  • 每张卡只保留 1/N 的 optimizer states

收益

  • 对 Adam 来说,优化器状态通常是最重的部分,ZeRO-1 就能显著降低显存

代价

  • 通信量增加不多,整体最接近 DDP
  • 兼容性最好(对很多自定义模型/算子最友好)

适用建议

  • 你想“先稳住”,把同等模型的 batch 提上去,或在同显存下把模型做大一点
  • 多数团队会先上 ZeRO-1 再逐步升级

ZeRO-2:再分片梯度,进一步缓解反向显存

省显存原理

  • 梯度在 reduce-scatter 后分片存放,每张卡最终只保留 1/N 的 grad partition
  • 反向阶段的梯度聚合由 all-reduce 变为 reduce-scatter + all-gather(实现细节由 DeepSpeed 处理)

收益

  • 梯度也很大,尤其大模型下能再省一截

代价

  • 通信量和通信时序更复杂
  • 更依赖良好的网络(IB/NVLink)与 bucket 设置

适用建议

  • ZeRO-1 仍然 OOM,且你确定主要是 optimizer+grad 压力
  • 你有较稳定的网络与分布式环境

ZeRO-3:参数也分片,显存最省,但最容易踩坑

省显存原理(关键点)

  • 参数不再每卡一份,而是每卡只存 1/N 的参数分片
  • 前向/反向需要用到某层参数时,DeepSpeed 临时 all-gather 把该层参数拼齐到参与计算的 GPU,上下文用完再释放或重新分片

收益

  • 参数、梯度、优化器状态几乎都按 1/N 分摊,理论上显存最省
  • 对“模型参数本身就装不下”的情况是救命方案

代价

  • 通信显著增多(按层 all-gather),对网络带宽和延迟敏感
  • 对算子融合、checkpoint、参数访问方式更挑剔
  • 配置项多,排障复杂

适用建议

  • 单卡/多卡都装不下参数,必须上 ZeRO-3
  • 能接受调参、排坑,并具备较强的工程控制能力

显存节省“算账”:一个可落地的估算方法

你在选 ZeRO stage 前,建议先做一次估算,至少知道“瓶颈在参数、梯度、优化器状态、还是激活”。

步骤 1:估算每类占用

以参数量 P(单位:个参数)为基准:

  • 参数(fp16/bf16):约 2 * P bytes
  • 梯度(fp16/bf16):约 2 * P bytes
  • Adam 状态(m/v,fp32):约 8 * P bytes
  • master param(fp32,可选):约 4 * P bytes

如果你用 bf16 且不保留 master param(某些实现可做到),会略不同,但“优化器状态最重”通常仍成立。

步骤 2:估算 ZeRO 后的理论变化(数据并行组大小 = N)

  • ZeRO-1:optimizer state ~ 1/N;参数和梯度不变
  • ZeRO-2:optimizer state ~ 1/N;梯度 ~ 1/N;参数不变
  • ZeRO-3:optimizer state ~ 1/N;梯度 ~ 1/N;参数 ~ 1/N(但会有通信/临时 all-gather 峰值)

步骤 3:别忘了“真正的 OOM 常来自激活与临时张量”

即使 ZeRO 把参数相关显存压下去,

  • 激活(activation) 仍可能因序列长度、batch、层数、注意力实现而爆
  • 这时要结合 activation checkpointing、flash-attn、sequence parallel 等手段

DeepSpeed 配置模板(可直接改):从 ZeRO-1 到 ZeRO-3

下面用常见 JSON 配置字段说明(具体字段会随 DeepSpeed 版本略有差异,但核心一致)。

通用骨架:fp16/bf16 + 基础训练参数

{
  "train_batch_size": 64,
  "train_micro_batch_size_per_gpu": 2,
  "gradient_accumulation_steps": 32,
  "steps_per_print": 100,
  "fp16": {
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  }
}

建议:

  • 先把 micro-batch 调到不 OOM 的最小值,再用 accumulation 把全局 batch 拉起来
  • loss_scale=0 让 DeepSpeed 自动动态缩放,初期更稳

ZeRO-1 配置示例

{
  "zero_optimization": {
    "stage": 1,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "reduce_bucket_size": 500000000,
    "allgather_bucket_size": 500000000
  }
}

关键点:

  • overlap_comm: 尝试通信与计算重叠,通常建议开启
  • bucket 太小通信频繁,太大可能峰值显存上升;可从 5e8(约 0.5GB)起试

ZeRO-2 配置示例

{
  "zero_optimization": {
    "stage": 2,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 500000000,
    "allgather_bucket_size": 500000000
  }
}

关键点:

  • reduce_scatter 开启后通常更符合 ZeRO-2 的通信方式
  • 若遇到吞吐不升反降,优先排查网络与 bucket、以及是否出现频繁的同步点

ZeRO-3 配置示例(含常见关键项)

{
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "reduce_scatter": true,
    "stage3_prefetch_bucket_size": 500000000,
    "stage3_param_persistence_threshold": 100000,
    "stage3_max_live_parameters": 1000000000,
    "stage3_max_reuse_distance": 1000000000,
    "gather_16bit_weights_on_model_save": true
  }
}

字段解释与建议:

  • stage3_prefetch_bucket_size:预取参数的 bucket,大了可能更快但峰值更高;OOM 先降它
  • stage3_param_persistence_threshold:小参数常驻可减少频繁 all-gather;但太大会增加常驻显存
  • gather_16bit_weights_on_model_save:保存模型时把分片参数聚合为完整权重,便于下游加载(但保存时需要额外显存/时间)

训练脚本对接要点:让“配置生效”并可控

1)确认你是用 deepspeed 启动,而不是只用 torchrun

典型方式:

  • deepspeed --num_gpus=8 train.py --deepspeed ds_config.json

如果你用 Hugging Face Trainer:

  • 在 TrainingArguments 里指定 deepspeed=ds_config.json

2)梯度累积与全局 batch 的一致性检查

确保:

  • train_batch_size = micro_batch_per_gpu * gradient_accumulation_steps * world_size

不一致会导致:

  • 学习率调度与实际 batch 不匹配
  • 日志里显示的 batch 概念混乱,影响复现

3)先用 ZeRO-1 跑通,再升 stage

实操策略:

  1. 单机多卡 + ZeRO-1:验证收敛、loss scale、吞吐
  2. 升 ZeRO-2:对比显存下降是否符合预期
  3. 需要时再上 ZeRO-3:一次只改一个关键参数,便于定位问题

配置避坑清单:最常见的 10 个问题与处理思路

坑 1:ZeRO-3 省了参数却依旧 OOM

原因常见于:

  • 激活显存占比更大(长序列、attention 开销大)
  • 临时 all-gather 峰值 + bucket 过大

处理:

  • 开启 activation checkpointing(按层或按 block)
  • 降低 stage3_prefetch_bucket_sizeallgather_bucket_size
  • 优先上 FlashAttention / xFormers,降低 attention 激活

坑 2:吞吐骤降,GPU 利用率忽高忽低

原因:

  • 通信成为瓶颈(尤其 ZeRO-3 分层 all-gather)
  • bucket 设置不合理,导致频繁小通信

处理:

  • 逐步增大 bucket(例如 1e8 -> 2e8 -> 5e8),观察吞吐与峰值显存
  • 确认网络拓扑(NVLink/IB)与 NCCL 环境变量配置

坑 3:保存 checkpoint 时报错或巨慢

原因:

  • ZeRO-3 参数分片,保存“完整权重”需要 gather
  • IO 吞吐不足导致保存卡住

处理:

  • 训练中间 checkpoint 用分片(更快),最终导出再 gather
  • 开启/确认 gather_16bit_weights_on_model_save 的策略符合你的加载需求
  • 把 checkpoint 写到高吞吐存储(本地 NVMe > 网络盘)

坑 4:模型里手动访问 .param.data 或做奇怪的 in-place 操作

原因:

  • ZeRO-3 会对参数生命周期做管理,某些手动操作破坏假设

处理:

  • 避免对参数做 in-place 修改
  • 自定义层里需要参数时遵循标准 forward 方式,不要在 forward 外缓存参数引用

坑 5:梯度裁剪(grad clipping)行为不符合预期

原因:

  • ZeRO-2/3 下梯度是分片的,裁剪需要正确聚合范数

处理:

  • 优先使用 DeepSpeed/HF 提供的内置裁剪配置或接口
  • 不要自己在外面遍历 model.parameters() 做 naive clipping

坑 6:混合精度下 loss 变 NaN 或震荡

原因:

  • 动态 loss scale 不稳定、学习率过大、梯度溢出

处理:

  • 降低学习率或 warmup 更长
  • fp16 下保持 loss_scale=0 动态缩放,观察溢出频率
  • 能用 bf16 尽量用 bf16(对溢出更不敏感),前提是硬件支持

坑 7:发现显存“省了”,但显存碎片严重,偶发 OOM

原因:

  • 频繁分配释放导致碎片(尤其 ZeRO-3 all-gather 临时 buffer)

处理:

  • 尽量使用稳定的 batch/seq len,避免动态 shape
  • 适当调大 contiguous_gradients、使用更连续的 buffer
  • 训练前做一次 profile warmup,让内存分配更稳定

坑 8:与梯度检查点(activation checkpointing)一起用时速度过慢

原因:

  • checkpointing 会增加重算;ZeRO-3 又加重通信

处理:

  • 优先对 attention/MLP block 做分段 checkpoint,而不是全模型无脑 checkpoint
  • 在可接受显存下,宁愿用 ZeRO-2 + checkpointing,而不是 ZeRO-3 + 重度 checkpointing

坑 9:多节点训练时偶发 NCCL 超时

原因:

  • 通信量大且同步点多,网络抖动/配置不当就容易超时

处理:

  • 检查 IB/NCCL 配置、网卡绑定、NCCL 超时设置
  • 降低通信压力(bucket 调整、必要时降低 ZeRO stage 或减少并行度)

坑 10:参数量很大但依然“看起来没省多少”

原因:

  • 你省的是参数/优化器,但激活占比更大
  • 或你观察的是 nvidia-smi 的峰值/保留显存,而不是实际分解

处理:

  • 用更细的工具(如 torch profiler、deepspeed memory breakdown)看参数/梯度/优化器/激活占比
  • 用固定输入做可复现对比,避免动态 shape 影响

选择建议:如何在项目里做决策(可执行)

决策流程(推荐照着走)

  1. 先测 DDP 是否 OOM:记录 micro-batch=1 时的最大显存
  2. 如果主要卡在优化器状态(常见于 Adam):上 ZeRO-1
  3. 如果主要卡在梯度/反向显存:尝试 ZeRO-2
  4. 如果参数本体就装不下:直接 ZeRO-3,并配合

    • 合理的 bucket
    • activation checkpointing
    • 更高效 attention 实现
  5. 吞吐优先:能 ZeRO-1/2 就别上 3;ZeRO-3 是“容量优先”的方案

一条经验线

  • 你只是想把 batch 拉上去:ZeRO-1 往往就够
  • 你要进一步扩大模型但还没到“参数装不下”:ZeRO-2 性价比很高
  • 你要上超大模型(参数单卡完全不可能):ZeRO-3 + 工程化调参不可避免

小结:把 ZeRO 当作“显存预算器”来用

ZeRO 不是“开关一开就完事”,而是你在训练系统里做显存预算与通信预算的核心工具:

  • ZeRO-1:最稳的第一步,立竿见影省优化器显存
  • ZeRO-2:进一步省梯度,适合想扩大 batch/模型且网络条件不错的场景
  • ZeRO-3:容量天花板方案,省得最多但通信与配置最复杂,务必按 bucket、保存策略、激活开销系统调优

在本系列后续文章里,如果你继续往“实战落地”推进,建议把 ZeRO 与 activation checkpointing、FlashAttention、以及不同并行策略(TP/PP)结合起来做端到端的吞吐与成本优化。

DeepSpeed ZeRO 1/2/3详解:显存节省原理与配置避坑
https://aissn.com/119.html