diff --git a/THIRD_PARTY_NOTICES.md b/THIRD_PARTY_NOTICES.md
new file mode 100644
index 000000000000..3ec9ca860fdf
--- /dev/null
+++ b/THIRD_PARTY_NOTICES.md
@@ -0,0 +1,48 @@
+# Third-Party Notices
+
+This file records third-party source notices for code incorporated into
+DeepSpeed source files.
+
+## TorchTitan
+
+The following files contain portions derived from TorchTitan:
+
+- `deepspeed/module_inject/auto_ep_layer.py`
+- `deepspeed/moe/ep_experts.py`
+- `deepspeed/moe/ep_kernels.py`
+- `deepspeed/moe/ep_router.py`
+
+Source project: https://github.com/pytorch/torchtitan
+
+TorchTitan is licensed under the BSD 3-Clause License:
+
+```text
+BSD 3-Clause License
+
+(c) Meta Platforms, Inc. and affiliates.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation
+and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its contributors
+may be used to endorse or promote products derived from this software without
+specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+```
diff --git a/deepspeed/checkpoint/autoep_universal.py b/deepspeed/checkpoint/autoep_universal.py
new file mode 100644
index 000000000000..3c19ab0c4183
--- /dev/null
+++ b/deepspeed/checkpoint/autoep_universal.py
@@ -0,0 +1,285 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""AutoEP universal checkpoint conversion utilities.
+
+Consolidates per-expert checkpoint files (and their optimizer states) into
+topology-agnostic universal format for EP resharding support.
+"""
+
+import os
+import glob
+import torch
+
+from .constants import (
+ PARAM,
+ CAT_DIM,
+ EP_IS_EXPERT_PARAM,
+ EP_NUM_EXPERTS,
+)
+
+
+def _state_entry(state, param_id):
+ """Get optimizer state entry by param id, handling int/str key variants."""
+ if param_id in state:
+ return state[param_id]
+
+ pid_str = str(param_id)
+ if pid_str in state:
+ return state[pid_str]
+
+ if isinstance(param_id, str):
+ try:
+ pid_int = int(param_id)
+ except ValueError:
+ return None
+ return state.get(pid_int)
+
+ return None
+
+
+def _ordered_param_ids(optim_sd):
+ """Return optimizer param ids in param_groups order, deduplicated."""
+ ordered = []
+ seen = set()
+ for group in optim_sd.get('param_groups', []):
+ for param_id in group.get('params', []):
+ key = str(param_id)
+ if key in seen:
+ continue
+ seen.add(key)
+ ordered.append(param_id)
+
+ if ordered:
+ return ordered
+
+ # Fallback for unexpected optimizer formats.
+ state = optim_sd.get('state', {})
+ return list(state.keys())
+
+
+def _param_name_to_id(optim_sd):
+ """Build optional mapping from parameter name to optimizer param id."""
+ mapping = {}
+ for group in optim_sd.get('param_groups', []):
+ params = group.get('params', [])
+ param_names = group.get('param_names', None)
+ if not isinstance(param_names, list):
+ continue
+ if len(param_names) != len(params):
+ continue
+ for param_id, param_name in zip(params, param_names):
+ mapping[param_name] = param_id
+ return mapping
+
+
+def _is_expert_optimizer_state(param_state, num_local):
+ for state_key in ('exp_avg', 'exp_avg_sq'):
+ tensor = param_state.get(state_key)
+ if tensor is None:
+ continue
+ if tensor.dim() == 3 and tensor.shape[0] == num_local:
+ return True
+ return False
+
+
+def resolve_expert_ckpt_path(checkpoint_dir, moe_layer_id, global_expert_id):
+ """Find the expert checkpoint file for a given (layer, expert) pair.
+
+ Resolves using glob pattern without assuming mp_rank=0.
+
+ Returns:
+ Path to the single matching expert checkpoint file.
+
+ Raises:
+ FileNotFoundError: No matching file found.
+ NotImplementedError: Multiple matching files found (multi-mp_rank).
+ """
+ pattern = os.path.join(checkpoint_dir, f'layer_{moe_layer_id}_expert_{global_expert_id}_mp_rank_*_model_states.pt')
+ matches = glob.glob(pattern)
+ if len(matches) == 0:
+ raise FileNotFoundError(f"Expert checkpoint file not found: layer_{moe_layer_id} "
+ f"expert_{global_expert_id} in {checkpoint_dir}")
+ if len(matches) > 1:
+ raise NotImplementedError(f"Multiple expert checkpoint files found for layer_{moe_layer_id} "
+ f"expert_{global_expert_id}: {matches}. Multi-mp_rank expert files "
+ f"are not yet supported.")
+ return matches[0]
+
+
+def consolidate_autoep_expert_files(checkpoint_dir, output_dir, autoep_layers_metadata):
+ """Consolidate per-expert checkpoint files into full-expert universal format.
+
+ For each AutoEP layer, loads all per-expert files, stacks into
+ [E_total, H, D] tensors, and saves in universal checkpoint format.
+
+ Args:
+ checkpoint_dir: Path to DeepSpeed checkpoint directory.
+ output_dir: Path to universal checkpoint output directory.
+ autoep_layers_metadata: AutoEP metadata list from main checkpoint.
+
+ Raises:
+ FileNotFoundError: If expected expert files are missing.
+ NotImplementedError: If multiple mp_rank files match one (layer, expert).
+ RuntimeError: If metadata is missing or malformed.
+ """
+ if autoep_layers_metadata is None:
+ raise RuntimeError("AutoEP metadata is missing from checkpoint. Cannot consolidate "
+ "expert files without ds_autoep_layers metadata.")
+ if not isinstance(autoep_layers_metadata, list):
+ raise RuntimeError(f"AutoEP metadata is malformed: expected list, got "
+ f"{type(autoep_layers_metadata).__name__}")
+
+ for layer_info in autoep_layers_metadata:
+ moe_layer_id = layer_info['moe_layer_id']
+ num_experts = layer_info['num_experts']
+ prefix = layer_info['expert_key_prefix']
+
+ for wname in ('w1', 'w2', 'w3'):
+ expert_tensors = []
+ for global_eid in range(num_experts):
+ ckpt_path = resolve_expert_ckpt_path(checkpoint_dir, moe_layer_id, global_eid)
+ sd = torch.load(ckpt_path, map_location='cpu', weights_only=False)
+ key = f"{prefix}.{wname}.{global_eid}"
+ if key not in sd:
+ raise RuntimeError(f"Expected key '{key}' not found in {ckpt_path}")
+ expert_tensors.append(sd[key])
+
+ # Stack to full fused tensor [E_total, H, D]
+ full_tensor = torch.stack(expert_tensors, dim=0)
+
+ # Save in universal format
+ param_name = f"{prefix}.{wname}"
+ param_dir = os.path.join(output_dir, "zero", param_name)
+ os.makedirs(param_dir, exist_ok=True)
+ torch.save({
+ PARAM: full_tensor,
+ CAT_DIM: 0,
+ EP_IS_EXPERT_PARAM: True,
+ EP_NUM_EXPERTS: num_experts,
+ }, os.path.join(param_dir, "fp32.pt"))
+
+
+def consolidate_autoep_optimizer_states(checkpoint_dir, output_dir, autoep_layers_metadata, ep_size):
+ """Consolidate expert optimizer states from expp_rank files into universal format.
+
+ Loads optimizer states from all expp_rank_*_optim_states.pt files,
+ extracts per-expert-parameter states (exp_avg, exp_avg_sq, etc.),
+ concatenates along the expert dimension (dim 0) to form full
+ [E_total, H, D] optimizer states, and saves alongside the model
+ parameter in universal format.
+
+ Args:
+ checkpoint_dir: Path to DeepSpeed checkpoint directory.
+ output_dir: Path to universal checkpoint output directory.
+ autoep_layers_metadata: AutoEP metadata list from main checkpoint.
+ ep_size: Expert parallel world size (number of expp_rank files to load).
+
+ Raises:
+ FileNotFoundError: If expected optimizer state files are missing.
+ RuntimeError: If expert parameter states cannot be extracted.
+ """
+ if autoep_layers_metadata is None:
+ raise RuntimeError("AutoEP metadata is missing. Cannot consolidate optimizer states.")
+
+ # Load all expp_rank optimizer states
+ optim_states = []
+ for rank in range(ep_size):
+ pattern = os.path.join(checkpoint_dir, f'expp_rank_{rank}_mp_rank_*_optim_states.pt')
+ matches = glob.glob(pattern)
+ if not matches:
+ # No optimizer state files (e.g., ZeRO handles optimizer differently)
+ return
+ optim_path = matches[0]
+ sd = torch.load(optim_path, map_location='cpu', weights_only=False)
+ optim_states.append(sd)
+
+ if not optim_states:
+ return
+
+ # Extract optimizer state dict
+ optim_sd = optim_states[0].get('optimizer')
+ if optim_sd is None:
+ return
+
+ state = optim_sd.get('state', {})
+
+ if not state:
+ return
+
+ ordered_param_ids = _ordered_param_ids(optim_sd)
+ name_to_param_id = _param_name_to_id(optim_sd)
+ consumed_param_ids = set()
+
+ # For each AutoEP layer, extract and consolidate optimizer states
+ for layer_info in autoep_layers_metadata:
+ prefix = layer_info['expert_key_prefix']
+ num_experts = layer_info['num_experts']
+ num_local = layer_info['num_local_experts']
+ layer_param_ids = {}
+
+ # If optimizer state carries param names, map weights by exact identity.
+ for wname in ('w1', 'w2', 'w3'):
+ param_name = f"{prefix}.{wname}"
+ param_id = name_to_param_id.get(param_name)
+ if param_id is None:
+ continue
+ layer_param_ids[wname] = param_id
+ consumed_param_ids.add(str(param_id))
+
+ # Fallback: consume expert-like params in optimizer param_groups order.
+ missing_wnames = [w for w in ('w1', 'w2', 'w3') if w not in layer_param_ids]
+ if missing_wnames:
+ candidates = []
+ for param_id in ordered_param_ids:
+ if str(param_id) in consumed_param_ids:
+ continue
+ param_state = _state_entry(state, param_id)
+ if param_state is None:
+ continue
+ if not _is_expert_optimizer_state(param_state, num_local):
+ continue
+ candidates.append(param_id)
+
+ for wname, param_id in zip(missing_wnames, candidates):
+ layer_param_ids[wname] = param_id
+ consumed_param_ids.add(str(param_id))
+
+ for wname in ('w1', 'w2', 'w3'):
+ param_name = f"{prefix}.{wname}"
+ param_dir = os.path.join(output_dir, "zero", param_name)
+ os.makedirs(param_dir, exist_ok=True)
+ param_id = layer_param_ids.get(wname)
+ if param_id is None:
+ continue
+
+ # Consolidate optimizer states for this specific expert parameter id.
+ for state_key in ('exp_avg', 'exp_avg_sq'):
+ rank_tensors = []
+
+ for rank in range(ep_size):
+ rank_optim_sd = optim_states[rank].get('optimizer', {})
+ rank_state = rank_optim_sd.get('state', {})
+ param_state = _state_entry(rank_state, param_id)
+ if param_state is None:
+ rank_tensors = []
+ break
+ tensor = param_state.get(state_key)
+ if tensor is None:
+ rank_tensors = []
+ break
+ if tensor.dim() != 3 or tensor.shape[0] != num_local:
+ rank_tensors = []
+ break
+ rank_tensors.append(tensor)
+
+ if len(rank_tensors) == ep_size:
+ full_tensor = torch.cat(rank_tensors, dim=0)
+ torch.save(
+ {
+ PARAM: full_tensor,
+ CAT_DIM: 0,
+ EP_IS_EXPERT_PARAM: True,
+ EP_NUM_EXPERTS: num_experts,
+ }, os.path.join(param_dir, f"{state_key}.pt"))
diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py
index dde5b16bd946..1ea9585c81c1 100644
--- a/deepspeed/checkpoint/constants.py
+++ b/deepspeed/checkpoint/constants.py
@@ -87,3 +87,16 @@
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0 = 'parameter_with_2_sub_params_cat_dim_0'
PARAMETER_WITH_SUB_PARAMS = 'parameter_with_sub_params'
SUB_PARAMS_SHAPE = 'sub_params_shape'
+
+#########################################
+# AutoEP Checkpoint keys
+#########################################
+AUTOEP_LAYERS_KEY = 'ds_autoep_layers'
+AUTOEP_LAYERS_KEY_LEGACY = 'autoep_layers'
+
+#########################################
+# Universal Checkpoint EP keys
+#########################################
+EP_IS_EXPERT_PARAM = 'is_expert_param'
+EP_NUM_EXPERTS = 'ep_num_experts'
+EXPERT_PARAMETER_PATTERNS = 'expert_parameter_patterns'
diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py
index 8a39f6bb4c31..5c392ca52ec2 100755
--- a/deepspeed/checkpoint/ds_to_universal.py
+++ b/deepspeed/checkpoint/ds_to_universal.py
@@ -466,6 +466,14 @@ def _check_for_required_state(ds_checkpoint):
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
+def _classify_autoep_expert_file_consolidation(autoep_metadata, expert_files):
+ if autoep_metadata is not None:
+ return 'autoep'
+ if expert_files:
+ return 'native_moe'
+ return 'none'
+
+
def main(args):
print('Convert DeepSpeed Checkpoint to Universal Checkpoint')
@@ -501,17 +509,72 @@ def main(args):
print('*** 2. Merging slices .....')
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)
+ print('*** 2.5. Consolidating AutoEP expert files')
+ from deepspeed.checkpoint.constants import (
+ AUTOEP_LAYERS_KEY,
+ AUTOEP_LAYERS_KEY_LEGACY,
+ EXPERT_PARAMETER_PATTERNS,
+ )
+ from deepspeed.checkpoint.autoep_universal import (
+ consolidate_autoep_expert_files,
+ consolidate_autoep_optimizer_states,
+ )
+
+ # Load AutoEP metadata from main checkpoint
+ main_sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'), weights_only=False)
+ autoep_metadata = main_sd.get(AUTOEP_LAYERS_KEY)
+ if autoep_metadata is None:
+ autoep_metadata = main_sd.get(AUTOEP_LAYERS_KEY_LEGACY)
+
+ # Check for expert files in checkpoint directory
+ expert_files = glob.glob(os.path.join(args.input_folder, 'layer_*_expert_*_model_states.pt'))
+ autoep_expert_file_type = _classify_autoep_expert_file_consolidation(autoep_metadata, expert_files)
+
+ if autoep_expert_file_type == 'autoep':
+ consolidate_autoep_expert_files(args.input_folder, args.output_folder, autoep_metadata)
+ ep_size = autoep_metadata[0]['ep_size'] if autoep_metadata else 1
+ consolidate_autoep_optimizer_states(args.input_folder, args.output_folder, autoep_metadata, ep_size)
+ print(f' Consolidated {len(autoep_metadata)} AutoEP layer(s)')
+ elif autoep_expert_file_type == 'native_moe':
+ print(f' Found {len(expert_files)} expert checkpoint file(s) but no AutoEP metadata; '
+ 'assuming native DeepSpeed MoE and skipping AutoEP consolidation')
+ else:
+ print(' No AutoEP layers found, skipping')
+
print('*** 3. Saving common optimizer states')
_save_optimizer_state(args, ds_checkpoint)
if not args.keep_temp_folder:
shutil.rmtree(temp_dir, ignore_errors=True)
- # Copy mp* files into output folder
+ # Copy mp* files into output folder, injecting AutoEP metadata into UNIVERSAL_CHECKPOINT_INFO
for f in glob.glob(os.path.join(args.input_folder, 'mp*')):
- shutil.copy2(f, args.output_folder)
+ if autoep_metadata is not None:
+ # Load -> update with AutoEP metadata -> save
+ mp_sd = torch.load(f, map_location=torch.device('cpu'), weights_only=False)
+ if UNIVERSAL_CHECKPOINT_INFO not in mp_sd:
+ mp_sd[UNIVERSAL_CHECKPOINT_INFO] = {}
+ mp_sd[UNIVERSAL_CHECKPOINT_INFO][EXPERT_PARAMETER_PATTERNS] = [r'\.experts\.w[123]$']
+ mp_sd[UNIVERSAL_CHECKPOINT_INFO][AUTOEP_LAYERS_KEY] = autoep_metadata
+ out_path = os.path.join(args.output_folder, os.path.basename(f))
+ torch.save(mp_sd, out_path)
+ else:
+ shutil.copy2(f, args.output_folder)
else:
+ # Stage 3 path
+ # Check for AutoEP metadata - Stage 3 + AutoEP is not supported
+ stage3_expert_files = glob.glob(os.path.join(args.input_folder, 'layer_*_expert_*_model_states.pt'))
+ stage3_model_files_for_meta = glob.glob(os.path.join(args.input_folder, 'mp_rank_*_model_states.pt'))
+ if stage3_model_files_for_meta:
+ _stage3_sd = torch.load(stage3_model_files_for_meta[0],
+ map_location=torch.device('cpu'),
+ weights_only=False)
+ _stage3_autoep = _stage3_sd.get('ds_autoep_layers') or _stage3_sd.get('autoep_layers')
+ if _stage3_autoep is not None:
+ raise NotImplementedError("Stage 3 universal checkpoint conversion with AutoEP is not supported. "
+ "AutoEP currently requires ZeRO Stage 1 or 2.")
+
model_files = _get_model_state_files(args.input_folder)
param_shapes = _parse_model_states_stage3(model_files)
dp_degree = len(model_files)
@@ -531,8 +594,11 @@ def main(args):
if not args.keep_temp_folder:
shutil.rmtree(temp_dir, ignore_errors=True)
- # Copy *model_states files into output folder
+ # Copy *model_states files into output folder, filtering out expert files
for f in glob.glob(os.path.join(args.input_folder, '*model_states.pt')):
+ basename = os.path.basename(f)
+ if basename.startswith('layer_') and '_expert_' in basename:
+ continue # Skip expert files (handled separately if AutoEP were supported)
shutil.copy2(f, args.output_folder)
# Update latest to output folder
diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py
index 7a9c2bcb068b..f057393ecdfc 100644
--- a/deepspeed/checkpoint/universal_checkpoint.py
+++ b/deepspeed/checkpoint/universal_checkpoint.py
@@ -10,7 +10,7 @@
from typing import List, Tuple, Union
from dataclasses import dataclass
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE,
- DS_AUTOTP_UC_META)
+ EP_IS_EXPERT_PARAM, EP_NUM_EXPERTS, DS_AUTOTP_UC_META)
@dataclass
@@ -96,7 +96,7 @@ def _resolve_autotp_partition(current_param, ckpt_dict, full_hp_param, tp_rank,
return slice_tensor.flatten()
-def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
+def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size, ep_rank=0, ep_size=1):
hp_mapping = self._hp_mapping
hp_mapping.optim_fragment = {}
@@ -119,6 +119,23 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
full_hp_param = ckpt_dict[PARAM]
+ # EP-aware slicing for expert parameters saved in universal format.
+ # Must happen BEFORE shape-match check so that after slicing,
+ # full_hp_param.shape == self.shape triggers tp_rank=0, tp_world_size=1.
+ is_expert_param = ckpt_dict.get(EP_IS_EXPERT_PARAM, False)
+ if is_expert_param and ep_size > 1:
+ ep_num_experts = ckpt_dict.get(EP_NUM_EXPERTS)
+ assert ep_num_experts is not None, \
+ f"Expert param in {ckpt_file} missing '{EP_NUM_EXPERTS}' metadata"
+ assert full_hp_param.shape[0] == ep_num_experts, \
+ f"Expert param dim 0 ({full_hp_param.shape[0]}) != {EP_NUM_EXPERTS} ({ep_num_experts})"
+ assert ep_num_experts % ep_size == 0, \
+ f"num_experts ({ep_num_experts}) not divisible by ep_size ({ep_size})"
+ num_local = ep_num_experts // ep_size
+ ep_start = ep_rank * num_local
+ ep_end = ep_start + num_local
+ full_hp_param = full_hp_param[ep_start:ep_end]
+
# need to deal with slices that were averaged.
# the opposite of averaging here becomes an exact copy of the first slice
# I thought of 2 ways:
@@ -139,7 +156,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
- is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False)
+ is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False) and not is_expert_param
if is_vocab_tensor:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py
index 7e78a6b060fb..2ff8e381f702 100755
--- a/deepspeed/inference/engine.py
+++ b/deepspeed/inference/engine.py
@@ -464,7 +464,8 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
old_moe_load=old_moe_load,
model=self.module,
mpu=self.mpu,
- checkpoint_engine=self.checkpoint_engine)
+ checkpoint_engine=self.checkpoint_engine,
+ autoep_layers=None)
self.module.load_state_dict(state_dict=checkpoint[self._choose_module_key(checkpoint)],
strict=load_module_strict)
diff --git a/deepspeed/module_inject/auto_ep.py b/deepspeed/module_inject/auto_ep.py
new file mode 100644
index 000000000000..e80bebf15e74
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep.py
@@ -0,0 +1,582 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""AutoEP: Automatic Expert Parallelism for MoE models.
+
+Phase 3: MoE layer detection and structural validation.
+Phase 5: Layer replacement (replace_moe_layer filled in).
+"""
+
+from __future__ import annotations
+
+import math
+import re
+from collections import OrderedDict
+from typing import Literal
+
+import torch
+import torch.nn as nn
+
+from deepspeed.module_inject.auto_ep_config import (
+ AutoEPConfig,
+ MoELayerSpec,
+ MoEModelPreset,
+)
+from deepspeed.module_inject.auto_ep_presets.base import ForwardContract
+from deepspeed.module_inject.auto_ep_presets.registry import (
+ apply_config_overrides,
+ get_preset_adapter,
+ preset_name_for_hf_model_type,
+ resolve_preset_candidates,
+ unsupported_preset_for_hf_model_type,
+)
+from deepspeed.moe.fused_expert_layout import classify_fused_gate_up_layout
+from deepspeed.utils import logger
+
+
+def _remove_transformers_output_capture_hooks(model: nn.Module) -> int:
+ """Remove HF output-capturing hooks so they can be reinstalled after AutoEP conversion."""
+ removed = 0
+ for module in model.modules():
+ hooks = getattr(module, "_forward_hooks", None)
+ if not hooks:
+ continue
+
+ for hook_id, hook in list(hooks.items()):
+ if getattr(hook, "__name__", "") != "output_capturing_hook":
+ continue
+ del hooks[hook_id]
+ removed += 1
+ hooks_with_kwargs = getattr(module, "_forward_hooks_with_kwargs", None)
+ if hooks_with_kwargs is not None:
+ hooks_with_kwargs.pop(hook_id, None)
+ hooks_always_called = getattr(module, "_forward_hooks_always_called", None)
+ if hooks_always_called is not None:
+ hooks_always_called.pop(hook_id, None)
+ return removed
+
+
+def _is_known_hf_model_type(model_type: str | None) -> bool:
+ if model_type is None:
+ return False
+ return (preset_name_for_hf_model_type(model_type) is not None
+ or unsupported_preset_for_hf_model_type(model_type) is not None)
+
+
+def _raise_if_duplicate_moe_specs(specs: list[MoELayerSpec]) -> None:
+ by_module: dict[str, list[MoELayerSpec]] = {}
+ for spec in specs:
+ by_module.setdefault(spec.moe_module_name, []).append(spec)
+
+ duplicates = {name: matches for name, matches in by_module.items() if len(matches) > 1}
+ if not duplicates:
+ return
+
+ details = "; ".join(f"{name}: {', '.join(spec.model_family for spec in matches)}"
+ for name, matches in sorted(duplicates.items()))
+ raise ValueError("AutoEP detection is ambiguous and produced multiple replacement specs for the same "
+ f"MoE module(s): {details}. Set expert_parallel.preset_model or provide custom "
+ "AutoEP patterns so each MoE module matches exactly one preset.")
+
+
+def _has_3d_expert_params(module: nn.Module, preset: MoEModelPreset) -> bool:
+ """Check if module stores expert weights as 3D parameter tensors (transformers 5.0.0+).
+
+ Returns True if the module has a parameter named preset.expert_w1 (e.g., "gate_up_proj")
+ with 3 dimensions (num_experts, ..., ...).
+ """
+ w1_name = preset.expert_w1
+ param = getattr(module, w1_name, None)
+ if param is None:
+ return False
+ if isinstance(param, nn.Parameter) or isinstance(param, torch.Tensor):
+ return param.ndim == 3
+ return False
+
+
+def _get_num_experts_from_config(model_config, preset: MoEModelPreset) -> int | None:
+ """Extract num_experts from model.config using the preset's attribute name."""
+ return getattr(model_config, preset.num_experts_attr, None)
+
+
+def _get_top_k_from_config(model_config, preset: MoEModelPreset) -> int | None:
+ """Extract top_k from model.config using the preset's attribute name."""
+ return getattr(model_config, preset.top_k_attr, None)
+
+
+def _as_finite_float(value, field_name: str) -> float:
+ if isinstance(value, bool) or not isinstance(value, (int, float)):
+ raise ValueError(f"{field_name} must be a finite number")
+
+ value = float(value)
+ if not math.isfinite(value):
+ raise ValueError(f"{field_name} must be a finite number")
+ return value
+
+
+def _resolve_route_scale(config: AutoEPConfig, model_config) -> float:
+ """Resolve the single scale applied by TokenChoiceTopKRouter."""
+ routed_scaling_factor = config.routed_scaling_factor
+
+ if routed_scaling_factor != "auto":
+ route_scale = _as_finite_float(routed_scaling_factor, "routed_scaling_factor")
+ if config.route_scale != 1.0:
+ logger.warning("AutoEP: routed_scaling_factor=%s overrides route_scale=%s.", routed_scaling_factor,
+ config.route_scale)
+ return route_scale
+
+ cfg_routed_scaling_factor = getattr(model_config, 'routed_scaling_factor', None)
+ if cfg_routed_scaling_factor is not None:
+ route_scale = _as_finite_float(cfg_routed_scaling_factor, "model.config.routed_scaling_factor")
+ if config.route_scale != 1.0:
+ logger.warning("AutoEP: model.config.routed_scaling_factor=%s overrides route_scale=%s.",
+ cfg_routed_scaling_factor, config.route_scale)
+ return route_scale
+
+ return _as_finite_float(config.route_scale, "route_scale")
+
+
+def _detect_expert_storage(experts_module: nn.Module, preset: MoEModelPreset) -> Literal["fused_3d", "module_list"]:
+ """Determine whether experts are stored as fused 3D tensors or nn.ModuleList."""
+ if _has_3d_expert_params(experts_module, preset):
+ return "fused_3d"
+ if isinstance(experts_module, nn.ModuleList):
+ return "module_list"
+ # Check children for 3D params as fallback
+ for name, param in experts_module.named_parameters(recurse=False):
+ if param.ndim == 3:
+ return "fused_3d"
+ return "module_list"
+
+
+def _infer_hidden_and_ffn_size(
+ experts_module: nn.Module,
+ preset: MoEModelPreset,
+ storage: Literal["fused_3d", "module_list"],
+ num_experts: int,
+) -> tuple[int, int]:
+ """Infer hidden_size and ffn_hidden_size from expert weight shapes."""
+ if storage == "fused_3d":
+ w1_param = getattr(experts_module, preset.expert_w1, None)
+ w2_param = getattr(experts_module, preset.expert_w2, None)
+ if w1_param is not None and w2_param is not None:
+ if preset.expert_w3 is None:
+ layout = classify_fused_gate_up_layout(tuple(w1_param.shape), tuple(w2_param.shape))
+ if layout is None:
+ raise ValueError("expert_w3=None expects fused gate+up weights with either "
+ f"[E, 2*ffn, hidden]/[E, hidden, ffn] or [E, hidden, 2*ffn]/[E, ffn, hidden], "
+ f"but got {preset.expert_w1}={tuple(w1_param.shape)} and "
+ f"{preset.expert_w2}={tuple(w2_param.shape)}.")
+ hidden_size = layout.hidden_size
+ ffn_hidden_size = layout.ffn_hidden_size
+ else:
+ # Separate gate and up: w1 shape is [E, ffn, hidden]
+ w3_param = getattr(experts_module, preset.expert_w3, None)
+ if w3_param is None:
+ raise ValueError(f"expert_w3='{preset.expert_w3}' is set but no such weight "
+ f"exists on experts module.")
+ hidden_size = w1_param.shape[2]
+ ffn_hidden_size = w1_param.shape[1]
+ return hidden_size, ffn_hidden_size
+ elif storage == "module_list":
+ # Legacy: individual expert modules
+ if isinstance(experts_module, nn.ModuleList) and len(experts_module) > 0:
+ expert0 = experts_module[0]
+ w1 = getattr(expert0, preset.expert_w1, None)
+ if w1 is None:
+ # Try weight attribute for nn.Linear
+ for name, child in expert0.named_children():
+ if preset.expert_w1 in name:
+ w1 = child.weight if hasattr(child, 'weight') else None
+ break
+ if w1 is not None:
+ if isinstance(w1, nn.Linear):
+ return w1.in_features, w1.out_features
+ elif isinstance(w1, (nn.Parameter, torch.Tensor)):
+ if w1.ndim == 2:
+ return w1.shape[1], w1.shape[0]
+
+ raise ValueError(f"Could not infer hidden_size/ffn_hidden_size from experts module "
+ f"with storage={storage}, preset.expert_w1={preset.expert_w1}")
+
+
+def _detect_forward_contract(
+ moe_module: nn.Module,
+ router_module: nn.Module,
+) -> ForwardContract:
+ """Detect the forward contract for router logits capture.
+
+ Returns:
+ ForwardContract with router-logit return and capture metadata.
+ """
+ # Check for OutputRecorder on the model (transformers 5.0.0 pattern)
+ # Look for _can_record_outputs attribute on parent modules
+ capture_target: Literal["moe_block", "router", "none"] = "none"
+ capture_index: int | None = None
+ capture_layer_name: str | None = None
+ return_router_logits = False
+
+ # Check for OutputRecorder pattern on router class
+ router_class = type(router_module)
+ if hasattr(router_class, '_can_record_outputs'):
+ capture_target = "router"
+ record_config = router_class._can_record_outputs
+ if isinstance(record_config, dict):
+ for key, val in record_config.items():
+ if isinstance(val, dict):
+ capture_index = val.get('index', 0)
+ capture_layer_name = val.get('layer_name', None)
+ else:
+ capture_index = 0
+ elif isinstance(record_config, (list, tuple)):
+ capture_index = 0
+ logger.debug(f"Detected OutputRecorder on router class {router_class.__name__}: "
+ f"index={capture_index}, layer_name={capture_layer_name}")
+
+ # Check if MoE block has tuple return contract (legacy transformers)
+ if hasattr(moe_module, '_can_record_outputs'):
+ record_config = moe_module._can_record_outputs
+ if record_config:
+ capture_target = "moe_block"
+ return_router_logits = True
+ if isinstance(record_config, dict):
+ for key, val in record_config.items():
+ if isinstance(val, dict):
+ capture_index = val.get('index', None)
+ elif isinstance(val, int):
+ capture_index = val
+
+ return ForwardContract(
+ return_router_logits=return_router_logits,
+ capture_target=capture_target,
+ capture_index=capture_index,
+ capture_layer_name=capture_layer_name,
+ )
+
+
+class AutoEP:
+ """Automatic Expert Parallelism: detect and replace MoE layers."""
+
+ def __init__(self, model: nn.Module, config: AutoEPConfig) -> None:
+ self.model = model
+ self.config = config
+ self.model_config = getattr(model, 'config', None)
+ self._retargeted_transformers_output_recorders: set[str] = set()
+
+ def ep_parser(self) -> list[MoELayerSpec]:
+ """Traverse model and detect MoE layers. Returns list of MoELayerSpec."""
+ specs = []
+
+ # Determine which preset(s) to use
+ presets_to_try = self._resolve_presets()
+
+ for preset_name, preset in presets_to_try:
+ adapter = get_preset_adapter(preset.preset_adapter)
+ pattern = re.compile(preset.moe_layer_pattern)
+
+ for module_name, module in self.model.named_modules():
+ if not pattern.fullmatch(module_name):
+ continue
+
+ # Structural validation: check for experts child
+ experts_child = getattr(module, preset.experts_pattern, None)
+ if experts_child is None:
+ logger.debug(
+ "Skipping %s: pattern matched but no '%s' child (likely dense FFN)",
+ module_name,
+ preset.experts_pattern,
+ )
+ continue
+
+ expert_layout = adapter.resolve_expert_layout(experts_child, preset)
+
+ # Accept both: nn.ModuleList (legacy) and Experts class (transformers 5.0.0+)
+ has_expert_params = (isinstance(experts_child, nn.ModuleList)
+ or _has_3d_expert_params(experts_child, expert_layout))
+ if not has_expert_params:
+ logger.debug(
+ "Skipping %s: '%s' child exists but has no expert parameters",
+ module_name,
+ preset.experts_pattern,
+ )
+ continue
+
+ # Check for router
+ router_child = getattr(module, preset.router_pattern, None)
+ if router_child is None:
+ logger.debug(
+ "Skipping %s: no router child '%s'",
+ module_name,
+ preset.router_pattern,
+ )
+ continue
+
+ # Detect storage format
+ storage = _detect_expert_storage(experts_child, expert_layout)
+
+ # Get num_experts and top_k from config or weights
+ num_experts = None
+ top_k = None
+
+ if self.model_config is not None:
+ num_experts = _get_num_experts_from_config(self.model_config, preset)
+ top_k = _get_top_k_from_config(self.model_config, preset)
+
+ # Validate/derive from router weight shape
+ router_weight = getattr(router_child, 'weight', None)
+ if router_weight is not None and router_weight.ndim == 2:
+ num_experts_from_weight = router_weight.shape[0]
+ hidden_from_weight = router_weight.shape[1]
+ if num_experts is not None and num_experts != num_experts_from_weight:
+ raise ValueError(f"Config num_experts={num_experts} mismatches router weight "
+ f"shape {router_weight.shape} (expected {num_experts_from_weight}) "
+ f"in layer '{module_name}'")
+ num_experts = num_experts_from_weight
+
+ if num_experts is None:
+ raise ValueError(f"Could not determine num_experts for layer '{module_name}'. "
+ f"Set model.config.{preset.num_experts_attr} or use a preset.")
+
+ # Override top_k from config if user specified
+ if isinstance(self.config.top_k, int):
+ top_k = self.config.top_k
+ elif top_k is None:
+ raise ValueError(f"Could not determine top_k for layer '{module_name}'. "
+ f"Set model.config.{preset.top_k_attr} or config top_k.")
+
+ # Infer hidden sizes
+ try:
+ hidden_size, ffn_hidden_size = _infer_hidden_and_ffn_size(experts_child, expert_layout, storage,
+ num_experts)
+ except ValueError as e:
+ if self._requires_selected_preset_detection():
+ raise ValueError(f"AutoEP: preset '{preset_name}' matched layer '{module_name}' "
+ f"with router and experts, but shape inference failed: {e}") from e
+ logger.warning(f"Skipping {module_name}: {e}")
+ continue
+
+ # Cross-validate hidden_size with router
+ if router_weight is not None and router_weight.ndim == 2:
+ if hidden_size != router_weight.shape[1]:
+ raise ValueError(f"hidden_size={hidden_size} from expert weights mismatches "
+ f"router weight dim={router_weight.shape[1]} in '{module_name}'")
+
+ # Validate top_k <= num_experts
+ if top_k > num_experts:
+ raise ValueError(f"top_k={top_k} exceeds num_experts={num_experts} "
+ f"in layer '{module_name}'")
+
+ # Resolve score_func
+ if self.config.score_func != "auto":
+ score_func = self.config.score_func
+ else:
+ # Check model config for scoring_func attribute
+ cfg_score = getattr(self.model_config, 'scoring_func', None)
+ if cfg_score in ("softmax", "sigmoid"):
+ score_func = cfg_score
+ else:
+ score_func = preset.score_func
+
+ # Resolve score_apply
+ if self.config.score_apply != "auto":
+ score_apply = self.config.score_apply
+ else:
+ score_apply = preset.score_apply
+
+ route_norm = adapter.resolve_route_norm(self.config, preset, self.model_config)
+
+ route_scale = _resolve_route_scale(self.config, self.model_config)
+
+ group_routing = adapter.resolve_group_routing(self.config, self.model_config)
+
+ # Check gate bias
+ gate_bias = preset.gate_bias
+ if router_weight is not None:
+ gate_bias = getattr(router_child, 'bias', None) is not None
+
+ forward_contract = adapter.adjust_forward_contract(_detect_forward_contract(module, router_child))
+
+ # Check shared experts
+ has_shared = False
+ shared_name = ""
+ shared_gate_name = ""
+ if preset.has_shared_experts and preset.shared_experts_pattern:
+ shared = getattr(module, preset.shared_experts_pattern, None)
+ if shared is not None:
+ has_shared = True
+ shared_name = preset.shared_experts_pattern
+ if preset.shared_experts_gate_pattern:
+ shared_gate = getattr(module, preset.shared_experts_gate_pattern, None)
+ if shared_gate is not None:
+ shared_gate_name = preset.shared_experts_gate_pattern
+
+ # Warn about router stochasticity/precision settings
+ if self.model_config is not None:
+ jitter = getattr(self.model_config, 'router_jitter_noise', 0.0)
+ if jitter and jitter > 0:
+ logger.warning(f"Layer {module_name}: model has router_jitter_noise={jitter}, "
+ f"AutoEP router does not implement jitter.")
+ z_loss = getattr(self.model_config, 'router_z_loss_coef', 0.0)
+ if z_loss and z_loss > 0:
+ logger.warning(f"Layer {module_name}: model has router_z_loss_coef={z_loss}, "
+ f"AutoEP router does not implement z-loss.")
+
+ spec = MoELayerSpec(
+ moe_module_name=module_name,
+ model_family=preset_name,
+ router_name=preset.router_pattern,
+ experts_name=preset.experts_pattern,
+ expert_storage=storage,
+ expert_w1_name=expert_layout.expert_w1,
+ expert_w2_name=expert_layout.expert_w2,
+ expert_w3_name=expert_layout.expert_w3,
+ num_experts=num_experts,
+ top_k=top_k,
+ hidden_size=hidden_size,
+ ffn_hidden_size=ffn_hidden_size,
+ score_func=score_func,
+ score_apply=score_apply,
+ route_norm=route_norm,
+ gate_bias=gate_bias,
+ return_router_logits=forward_contract.return_router_logits,
+ router_logits_capture_target=forward_contract.capture_target,
+ router_logits_capture_index=forward_contract.capture_index,
+ router_logits_capture_layer_name=forward_contract.capture_layer_name,
+ has_shared_experts=has_shared,
+ shared_experts_name=shared_name,
+ shared_experts_gate_name=shared_gate_name,
+ route_scale=route_scale,
+ num_expert_groups=group_routing.num_expert_groups,
+ num_limited_groups=group_routing.num_limited_groups,
+ group_score_func=group_routing.group_score_func,
+ supports_expert_bias=preset.supports_expert_bias,
+ unsupported_router_bias_names=preset.unsupported_router_bias_names,
+ preset_adapter=preset.preset_adapter,
+ router_logits_capture_mode=forward_contract.router_logits_capture_mode,
+ moe_output_shape=forward_contract.moe_output_shape,
+ )
+ specs.append(spec)
+ logger.debug(f"Detected MoE layer: {module_name} (family={preset_name}, "
+ f"experts={num_experts}, top_k={top_k}, storage={storage})")
+
+ if not specs:
+ if self._requires_selected_preset_detection():
+ self._raise_no_moe_layers_detected(presets_to_try)
+ logger.warning("AutoEP: no MoE layers detected in model.")
+ else:
+ _raise_if_duplicate_moe_specs(specs)
+
+ return specs
+
+ def _replace_moe_layer_without_retarget(
+ self,
+ spec: MoELayerSpec,
+ ep_size: int,
+ ep_rank: int,
+ ) -> nn.Module:
+ from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer
+
+ # Navigate to the parent module and get the child name
+ parts = spec.moe_module_name.split(".")
+ parent = self.model
+ for part in parts[:-1]:
+ parent = getattr(parent, part)
+ child_name = parts[-1]
+ source_module = getattr(parent, child_name)
+
+ # Create replacement layer
+ replacement = AutoEPMoELayer(
+ spec=spec,
+ source_module=source_module,
+ ep_size=ep_size,
+ ep_rank=ep_rank,
+ config=self.config,
+ )
+
+ # Replace in-place on parent
+ setattr(parent, child_name, replacement)
+ return replacement
+
+ def _retarget_transformers_output_recorders(self, spec: MoELayerSpec, replacement: nn.Module) -> None:
+ adapter = get_preset_adapter(spec.preset_adapter)
+ adapter.retarget_transformers_output_recorders(
+ self.model,
+ spec,
+ replacement,
+ self._retargeted_transformers_output_recorders,
+ _remove_transformers_output_capture_hooks,
+ )
+
+ def replace_moe_layer(
+ self,
+ spec: MoELayerSpec,
+ ep_size: int,
+ ep_rank: int,
+ ) -> None:
+ """Replace a single MoE module with AutoEPMoELayer in-place on the model."""
+ replacement = self._replace_moe_layer_without_retarget(spec, ep_size, ep_rank)
+ self._retarget_transformers_output_recorders(spec, replacement)
+
+ logger.info(f"AutoEP: replaced '{spec.moe_module_name}' with AutoEPMoELayer "
+ f"(ep_size={ep_size}, ep_rank={ep_rank}, "
+ f"local_experts={replacement.num_local_experts})")
+
+ def replace_moe_layers(
+ self,
+ specs: list[MoELayerSpec],
+ ep_size: int,
+ ep_rank: int,
+ ) -> None:
+ """Replace multiple MoE modules and batch post-replacement recorder retargeting."""
+ replacements: list[tuple[MoELayerSpec, nn.Module]] = []
+ for spec in specs:
+ replacement = self._replace_moe_layer_without_retarget(spec, ep_size, ep_rank)
+ replacements.append((spec, replacement))
+ logger.info(f"AutoEP: replaced '{spec.moe_module_name}' with AutoEPMoELayer "
+ f"(ep_size={ep_size}, ep_rank={ep_rank}, "
+ f"local_experts={replacement.num_local_experts})")
+
+ retarget_groups: OrderedDict[tuple[str, str, type], tuple[MoELayerSpec, nn.Module]] = OrderedDict()
+ for spec, replacement in replacements:
+ retarget_key = (spec.preset_adapter, spec.model_family, replacement.__class__)
+ retarget_groups.setdefault(retarget_key, (spec, replacement))
+
+ for spec, replacement in retarget_groups.values():
+ self._retarget_transformers_output_recorders(spec, replacement)
+
+ def _apply_config_overrides(self, preset: MoEModelPreset) -> MoEModelPreset:
+ return apply_config_overrides(self.config, preset)
+
+ def _requires_selected_preset_detection(self) -> bool:
+ """Return whether empty detection should fail for the selected preset."""
+ if self.config.preset_model is not None:
+ return True
+ if self.config.moe_layer_pattern is not None:
+ return True
+ if self.model_config is None:
+ return False
+ model_type = getattr(self.model_config, 'model_type', None)
+ return _is_known_hf_model_type(model_type)
+
+ def _raise_no_moe_layers_detected(self, presets_to_try: list[tuple[str, MoEModelPreset]]) -> None:
+ model_type = getattr(self.model_config, 'model_type', None)
+ if self.config.preset_model is not None:
+ source = f"preset_model='{self.config.preset_model}'"
+ elif self.config.moe_layer_pattern is not None:
+ source = f"moe_layer_pattern='{self.config.moe_layer_pattern}'"
+ else:
+ source = f"model_type='{model_type}'"
+
+ expected = "; ".join(f"{preset_name}: moe_layer_pattern='{preset.moe_layer_pattern}', "
+ f"router='{preset.router_pattern}', experts='{preset.experts_pattern}'"
+ for preset_name, preset in presets_to_try)
+ raise ValueError(f"AutoEP: no MoE layers detected for {source}. "
+ f"Expected MoE structure for selected preset(s): {expected}. "
+ "This usually means the selected preset does not match the model implementation, "
+ "or the installed Transformers version exposes a different structure. Choose a matching "
+ "preset, upgrade Transformers, or provide custom AutoEP patterns.")
+
+ def _resolve_presets(self) -> list[tuple[str, MoEModelPreset]]:
+ """Determine which preset(s) to use for detection."""
+ return resolve_preset_candidates(self.config, self.model_config)
diff --git a/deepspeed/module_inject/auto_ep_config.py b/deepspeed/module_inject/auto_ep_config.py
new file mode 100644
index 000000000000..addcc6094a38
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_config.py
@@ -0,0 +1,303 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""AutoEP configuration: config parsing, model presets, and validation."""
+
+from __future__ import annotations
+
+from deepspeed.module_inject.auto_ep_presets.base import (
+ _UNSET,
+ _raise_unsupported_load_balance_coeff,
+ AutoEPConfig,
+ MoELayerSpec,
+ MoEModelPreset,
+)
+from deepspeed.module_inject.auto_ep_presets.registry import (
+ PRESET_MODELS,
+ available_preset_names,
+ resolve_autoep_config_defaults,
+)
+from deepspeed.utils import logger
+
+__all__ = [
+ "_UNSET",
+ "AutoEPConfig",
+ "MoELayerSpec",
+ "MoEModelPreset",
+ "PRESET_MODELS",
+ "parse_autoep_config",
+ "resolve_autoep_config_defaults",
+ "validate_autoep_config",
+ "validate_autoep_post_detection",
+]
+
+# ---------------------------------------------------------------------------
+# Config parsing
+# ---------------------------------------------------------------------------
+
+
+def parse_autoep_config(param_dict: dict) -> AutoEPConfig:
+ """Parse the 'expert_parallel' section from DS config JSON."""
+ if not param_dict:
+ return AutoEPConfig()
+
+ config = AutoEPConfig()
+ config.enabled = param_dict.get("enabled", False)
+ config.autoep_size = param_dict.get("autoep_size", 1)
+ config.preset_model = param_dict.get("preset_model", None)
+ config.moe_layer_pattern = param_dict.get("moe_layer_pattern", None)
+ config.expert_pattern = param_dict.get("expert_pattern", None)
+ config.router_pattern = param_dict.get("router_pattern", None)
+ config.use_grouped_mm = param_dict.get("use_grouped_mm", True)
+ config.route_norm = param_dict.get("route_norm", None)
+ config.route_scale = param_dict.get("route_scale", 1.0)
+ config.score_apply = param_dict.get("score_apply", "auto")
+ config.combine_impl = param_dict.get("combine_impl", "auto")
+ config.num_expert_groups = param_dict.get("num_expert_groups", None)
+ config.num_limited_groups = param_dict.get("num_limited_groups", None)
+ config.score_func = param_dict.get("score_func", "auto")
+ config.top_k = param_dict.get("top_k", "auto")
+ if "load_balance_coeff" in param_dict:
+ value = param_dict["load_balance_coeff"]
+ if value is not None:
+ _raise_unsupported_load_balance_coeff(value)
+ config.load_balance_coeff = None
+ config._load_balance_coeff_explicit = True
+ else:
+ config.load_balance_coeff = None
+ config._load_balance_coeff_explicit = False
+ config.routed_scaling_factor = param_dict.get("routed_scaling_factor", "auto")
+ config.expert_w1 = param_dict.get("expert_w1", None)
+ config.expert_w2 = param_dict.get("expert_w2", None)
+ # expert_w3: key absent → _UNSET (preset default); key present with null → None (fused); key present with string → custom name
+ if "expert_w3" in param_dict:
+ config.expert_w3 = param_dict["expert_w3"] # None or string
+ else:
+ config.expert_w3 = _UNSET
+ config.num_experts_attr = param_dict.get("num_experts_attr", None)
+ config.top_k_attr = param_dict.get("top_k_attr", None)
+ config.has_shared_experts = param_dict.get("has_shared_experts", None)
+ config.shared_experts_pattern = param_dict.get("shared_experts_pattern", None)
+ config.shared_experts_gate_pattern = param_dict.get("shared_experts_gate_pattern", None)
+
+ return config
+
+
+# ---------------------------------------------------------------------------
+# Validation helpers
+# ---------------------------------------------------------------------------
+
+
+def validate_autoep_config(
+ config: AutoEPConfig,
+ world_size: int,
+ pp_size: int,
+ tp_size: int,
+ sp_size: int,
+) -> None:
+ """Validate config constraints. Raises ValueError on invalid config."""
+ if config.load_balance_coeff is not None:
+ _raise_unsupported_load_balance_coeff(config.load_balance_coeff)
+
+ if not config.enabled:
+ return
+
+ if tp_size > 1:
+ raise ValueError("AutoEP does not currently support AutoTP "
+ f"(tensor_parallel.autotp_size={tp_size}). Disable AutoTP for this run; "
+ "AutoEP+AutoTP support is planned as follow-up work.")
+
+ # ep_size must divide the stage size (world_size / pp_size)
+ stage_size = world_size // pp_size
+ if stage_size % config.autoep_size != 0:
+ raise ValueError(f"autoep_size={config.autoep_size} must divide the stage size "
+ f"(world_size={world_size} / pp_size={pp_size} = {stage_size}). "
+ f"Valid autoep_size values: {_divisors(stage_size)}")
+
+ # Validate preset_model if specified
+ if config.preset_model is not None and config.preset_model not in PRESET_MODELS:
+ raise ValueError(f"Unknown preset_model '{config.preset_model}'. "
+ f"Available presets: {list(available_preset_names())}")
+
+ # Validate score_apply
+ valid_score_apply = ("auto", "pre", "post")
+ if config.score_apply not in valid_score_apply:
+ raise ValueError(f"score_apply must be one of {valid_score_apply}, "
+ f"got '{config.score_apply}'")
+
+ # Validate combine_impl
+ valid_combine_impl = ("auto", "weighted_sum", "legacy_bmm")
+ if config.combine_impl not in valid_combine_impl:
+ raise ValueError(f"combine_impl must be one of {valid_combine_impl}, "
+ f"got '{config.combine_impl}'")
+
+ # Validate score_func
+ valid_score_func = ("auto", "softmax", "sigmoid")
+ if config.score_func not in valid_score_func:
+ raise ValueError(f"score_func must be one of {valid_score_func}, "
+ f"got '{config.score_func}'")
+
+ # Validate group-limited routing constraints
+ if config.num_limited_groups is not None:
+ if config.num_limited_groups < 1:
+ raise ValueError(f"num_limited_groups must be >= 1, got {config.num_limited_groups}")
+
+ if config.num_expert_groups is not None:
+ if config.num_expert_groups < 1:
+ raise ValueError(f"num_expert_groups must be >= 1, got {config.num_expert_groups}")
+ if config.num_limited_groups is not None and config.num_limited_groups > config.num_expert_groups:
+ raise ValueError(f"num_limited_groups ({config.num_limited_groups}) must be <= "
+ f"num_expert_groups ({config.num_expert_groups})")
+ logger.warning("num_expert_groups is set; interaction with EP topology "
+ "is not yet optimized.")
+
+ # Warn if autoep_size == 1 (no EP needed)
+ if config.autoep_size == 1:
+ logger.warning("autoep_size=1 means every rank owns all experts with no AllToAll. "
+ "AutoEP replacement remains enabled, but expert-parallel communication "
+ "is bypassed because every rank owns all experts.")
+
+ # Helper validators (local to validate_autoep_config)
+ def _validate_attr_name(field_name: str, value, *, allow_dot: bool = False) -> None:
+ if value is None:
+ return
+ if not isinstance(value, str) or value == "":
+ raise ValueError(f"{field_name} must be a non-empty string")
+ if not allow_dot and "." in value:
+ raise ValueError(f"{field_name} must be a direct attribute name (no dots)")
+
+ # Validate expert weight names
+ _validate_attr_name("expert_w1", config.expert_w1)
+ _validate_attr_name("expert_w2", config.expert_w2)
+ if config.expert_w3 is not _UNSET and config.expert_w3 is not None:
+ _validate_attr_name("expert_w3", config.expert_w3)
+
+ # Validate model.config attribute names
+ _validate_attr_name("num_experts_attr", config.num_experts_attr)
+ _validate_attr_name("top_k_attr", config.top_k_attr)
+
+ # Validate child-name fields (direct attribute names, not regex/path)
+ _validate_attr_name("router_pattern", config.router_pattern)
+ _validate_attr_name("expert_pattern", config.expert_pattern)
+ _validate_attr_name("shared_experts_pattern", config.shared_experts_pattern)
+ _validate_attr_name("shared_experts_gate_pattern", config.shared_experts_gate_pattern)
+
+ # Validate has_shared_experts type
+ if config.has_shared_experts is not None and not isinstance(config.has_shared_experts, bool):
+ raise ValueError("has_shared_experts must be a boolean when set")
+
+ # Warn if explicit top_k overrides top_k_attr
+ if isinstance(config.top_k, int) and config.top_k_attr is not None:
+ logger.warning("top_k is explicitly set; top_k_attr will be ignored.")
+
+ if config.routed_scaling_factor != "auto" and not isinstance(config.routed_scaling_factor, (int, float)):
+ raise ValueError("routed_scaling_factor must be a number or 'auto'")
+
+ # Validate shared expert field pairing
+ if config.has_shared_experts is True and not config.shared_experts_pattern:
+ logger.warning("has_shared_experts=True but shared_experts_pattern is not set. "
+ "Shared expert detection requires both fields.")
+ if config.shared_experts_pattern and config.has_shared_experts is not True:
+ logger.warning(f"shared_experts_pattern='{config.shared_experts_pattern}' is set "
+ f"but has_shared_experts is not True. Pattern will be ignored.")
+ if config.shared_experts_gate_pattern and config.has_shared_experts is not True:
+ logger.warning(f"shared_experts_gate_pattern='{config.shared_experts_gate_pattern}' is set "
+ f"but has_shared_experts is not True. Pattern will be ignored.")
+
+ # Warn if custom override fields are set alongside preset_model or auto-detect
+ custom_fields_set = []
+ if config.moe_layer_pattern is not None:
+ custom_fields_set.append("moe_layer_pattern")
+ if config.router_pattern is not None:
+ custom_fields_set.append("router_pattern")
+ if config.expert_pattern is not None:
+ custom_fields_set.append("expert_pattern")
+ if config.expert_w1 is not None:
+ custom_fields_set.append("expert_w1")
+ if config.expert_w2 is not None:
+ custom_fields_set.append("expert_w2")
+ if config.expert_w3 is not _UNSET:
+ custom_fields_set.append("expert_w3")
+ if config.num_experts_attr is not None:
+ custom_fields_set.append("num_experts_attr")
+ if config.top_k_attr is not None:
+ custom_fields_set.append("top_k_attr")
+ if config.has_shared_experts is not None:
+ custom_fields_set.append("has_shared_experts")
+ if config.shared_experts_pattern is not None:
+ custom_fields_set.append("shared_experts_pattern")
+ if config.shared_experts_gate_pattern is not None:
+ custom_fields_set.append("shared_experts_gate_pattern")
+ if custom_fields_set and config.preset_model is not None:
+ logger.warning(f"Custom preset fields {custom_fields_set} are set alongside "
+ f"preset_model='{config.preset_model}'. Custom fields will override "
+ f"preset defaults during detection.")
+ if custom_fields_set and config.preset_model is None and config.moe_layer_pattern is None:
+ logger.warning(f"Custom preset fields {custom_fields_set} are set without preset_model or "
+ f"moe_layer_pattern. Overrides will apply to auto-detected presets or try-all.")
+
+
+def validate_autoep_post_detection(
+ config: AutoEPConfig,
+ specs: list[MoELayerSpec],
+) -> None:
+ """Post-detection validation: ep_size vs num_experts constraints."""
+ if not config.enabled or not specs:
+ return
+
+ for spec in specs:
+ # ep_size must not exceed num_experts
+ if config.autoep_size > spec.num_experts:
+ valid_divisors = _divisors(spec.num_experts)
+ raise ValueError(f"autoep_size={config.autoep_size} exceeds num_experts="
+ f"{spec.num_experts} in layer '{spec.moe_module_name}'. "
+ f"Each rank must own at least one expert. "
+ f"Valid autoep_size values (divisors of {spec.num_experts}): "
+ f"{valid_divisors}")
+
+ # num_experts must be divisible by ep_size
+ if spec.num_experts % config.autoep_size != 0:
+ valid_sizes = [d for d in _divisors(spec.num_experts) if d <= spec.num_experts]
+ raise ValueError(f"num_experts={spec.num_experts} in layer "
+ f"'{spec.moe_module_name}' is not divisible by "
+ f"autoep_size={config.autoep_size}. "
+ f"Suggested autoep_size values: {valid_sizes}")
+
+ num_expert_groups = spec.num_expert_groups if spec.num_expert_groups is not None else config.num_expert_groups
+ num_limited_groups = spec.num_limited_groups if spec.num_limited_groups is not None else config.num_limited_groups
+
+ # Validate group-limited routing constraints after layer-specific defaults.
+ if num_limited_groups is not None and num_expert_groups is None:
+ raise ValueError(f"num_limited_groups requires num_expert_groups to be set "
+ f"in layer '{spec.moe_module_name}'")
+
+ if num_expert_groups is not None:
+ if num_expert_groups < 1:
+ raise ValueError(f"num_expert_groups must be >= 1 in layer '{spec.moe_module_name}', "
+ f"got {num_expert_groups}")
+ if spec.num_experts % num_expert_groups != 0:
+ raise ValueError(f"num_expert_groups ({num_expert_groups}) must divide "
+ f"num_experts ({spec.num_experts}) in layer "
+ f"'{spec.moe_module_name}'")
+ if num_limited_groups is None:
+ raise ValueError(f"num_limited_groups must be set when num_expert_groups is set "
+ f"in layer '{spec.moe_module_name}'")
+ if num_limited_groups < 1:
+ raise ValueError(f"num_limited_groups must be >= 1 in layer '{spec.moe_module_name}', "
+ f"got {num_limited_groups}")
+ if num_limited_groups > num_expert_groups:
+ raise ValueError(f"num_limited_groups ({num_limited_groups}) must be <= "
+ f"num_expert_groups ({num_expert_groups}) in layer "
+ f"'{spec.moe_module_name}'")
+
+
+def _divisors(n: int) -> list[int]:
+ """Return sorted list of positive divisors of n."""
+ divs = []
+ for i in range(1, int(n**0.5) + 1):
+ if n % i == 0:
+ divs.append(i)
+ if i != n // i:
+ divs.append(n // i)
+ return sorted(divs)
diff --git a/deepspeed/module_inject/auto_ep_layer.py b/deepspeed/module_inject/auto_ep_layer.py
new file mode 100644
index 000000000000..e2c75219a0b7
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_layer.py
@@ -0,0 +1,602 @@
+# Copyright (c) DeepSpeed Team.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
+#
+# Portions of this file are derived from TorchTitan.
+# See THIRD_PARTY_NOTICES.md for the BSD-3-Clause notice.
+
+# DeepSpeed Team
+"""AutoEP MoE Layer: drop-in replacement for HF MoE blocks with EP support.
+
+Contains AutoEPMoELayer, compute_split_plan, _AllToAllV, and helper functions.
+"""
+
+from __future__ import annotations
+
+from typing import Literal, NamedTuple
+
+import torch
+import torch.nn as nn
+import deepspeed.comm as dist
+from deepspeed.module_inject.auto_ep_config import AutoEPConfig, MoELayerSpec, resolve_autoep_config_defaults
+from deepspeed.utils import logger
+from deepspeed.moe.ep_router import TokenChoiceTopKRouter
+from deepspeed.moe.ep_count import count_tokens_per_expert
+from deepspeed.moe.ep_experts import GroupedExperts
+from deepspeed.moe.ep_kernels import TokenReorderer
+from deepspeed.moe.ep_repack import repack_expert_weights
+
+# ---------------------------------------------------------------------------
+# Named tuples
+# ---------------------------------------------------------------------------
+
+
+class RouterOutput(NamedTuple):
+ top_scores: torch.Tensor # [T, K]
+ selected_experts: torch.Tensor # [T, K]
+ num_tokens_per_expert: torch.Tensor # [E_global]
+
+
+class SplitPlan(NamedTuple):
+ input_splits: list[int] # len=ep_size
+ output_splits: list[int] # len=ep_size
+ local_counts: torch.Tensor # [E_local]
+ local_counts_by_source: torch.Tensor # [ep_size, E_local]
+
+
+# ---------------------------------------------------------------------------
+# Helper functions
+# ---------------------------------------------------------------------------
+
+
+def resolve_score_apply_mode(
+ spec: MoELayerSpec,
+ config_override: Literal["auto", "pre", "post"],
+) -> Literal["pre", "post"]:
+ """Resolve score-application mode from config override or preset default."""
+ if config_override != "auto":
+ return config_override
+ return spec.score_apply
+
+
+def resolve_combine_impl(
+ config_override: Literal["auto", "weighted_sum", "legacy_bmm"], ) -> Literal["weighted_sum", "legacy_bmm"]:
+ """Resolve combine implementation from config override or default."""
+ if config_override != "auto":
+ return config_override
+ return "weighted_sum"
+
+
+def apply_scores_before_experts_if_enabled(
+ routed_input: torch.Tensor,
+ top_scores: torch.Tensor,
+ score_apply: Literal["pre", "post"],
+) -> torch.Tensor:
+ """Pre-multiply token representations by router scores before expert compute."""
+ if score_apply == "pre":
+ return (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to(routed_input.dtype)
+ return routed_input
+
+
+def compute_split_plan(
+ selected_experts: torch.Tensor, # [T, K]
+ num_experts: int,
+ ep_size: int,
+ num_local_experts: int,
+ ep_group: dist.ProcessGroup | None,
+) -> SplitPlan:
+ """Compute AllToAllV split sizes for token dispatch/combine.
+
+ Returns SplitPlan with input_splits, output_splits, local_counts, and
+ local_counts_by_source.
+ """
+ T_K = selected_experts.numel()
+
+ if ep_size == 1:
+ # No dispatch needed - all tokens stay local
+ num_tokens_per_expert = count_tokens_per_expert(
+ selected_experts,
+ num_experts,
+ out_dtype=torch.int32,
+ )
+ return SplitPlan(
+ input_splits=[T_K],
+ output_splits=[T_K],
+ local_counts=num_tokens_per_expert,
+ local_counts_by_source=num_tokens_per_expert.view(1, num_local_experts),
+ )
+
+ # Count tokens per expert globally
+ num_tokens_per_expert = count_tokens_per_expert(
+ selected_experts,
+ num_experts,
+ out_dtype=torch.int32,
+ )
+
+ # Reshape to [ep_size, num_local_experts] to get per-rank counts
+ count_matrix = num_tokens_per_expert.view(ep_size, num_local_experts)
+
+ # input_splits: how many tokens THIS rank sends to each destination rank
+ input_splits = count_matrix.sum(dim=1).cpu().tolist()
+
+ # Exchange counts with all ranks to get output_splits
+ # Each rank tells every other rank how many tokens it will send
+ local_counts_tensor = count_matrix.sum(dim=1).clone() # [ep_size]
+ remote_counts_tensor = torch.zeros_like(local_counts_tensor)
+
+ dist.all_to_all_single(
+ remote_counts_tensor,
+ local_counts_tensor,
+ group=ep_group,
+ )
+ output_splits = remote_counts_tensor.cpu().tolist()
+
+ # local_counts: how many tokens this rank will process for each local expert
+ # After receiving tokens, we need per-expert counts for this rank
+ local_expert_counts = count_matrix[:, :].clone() # [ep_size, E_local]
+
+ # Exchange the detailed per-expert counts
+ # Each rank needs to know, for its local experts, how many tokens come from each source
+ local_expert_counts_flat = local_expert_counts.view(-1).contiguous() # [ep_size * E_local]
+ received_counts_flat = torch.zeros_like(local_expert_counts_flat)
+
+ dist.all_to_all_single(
+ received_counts_flat,
+ local_expert_counts_flat,
+ group=ep_group,
+ )
+
+ # Sum over source ranks to get total per local expert
+ received_counts = received_counts_flat.view(ep_size, num_local_experts)
+ local_counts = received_counts.sum(dim=0) # [E_local]
+
+ return SplitPlan(
+ input_splits=input_splits,
+ output_splits=output_splits,
+ local_counts=local_counts,
+ local_counts_by_source=received_counts,
+ )
+
+
+class _AllToAllV(torch.autograd.Function):
+ """Autograd-compatible all-to-all with variable split sizes."""
+
+ @staticmethod
+ def forward(ctx, group, x, input_splits, output_splits):
+ ctx.group = group
+ ctx.input_splits = input_splits
+ ctx.output_splits = output_splits
+
+ output_size = sum(output_splits)
+ output = torch.empty(
+ (output_size, x.shape[1]),
+ dtype=x.dtype,
+ device=x.device,
+ )
+
+ dist.all_to_all_single(
+ output,
+ x.contiguous(),
+ output_split_sizes=output_splits,
+ input_split_sizes=input_splits,
+ group=group,
+ )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ # Reverse the splits for backward
+ grad_out = grad_out.contiguous()
+ input_size = sum(ctx.input_splits)
+ grad_input = torch.empty(
+ (input_size, grad_out.shape[1]),
+ dtype=grad_out.dtype,
+ device=grad_out.device,
+ )
+
+ dist.all_to_all_single(
+ grad_input,
+ grad_out,
+ output_split_sizes=ctx.input_splits,
+ input_split_sizes=ctx.output_splits,
+ group=ctx.group,
+ )
+ return None, grad_input, None, None
+
+
+def permute_by_local_expert(
+ tokens: torch.Tensor,
+ local_counts: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
+ """Reorder tokens so they are grouped contiguously by local expert ID.
+
+ Uses TorchTitan's Triton kernel for permutation index generation.
+
+ Returns:
+ tokens_permuted: [N_padded, H] (alignment-padded)
+ permuted_indices: [N_padded] (maps padded positions -> original positions)
+ aligned_counts: [E_local] aligned token counts per expert (for expert computation)
+ n_tokens: original token count before padding (for unpermute)
+ """
+ from deepspeed.moe.ep_kernels import generate_permute_indices, TOKEN_GROUP_ALIGN_SIZE_M
+
+ if local_counts.ndim == 1:
+ # [E_local]: already aggregated over sources (ep_degree=1)
+ ep_degree = 1
+ num_local_experts = local_counts.shape[0]
+ local_counts_flat = local_counts
+ elif local_counts.ndim == 2:
+ # [ep_size, E_local]: preserve per-source layout for correct regrouping
+ ep_degree, num_local_experts = local_counts.shape
+ local_counts_flat = local_counts.reshape(-1)
+ else:
+ raise ValueError(
+ f"local_counts must have shape [E_local] or [ep_degree, E_local], got {tuple(local_counts.shape)}")
+
+ n_tokens = tokens.shape[0]
+ alignment = TOKEN_GROUP_ALIGN_SIZE_M
+
+ # Compute padded max length
+ x_padded_per_expert = n_tokens + num_local_experts * alignment
+ padded_max_len = ((x_padded_per_expert + alignment - 1) // alignment) * alignment
+
+ # Use the pure-PyTorch path for host tensors. The CPU accelerator reports
+ # CPU tensors as "on accelerator", but Triton still requires a GPU driver.
+ use_cpu = tokens.device.type == "cpu"
+ counts_for_permute = local_counts_flat.cpu() if use_cpu else local_counts_flat
+ with torch.no_grad():
+ permuted_indices, m_sizes, _offsets = generate_permute_indices(
+ counts_for_permute,
+ num_local_experts,
+ ep_degree,
+ padded_max_len,
+ alignment,
+ use_cpu=use_cpu,
+ )
+ if not use_cpu:
+ permuted_indices = permuted_indices.to(tokens.device)
+ m_sizes = m_sizes.to(tokens.device)
+
+ # Add padding row for out-of-bounds indices (index n_tokens -> zero row)
+ tokens_padded = torch.vstack((tokens, tokens.new_zeros((tokens.shape[-1], ))))
+ tokens_permuted = tokens_padded[permuted_indices, :]
+
+ return tokens_permuted, permuted_indices, m_sizes, n_tokens
+
+
+def unpermute_by_local_expert(
+ expert_output: torch.Tensor,
+ permuted_indices: torch.Tensor,
+ n_tokens: int,
+) -> torch.Tensor:
+ """Reverse permute_by_local_expert: restore original token order and strip padding.
+
+ Args:
+ expert_output: [N_padded, H] from expert computation
+ permuted_indices: [N_padded] index mapping from permute_by_local_expert
+ n_tokens: original token count before alignment padding
+ """
+ # Scatter expert outputs back to original positions.
+ # permuted_indices values range 0..n_tokens, where n_tokens is the zero-padding row.
+ out_unpermuted = expert_output.new_zeros((n_tokens + 1, expert_output.shape[-1]))
+ out_unpermuted[permuted_indices, :] = expert_output
+ # Strip the zero-padding row to get [n_tokens, H]
+ return out_unpermuted[:-1]
+
+
+def combine_from_routed(
+ expert_output: torch.Tensor, # [N, H]
+ top_scores: torch.Tensor, # [T, K]
+ token_indices_sorted: torch.Tensor, # [N]
+ top_k: int,
+ score_apply: Literal["pre", "post"],
+ combine_impl: Literal["weighted_sum", "legacy_bmm"],
+ shape: tuple[int, int, int], # (B, S, H)
+) -> torch.Tensor:
+ """Scatter-add expert outputs back to original token positions."""
+ bsz, seqlen, hdim = shape
+ T = bsz * seqlen
+
+ # Create output tensor
+ output = torch.zeros(T * top_k, hdim, dtype=expert_output.dtype, device=expert_output.device)
+
+ # Place expert outputs back in unsorted order
+ output[token_indices_sorted] = expert_output
+
+ # Reshape to [T, K, H]
+ output = output.reshape(T, top_k, hdim)
+
+ if score_apply == "post":
+ if combine_impl == "legacy_bmm":
+ # Legacy reduction path retained as a debug option for model-family
+ # verification. The weighted-sum path is the default.
+ output = torch.bmm(
+ top_scores.reshape(-1, 1, top_k).float(),
+ output.float(),
+ ).to(expert_output.dtype).squeeze(1)
+ else:
+ # Match the runtime HF grouped-mm path: apply routing weights per
+ # token-slot sample, then reduce over top-k.
+ output = (output.float() * top_scores.reshape(T, top_k, 1).float()).sum(dim=1).to(expert_output.dtype)
+ else:
+ # Scores already applied pre-experts, just sum over top_k
+ output = output.sum(dim=1)
+
+ return output.reshape(bsz, seqlen, hdim)
+
+
+# ---------------------------------------------------------------------------
+# AutoEPMoELayer
+# ---------------------------------------------------------------------------
+
+
+class AutoEPMoELayer(nn.Module):
+ """Drop-in replacement for HF MoE blocks with Expert Parallelism support."""
+
+ _is_autoep_layer = True # Marker for AutoTP skip handshake
+
+ def __init__(
+ self,
+ spec: MoELayerSpec,
+ source_module: nn.Module,
+ ep_size: int,
+ ep_rank: int,
+ config: AutoEPConfig,
+ ) -> None:
+ super().__init__()
+
+ self.model_family = spec.model_family
+ self.return_router_logits = spec.return_router_logits
+ self.router_logits_capture_target = spec.router_logits_capture_target
+ self.router_logits_capture_index = spec.router_logits_capture_index
+ self.router_logits_capture_mode = spec.router_logits_capture_mode
+ self.moe_output_shape = spec.moe_output_shape
+ self.top_k = spec.top_k
+ self.score_apply = resolve_score_apply_mode(spec, config.score_apply)
+ self.combine_impl = resolve_combine_impl(config.combine_impl)
+ route_norm = spec.route_norm if config.route_norm is None else config.route_norm
+ self.ep_size = ep_size
+ self.ep_rank = ep_rank
+ self.num_experts = spec.num_experts
+ self.num_local_experts = spec.num_experts // ep_size
+ self.hidden_size = spec.hidden_size
+ self.ep_group_name = f"ep_size_{ep_size}"
+ self.ep_group = None # Set by set_deepspeed_parallelism()
+ resolved_config = resolve_autoep_config_defaults(config, spec.model_family)
+
+ # Router: copy gate weights from source
+ source_gate = getattr(source_module, spec.router_name)
+ if not spec.supports_expert_bias and resolved_config.load_balance_coeff is not None:
+ raise ValueError(f"AutoEP preset '{spec.model_family}' does not support load_balance_coeff/expert_bias "
+ "yet. Set load_balance_coeff=None.")
+ for bias_name in spec.unsupported_router_bias_names:
+ router_bias = getattr(source_gate, bias_name, None)
+ if router_bias is None:
+ continue
+ if torch.is_tensor(router_bias) and torch.count_nonzero(router_bias.detach()).item() == 0:
+ continue
+ raise ValueError(f"AutoEP preset '{spec.model_family}' does not support nonzero router bias "
+ f"'{bias_name}' yet.")
+ self.router = TokenChoiceTopKRouter(
+ dim=spec.hidden_size,
+ num_experts=spec.num_experts,
+ num_expert_groups=spec.num_expert_groups,
+ num_limited_groups=spec.num_limited_groups,
+ top_k=spec.top_k,
+ score_func=spec.score_func,
+ route_norm=route_norm,
+ route_scale=spec.route_scale,
+ gate_bias=spec.gate_bias,
+ group_score_func=spec.group_score_func,
+ )
+ # Copy gate weights
+ self.router.gate.weight.data.copy_(source_gate.weight.data)
+ if spec.gate_bias and getattr(source_gate, 'bias', None) is not None:
+ self.router.gate.bias.data.copy_(source_gate.bias.data)
+
+ # Alias router under the name OutputRecorder expects (layer_name if provided),
+ # but only when OutputRecorder captures from the router child and the alias is safe.
+ alias_target = spec.router_logits_capture_layer_name or spec.router_name
+ if spec.router_logits_capture_target == "router" and alias_target != "router":
+ if "." in alias_target or alias_target in ("experts", "shared_experts") or hasattr(self, alias_target):
+ logger.warning(f"Skipping router alias '{alias_target}' to avoid name collision.")
+ else:
+ setattr(self, alias_target, self.router)
+
+ # Experts: extract local expert weights
+ w1, w2, w3 = repack_expert_weights(
+ experts_source=getattr(source_module, spec.experts_name),
+ spec=spec,
+ ep_rank=ep_rank,
+ ep_size=ep_size,
+ )
+ self.experts = GroupedExperts(
+ dim=spec.hidden_size,
+ hidden_dim=spec.ffn_hidden_size,
+ num_experts=self.num_local_experts,
+ use_grouped_mm=config.use_grouped_mm,
+ )
+ self.experts.w1.data.copy_(w1)
+ self.experts.w2.data.copy_(w2)
+ self.experts.w3.data.copy_(w3)
+
+ self.reorderer = TokenReorderer(num_experts=self.num_experts, top_k=self.top_k)
+ self.shared_experts = getattr(source_module, spec.shared_experts_name,
+ None) if spec.has_shared_experts else None
+ self.shared_experts_gate = getattr(source_module, spec.shared_experts_gate_name,
+ None) if spec.shared_experts_gate_name else None
+
+ # Mark expert params for EDP gradient reduction
+ for param in self.experts.parameters():
+ param.allreduce = False
+ param.group_name = self.ep_group_name
+
+ # Mark shared expert and router params for global DP reduction
+ for param in self.router.parameters():
+ param.allreduce = True
+ if self.shared_experts is not None:
+ for param in self.shared_experts.parameters():
+ param.allreduce = True
+ if self.shared_experts_gate is not None:
+ for param in self.shared_experts_gate.parameters():
+ param.allreduce = True
+
+ # Load balancing buffers
+ self.load_balance_coeff = resolved_config.load_balance_coeff
+ buf_device = source_gate.weight.device
+ if self.load_balance_coeff is not None:
+ self.register_buffer(
+ "expert_bias",
+ torch.zeros(spec.num_experts, dtype=torch.float32, device=buf_device),
+ persistent=True,
+ )
+ else:
+ self.expert_bias = None
+ self.register_buffer(
+ "tokens_per_expert",
+ torch.zeros(spec.num_experts, dtype=torch.float32, device=buf_device),
+ persistent=False,
+ )
+
+ # Router-logit cache
+ self._cached_router_logits = None
+ self._register_logit_hook()
+
+ def _register_logit_hook(self):
+ """Register a forward hook that caches gate logits for OutputRecorder capture."""
+ if self.router_logits_capture_target != "router":
+ return
+
+ def hook_fn(module, input, output):
+ x = input[0] # [T, H]
+ logits = module.gate(x) # [T, E_global]
+ if self.router_logits_capture_mode == "post_score":
+ if self.router.score_func == "softmax":
+ logits = torch.softmax(logits.float(), dim=-1).to(logits.dtype)
+ elif self.router.score_func == "sigmoid":
+ logits = torch.sigmoid(logits.float()).to(logits.dtype)
+ self._cached_router_logits = logits
+
+ self.router.register_forward_hook(hook_fn)
+
+ def set_deepspeed_parallelism(
+ self,
+ use_data_before_expert_parallel_: bool = False,
+ ) -> None:
+ """Bind EP group handle to this module."""
+ from deepspeed.utils import groups
+ from deepspeed.utils.bwc import bwc_pipeline_parallel_world_size
+
+ if self.ep_group_name not in groups._get_expert_parallel_group_dict():
+ mp_size = max(
+ getattr(groups, '_get_model_parallel_world_size', lambda: 1)(),
+ getattr(groups, '_get_sequence_parallel_world_size', lambda: 1)(),
+ )
+ mp_mode = "tp" if getattr(groups, '_get_model_parallel_world_size', lambda: 1)() > 1 else "sp"
+ pp_size = 1 if groups.mpu is None else bwc_pipeline_parallel_world_size(groups.mpu)
+ groups._create_expert_and_data_parallel(
+ expert_parallel_size_=self.ep_size,
+ mp_size=mp_size,
+ pp_size=pp_size,
+ mp_mode=mp_mode,
+ use_data_before_expert_parallel_=use_data_before_expert_parallel_,
+ )
+ self.ep_group = groups._get_expert_parallel_group(self.ep_group_name)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ """Forward pass.
+
+ Args:
+ hidden_states: [B, S, H]
+
+ Returns:
+ [B, S, H] or ([B, S, H], [T, E]) if return_router_logits.
+ Some HF MoE contracts return ([T, H], [T, E]) instead.
+ """
+ bsz, seqlen, hdim = hidden_states.shape
+ x = hidden_states.reshape(-1, hdim) # [T, H]
+
+ # Router
+ ro: RouterOutput = RouterOutput(*self.router(x, self.expert_bias))
+
+ # Accumulate expert utilization
+ with torch.no_grad():
+ self.tokens_per_expert.add_(ro.num_tokens_per_expert)
+
+ # Reorder tokens by expert
+ top_scores_sorted, token_indices_sorted, _ = self.reorderer(ro.top_scores, ro.selected_experts)
+
+ routed_input = x[token_indices_sorted // self.top_k] # [N, H]
+ routed_input = apply_scores_before_experts_if_enabled(routed_input,
+ top_scores_sorted,
+ score_apply=self.score_apply)
+
+ if self.ep_size == 1:
+ # No AllToAll needed - local computation only
+ local_counts = count_tokens_per_expert(
+ ro.selected_experts,
+ self.num_local_experts,
+ out_dtype=torch.int32,
+ )
+
+ routed_input_permuted, perm_indices, aligned_counts, n_tokens = permute_by_local_expert(
+ routed_input, local_counts)
+ expert_output = self.experts(routed_input_permuted, aligned_counts)
+ expert_output = unpermute_by_local_expert(expert_output, perm_indices, n_tokens)
+ else:
+ # EP dispatch/compute/combine
+ plan = compute_split_plan(
+ selected_experts=ro.selected_experts,
+ num_experts=self.num_experts,
+ ep_size=self.ep_size,
+ num_local_experts=self.num_local_experts,
+ ep_group=self.ep_group,
+ )
+
+ routed_input = _AllToAllV.apply(self.ep_group, routed_input, plan.input_splits, plan.output_splits)
+
+ routed_input, perm_indices, aligned_counts, n_tokens = permute_by_local_expert(
+ routed_input, plan.local_counts_by_source)
+ expert_output = self.experts(routed_input, aligned_counts)
+ expert_output = unpermute_by_local_expert(expert_output, perm_indices, n_tokens)
+
+ expert_output = _AllToAllV.apply(self.ep_group, expert_output, plan.output_splits, plan.input_splits)
+
+ output = combine_from_routed(
+ expert_output,
+ top_scores=ro.top_scores,
+ token_indices_sorted=token_indices_sorted,
+ top_k=self.top_k,
+ score_apply=self.score_apply,
+ combine_impl=self.combine_impl,
+ shape=(bsz, seqlen, hdim),
+ )
+
+ if self.moe_output_shape == "flat":
+ output = output.reshape(-1, hdim)
+ shared_expert_input = x
+ elif self.shared_experts_gate is not None:
+ shared_expert_input = x
+ else:
+ shared_expert_input = hidden_states
+
+ if self.shared_experts is not None:
+ shared_expert_output = self.shared_experts(shared_expert_input)
+ if self.shared_experts_gate is not None:
+ shared_expert_gate = torch.sigmoid(self.shared_experts_gate(shared_expert_input))
+ shared_expert_output = shared_expert_gate * shared_expert_output
+ if shared_expert_output.shape != output.shape:
+ shared_expert_output = shared_expert_output.reshape_as(output)
+ output = output + shared_expert_output
+
+ if self.return_router_logits:
+ logits = self._cached_router_logits
+ self._cached_router_logits = None
+ return output, logits
+
+ self._cached_router_logits = None
+ return output
diff --git a/deepspeed/module_inject/auto_ep_preset_adapters.py b/deepspeed/module_inject/auto_ep_preset_adapters.py
new file mode 100644
index 000000000000..5574e32aa246
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_preset_adapters.py
@@ -0,0 +1,27 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""Compatibility shim for AutoEP preset adapter APIs."""
+
+from deepspeed.module_inject.auto_ep_presets.base import (
+ AutoEPPresetAdapter,
+ ForwardContract,
+ GroupRoutingConfig,
+ TransformersTopLevelRouterLogitsAdapter,
+)
+from deepspeed.module_inject.auto_ep_presets.deepseek_v2 import DeepSeekV2PresetAdapter
+from deepspeed.module_inject.auto_ep_presets.deepseek_v3 import DeepSeekV3PresetAdapter
+from deepspeed.module_inject.auto_ep_presets.llama4 import Llama4PresetAdapter
+from deepspeed.module_inject.auto_ep_presets.qwen3_5_moe import Qwen35MoePresetAdapter
+from deepspeed.module_inject.auto_ep_presets.registry import get_preset_adapter
+
+__all__ = [
+ "AutoEPPresetAdapter",
+ "DeepSeekV2PresetAdapter",
+ "DeepSeekV3PresetAdapter",
+ "ForwardContract",
+ "GroupRoutingConfig",
+ "Llama4PresetAdapter",
+ "Qwen35MoePresetAdapter",
+ "TransformersTopLevelRouterLogitsAdapter",
+ "get_preset_adapter",
+]
diff --git a/deepspeed/module_inject/auto_ep_presets/__init__.py b/deepspeed/module_inject/auto_ep_presets/__init__.py
new file mode 100644
index 000000000000..a94395d12a85
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_presets/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""AutoEP built-in preset registry package."""
diff --git a/deepspeed/module_inject/auto_ep_presets/base.py b/deepspeed/module_inject/auto_ep_presets/base.py
new file mode 100644
index 000000000000..342e6ff1abb5
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_presets/base.py
@@ -0,0 +1,431 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""Shared AutoEP preset dataclasses and adapter interface."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field, replace
+from typing import Any, Callable, Literal, NoReturn
+
+import torch.nn as nn
+from packaging.version import InvalidVersion, Version
+
+# Sentinel for "not specified in config, use preset default".
+# Unlike None (which means "fused gate+up, no separate w3"), _UNSET means
+# the user did not set the field at all. Compare with `is _UNSET`.
+_UNSET = object()
+
+
+def _raise_unsupported_load_balance_coeff(value: object) -> NoReturn:
+ raise ValueError(f"load_balance_coeff={value!r} is not supported in this AutoEP build "
+ "(would register expert_bias and route through unsupported "
+ "auxiliary-loss-free load balancing). Set load_balance_coeff to null "
+ "or omit the key.")
+
+
+@dataclass
+class MoEModelPreset:
+ """Preset configuration for a known MoE model family."""
+
+ moe_layer_pattern: str
+ router_pattern: str
+ experts_pattern: str
+ expert_storage: Literal["fused_3d", "module_list"]
+ expert_w1: str
+ expert_w2: str
+ expert_w3: str | None
+ num_experts_attr: str
+ top_k_attr: str
+ score_func: Literal["softmax", "sigmoid"]
+ score_apply: Literal["pre", "post"]
+ route_norm: bool
+ gate_bias: bool
+ has_shared_experts: bool = False
+ shared_experts_pattern: str = ""
+ shared_experts_gate_pattern: str = ""
+ autoep_config_defaults: dict[str, Any] = field(default_factory=dict)
+ supports_expert_bias: bool = True
+ unsupported_router_bias_names: tuple[str, ...] = ()
+ preset_adapter: str = "default"
+ hf_model_types: tuple[str, ...] = ()
+ unsupported_hf_model_type_notes: dict[str, str] = field(default_factory=dict)
+ min_transformers_version: str | None = None
+ validated_transformers_versions: str = ""
+ docs_support_notes: str = ""
+
+
+@dataclass
+class MoELayerSpec:
+ """Detected MoE layer specification for a single module in the model."""
+
+ moe_module_name: str
+ model_family: str
+ router_name: str
+ experts_name: str
+ expert_storage: Literal["fused_3d", "module_list"]
+ expert_w1_name: str
+ expert_w2_name: str
+ expert_w3_name: str | None
+ num_experts: int
+ top_k: int
+ hidden_size: int
+ ffn_hidden_size: int
+ score_func: Literal["softmax", "sigmoid"]
+ score_apply: Literal["pre", "post"]
+ route_norm: bool
+ gate_bias: bool
+ return_router_logits: bool
+ router_logits_capture_target: Literal["moe_block", "router", "none"]
+ router_logits_capture_index: int | None
+ router_logits_capture_layer_name: str | None
+ has_shared_experts: bool
+ shared_experts_name: str
+ shared_experts_gate_name: str = ""
+ route_scale: float = 1.0
+ num_expert_groups: int | None = None
+ num_limited_groups: int | None = None
+ group_score_func: Literal["max", "top2_sum"] = "top2_sum"
+ supports_expert_bias: bool = True
+ unsupported_router_bias_names: tuple[str, ...] = ()
+ preset_adapter: str = "default"
+ router_logits_capture_mode: Literal["raw", "post_score"] = "post_score"
+ moe_output_shape: Literal["batched", "flat"] = "batched"
+
+
+@dataclass
+class AutoEPConfig:
+ """User-facing configuration parsed from DS config JSON."""
+
+ enabled: bool = False
+ autoep_size: int = 1
+ preset_model: str | None = None
+ moe_layer_pattern: str | None = None
+ expert_pattern: str | None = None
+ router_pattern: str | None = None
+ use_grouped_mm: bool = True
+ route_norm: bool | None = None
+ route_scale: float = 1.0
+ score_apply: Literal["auto", "pre", "post"] = "auto"
+ combine_impl: Literal["auto", "weighted_sum", "legacy_bmm"] = "auto"
+ num_expert_groups: int | None = None
+ num_limited_groups: int | None = None
+ score_func: Literal["auto", "softmax", "sigmoid"] = "auto"
+ top_k: int | str = "auto"
+ load_balance_coeff: float | None | object = _UNSET
+ routed_scaling_factor: float | str = "auto"
+ expert_w1: str | None = None
+ expert_w2: str | None = None
+ expert_w3: object = _UNSET
+ num_experts_attr: str | None = None
+ top_k_attr: str | None = None
+ has_shared_experts: bool | None = None
+ shared_experts_pattern: str | None = None
+ shared_experts_gate_pattern: str | None = None
+ _load_balance_coeff_explicit: bool = field(default=False, init=False, repr=False)
+
+ def __post_init__(self) -> None:
+ if self.load_balance_coeff is _UNSET:
+ self.load_balance_coeff = None
+ self._load_balance_coeff_explicit = False
+ else:
+ self._load_balance_coeff_explicit = True
+
+
+@dataclass(frozen=True)
+class GroupRoutingConfig:
+ num_expert_groups: int | None
+ num_limited_groups: int | None
+ group_score_func: Literal["max", "top2_sum"] = "top2_sum"
+
+
+@dataclass(frozen=True)
+class ForwardContract:
+ return_router_logits: bool = False
+ capture_target: Literal["moe_block", "router", "none"] = "none"
+ capture_index: int | None = None
+ capture_layer_name: str | None = None
+ router_logits_capture_mode: Literal["raw", "post_score"] = "post_score"
+ moe_output_shape: Literal["batched", "flat"] = "batched"
+
+
+class AutoEPPresetAdapter:
+ """Default behavior shared by presets without model-specific parser rules."""
+
+ def validate_compatibility(
+ self,
+ preset_name: str,
+ preset: MoEModelPreset,
+ model_config,
+ ) -> None:
+ """Validate public HF compatibility metadata for a selected preset."""
+ model_type = getattr(model_config, "model_type", None) if model_config is not None else None
+ self._validate_hf_model_type(preset_name, preset, model_type)
+ self._validate_transformers_version(preset_name, preset, model_type)
+
+ def _validate_hf_model_type(
+ self,
+ preset_name: str,
+ preset: MoEModelPreset,
+ model_type: str | None,
+ ) -> None:
+ if model_type is None:
+ return
+
+ unsupported_note = preset.unsupported_hf_model_type_notes.get(model_type)
+ if unsupported_note is None:
+ return
+
+ supported = ", ".join(repr(value) for value in preset.hf_model_types) or "none"
+ raise ValueError(f"AutoEP preset '{preset_name}' does not support model_type='{model_type}'. "
+ f"{unsupported_note} Supported HF model_type value(s): {supported}.")
+
+ def _validate_transformers_version(
+ self,
+ preset_name: str,
+ preset: MoEModelPreset,
+ model_type: str | None,
+ ) -> None:
+ min_version = preset.min_transformers_version
+ if min_version is None or model_type is None:
+ return
+ if not self._requires_transformers_version_validation():
+ return
+ if model_type not in preset.hf_model_types and model_type not in preset.unsupported_hf_model_type_notes:
+ return
+
+ try:
+ installed_version = self._installed_transformers_version()
+ except Exception as exc:
+ raise ValueError(f"AutoEP preset '{preset_name}' for model_type='{model_type}' requires "
+ f"Transformers >= {min_version}, but transformers could not be imported: {exc}.") from exc
+
+ try:
+ installed = Version(installed_version)
+ minimum = Version(min_version)
+ except InvalidVersion as exc:
+ raise ValueError(f"AutoEP preset '{preset_name}' for model_type='{model_type}' requires "
+ f"Transformers >= {min_version}, but the installed Transformers version "
+ f"'{installed_version}' could not be parsed.") from exc
+
+ if installed < minimum:
+ raise ValueError(f"AutoEP preset '{preset_name}' for model_type='{model_type}' requires "
+ f"Transformers >= {min_version}, but installed transformers=={installed_version}. "
+ "Upgrade Transformers or choose a preset/model combination supported by the "
+ "installed Transformers version.")
+
+ def _installed_transformers_version(self) -> str:
+ import transformers
+ return getattr(transformers, "__version__", "unknown")
+
+ def _requires_transformers_version_validation(self) -> bool:
+ # The default adapter also covers non-HF/mock/custom-compatible configs;
+ # specialized HF-only adapters opt in to minimum Transformers checks.
+ return False
+
+ def resolve_route_norm(
+ self,
+ config: AutoEPConfig,
+ preset: MoEModelPreset,
+ model_config,
+ ) -> bool:
+ if config.route_norm is not None:
+ return config.route_norm
+
+ cfg_norm = getattr(model_config, 'norm_topk_prob', None)
+ if cfg_norm is not None:
+ return bool(cfg_norm)
+ return preset.route_norm
+
+ def resolve_group_routing(
+ self,
+ config: AutoEPConfig,
+ model_config,
+ ) -> GroupRoutingConfig:
+ return GroupRoutingConfig(
+ num_expert_groups=config.num_expert_groups,
+ num_limited_groups=config.num_limited_groups,
+ )
+
+ def resolve_expert_layout(
+ self,
+ experts_module: nn.Module,
+ preset: MoEModelPreset,
+ ) -> MoEModelPreset:
+ return preset
+
+ def adjust_forward_contract(self, contract: ForwardContract) -> ForwardContract:
+ return contract
+
+ def retarget_transformers_output_recorders(
+ self,
+ model: nn.Module,
+ spec: MoELayerSpec,
+ replacement: nn.Module,
+ retargeted_keys: set[str],
+ remove_output_capture_hooks: Callable[[nn.Module], int],
+ ) -> None:
+ return
+
+
+_MISSING_REGISTRY_ENTRY = object()
+
+
+def _restore_transformers_output_capture_registry(
+ registry: dict[str, Any],
+ original_entries: dict[str, object],
+) -> None:
+ for registry_key, original_entry in original_entries.items():
+ if original_entry is _MISSING_REGISTRY_ENTRY:
+ registry.pop(registry_key, None)
+ else:
+ registry[registry_key] = original_entry
+
+
+def _install_instance_transformers_output_recorders(
+ model: nn.Module,
+ registry_entries: dict[str, dict[str, Any]],
+ output_capturing: Any,
+ remove_output_capture_hooks: Callable[[nn.Module], int],
+) -> bool:
+ maybe_install_capturing_hooks = getattr(output_capturing, "maybe_install_capturing_hooks", None)
+ registry = getattr(output_capturing, "_CAN_RECORD_REGISTRY", None)
+ if not callable(maybe_install_capturing_hooks) or not isinstance(registry, dict):
+ return False
+
+ remove_output_capture_hooks(model)
+ for module in model.modules():
+ if hasattr(module, "_output_capturing_hooks_installed"):
+ module._output_capturing_hooks_installed = False
+ model._output_capturing_hooks_installed = False
+
+ original_entries = {
+ registry_key: registry.get(registry_key, _MISSING_REGISTRY_ENTRY)
+ for registry_key in registry_entries
+ }
+ try:
+ registry.update(registry_entries)
+ maybe_install_capturing_hooks(model)
+ finally:
+ _restore_transformers_output_capture_registry(registry, original_entries)
+ return True
+
+
+def _retarget_transformers_output_recorders_for_modules(
+ *,
+ model: nn.Module,
+ display_name: str,
+ recorder_key: str,
+ retargeted_keys: set[str],
+ remove_output_capture_hooks: Callable[[nn.Module], int],
+ module_matches: Callable[[nn.Module], bool],
+ make_output_recorder: Callable[[Any], Any],
+) -> int:
+ try:
+ from transformers.utils import output_capturing
+ except Exception:
+ return 0
+
+ registry = getattr(output_capturing, "_CAN_RECORD_REGISTRY", None)
+ if not isinstance(registry, dict):
+ return 0
+
+ registry_entries: dict[str, dict[str, Any]] = {}
+ retargeted = 0
+ for module in model.modules():
+ if not module_matches(module):
+ continue
+
+ registry_key = str(module.__class__)
+ record_outputs = getattr(module, "_can_record_outputs", None)
+ registry_outputs = registry.get(registry_key)
+ base_outputs = record_outputs if isinstance(record_outputs, dict) else registry_outputs
+ if not isinstance(base_outputs, dict) or "router_logits" not in base_outputs:
+ continue
+
+ retargeted_outputs = dict(base_outputs)
+ retargeted_outputs["router_logits"] = make_output_recorder(output_capturing.OutputRecorder)
+ module._can_record_outputs = retargeted_outputs
+ registry_entries[registry_key] = retargeted_outputs
+ retargeted += 1
+
+ if retargeted == 0:
+ from deepspeed.utils import logger
+ logger.warning(f"AutoEP: {display_name} conversion did not find a HF output-capture registry "
+ "entry for router_logits.")
+ return 0
+
+ if _install_instance_transformers_output_recorders(
+ model,
+ registry_entries,
+ output_capturing,
+ remove_output_capture_hooks,
+ ):
+ return retargeted
+
+ if recorder_key in retargeted_keys:
+ return retargeted
+ retargeted_keys.add(recorder_key)
+ registry.update(registry_entries)
+ if getattr(model, "_output_capturing_hooks_installed", False):
+ remove_output_capture_hooks(model)
+ model._output_capturing_hooks_installed = False
+ return retargeted
+
+
+class TransformersTopLevelRouterLogitsAdapter(AutoEPPresetAdapter):
+ """Retarget Transformers model-level router-logit recorders to AutoEP."""
+
+ def __init__(
+ self,
+ *,
+ display_name: str,
+ hf_model_types: tuple[str, ...],
+ class_name_fragments: tuple[str, ...],
+ ) -> None:
+ self.display_name = display_name
+ self.hf_model_types = hf_model_types
+ self.class_name_fragments = class_name_fragments
+
+ def adjust_forward_contract(self, contract: ForwardContract) -> ForwardContract:
+ # Mixtral/Qwen3/Qwen2 capture raw router logits through Transformers'
+ # model-level OutputRecorder hooks. AutoEP keeps the MoE block tensor
+ # return contract intact and retargets the recorder to router.gate.
+ return replace(
+ contract,
+ return_router_logits=False,
+ capture_target="router",
+ capture_index=0,
+ router_logits_capture_mode="raw",
+ )
+
+ def retarget_transformers_output_recorders(
+ self,
+ model: nn.Module,
+ spec: MoELayerSpec,
+ replacement: nn.Module,
+ retargeted_keys: set[str],
+ remove_output_capture_hooks: Callable[[nn.Module], int],
+ ) -> None:
+ recorder_key = f"{spec.model_family}:{replacement.__class__.__module__}.{replacement.__class__.__qualname__}"
+
+ router_gate = getattr(getattr(replacement, "router", None), "gate", None)
+ if router_gate is None:
+ return
+
+ def module_matches(module: nn.Module) -> bool:
+ module_config = getattr(module, "config", None)
+ model_type = getattr(module_config, "model_type", None)
+ class_name = module.__class__.__name__
+ return (model_type in self.hf_model_types
+ or any(fragment in class_name for fragment in self.class_name_fragments))
+
+ _retarget_transformers_output_recorders_for_modules(
+ model=model,
+ display_name=self.display_name,
+ recorder_key=recorder_key,
+ retargeted_keys=retargeted_keys,
+ remove_output_capture_hooks=remove_output_capture_hooks,
+ module_matches=module_matches,
+ make_output_recorder=lambda OutputRecorder: OutputRecorder(
+ router_gate.__class__, index=0, layer_name="router.gate"),
+ )
diff --git a/deepspeed/module_inject/auto_ep_presets/deepseek_v2.py b/deepspeed/module_inject/auto_ep_presets/deepseek_v2.py
new file mode 100644
index 000000000000..529187eb45e1
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_presets/deepseek_v2.py
@@ -0,0 +1,76 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""DeepSeek-V2 AutoEP preset and parser adapter."""
+
+from __future__ import annotations
+
+from deepspeed.module_inject.auto_ep_presets.base import (
+ AutoEPConfig,
+ AutoEPPresetAdapter,
+ GroupRoutingConfig,
+ MoEModelPreset,
+)
+
+PRESET_NAME = "deepseek_v2"
+
+PRESET = MoEModelPreset(
+ moe_layer_pattern=r"model\.layers\.\d+\.mlp",
+ router_pattern="gate",
+ experts_pattern="experts",
+ expert_storage="fused_3d",
+ expert_w1="gate_up_proj",
+ expert_w2="down_proj",
+ expert_w3=None,
+ num_experts_attr="n_routed_experts",
+ top_k_attr="num_experts_per_tok",
+ score_func="softmax",
+ score_apply="post",
+ route_norm=False,
+ gate_bias=False,
+ has_shared_experts=True,
+ shared_experts_pattern="shared_experts",
+ autoep_config_defaults={"load_balance_coeff": None},
+ supports_expert_bias=False,
+ preset_adapter="deepseek_v2",
+ hf_model_types=("deepseek_v2", ),
+ min_transformers_version="5.0.0",
+ docs_support_notes=("load_balance_coeff / expert-bias auxiliary-loss-free load balancing "
+ "is not currently supported; non-null values are rejected."),
+)
+
+
+class DeepSeekV2PresetAdapter(AutoEPPresetAdapter):
+ """DeepSeek-V2 keeps native top-k normalization and optional group-limited routing."""
+
+ def _requires_transformers_version_validation(self) -> bool:
+ return True
+
+ def resolve_route_norm(
+ self,
+ config: AutoEPConfig,
+ preset: MoEModelPreset,
+ model_config,
+ ) -> bool:
+ if config.route_norm is not None:
+ return config.route_norm
+ return preset.route_norm
+
+ def resolve_group_routing(
+ self,
+ config: AutoEPConfig,
+ model_config,
+ ) -> GroupRoutingConfig:
+ group_routing = super().resolve_group_routing(config, model_config)
+ if getattr(model_config, 'topk_method', None) != "group_limited_greedy":
+ return group_routing
+
+ return GroupRoutingConfig(
+ num_expert_groups=group_routing.num_expert_groups or getattr(model_config, 'n_group', None),
+ num_limited_groups=group_routing.num_limited_groups or getattr(model_config, 'topk_group', None),
+ group_score_func="max",
+ )
+
+
+PRESET_ADAPTERS = {
+ "deepseek_v2": DeepSeekV2PresetAdapter(),
+}
diff --git a/deepspeed/module_inject/auto_ep_presets/deepseek_v3.py b/deepspeed/module_inject/auto_ep_presets/deepseek_v3.py
new file mode 100644
index 000000000000..7bf422166f88
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_presets/deepseek_v3.py
@@ -0,0 +1,97 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""DeepSeek-V3 AutoEP preset and parser adapter."""
+
+from __future__ import annotations
+
+from dataclasses import replace
+
+import torch.nn as nn
+
+from deepspeed.module_inject.auto_ep_presets.base import AutoEPConfig, AutoEPPresetAdapter, GroupRoutingConfig, MoEModelPreset
+
+PRESET_NAME = "deepseek_v3"
+
+PRESET = MoEModelPreset(
+ moe_layer_pattern=r"model\.layers\.\d+\.mlp",
+ router_pattern="gate",
+ experts_pattern="experts",
+ expert_storage="fused_3d",
+ expert_w1="gate_up_proj",
+ expert_w2="down_proj",
+ expert_w3=None,
+ num_experts_attr="n_routed_experts",
+ top_k_attr="num_experts_per_tok",
+ score_func="sigmoid",
+ score_apply="post",
+ route_norm=False,
+ gate_bias=False,
+ has_shared_experts=True,
+ shared_experts_pattern="shared_experts",
+ autoep_config_defaults={"load_balance_coeff": None},
+ supports_expert_bias=False,
+ unsupported_router_bias_names=("e_score_correction_bias", ),
+ preset_adapter="deepseek_v3",
+ hf_model_types=("deepseek_v3", ),
+ min_transformers_version="5.0.0",
+ docs_support_notes=("load_balance_coeff / expert-bias auxiliary-loss-free load balancing "
+ "is not currently supported; non-null values are rejected."),
+)
+
+
+class DeepSeekV3PresetAdapter(AutoEPPresetAdapter):
+ """DeepSeek-V3 always carries group-limited routing fields when present."""
+
+ def _requires_transformers_version_validation(self) -> bool:
+ return True
+
+ def resolve_group_routing(
+ self,
+ config: AutoEPConfig,
+ model_config,
+ ) -> GroupRoutingConfig:
+ group_routing = super().resolve_group_routing(config, model_config)
+ return GroupRoutingConfig(
+ num_expert_groups=group_routing.num_expert_groups or getattr(model_config, 'n_group', None),
+ num_limited_groups=group_routing.num_limited_groups or getattr(model_config, 'topk_group', None),
+ group_score_func=group_routing.group_score_func,
+ )
+
+ def resolve_expert_layout(
+ self,
+ experts_module: nn.Module,
+ preset: MoEModelPreset,
+ ) -> MoEModelPreset:
+ if not isinstance(experts_module, nn.ModuleList) or len(experts_module) == 0:
+ return preset
+
+ default_fused_layout = (preset.expert_storage == "fused_3d" and preset.expert_w1 == "gate_up_proj"
+ and preset.expert_w2 == "down_proj" and preset.expert_w3 is None)
+ if not default_fused_layout:
+ return preset
+
+ expert0 = experts_module[0]
+ if not all(_has_expert_projection(expert0, name) for name in ("gate_proj", "up_proj", "down_proj")):
+ return preset
+
+ return replace(
+ preset,
+ expert_storage="module_list",
+ expert_w1="gate_proj",
+ expert_w2="down_proj",
+ expert_w3="up_proj",
+ )
+
+
+def _has_expert_projection(expert_module: nn.Module, name: str) -> bool:
+ projection = getattr(expert_module, name, None)
+ if projection is None:
+ return False
+ if isinstance(projection, (nn.Linear, nn.Parameter)):
+ return True
+ return hasattr(projection, "weight")
+
+
+PRESET_ADAPTERS = {
+ "deepseek_v3": DeepSeekV3PresetAdapter(),
+}
diff --git a/deepspeed/module_inject/auto_ep_presets/llama4.py b/deepspeed/module_inject/auto_ep_presets/llama4.py
new file mode 100644
index 000000000000..23469a43f7af
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_presets/llama4.py
@@ -0,0 +1,55 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""Llama4 AutoEP preset and parser adapter."""
+
+from __future__ import annotations
+
+from dataclasses import replace
+
+from deepspeed.module_inject.auto_ep_presets.base import AutoEPPresetAdapter, ForwardContract, MoEModelPreset
+
+PRESET_NAME = "llama4"
+
+PRESET = MoEModelPreset(
+ moe_layer_pattern=r"model\.layers\.\d+\.feed_forward",
+ router_pattern="router",
+ experts_pattern="experts",
+ expert_storage="fused_3d",
+ expert_w1="gate_up_proj",
+ expert_w2="down_proj",
+ expert_w3=None,
+ num_experts_attr="num_local_experts",
+ top_k_attr="num_experts_per_tok",
+ score_func="sigmoid",
+ score_apply="pre",
+ route_norm=False,
+ gate_bias=False,
+ has_shared_experts=True,
+ shared_experts_pattern="shared_expert",
+ autoep_config_defaults={"load_balance_coeff": None},
+ preset_adapter="llama4",
+ hf_model_types=("llama4", "llama4_text"),
+ min_transformers_version="5.0.0",
+)
+
+
+class Llama4PresetAdapter(AutoEPPresetAdapter):
+ """Llama4 MoE returns a flat hidden-state tensor with raw router logits."""
+
+ def adjust_forward_contract(self, contract: ForwardContract) -> ForwardContract:
+ capture_target = contract.capture_target
+ if capture_target == "none":
+ capture_target = "router"
+
+ return replace(
+ contract,
+ return_router_logits=True,
+ capture_target=capture_target,
+ router_logits_capture_mode="raw",
+ moe_output_shape="flat",
+ )
+
+
+PRESET_ADAPTERS = {
+ "llama4": Llama4PresetAdapter(),
+}
diff --git a/deepspeed/module_inject/auto_ep_presets/mixtral.py b/deepspeed/module_inject/auto_ep_presets/mixtral.py
new file mode 100644
index 000000000000..8be0f27c1f7b
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_presets/mixtral.py
@@ -0,0 +1,37 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""Mixtral AutoEP preset."""
+
+from __future__ import annotations
+
+from deepspeed.module_inject.auto_ep_presets.base import MoEModelPreset, TransformersTopLevelRouterLogitsAdapter
+
+PRESET_NAME = "mixtral"
+
+PRESET = MoEModelPreset(
+ moe_layer_pattern=r"model\.layers\.\d+\.mlp",
+ router_pattern="gate",
+ experts_pattern="experts",
+ expert_storage="fused_3d",
+ expert_w1="gate_up_proj",
+ expert_w2="down_proj",
+ expert_w3=None,
+ num_experts_attr="num_local_experts",
+ top_k_attr="num_experts_per_tok",
+ score_func="softmax",
+ score_apply="post",
+ route_norm=True,
+ gate_bias=False,
+ preset_adapter="mixtral",
+ hf_model_types=("mixtral", ),
+ min_transformers_version="5.0.0",
+)
+
+PRESET_ADAPTERS = {
+ "mixtral":
+ TransformersTopLevelRouterLogitsAdapter(
+ display_name="Mixtral",
+ hf_model_types=("mixtral", ),
+ class_name_fragments=("Mixtral", ),
+ ),
+}
diff --git a/deepspeed/module_inject/auto_ep_presets/qwen3_5_moe.py b/deepspeed/module_inject/auto_ep_presets/qwen3_5_moe.py
new file mode 100644
index 000000000000..0b17e9c27de5
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_presets/qwen3_5_moe.py
@@ -0,0 +1,97 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""Qwen3.5-MoE AutoEP preset and parser adapter."""
+
+from __future__ import annotations
+
+from dataclasses import replace
+from typing import Callable
+
+import torch.nn as nn
+
+from deepspeed.module_inject.auto_ep_presets.base import (
+ AutoEPPresetAdapter,
+ ForwardContract,
+ MoELayerSpec,
+ MoEModelPreset,
+ _retarget_transformers_output_recorders_for_modules,
+)
+
+PRESET_NAME = "qwen3_5_moe"
+
+PRESET = MoEModelPreset(
+ moe_layer_pattern=r"model\.layers\.\d+\.mlp",
+ router_pattern="gate",
+ experts_pattern="experts",
+ expert_storage="fused_3d",
+ expert_w1="gate_up_proj",
+ expert_w2="down_proj",
+ expert_w3=None,
+ num_experts_attr="num_experts",
+ top_k_attr="num_experts_per_tok",
+ score_func="softmax",
+ score_apply="post",
+ route_norm=True,
+ gate_bias=False,
+ has_shared_experts=True,
+ shared_experts_pattern="shared_expert",
+ shared_experts_gate_pattern="shared_expert_gate",
+ preset_adapter="qwen3_5_moe",
+ hf_model_types=("qwen3_5_moe_text", ),
+ unsupported_hf_model_type_notes={
+ "qwen3_5_moe": ("AutoEP supports the Qwen3.5 text backbone preset path; pass the "
+ "text-backbone model/config with model_type='qwen3_5_moe_text'.")
+ },
+ min_transformers_version="5.2.0",
+ docs_support_notes="Requires the Qwen3.5 text-backbone qwen3_5_moe_text model type.",
+)
+
+
+class Qwen35MoePresetAdapter(AutoEPPresetAdapter):
+ """Qwen3.5 MoE exposes router logits through HF output recording."""
+
+ def _requires_transformers_version_validation(self) -> bool:
+ return True
+
+ def adjust_forward_contract(self, contract: ForwardContract) -> ForwardContract:
+ # HF records Qwen3.5 router output on Qwen3_5MoeTopKRouter. AutoEP replaces
+ # the owning MoE block, so replacement output index 1 is used for recorder retargeting.
+ return replace(
+ contract,
+ return_router_logits=True,
+ capture_target="router",
+ capture_index=1,
+ )
+
+ def retarget_transformers_output_recorders(
+ self,
+ model: nn.Module,
+ spec: MoELayerSpec,
+ replacement: nn.Module,
+ retargeted_keys: set[str],
+ remove_output_capture_hooks: Callable[[nn.Module], int],
+ ) -> None:
+ recorder_key = f"{spec.model_family}:{replacement.__class__.__module__}.{replacement.__class__.__qualname__}"
+
+ replacement_cls = replacement.__class__
+
+ def module_matches(module: nn.Module) -> bool:
+ module_config = getattr(module, "config", None)
+ model_type = getattr(module_config, "model_type", None)
+ class_name = module.__class__.__name__
+ return model_type == "qwen3_5_moe_text" or "Qwen3_5Moe" in class_name
+
+ _retarget_transformers_output_recorders_for_modules(
+ model=model,
+ display_name="Qwen3.5 AutoEP",
+ recorder_key=recorder_key,
+ retargeted_keys=retargeted_keys,
+ remove_output_capture_hooks=remove_output_capture_hooks,
+ module_matches=module_matches,
+ make_output_recorder=lambda OutputRecorder: OutputRecorder(replacement_cls, index=1),
+ )
+
+
+PRESET_ADAPTERS = {
+ "qwen3_5_moe": Qwen35MoePresetAdapter(),
+}
diff --git a/deepspeed/module_inject/auto_ep_presets/qwen3_moe.py b/deepspeed/module_inject/auto_ep_presets/qwen3_moe.py
new file mode 100644
index 000000000000..8535c448e843
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_presets/qwen3_moe.py
@@ -0,0 +1,42 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""Qwen3-MoE AutoEP preset."""
+
+from __future__ import annotations
+
+from deepspeed.module_inject.auto_ep_presets.base import MoEModelPreset, TransformersTopLevelRouterLogitsAdapter
+
+PRESET_NAME = "qwen3_moe"
+
+PRESET = MoEModelPreset(
+ moe_layer_pattern=r"model\.layers\.\d+\.mlp",
+ router_pattern="gate",
+ experts_pattern="experts",
+ expert_storage="fused_3d",
+ expert_w1="gate_up_proj",
+ expert_w2="down_proj",
+ expert_w3=None,
+ num_experts_attr="num_experts",
+ top_k_attr="num_experts_per_tok",
+ score_func="softmax",
+ score_apply="post",
+ route_norm=True,
+ gate_bias=False,
+ has_shared_experts=True,
+ shared_experts_pattern="shared_expert",
+ shared_experts_gate_pattern="shared_expert_gate",
+ preset_adapter="qwen3_moe",
+ hf_model_types=("qwen3_moe", "qwen2_moe"),
+ min_transformers_version="5.0.0",
+ docs_support_notes=("Also covers Qwen2-MoE when the installed Transformers build uses the "
+ "validated fused expert layout."),
+)
+
+PRESET_ADAPTERS = {
+ "qwen3_moe":
+ TransformersTopLevelRouterLogitsAdapter(
+ display_name="Qwen3-MoE/Qwen2-MoE",
+ hf_model_types=("qwen3_moe", "qwen2_moe"),
+ class_name_fragments=("Qwen3Moe", "Qwen2Moe"),
+ ),
+}
diff --git a/deepspeed/module_inject/auto_ep_presets/registry.py b/deepspeed/module_inject/auto_ep_presets/registry.py
new file mode 100644
index 000000000000..7d5397d432cb
--- /dev/null
+++ b/deepspeed/module_inject/auto_ep_presets/registry.py
@@ -0,0 +1,212 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""AutoEP preset registry and config override helpers."""
+
+from __future__ import annotations
+
+import copy
+from dataclasses import replace
+
+from deepspeed.module_inject.auto_ep_presets.base import (
+ _UNSET,
+ AutoEPConfig,
+ AutoEPPresetAdapter,
+ MoEModelPreset,
+)
+from deepspeed.module_inject.auto_ep_presets import deepseek_v2, deepseek_v3, llama4, mixtral, qwen3_5_moe, qwen3_moe
+from deepspeed.utils import logger
+
+_PRESET_MODULES = (
+ mixtral,
+ qwen3_moe,
+ qwen3_5_moe,
+ deepseek_v2,
+ deepseek_v3,
+ llama4,
+)
+
+PRESET_MODELS: dict[str, MoEModelPreset] = {module.PRESET_NAME: module.PRESET for module in _PRESET_MODULES}
+
+_PRESET_ADAPTERS: dict[str, AutoEPPresetAdapter] = {
+ "default": AutoEPPresetAdapter(),
+}
+for _module in _PRESET_MODULES:
+ _PRESET_ADAPTERS.update(getattr(_module, "PRESET_ADAPTERS", {}))
+
+
+def _validate_registered_preset_adapters(
+ preset_models: dict[str, MoEModelPreset] | None = None,
+ preset_adapters: dict[str, AutoEPPresetAdapter] | None = None,
+) -> None:
+ """Fail fast if a registered preset references an adapter that is not registered."""
+ preset_models = PRESET_MODELS if preset_models is None else preset_models
+ preset_adapters = _PRESET_ADAPTERS if preset_adapters is None else preset_adapters
+
+ missing_presets = []
+ for preset_name, preset in preset_models.items():
+ if preset.preset_adapter not in preset_adapters:
+ missing_presets.append((preset_name, preset.preset_adapter))
+
+ if not missing_presets:
+ return
+
+ details = ", ".join(f"{preset_name}:{adapter_name}" for preset_name, adapter_name in missing_presets)
+ raise RuntimeError(f"AutoEP preset registry is inconsistent; missing preset_adapter registration(s): {details}")
+
+
+_validate_registered_preset_adapters()
+
+_PRESET_DEFAULT_EXPLICIT_FLAGS = {
+ "load_balance_coeff": "_load_balance_coeff_explicit",
+}
+
+
+def available_preset_names() -> tuple[str, ...]:
+ """Return built-in AutoEP preset names in registry order."""
+ return tuple(PRESET_MODELS.keys())
+
+
+def get_preset(preset_name: str) -> MoEModelPreset:
+ """Return a registered AutoEP preset by name."""
+ preset = PRESET_MODELS.get(preset_name)
+ if preset is None:
+ raise ValueError(f"Unknown preset_model '{preset_name}'. Available presets: {list(available_preset_names())}")
+ return preset
+
+
+def get_preset_adapter(adapter_name: str) -> AutoEPPresetAdapter:
+ """Return a registered AutoEP preset adapter by name."""
+ adapter = _PRESET_ADAPTERS.get(adapter_name)
+ if adapter is None:
+ raise ValueError(f"Unknown AutoEP preset adapter '{adapter_name}'")
+ return adapter
+
+
+def preset_name_for_hf_model_type(model_type: str) -> str | None:
+ """Return the AutoEP preset name for a supported HF model_type."""
+ for preset_name, preset in PRESET_MODELS.items():
+ if model_type in preset.hf_model_types:
+ return preset_name
+ return None
+
+
+def unsupported_preset_for_hf_model_type(model_type: str) -> tuple[str, MoEModelPreset] | None:
+ """Return a preset carrying an actionable diagnostic for an unsupported HF model_type."""
+ for preset_name, preset in PRESET_MODELS.items():
+ if model_type in preset.unsupported_hf_model_type_notes:
+ return preset_name, preset
+ return None
+
+
+def resolve_autoep_config_defaults(config: AutoEPConfig, preset_name: str | None) -> AutoEPConfig:
+ """Return config with preset-level AutoEP defaults applied where the user did not override."""
+ if preset_name is None or preset_name not in PRESET_MODELS:
+ return config
+
+ preset_defaults = PRESET_MODELS[preset_name].autoep_config_defaults
+ if not preset_defaults:
+ return config
+
+ resolved = copy.copy(config)
+ for field_name, default_value in preset_defaults.items():
+ explicit_flag = _PRESET_DEFAULT_EXPLICIT_FLAGS.get(field_name)
+ if explicit_flag is None:
+ continue
+ if not getattr(config, explicit_flag, False):
+ setattr(resolved, field_name, default_value)
+ return resolved
+
+
+def apply_config_overrides(config: AutoEPConfig, preset: MoEModelPreset) -> MoEModelPreset:
+ """Apply explicit AutoEP config overrides to a preset.
+
+ Return the original preset object when there are no overrides. When overrides
+ are present, return a dataclass copy so the registered preset remains unchanged.
+ """
+ overrides = {}
+ if config.moe_layer_pattern is not None:
+ overrides["moe_layer_pattern"] = config.moe_layer_pattern
+ if config.router_pattern is not None:
+ overrides["router_pattern"] = config.router_pattern
+ if config.expert_pattern is not None:
+ overrides["experts_pattern"] = config.expert_pattern
+ if config.expert_w1 is not None:
+ overrides["expert_w1"] = config.expert_w1
+ if config.expert_w2 is not None:
+ overrides["expert_w2"] = config.expert_w2
+ if config.expert_w3 is not _UNSET:
+ overrides["expert_w3"] = config.expert_w3
+ if config.num_experts_attr is not None:
+ overrides["num_experts_attr"] = config.num_experts_attr
+ if config.top_k_attr is not None:
+ overrides["top_k_attr"] = config.top_k_attr
+ if config.has_shared_experts is not None:
+ overrides["has_shared_experts"] = config.has_shared_experts
+ if config.shared_experts_pattern is not None:
+ overrides["shared_experts_pattern"] = config.shared_experts_pattern
+ if config.shared_experts_gate_pattern is not None:
+ overrides["shared_experts_gate_pattern"] = config.shared_experts_gate_pattern
+ if not overrides:
+ return preset
+ return replace(preset, **overrides)
+
+
+def resolve_preset_candidates(
+ config: AutoEPConfig,
+ model_config,
+) -> list[tuple[str, MoEModelPreset]]:
+ """Resolve ordered preset candidates for AutoEP detection."""
+ if config.preset_model is not None:
+ preset = apply_config_overrides(config, get_preset(config.preset_model))
+ _validate_preset_compatibility(config.preset_model, preset, model_config)
+ return [(config.preset_model, preset)]
+
+ if model_config is not None:
+ model_type = getattr(model_config, 'model_type', None)
+ if model_type:
+ preset_name = preset_name_for_hf_model_type(model_type)
+ if preset_name is not None:
+ logger.info(f"AutoEP: auto-detected model_type='{model_type}', using preset '{preset_name}'")
+ preset = apply_config_overrides(config, get_preset(preset_name))
+ _validate_preset_compatibility(preset_name, preset, model_config)
+ return [(preset_name, preset)]
+
+ unsupported_preset = unsupported_preset_for_hf_model_type(model_type)
+ if unsupported_preset is not None:
+ preset_name, preset = unsupported_preset
+ _validate_preset_compatibility(preset_name, preset, model_config)
+
+ if config.moe_layer_pattern:
+ return [("custom", _build_custom_preset(config))]
+
+ return [(name, apply_config_overrides(config, preset)) for name, preset in PRESET_MODELS.items()]
+
+
+def _validate_preset_compatibility(
+ preset_name: str,
+ preset: MoEModelPreset,
+ model_config,
+) -> None:
+ adapter = get_preset_adapter(preset.preset_adapter)
+ adapter.validate_compatibility(preset_name, preset, model_config)
+
+
+def _build_custom_preset(config: AutoEPConfig) -> MoEModelPreset:
+ return MoEModelPreset(
+ moe_layer_pattern=config.moe_layer_pattern,
+ router_pattern=config.router_pattern or "gate",
+ experts_pattern=config.expert_pattern or "experts",
+ expert_storage="fused_3d",
+ expert_w1=config.expert_w1 or "gate_up_proj",
+ expert_w2=config.expert_w2 or "down_proj",
+ expert_w3=(None if config.expert_w3 is _UNSET else config.expert_w3),
+ num_experts_attr=config.num_experts_attr or "num_local_experts",
+ top_k_attr=config.top_k_attr or "num_experts_per_tok",
+ score_func=(config.score_func if config.score_func != "auto" else "softmax"),
+ score_apply=(config.score_apply if config.score_apply != "auto" else "post"),
+ route_norm=(config.route_norm if config.route_norm is not None else True),
+ gate_bias=False,
+ has_shared_experts=(config.has_shared_experts if config.has_shared_experts is not None else False),
+ shared_experts_pattern=config.shared_experts_pattern or "",
+ shared_experts_gate_pattern=config.shared_experts_gate_pattern or "",
+ )
diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py
index 852c492f8b8e..4e47278e52c5 100755
--- a/deepspeed/module_inject/auto_tp.py
+++ b/deepspeed/module_inject/auto_tp.py
@@ -354,6 +354,10 @@ def _replace(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
+ # Skip AutoEP-managed modules (expert weights are EP-sharded, not TP-sharded)
+ if getattr(child, "_is_autoep_layer", False):
+ return child
+
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
@@ -546,6 +550,9 @@ def update_linear_policies(self):
def _replace_module(self, r_module, prev_name='', prev_class_name=''):
for name, child in r_module.named_children():
+ if getattr(child, "_is_autoep_layer", False):
+ continue
+
if prev_class_name == "":
class_name = prev_name
elif prev_name == "":
diff --git a/deepspeed/moe/ep_count.py b/deepspeed/moe/ep_count.py
new file mode 100644
index 000000000000..4b8d863d80bb
--- /dev/null
+++ b/deepspeed/moe/ep_count.py
@@ -0,0 +1,41 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""Helpers for expert token counting in AutoEP routing paths."""
+
+import torch
+
+from deepspeed.accelerator import get_accelerator
+
+
+def count_tokens_per_expert(
+ selected_experts_indices: torch.Tensor,
+ num_experts: int,
+ *,
+ out_dtype: torch.dtype = torch.float32,
+ deterministic_safe: bool = False,
+) -> torch.Tensor:
+ """Count routed tokens per expert.
+
+ Fast path uses ``torch.bincount`` on the current device.
+ If ``deterministic_safe=True`` and deterministic algorithms are enabled
+ on CUDA, this falls back to CPU bincount to avoid non-deterministic kernel
+ restrictions.
+ """
+ flat_indices = selected_experts_indices.reshape(-1).to(torch.int64)
+
+ if deterministic_safe and torch.are_deterministic_algorithms_enabled() and get_accelerator().on_accelerator(
+ flat_indices):
+ counts = torch.bincount(flat_indices.detach().cpu(), minlength=num_experts)
+ counts = counts.to(selected_experts_indices.device)
+ else:
+ counts = torch.bincount(flat_indices, minlength=num_experts)
+
+ if counts.numel() < num_experts:
+ pad = torch.zeros(num_experts - counts.numel(), device=counts.device, dtype=counts.dtype)
+ counts = torch.cat([counts, pad], dim=0)
+ elif counts.numel() > num_experts:
+ counts = counts[:num_experts]
+
+ return counts.to(out_dtype)
diff --git a/deepspeed/moe/ep_experts.py b/deepspeed/moe/ep_experts.py
new file mode 100644
index 000000000000..113986ea9952
--- /dev/null
+++ b/deepspeed/moe/ep_experts.py
@@ -0,0 +1,191 @@
+# Copyright (c) DeepSpeed Team.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
+#
+# Portions of this file are derived from TorchTitan.
+# See THIRD_PARTY_NOTICES.md for the BSD-3-Clause notice.
+
+# DeepSpeed Team
+"""
+Grouped expert computation for expert parallelism.
+
+Ported from TorchTitan's GroupedExperts with adaptations for DeepSpeed:
+ - Replaced hardcoded .bfloat16() with input-dtype-aware casting
+ - Fail-fast RuntimeError when use_grouped_mm=True but torch._grouped_mm is unavailable
+ - Removed DTensor-specific code paths
+
+This module is self-contained: no imports from deepspeed.module_inject
+or deepspeed.runtime.
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# ---------------------------------------------------------------------------
+# Expert computation: sequential for-loop (reference path)
+# ---------------------------------------------------------------------------
+
+
+def _run_experts_for_loop(
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w3: torch.Tensor,
+ x: torch.Tensor,
+ num_tokens_per_expert: torch.Tensor,
+) -> torch.Tensor:
+ """Compute SwiGLU expert MLP via a sequential for-loop over experts.
+
+ This is the reference implementation that works on all PyTorch versions.
+
+ Args:
+ w1: Gate-up weight, shape ``(E, hidden_dim, dim)``.
+ w2: Down weight, shape ``(E, dim, hidden_dim)``.
+ w3: Up weight, shape ``(E, hidden_dim, dim)``.
+ x: Input tokens, shape ``(T, dim)``.
+ num_tokens_per_expert: Token counts per expert, shape ``(E,)``.
+
+ Returns:
+ Output tensor of shape ``(T, dim)``.
+ """
+ # NOTE: .tolist() incurs a device-host synchronization
+ num_tokens_per_expert_list = num_tokens_per_expert.tolist()
+
+ # Handle padding rows injected by generate_permute_indices
+ num_padding = x.shape[0] - sum(num_tokens_per_expert_list)
+
+ x_splits = torch.split(
+ x[:sum(num_tokens_per_expert_list)],
+ split_size_or_sections=num_tokens_per_expert_list,
+ dim=0,
+ )
+
+ cast_dtype = x.dtype
+ out_experts_splits = []
+ for expert_idx, x_expert in enumerate(x_splits):
+ w1_e = w1[expert_idx].to(cast_dtype).transpose(-2, -1)
+ w3_e = w3[expert_idx].to(cast_dtype).transpose(-2, -1)
+ w2_e = w2[expert_idx].to(cast_dtype).transpose(-2, -1)
+ h = F.silu(torch.matmul(x_expert, w1_e))
+ h = h * torch.matmul(x_expert, w3_e)
+ h = torch.matmul(h, w2_e)
+ out_experts_splits.append(h)
+
+ out = torch.cat(out_experts_splits, dim=0)
+
+ # Re-add padding rows (zeros) so output shape matches input shape
+ out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
+
+ return out
+
+
+# ---------------------------------------------------------------------------
+# Expert computation: grouped GEMM (torch._grouped_mm)
+# ---------------------------------------------------------------------------
+
+
+def _run_experts_grouped_mm(
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w3: torch.Tensor,
+ x: torch.Tensor,
+ num_tokens_per_expert: torch.Tensor,
+) -> torch.Tensor:
+ """Compute SwiGLU expert MLP via torch._grouped_mm (grouped GEMM).
+
+ Uses input dtype for casting instead of hardcoded bfloat16.
+
+ Args:
+ w1: Gate-up weight, shape ``(E, hidden_dim, dim)``.
+ w2: Down weight, shape ``(E, dim, hidden_dim)``.
+ w3: Up weight, shape ``(E, hidden_dim, dim)``.
+ x: Input tokens, shape ``(T, dim)``.
+ num_tokens_per_expert: Token counts per expert, shape ``(E,)``.
+
+ Returns:
+ Output tensor of shape ``(T, dim)``.
+ """
+ offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
+
+ cast_dtype = x.dtype
+ h = F.silu(torch._grouped_mm(
+ x.to(cast_dtype),
+ w1.to(cast_dtype).transpose(-2, -1),
+ offs=offsets,
+ ))
+ h = h * torch._grouped_mm(
+ x.to(cast_dtype),
+ w3.to(cast_dtype).transpose(-2, -1),
+ offs=offsets,
+ )
+ out = torch._grouped_mm(
+ h,
+ w2.to(cast_dtype).transpose(-2, -1),
+ offs=offsets,
+ ).type_as(x)
+
+ return out
+
+
+# ---------------------------------------------------------------------------
+# GroupedExperts module
+# ---------------------------------------------------------------------------
+
+
+class GroupedExperts(nn.Module):
+ """Grouped expert computation for MoE layers.
+
+ Supports two execution paths:
+ - **grouped_mm**: Uses ``torch._grouped_mm`` for fused grouped GEMM
+ (requires a sufficiently recent PyTorch build).
+ - **for-loop**: Sequential per-expert matmuls; always available.
+
+ If ``use_grouped_mm=True`` but ``torch._grouped_mm`` is not available, the
+ constructor raises ``RuntimeError``. Set ``use_grouped_mm=False`` to select
+ the sequential for-loop path without checking ``torch._grouped_mm``.
+
+ Args:
+ dim (int): Input / output dimension.
+ hidden_dim (int): Hidden dimension of the SwiGLU FFN.
+ num_experts (int): Number of experts.
+ use_grouped_mm (bool): Whether to attempt using grouped GEMM.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ num_experts: int,
+ use_grouped_mm: bool = True,
+ ):
+ super().__init__()
+ self.num_experts = num_experts
+ self.w1 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
+ self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
+ self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
+
+ if use_grouped_mm and not hasattr(torch, "_grouped_mm"):
+ raise RuntimeError("GroupedExperts was constructed with use_grouped_mm=True but "
+ "torch._grouped_mm is not available in this PyTorch build. "
+ "Upgrade PyTorch to a build that provides torch._grouped_mm, or "
+ "set use_grouped_mm=False to use the sequential expert loop.")
+ self.use_grouped_mm = use_grouped_mm
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ num_tokens_per_expert: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x: Input tokens, shape ``(T, dim)``.
+ num_tokens_per_expert: Token counts per expert, shape ``(E,)``.
+
+ Returns:
+ Output tensor of shape ``(T, dim)``.
+ """
+ if self.use_grouped_mm:
+ return _run_experts_grouped_mm(self.w1, self.w2, self.w3, x, num_tokens_per_expert)
+ else:
+ return _run_experts_for_loop(self.w1, self.w2, self.w3, x, num_tokens_per_expert)
diff --git a/deepspeed/moe/ep_kernels.py b/deepspeed/moe/ep_kernels.py
new file mode 100644
index 000000000000..e5eaa1067157
--- /dev/null
+++ b/deepspeed/moe/ep_kernels.py
@@ -0,0 +1,381 @@
+# Copyright (c) DeepSpeed Team.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
+#
+# Portions of this file are derived from TorchTitan.
+# See THIRD_PARTY_NOTICES.md for the BSD-3-Clause notice.
+
+# DeepSpeed Team
+"""
+Token reordering and permutation utilities for expert parallelism.
+
+Ported from TorchTitan's TokenReorderer, Triton kernels, and alignment
+utilities with adaptations for DeepSpeed:
+ - Triton import guarded with try/except; pure-PyTorch fallback provided
+ - Alignment config exposed as TOKEN_GROUP_ALIGN_SIZE_M
+
+This module is self-contained: no imports from deepspeed.module_inject
+or deepspeed.runtime.
+"""
+
+import logging
+from typing import Callable
+
+import torch
+import torch.nn as nn
+
+from deepspeed.moe.ep_count import count_tokens_per_expert
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# Try to import Triton; fall back gracefully
+# ---------------------------------------------------------------------------
+
+_TRITON_AVAILABLE = False
+try:
+ import triton
+ import triton.language as tl
+
+ _TRITON_AVAILABLE = True
+except ImportError:
+ logger.info("Triton not available; using pure-PyTorch CPU fallback for "
+ "permutation index generation.")
+
+# ---------------------------------------------------------------------------
+# Alignment constant
+# ---------------------------------------------------------------------------
+
+TOKEN_GROUP_ALIGN_SIZE_M = 8
+"""Alignment granularity for token groups in grouped GEMM.
+
+ - bf16: 8 (16 bytes / 2 bytes per elem)
+ - fp8: 16 (16 bytes / 1 byte per elem)
+ - mxfp8: 32 (scaling block size)
+"""
+
+# ---------------------------------------------------------------------------
+# Utility: round up
+# ---------------------------------------------------------------------------
+
+
+def _round_up(x: int, y: int) -> int:
+ """Round *x* up to the nearest multiple of *y*."""
+ return ((x + y - 1) // y) * y
+
+
+# ===================================================================
+# Triton kernel for filling permutation indices
+# ===================================================================
+
+if _TRITON_AVAILABLE:
+
+ @triton.jit
+ def _fill_indices_kernel(
+ tokens_per_expert_group_ptr,
+ start_index_values_ptr,
+ write_offsets_ptr,
+ output_ptr,
+ experts_per_rank: tl.constexpr,
+ num_ranks: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ pid = tl.program_id(axis=0)
+ num_programs = tl.num_programs(axis=0)
+
+ for expert_id in range(pid, experts_per_rank, num_programs):
+ write_offset = tl.load(write_offsets_ptr + expert_id)
+
+ for r in range(num_ranks):
+ i = r * experts_per_rank + expert_id
+ start_index = tl.load(start_index_values_ptr + i)
+ length = tl.load(tokens_per_expert_group_ptr + i)
+
+ offsets = tl.arange(0, BLOCK_SIZE)
+ for chunk_start in range(0, length, BLOCK_SIZE):
+ chunk_offsets = chunk_start + offsets
+ mask = chunk_offsets < length
+ values = start_index + chunk_offsets
+ dest_indices = write_offset + chunk_offsets
+ tl.store(output_ptr + dest_indices, values, mask=mask)
+
+ write_offset += length
+
+
+# ===================================================================
+# Triton wrapper
+# ===================================================================
+
+
+def fill_indices_wrapper(
+ tokens_per_expert_group: torch.Tensor,
+ start_index_values: torch.Tensor,
+ write_offsets: torch.Tensor,
+ experts_per_rank: int,
+ num_ranks: int,
+ max_len: int,
+ block_size: int = 128,
+ max_blocks: int = 1024,
+) -> torch.Tensor:
+ """Launch the Triton kernel to fill permutation indices.
+
+ Falls back to :func:`fill_indices_cpu` when Triton is unavailable.
+ """
+ if not _TRITON_AVAILABLE:
+ return fill_indices_cpu(
+ tokens_per_expert_group,
+ start_index_values,
+ write_offsets,
+ experts_per_rank,
+ num_ranks,
+ max_len,
+ )
+
+ permuted_indices = torch.full((max_len, ), -1, dtype=torch.int32, device=tokens_per_expert_group.device)
+
+ num_blocks = min(experts_per_rank, max_blocks)
+ grid = (num_blocks, )
+
+ _fill_indices_kernel[grid](
+ tokens_per_expert_group,
+ start_index_values,
+ write_offsets,
+ permuted_indices,
+ experts_per_rank,
+ num_ranks,
+ BLOCK_SIZE=block_size,
+ )
+ return permuted_indices
+
+
+# ===================================================================
+# CPU reference implementation (always available)
+# ===================================================================
+
+
+def fill_indices_cpu(
+ tokens_per_expert_group: torch.Tensor,
+ start_index_values: torch.Tensor,
+ write_offsets: torch.Tensor,
+ experts_per_rank: int,
+ num_ranks: int,
+ max_len: int,
+) -> torch.Tensor:
+ """Pure-PyTorch CPU reference for filling permutation indices."""
+ permuted_indices = torch.full(
+ (max_len, ),
+ -1,
+ dtype=torch.int32,
+ )
+ for e in range(experts_per_rank):
+ write_start = write_offsets[e].item()
+ for r in range(num_ranks):
+ i = r * experts_per_rank + e
+ start_index = start_index_values[i].item()
+ length = tokens_per_expert_group[i].item()
+ if length > 0:
+ end_idx = min(write_start + length, max_len)
+ permuted_indices[write_start:end_idx] = torch.arange(
+ start_index,
+ start_index + (end_idx - write_start),
+ dtype=torch.int32,
+ )
+ write_start += length
+ return permuted_indices
+
+
+# ===================================================================
+# generate_permute_indices
+# ===================================================================
+
+
+def generate_permute_indices(
+ tokens_per_expert_group: torch.Tensor,
+ experts_per_rank: int,
+ num_ranks: int,
+ max_len: int,
+ alignment: int,
+ use_cpu: bool = False,
+) -> tuple:
+ """Prepare permutation indices and aligned token counts per expert.
+
+ Args:
+ tokens_per_expert_group: Token counts for each expert from all ranks,
+ shape ``(num_ranks * experts_per_rank,)``.
+ experts_per_rank: Number of experts per rank.
+ num_ranks: Number of ranks.
+ max_len: Maximum length of the output index vector.
+ alignment: Alignment for ``m_sizes`` and padding minimum.
+ use_cpu: Whether to force the CPU implementation.
+
+ Returns:
+ Tuple of:
+ - permuted_indices: Index mapping from original to expert-grouped order.
+ - m_sizes: Aligned token counts per expert.
+ - m_offsets: Cumulative sum of m_sizes.
+ """
+ # Prefix sum for start indices
+ start_index_values = (torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group)
+
+ # Total tokens per expert across all ranks
+ total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
+
+ # Pad empty experts to alignment minimum
+ total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
+
+ # Align chunk sizes (ceiling division * alignment)
+ m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(torch.int32)
+
+ # Write offsets per local expert
+ m_offsets = torch.cumsum(m_sizes, 0)
+ write_offsets = m_offsets - m_sizes
+
+ if use_cpu:
+ permuted_indices = fill_indices_cpu(
+ tokens_per_expert_group,
+ start_index_values,
+ write_offsets,
+ experts_per_rank,
+ num_ranks,
+ max_len,
+ )
+ else:
+ permuted_indices = fill_indices_wrapper(
+ tokens_per_expert_group,
+ start_index_values,
+ write_offsets,
+ experts_per_rank,
+ num_ranks,
+ max_len,
+ )
+
+ return permuted_indices, m_sizes, m_offsets.to(torch.int32)
+
+
+# ===================================================================
+# _permute / _unpermute / indices_padding_wrapper
+# ===================================================================
+
+
+def _permute(
+ x: torch.Tensor,
+ num_tokens_per_expert: torch.Tensor,
+ ep_degree: int,
+ num_local_experts: int,
+) -> tuple:
+ """Permute tokens into expert-grouped order with alignment padding.
+
+ Returns:
+ Tuple of (input_shape, permuted_x, permuted_indices, aligned_counts).
+ """
+ global TOKEN_GROUP_ALIGN_SIZE_M
+ x_padded_per_expert = x.shape[0] + num_local_experts * TOKEN_GROUP_ALIGN_SIZE_M
+ padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M)
+
+ with torch.no_grad():
+ permuted_indices, num_tokens_per_expert, _offsets = generate_permute_indices(
+ num_tokens_per_expert,
+ num_local_experts,
+ ep_degree,
+ padded_max_len,
+ TOKEN_GROUP_ALIGN_SIZE_M,
+ )
+
+ # Append a single zero-row for safe indexing of padding slots
+ x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
+ input_shape = x.shape
+ x = x[permuted_indices, :]
+
+ return input_shape, x, permuted_indices, num_tokens_per_expert
+
+
+def _unpermute(
+ out: torch.Tensor,
+ input_shape: torch.Size,
+ permuted_indices: torch.Tensor,
+) -> torch.Tensor:
+ """Reverse the permutation produced by :func:`_permute`."""
+ out_unpermuted = out.new_empty(input_shape)
+ out_unpermuted[permuted_indices, :] = out
+ # Strip the extra zero-row appended during _permute
+ out = out_unpermuted[:-1]
+ return out
+
+
+def indices_padding_wrapper(func: Callable) -> Callable:
+ """Decorator that pads / aligns token groups for ``torch._grouped_mm``.
+
+ Wraps an expert-computation function so that each expert's token
+ count is a multiple of ``TOKEN_GROUP_ALIGN_SIZE_M``.
+ """
+
+ def wrapper(
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w3: torch.Tensor,
+ x: torch.Tensor,
+ num_tokens_per_expert: torch.Tensor,
+ ) -> torch.Tensor:
+ num_local_experts = w1.shape[0]
+ ep_degree = num_tokens_per_expert.shape[0] // num_local_experts
+
+ input_shape, x, permuted_indices, num_tokens_per_expert = _permute(x, num_tokens_per_expert, ep_degree,
+ num_local_experts)
+
+ out = func(w1, w2, w3, x, num_tokens_per_expert)
+
+ out = _unpermute(out, input_shape, permuted_indices)
+ return out
+
+ return wrapper
+
+
+# ===================================================================
+# TokenReorderer
+# ===================================================================
+
+
+class TokenReorderer(nn.Module):
+ """Reorder token indices to match expert order for efficient parallel
+ processing.
+
+ Args:
+ num_experts (int): Number of experts in the MoE layer.
+ top_k (int): Number of experts each token is routed to.
+ """
+
+ def __init__(self, num_experts: int, top_k: int):
+ super().__init__()
+ self.num_experts = num_experts
+ self.top_k = top_k
+
+ def forward(
+ self,
+ top_scores: torch.Tensor,
+ selected_experts_indices: torch.Tensor,
+ ) -> tuple:
+ """
+ Args:
+ top_scores: Routing scores, shape ``(T, top_k)``.
+ selected_experts_indices: Expert indices, shape ``(T, top_k)``.
+
+ Returns:
+ Tuple of:
+ - top_scores_experts_sorted ``(T * top_k,)``: scores in
+ expert-sorted order.
+ - token_indices_experts_sorted ``(T * top_k,)``: flattened
+ token-slot indices sorted by expert.
+ - num_tokens_per_expert ``(num_experts,)``: histogram.
+ """
+ num_tokens_per_expert = count_tokens_per_expert(selected_experts_indices, self.num_experts)
+
+ token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True)
+
+ top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted]
+
+ return (
+ top_scores_experts_sorted,
+ token_indices_experts_sorted,
+ num_tokens_per_expert,
+ )
diff --git a/deepspeed/moe/ep_repack.py b/deepspeed/moe/ep_repack.py
new file mode 100644
index 000000000000..84a1e60b669a
--- /dev/null
+++ b/deepspeed/moe/ep_repack.py
@@ -0,0 +1,174 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""Expert weight repacking for AutoEP.
+
+Converts HuggingFace expert weight formats into TorchTitan-compatible
+grouped tensors [E_local, hidden_dim, dim] for grouped GEMM.
+"""
+
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+from deepspeed.module_inject.auto_ep_config import MoELayerSpec
+from deepspeed.moe.fused_expert_layout import classify_fused_gate_up_layout
+
+
+def repack_expert_weights(
+ experts_source: nn.Module,
+ spec: MoELayerSpec,
+ ep_rank: int,
+ ep_size: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Repack expert weights from HF format to TorchTitan grouped format.
+
+ Returns (w1, w2, w3) where:
+ w1: [E_local, ffn_hidden_size, hidden_size]
+ w2: [E_local, hidden_size, ffn_hidden_size]
+ w3: [E_local, ffn_hidden_size, hidden_size]
+
+ For fused_3d storage where expert_w3 is None (gate+up fused):
+ Standard HF layout:
+ Source gate_up_proj: [E, 2*ffn_hidden, hidden]
+ Source down_proj: [E, hidden, ffn_hidden]
+
+ Llama4 layout:
+ Source gate_up_proj: [E, hidden, 2*ffn_hidden]
+ Source down_proj: [E, ffn_hidden, hidden]
+
+ In both cases, the returned grouped-expert tensors are normalized to:
+ w1 = gate_proj: [E_local, ffn_hidden, hidden]
+ w3 = up_proj: [E_local, ffn_hidden, hidden]
+ w2 = down_proj: [E_local, hidden, ffn_hidden]
+ """
+ num_local_experts = spec.num_experts // ep_size
+ expert_start = ep_rank * num_local_experts
+ expert_end = expert_start + num_local_experts
+
+ if spec.expert_storage == "fused_3d":
+ return _repack_fused_3d(experts_source, spec, expert_start, expert_end)
+ elif spec.expert_storage == "module_list":
+ return _repack_module_list(experts_source, spec, expert_start, expert_end)
+ else:
+ raise ValueError(f"Unknown expert_storage type: {spec.expert_storage}")
+
+
+def _repack_fused_3d(
+ experts_source: nn.Module,
+ spec: MoELayerSpec,
+ expert_start: int,
+ expert_end: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Repack from fused 3D parameter tensors (transformers 5.0.0+)."""
+ w1_full = getattr(experts_source, spec.expert_w1_name)
+ w2_full = getattr(experts_source, spec.expert_w2_name)
+
+ if isinstance(w1_full, nn.Parameter):
+ w1_full = w1_full.data
+ if isinstance(w2_full, nn.Parameter):
+ w2_full = w2_full.data
+
+ # Slice to local experts
+ w1_local = w1_full[expert_start:expert_end].clone()
+ w2_local = w2_full[expert_start:expert_end].clone()
+
+ if spec.expert_w3_name is None:
+ layout = classify_fused_gate_up_layout(tuple(w1_local.shape), tuple(w2_local.shape))
+ if layout is None:
+ raise ValueError("Unsupported fused expert weight layout for AutoEP repacking: "
+ f"{spec.expert_w1_name}={tuple(w1_local.shape)}, "
+ f"{spec.expert_w2_name}={tuple(w2_local.shape)}")
+
+ ffn_hidden = layout.ffn_hidden_size
+ if layout.layout == "gate_up_first":
+ w1 = w1_local[:, :ffn_hidden, :].contiguous() # [E_local, ffn, hidden]
+ w3 = w1_local[:, ffn_hidden:, :].contiguous() # [E_local, ffn, hidden]
+ w2 = w2_local.contiguous() # [E_local, hidden, ffn]
+ else:
+ w1 = w1_local[:, :, :ffn_hidden].transpose(1, 2).contiguous() # [E_local, ffn, hidden]
+ w3 = w1_local[:, :, ffn_hidden:].transpose(1, 2).contiguous() # [E_local, ffn, hidden]
+ w2 = w2_local.transpose(1, 2).contiguous() # [E_local, hidden, ffn]
+ else:
+ # Separate w1 (gate), w3 (up)
+ w3_full = getattr(experts_source, spec.expert_w3_name)
+ if isinstance(w3_full, nn.Parameter):
+ w3_full = w3_full.data
+ w3_local = w3_full[expert_start:expert_end].clone()
+
+ w1 = w1_local.contiguous() # [E_local, ffn, hidden]
+ w2 = w2_local.contiguous() # [E_local, hidden, ffn]
+ w3 = w3_local.contiguous() # [E_local, ffn, hidden]
+
+ return w1, w2, w3
+
+
+def _repack_module_list(
+ experts_source: nn.Module,
+ spec: MoELayerSpec,
+ expert_start: int,
+ expert_end: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Repack from nn.ModuleList of individual expert modules (legacy transformers)."""
+ assert isinstance(experts_source, nn.ModuleList), \
+ f"Expected nn.ModuleList for module_list storage, got {type(experts_source)}"
+
+ w1_list = []
+ w2_list = []
+ w3_list = []
+
+ for expert_idx in range(expert_start, expert_end):
+ expert = experts_source[expert_idx]
+
+ # Get weight tensors - handle both nn.Linear children and direct attributes
+ w1_param = _get_expert_weight(expert, spec.expert_w1_name)
+ w2_param = _get_expert_weight(expert, spec.expert_w2_name)
+
+ # nn.Linear stores weight as [out_features, in_features]
+ # TorchTitan expects [ffn_hidden, hidden] for w1/w3 and [hidden, ffn_hidden] for w2
+ # nn.Linear.weight is already [out, in] which matches TorchTitan's [ffn, hidden] for w1
+ # No transpose needed - store as-is
+ w1_list.append(w1_param.data.clone())
+ w2_list.append(w2_param.data.clone())
+
+ if spec.expert_w3_name is not None:
+ w3_param = _get_expert_weight(expert, spec.expert_w3_name)
+ w3_list.append(w3_param.data.clone())
+
+ w1 = torch.stack(w1_list) # [E_local, ffn_hidden, hidden]
+ w2 = torch.stack(w2_list) # [E_local, hidden, ffn_hidden]
+
+ if spec.expert_w3_name is not None:
+ w3 = torch.stack(w3_list) # [E_local, ffn_hidden, hidden]
+ else:
+ # If no w3, this is fused gate+up - split w1
+ ffn_hidden = w1.shape[1] // 2
+ w3 = w1[:, ffn_hidden:, :].contiguous()
+ w1 = w1[:, :ffn_hidden, :].contiguous()
+
+ return w1, w2, w3
+
+
+def _get_expert_weight(expert_module: nn.Module, weight_name: str) -> torch.Tensor:
+ """Get expert weight tensor by name, handling both attribute and child module patterns."""
+ # Direct attribute
+ param = getattr(expert_module, weight_name, None)
+ if param is not None:
+ if isinstance(param, nn.Linear):
+ return param.weight
+ if isinstance(param, (nn.Parameter, torch.Tensor)):
+ return param
+
+ # Try as child module name
+ for name, child in expert_module.named_children():
+ if name == weight_name:
+ if isinstance(child, nn.Linear):
+ return child.weight
+ if hasattr(child, 'weight'):
+ return child.weight
+
+ raise ValueError(f"Could not find weight '{weight_name}' in expert module "
+ f"{type(expert_module).__name__}. Available attributes: "
+ f"{[n for n, _ in expert_module.named_parameters(recurse=False)]}")
diff --git a/deepspeed/moe/ep_router.py b/deepspeed/moe/ep_router.py
new file mode 100644
index 000000000000..9fee8ea80207
--- /dev/null
+++ b/deepspeed/moe/ep_router.py
@@ -0,0 +1,180 @@
+# Copyright (c) DeepSpeed Team.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
+#
+# Portions of this file are derived from TorchTitan.
+# See THIRD_PARTY_NOTICES.md for the BSD-3-Clause notice.
+
+# DeepSpeed Team
+"""
+Token-choice top-K router for expert parallelism.
+
+Ported from TorchTitan's TokenChoiceTopKRouter with adaptations for DeepSpeed.
+This module is self-contained: no imports from deepspeed.module_inject
+or deepspeed.runtime.
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from deepspeed.moe.ep_count import count_tokens_per_expert
+
+
+class TokenChoiceTopKRouter(nn.Module):
+ """Token-choice top-K routing for Mixture of Experts.
+
+ Each token is routed to top-K experts based on router scores.
+ Optionally supports node-limited (group-limited) routing where experts
+ are divided into groups (e.g., by node), and only ``num_limited_groups``
+ groups are considered before selecting top_k experts. This reduces
+ cross-node communication in distributed settings.
+
+ Args:
+ dim (int): Dimension of input tokens.
+ num_experts (int): Number of experts in each MoE layer.
+ num_expert_groups (int | None): Number of expert groups for
+ node-limited routing. If None, standard top-k routing is used.
+ Must be a divisor of num_experts.
+ num_limited_groups (int | None): Number of groups to select in
+ node-limited routing. Required when num_expert_groups is set.
+ top_k (int): Number of experts each token will be routed to.
+ score_func (str): ``"softmax"`` or ``"sigmoid"`` scoring function.
+ route_norm (bool): Whether to normalize routing scores.
+ route_scale (float): Scaling factor applied to routing scores.
+ gate_bias (bool): Whether to include a bias term in the gate linear.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_experts: int,
+ num_expert_groups: int | None,
+ num_limited_groups: int | None,
+ top_k: int,
+ score_func: str,
+ route_norm: bool,
+ route_scale: float,
+ gate_bias: bool,
+ group_score_func: str = "top2_sum",
+ ):
+ super().__init__()
+ self.gate = nn.Linear(dim, num_experts, bias=gate_bias)
+ self.num_experts = num_experts
+ self.num_expert_groups = num_expert_groups
+ self.num_limited_groups = num_limited_groups
+ self.top_k = top_k
+ self.score_func = score_func
+ self.route_norm = route_norm
+ self.route_scale = route_scale
+ self.group_score_func = group_score_func
+
+ # ------------------------------------------------------------------
+ # Node-limited (group-limited) routing
+ # ------------------------------------------------------------------
+
+ def _get_node_limited_routing_scores(
+ self,
+ scores_for_choice: torch.Tensor,
+ ) -> torch.Tensor:
+ """Select ``num_limited_groups`` groups based on group scores and
+ mask out experts in non-selected groups.
+
+ Args:
+ scores_for_choice: Router scores with optional expert_bias,
+ shape ``(T, num_experts)``.
+
+ Returns:
+ Masked scores of the same shape, with non-selected group
+ entries set to ``-inf``.
+ """
+ if self.num_limited_groups is None:
+ raise ValueError("num_limited_groups must be set when num_expert_groups is set")
+ assert self.num_expert_groups is not None
+ if self.num_limited_groups < 1:
+ raise ValueError(f"num_limited_groups must be >= 1, got {self.num_limited_groups}")
+ if self.num_experts % self.num_expert_groups != 0:
+ raise ValueError(f"num_experts ({self.num_experts}) must be divisible by "
+ f"num_expert_groups ({self.num_expert_groups})")
+
+ experts_per_group = self.num_experts // self.num_expert_groups
+
+ scores_grouped = scores_for_choice.view(-1, self.num_expert_groups, experts_per_group)
+ if self.group_score_func == "max":
+ group_scores = scores_grouped.max(dim=-1).values
+ elif self.group_score_func == "top2_sum":
+ if experts_per_group < 2:
+ raise ValueError(f"experts_per_group ({experts_per_group}) must be >= 2")
+ # DeepSeek-V3 scores each group by the sum of its top-2 experts.
+ top2_scores_in_group, _ = scores_grouped.topk(2, dim=-1)
+ group_scores = top2_scores_in_group.sum(dim=-1)
+ else:
+ raise NotImplementedError(f"Unknown group score function: {self.group_score_func}")
+
+ # Select top groups
+ _, group_idx = torch.topk(group_scores, k=self.num_limited_groups, dim=-1, sorted=False)
+
+ # Build mask: True = masked out (non-selected groups)
+ group_mask = torch.ones_like(group_scores, dtype=torch.bool)
+ group_mask.scatter_(1, group_idx, False)
+
+ scores_for_choice = scores_grouped.masked_fill(group_mask.unsqueeze(-1),
+ float("-inf")).view(-1, self.num_experts)
+
+ return scores_for_choice
+
+ # ------------------------------------------------------------------
+ # Forward
+ # ------------------------------------------------------------------
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ expert_bias: torch.Tensor | None = None,
+ ) -> tuple:
+ """
+ Args:
+ x: Input tensor of shape ``(T, dim)``.
+ expert_bias: Optional bias tensor of shape ``(num_experts,)``
+ used for load balancing.
+
+ Returns:
+ Tuple of:
+ - top_scores ``(T, top_k)``: routing weights for selected experts.
+ - selected_experts ``(T, top_k)``: expert indices per token.
+ - num_tokens_per_expert ``(num_experts,)``: histogram of token counts.
+ """
+ # Gate projection -> (T, num_experts)
+ scores = self.gate(x)
+
+ # Scoring in float32 to avoid loss explosion
+ if self.score_func == "sigmoid":
+ scores = torch.sigmoid(scores.to(torch.float32))
+ elif self.score_func == "softmax":
+ scores = F.softmax(scores.to(torch.float32), dim=1)
+ else:
+ raise NotImplementedError(f"Unknown score function: {self.score_func}")
+
+ scores_for_choice = (scores if expert_bias is None else scores + expert_bias)
+
+ # Apply node-limited routing if configured
+ if self.num_expert_groups is not None:
+ scores_for_choice = self._get_node_limited_routing_scores(scores_for_choice)
+
+ # Select top-k experts per token
+ _, selected_experts_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
+
+ # Gather original (unbiased) scores for selected experts
+ top_scores = scores.gather(dim=1, index=selected_experts_indices)
+
+ # Optional normalization
+ if self.route_norm:
+ denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
+ top_scores = top_scores / denominator
+
+ top_scores = top_scores * self.route_scale
+
+ num_tokens_per_expert = count_tokens_per_expert(selected_experts_indices, self.num_experts)
+
+ return top_scores, selected_experts_indices, num_tokens_per_expert
diff --git a/deepspeed/moe/fused_expert_layout.py b/deepspeed/moe/fused_expert_layout.py
new file mode 100644
index 000000000000..4578fb9de5af
--- /dev/null
+++ b/deepspeed/moe/fused_expert_layout.py
@@ -0,0 +1,47 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""Shape helpers for fused gate/up expert tensors."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Literal
+
+FusedLayout = Literal["gate_up_first", "hidden_first"]
+
+
+@dataclass(frozen=True)
+class FusedExpertLayout:
+ layout: FusedLayout
+ hidden_size: int
+ ffn_hidden_size: int
+ needs_transpose: bool
+
+
+def classify_fused_gate_up_layout(
+ w1_shape: tuple[int, ...],
+ w2_shape: tuple[int, ...],
+) -> FusedExpertLayout | None:
+ """Classify fused gate/up expert weights from raw tensor shapes."""
+ if len(w1_shape) != 3 or len(w2_shape) != 3:
+ return None
+
+ if w1_shape[1] % 2 == 0 and w2_shape[1:] == (w1_shape[2], w1_shape[1] // 2):
+ return FusedExpertLayout(
+ layout="gate_up_first",
+ hidden_size=w1_shape[2],
+ ffn_hidden_size=w1_shape[1] // 2,
+ needs_transpose=False,
+ )
+
+ if w1_shape[2] % 2 == 0 and w2_shape[1:] == (w1_shape[2] // 2, w1_shape[1]):
+ return FusedExpertLayout(
+ layout="hidden_first",
+ hidden_size=w1_shape[1],
+ ffn_hidden_size=w1_shape[2] // 2,
+ needs_transpose=True,
+ )
+
+ return None
diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py
index 213b5c659499..747585f6c17f 100644
--- a/deepspeed/runtime/base_optimizer.py
+++ b/deepspeed/runtime/base_optimizer.py
@@ -18,6 +18,21 @@ class DeepSpeedOptimizer(object):
pass
+def _get_universal_checkpoint_ep_info() -> tuple[int, int]:
+ # Universal checkpoints use EP slicing only when an expert group exists.
+ try:
+ from deepspeed.utils import groups
+ expert_groups = groups._get_expert_parallel_group_dict()
+ if not expert_groups:
+ return 0, 1
+ max_ep_name = groups._get_max_expert_size_name()
+ if max_ep_name not in expert_groups:
+ return 0, 1
+ return groups._get_expert_parallel_rank(max_ep_name), groups._get_expert_parallel_world_size(max_ep_name)
+ except (RuntimeError, AttributeError, KeyError):
+ return 0, 1
+
+
class BackwardHookStateManager:
"""Manages backward pass state for ZeRO optimizers.
@@ -314,6 +329,8 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()
+ ep_rank, ep_size = _get_universal_checkpoint_ep_info()
+
for i, (param_group,
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
# We have an assumption that all params in the same param_group have the same keys
@@ -324,8 +341,11 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
for lp in lp_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
- step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
- tp_world_size)
+ step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]),
+ tp_rank,
+ tp_world_size,
+ ep_rank=ep_rank,
+ ep_size=ep_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
steps.append(step)
diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py
index ec3833cbdcc6..012d977d31ec 100755
--- a/deepspeed/runtime/config.py
+++ b/deepspeed/runtime/config.py
@@ -66,6 +66,7 @@
from ..utils.config import get_timers_config
TENSOR_CORE_ALIGN_SIZE = 8
+EXPERT_PARALLEL = "expert_parallel"
ADAGRAD_OPTIMIZER = 'adagrad'
ADAM_OPTIMIZER = 'adam'
@@ -124,6 +125,14 @@ def __repr__(self):
)
+def get_expert_parallel_config(param_dict):
+ if EXPERT_PARALLEL in param_dict:
+ from deepspeed.module_inject.auto_ep_config import parse_autoep_config
+ return parse_autoep_config(param_dict[EXPERT_PARALLEL])
+ from deepspeed.module_inject.auto_ep_config import AutoEPConfig
+ return AutoEPConfig()
+
+
def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], PLD_ENABLED, PLD_ENABLED_DEFAULT)
@@ -870,6 +879,7 @@ def _initialize_params(self, param_dict):
self.timers_config = get_timers_config(param_dict)
self.tensor_parallel_config = get_tensor_parallel_config(param_dict)
+ self.expert_parallel_config = get_expert_parallel_config(param_dict)
def _batch_assertion(self):
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 1e02ab12baef..90521dee5d8f 100755
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -277,6 +277,7 @@ def __init__(self,
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
+ self._configure_expert_parallel(model)
if self.autotp_size() > 1:
self._configure_tensor_parallel(model, self.tensor_parallel_config())
see_memory_usage("DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
@@ -498,6 +499,56 @@ def _optimized_linear_offload_setup(self):
else:
p.ds_offload = False
+ def _configure_expert_parallel(self, model):
+ """Initialize AutoEP: detect MoE layers, create EP groups, replace with EP-enabled layers."""
+ autoep_config = self._config.expert_parallel_config
+ if autoep_config is None or not autoep_config.enabled:
+ return
+
+ from deepspeed.module_inject.auto_ep import AutoEP
+ from deepspeed.module_inject.auto_ep_config import validate_autoep_config, validate_autoep_post_detection
+
+ ep_size = autoep_config.autoep_size
+ tp_size = self.autotp_size()
+ sp_size = self._autoep_sequence_parallel_world_size()
+ pp_size = 1
+ if self.mpu is not None:
+ from deepspeed.utils.bwc import bwc_pipeline_parallel_world_size
+ pp_size = bwc_pipeline_parallel_world_size(self.mpu)
+
+ world_size = dist.get_world_size()
+ validate_autoep_config(autoep_config, world_size, pp_size, tp_size, sp_size)
+
+ # Create EP/EDP process groups
+ mp_size = max(tp_size, sp_size, 1)
+ mp_mode = "tp" if tp_size > 1 else "sp"
+ groups._create_expert_and_data_parallel(
+ expert_parallel_size_=ep_size,
+ mp_size=mp_size,
+ pp_size=pp_size,
+ mp_mode=mp_mode,
+ use_data_before_expert_parallel_=self._config.use_data_before_expert_parallel_,
+ )
+
+ # Derive EP rank
+ ep_group_name = f"ep_size_{ep_size}"
+ ep_group = groups._get_expert_parallel_group(ep_group_name)
+ ep_rank = dist.get_rank(group=ep_group)
+
+ # Detect and replace MoE layers
+ auto_ep = AutoEP(model, autoep_config)
+ specs = auto_ep.ep_parser()
+
+ if specs:
+ validate_autoep_post_detection(autoep_config, specs)
+ auto_ep.replace_moe_layers(specs, ep_size=ep_size, ep_rank=ep_rank)
+ logger.info(f"AutoEP: replaced {len(specs)} MoE layer(s) with ep_size={ep_size}")
+
+ def _autoep_sequence_parallel_world_size(self):
+ if self.mpu is not None and hasattr(self.mpu, 'get_sequence_parallel_world_size'):
+ return self.mpu.get_sequence_parallel_world_size()
+ return groups._get_sequence_parallel_world_size()
+
def _configure_tensor_parallel(self, model, tp_config):
self._configure_tensor_parallel_states(model)
configure_tensor_parallel_runtime(tp_config)
@@ -1469,10 +1520,17 @@ def _configure_distributed_model(self, model):
self.module.to(self.device)
# MoE related initialization
+ try:
+ from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer
+ except ImportError:
+ _AutoEPMoELayer = None
for _, module in self.module.named_modules():
if isinstance(module, MoE):
self.has_moe_layers = True
self.num_experts.append(module.num_experts)
+ elif _AutoEPMoELayer is not None and isinstance(module, _AutoEPMoELayer):
+ self.has_moe_layers = True
+ self.num_experts.append(module.num_experts)
if self.has_moe_layers:
for _, module in self.module.named_modules():
@@ -2585,9 +2643,9 @@ def scale(self, loss):
# Apply loss scaler based on optimizer type
scaled_loss = loss
if isinstance(self.optimizer, ZeROOptimizer):
- scaled_loss = self.optimizer.scale_if_loss(loss)
+ scaled_loss = self.optimizer.scale_if_loss(scaled_loss)
elif self.torch_autocast_z0_gradscaler:
- scaled_loss = self.torch_autocast_z0_gradscaler.scale(loss)
+ scaled_loss = self.torch_autocast_z0_gradscaler.scale(scaled_loss)
# Mark that scale() was called for validation in backward hook
self._manual_backward_expected = True
@@ -3248,8 +3306,20 @@ def load_moe_state_dict(checkpoint_path,
model=None,
mpu=None,
num_experts=1,
- checkpoint_engine=TorchCheckpointEngine()):
+ checkpoint_engine=TorchCheckpointEngine(),
+ autoep_layers=None):
+ try:
+ from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer
+ except ImportError:
+ _AutoEPMoELayer = None
+
+ has_autoep_layers = _AutoEPMoELayer is not None and model is not None and any(
+ isinstance(m, _AutoEPMoELayer) for _, m in model.named_modules())
+
if old_moe_load:
+ if has_autoep_layers:
+ raise RuntimeError("Legacy checkpoint format (old_moe_load) is incompatible with AutoEP layers. "
+ "Use Universal Checkpointing to convert the checkpoint first.")
expp_rank = groups._get_expert_data_parallel_rank(groups._get_max_expert_size_name())
num_local_experts = max(num_experts) // groups._get_expert_parallel_world_size(
@@ -3274,6 +3344,30 @@ def load_moe_state_dict(checkpoint_path,
state_dict.update(expert_state_dict)
else:
+ # Validate AutoEP metadata if present
+ if autoep_layers is not None:
+ if not isinstance(autoep_layers, list):
+ raise RuntimeError(
+ f"ds_autoep_layers metadata is malformed: expected list, got {type(autoep_layers).__name__}")
+ seen_ids = set()
+ required_fields = {
+ 'moe_layer_id', 'module_path', 'num_experts', 'num_local_experts', 'ep_size', 'expert_key_prefix'
+ }
+ for entry in autoep_layers:
+ if not isinstance(entry, dict):
+ raise RuntimeError(
+ f"ds_autoep_layers entry is malformed: expected dict, got {type(entry).__name__}")
+ missing = required_fields - entry.keys()
+ if missing:
+ raise RuntimeError(f"ds_autoep_layers entry is invalid: missing fields {sorted(missing)}")
+ lid = entry['moe_layer_id']
+ if lid in seen_ids:
+ raise RuntimeError(f"ds_autoep_layers metadata has duplicate moe_layer_id: {lid}")
+ seen_ids.add(lid)
+ elif has_autoep_layers:
+ logger.warning("Checkpoint does not contain ds_autoep_layers metadata. "
+ "Loading AutoEP expert weights using best-effort module detection.")
+
moe_layer_id = 0
for n_module, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
@@ -3296,6 +3390,43 @@ def load_moe_state_dict(checkpoint_path,
state_dict.update(expert_state_dict)
moe_layer_id += 1
+ elif _AutoEPMoELayer is not None and isinstance(module, _AutoEPMoELayer):
+ group_name = module.ep_group_name
+ num_local_experts = module.num_local_experts
+ expp_rank = groups._get_expert_parallel_rank(group_name)
+ module_prefix = f"{n_module}." if n_module else ""
+
+ # Collect per-expert tensors to stack
+ stacked = {wname: [] for wname in ('w1', 'w2', 'w3')}
+
+ for local_expert_id in range(num_local_experts):
+ global_expert_id = expp_rank * num_local_experts + local_expert_id
+ expert_ckpt_path = DeepSpeedEngine._get_expert_ckpt_name(checkpoint_path, moe_layer_id,
+ global_expert_id, tag, mpu)
+ if not os.path.exists(expert_ckpt_path):
+ raise FileNotFoundError(f"Expert checkpoint file not found: {expert_ckpt_path}. "
+ f"Expected layer_{moe_layer_id} expert_{global_expert_id}.")
+ expert_sd = checkpoint_engine.load(expert_ckpt_path, map_location=torch.device('cpu'))
+
+ for wname in ('w1', 'w2', 'w3'):
+ fused_key = f"{module_prefix}experts.{wname}"
+ expert_key = f"{fused_key}.{global_expert_id}"
+ if expert_key not in expert_sd:
+ raise RuntimeError(f"Expert checkpoint file is corrupt: key '{expert_key}' not found "
+ f"in {expert_ckpt_path}")
+ tensor = expert_sd[expert_key]
+ if tensor.dim() != 2:
+ raise RuntimeError(f"Expert checkpoint file is corrupt: expected 2D tensor for "
+ f"'{expert_key}', got {tensor.dim()}D in {expert_ckpt_path}")
+ stacked[wname].append(tensor)
+
+ # Stack back to fused [E_local, ...] format
+ for wname in ('w1', 'w2', 'w3'):
+ fused_key = f"{module_prefix}experts.{wname}"
+ state_dict[fused_key] = torch.stack(stacked[wname], dim=0)
+
+ moe_layer_id += 1
+
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False):
if fetch_z3_params:
params_to_fetch = [
@@ -3526,11 +3657,17 @@ def _load_checkpoint(self,
# Pipeline parallelism uses this to load its own checkpoint files.
self._curr_ckpt_path = os.path.join(load_dir, tag)
- if self.has_moe_layers:
+ # Universal Checkpoint restores parameters from the zero/ layout, so
+ # do not require regular MoE expert checkpoint files in that path.
+ if self.has_moe_layers and not self.load_universal_checkpoint():
# print(checkpoint.keys())
old_moe_load = False
if not isinstance(checkpoint['num_experts'], list):
old_moe_load = True
+ from deepspeed.checkpoint.constants import AUTOEP_LAYERS_KEY, AUTOEP_LAYERS_KEY_LEGACY
+ autoep_layers = checkpoint.get(AUTOEP_LAYERS_KEY)
+ if autoep_layers is None:
+ autoep_layers = checkpoint.get(AUTOEP_LAYERS_KEY_LEGACY)
DeepSpeedEngine.load_moe_state_dict(load_dir,
tag,
state_dict=checkpoint['module'],
@@ -3538,7 +3675,8 @@ def _load_checkpoint(self,
model=self.module,
mpu=self.mpu,
num_experts=self.num_experts,
- checkpoint_engine=self.checkpoint_engine)
+ checkpoint_engine=self.checkpoint_engine,
+ autoep_layers=autoep_layers)
if not self.load_universal_checkpoint():
self.load_module_state_dict(checkpoint=checkpoint,
strict=load_module_strict,
@@ -3864,23 +4002,53 @@ def _commit_decoupled_checkpoint(self):
dist.barrier()
def _get_non_moe_state_dict(self, full_state_dict):
+ """Remove expert-param keys from state dict, keeping all non-expert params.
+
+ Handles both native MoE (deepspeed_moe.experts.*) and AutoEP (experts.w1/w2/w3).
+ Preserves: router weights, shared_experts, any legacy/manually-built
+ expert_bias keys, all non-MoE params.
"""
- Get the state dict of the non-moe layers
- """
- for key in list(full_state_dict.keys()):
- if 'expert' in key and 'moe.gate.wg.weight' not in key:
- full_state_dict.pop(key)
+ try:
+ from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer
+ except ImportError:
+ _AutoEPMoELayer = None
+
+ expert_param_keys = set()
+
+ for n_module, module in self.module.named_modules():
+ module_prefix = f"{n_module}." if n_module else ""
+ if isinstance(module, MoE):
+ # Native MoE: remove keys with 'expert' except gate, scoped to this module
+ for key in full_state_dict.keys():
+ if key.startswith(module_prefix) and 'expert' in key and 'moe.gate.wg.weight' not in key:
+ expert_param_keys.add(key)
+ elif _AutoEPMoELayer is not None and isinstance(module, _AutoEPMoELayer):
+ # AutoEP: remove only the fused expert weight keys (w1, w2, w3)
+ experts_prefix = f"{module_prefix}experts."
+ for key in full_state_dict.keys():
+ if key.startswith(experts_prefix) and key[len(experts_prefix):] in ('w1', 'w2', 'w3'):
+ expert_param_keys.add(key)
+
+ for key in expert_param_keys:
+ full_state_dict.pop(key)
return full_state_dict
def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):
save_path = self._get_ckpt_name(save_dir, tag)
+ try:
+ from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer
+ except ImportError:
+ _AutoEPMoELayer = None
+
# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None.
# Using layer_#_export_# to save the model's expert state_dict
+ autoep_layer_info = []
+ autoep_group_names = set()
moe_layer_id = 0
for n_module, module in self.module.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
@@ -3932,6 +4100,51 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
self.checkpoint_engine.save(saveable_state_dict, moe_save_path)
moe_layer_id += 1
+ elif _AutoEPMoELayer is not None and isinstance(module, _AutoEPMoELayer):
+ group_name = module.ep_group_name
+ num_local_experts = module.num_local_experts
+ expp_rank = groups._get_expert_parallel_rank(group_name)
+ exp_dp_rank = groups._get_expert_data_parallel_rank(group_name)
+ module_prefix = f"{n_module}." if n_module else ""
+
+ # Collect metadata on ALL ranks (before writer guard)
+ autoep_layer_info.append({
+ 'moe_layer_id': moe_layer_id,
+ 'module_path': n_module,
+ 'num_experts': module.num_experts,
+ 'num_local_experts': num_local_experts,
+ 'ep_size': module.ep_size,
+ 'expert_key_prefix': f"{module_prefix}experts",
+ })
+ autoep_group_names.add(group_name)
+ if len(autoep_group_names) > 1:
+ raise RuntimeError(f"AutoEP checkpointing requires a single EP group size, but found "
+ f"multiple groups: {sorted(autoep_group_names)}. "
+ f"All AutoEPMoELayer instances must use the same ep_size.")
+
+ # Gate file writes behind writer guard
+ if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank):
+ moe_layer_id += 1
+ continue
+
+ # Slice fused 3D tensors into per-expert state dicts
+ for local_expert_id in range(num_local_experts):
+ global_expert_id = expp_rank * num_local_experts + local_expert_id
+ expert_state_dict = {}
+ for wname in ('w1', 'w2', 'w3'):
+ fused_key = f"{module_prefix}experts.{wname}"
+ param = getattr(module.experts, wname)
+ expert_state_dict[f"{fused_key}.{global_expert_id}"] = (
+ param[local_expert_id].clone().detach())
+
+ moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
+ saveable = expert_state_dict
+ if self.checkpoint_engine.preserves_storage_sharing():
+ saveable = clone_tensors_for_torch_save(expert_state_dict)
+ self.checkpoint_engine.save(saveable, moe_save_path)
+
+ moe_layer_id += 1
+
self._curr_ckpt_path = os.path.join(save_dir, tag)
largest_group_name = groups._get_max_expert_size_name()
@@ -3988,8 +4201,16 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
'mp_world_size':
self.mp_world_size,
'num_experts':
- self.num_experts
+ self.num_experts,
+ 'ds_autoep_layers':
+ autoep_layer_info if autoep_layer_info else None,
}
+ # Check for reserved-key collisions with client_state
+ reserved_keys = {'ds_autoep_layers', 'autoep_layers'}
+ collisions = reserved_keys.intersection(client_state.keys())
+ if collisions:
+ raise KeyError(f"client_state contains reserved checkpoint keys: {sorted(collisions)}. "
+ f"These keys are used internally by DeepSpeed for AutoEP metadata.")
state.update(client_state)
logger.info(f'Saving model checkpoint: {save_path}')
saveable_state_dict = state
diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py
index 2392683db81d..f39f73d20281 100755
--- a/deepspeed/runtime/utils.py
+++ b/deepspeed/runtime/utils.py
@@ -1121,7 +1121,7 @@ def get_norm_with_moe_layers(non_expert_norm, mpu, expert_tensors, norm_type=2):
"""
def to_tensor(v):
- return get_accelerator().FloatTensor(float(v)).detach()
+ return get_accelerator().FloatTensor([float(v)]).detach()
group_norms = [non_expert_norm]
for exp_name, tensors in expert_tensors.items():
diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py
index a6f0a7228977..d912625c544b 100644
--- a/deepspeed/utils/groups.py
+++ b/deepspeed/utils/groups.py
@@ -237,25 +237,47 @@ def _create_model_parallel(model_parallel_size_):
return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
-def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expert_parallel_=False):
- """
- Create expert and data parallel groups.
-
- Note: Caller of this function is responsible to check if the groups already exist.
+def _create_expert_and_data_parallel(expert_parallel_size_,
+ mp_size=None,
+ pp_size=None,
+ mp_mode="tp",
+ use_data_before_expert_parallel_=False):
+ """Create expert and data parallel groups.
+
+ When mp_size is None or 1: legacy consecutive ordering (backward compatible).
+ When mp_size > 1 and mp_mode=="tp": TP-strided rank ordering.
+ When mp_size > 1 and mp_mode=="sp": consecutive rank ordering.
+
+ Note: Caller of this function is responsible to check if the groups already exist.
+
+ Example - E + D parallel (legacy path)
+ world_size = 16
+ expert_parallel_size = 2 # number of experts in same group
+ expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params
+ expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all
+ data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE
- Example - E + D parallel
- world_size = 16
- expert_parallel_size = 2 # number of experts in same group
- expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params
- expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all
- data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE
- use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology
+ Args:
+ expert_parallel_size_ (int): Expert parallel group size.
+ mp_size (int, optional): Model parallel size (TP or SP). None treated as 1.
+ pp_size (int, optional): Pipeline parallel size. None falls back to mpu.
+ mp_mode (str): "tp" for TP-strided ordering, "sp" for consecutive ordering.
+ use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology.
"""
assert dist.is_initialized()
+ # Resolve parameters for backward compat
+ effective_mp_size = 1 if mp_size is None else mp_size
+
log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0])
world_size = dist.get_world_size()
- pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
+
+ # Resolve pp_size
+ if pp_size is not None:
+ pp_world_size = pp_size
+ else:
+ pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu)
+
rank = dist.get_rank()
pp_stride = world_size // pp_world_size
@@ -263,37 +285,49 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe
group_name = f"ep_size_{expert_parallel_size_}"
- # Build the expert data parallel groups.
global _EXPERT_DATA_PARALLEL_GROUP
global _EXPERT_DATA_PARALLEL_GROUP_RANKS
-
- ep_stride = pp_stride // expert_parallel_size_
-
- # Only create group if it does not already exist
- if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
- for pp_stage_start in range(0, world_size, pp_stride):
- for i in range(expert_parallel_size_):
- if use_data_before_expert_parallel_:
- ranks = range(pp_stage_start + i * ep_stride, pp_stage_start + (i + 1) * ep_stride)
- else:
- ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, expert_parallel_size_)
- group = dist.new_group(ranks)
- log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}',
- [0])
- if rank in ranks:
- _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
- _EXPERT_DATA_PARALLEL_GROUP_RANKS[group_name] = ranks
-
- # Build the expert parallel groups.
global _EXPERT_PARALLEL_GROUP
global _EXPERT_PARALLEL_GROUP_RANKS
- # Only create group if it does not already exist
- if group_name not in _EXPERT_PARALLEL_GROUP:
- if use_data_before_expert_parallel_:
+ # Legacy path: mp_size <= 1 (preserves exact original behavior)
+ if effective_mp_size <= 1:
+ ep_stride = pp_stride // expert_parallel_size_
+
+ # Build the expert data parallel groups.
+ # Only create group if it does not already exist
+ if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
for pp_stage_start in range(0, world_size, pp_stride):
- for i in range(ep_stride):
- ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, ep_stride)
+ for i in range(expert_parallel_size_):
+ if use_data_before_expert_parallel_:
+ ranks = range(pp_stage_start + i * ep_stride, pp_stage_start + (i + 1) * ep_stride)
+ else:
+ ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, expert_parallel_size_)
+ group = dist.new_group(ranks)
+ log_dist(
+ f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}',
+ [0])
+ if rank in ranks:
+ _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
+ _EXPERT_DATA_PARALLEL_GROUP_RANKS[group_name] = ranks
+
+ # Build the expert parallel groups.
+ # Only create group if it does not already exist
+ if group_name not in _EXPERT_PARALLEL_GROUP:
+ if use_data_before_expert_parallel_:
+ for pp_stage_start in range(0, world_size, pp_stride):
+ for i in range(ep_stride):
+ ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, ep_stride)
+ group = dist.new_group(ranks)
+ log_dist(
+ f'creating expert parallel process group named {group_name} '
+ f'with ranks: {list(ranks)}', [0])
+ if rank in ranks:
+ _EXPERT_PARALLEL_GROUP[group_name] = group
+ _EXPERT_PARALLEL_GROUP_RANKS[group_name] = ranks
+ else:
+ for i in range(world_size // expert_parallel_size_):
+ ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
group = dist.new_group(ranks)
log_dist(
f'creating expert parallel process group named {group_name} '
@@ -301,15 +335,51 @@ def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expe
if rank in ranks:
_EXPERT_PARALLEL_GROUP[group_name] = group
_EXPERT_PARALLEL_GROUP_RANKS[group_name] = ranks
+ return
+
+ # New path: mp_size > 1
+ if use_data_before_expert_parallel_:
+ raise NotImplementedError("use_data_before_expert_parallel_ is not supported with mp_size > 1")
+
+ if group_name in _EXPERT_PARALLEL_GROUP:
+ return # Already created
+
+ for pp_stage_start in range(0, world_size, pp_stride):
+ stage_ranks = list(range(pp_stage_start, pp_stage_start + pp_stride))
+
+ # Build ordered_stage_ranks based on mp_mode
+ if mp_mode == "tp" and effective_mp_size > 1:
+ # TP-strided: group by TP, then interleave DP lanes
+ num_tp_groups = len(stage_ranks) // effective_mp_size
+ ordered = []
+ for dp_lane in range(effective_mp_size):
+ for tp_group_idx in range(num_tp_groups):
+ ordered.append(stage_ranks[tp_group_idx * effective_mp_size + dp_lane])
+ ordered_stage_ranks = ordered
else:
- for i in range(world_size // expert_parallel_size_):
- ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
- group = dist.new_group(ranks)
- log_dist(f'creating expert parallel process group named {group_name} '
- f'with ranks: {list(ranks)}', [0])
- if rank in ranks:
- _EXPERT_PARALLEL_GROUP[group_name] = group
- _EXPERT_PARALLEL_GROUP_RANKS[group_name] = ranks
+ # SP or no-MP: consecutive
+ ordered_stage_ranks = stage_ranks
+
+ # Create EP groups by chunking ordered ranks
+ num_ep_groups = len(ordered_stage_ranks) // expert_parallel_size_
+ ep_groups_list = []
+ for g in range(num_ep_groups):
+ ep_ranks = ordered_stage_ranks[g * expert_parallel_size_:(g + 1) * expert_parallel_size_]
+ ep_groups_list.append(ep_ranks)
+ group = dist.new_group(ep_ranks)
+ log_dist(f'creating expert parallel process group named {group_name} with ranks: {ep_ranks}', [0])
+ if rank in ep_ranks:
+ _EXPERT_PARALLEL_GROUP[group_name] = group
+ _EXPERT_PARALLEL_GROUP_RANKS[group_name] = ep_ranks
+
+ # Create EDP groups: same position across EP groups
+ for pos in range(expert_parallel_size_):
+ edp_ranks = [ep_groups_list[g][pos] for g in range(num_ep_groups)]
+ group = dist.new_group(edp_ranks)
+ log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {edp_ranks}', [0])
+ if rank in edp_ranks:
+ _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
+ _EXPERT_DATA_PARALLEL_GROUP_RANKS[group_name] = edp_ranks
def _get_expert_parallel_ranks(world_size,
diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml
index fe72f8890bab..360f2ac28a44 100755
--- a/docs/_data/navigation.yml
+++ b/docs/_data/navigation.yml
@@ -93,7 +93,7 @@ lnav:
url: /tutorials/lrrt/
- title: 'Megatron-LM GPT2'
url: /tutorials/megatron/
- - title: 'Mixture-of-Experts (MoE)'
+ - title: 'Mixture of Experts (DeepSpeed MoE)'
url: /tutorials/mixture-of-experts/
- title: 'MoE for NLG'
url: /tutorials/mixture-of-experts-nlg/
diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md
index 543339fad43f..b1a1fe8d0bab 100755
--- a/docs/_pages/config-json.md
+++ b/docs/_pages/config-json.md
@@ -860,6 +860,219 @@ When a HuggingFace model provides a built-in `tp_plan` (via `model.config.base_m
| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to `True` by default, which means unused parameters are ignored and training continues. Now is just used in stage 2. | `True` |
+### Expert Parallel (AutoEP)
+Configure AutoEP expert parallelism for MoE models. AutoEP automatically detects MoE layers in HuggingFace models and replaces them with EP-enabled versions using TorchTitan's grouped GEMM kernels. Requires zero model code changes. Supports ZeRO stages 0, 1, and 2 (stage 3 is not supported).
+```json
+ "expert_parallel": {
+ "enabled": true,
+ "autoep_size": 4,
+ "preset_model": "mixtral"
+ }
+```
+**expert_parallel**: [dictionary]
+
+| Description | Default |
+| ------------------------------------------------------------------------------------------ | ------- |
+| Enable AutoEP expert parallelism and configure MoE layer detection and replacement. | `{}` |
+
+***enabled***: [boolean]
+
+| Description | Default |
+| --------------------------------------------------------------------------- | ------- |
+| Enable AutoEP. When `false`, all other expert_parallel settings are ignored. | `false` |
+
+***autoep_size***: [integer]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------- | ------- |
+| Expert-parallel degree (number of ranks sharing expert computation). Must divide `world_size / pp_size`. `1` = all experts local (no AllToAll), useful for testing. | `1` |
+
+***preset_model***: [string]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------------------------------------------- | ------- |
+| Built-in model preset for MoE detection: `mixtral`, `qwen3_moe`, `qwen3_5_moe`, `deepseek_v2`, `deepseek_v3`, `llama4`. Determines router, expert, and weight naming patterns. | `null` |
+
+Built-in AutoEP presets describe DeepSpeed's router/expert/weight-pattern support for a model family.
+Running a HuggingFace model also requires the installed Transformers package to expose the corresponding
+config/model classes, `model.config.model_type` value, and fused expert layout. The tiny HuggingFace
+smoke coverage used for this AutoEP surface produced the following version gates:
+
+| Preset | Minimum Transformers version | Notes |
+| ------ | ---------------------------- | ----- |
+| `mixtral` | `5.0.0` | |
+| `qwen3_moe` | `5.0.0` | Also covers Qwen2-MoE when the installed Transformers build uses the validated fused expert layout. Qwen3-MoE classes appear in `4.51.3`, but the tested `4.x` builds do not match the validated AutoEP layout. |
+| `qwen3_5_moe` | `5.2.0` | Requires the Qwen3.5 text-backbone `qwen3_5_moe_text` model type. For performance on Qwen3.5's Gated DeltaNet layers, install optimized kernels; see the [Hugging Face Transformers kernel loading docs](https://huggingface.co/docs/transformers/kernel_doc/loading_kernels) and the [Qwen FlashQLA blog](https://qwen.ai/blog?id=flashqla). |
+| `deepseek_v2` | `5.0.0` | `load_balance_coeff` / expert-bias auxiliary-loss-free load balancing is not currently supported; non-null values are rejected. |
+| `deepseek_v3` | `5.0.0` | `load_balance_coeff` / expert-bias auxiliary-loss-free load balancing is not currently supported; non-null values are rejected. |
+| `llama4` | `5.0.0` | |
+
+***use_grouped_mm***: [boolean]
+
+| Description | Default |
+| ---------------------------------------------------------------------------------------------- | ------- |
+| Use `torch._grouped_mm` for fused grouped GEMM. Raises `RuntimeError` at `GroupedExperts` construction time when `torch._grouped_mm` is unavailable; set `use_grouped_mm=false` to use the sequential for-loop. | `true` |
+
+***moe_layer_pattern***: [string]
+
+| Description | Default |
+| ------------------------------------------------------------------------------------------------------------- | ------- |
+| Regex pattern matching MoE module names (e.g., `"model\\.layers\\.\\d+\\.mlp"`). When set, uses the custom preset path instead of auto-detecting from `model_type`. | `null` |
+
+***router_pattern***: [string]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------- | ------- |
+| Direct child attribute name for the router/gate module (e.g., `"gate"`, `"router"`). Not a regex. | `null` |
+
+***expert_pattern***: [string]
+
+| Description | Default |
+| ------------------------------------------------------------------------------------------- | ------- |
+| Direct child attribute name for the experts module (e.g., `"experts"`). Not a regex. | `null` |
+
+***score_func***: [string]
+
+| Description | Default |
+| ------------------------------------------------------------------------------------------------------------------------ | -------- |
+| Router scoring function: `"softmax"`, `"sigmoid"`, or `"auto"` (detect from `model.config.scoring_func` or use preset). | `"auto"` |
+
+***score_apply***: [string]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------------------- | -------- |
+| When to apply router scores: `"pre"` (before experts), `"post"` (during combine), or `"auto"` (from preset). | `"auto"` |
+
+***route_norm***: [boolean]
+
+| Description | Default |
+| --------------------------------------------------------------------------------------------------------------- | ------- |
+| Renormalize top-k router scores. `null` = auto-detect from `model.config.norm_topk_prob` or use preset default. | `null` |
+
+***route_scale***: [float]
+
+| Description | Default |
+| -------------------------------------------------------- | ------- |
+| Scale factor applied to router scores after computation. | `1.0` |
+
+***top_k***: [integer|string]
+
+| Description | Default |
+| --------------------------------------------------------------------------------------------------------------------------------------------------- | -------- |
+| Number of experts each token is routed to. An explicit integer overrides `top_k_attr` lookup. `"auto"` = read from `model.config` using `top_k_attr`. | `"auto"` |
+
+***routed_scaling_factor***: [float|string]
+
+| Description | Default |
+| ---------------------------------------------------------------------------------------------- | -------- |
+| Scaling factor for routed expert outputs. `"auto"` = detect from `model.config` if available. | `"auto"` |
+
+***num_expert_groups***: [integer]
+
+| Description | Default |
+| -------------------------------------------------------------------------- | ------- |
+| Number of expert groups for group-limited routing (DeepSeek-V3 style). | `null` |
+
+***num_limited_groups***: [integer]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------- | ------- |
+| Number of groups to select from in group-limited routing. Must be <= `num_expert_groups` when set. | `null` |
+
+***load_balance_coeff***: [null]
+
+| Description | Default |
+| ---------------------------------------------------------------------------------------------------- | ------- |
+| Reserved for future auxiliary-loss-free load balancing via `expert_bias`. Currently unsupported - must be unset or `null`; any other value is rejected. | `null` |
+
+***expert_w1***: [string]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------------- | ------- |
+| Expert weight name for gate (or fused gate+up) projection (e.g., `"gate_up_proj"`, `"w1"`). `null` = use preset default. | `null` |
+
+***expert_w2***: [string]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------- | ------- |
+| Expert weight name for down projection (e.g., `"down_proj"`, `"w2"`). `null` = use preset default. | `null` |
+
+***expert_w3***: [string|null]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- |
+| Expert weight name for up projection (separate from gate). Three states: key absent = use preset default; `null` = fused gate+up (no separate w3); string = custom weight name. | absent (preset default) |
+
+***num_experts_attr***: [string]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------------- | ------- |
+| Name of `model.config` attribute for number of experts (e.g., `"num_local_experts"`). `null` = use preset default. | `null` |
+
+***top_k_attr***: [string]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------------------------------------- | ------- |
+| Name of `model.config` attribute for top-k value (e.g., `"num_experts_per_tok"`). `null` = use preset default. If `top_k` is explicitly set as an integer, `top_k_attr` is ignored. | `null` |
+
+***has_shared_experts***: [boolean]
+
+| Description | Default |
+| ---------------------------------------------------------------------------------------------------------- | ------- |
+| Whether the MoE layer has shared (non-routed) experts. `null` = auto-detect from preset. Must be paired with `shared_experts_pattern`. | `null` |
+
+***shared_experts_pattern***: [string]
+
+| Description | Default |
+| -------------------------------------------------------------------------------------------------------- | ------- |
+| Direct child attribute name for shared experts (e.g., `"shared_expert"`). `null` = use preset default. | `null` |
+
+#### Custom Model Example
+
+For a model with non-standard naming conventions that is not covered by built-in presets:
+
+```json
+{
+ "expert_parallel": {
+ "enabled": true,
+ "autoep_size": 4,
+ "moe_layer_pattern": "model\\.layers\\.\\d+\\.moe",
+ "router_pattern": "router",
+ "expert_pattern": "mlp_experts",
+ "expert_w1": "w1",
+ "expert_w2": "w2",
+ "expert_w3": "w3",
+ "num_experts_attr": "num_moe_experts",
+ "top_k_attr": "moe_top_k",
+ "has_shared_experts": false
+ }
+}
+```
+
+#### Preset Override Example
+
+Use a built-in preset but override specific naming/weight fields for a fine-tuned model with renamed module paths:
+
+```json
+{
+ "expert_parallel": {
+ "enabled": true,
+ "preset_model": "mixtral",
+ "moe_layer_pattern": "model\\.layers\\.\\d+\\.moe",
+ "router_pattern": "router",
+ "expert_w1": "w1",
+ "expert_w2": "w2"
+ }
+}
+```
+
+> **Note:** `expert_storage` and `gate_bias` are auto-detected from model weights and cannot be overridden. `router_pattern`, `expert_pattern`, and `shared_experts_pattern` are direct child attribute names, not regex patterns.
+
+**Constraints:**
+- `autoep_size` must divide `num_experts` for all detected MoE layers
+- AutoEP currently cannot be combined with AutoTP (`tensor_parallel.autotp_size > 1`); support is planned as follow-up work
+- ZeRO Stage 3 is not supported with AutoEP (assertion will fire)
+
### Logging
**steps_per_print**: [integer]
diff --git a/docs/_pages/training.md b/docs/_pages/training.md
index e31651cc487a..870670ae0a3d 100644
--- a/docs/_pages/training.md
+++ b/docs/_pages/training.md
@@ -180,7 +180,7 @@ Below we provide a brief feature list, see our detailed [feature overview](https
* Efficient and robust compressed training
* Up to 2.5x convergence speedup for pre-training
* [Performance Analysis and Debugging](https://www.deepspeed.ai/features/#performance-analysis-and-debugging)
-* [Mixture of Experts (MoE)](https://www.deepspeed.ai/tutorials/mixture-of-experts/)
+* [Mixture of Experts (DeepSpeed MoE)](https://www.deepspeed.ai/tutorials/mixture-of-experts/)
---
diff --git a/docs/_tutorials/mixture-of-experts.md b/docs/_tutorials/mixture-of-experts.md
index b4a1c2f86d6a..700b32b01ee7 100644
--- a/docs/_tutorials/mixture-of-experts.md
+++ b/docs/_tutorials/mixture-of-experts.md
@@ -1,5 +1,5 @@
---
-title: "Mixture of Experts"
+title: "Mixture of Experts (DeepSpeed MoE)"
tags: MoE training
---
diff --git a/docs/_tutorials/universal-checkpointing.md b/docs/_tutorials/universal-checkpointing.md
index 994ea408bf52..013a14431c45 100644
--- a/docs/_tutorials/universal-checkpointing.md
+++ b/docs/_tutorials/universal-checkpointing.md
@@ -18,76 +18,143 @@ Before you begin, ensure you have the following:
## How to use DeepSpeed Universal Checkpointing
-Follow the three simple steps below:
+Universal Checkpointing uses the same high-level flow for dense models, AutoTP
+(Automatic Tensor Parallelism), and AutoEP (Automatic Expert Parallelism): save a
+regular DeepSpeed ZeRO checkpoint, convert that checkpoint to Universal format,
+then load it with `checkpoint.load_universal` enabled.
### Step 1: Create ZeRO Checkpoint
-The first step in leveraging DeepSpeed Universal Checkpointing is to create a ZeRO checkpoint. [ZeRO](/tutorials/zero/) (Zero Redundancy Optimizer) is a memory optimization technology in DeepSpeed that allows for efficient training of large models. To create a ZeRO checkpoint, you'll need to:
+Start by creating a regular DeepSpeed checkpoint from a run that uses
+[ZeRO](/tutorials/zero/) (Zero Redundancy Optimizer). Use the normal DeepSpeed
+checkpoint API from your training script:
- - Initialize your model with DeepSpeed using the ZeRO optimizer.
- - Train your model to the desired state (iterations).
- - Save a checkpoint using DeepSpeed's checkpointing feature.
+```python
+engine.save_checkpoint(save_dir, tag=tag)
+```
+This is the same save call used for AutoTP and AutoEP training runs. AutoTP
+checkpoints include Universal Checkpoint metadata that describes tensor-parallel
+parameter layouts. AutoEP checkpoints also use the normal save API; AutoEP's
+expert-specific layout is described in the AutoEP requirements section below.
### Step 2: Convert ZeRO Checkpoint to Universal Format
-Once you have a ZeRO checkpoint, the next step is to convert it into the Universal format. This format is designed to be flexible and compatible across different model architectures and DeepSpeed configurations. To convert a checkpoint:
-
- - Use the [ds_to_universal.py](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/checkpoint/ds_to_universal.py) script provided by DeepSpeed.
- - Specify the path to your ZeRO checkpoint and the desired output path for the Universal checkpoint.
+Once you have a ZeRO checkpoint, convert it to Universal format with the
+`ds_to_universal.py` script provided by DeepSpeed:
```bash
-python ds_to_universal.py --input_folder /path/to/zero/checkpoint --output_folder /path/to/universal/checkpoint
+python deepspeed/checkpoint/ds_to_universal.py \
+ --input_folder /path/to/ds_checkpoint \
+ --output_folder /path/to/universal_checkpoint
```
-This script will process the ZeRO checkpoint and generate a new checkpoint in the Universal format. Pass `--help` flag to see other options.
-
-### Step 3: Resume Training with Universal Checkpoint
-With the Universal checkpoint ready, you can now resume training on potentially with different parallelism topologies or training configurations. To do this add `--universal-checkpoint` to your DeepSpeed config (json) file
-
-
-## Universal Checkpointing with AutoTP (Automatic Tensor Parallelism)
-
-DeepSpeed AutoTP (Automatic Tensor Parallelism) can produce checkpoints that are compatible with Universal
-Checkpoint conversion and restore.
-
-### What gets saved
+This script processes the saved ZeRO checkpoint and writes a Universal
+checkpoint to the output folder. Pass the `--help` flag to see additional
+options.
-When AutoTP is enabled, DeepSpeed will attach Universal Checkpoint metadata (`UNIVERSAL_CHECKPOINT_INFO`)
-to the saved training checkpoint. This metadata describes how tensor-parallel parameters were partitioned
-(e.g. row-parallel vs column-parallel, replicated parameters, and fused/sub-parameter layouts).
+For AutoTP checkpoints, the converter uses the saved Universal Checkpoint
+metadata (`UNIVERSAL_CHECKPOINT_INFO`) to reconstruct tensor-parallel parameters
+correctly, including row-parallel, column-parallel, replicated, fused, and
+sub-parameter layouts.
-This enables:
-- converting a TP-sharded training checkpoint into a Universal checkpoint via `ds_to_universal.py`
-- restoring the checkpoint correctly even when TP partitioning uses fused weights (e.g. QKV)
-
-### Enablement
+### Step 3: Resume Training with Universal Checkpoint
-AutoTP is enabled by setting `tensor_parallel` in your DeepSpeed config:
+With the Universal checkpoint ready, resume training by enabling Universal
+Checkpoint loading in your DeepSpeed config:
```json
{
- "zero_optimization": { "stage": 2 },
- "bf16": { "enabled": true },
- "tensor_parallel": { "autotp_size": 4 }
+ "checkpoint": {
+ "load_universal": true
+ }
}
```
-Save a regular DeepSpeed checkpoint during training:
+Then load the converted checkpoint through the normal DeepSpeed checkpoint API:
+```python
+engine.load_checkpoint("/path/to/universal_checkpoint", tag=tag)
```
-engine.save_checkpoint(save_dir, tag=tag)
-```
-
-### Conversion
-
-Convert the saved DeepSpeed checkpoint to the universal format:
-```
-python deepspeed/checkpoint/ds_to_universal.py \
- --input_folder /path/to/ds_checkpoint \
- --output_folder /path/to/universal_checkpoint
-```
+The target run still needs the DeepSpeed parallelism configuration that matches
+the model and topology you want to use for resumed training.
+
+### AutoEP Requirements and Limitations
+
+AutoEP checkpoints are saved as regular DeepSpeed checkpoints, but routed expert
+weights have an additional layout. With AutoEP enabled, DeepSpeed writes the
+routed expert weights (`w1`, `w2`, and `w3`) into per-expert files named like
+`layer__expert__mp_rank__model_states.pt`.
+The regular model checkpoint records AutoEP metadata in `ds_autoep_layers`; older
+checkpoints may use the legacy `autoep_layers` key. Router, gate, shared-expert,
+and other non-routed-expert parameters stay in the regular
+`mp_rank_*_model_states.pt` files and use the standard Universal Checkpointing
+path.
+
+Use ZeRO Stage 1 or ZeRO Stage 2 for the current AutoEP Universal Checkpoint
+conversion path. ZeRO Stage 3 AutoEP Universal Checkpoint conversion is not
+supported; when AutoEP metadata is present, the converter raises
+`NotImplementedError` with the message that AutoEP currently requires ZeRO Stage
+1 or 2.
+
+During conversion, `ds_to_universal.py` reads `ds_autoep_layers` or the legacy
+`autoep_layers` key, consolidates each AutoEP layer's routed expert files, and
+writes full expert tensors to paths such as `zero/.w1/fp32.pt`.
+These files are tagged with `is_expert_param` and `ep_num_experts`, which are the
+load-time signals used for AutoEP expert resharding. When matching expert
+optimizer shards are available, the converter also writes optimizer state files
+such as `exp_avg.pt` and `exp_avg_sq.pt` next to the converted parameter.
+
+Regular AutoEP checkpoint load requires the target run to use the same
+`autoep_size` as the save run. To change `autoep_size` for the same
+AutoEP-detected model topology, convert the saved checkpoint to Universal format
+and load the Universal checkpoint.
+
+In the Universal Checkpoint load path, AutoEP routed experts are restored from
+the `zero/` parameter layout rather than from the regular
+`layer_*_expert_*_model_states.pt` files. The target run's AutoEP process group
+supplies the load-side expert-parallel rank and size. For each tagged expert
+tensor, the loader slices the saved expert dimension by `ep_rank` and `ep_size`
+when `ep_size > 1`.
+
+The target model still needs to expose matching AutoEP parameter names and
+compatible shapes, for example `.experts.w1`,
+`.experts.w2`, and `.experts.w3`. Universal
+Checkpointing changes the expert-parallel sharding for matching tensors; it does
+not translate between different model families, different module paths, or
+arbitrary expert parameter names. The target AutoEP configuration must also be
+valid before checkpoint loading: `autoep_size` must divide the target pipeline
+stage size (`world_size / pp_size`) and every detected target layer's expert
+count.
+
+Topology changes are limited to `autoep_size` resharding for matching
+AutoEP-managed expert parameters. For every AutoEP layer in the checkpoint, the
+saved `ep_num_experts` must be divisible by the target `autoep_size` when the
+target `ep_size > 1`. For example, an 8-expert checkpoint can load with target
+`autoep_size` values of 1, 2, 4, or 8, but not 3. With `autoep_size=1`, the expert
+tensor is not sliced, but the target parameter must still have the compatible
+full expert shape.
+
+Additional AutoEP failure cases:
+
+- For ZeRO Stage 1 and ZeRO Stage 2 conversion, expert checkpoint files without
+ `ds_autoep_layers` or `autoep_layers` metadata raise a `RuntimeError`.
+- Existing DeepSpeed MoE or Megatron-DeepSpeed expert checkpoint files may share
+ the `layer__expert__mp_rank__model_states.pt`
+ naming convention, but they use native `deepspeed_moe` expert parameter names
+ and do not carry AutoEP metadata. Loading or converting those checkpoints into
+ AutoEP requires a separate model-specific migration step.
+- If AutoEP metadata is present but an expected per-expert model file is missing,
+ conversion raises `FileNotFoundError`.
+- More than one `mp_rank_*` expert file for the same `(layer, expert)` pair
+ raises `NotImplementedError`; combined AutoEP + AutoTP topology changes are
+ not documented by this path.
+- AutoEP optimizer-state consolidation is best effort. It succeeds for the usual
+ ZeRO Stage 1 or ZeRO Stage 2 AutoEP training checkpoints that include matching
+ expert optimizer shards. If `expp_rank_*_mp_rank_*_optim_states.pt` files or
+ matching state entries are absent, the converter still writes the model
+ parameter `fp32.pt` files and skips unavailable optimizer state files.
## Conclusion
diff --git a/docs/code-docs/source/moe.rst b/docs/code-docs/source/moe.rst
index 097a4b0bc27d..12081676cc8e 100644
--- a/docs/code-docs/source/moe.rst
+++ b/docs/code-docs/source/moe.rst
@@ -1,7 +1,108 @@
Mixture of Experts (MoE)
========================
-Layer specification
---------------------
+DeepSpeed provides two MoE implementations: AutoEP (Automatic Expert
+Parallelism), which automatically detects and replaces supported Hugging Face MoE
+layers, and DeepSpeed MoE, the explicit ``deepspeed.moe.layer.MoE`` API for
+constructing MoE layers in model code.
+
+AutoEP (Automatic Expert Parallelism)
+---------------------------------------
+
+AutoEP automatically detects MoE layers in Hugging Face models and replaces them
+with EP-enabled versions, requiring zero model code changes. It follows the
+pattern of AutoTP (Automatic Tensor Parallelism).
+
+**Built-in AutoEP presets:** ``mixtral`` (Mixtral), ``qwen3_moe`` (Qwen3-MoE),
+``qwen3_5_moe`` (Qwen3.5-MoE), ``deepseek_v2`` (DeepSeek-V2),
+``deepseek_v3`` (DeepSeek-V3), and ``llama4`` (LLaMA-4).
+
+The preset name means AutoEP knows the router, expert, and weight naming
+patterns for that model family. Running a Hugging Face model also requires a
+Transformers build that exposes the matching config/model classes,
+``model.config.model_type`` value, and fused expert layout.
+
+.. list-table:: AutoEP preset compatibility by Transformers version
+ :header-rows: 1
+
+ * - Preset
+ - Minimum Transformers version
+ - Notes
+ * - ``mixtral``
+ - ``5.0.0``
+ -
+ * - ``qwen3_moe``
+ - ``5.0.0``
+ - Also covers Qwen2-MoE when the installed Transformers build uses the
+ validated fused expert layout. Qwen3-MoE classes appear in ``4.51.3``,
+ but the tested ``4.x`` builds do not match the validated AutoEP layout.
+ * - ``qwen3_5_moe``
+ - ``5.2.0``
+ - Requires the Qwen3.5 text-backbone ``qwen3_5_moe_text`` model type;
+ for performance on Qwen3.5's Gated DeltaNet layers, install optimized
+ kernels. See the `Hugging Face Transformers kernel loading docs
+ `__
+ and the `Qwen FlashQLA blog `__.
+ * - ``deepseek_v2``
+ - ``5.0.0``
+ - ``load_balance_coeff`` / expert-bias auxiliary-loss-free load balancing
+ is not currently supported; non-null values are rejected.
+ * - ``deepseek_v3``
+ - ``5.0.0``
+ - ``load_balance_coeff`` / expert-bias auxiliary-loss-free load balancing
+ is not currently supported; non-null values are rejected.
+ * - ``llama4``
+ - ``5.0.0``
+ -
+
+**ZeRO compatibility:** Stages 0, 1, and 2. Stage 3 is not supported.
+
+**Usage:**
+
+.. code-block:: json
+
+ {
+ "expert_parallel": {
+ "enabled": true,
+ "autoep_size": 4,
+ "preset_model": "mixtral"
+ }
+ }
+
+**How it works:**
+
+1. During ``deepspeed.initialize()``, AutoEP scans the model for MoE layers
+ using preset-defined patterns (router name, expert name, weight shapes).
+2. Detected MoE blocks are replaced with ``AutoEPMoELayer``, which uses
+ TorchTitan's grouped GEMM kernels and AllToAll token dispatch.
+3. EP/EDP process groups are created automatically based on ``autoep_size``.
+4. Expert parameters are marked for expert-data-parallel gradient reduction;
+ router and shared-expert parameters use standard data-parallel reduction.
+
+**Constraints:**
+
+- ``autoep_size`` must divide ``num_experts`` for all detected MoE layers.
+- ``autoep_size=1`` is valid: all experts remain local (no AllToAll), useful
+ for functional testing on a single GPU.
+- AutoEP currently cannot be combined with AutoTP
+ (``tensor_parallel.autotp_size > 1``); support is planned as follow-up work.
+- Checkpoint save/load requires matching ``autoep_size``.
+ To change ``autoep_size`` across runs for the same AutoEP-detected model
+ topology, convert the checkpoint to Universal Checkpoint format and load it
+ with ``checkpoint.load_universal``; see the
+ `Universal Checkpointing tutorial `__
+ for the detailed flow and constraints.
+- DeepSeek-V2 and DeepSeek-V3 AutoEP do not support load-balance expert bias
+ yet. The built-in DeepSeek presets disable it by default; explicit non-null
+ values fail.
+
+DeepSpeed MoE
+-------------
+
+DeepSpeed MoE exposes the explicit ``deepspeed.moe.layer.MoE`` layer API for
+models that construct MoE layers directly. See the `Mixture of Experts
+(DeepSpeed MoE) tutorial `__ for training
+examples and configuration details.
+
.. autoclass:: deepspeed.moe.layer.MoE
:members:
diff --git a/scripts/check-license.py b/scripts/check-license.py
index 0d0e1e578faa..daffab199dc0 100755
--- a/scripts/check-license.py
+++ b/scripts/check-license.py
@@ -20,7 +20,12 @@ def err(s: str) -> None:
COPYRIGHT = [
# (r"^# Copyright (c) Microsoft Corporation.$", r"^\/\/ Copyright (c) Microsoft Corporation.$"),
- (r"^# SPDX-License-Identifier: Apache-2.0$", r"^\/\/ SPDX-License-Identifier: Apache-2.0$"),
+ (
+ r"^# SPDX-License-Identifier: Apache-2.0$",
+ r"^\/\/ SPDX-License-Identifier: Apache-2.0$",
+ r"^# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause$",
+ r"^\/\/ SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause$",
+ ),
(r"^# DeepSpeed Team$", r"^\/\/ DeepSpeed Team$"),
]
diff --git a/tests/unit/moe/__init__.py b/tests/unit/moe/__init__.py
new file mode 100644
index 000000000000..c8d652d4dc49
--- /dev/null
+++ b/tests/unit/moe/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
diff --git a/tests/unit/moe/autoep_test_utils.py b/tests/unit/moe/autoep_test_utils.py
new file mode 100644
index 000000000000..96ea8a430b9e
--- /dev/null
+++ b/tests/unit/moe/autoep_test_utils.py
@@ -0,0 +1,318 @@
+# SPDX-License-Identifier: Apache-2.0
+# DeepSpeed Team
+"""Shared fixtures and assertions for compact AutoEP tests."""
+
+import copy
+
+import deepspeed
+import pytest
+import torch
+import torch.nn as nn
+
+from deepspeed.accelerator import get_accelerator
+
+UNSET = object()
+UNSUPPORTED_LOAD_BALANCE_VALUES = [0, 0.0, 1e-3, 0.02, False, True, "1e-3", [1e-3], {"coeff": 1e-3}]
+
+
+class MockHFConfig:
+ model_type = "mixtral"
+ num_local_experts = 4
+ num_experts_per_tok = 2
+ hidden_size = 64
+ intermediate_size = 128
+
+
+class MockMoEExperts(nn.Module):
+
+ def __init__(self, num_experts=4, ffn_hidden=128, hidden_size=64, intermediate_size=None):
+ super().__init__()
+ if intermediate_size is not None:
+ ffn_hidden = intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.randn(num_experts, 2 * ffn_hidden, hidden_size))
+ self.down_proj = nn.Parameter(torch.randn(num_experts, hidden_size, ffn_hidden))
+
+
+class MockMoEBlock(nn.Module):
+
+ def __init__(self, num_experts=4, ffn_hidden=128, hidden_size=64, intermediate_size=None):
+ super().__init__()
+ if intermediate_size is not None:
+ ffn_hidden = intermediate_size
+ self.gate = nn.Linear(hidden_size, num_experts, bias=False)
+ self.experts = MockMoEExperts(num_experts, ffn_hidden, hidden_size, intermediate_size)
+ self.top_k = 2
+
+ def forward(self, x):
+ original_shape = x.shape
+ hidden_states = x.reshape(-1, original_shape[-1])
+ scores = torch.softmax(self.gate(hidden_states), dim=-1)
+ top_scores, top_indices = torch.topk(scores, k=self.top_k, dim=-1)
+ top_scores = top_scores / top_scores.sum(dim=-1, keepdim=True)
+ output = torch.zeros_like(hidden_states)
+
+ for expert_idx in range(self.gate.out_features):
+ expert_mask = top_indices == expert_idx
+ if not expert_mask.any():
+ continue
+ token_indices, route_indices = expert_mask.nonzero(as_tuple=True)
+ expert_input = hidden_states[token_indices]
+ gate_up = torch.matmul(expert_input, self.experts.gate_up_proj[expert_idx].transpose(0, 1))
+ gate_part, up_part = gate_up.chunk(2, dim=-1)
+ expert_output = torch.matmul(
+ torch.nn.functional.silu(gate_part) * up_part, self.experts.down_proj[expert_idx].transpose(0, 1))
+ output[token_indices] += expert_output * top_scores[token_indices, route_indices].unsqueeze(-1)
+
+ return output.reshape(original_shape)
+
+
+class MockDenseBlock(nn.Module):
+
+ def __init__(self, hidden_size=64, ffn_hidden=128):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, ffn_hidden, bias=False)
+ self.up_proj = nn.Linear(hidden_size, ffn_hidden, bias=False)
+ self.down_proj = nn.Linear(ffn_hidden, hidden_size, bias=False)
+
+
+class MockMoETransformer(nn.Module):
+
+ def __init__(self, num_layers=2, num_experts=4, hidden_size=64, intermediate_size=128, moe_every_n=1):
+ super().__init__()
+ self.config = MockHFConfig()
+ self.config.num_local_experts = num_experts
+ self.config.hidden_size = hidden_size
+ self.config.intermediate_size = intermediate_size
+ self.model = nn.Module()
+ self.model.layers = nn.ModuleList([
+ self._make_layer(layer_idx, num_experts, hidden_size, intermediate_size, moe_every_n)
+ for layer_idx in range(num_layers)
+ ])
+ self.lm_head = nn.Linear(hidden_size, 100)
+
+ @staticmethod
+ def _make_layer(layer_idx, num_experts, hidden_size, intermediate_size, moe_every_n):
+ layer = nn.Module()
+ layer.self_attn = nn.MultiheadAttention(hidden_size, 1, batch_first=True)
+ if layer_idx % moe_every_n == 0:
+ layer.mlp = MockMoEBlock(num_experts, intermediate_size, hidden_size)
+ else:
+ layer.mlp = MockDenseBlock(hidden_size, intermediate_size)
+ layer.input_layernorm = nn.LayerNorm(hidden_size)
+ layer.post_attention_layernorm = nn.LayerNorm(hidden_size)
+ return layer
+
+ def forward(self, x):
+ for layer_module in self.model.layers:
+ residual = x
+ x = layer_module.input_layernorm(x)
+ x, _ = layer_module.self_attn(x, x, x)
+ x = residual + x
+ residual = x
+ x = layer_module.post_attention_layernorm(x)
+ x = residual + layer_module.mlp(x)
+ return self.lm_head(x)
+
+
+def assert_load_balance_coeff_rejection_message(exc: BaseException, value: object) -> None:
+ text = str(exc)
+ for needle in ("load_balance_coeff", "expert_bias", "not supported", "null", "omit"):
+ assert needle in text
+ assert repr(value) in text
+
+
+def mixed_precision_config():
+ accelerator = get_accelerator()
+ if accelerator.is_fp16_supported() and accelerator.device_name() != "cpu":
+ return {"fp16": {"enabled": True, "initial_scale_power": 8}}
+ if accelerator.is_bf16_supported():
+ return {"bf16": {"enabled": True}}
+ if accelerator.is_fp16_supported():
+ return {"fp16": {"enabled": True, "initial_scale_power": 8}}
+ pytest.skip("AutoEP tests require fp16 or bf16 support")
+
+
+def make_autoep_config(zero_stage=0, ep_size=1, load_balance_coeff=UNSET, mixed_precision=True):
+ config = {
+ "train_micro_batch_size_per_gpu": 1,
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1e-4
+ },
+ },
+ "expert_parallel": {
+ "enabled": True,
+ "autoep_size": ep_size,
+ "preset_model": "mixtral",
+ "use_grouped_mm": False,
+ },
+ "zero_optimization": {
+ "stage": zero_stage,
+ },
+ }
+ if get_accelerator().device_name() == "cpu":
+ config["optimizer"]["torch_adam"] = True
+ if mixed_precision:
+ config.update(mixed_precision_config())
+ if load_balance_coeff is not UNSET:
+ config["expert_parallel"]["load_balance_coeff"] = load_balance_coeff
+ return config
+
+
+def make_autoep_integration_config(zero_stage=0, ep_size=2):
+ return make_autoep_config(zero_stage=zero_stage, ep_size=ep_size, mixed_precision=False)
+
+
+def seed_everything(seed=1234):
+ torch.manual_seed(seed)
+ get_accelerator().manual_seed(seed)
+ get_accelerator().manual_seed_all(seed)
+
+
+def engine_input_dtype(engine):
+ if engine.bfloat16_enabled():
+ return torch.bfloat16
+ if engine.fp16_enabled():
+ return torch.float16
+ return torch.float32
+
+
+def init_autoep_engine(ep_size=1, zero_stage=0, load_balance_coeff=UNSET):
+ seed_everything(42)
+ engine, _, _, _ = deepspeed.initialize(
+ model=MockMoETransformer(),
+ config=make_autoep_config(zero_stage=zero_stage, ep_size=ep_size, load_balance_coeff=load_balance_coeff),
+ )
+ return engine
+
+
+def run_training_steps(engine, num_steps=3, seq_len=8, hidden_dim=64):
+ losses = []
+ grad_norms = []
+ for _ in range(num_steps):
+ x = torch.randn(1, seq_len, hidden_dim, device=engine.device)
+ loss = engine(x).mean()
+ engine.backward(loss)
+
+ total_norm = 0.0
+ for param in engine.module.parameters():
+ if param.grad is not None:
+ total_norm += param.grad.data.float().norm(2).item()**2
+ grad_norms.append(total_norm**0.5)
+ engine.step()
+ losses.append(loss.item())
+ return losses, grad_norms
+
+
+def tiny_causal_lm_inputs():
+ input_ids = torch.tensor([[1, 5, 7, 9, 11]], dtype=torch.long)
+ return input_ids, input_ids.clone()
+
+
+def state_matched_models(model_cls, config):
+ native_model = model_cls(config)
+ autoep_model = model_cls(config)
+ autoep_model.load_state_dict(copy.deepcopy(native_model.state_dict()))
+ return native_model, autoep_model
+
+
+def replace_autoep_layers(model, preset_model, expected_count=1, **config_overrides):
+ from deepspeed.module_inject.auto_ep import AutoEP
+ from deepspeed.module_inject.auto_ep_config import parse_autoep_config
+
+ config = {
+ "enabled": True,
+ "autoep_size": 1,
+ "preset_model": preset_model,
+ "use_grouped_mm": False,
+ **config_overrides,
+ }
+ auto_ep = AutoEP(model, parse_autoep_config(config))
+ specs = auto_ep.ep_parser()
+ assert len(specs) == expected_count
+ for spec in specs:
+ auto_ep.replace_moe_layer(spec, ep_size=1, ep_rank=0)
+ return specs
+
+
+def assert_causal_lm_outputs_close(native_model,
+ autoep_model,
+ *,
+ output_router_logits=False,
+ compare_router_logits=False,
+ compare_aux_loss=False,
+ compare_logits=True,
+ rtol=1e-5,
+ atol=1e-6):
+ input_ids, labels = tiny_causal_lm_inputs()
+ native_model.eval()
+ autoep_model.eval()
+ with torch.no_grad():
+ native_outputs = native_model(input_ids=input_ids, labels=labels, output_router_logits=output_router_logits)
+ autoep_outputs = autoep_model(input_ids=input_ids, labels=labels, output_router_logits=output_router_logits)
+
+ if compare_router_logits:
+ assert autoep_outputs.router_logits
+ torch.testing.assert_close(autoep_outputs.router_logits[0],
+ native_outputs.router_logits[0],
+ rtol=rtol,
+ atol=atol)
+ if compare_aux_loss:
+ assert autoep_outputs.aux_loss is not None
+ torch.testing.assert_close(autoep_outputs.aux_loss, native_outputs.aux_loss, rtol=rtol, atol=atol)
+ if compare_logits:
+ torch.testing.assert_close(autoep_outputs.logits, native_outputs.logits, rtol=rtol, atol=atol)
+ torch.testing.assert_close(autoep_outputs.loss, native_outputs.loss, rtol=rtol, atol=atol)
+
+
+def skip_unless_transformers_has(transformers, *names, min_version=None, reason="AutoEP coverage"):
+ from packaging.version import Version
+
+ if min_version is not None and Version(transformers.__version__) < Version(min_version):
+ pytest.skip(f"{reason} requires Transformers >= {min_version}")
+ missing = [name for name in names if not hasattr(transformers, name)]
+ if missing:
+ pytest.skip(f"Installed transformers does not expose required classes: {missing}")
+
+
+def tiny_mixtral_config(transformers):
+ return transformers.MixtralConfig(
+ vocab_size=64,
+ hidden_size=32,
+ intermediate_size=64,
+ num_hidden_layers=1,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ max_position_embeddings=64,
+ num_local_experts=4,
+ num_experts_per_tok=2,
+ output_router_logits=True,
+ tie_word_embeddings=False,
+ use_cache=False,
+ )
+
+
+def tiny_llama4_text_config(transformers):
+ return transformers.Llama4TextConfig(
+ vocab_size=64,
+ hidden_size=32,
+ intermediate_size=16,
+ intermediate_size_mlp=16,
+ num_hidden_layers=1,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ head_dim=8,
+ max_position_embeddings=64,
+ num_local_experts=4,
+ num_experts_per_tok=1,
+ moe_layers=[0],
+ interleave_moe_layer_step=1,
+ output_router_logits=False,
+ router_jitter_noise=0.0,
+ tie_word_embeddings=False,
+ use_cache=False,
+ attention_chunk_size=64,
+ attn_temperature_tuning=False,
+ no_rope_layers=[0],
+ )
diff --git a/tests/unit/moe/test_autoep_checkpoint.py b/tests/unit/moe/test_autoep_checkpoint.py
new file mode 100644
index 000000000000..d2f7211466ca
--- /dev/null
+++ b/tests/unit/moe/test_autoep_checkpoint.py
@@ -0,0 +1,104 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""Compact AutoEP checkpoint tests."""
+
+import os
+
+import pytest
+import torch
+import torch.nn as nn
+
+from deepspeed.runtime.config import DeepSpeedConfig
+from unit.common import DistributedTest
+from unit.moe.autoep_test_utils import (
+ UNSUPPORTED_LOAD_BALANCE_VALUES,
+ assert_load_balance_coeff_rejection_message,
+ init_autoep_engine,
+)
+
+
+@pytest.mark.parametrize("enabled", [True, False])
+@pytest.mark.parametrize("include_key", [False, True])
+def test_load_balance_coeff_disabled_values_accepted_by_deepspeed_config(enabled, include_key):
+ config = {
+ "train_micro_batch_size_per_gpu": 1,
+ "expert_parallel": {
+ "enabled": enabled,
+ "autoep_size": 1,
+ "preset_model": "mixtral",
+ },
+ }
+ if include_key:
+ config["expert_parallel"]["load_balance_coeff"] = None
+
+ ds_config = DeepSpeedConfig(config)
+
+ assert ds_config.expert_parallel_config.load_balance_coeff is None
+ assert ds_config.expert_parallel_config._load_balance_coeff_explicit is include_key
+
+
+@pytest.mark.parametrize("enabled", [True, False])
+@pytest.mark.parametrize("value", UNSUPPORTED_LOAD_BALANCE_VALUES)
+def test_load_balance_coeff_rejected_by_deepspeed_config(enabled, value):
+ config = {
+ "train_micro_batch_size_per_gpu": 1,
+ "expert_parallel": {
+ "enabled": enabled,
+ "autoep_size": 1,
+ "preset_model": "mixtral",
+ "load_balance_coeff": value,
+ },
+ }
+
+ with pytest.raises(ValueError) as exc_info:
+ DeepSpeedConfig(config)
+ assert_load_balance_coeff_rejection_message(exc_info.value, value)
+
+
+class TestAutoEPCheckpointSaveLoad(DistributedTest):
+ world_size = 1
+
+ def test_save_load_same_ep_and_metadata(self, tmpdir):
+ engine = init_autoep_engine(ep_size=1)
+ params_before = {name: param.detach().clone() for name, param in engine.module.named_parameters()}
+ save_dir = str(tmpdir)
+ tag = "autoep"
+
+ engine.save_checkpoint(save_dir, tag=tag)
+
+ checkpoint = torch.load(os.path.join(save_dir, tag, "mp_rank_00_model_states.pt"),
+ map_location="cpu",
+ weights_only=False)
+ metadata = checkpoint["ds_autoep_layers"]
+ assert len(metadata) == 2
+ for entry in metadata:
+ assert {"moe_layer_id", "module_path", "num_experts", "num_local_experts", "ep_size"} <= entry.keys()
+ assert entry["num_experts"] == entry["num_local_experts"] * entry["ep_size"]
+
+ reloaded = init_autoep_engine(ep_size=1)
+ reloaded.load_checkpoint(save_dir, tag=tag)
+ for name, param in reloaded.module.named_parameters():
+ assert torch.equal(param, params_before[name]), f"{name} changed after same-EP reload"
+
+ def test_autoep_metadata_schema_validation(self):
+ from deepspeed.runtime.engine import DeepSpeedEngine
+
+ with pytest.raises(RuntimeError, match="malformed"):
+ DeepSpeedEngine.load_moe_state_dict(checkpoint_path="/fake",
+ tag="fake",
+ state_dict={},
+ old_moe_load=False,
+ model=nn.Linear(1, 1),
+ autoep_layers="not_a_list")
+
+ with pytest.raises(RuntimeError, match="missing fields"):
+ DeepSpeedEngine.load_moe_state_dict(checkpoint_path="/fake",
+ tag="fake",
+ state_dict={},
+ old_moe_load=False,
+ model=nn.Linear(1, 1),
+ autoep_layers=[{
+ "moe_layer_id": 0
+ }])
diff --git a/tests/unit/moe/test_autoep_grad_parity.py b/tests/unit/moe/test_autoep_grad_parity.py
new file mode 100644
index 000000000000..8452bbcb88f8
--- /dev/null
+++ b/tests/unit/moe/test_autoep_grad_parity.py
@@ -0,0 +1,180 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""One ZeRO-2 AutoEP gradient parity path."""
+
+import deepspeed
+import deepspeed.comm as dist
+import torch
+from deepspeed.utils import safe_get_full_grad
+from unit.common import DistributedTest
+from unit.moe.autoep_test_utils import (
+ MockMoETransformer,
+ engine_input_dtype as _engine_input_dtype,
+ mixed_precision_config as _mixed_precision_config,
+ seed_everything as _seed_everything,
+)
+
+
+def _make_model():
+ return MockMoETransformer(num_layers=1, num_experts=4, hidden_size=128, intermediate_size=256)
+
+
+def _make_zero2_config():
+ return {
+ **_mixed_precision_config(),
+ "train_micro_batch_size_per_gpu": 1,
+ "gradient_accumulation_steps": 2,
+ "gradient_clipping": 0.0,
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": 3e-3,
+ "betas": [0.9, 0.999],
+ "eps": 1e-8,
+ "weight_decay": 0.01,
+ },
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "allgather_partitions": True,
+ "allgather_bucket_size": 5e8,
+ "overlap_comm": True,
+ "reduce_scatter": True,
+ "reduce_bucket_size": 5e8,
+ },
+ }
+
+
+def _make_autoep_zero2_config(ep_size):
+ config = _make_zero2_config()
+ config["expert_parallel"] = {
+ "enabled": True,
+ "autoep_size": ep_size,
+ "preset_model": "mixtral",
+ "load_balance_coeff": None,
+ "use_grouped_mm": False,
+ }
+ return config
+
+
+def _make_local_batches(*, logical_dp_world_size, logical_dp_rank, grad_accum, seed, seq_len, micro_batch_size,
+ hidden_size, device, dtype):
+ batches = []
+ for accum_idx in range(grad_accum):
+ batch_idx = accum_idx * logical_dp_world_size + logical_dp_rank
+ generator = torch.Generator().manual_seed(seed + batch_idx)
+ batches.append(
+ torch.randn((micro_batch_size, seq_len, hidden_size), generator=generator, dtype=dtype).to(device))
+ return batches
+
+
+def _run_until_boundary(engine, *, logical_dp_world_size, logical_dp_rank, grad_accum, seed):
+ batches = _make_local_batches(
+ logical_dp_world_size=logical_dp_world_size,
+ logical_dp_rank=logical_dp_rank,
+ grad_accum=grad_accum,
+ seed=seed,
+ seq_len=16,
+ micro_batch_size=1,
+ hidden_size=128,
+ device=engine.device,
+ dtype=_engine_input_dtype(engine),
+ )
+ for batch_idx, batch in enumerate(batches):
+ loss = engine(batch).mean()
+ engine.backward(loss)
+ if batch_idx + 1 < len(batches):
+ engine.step()
+
+
+def _gather_autoep_expert_grad(param, group):
+ grad = safe_get_full_grad(param)
+ assert grad is not None, "Expected full expert grad"
+ group_size = dist.get_world_size(group=group)
+ shards = [torch.zeros_like(grad) for _ in range(group_size)]
+ dist.all_gather(shards, grad.detach(), group=group)
+ # The gather reconstructs expert shards; gradient reduction has already
+ # applied the data-parallel normalization, so do not average by EP size.
+ return torch.cat([shard.float().cpu() for shard in shards], dim=0)
+
+
+def _collect_autoep_expert_grads(engine):
+ from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer
+
+ grads = {}
+ for module_name, module in engine.module.named_modules():
+ if not isinstance(module, AutoEPMoELayer):
+ continue
+ prefix = f"{module_name}.experts"
+ w1 = _gather_autoep_expert_grad(module.experts.w1, module.ep_group)
+ w2 = _gather_autoep_expert_grad(module.experts.w2, module.ep_group)
+ w3 = _gather_autoep_expert_grad(module.experts.w3, module.ep_group)
+ grads[f"{prefix}.gate_up_proj"] = torch.cat([w1, w3], dim=1)
+ grads[f"{prefix}.down_proj"] = w2
+ return grads
+
+
+def _collect_zero2_expert_grads(engine):
+ grads = {}
+ for name, param in engine.module.named_parameters():
+ if name.endswith(".experts.gate_up_proj") or name.endswith(".experts.down_proj"):
+ grad = safe_get_full_grad(param)
+ assert grad is not None, f"Expected full grad for {name}"
+ grads[name] = grad.detach().float().cpu().clone()
+ return grads
+
+
+def _assert_grad_maps_close(actual, expected, *, lhs_name, rhs_name):
+ for name in sorted(expected):
+ assert name in actual, f"Missing {lhs_name} param snapshot for {name}"
+ diff = (actual[name] - expected[name]).abs()
+ torch.testing.assert_close(actual[name],
+ expected[name],
+ atol=1e-1,
+ rtol=5e-3,
+ msg=(f"Gradient mismatch for {name} between {lhs_name} and {rhs_name}; "
+ f"max_diff={diff.max().item()} "
+ f"actual_norm={actual[name].norm().item()} "
+ f"expected_norm={expected[name].norm().item()}"))
+
+
+class TestAutoEPGradParity(DistributedTest):
+ world_size = 4
+
+ def test_zero2_autoep_matches_zero2_after_one_update(self):
+ ep_size = 2
+ seed = 1234
+
+ _seed_everything(seed)
+ reference_state = _make_model().state_dict()
+
+ autoep_model = _make_model()
+ zero2_model = _make_model()
+ autoep_model.load_state_dict(reference_state)
+ zero2_model.load_state_dict(reference_state)
+
+ autoep_engine, _, _, _ = deepspeed.initialize(model=autoep_model, config=_make_autoep_zero2_config(ep_size))
+ zero2_engine, _, _, _ = deepspeed.initialize(model=zero2_model, config=_make_zero2_config())
+
+ autoep_rank = dist.get_rank() // ep_size
+ _run_until_boundary(autoep_engine,
+ logical_dp_world_size=self.world_size // ep_size,
+ logical_dp_rank=autoep_rank,
+ grad_accum=2,
+ seed=seed)
+ _run_until_boundary(zero2_engine,
+ logical_dp_world_size=self.world_size // ep_size,
+ logical_dp_rank=autoep_rank,
+ grad_accum=2,
+ seed=seed)
+
+ autoep_expert = _collect_autoep_expert_grads(autoep_engine)
+ zero2_expert = _collect_zero2_expert_grads(zero2_engine)
+
+ dist.barrier()
+ if dist.get_rank() != 0:
+ return
+
+ _assert_grad_maps_close(autoep_expert, zero2_expert, lhs_name="AutoEP expert", rhs_name="ZeRO-2 expert")
diff --git a/tests/unit/moe/test_autoep_integration.py b/tests/unit/moe/test_autoep_integration.py
new file mode 100644
index 000000000000..13c28e405474
--- /dev/null
+++ b/tests/unit/moe/test_autoep_integration.py
@@ -0,0 +1,72 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""Integration tests for AutoEP (multi-GPU, requires distributed backend)."""
+
+import pytest
+import torch
+import deepspeed
+from unit.moe.autoep_test_utils import (
+ MockMoETransformer,
+ make_autoep_integration_config as _make_autoep_config,
+ run_training_steps as _run_training_steps,
+ seed_everything as _seed_everything,
+)
+from unit.common import DistributedTest
+
+# ---------------------------------------------------------------------------
+# Test class: AutoEP integration (world_size=2)
+# ---------------------------------------------------------------------------
+
+
+class TestAutoEPOnly(DistributedTest):
+ world_size = 2
+
+ def test_zero2_ep_2gpu(self):
+ """EP with ZeRO-2 training.
+
+ Verifies EP and ZeRO Stage 2 work together: finite losses
+ and parameters actually update across training steps.
+ Note: ZeRO-2 partitions gradients, so p.grad may be None on some ranks.
+ """
+ _seed_everything(1234)
+
+ model = MockMoETransformer()
+ config = _make_autoep_config(zero_stage=2, ep_size=2)
+ engine, _, _, _ = deepspeed.initialize(model=model, config=config)
+
+ # Verify replacement
+ from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer
+ replaced_count = sum(1 for _, m in engine.module.named_modules() if isinstance(m, AutoEPMoELayer))
+ assert replaced_count == 2, (f"Expected 2 MoE layers replaced with ZeRO-2, found {replaced_count}")
+
+ # Snapshot parameter values before training
+ params_before = {n: p.data.clone().float() for n, p in engine.module.named_parameters() if p.requires_grad}
+
+ # Run training steps (ignore grad norms since ZeRO-2 partitions them)
+ losses, _ = _run_training_steps(engine, num_steps=3)
+
+ for i, loss_val in enumerate(losses):
+ assert torch.isfinite(torch.tensor(loss_val)), (f"Loss at step {i} is not finite: {loss_val}")
+
+ # Verify at least some parameters changed (optimizer step took effect)
+ params_changed = 0
+ for n, p in engine.module.named_parameters():
+ if n in params_before and not torch.equal(p.data.float(), params_before[n]):
+ params_changed += 1
+ assert params_changed > 0, "No parameters changed after 3 training steps with ZeRO-2"
+
+ def test_zero3_ep_rejected_2gpu(self):
+ """EP with ZeRO-3 should trigger an assertion error.
+
+ ZeRO Stage 3 is incompatible with MoE. The engine should raise
+ an AssertionError with the message 'MoE not supported with Stage 3'.
+ """
+ _seed_everything(1234)
+
+ model = MockMoETransformer()
+ config = _make_autoep_config(zero_stage=3, ep_size=2)
+
+ with pytest.raises(AssertionError, match="MoE not supported with Stage 3"):
+ deepspeed.initialize(model=model, config=config)
diff --git a/tests/unit/moe/test_autoep_unit.py b/tests/unit/moe/test_autoep_unit.py
new file mode 100644
index 000000000000..e1687a673886
--- /dev/null
+++ b/tests/unit/moe/test_autoep_unit.py
@@ -0,0 +1,486 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""Compact critical-path tests for AutoEP."""
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+import torch.nn as nn
+
+import deepspeed.runtime.engine as ds_engine
+from deepspeed.module_inject.auto_ep import AutoEP, _resolve_route_scale
+from deepspeed.module_inject.auto_ep_config import (
+ AutoEPConfig,
+ MoELayerSpec,
+ PRESET_MODELS,
+ parse_autoep_config,
+ validate_autoep_config,
+ validate_autoep_post_detection,
+)
+from deepspeed.module_inject.auto_ep_layer import (
+ AutoEPMoELayer,
+ apply_scores_before_experts_if_enabled,
+ combine_from_routed,
+ resolve_score_apply_mode,
+)
+from deepspeed.module_inject.auto_ep_preset_adapters import get_preset_adapter
+from deepspeed.module_inject.auto_ep_presets.registry import (
+ preset_name_for_hf_model_type,
+ unsupported_preset_for_hf_model_type,
+)
+from deepspeed.moe.ep_experts import GroupedExperts
+from deepspeed.moe.ep_kernels import TokenReorderer
+from deepspeed.moe.ep_repack import repack_expert_weights
+from deepspeed.moe.ep_router import TokenChoiceTopKRouter
+from deepspeed.runtime.engine import DeepSpeedEngine
+from deepspeed.utils import groups
+from unit.moe.autoep_test_utils import (
+ MockMoEBlock,
+ MockMoETransformer,
+ UNSUPPORTED_LOAD_BALANCE_VALUES,
+ assert_causal_lm_outputs_close,
+ assert_load_balance_coeff_rejection_message,
+ replace_autoep_layers,
+ skip_unless_transformers_has,
+ state_matched_models,
+ tiny_llama4_text_config,
+ tiny_mixtral_config,
+)
+
+
+def _runtime_config(**kwargs):
+ kwargs.setdefault("use_grouped_mm", False)
+ return AutoEPConfig(**kwargs)
+
+
+def _make_spec(**kwargs):
+ defaults = dict(
+ moe_module_name="model.layers.0.mlp",
+ model_family="mixtral",
+ router_name="gate",
+ experts_name="experts",
+ expert_storage="fused_3d",
+ expert_w1_name="gate_up_proj",
+ expert_w2_name="down_proj",
+ expert_w3_name=None,
+ num_experts=4,
+ top_k=2,
+ hidden_size=64,
+ ffn_hidden_size=128,
+ score_func="softmax",
+ score_apply="post",
+ route_norm=True,
+ gate_bias=False,
+ return_router_logits=False,
+ router_logits_capture_target="none",
+ router_logits_capture_index=None,
+ router_logits_capture_layer_name=None,
+ has_shared_experts=False,
+ shared_experts_name="",
+ shared_experts_gate_name="",
+ )
+ defaults.update(kwargs)
+ return MoELayerSpec(**defaults)
+
+
+class MockLlama4Config:
+ model_type = "llama4"
+ num_local_experts = 8
+ num_experts_per_tok = 1
+ hidden_size = 64
+ intermediate_size = 128
+
+
+class MockLlama4Experts(nn.Module):
+
+ def __init__(self, num_experts=8, ffn_hidden=128, hidden_size=64):
+ super().__init__()
+ self.gate_up_proj = nn.Parameter(torch.randn(num_experts, hidden_size, 2 * ffn_hidden))
+ self.down_proj = nn.Parameter(torch.randn(num_experts, ffn_hidden, hidden_size))
+
+
+class MockSharedExpert(nn.Module):
+
+ def __init__(self, hidden_size=64):
+ super().__init__()
+ self.up_proj = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.gate_proj = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.down_proj = nn.Linear(hidden_size, hidden_size, bias=False)
+
+
+class MockLlama4MoEBlock(nn.Module):
+
+ def __init__(self, num_experts=8, ffn_hidden=128, hidden_size=64):
+ super().__init__()
+ self.router = nn.Linear(hidden_size, num_experts, bias=False)
+ self.experts = MockLlama4Experts(num_experts, ffn_hidden, hidden_size)
+ self.shared_expert = MockSharedExpert(hidden_size)
+
+
+class MockLlama4Transformer(nn.Module):
+
+ def __init__(self, num_layers=2, num_experts=8):
+ super().__init__()
+ self.config = MockLlama4Config()
+ self.config.num_local_experts = num_experts
+ self.model = nn.Module()
+ self.model.layers = nn.ModuleList([self._make_layer(num_experts) for _ in range(num_layers)])
+
+ @staticmethod
+ def _make_layer(num_experts):
+ layer = nn.Module()
+ layer.feed_forward = MockLlama4MoEBlock(num_experts)
+ return layer
+
+
+class MockDeepSeekV3Config:
+ model_type = "deepseek_v3"
+ n_routed_experts = 8
+ num_experts_per_tok = 2
+ hidden_size = 64
+ moe_intermediate_size = 128
+
+
+class MockDeepSeekV3Expert(nn.Module):
+
+ def __init__(self, hidden_size=64, ffn_hidden=128):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, ffn_hidden, bias=False)
+ self.up_proj = nn.Linear(hidden_size, ffn_hidden, bias=False)
+ self.down_proj = nn.Linear(ffn_hidden, hidden_size, bias=False)
+
+
+class MockDeepSeekV3MoEBlock(nn.Module):
+
+ def __init__(self, num_experts=8, ffn_hidden=128, hidden_size=64):
+ super().__init__()
+ self.gate = nn.Linear(hidden_size, num_experts, bias=False)
+ self.experts = nn.ModuleList([MockDeepSeekV3Expert(hidden_size, ffn_hidden) for _ in range(num_experts)])
+ self.shared_experts = MockSharedExpert(hidden_size)
+
+
+class MockDeepSeekV3Transformer(nn.Module):
+
+ def __init__(self, num_layers=2, num_experts=8):
+ super().__init__()
+ self.config = MockDeepSeekV3Config()
+ self.config.n_routed_experts = num_experts
+ self.model = nn.Module()
+ self.model.layers = nn.ModuleList([self._make_layer(num_experts) for _ in range(num_layers)])
+
+ @staticmethod
+ def _make_layer(num_experts):
+ layer = nn.Module()
+ layer.mlp = MockDeepSeekV3MoEBlock(num_experts)
+ return layer
+
+
+class TestAutoEPConfig:
+
+ def test_parse_and_validate_enabled_size_contract(self):
+ disabled = parse_autoep_config({})
+ assert disabled.enabled is False
+ assert disabled.autoep_size == 1
+ assert disabled.load_balance_coeff is None
+ assert disabled._load_balance_coeff_explicit is False
+
+ config = parse_autoep_config({
+ "enabled": True,
+ "autoep_size": 4,
+ "preset_model": "mixtral",
+ "load_balance_coeff": None,
+ "score_apply": "pre",
+ "route_scale": 2.0,
+ })
+
+ assert config.enabled is True
+ assert config.autoep_size == 4
+ assert config.preset_model == "mixtral"
+ assert config.load_balance_coeff is None
+ assert config._load_balance_coeff_explicit is True
+ assert config.score_apply == "pre"
+ assert config.route_scale == 2.0
+ validate_autoep_config(config, world_size=4, pp_size=1, tp_size=1, sp_size=1)
+
+ @pytest.mark.parametrize("value", UNSUPPORTED_LOAD_BALANCE_VALUES)
+ def test_load_balance_coeff_rejected_at_parse(self, value):
+ with pytest.raises(ValueError) as exc_info:
+ parse_autoep_config({"enabled": True, "load_balance_coeff": value})
+ assert_load_balance_coeff_rejection_message(exc_info.value, value)
+
+ @pytest.mark.parametrize("enabled", [True, False])
+ @pytest.mark.parametrize("value", [0.01, False, "0.01"])
+ def test_load_balance_coeff_rejected_by_validate(self, enabled, value):
+ config = AutoEPConfig(enabled=enabled, load_balance_coeff=value)
+
+ with pytest.raises(ValueError) as exc_info:
+ validate_autoep_config(config, world_size=1, pp_size=1, tp_size=1, sp_size=1)
+ assert_load_balance_coeff_rejection_message(exc_info.value, value)
+
+ def test_ep_size_validation_rejects_invalid_topology(self):
+ with pytest.raises(ValueError, match="AutoTP"):
+ validate_autoep_config(AutoEPConfig(enabled=True, autoep_size=2),
+ world_size=8,
+ pp_size=1,
+ tp_size=2,
+ sp_size=1)
+ with pytest.raises(ValueError, match="must divide the stage size"):
+ validate_autoep_config(AutoEPConfig(enabled=True, autoep_size=3),
+ world_size=8,
+ pp_size=1,
+ tp_size=1,
+ sp_size=1)
+ with pytest.raises(ValueError, match="exceeds num_experts"):
+ validate_autoep_post_detection(AutoEPConfig(enabled=True, autoep_size=16), [_make_spec(num_experts=8)])
+
+ def test_configure_expert_parallel_uses_engine_mpu_sequence_parallel_size(self, monkeypatch):
+
+ class SequenceParallelMPU:
+
+ def get_sequence_parallel_world_size(self):
+ return 2
+
+ class EmptyAutoEP:
+
+ def __init__(self, model, config):
+ pass
+
+ def ep_parser(self):
+ return []
+
+ observed = {}
+
+ def record_validate(config, world_size, pp_size, tp_size, sp_size):
+ observed["validate"] = {
+ "world_size": world_size,
+ "pp_size": pp_size,
+ "tp_size": tp_size,
+ "sp_size": sp_size,
+ }
+
+ def record_create(**kwargs):
+ observed["create"] = kwargs
+
+ monkeypatch.setattr(groups, "mpu", None)
+ monkeypatch.setattr(groups, "_get_sequence_parallel_world_size", lambda: 1)
+ monkeypatch.setattr(groups, "_create_expert_and_data_parallel", record_create)
+ monkeypatch.setattr(groups, "_get_expert_parallel_group", lambda name: object())
+ monkeypatch.setattr(ds_engine.dist, "get_world_size", lambda: 4)
+ monkeypatch.setattr(ds_engine.dist, "get_rank", lambda group=None: 0)
+ monkeypatch.setattr("deepspeed.module_inject.auto_ep.AutoEP", EmptyAutoEP)
+ monkeypatch.setattr("deepspeed.module_inject.auto_ep_config.validate_autoep_config", record_validate)
+
+ engine = object.__new__(DeepSpeedEngine)
+ engine.mpu = SequenceParallelMPU()
+ engine._config = SimpleNamespace(
+ expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=2),
+ tensor_parallel_config=SimpleNamespace(autotp_size=1),
+ use_data_before_expert_parallel_=False,
+ )
+
+ engine._configure_expert_parallel(model=nn.Module())
+
+ assert groups.mpu is None
+ assert observed["validate"]["sp_size"] == 2
+ assert observed["create"]["mp_size"] == 2
+ assert observed["create"]["mp_mode"] == "sp"
+
+ def test_autoep_sequence_parallel_size_falls_back_to_groups_helper(self, monkeypatch):
+ monkeypatch.setattr(groups, "_get_sequence_parallel_world_size", lambda: 3)
+
+ engine = object.__new__(DeepSpeedEngine)
+ engine.mpu = object()
+
+ assert engine._autoep_sequence_parallel_world_size() == 3
+
+ def test_preset_registry_core_contracts(self):
+ assert set(PRESET_MODELS) == {"mixtral", "qwen3_moe", "qwen3_5_moe", "deepseek_v2", "deepseek_v3", "llama4"}
+ assert preset_name_for_hf_model_type("mixtral") == "mixtral"
+ assert preset_name_for_hf_model_type("qwen2_moe") == "qwen3_moe"
+ assert preset_name_for_hf_model_type("llama4_text") == "llama4"
+
+ qwen35 = unsupported_preset_for_hf_model_type("qwen3_5_moe")
+ assert qwen35 is not None
+ assert "qwen3_5_moe_text" in qwen35[1].unsupported_hf_model_type_notes["qwen3_5_moe"]
+ assert PRESET_MODELS["deepseek_v2"].supports_expert_bias is False
+ assert PRESET_MODELS["deepseek_v3"].unsupported_router_bias_names == ("e_score_correction_bias", )
+ assert PRESET_MODELS["llama4"].has_shared_experts is True
+
+ @pytest.mark.parametrize("value", ["2.5", True, float("nan"), float("inf")])
+ def test_invalid_routed_scaling_factor_rejected(self, value):
+ with pytest.raises(ValueError, match="routed_scaling_factor"):
+ _resolve_route_scale(AutoEPConfig(enabled=True, routed_scaling_factor=value), None)
+
+
+class TestRoutingAndLayerSemantics:
+
+ def test_router_route_scale_and_group_limited_routing(self):
+ base = TokenChoiceTopKRouter(64, 8, 4, 2, 2, "softmax", False, 1.0, False)
+ scaled = TokenChoiceTopKRouter(64, 8, 4, 2, 2, "softmax", False, 2.5, False)
+ scaled.load_state_dict(base.state_dict())
+ x = torch.randn(50, 64)
+
+ base_scores, base_experts, base_counts = base(x)
+ scaled_scores, scaled_experts, scaled_counts = scaled(x)
+
+ assert torch.equal(scaled_experts, base_experts)
+ assert torch.allclose(scaled_scores, base_scores * 2.5, atol=1e-5)
+ assert torch.equal(scaled_counts, base_counts)
+ assert base_counts.shape == (8, )
+
+ def test_grouped_experts_and_token_reorderer(self):
+ experts = GroupedExperts(dim=64, hidden_dim=128, num_experts=4, use_grouped_mm=False)
+ nn.init.normal_(experts.w1, std=0.02)
+ nn.init.normal_(experts.w2, std=0.02)
+ nn.init.normal_(experts.w3, std=0.02)
+ out = experts(torch.randn(8, 64), torch.tensor([2, 2, 2, 2]))
+ assert out.shape == (8, 64)
+ assert not torch.isnan(out).any()
+
+ top_scores = torch.randn(20, 2)
+ selected_experts = torch.randint(0, 4, (20, 2))
+ scores_sorted, indices_sorted, counts = TokenReorderer(num_experts=4, top_k=2)(top_scores, selected_experts)
+ assert scores_sorted.shape == (40, )
+ assert set(indices_sorted.tolist()) == set(range(40))
+ assert torch.equal(counts, torch.bincount(selected_experts.reshape(-1), minlength=4).to(counts.dtype))
+
+ def test_score_application_and_combine(self):
+ x = torch.randn(4, 8)
+ scores = torch.tensor([0.25, 0.5, 0.75, 1.0])
+ expected = x.float() * scores.reshape(-1, 1)
+ torch.testing.assert_close(apply_scores_before_experts_if_enabled(x, scores, "pre"), expected.to(x.dtype))
+
+ spec = _make_spec(score_apply="post")
+ assert resolve_score_apply_mode(spec, "auto") == "post"
+ expert_output = torch.ones(4, 8)
+ top_scores = torch.tensor([[0.6, 0.4], [0.7, 0.3]])
+ out = combine_from_routed(expert_output, top_scores, torch.arange(4), 2, "post", "weighted_sum", (1, 2, 8))
+ torch.testing.assert_close(out[0, 0], torch.ones(8))
+
+ def test_autoep_layer_forward_and_expert_bias_rejection(self):
+ source = MockMoEBlock(num_experts=4, ffn_hidden=128, hidden_size=64)
+ layer = AutoEPMoELayer(_make_spec(route_scale=2.5),
+ source,
+ ep_size=1,
+ ep_rank=0,
+ config=_runtime_config(enabled=True, autoep_size=1))
+ out = layer(torch.randn(2, 8, 64))
+ assert layer._is_autoep_layer is True
+ assert layer.num_experts == 4
+ assert layer.router.route_scale == pytest.approx(2.5)
+ assert out.shape == (2, 8, 64)
+ assert not torch.isnan(out).any()
+
+ with pytest.raises(ValueError, match="load_balance_coeff/expert_bias"):
+ AutoEPMoELayer(_make_spec(model_family="no_bias_family", supports_expert_bias=False),
+ source,
+ ep_size=1,
+ ep_rank=0,
+ config=AutoEPConfig(enabled=True, autoep_size=1, load_balance_coeff=0.02))
+
+
+class TestModelDetectionAndReplacement:
+
+ def test_mixtral_detect_replace_and_mock_forward(self):
+ model = MockMoETransformer(num_layers=2, moe_every_n=1)
+ auto_ep = AutoEP(model, _runtime_config(enabled=True, autoep_size=1, preset_model="mixtral"))
+ specs = auto_ep.ep_parser()
+
+ assert len(specs) == 2
+ assert specs[0].model_family == "mixtral"
+ auto_ep.replace_moe_layer(specs[0], ep_size=1, ep_rank=0)
+ assert isinstance(model.model.layers[0].mlp, AutoEPMoELayer)
+ assert model(torch.randn(1, 4, 64)).shape == (1, 4, 100)
+
+ def test_hf_mixtral_causal_lm_matches_autoep_with_router_logits(self):
+ transformers = pytest.importorskip("transformers")
+ skip_unless_transformers_has(transformers,
+ "MixtralConfig",
+ "MixtralForCausalLM",
+ min_version="5.0.0",
+ reason="Mixtral AutoEP router-logit capture")
+
+ torch.manual_seed(1234)
+ config = tiny_mixtral_config(transformers)
+ native_model, autoep_model = state_matched_models(transformers.MixtralForCausalLM, config)
+ replace_autoep_layers(autoep_model, "mixtral")
+ assert_causal_lm_outputs_close(native_model,
+ autoep_model,
+ output_router_logits=True,
+ compare_router_logits=True,
+ compare_aux_loss=True,
+ compare_logits=False)
+
+ def test_llama4_detection_and_repack_contract(self):
+ model = MockLlama4Transformer(num_layers=1, num_experts=8)
+ specs = AutoEP(model, _runtime_config(enabled=True, autoep_size=2, preset_model="llama4")).ep_parser()
+ spec = specs[0]
+
+ assert spec.model_family == "llama4"
+ assert spec.score_apply == "pre"
+ assert spec.return_router_logits is True
+ assert spec.moe_output_shape == "flat"
+ assert spec.has_shared_experts is True
+
+ experts = model.model.layers[0].feed_forward.experts
+ w1, w2, w3 = repack_expert_weights(experts, spec, ep_rank=0, ep_size=2)
+ assert w1.shape == (4, 128, 64)
+ assert w2.shape == (4, 64, 128)
+ assert w3.shape == (4, 128, 64)
+
+ def test_hf_llama4_autoep_direct_moe_returns_flat_contract(self):
+ transformers = pytest.importorskip("transformers")
+ skip_unless_transformers_has(transformers,
+ "Llama4ForCausalLM",
+ "Llama4TextConfig",
+ min_version="5.0.0",
+ reason="Llama4 AutoEP preset")
+ config = tiny_llama4_text_config(transformers)
+ native_model, autoep_model = state_matched_models(transformers.Llama4ForCausalLM, config)
+ replace_autoep_layers(autoep_model, "llama4")
+ native_moe = native_model.model.layers[0].feed_forward
+ autoep_moe = next(module for module in autoep_model.modules() if isinstance(module, AutoEPMoELayer))
+ hidden_states = torch.randn(2, 5, 32)
+
+ with torch.no_grad():
+ native_output, native_router_logits = native_moe(hidden_states)
+ autoep_output, autoep_router_logits = autoep_moe(hidden_states)
+
+ assert autoep_output.shape == (10, 32)
+ assert autoep_router_logits.shape == (10, 4)
+ torch.testing.assert_close(autoep_output, native_output, rtol=1e-5, atol=1e-6)
+ torch.testing.assert_close(autoep_router_logits, native_router_logits, rtol=1e-5, atol=1e-6)
+
+ def test_qwen_adapter_guards(self, monkeypatch):
+ monkeypatch.setattr(get_preset_adapter("qwen3_moe"), "_installed_transformers_version", lambda: "5.0.0")
+ model = MockMoETransformer(num_layers=1, num_experts=4, moe_every_n=1)
+ model.config.model_type = "qwen2_moe"
+ model.config.num_experts = model.config.num_local_experts
+
+ specs = AutoEP(model, _runtime_config(enabled=True, autoep_size=1)).ep_parser()
+
+ assert len(specs) == 1
+ assert specs[0].model_family == "qwen3_moe"
+
+ model.config.model_type = "qwen3_5_moe"
+ with pytest.raises(ValueError, match="qwen3_5_moe_text"):
+ AutoEP(model, _runtime_config(enabled=True, autoep_size=1))._resolve_presets()
+
+ def test_deepseek_v3_detection_and_router_bias_guard(self, monkeypatch):
+ monkeypatch.setattr(get_preset_adapter("deepseek_v3"), "_installed_transformers_version", lambda: "5.0.0")
+ model = MockDeepSeekV3Transformer(num_layers=1, num_experts=8)
+ auto_ep = AutoEP(model, _runtime_config(enabled=True, autoep_size=2))
+ specs = auto_ep.ep_parser()
+
+ assert len(specs) == 1
+ assert specs[0].model_family == "deepseek_v3"
+ assert specs[0].expert_storage == "module_list"
+ assert specs[0].expert_w1_name == "gate_proj"
+ assert specs[0].has_shared_experts is True
+
+ model.model.layers[0].mlp.gate.register_buffer("e_score_correction_bias", torch.ones(8))
+ with pytest.raises(ValueError, match="e_score_correction_bias"):
+ auto_ep.replace_moe_layer(specs[0], ep_size=2, ep_rank=0)
diff --git a/tests/unit/moe/test_fused_expert_layout.py b/tests/unit/moe/test_fused_expert_layout.py
new file mode 100644
index 000000000000..69b90e848ee5
--- /dev/null
+++ b/tests/unit/moe/test_fused_expert_layout.py
@@ -0,0 +1,38 @@
+# Copyright (c) DeepSpeed Team.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from deepspeed.moe.fused_expert_layout import classify_fused_gate_up_layout
+
+
+def test_classifies_gate_up_first():
+ layout = classify_fused_gate_up_layout((8, 256, 64), (8, 64, 128))
+
+ assert layout is not None
+ assert layout.layout == "gate_up_first"
+ assert layout.hidden_size == 64
+ assert layout.ffn_hidden_size == 128
+ assert layout.needs_transpose is False
+
+
+def test_classifies_hidden_first():
+ layout = classify_fused_gate_up_layout((8, 64, 256), (8, 128, 64))
+
+ assert layout is not None
+ assert layout.layout == "hidden_first"
+ assert layout.hidden_size == 64
+ assert layout.ffn_hidden_size == 128
+ assert layout.needs_transpose is True
+
+
+def test_returns_none_for_unknown_layout():
+ assert classify_fused_gate_up_layout((8, 64, 64), (8, 64, 64)) is None
+
+
+def test_returns_none_for_odd_inner_dim():
+ assert classify_fused_gate_up_layout((8, 255, 64), (8, 64, 127)) is None
+
+
+def test_returns_none_for_non_3d():
+ assert classify_fused_gate_up_layout((64, 64), (64, 64)) is None