本文属于《Ai大模型训练教程》系列,围绕提升大模型训练吞吐的三类实用技巧展开:梯度累积实现更大有效batch、FlashAttention/SDPA加速注意力并节省显存、以及torch.compile与算子融合等编译加速手段。提供可落地的配置思路、实现要点与常见问题排查。
提升训练吞吐的实用技巧:Gradient Accumulation、FlashAttention与编译加速
在《Ai大模型训练教程:从入门到实战落地的系统课程》系列里,很多读者做到能“跑起来”后,下一关往往是“跑得快、跑得稳、跑得省”。大模型训练吞吐(throughput)上不去,常见表现是:显存不够导致 batch 很小、GPU 利用率低、attention 成为瓶颈、训练一步耗时巨大、同样的预算跑不出足够 token。
这一篇聚焦三类立竿见影且工程上最常用的吞吐提升手段:
- Gradient Accumulation(梯度累积):在显存不够时“模拟更大的 batch”。
- FlashAttention:用更高效的 attention kernel 提升速度并节省显存。
- 编译加速(torch.compile / CUDA Graph / fused kernels):减少 Python 调度开销,融合算子,提高算子级吞吐。
文章会尽量用可直接落地的步骤、配置与排错建议,帮助你把“训练吞吐”从玄学变成可控的工程指标。
1. 先把吞吐指标和瓶颈定位清楚
吞吐优化前,先明确你优化的对象是什么,否则容易“优化错方向”。建议至少记录下面三类指标:
1.1 三个最常用指标
- tokens/s(或 samples/s):最直观。大模型预训练/指令微调都建议用 tokens/s(尤其是可变长度)。
- step time(每步耗时):便于对比启用某个优化前后的变化。
GPU 利用率与显存曲线:
- 利用率长期 < 70%:多半是数据加载/CPU瓶颈/小 batch/频繁同步。
- 显存接近上限且有抖动:可能在反向图或 attention 上爆显存。
1.2 快速定位瓶颈的方法(建议照做)
- 先看数据侧:DataLoader worker、pin_memory、prefetch 等是否足够;I/O 是否拖后腿。
- 再看计算侧:用 PyTorch Profiler / Nsight Systems 观察最耗时的 kernel,注意是否 attention 占比很高。
- 再看同步点:频繁
torch.cuda.synchronize()、频繁打印 loss、过于频繁的all_reduce都会拉低吞吐。
如果你发现 attention 占比很高、显存也吃紧,那么 FlashAttention 往往是第一优先级;如果显存限制导致 batch 太小,那么梯度累积是最常用的工程解法;如果 GPU 利用率不高但并非数据瓶颈,则编译加速/算子融合可能更有效。
2. Gradient Accumulation:显存不够时“凑”出大 batch
梯度累积的核心目的:在不增加显存峰值的情况下,得到更大的有效 batch size(global batch)。
2.1 概念与公式
通常我们把 batch 分成三层:
- micro batch:一次前向+反向真正喂给 GPU 的 batch(受显存限制)。
- gradient accumulation steps(累积步数):累积多少次 micro batch 后再更新一次参数。
- global batch:一次优化器 step 对应的总样本量。
公式:
global_batch = micro_batch * accumulation_steps * data_parallel_world_size
例如:
- 单卡 micro_batch=2,accumulation=8,8 卡数据并行:global=288=128。
2.2 什么时候必须用梯度累积
- 模型大、序列长、显存紧,导致 micro batch 只能是 1 或 2。
- 想保持训练稳定性(大 batch 更平滑),但硬件不允许直接增大 batch。
2.3 关键实现细节(避免“算错梯度”)
(1) loss 归一化
若你在每个 micro step 都 loss.backward(),常见做法是:
- 把 loss 除以 accumulation_steps,保证累积后的梯度尺度与直接大 batch 一致。
伪代码(PyTorch 风格):
- 每次 micro batch:
loss = loss / grad_accum_steps loss.backward()- 每累积到
grad_accum_steps:optimizer.step(); optimizer.zero_grad()
若你使用 Hugging Face Trainer / Accelerate,通常会自动处理,但你仍需确认日志里显示的 loss 是否已被平均。
(2) 梯度同步时机(分布式训练尤其重要)
在 DDP(DistributedDataParallel)下,如果每个 micro step 都触发 all-reduce,会浪费大量通信时间。
正确做法:
- 累积期间使用
no_sync(),只在最后一个 micro step 同步梯度。
这样会显著降低通信开销,提高吞吐。
(3) 学习率与调度器的“step 频率”
使用梯度累积后,optimizer.step 的次数减少。
你需要确保:
- 学习率 scheduler 的 step 是按 optimizer step 走,而不是按 micro step 走。
- warmup_steps、total_steps 的计算基于 optimizer step 数。
否则会出现:warmup 过短/过长、lr 下降节奏错乱,训练不稳定。
(4) 梯度裁剪(clip grad)放在哪里
- 应当在真正
optimizer.step()前裁剪一次(即在累积完成后),而不是每个 micro step 裁剪。
2.4 梯度累积的副作用与应对
更慢的参数更新频率:同样的 epoch,optimizer step 变少。
- 应对:关注 token 数而非 step 数,必要时增加训练总 token。
更长的“一个 step”时间:因为一个 optimizer step 包含多个 micro step。
- 这不一定是坏事,吞吐(tokens/s)可能更高。
梯度噪声尺度变化:理论上大 batch 改变优化动力学。
- 应对:尝试线性缩放学习率或使用更稳健的优化器/调度。
2.5 一个可操作的调参顺序
当你显存不足时,建议按以下顺序尝试(从“性价比最高”到“侵入性更强”):
- 开启混合精度(bf16 优先,其次 fp16)。
- 开启梯度检查点(checkpointing)降低激活显存。
- 把 micro batch 降到能跑的最小值。
- 通过 Gradient Accumulation 把 global batch 拉回目标区间。
3. FlashAttention:把 attention 从瓶颈变成优势
对于 Transformer,attention 通常既耗时又吃显存,尤其在长序列下。FlashAttention 的价值在于:
- 通过更好的内存访问与 kernel 融合,显著减少 HBM 读写。
- 避免显式构建巨大的 attention 矩阵,显存更省。
- 训练吞吐常见提升 20%~数倍(视序列长度、模型结构、硬件而定)。
3.1 你适合用 FlashAttention 的典型信号
- profile 显示
scaled_dot_product_attention或相关 attention kernel 占比很高。 - seq_len 较长(例如 2k、4k、8k),吞吐随长度增长下降明显。
- 显存卡在 attention 的中间张量上。
3.2 PyTorch 原生 SDPA:优先试“最省心”的路径
从 PyTorch 2.x 开始,推荐优先走 Scaled Dot Product Attention(SDPA),它会在可用时自动选择更快的实现(可能包括 FlashAttention 类内核)。
实操建议:
- 使用较新的 PyTorch(2.1+ 通常更稳,越新对 SDPA/Flash 支持越完善)。
- 确保模型的 attention 实现走
torch.nn.functional.scaled_dot_product_attention。 开启/确认 backend:
- 允许 flash kernel
- 允许 mem-efficient
- 允许 math(作为 fallback)
如果你的框架支持配置(如某些 Transformer 实现),优先用“flash/sdpa”开关,而不要自己手写内核。
3.3 使用 flash-attn 库:适合追求极致与特定模型实现
在部分模型(尤其是自定义 attention、或想要更稳定的 flash kernel)上,会使用 flash-attn 第三方库。
落地建议:
- 版本匹配非常关键:CUDA、PyTorch、显卡架构(A100/H100/3090/4090)与 flash-attn wheel 是否兼容。
- 安装后要确认模型代码确实调用了 flash-attn 的 attention 模块,而不是仍然走原生实现。
3.4 常见“启用了但没变快”的原因排查
没走到 flash kernel:实际 fallback 到 math kernel。
- 解决:打开 debug 日志/检查 profiler 中 kernel 名称;确认 dtype、head_dim 等条件满足。
- 序列太短:flash 的优势在长序列更明显,短序列提升有限。
- 被其他瓶颈掩盖:如数据加载、通信、embedding/MLP 才是主要耗时。
- 开启了不兼容的 mask/attention 形式:某些复杂 mask、滑窗注意力、特殊位置编码实现可能导致 fallback。
3.5 FlashAttention 与训练稳定性/数值注意点
- 建议优先使用 bf16(若硬件支持),通常比 fp16 更稳。
对于极端长序列或非常大的 logits,注意是否出现 NaN/Inf:
- 检查 gradient scaling(fp16 时)
- 检查是否有异常样本导致 loss 爆炸
- 必要时开启更保守的数值选项或回退到 mem-efficient 实现
4. 编译加速:torch.compile、算子融合与减少 Python 开销
当你已经解决显存与 attention 后,仍然可能遇到 GPU 利用率上不去、kernel 很碎、CPU 调度开销大的情况。此时编译加速往往能带来额外收益。
4.1 torch.compile:最通用的“开关式”加速
torch.compile(PyTorch 2.x)通过动态图捕获与图级优化,把许多小算子融合、减少 Python 解释器开销。
落地步骤建议:
- 先保证模型在 eager 模式稳定训练(loss 正常下降、无 NaN)。
- 对模型主体
model = torch.compile(model, mode=...)(不同 mode 在速度/编译时长/稳定性间权衡)。 - 首次迭代会有编译开销,吞吐评估需跳过前若干 step。
适用场景:
- 小 kernel 多、Python 调度重的训练脚本。
- 非常标准的 Transformer/MLP 堆叠结构。
不适用或收益小的场景:
- 动态控制流很重(可变 shape 大量出现)。
- 模型里混入了无法捕获/无法编译的自定义 CUDA op。
4.2 编译加速的常见坑与处理建议
(1) 动态 shape 导致频繁重新编译
表现:训练过程中吞吐忽高忽低,日志显示不断 recompilation。
处理建议:
- 尽量 固定 seq_len(padding 到固定长度或按桶分组)。
- 减少 batch 内 shape 变化。
- 对可变长度任务,用 bucketed batching:把相近长度样本分到同一 batch。
(2) 训练中断点/调试变困难
编译后堆栈更复杂,不利于逐行调试。
处理建议:
- 调试阶段先关 compile;稳定后再开。
- 将 compile 封装成可切换开关,便于回退。
(3) 数值差异
编译与融合可能带来极小数值差异,通常可接受,但对某些极端情况可能触发不稳定。
处理建议:
- 首次启用后对比:loss 曲线、梯度是否 NaN、评估指标是否一致。
- 若问题出现,尝试更保守的模式或只编译部分模块。
4.3 CUDA Graph:适合“形状固定、重复性强”的训练
CUDA Graph 可以把一段 GPU 工作流捕获成图,减少 CPU 发射 kernel 的开销。
适合:
- 固定 batch、固定 seq_len、固定计算图的训练(例如预训练阶段、长度固定的数据管线)。
不适合:
- 每步 shape 变化大、控制流变化多的任务。
如果你的训练是固定长度、固定 micro batch,且 profile 显示 CPU 开销明显,CUDA Graph 常能带来额外吞吐提升。
4.4 Fused Kernels:把“零碎小算子”变成“大算子”
典型融合点:
- fused Adam/AdamW(优化器融合)
- fused layernorm / RMSNorm
- fused dropout + add + layernorm(某些框架提供)
- fused MLP(如 GeLU/SiLU 与线性层融合)
落地建议:
- 优先使用成熟训练框架/库提供的 fused 组件(如 Apex、xFormers、TransformerEngine、DeepSpeed 等视栈而定)。
- 评估收益要以 tokens/s 为准,避免只看单个 kernel 变快但整体未提升。
5. 把三招组合起来:一套可复用的“吞吐优化流程”
为了避免盲目堆开关,建议按下面顺序做 A/B 测试,每次只改一个变量,记录 tokens/s、显存峰值与稳定性:
5.1 基线准备
- 固定数据集子集(例如 1万条样本)用于重复测试。
- 固定 seq_len 策略(padding 或 bucket)。
- 固定日志:记录 step time、tokens/s、max memory allocated。
5.2 优化顺序建议(高成功率路径)
- 混合精度(bf16):几乎必开。
- FlashAttention / SDPA:attention 是大头时最值。
- Gradient Accumulation:把 global batch 提到目标范围,提升训练稳定性与吞吐(尤其在多卡时减少通信频率的收益更明显)。
- torch.compile:在模型与数据 shape 相对稳定时追加收益。
- (可选)fused optimizer / CUDA Graph:对“固定图训练”继续榨干。
5.3 一个具体示例(便于你照抄思路)
假设你在 8xA100 上做指令微调:
- 目标:global batch=128,seq_len=4096
- 受限:单卡 micro_batch=1 才不 OOM
可行配置:
- micro_batch=1
- grad_accum_steps=16
- data_parallel=8
- global=1168=128
然后:
- attention 用 SDPA/Flash
- bf16 训练
- 固定长度或 bucket batching 减少 shape 抖动
- 最后尝试 torch.compile
你应当观察到:显存不再爆、吞吐明显上升、loss 曲线更平滑。
6. 常见问题清单(快速对照)
6.1 开了梯度累积后,loss 变大/变小很多
优先检查:
- 是否对 loss 做了
/ grad_accum_steps。 - 日志记录的是“每个 micro step 的 loss”还是“累积后的 loss”。
6.2 开了 FlashAttention 反而变慢
优先检查:
- 是否真的走到 flash kernel(profiler 验证)。
- seq_len 是否足够长(短序列提升有限)。
- 是否被 dataloader 或通信瓶颈掩盖。
6.3 torch.compile 训练中途开始卡顿
优先检查:
- 是否频繁触发重新编译(动态 shape)。
- 是否存在某些 step 走了不同分支(控制流变化)。
7. 小结:把吞吐当成可工程化的指标
在 Ai 大模型训练落地中,“吞吐”直接决定你能用同样预算跑多少 token、试多少组超参、多久能交付一个可用模型。最实用的三类提速手段可以概括为:
- Gradient Accumulation:解决显存限制下的 batch 问题,同时减少分布式同步频率。
- FlashAttention(或 SDPA):Transformer 训练最常见的核心提速点,兼顾速度与显存。
- 编译加速(torch.compile / CUDA Graph / fused):减少 kernel 碎片与 Python 调度,进一步榨干 GPU。
建议你用“固定基线 + 单变量 A/B + profiler 验证”的方式,把每个开关的收益量化出来,最终形成你团队可复用的训练配方。下一篇如果你希望继续深入,我们可以围绕:如何用 profiler 给出可复现的吞吐报告、如何对齐不同框架(HF/DeepSpeed/Megatron)的吞吐对比方法。
Prev:训练过程如何看懂曲线:Loss、PPL、Grad Norm、吞吐与显存的监控方法