Support AutoEP with ZeRO-3 zero.Init source modules#8060
Conversation
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
- Honor load_module_strict=False in load_module_state_dict when AutoEP expert keys are allowed to be missing; the mismatch error now only fires for strict loads. - Add a 4-GPU universal checkpoint round trip with expert-DP world size 2 so converter consolidation and universal/module-only loads cover real partition shards instead of the degenerate single-rank case. - Check AutoEP partition-native metadata inside parse_model_states so zero_to_fp32 loads each model state file once and rejects unsupported checkpoints before the optimizer shard load. Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
There was a problem hiding this comment.
Pull request overview
This PR extends DeepSpeed’s AutoEP (Automatic Expert Parallelism) support to work with ZeRO Stage 3 by introducing parameter placement/partitioning across expert replica groups for expert parameters while keeping router/replicated parameters on the global data-parallel group, and updating checkpointing + Universal Checkpoint conversion paths accordingly.
Changes:
- Add ZeRO-3 compatibility gates for MoE, and resolve per-parameter ZeRO-3 partition placement (AutoEP expert vs replicated).
- Update ZeRO-3 internals (partition groups, gradient averaging, grad norm, all-gather/broadcast handling) to support mixed partition process groups.
- Extend checkpoint save/load and
ds_to_universal.pylogic to handle AutoEP ZeRO-3 “partition-native” expert metadata, plus add extensive unit/integration coverage.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
deepspeed/runtime/engine.py |
Adds ZeRO-3 MoE compatibility validation, per-parameter ZeRO partition placement, and AutoEP-aware checkpoint save/load behaviors. |
deepspeed/runtime/zero/stage3.py |
Teaches ZeRO-3 optimizer logic to operate with per-subgroup process groups and AutoEP expert vs replicated gradient semantics. |
deepspeed/runtime/zero/partition_parameters.py |
Extends parameter partition/all-gather/broadcast logic to support per-parameter partition process groups and multi-handle waits. |
deepspeed/moe/ep_repack.py |
Gathers ZeRO params when repacking expert weights from ZeRO.Init-created source modules. |
deepspeed/module_inject/auto_ep_layer.py |
Ensures router/expert weight copies gather ZeRO params first; annotates params with ZeRO placement metadata. |
deepspeed/module_inject/auto_ep.py |
Uses ds_shape for ZeRO params when inferring shapes during AutoEP detection/parsing. |
deepspeed/module_inject/auto_ep_config.py |
Parses/validates expert_tensor_parallel_size (currently constrained to 1). |
deepspeed/module_inject/auto_ep_presets/base.py |
Adds expert_tensor_parallel_size to AutoEPConfig. |
deepspeed/checkpoint/ds_to_universal.py |
Adds ZeRO-3 AutoEP partition-native expert consolidation into Universal Checkpoint format and excludes expert params from standard Stage3 slice merge. |
deepspeed/checkpoint/constants.py |
Introduces AutoEP ZeRO-3 checkpoint format keys/versioning constants. |
deepspeed/utils/zero_to_fp32.py |
Rejects AutoEP ZeRO-3 partition-native checkpoints (directs users to expert-aware conversion). |
deepspeed/runtime/pipe/engine.py |
Extends pipeline load_module_state_dict signature to pass through new ZeRO-3 fetch/missing-key controls. |
tests/unit/v1/moe/test_autoep_unit.py |
Adds unit tests covering ZeRO-3 gates, placement metadata, Stage3 averaging/norm behavior, and zero.Init source gathering. |
tests/unit/v1/moe/test_autoep_integration.py |
Adds distributed integration tests for ZeRO-3 + AutoEP training, placement, checkpoint save/load, and replica-group behavior. |
tests/unit/v1/moe/test_autoep_checkpoint.py |
Adds distributed tests for ZeRO-3 → Universal conversion round-trips and load_universal restore. |
tests/unit/runtime/zero/test_zero_context.py |
Adds coverage ensuring multi-handle waits pass handle_dependency by keyword. |
docs/code-docs/source/autoep.rst |
Updates AutoEP ZeRO compatibility and checkpointing constraints (but currently conflicts with new code/tests). |
docs/_tutorials/universal-checkpointing.md |
Updates AutoEP Universal Checkpoint documentation (but currently conflicts with new ZeRO-3 conversion support added here). |
docs/_pages/config-json.md |
Updates config docs for constrained ZeRO-3 support and expert TP placeholder (but currently conflicts with new ZeRO-3 universal conversion support added here). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| convert a ZeRO Stage 1 or ZeRO Stage 2 checkpoint to Universal Checkpoint | ||
| format and load it with ``checkpoint.load_universal``; see the | ||
| `Universal Checkpointing tutorial </tutorials/universal-checkpointing/>`__ | ||
| for the detailed flow and constraints. | ||
| for the detailed flow and constraints. ZeRO Stage 3 AutoEP checkpoints must | ||
| be loaded with the same topology. |
| Regular AutoEP checkpoint load requires the target run to use the same | ||
| `autoep_size` as the save run. To change `autoep_size` for the same | ||
| AutoEP-detected model topology, convert the saved checkpoint to Universal format | ||
| and load the Universal checkpoint. | ||
| AutoEP-detected model topology, convert a ZeRO Stage 1 or ZeRO Stage 2 checkpoint | ||
| to Universal format and load the Universal checkpoint. For ZeRO Stage 3 AutoEP | ||
| checkpoints, use regular same-topology checkpoint load instead. |
A parameter that is already ZeRO-partitioned when _resolve_zero3_param_placement runs, for example an AutoEPMoELayer wrapped directly in zero.Init, keeps the partition group fixed at conversion time. Recording the freshly resolved expert replica group in ds_zero_partition_* metadata for such a parameter would silently reduce-scatter different experts across the wrong ranks. Raise instead, and derive the metadata from the actual partition group so it always describes the real partitioning. Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
The AutoEP and Universal Checkpointing docs still described the pre-redesign behavior: per-expert checkpoint files for every stage, no ZeRO-3 universal conversion, and no module-only or optimizer-state-free ZeRO-3 loads. Describe the partition-native ZeRO-3 layout and the supported conversion/load paths, and limit the remaining constraints to topology changes and zero_to_fp32 consolidation. Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
The DistributedFixture subclass alone was not collected as a fixture, so all six topology-change tests failed in setup with 'fixture not found'. Wrap it in an explicit @pytest.fixture function. Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Bring the config.md constraint bullet in line with autoep.rst and the universal checkpointing tutorial: ZeRO-3 AutoEP optimizer-including universal loads can resume at a different data-parallel world size and/or autoep_size; only weights-only/module-only universal loads remain unsupported. Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
| averaging_world_size = self._gradient_averaging_world_size(params_to_reduce, partition_world_size) | ||
| if averaging_world_size != partition_world_size: | ||
| scale = partition_world_size / float(averaging_world_size) | ||
| grad_partitions_for_rank = [g.mul(scale) for g in grad_partitions_for_rank] |
There was a problem hiding this comment.
zero3 expert gradient averaging is more complicated than zero2, suggest to add a UT to ensure zero3 averaged gradients are close to zero2 averaged gradients.
| for module_name, module in model.named_modules() if isinstance(module, _AutoEPMoELayer) | ||
| } | ||
|
|
||
| required_fields = { |
There was a problem hiding this comment.
What is the relationship between required_fields and partitioned_fields here and the two fields in ds_to_universal.py? They look similiar but not exactly identical. If they need to be identical, is there a way to ensure the coupling?
from ds_to_universal.py
required_fields = {
'moe_layer_id',
'module_path',
'num_experts',
'num_local_experts',
'ep_size',
'expert_key_prefix',
}
partitioned_fields = {
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY,
'ep_rank',
'expert_data_parallel_rank',
'expert_data_parallel_world_size',
'global_expert_start',
'global_expert_end',
}
This PR enables ZeRO3 support for AutoEP-managed MoE layers by partitioning expert parameters over expert replica groups while router and replicated parameters use the global data-parallel group.
With ZeRO3 enable, AutoEP preserves global data-parallel gradient averaging for AutoEP expert parameters while reducing them over expert replica groups. ZeRO parameters are gathered before AutoEP reads router or expert tensors when replacing MoE modules created under
deepspeed.zero.Init().