Skip to content

Add AutoEP + AutoTP parallel folding#8064

Open
tohtana wants to merge 3 commits into
deepspeedai:masterfrom
tohtana:tohtana/autoep-autotp-parallel-folding-design
Open

Add AutoEP + AutoTP parallel folding#8064
tohtana wants to merge 3 commits into
deepspeedai:masterfrom
tohtana:tohtana/autoep-autotp-parallel-folding-design

Conversation

@tohtana

@tohtana tohtana commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

This PR adds parallel folding for AutoEP: tensor parallelism (AutoTP) for the dense/attention path can now coexist with expert parallelism (AutoEP) for the routed-expert path on the same set of ranks, without forcing EP to be a subset of DP.
(This PR should be adjusted for ZeRO3 support after #8060 is merged)

Design

Attention/dense and MoE are treated as two independent partitionings of the same rank set, parameterized per parameter family:

  • Dense / attention / shared-expert params: stage_size = tp * dp
  • Routed-expert params: stage_size = ep * etp * edp

dp and edp are always derived, never user-configured, so the invariant tp * dp == ep * etp * edp == stage_size cannot be broken from config.

Configuration

No new config section. Folding is expressed by the coexistence of the existing tensor_parallel and expert_parallel sections:

{
  "tensor_parallel":  { "autotp_size": 2 },
  "expert_parallel":  { "enabled": true, "autoep_size": 4,
                        "expert_tensor_parallel_size": 1 }
}

expert_tensor_parallel_size is carried as a config field but currently must be 1 (expert-internal TP is reserved as follow-up and rejected fail-fast). Validation enforces divisibility, TP/sequence-parallel exclusivity, and preset_model consistency between the two sections.

What's included

  • Folded process-group derivation using the generalized expert/data-parallel group creation (mp_mode TP-strided vs SP-consecutive ordering).
  • Route-full / partition-dispatch path for folded MoE (deepspeed/moe/ep_tp_dispatch.py), with AutoTP skipping AutoEP subtrees.
  • Mode-aware TP-replicated gradient reduction for router/gate params: summed when the parallelism mode partitions tokens, averaged when tokens are replicated — matching standard sequence-parallel / tensor-parallel gradient semantics.
  • Per-parameter-family ZeRO checkpoint metadata (routed-expert vs dense/router/shared placement) and folded ZeRO-1/2 optimizer-state handling.

Correctness & validation

  • Router/gate gradient parity against a non-folded ZeRO baseline on a TP2 × EP4 (8-GPU) shape: folded gradient matches baseline to ~1e-7 (scale 1.0).
  • New folding unit tests for config, group layout, dispatch, runtime, gradient parity, and checkpoint save/load (multi-rank cases gated for GPU runners).
  • Passes the full unit test suite (aws-torch-latest-full) on H100 GPUs.

Scope / follow-ups

  • This PR covers AutoEP + AutoTP folding. The replicated-grad reduction is mode-aware so the sequence-parallel (Ulysses) folding case fits the same contract; AutoTP + AutoEP is the validated path here.
  • Expert-internal tensor parallelism (expert_tensor_parallel_size > 1) is reserved for a follow-up.
  • ZeRO-3 composition with folding is planned as separate follow-up work. (It should be done after Support AutoEP with ZeRO-3 zero.Init source modules #8060 is merged)

Allow tensor parallelism (AutoTP) for the dense/attention path to coexist
with expert parallelism (AutoEP) for routed experts on the same rank set,
without requiring EP to be a subset of DP.

- Treat dense and MoE as independent partitionings: dense view tp*dp,
  expert view ep*etp*edp, with dp/edp derived so tp*dp == ep*etp*edp ==
  stage_size. expert_tensor_parallel_size is reserved (must currently be 1).
- Express folding via the existing tensor_parallel/expert_parallel config
  sections, with divisibility, TP/sequence-parallel exclusivity, and
  preset_model consistency validation.
- Add the route-full / partition-dispatch MoE path and AutoTP skipping of
  AutoEP subtrees; derive folded process groups via the generalized
  expert/data-parallel group creation.
- Reduce TP-replicated router/gate gradients mode-aware (sum when tokens are
  partitioned, average when replicated); record per-parameter-family ZeRO
  checkpoint metadata and handle folded ZeRO-1/2 optimizer state.
- Add folding unit tests (config, groups, dispatch, runtime, gradient parity,
  checkpoint), including multi-rank GPU-gated cases.

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 278c919489

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread deepspeed/moe/ep_tp_dispatch.py Outdated
Comment on lines +216 to +220
chunks = torch.split(grad_output, ctx.counts, dim=0)
grad_padded = grad_output.new_zeros((ctx.max_rows, *grad_output.shape[1:]))
if local_count:
grad_padded[:local_count].copy_(chunks[ctx.group_rank])
return grad_padded[:local_count].contiguous(), None, None, None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Sum gathered-row gradients across TP lanes

When folded MoE output is consumed differently on each TP lane (for example by a row-parallel/lm-head layer that slices the hidden dimension), every gathered row participates in the loss on every lane. This backward path only returns chunks[ctx.group_rank] from the local rank's grad_output, so contributions from peer lanes to this rank's local expert outputs and routing weights are dropped; the padded local gradient needs to be accumulated across ctx.group before returning.

Useful? React with 👍 / 👎.

Comment thread deepspeed/runtime/zero/stage_1_and_2.py Outdated
Comment on lines +1120 to +1121
grad_reduc = self.get_gradient_for_reduction(param)
self._maybe_reduce_autoep_folding_tp_gradient(param, grad_reduc)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor ds_grad_is_ready before TP reduction

In ZeRO-2 folded runs, parameters with ds_grad_is_ready=False are intentionally skipped until their transient/tiled gradient is complete, as the guard immediately below documents. Calling the new TP reduction before that guard mutates and all-reduces incomplete gradients for those parameters, which can corrupt the final accumulated gradient once the ready shard is eventually reduced.

Useful? React with 👍 / 👎.

tohtana added 2 commits June 13, 2026 11:44
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
@PKUWZP PKUWZP self-requested a review June 14, 2026 02:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant