ZeRO 优化器 Premium渐进式去冗余,从优化器状态到参数的三级分片
Companion Code 上一章我们看到 DDP 的内存问题:为了保证训练一致性(通过 All-Reduce 同步梯度),每个 GPU 都需要存储完整的模型状态。4 个 GPU 就是 4 份完整副本(参数、梯度、优化器状态)。ZeRO(Zero Redundancy Optimizer)的核心思想很直接:既然最终状态是一致的,那就每个 GPU 只存一部分 ,需要的时候再通信取回。
先量化一下 DDP 的浪费。以混合精度 + Adam 为例,N N N 个 GPU 训练一个 Φ \Phi Φ 参数的模型,每个 GPU 需要存储:
参数(fp16):2 Φ 2\Phi 2Φ bytes
梯度(fp16):2 Φ 2\Phi 2Φ bytes
优化器状态(fp32):12 Φ 12\Phi 12Φ bytes(参数副本 + 一阶矩 + 二阶矩)
合计 16 Φ 16\Phi 16Φ bytes,其中优化器状态占了 75%。
Log in to continue reading This is premium content. Please log in to access the full article.
N N N 个 GPU 就是 N N N 倍冗余:全局存储 16 N Φ 16N\Phi 16 N Φ bytes,但实际只需要 16 Φ 16\Phi 16Φ bytes。ZeRO 的三个 Stage 就是按从大到小的顺序,依次消除这些冗余。
Stage 1 只做一件事:把优化器状态均分到 N N N 个 GPU 上 。
每个参数有一个"owner" rank,只有 owner 存储该参数的 Adam 状态(fp32 参数副本、一阶矩 m m m 、二阶矩 v v v )。训练流程变为:
前向传播 :和 DDP 一样,各自独立计算
反向传播 :计算梯度后,通过 reduce(不是 all_reduce)发送到 owner rank
参数更新 :owner rank 更新参数,然后 broadcast 广播更新后的参数
首先要决定每个参数的 owner。最简单的方式是轮询(round-robin):
systems/distributed_training/02_zero1.py # 为每个参数分配 owner rank
param_to_rank = {}
owned_params = [] # 当前 rank 拥有的参数
for i, (name, param) in enumerate (model.named_parameters()):
owner_rank = i % world_size
param_to_rank[name]
这样可以保证参数均匀分布到各个 GPU 上。例如 4 个 GPU 训练时:
参数 0, 4, 8, ... → GPU 0(GPU 0 只为这些参数创建优化器状态)
参数 1, 5, 9, ... → GPU 1(GPU 1 只为这些参数创建优化器状态)
参数 2, 6, 10, ... → GPU 2
参数 3, 7, 11, ... → GPU 3
ZeRO-1 节省的是优化器状态,不是梯度!
每个 rank 只为自己拥有的参数创建 Adam 状态(fp32 参数副本、一阶矩、二阶矩),这样优化器状态从 12 Φ 12\Phi 12Φ 降到 12 Φ / N 12\Phi/N 12Φ/ N 。
梯度在 ZeRO-1 中仍然占用完整内存(2 Φ 2\Phi 2Φ ),要到 ZeRO-2 才会释放。
# DDP: all_reduce, 所有 rank 都拿到完整梯度
dist.all_reduce(grad, async_op = True )
# ZeRO-1: reduce, 只有 owner rank 拿到梯度
dist.reduce(grad, dst = rank_id, async_op = True ) systems/distributed_training/02_zero1.py def sync_grad (grad, dst_rank):
"""Reduce gradient to owner rank (ZeRO-1 key operation)"""
dist.reduce(grad, dst = dst_rank)
def sync_param (param, src_rank):
"""Broadcast parameter from owner rank to all ranks"""
dist.broadcast(param.data, src = src_rank) # 正常训练
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() # DDP 用 wrapper 包装模型
model = DDP(model, device_ids = [rank])
# 训练循环和单卡一样
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward() # DDP wrapper 自动在这里插入 all_reduce
optimizer.step() DDP 的 wrapper 会在 backward() 时自动注册 hooks,完成梯度的 all_reduce,对用户透明。
ZeRO-1 不使用 DDP wrapper,而是直接用原始模型,所以需要手动控制:
systems/distributed_training/02_zero1.py # 不使用 DDP wrapper
model = GPT(cfg).to(device)
# 前向 + 反向
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward() # 只计算梯度,没有自动同步
# ZeRO-1: 手动梯度同步(插入位置 1)
for name, param in model.named_parameters():
backward() 之后 :每个 rank 都计算了自己的梯度(基于自己的数据),需要 reduce 到 owner rank 求和
step() 之后 :owner 已经更新了参数,需要 broadcast 给其他 ranks,保证所有 GPU 的参数一致
梯度状态的变化 :
dist.reduce() 不会自动清空非 owner 的梯度,这就是为什么 ZeRO-2 需要显式释放。
如果不做第 2 步,各个 GPU 的参数就会不一致,训练就乱了。
内存节省 :优化器状态从 12 Φ 12\Phi 12Φ 降到 12 Φ / N 12\Phi/N 12Φ/ N 。对于 4 个 GPU,这部分省了 75%。
组件 DDP ZeRO-1 参数 2 Φ 2\Phi 2Φ 2 Φ 2\Phi 2Φ 梯度 2 Φ 2\Phi 2Φ 2 Φ 2\Phi 2Φ 优化器状态 12 Φ 12\Phi 12Φ 12 Φ / N 12\Phi/N 12Φ/ N
Stage 2 在 Stage 1 的基础上,梯度也只保留在 owner rank 上 。
回顾 ZeRO-1:dist.reduce() 之后,非 owner ranks 的梯度虽然不会被使用,但仍然占用内存。ZeRO-2 显式释放这些梯度,进一步节省内存。
systems/distributed_training/03_zero2.py def desync_grad (param, owner_rank, current_rank):
"""Release gradient on non-owner ranks (ZeRO-2 key operation)"""
if current_rank != owner_rank:
param.grad = None systems/distributed_training/03_zero2.py # 梯度同步 + 释放
for name, param in model.named_parameters():
if param.grad is not None :
owner_rank = param_to_rank[name]
# Step 1: Reduce gradient to owner
sync_grad(param.grad, dst_rank =
组件 DDP ZeRO-1 ZeRO-2 参数 2 Φ 2\Phi 2Φ 2 Φ 2\Phi 2Φ 2 Φ 2\Phi 2Φ 梯度 2 Φ 2\Phi 2Φ 2 Φ 2\Phi 2Φ 2 Φ / N 2\Phi/N 2Φ/ N
Stage 3 更进一步:参数本身也分片了 。每个 GPU 只存储 1 / N 1/N 1/ N 的参数。
这意味着前向传播也需要通信了。每个层在计算前,需要 broadcast 聚合完整参数,计算完后立即释放。
ZeRO-3 的参数分配和优化器创建与 ZeRO-1/2 完全一样(轮询分配 + 只为 owned_params 创建优化器)。
ZeRO-1/2 :所有 ranks 都保留完整参数(2 Φ 2\Phi 2Φ )
ZeRO-3 :初始化后立即释放非 owner 的参数,只保留占位符
systems/distributed_training/04_zero3.py # 初始化后立即释放非 owner 的参数数据
for name, param in param_list:
owner_rank = param_to_rank[name]
if owner_rank != rank:
# 非 owner:释放参数数据,保留小占位符
param.data = torch.empty( 1 , device 这样从一开始就只占用 2 Φ / N 2\Phi/N 2Φ/ N 的参数内存。
参数对象 vs 参数数据 :
param 是参数对象(Python 对象,内存占用很小)
param.data 是参数数据(张量,占用大量内存)
释放的是 param.data,把大张量替换成小占位符
param_list 保存的是参数对象的引用,用于训练循环中遍历参数
systems/distributed_training/04_zero3.py def sync_param (param, src_rank, current_rank):
"""Broadcast parameter from owner rank"""
if current_rank != src_rank:
# Non-owner: create buffer to receive data
if not hasattr (param, '_full_param_shape' ):
param._full_param_shape = param.data.shape
systems/distributed_training/04_zero3.py # 前向传播前:聚合参数
for name, param in param_list:
owner_rank = param_to_rank[name]
sync_param(param, src_rank = owner_rank, current_rank = rank)
# 前向传播
_, loss = model(x, y)
因为 ZeRO-3 在前向传播后立即释放了参数(节省内存),所以反向传播前需要再次聚合。这是 ZeRO-3 的核心特点:用通信换内存 。
为什么不能前向后直接反向?
你可能会想:前向传播后不释放参数,直接进行反向传播,不是可以省掉一次通信吗?
关键在于峰值内存 和逐层处理 :
真正的 ZeRO-3(如 DeepSpeed)是逐层 聚合/释放参数:
Layer 1 前向:聚合 → 计算 → 立即释放
Layer 2 前向:聚合 → 计算 → 立即释放
...(反向传播同理)
这样任何时刻只有一层 的参数在内存中,大大降低峰值内存。
我们的简化实现是整个模型一起聚合/释放,主要用于演示概念。生产环境中应该使用逐层控制的实现(如 PyTorch FSDP)。
组件 DDP ZeRO-1 ZeRO-2 ZeRO-3 参数 2 Φ 2\Phi 2Φ 2 Φ 2\Phi 2Φ 2 Φ 2\Phi 2Φ 2 Φ / N 2\Phi/N 2Φ/ N 梯度 2 Φ 2\Phi 2Φ 2 Φ 2\Phi
Memory per GPU DDP ZeRO-1 ZeRO-2 ZeRO-3
以下分析按每个训练步 统计,通信量以参数量 Φ \Phi Φ 为单位。All-Reduce 在 Ring 实现下的通信量为 2 Φ 2\Phi 2Φ (一轮 reduce-scatter + 一轮 all-gather)。
策略 前向通信 反向通信 参数更新后通信 总通信量 DDP 无 All-Reduce: 2 Φ 2\Phi 2Φ 无 2 Φ 2\Phi 2Φ ZeRO-1 无 Reduce: Φ \Phi Φ Broadcast: Φ \Phi Φ 2 Φ 2\Phi 2Φ ZeRO-2 无 Reduce: Φ \Phi
三个 Stage 的总通信量都是 2 Φ 2\Phi 2Φ ,与 DDP 完全相同。区别只是通信发生的时机不同:Stage 1/2 在反向传播后同步,Stage 3 在前向传播前也需要通信。ZeRO 论文的核心贡献正是:在不增加通信量的前提下,实现线性的内存扩展 。
注意:如果 Stage 3 的反向传播也需要重新 broadcast 参数(因为前向后已释放),总通信量会增加到 3 Φ 3\Phi 3Φ 。具体取决于实现是否缓存了反向所需的参数。
ZeRO-3 的分片策略是 Inter-Tensor (跨张量分配):把整个参数张量分配给某一个 owner rank。
# ZeRO-3: 每个完整的参数张量有一个 owner
# weight_0 → Rank 0 (owner)
# weight_1 → Rank 1 (owner)
# weight_2 → Rank 2 (owner)
# weight_3 → Rank 3 (owner) 这种方式实现简单,但有一个潜在问题:如果某些层的参数特别大,负载可能不均衡。比如 Embedding 层往往比 Linear 层大得多,持有 Embedding 的 rank 会占用更多内存。
这个问题引出了下一章的 FSDP,它用一种更均匀的分片方式来解决。
Stage 1 :分片优化器状态,梯度用 reduce 替代 all_reduce,内存降至 4 Φ + 12 Φ / N 4\Phi + 12\Phi/N 4Φ + 12Φ/ N
Stage 2 :额外分片梯度,非 owner rank 及时释放梯度内存,降至 2 Φ + 14 Φ / N 2\Phi + 14\Phi/N 2Φ + 14Φ/ N
Stage 3 :参数也分片,前向传播前 broadcast 聚合,用完立即释放,降至 16 Φ / N 16\Phi/N 16Φ/ N
通信代价 :三个 Stage 的总通信量都是 2 Φ 2\Phi 2Φ ,与 DDP 相同(若反向传播不缓存参数则 Stage 3 增至 )
在下一章,我们将探讨 FSDP ,看它如何用 Intra-Tensor 分片解决 ZeRO-3 的负载均衡问题。
=
owner_rank
if owner_rank == rank:
owned_params.append(param)
# 只为拥有的参数创建优化器状态(这是 ZeRO-1 节省内存的关键)
optimizer = torch.optim.AdamW(owned_params, lr = 1e-3 )
if
param.grad
is
not
None
:
owner_rank = param_to_rank[name]
sync_grad(param.grad, dst_rank = owner_rank) # reduce 到 owner
# 更新参数(只有 owner 的更新有效)
optimizer.step()
# ZeRO-1: 广播更新后的参数(插入位置 2)
for name, param in model.named_parameters():
owner_rank = param_to_rank[name]
sync_param(param, src_rank = owner_rank) # broadcast 给所有 rank
合计 16 Φ 16\Phi 16Φ 4 Φ + 12 Φ / N 4\Phi + 12\Phi/N 4Φ + 12Φ/ N
owner_rank)
# Step 2: Non-owner ranks release gradient (NEW in ZeRO-2)
desync_grad(param, owner_rank, rank)
优化器状态 12 Φ 12\Phi 12Φ 12 Φ / N 12\Phi/N 12Φ/ N 12 Φ / N 12\Phi/N 12Φ/ N
合计 16 Φ 16\Phi 16Φ 4 Φ + 12 Φ / N 4\Phi + 12\Phi/N 4Φ + 12Φ/ N 2 Φ + 14 Φ / N 2\Phi + 14\Phi/N 2Φ + 14Φ/ N
=
param.data.device,
dtype
=
param.data.dtype)
param._full_param_dtype = param.data.dtype
dist.broadcast(param.data, src = src_rank)
def desync_param_data (param, owner_rank, current_rank):
"""Release full parameter on non-owner ranks"""
if current_rank != owner_rank:
# Non-owner: release full parameter (keep small placeholder)
param.data = torch.empty( 1 , device = param.data.device, dtype = param.data.dtype)
# 前向传播后:立即释放参数
for name, param in param_list:
owner_rank = param_to_rank[name]
desync_param_data(param, owner_rank, rank)
# 反向传播前:再次聚合参数(用于计算梯度)
for name, param in param_list:
owner_rank = param_to_rank[name]
sync_param(param, src_rank = owner_rank, current_rank = rank)
loss.backward()
# 梯度同步后:释放非 owner 参数
for name, param in param_list:
owner_rank = param_to_rank[name]
sync_param(param, src_rank = owner_rank, current_rank = rank)
desync_param_data(param, owner_rank, rank)
2Φ
优化器状态 12 Φ 12\Phi 12Φ 12 Φ / N 12\Phi/N 12Φ/ N 12 Φ / N 12\Phi/N 12Φ/ N 12 Φ / N 12\Phi/N 12Φ/ N
合计 16 Φ 16\Phi 16Φ 4 Φ + 12 Φ / N 4\Phi + 12\Phi/N 4Φ + 12Φ/ N 2 Φ + 14 Φ / N 2\Phi + 14\Phi/N 2Φ + 14Φ/ N 16 Φ / N 16\Phi/N 16Φ/ N
Φ
ZeRO-3 Broadcast: Φ \Phi Φ Reduce: Φ \Phi Φ 无(已含在前向) 2 Φ 2\Phi 2Φ
3Φ