ZeRO 优化器
Premium渐进式去冗余,从优化器状态到参数的三级分片
Companion Code上一章我们看到 DDP 的内存问题:为了保证训练一致性(通过 All-Reduce 同步梯度),每个 GPU 都需要存储完整的模型状态。4 个 GPU 就是 4 份完整副本(参数、梯度、优化器状态)。ZeRO(Zero Redundancy Optimizer)的核心思想很直接:既然最终状态是一致的,那就每个 GPU 只存一部分,需要的时候再通信取回。
训练状态的冗余分析
先量化一下 DDP 的浪费。以混合精度 + Adam 为例, 个 GPU 训练一个 参数的模型,每个 GPU 需要存储:
- 参数(fp16): bytes
- 梯度(fp16): bytes
- 优化器状态(fp32): bytes(参数副本 + 一阶矩 + 二阶矩)
合计 bytes,其中优化器状态占了 75%。
个 GPU 就是 倍冗余:全局存储 bytes,但实际只需要 bytes。ZeRO 的三个 Stage 就是按从大到小的顺序,依次消除这些冗余。
ZeRO Stage 1:分片优化器状态
Stage 1 只做一件事:把优化器状态均分到 个 GPU 上。
每个参数有一个"owner" rank,只有 owner 存储该参数的 Adam 状态(fp32 参数副本、一阶矩 、二阶矩 )。训练流程变为:
- 前向传播:和 DDP 一样,各自独立计算
- 反向传播:计算梯度后,通过
reduce(不是all_reduce)发送到 owner rank - 参数更新:owner rank 更新参数,然后
broadcast广播更新后的参数
参数分配策略
首先要决定每个参数的 owner。最简单的方式是轮询(round-robin):
# 为每个参数分配 owner rank
param_to_rank = {}
owned_params = [] # 当前 rank 拥有的参数
for i, (name, param) in enumerate(model.named_parameters()):
owner_rank = i % world_size
param_to_rank[name] = owner_rank
if owner_rank == rank:
owned_params.append(param)
# 只为拥有的参数创建优化器状态(这是 ZeRO-1 节省内存的关键)
optimizer = torch.optim.AdamW(owned_params, lr=1e-3)这样可以保证参数均匀分布到各个 GPU 上。例如 4 个 GPU 训练时:
- 参数 0, 4, 8, ... → GPU 0(GPU 0 只为这些参数创建优化器状态)
- 参数 1, 5, 9, ... → GPU 1(GPU 1 只为这些参数创建优化器状态)
- 参数 2, 6, 10, ... → GPU 2
- 参数 3, 7, 11, ... → GPU 3
ZeRO-1 节省的是优化器状态,不是梯度!
每个 rank 只为自己拥有的参数创建 Adam 状态(fp32 参数副本、一阶矩、二阶矩),这样优化器状态从 降到 。
梯度在 ZeRO-1 中仍然占用完整内存(),要到 ZeRO-2 才会释放。
梯度同步
关键代码变化只有一行:
# DDP: all_reduce, 所有 rank 都拿到完整梯度
dist.all_reduce(grad, async_op=True)
# ZeRO-1: reduce, 只有 owner rank 拿到梯度
dist.reduce(grad, dst=rank_id, async_op=True)看 ZeRO-1 的核心函数:
def sync_grad(grad, dst_rank):
"""Reduce gradient to owner rank (ZeRO-1 key operation)"""
dist.reduce(grad, dst=dst_rank)
def sync_param(param, src_rank):
"""Broadcast parameter from owner rank to all ranks"""
dist.broadcast(param.data, src=src_rank)训练循环对比
正常的单卡训练循环很简单:
# 正常训练
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()DDP 的自动同步:
# DDP 用 wrapper 包装模型
model = DDP(model, device_ids=[rank])
# 训练循环和单卡一样
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward() # DDP wrapper 自动在这里插入 all_reduce
optimizer.step()DDP 的 wrapper 会在 backward() 时自动注册 hooks,完成梯度的 all_reduce,对用户透明。
ZeRO-1 的手动同步:
ZeRO-1 不使用 DDP wrapper,而是直接用原始模型,所以需要手动控制:
# 不使用 DDP wrapper
model = GPT(cfg).to(device)
# 前向 + 反向
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward() # 只计算梯度,没有自动同步
# ZeRO-1: 手动梯度同步(插入位置 1)
for name, param in model.named_parameters():
if param.grad is not None:
owner_rank = param_to_rank[name]
sync_grad(param.grad, dst_rank=owner_rank) # reduce 到 owner
# 更新参数(只有 owner 的更新有效)
optimizer.step()
# ZeRO-1: 广播更新后的参数(插入位置 2)
for name, param in model.named_parameters():
owner_rank = param_to_rank[name]
sync_param(param, src_rank=owner_rank) # broadcast 给所有 rank为什么要在这两个位置插入?
backward()之后:每个 rank 都计算了自己的梯度(基于自己的数据),需要 reduce 到 owner rank 求和step()之后:owner 已经更新了参数,需要 broadcast 给其他 ranks,保证所有 GPU 的参数一致
梯度状态的变化:
-
backward() 之后:每个 rank 都有梯度(基于自己的 batch 计算的,不是 0)
- GPU 0: grad = ∇L(batch_0)
- GPU 1: grad = ∇L(batch_1)
-
reduce() 之后:
- Owner rank: grad = ∇L(batch_0) + ∇L(batch_1) + ... (完整梯度)
- 非 owner ranks: grad 内容未定义(可能还保留原值,但不应使用)
dist.reduce() 不会自动清空非 owner 的梯度,这就是为什么 ZeRO-2 需要显式释放。
如果不做第 2 步,各个 GPU 的参数就会不一致,训练就乱了。
内存节省:优化器状态从 降到 。对于 4 个 GPU,这部分省了 75%。
| 组件 | DDP | ZeRO-1 |
|---|---|---|
| 参数 | ||
| 梯度 | ||
| 优化器状态 | ||
| 合计 |
ZeRO Stage 2:分片梯度
Stage 2 在 Stage 1 的基础上,梯度也只保留在 owner rank 上。
回顾 ZeRO-1:dist.reduce() 之后,非 owner ranks 的梯度虽然不会被使用,但仍然占用内存。ZeRO-2 显式释放这些梯度,进一步节省内存。
ZeRO-2 的关键改动:
def desync_grad(param, owner_rank, current_rank):
"""Release gradient on non-owner ranks (ZeRO-2 key operation)"""
if current_rank != owner_rank:
param.grad = None训练循环中的梯度同步和释放:
# 梯度同步 + 释放
for name, param in model.named_parameters():
if param.grad is not None:
owner_rank = param_to_rank[name]
# Step 1: Reduce gradient to owner
sync_grad(param.grad, dst_rank=owner_rank)
# Step 2: Non-owner ranks release gradient (NEW in ZeRO-2)
desync_grad(param, owner_rank, rank)| 组件 | DDP | ZeRO-1 | ZeRO-2 |
|---|---|---|---|
| 参数 | |||
| 梯度 | |||
| 优化器状态 | |||
| 合计 |
ZeRO Stage 3:分片参数
Stage 3 更进一步:参数本身也分片了。每个 GPU 只存储 的参数。
这意味着前向传播也需要通信了。每个层在计算前,需要 broadcast 聚合完整参数,计算完后立即释放。
参数分片
ZeRO-3 的参数分配和优化器创建与 ZeRO-1/2 完全一样(轮询分配 + 只为 owned_params 创建优化器)。
关键区别:
- ZeRO-1/2:所有 ranks 都保留完整参数()
- ZeRO-3:初始化后立即释放非 owner 的参数,只保留占位符
# 初始化后立即释放非 owner 的参数数据
for name, param in param_list:
owner_rank = param_to_rank[name]
if owner_rank != rank:
# 非 owner:释放参数数据,保留小占位符
param.data = torch.empty(1, device=param.data.device, dtype=param.data.dtype)这样从一开始就只占用 的参数内存。
参数对象 vs 参数数据:
param是参数对象(Python 对象,内存占用很小)param.data是参数数据(张量,占用大量内存)- 释放的是
param.data,把大张量替换成小占位符 param_list保存的是参数对象的引用,用于训练循环中遍历参数
通信模式
ZeRO-3 的核心函数:
def sync_param(param, src_rank, current_rank):
"""Broadcast parameter from owner rank"""
if current_rank != src_rank:
# Non-owner: create buffer to receive data
if not hasattr(param, '_full_param_shape'):
param._full_param_shape = param.data.shape
param._full_param_dtype = param.data.dtype
dist.broadcast(param.data, src=src_rank)
def desync_param_data(param, owner_rank, current_rank):
"""Release full parameter on non-owner ranks"""
if current_rank != owner_rank:
# Non-owner: release full parameter (keep small placeholder)
param.data = torch.empty(1, device=param.data.device, dtype=param.data.dtype)训练循环的关键步骤:
# 前向传播前:聚合参数
for name, param in param_list:
owner_rank = param_to_rank[name]
sync_param(param, src_rank=owner_rank, current_rank=rank)
# 前向传播
_, loss = model(x, y)
# 前向传播后:立即释放参数
for name, param in param_list:
owner_rank = param_to_rank[name]
desync_param_data(param, owner_rank, rank)
# 反向传播前:再次聚合参数(用于计算梯度)
for name, param in param_list:
owner_rank = param_to_rank[name]
sync_param(param, src_rank=owner_rank, current_rank=rank)
loss.backward()
# 梯度同步后:释放非 owner 参数
for name, param in param_list:
owner_rank = param_to_rank[name]
sync_param(param, src_rank=owner_rank, current_rank=rank)
desync_param_data(param, owner_rank, rank)为什么需要两次聚合参数?
因为 ZeRO-3 在前向传播后立即释放了参数(节省内存),所以反向传播前需要再次聚合。这是 ZeRO-3 的核心特点:用通信换内存。
为什么不能前向后直接反向?
你可能会想:前向传播后不释放参数,直接进行反向传播,不是可以省掉一次通信吗?
关键在于峰值内存和逐层处理:
真正的 ZeRO-3(如 DeepSpeed)是逐层聚合/释放参数:
- Layer 1 前向:聚合 → 计算 → 立即释放
- Layer 2 前向:聚合 → 计算 → 立即释放
- ...(反向传播同理)
这样任何时刻只有一层的参数在内存中,大大降低峰值内存。
我们的简化实现是整个模型一起聚合/释放,主要用于演示概念。生产环境中应该使用逐层控制的实现(如 PyTorch FSDP)。
| 组件 | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 |
|---|---|---|---|---|
| 参数 | ||||
| 梯度 | ||||
| 优化器状态 | ||||
| 合计 |
Memory per GPU
通信开销对比
内存省了,通信是不是更贵了?我们来对比一下。
以下分析按每个训练步统计,通信量以参数量 为单位。All-Reduce 在 Ring 实现下的通信量为 (一轮 reduce-scatter + 一轮 all-gather)。
| 策略 | 前向通信 | 反向通信 | 参数更新后通信 | 总通信量 |
|---|---|---|---|---|
| DDP | 无 | All-Reduce: | 无 | |
| ZeRO-1 | 无 | Reduce: | Broadcast: | |
| ZeRO-2 | 无 | Reduce: | Broadcast: | |
| ZeRO-3 | Broadcast: | Reduce: | 无(已含在前向) |
三个 Stage 的总通信量都是 ,与 DDP 完全相同。区别只是通信发生的时机不同:Stage 1/2 在反向传播后同步,Stage 3 在前向传播前也需要通信。ZeRO 论文的核心贡献正是:在不增加通信量的前提下,实现线性的内存扩展。
注意:如果 Stage 3 的反向传播也需要重新 broadcast 参数(因为前向后已释放),总通信量会增加到 。具体取决于实现是否缓存了反向所需的参数。
ZeRO-3 的分片方式:Inter-Tensor
ZeRO-3 的分片策略是 Inter-Tensor(跨张量分配):把整个参数张量分配给某一个 owner rank。
# ZeRO-3: 每个完整的参数张量有一个 owner
# weight_0 → Rank 0 (owner)
# weight_1 → Rank 1 (owner)
# weight_2 → Rank 2 (owner)
# weight_3 → Rank 3 (owner)这种方式实现简单,但有一个潜在问题:如果某些层的参数特别大,负载可能不均衡。比如 Embedding 层往往比 Linear 层大得多,持有 Embedding 的 rank 会占用更多内存。
这个问题引出了下一章的 FSDP,它用一种更均匀的分片方式来解决。
总结
本章我们学习了 ZeRO 优化器的三级分片策略:
- Stage 1:分片优化器状态,梯度用
reduce替代all_reduce,内存降至 - Stage 2:额外分片梯度,非 owner rank 及时释放梯度内存,降至
- Stage 3:参数也分片,前向传播前
broadcast聚合,用完立即释放,降至 - 通信代价:三个 Stage 的总通信量都是 ,与 DDP 相同(若反向传播不缓存参数则 Stage 3 增至 )
在下一章,我们将探讨 FSDP,看它如何用 Intra-Tensor 分片解决 ZeRO-3 的负载均衡问题。
Log in to continue reading
This is premium content. Please log in to access the full article.
CookLLM Docs