LogoCookLLM Docs
LogoCookLLM Docs
HomeCookLLM

Principles

Tokenization
Tokenization BasicsBPE AlgorithmGPT TokenizersBPE Training Engineering
Model Architecture
Attention Mechanisms
Position Encoding
Position Encoding BasicsRoPE Math DerivationRoPE ImplementationLength Extrapolation
GPU Programming Basics
GPU Architecture BasicsTensor LayoutTriton Basics: Vector Add
FlashAttention
Flash Attention PrinciplesFrom Naive to Auto-TuningBlock Pointers and Multi-Dim SupportCausal Masking OptimizationGrouped Query AttentionBackward Pass

Hands-on Training

X (Twitter)

ZeRO 优化器

Premium

渐进式去冗余,从优化器状态到参数的三级分片

Companion Code

上一章我们看到 DDP 的内存问题:为了保证训练一致性(通过 All-Reduce 同步梯度),每个 GPU 都需要存储完整的模型状态。4 个 GPU 就是 4 份完整副本(参数、梯度、优化器状态)。ZeRO(Zero Redundancy Optimizer)的核心思想很直接:既然最终状态是一致的,那就每个 GPU 只存一部分,需要的时候再通信取回。

训练状态的冗余分析

先量化一下 DDP 的浪费。以混合精度 + Adam 为例,NNN 个 GPU 训练一个 Φ\PhiΦ 参数的模型,每个 GPU 需要存储:

  • 参数(fp16):2Φ2\Phi2Φ bytes
  • 梯度(fp16):2Φ2\Phi2Φ bytes
  • 优化器状态(fp32):12Φ12\Phi12Φ bytes(参数副本 + 一阶矩 + 二阶矩)

合计 16Φ16\Phi16Φ bytes,其中优化器状态占了 75%。

Log in to continue reading

This is premium content. Please log in to access the full article.

NNN 个 GPU 就是 NNN 倍冗余:全局存储 16NΦ16N\Phi16NΦ bytes,但实际只需要 16Φ16\Phi16Φ bytes。ZeRO 的三个 Stage 就是按从大到小的顺序,依次消除这些冗余。

ZeRO Stage 1:分片优化器状态

Stage 1 只做一件事:把优化器状态均分到 NNN 个 GPU 上。

每个参数有一个"owner" rank,只有 owner 存储该参数的 Adam 状态(fp32 参数副本、一阶矩 mmm、二阶矩 vvv)。训练流程变为:

  1. 前向传播:和 DDP 一样,各自独立计算
  2. 反向传播:计算梯度后,通过 reduce(不是 all_reduce)发送到 owner rank
  3. 参数更新:owner rank 更新参数,然后 broadcast 广播更新后的参数

参数分配策略

首先要决定每个参数的 owner。最简单的方式是轮询(round-robin):

systems/distributed_training/02_zero1.py
# 为每个参数分配 owner rank
param_to_rank = {}
owned_params = []  # 当前 rank 拥有的参数
for i, (name, param) in enumerate(model.named_parameters()):
    owner_rank = i % world_size
    param_to_rank[name] 




这样可以保证参数均匀分布到各个 GPU 上。例如 4 个 GPU 训练时:

  • 参数 0, 4, 8, ... → GPU 0(GPU 0 只为这些参数创建优化器状态)
  • 参数 1, 5, 9, ... → GPU 1(GPU 1 只为这些参数创建优化器状态)
  • 参数 2, 6, 10, ... → GPU 2
  • 参数 3, 7, 11, ... → GPU 3

ZeRO-1 节省的是优化器状态,不是梯度!

每个 rank 只为自己拥有的参数创建 Adam 状态(fp32 参数副本、一阶矩、二阶矩),这样优化器状态从 12Φ12\Phi12Φ 降到 12Φ/N12\Phi/N12Φ/N。

梯度在 ZeRO-1 中仍然占用完整内存(2Φ2\Phi2Φ),要到 ZeRO-2 才会释放。

梯度同步

关键代码变化只有一行:

# DDP: all_reduce, 所有 rank 都拿到完整梯度
dist.all_reduce(grad, async_op=True)

# ZeRO-1: reduce, 只有 owner rank 拿到梯度
dist.reduce(grad, dst=rank_id, async_op=True)

看 ZeRO-1 的核心函数:

systems/distributed_training/02_zero1.py
def sync_grad(grad, dst_rank):
    """Reduce gradient to owner rank (ZeRO-1 key operation)"""
    dist.reduce(grad, dst=dst_rank)

def sync_param(param, src_rank):
    """Broadcast parameter from owner rank to all ranks"""
    dist.broadcast(param.data, src=src_rank)

训练循环对比

正常的单卡训练循环很简单:

# 正常训练
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

DDP 的自动同步:

# DDP 用 wrapper 包装模型
model = DDP(model, device_ids=[rank])

# 训练循环和单卡一样
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()  # DDP wrapper 自动在这里插入 all_reduce
optimizer.step()

DDP 的 wrapper 会在 backward() 时自动注册 hooks,完成梯度的 all_reduce,对用户透明。

ZeRO-1 的手动同步:

ZeRO-1 不使用 DDP wrapper,而是直接用原始模型,所以需要手动控制:

systems/distributed_training/02_zero1.py
# 不使用 DDP wrapper
model = GPT(cfg).to(device)

# 前向 + 反向
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()  # 只计算梯度,没有自动同步

# ZeRO-1: 手动梯度同步(插入位置 1)
for name, param in model.named_parameters():










为什么要在这两个位置插入?

  1. backward() 之后:每个 rank 都计算了自己的梯度(基于自己的数据),需要 reduce 到 owner rank 求和
  2. step() 之后:owner 已经更新了参数,需要 broadcast 给其他 ranks,保证所有 GPU 的参数一致

梯度状态的变化:

  • backward() 之后:每个 rank 都有梯度(基于自己的 batch 计算的,不是 0)

    • GPU 0: grad = ∇L(batch_0)
    • GPU 1: grad = ∇L(batch_1)
  • reduce() 之后:

    • Owner rank: grad = ∇L(batch_0) + ∇L(batch_1) + ... (完整梯度)
    • 非 owner ranks: grad 内容未定义(可能还保留原值,但不应使用)

dist.reduce() 不会自动清空非 owner 的梯度,这就是为什么 ZeRO-2 需要显式释放。

如果不做第 2 步,各个 GPU 的参数就会不一致,训练就乱了。

内存节省:优化器状态从 12Φ12\Phi12Φ 降到 12Φ/N12\Phi/N12Φ/N。对于 4 个 GPU,这部分省了 75%。

组件DDPZeRO-1
参数2Φ2\Phi2Φ2Φ2\Phi2Φ
梯度2Φ2\Phi2Φ2Φ2\Phi2Φ
优化器状态12Φ12\Phi12Φ12Φ/N12\Phi/N12Φ/N

ZeRO Stage 2:分片梯度

Stage 2 在 Stage 1 的基础上,梯度也只保留在 owner rank 上。

回顾 ZeRO-1:dist.reduce() 之后,非 owner ranks 的梯度虽然不会被使用,但仍然占用内存。ZeRO-2 显式释放这些梯度,进一步节省内存。

ZeRO-2 的关键改动:

systems/distributed_training/03_zero2.py
def desync_grad(param, owner_rank, current_rank):
    """Release gradient on non-owner ranks (ZeRO-2 key operation)"""
    if current_rank != owner_rank:
        param.grad = None

训练循环中的梯度同步和释放:

systems/distributed_training/03_zero2.py
# 梯度同步 + 释放
for name, param in model.named_parameters():
    if param.grad is not None:
        owner_rank = param_to_rank[name]
        # Step 1: Reduce gradient to owner
        sync_grad(param.grad, dst_rank=

组件DDPZeRO-1ZeRO-2
参数2Φ2\Phi2Φ2Φ2\Phi2Φ2Φ2\Phi2Φ
梯度2Φ2\Phi2Φ2Φ2\Phi2Φ2Φ/N2\Phi/N2Φ/N

ZeRO Stage 3:分片参数

Stage 3 更进一步:参数本身也分片了。每个 GPU 只存储 1/N1/N1/N 的参数。

这意味着前向传播也需要通信了。每个层在计算前,需要 broadcast 聚合完整参数,计算完后立即释放。

参数分片

ZeRO-3 的参数分配和优化器创建与 ZeRO-1/2 完全一样(轮询分配 + 只为 owned_params 创建优化器)。

关键区别:

  • ZeRO-1/2:所有 ranks 都保留完整参数(2Φ2\Phi2Φ)
  • ZeRO-3:初始化后立即释放非 owner 的参数,只保留占位符
systems/distributed_training/04_zero3.py
# 初始化后立即释放非 owner 的参数数据
for name, param in param_list:
    owner_rank = param_to_rank[name]
    if owner_rank != rank:
        # 非 owner:释放参数数据,保留小占位符
        param.data = torch.empty(1, device

这样从一开始就只占用 2Φ/N2\Phi/N2Φ/N 的参数内存。

参数对象 vs 参数数据:

  • param 是参数对象(Python 对象,内存占用很小)
  • param.data 是参数数据(张量,占用大量内存)
  • 释放的是 param.data,把大张量替换成小占位符
  • param_list 保存的是参数对象的引用,用于训练循环中遍历参数

通信模式

ZeRO-3 的核心函数:

systems/distributed_training/04_zero3.py
def sync_param(param, src_rank, current_rank):
    """Broadcast parameter from owner rank"""
    if current_rank != src_rank:
        # Non-owner: create buffer to receive data
        if not hasattr(param, '_full_param_shape'):
            param._full_param_shape = param.data.shape







训练循环的关键步骤:

systems/distributed_training/04_zero3.py
# 前向传播前:聚合参数
for name, param in param_list:
    owner_rank = param_to_rank[name]
    sync_param(param, src_rank=owner_rank, current_rank=rank)

# 前向传播
_, loss = model(x, y)

















为什么需要两次聚合参数?

因为 ZeRO-3 在前向传播后立即释放了参数(节省内存),所以反向传播前需要再次聚合。这是 ZeRO-3 的核心特点:用通信换内存。

为什么不能前向后直接反向?

你可能会想:前向传播后不释放参数,直接进行反向传播,不是可以省掉一次通信吗?

关键在于峰值内存和逐层处理:

真正的 ZeRO-3(如 DeepSpeed)是逐层聚合/释放参数:

  • Layer 1 前向:聚合 → 计算 → 立即释放
  • Layer 2 前向:聚合 → 计算 → 立即释放
  • ...(反向传播同理)

这样任何时刻只有一层的参数在内存中,大大降低峰值内存。

我们的简化实现是整个模型一起聚合/释放,主要用于演示概念。生产环境中应该使用逐层控制的实现(如 PyTorch FSDP)。

组件DDPZeRO-1ZeRO-2ZeRO-3
参数2Φ2\Phi2Φ2Φ2\Phi2Φ2Φ2\Phi2Φ2Φ/N2\Phi/N2Φ/N
梯度2Φ2\Phi2Φ2Φ2\Phi

Memory per GPU

GPU 0
Optimizer12Φ
Grads2Φ
Params2Φ
GPU 1
Optimizer12Φ
Grads2Φ
Params2Φ
GPU 2
Optimizer12Φ
Grads2Φ
Params2Φ
GPU 3
Optimizer12Φ
Grads2Φ
Params2Φ
Per GPU: 16Φ
|
Total: 16NΦ

通信开销对比

内存省了,通信是不是更贵了?我们来对比一下。

以下分析按每个训练步统计,通信量以参数量 Φ\PhiΦ 为单位。All-Reduce 在 Ring 实现下的通信量为 2Φ2\Phi2Φ(一轮 reduce-scatter + 一轮 all-gather)。

策略前向通信反向通信参数更新后通信总通信量
DDP无All-Reduce: 2Φ2\Phi2Φ无2Φ2\Phi2Φ
ZeRO-1无Reduce: Φ\PhiΦBroadcast: Φ\PhiΦ2Φ2\Phi2Φ
ZeRO-2无Reduce: Φ\Phi

三个 Stage 的总通信量都是 2Φ2\Phi2Φ,与 DDP 完全相同。区别只是通信发生的时机不同:Stage 1/2 在反向传播后同步,Stage 3 在前向传播前也需要通信。ZeRO 论文的核心贡献正是:在不增加通信量的前提下,实现线性的内存扩展。

注意:如果 Stage 3 的反向传播也需要重新 broadcast 参数(因为前向后已释放),总通信量会增加到 3Φ3\Phi3Φ。具体取决于实现是否缓存了反向所需的参数。

ZeRO-3 的分片方式:Inter-Tensor

ZeRO-3 的分片策略是 Inter-Tensor(跨张量分配):把整个参数张量分配给某一个 owner rank。

# ZeRO-3: 每个完整的参数张量有一个 owner
# weight_0 → Rank 0 (owner)
# weight_1 → Rank 1 (owner)
# weight_2 → Rank 2 (owner)
# weight_3 → Rank 3 (owner)

这种方式实现简单,但有一个潜在问题:如果某些层的参数特别大,负载可能不均衡。比如 Embedding 层往往比 Linear 层大得多,持有 Embedding 的 rank 会占用更多内存。

这个问题引出了下一章的 FSDP,它用一种更均匀的分片方式来解决。

总结

本章我们学习了 ZeRO 优化器的三级分片策略:

  • Stage 1:分片优化器状态,梯度用 reduce 替代 all_reduce,内存降至 4Φ+12Φ/N4\Phi + 12\Phi/N4Φ+12Φ/N
  • Stage 2:额外分片梯度,非 owner rank 及时释放梯度内存,降至 2Φ+14Φ/N2\Phi + 14\Phi/N2Φ+14Φ/N
  • Stage 3:参数也分片,前向传播前 broadcast 聚合,用完立即释放,降至 16Φ/N16\Phi/N16Φ/N
  • 通信代价:三个 Stage 的总通信量都是 2Φ2\Phi2Φ,与 DDP 相同(若反向传播不缓存参数则 Stage 3 增至 )

在下一章,我们将探讨 FSDP,看它如何用 Intra-Tensor 分片解决 ZeRO-3 的负载均衡问题。

Table of Contents

训练状态的冗余分析
ZeRO Stage 1:分片优化器状态
参数分配策略
梯度同步
训练循环对比
ZeRO Stage 2:分片梯度
ZeRO Stage 3:分片参数
参数分片
通信模式
通信开销对比
ZeRO-3 的分片方式:Inter-Tensor
总结
=
owner_rank
if owner_rank == rank:
owned_params.append(param)
# 只为拥有的参数创建优化器状态(这是 ZeRO-1 节省内存的关键)
optimizer = torch.optim.AdamW(owned_params, lr=1e-3)
if
param.grad
is
not
None
:
owner_rank = param_to_rank[name]
sync_grad(param.grad, dst_rank=owner_rank) # reduce 到 owner
# 更新参数(只有 owner 的更新有效)
optimizer.step()
# ZeRO-1: 广播更新后的参数(插入位置 2)
for name, param in model.named_parameters():
owner_rank = param_to_rank[name]
sync_param(param, src_rank=owner_rank) # broadcast 给所有 rank
合计16Φ16\Phi16Φ4Φ+12Φ/N4\Phi + 12\Phi/N4Φ+12Φ/N
owner_rank)
# Step 2: Non-owner ranks release gradient (NEW in ZeRO-2)
desync_grad(param, owner_rank, rank)
优化器状态12Φ12\Phi12Φ12Φ/N12\Phi/N12Φ/N12Φ/N12\Phi/N12Φ/N
合计16Φ16\Phi16Φ4Φ+12Φ/N4\Phi + 12\Phi/N4Φ+12Φ/N2Φ+14Φ/N2\Phi + 14\Phi/N2Φ+14Φ/N
=
param.data.device,
dtype
=
param.data.dtype)
param._full_param_dtype = param.data.dtype
dist.broadcast(param.data, src=src_rank)
def desync_param_data(param, owner_rank, current_rank):
"""Release full parameter on non-owner ranks"""
if current_rank != owner_rank:
# Non-owner: release full parameter (keep small placeholder)
param.data = torch.empty(1, device=param.data.device, dtype=param.data.dtype)
# 前向传播后:立即释放参数
for name, param in param_list:
owner_rank = param_to_rank[name]
desync_param_data(param, owner_rank, rank)
# 反向传播前:再次聚合参数(用于计算梯度)
for name, param in param_list:
owner_rank = param_to_rank[name]
sync_param(param, src_rank=owner_rank, current_rank=rank)
loss.backward()
# 梯度同步后:释放非 owner 参数
for name, param in param_list:
owner_rank = param_to_rank[name]
sync_param(param, src_rank=owner_rank, current_rank=rank)
desync_param_data(param, owner_rank, rank)
2Φ
2Φ/N2\Phi/N2Φ/N
2Φ/N2\Phi/N2Φ/N
优化器状态12Φ12\Phi12Φ12Φ/N12\Phi/N12Φ/N12Φ/N12\Phi/N12Φ/N12Φ/N12\Phi/N12Φ/N
合计16Φ16\Phi16Φ4Φ+12Φ/N4\Phi + 12\Phi/N4Φ+12Φ/N2Φ+14Φ/N2\Phi + 14\Phi/N2Φ+14Φ/N16Φ/N16\Phi/N16Φ/N
Φ
Broadcast: Φ\PhiΦ
2Φ2\Phi2Φ
ZeRO-3Broadcast: Φ\PhiΦReduce: Φ\PhiΦ无(已含在前向)2Φ2\Phi2Φ
3Φ3\Phi
3Φ