Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Open
phu0ngng wants to merge 9 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Open

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng wants to merge 9 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng commented May 22, 2026

Summary

Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the nvte_ep_* C API, a Python wrapper with custom_vjp for autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCL ncclEpDispatch/ncclEpCombine are exposed as XLA primitives and work with CUDA-graph capture.

Implementation

Public Python API (transformer_engine/jax/ep.py)

from transformer_engine.jax.ep import (
    EpHandle,        # opaque (id, handle_mem) pair from ep_prepare
    ep_bootstrap,    # one-shot per-process: init NCCL comm + nvte_ep_initialize
    ep_dispatch,     # custom_vjp-wrapped dispatch 
    ep_combine,      # custom_vjp-wrapped combine

ep_dispatch / ep_combine are jax.custom_vjp functions: forward is the FFI primitive, backward calls the matching nvte_ep_*_bwd FFI primitive directly (no ep_prepare in the bwd — routing state is already cached in handle.mem). Note that ep_dispatch also calls ep_prepare in the forward path, which all-gathers and preprocesses routing maps.

XLA FFI bindings (transformer_engine/jax/csrc/extensions/ep.cpp)

Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries — EpPrepareHandler, EpDispatchHandler, EpCombineHandler, EpDispatchBwdHandler, EpCombineBwdHandler — each calling the corresponding nvte_ep_* C entry point. All marked FFI_CudaGraph_Traits so they capture cleanly. handle_id is a static FFI attribute baked at jit trace time.

Primitives + Python layer (transformer_engine/jax/cpp_extensions/ep.py, +951 lines)

Standard TE primitive plumbing: abstract_eval (shape/dtype inference), lowering, impl, outer_primitive registration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).

Sharding (transformer_engine/jax/sharding.py, +12 lines)

Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.

Build wiring (build_tools/jax.py, +41 lines)

Threads NCCL EP linkage through the JAX transformer_engine_jax extension. No new top-level build flags; rides on the parent PR's NVTE_BUILD_WITH_NCCL_EP.

Tests & example

  • tests/jax/test_multi_process_ep.py (+690 lines): 13 tests covering bootstrap, ep_prepare shape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing), custom_vjp fwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).
  • tests/jax/multi_process_launch_ep.sh: 4-rank launcher; sets XLA_FLAGS to keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).
  • examples/jax/ep/ep_moe.py (+394 lines) + run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison --check that verifies fwd+bwd vs a single-process reference.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR lands the JAX bindings for Expert Parallelism: XLA FFI handlers over the nvte_ep_* C API, jax.custom_vjp-wrapped ep_dispatch/ep_combine with SPMD partitioning rules, mesh-aware sharding via a new ep_resource axis, build wiring for the NCCL EP submodule, a 13-test multi-process suite, and an end-to-end MoE example.

  • Core EP ops are exposed as XLA FFI primitives with FFI_CudaGraph_Traits, proper abstract_eval/lowering/partition plumbing, and sharding constraints re-pinned in backward passes to prevent XLA transpose from dropping the EP axis.
  • Bootstrap (ep_bootstrap) eagerly initialises the NCCL communicator after a UID all-gather; pre-flight guards for divisibility and single-device-per-process all use raise ValueError.
  • Build integration adds conditional NCCL EP linkage via NVTE_BUILD_WITH_NCCL_EP; ep_api_stub.cpp provides throwing stubs for non-EP builds.

Confidence Score: 4/5

Safe to merge after adding a dtype guard for topk_weights; without it, mixed-precision MoE training silently produces wrong routing weight values with no error or warning.

The topk_weights dtype assumption is an active data-corruption path in mixed-precision MoE training: the C++ FFI hardcodes DType::kFloat32 without checking the buffer element type, and the Python abstract eval deletes the aval before any inspection. Everything else — bootstrap validation, sharding rules, VJP correctness, build wiring — is solid.

transformer_engine/jax/csrc/extensions/ep.cpp (EpDispatchFFI topk_weights wrapper) and transformer_engine/jax/cpp_extensions/ep.py (EpDispatchPrimitive.abstract) need coordinated dtype guards before the DType::kFloat32 assumption.

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions/ep.cpp Five XLA FFI handlers wrapping nvte_ep_* C API; topk_weights unconditionally wrapped as DType::kFloat32 without element-type guard, handle_mem mutated in-place while declared read-only Arg.
transformer_engine/jax/cpp_extensions/ep.py JAX primitives with abstract eval/lowering/partitioning for all five EP ops; topk_weights_aval deleted without dtype check in EpDispatchPrimitive.abstract.
transformer_engine/jax/ep.py Public Python EP API with custom_vjp wrappers and sharding re-pinning in bwd; overly-broad except clause in _allgather_uid.
build_tools/jax.py Adds NCCL EP linkage; arch guard fires only for explicitly-listed sub-90 arches, silent when NVTE_CUDA_ARCHS is unset.
transformer_engine/jax/sharding.py Adds ep_resource to MeshResource and ep_axis_size() helper; minimal and correct.
transformer_engine/common/ep/ep_api.cpp Thin C API delegation to EPBackend with null-checks on handle.mem before every call.
transformer_engine/common/ep/ep_api_stub.cpp Throwing stubs for non-EP builds; nvte_ep_shutdown correctly a no-op.
tests/jax/test_multi_process_ep.py 13-test multi-process suite covering bootstrap, shape contracts, fwd+bwd correctness, and HLO inspection.
examples/jax/ep/ep_moe.py End-to-end MoE example with dp=ep=2 mesh and reference comparison check for fwd+bwd correctness.

Reviews (7): Last reviewed commit: "jax/ep: introduce per-layer EpHandle, dr..." | Re-trigger Greptile

Comment thread build_tools/jax.py
Comment thread build_tools/jax.py
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp Outdated
Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts,
Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) {
auto topk_dims = topk_idx.dimensions();
NVTE_CHECK(topk_dims.size() >= 2,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.

@phu0ngng phu0ngng requested a review from tdophung May 22, 2026 15:51
@phu0ngng
Copy link
Copy Markdown
Collaborator Author

I would appreciate your help to review this PR @tdophung @jberchtold-nvidia!
Please focus on the changes in the JAX side, as the TE/Common ones will be discussed in #3034

Comment thread examples/jax/ep/ep_moe.py Outdated
Comment thread tests/jax/multi_process_launch_ep.sh Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread examples/jax/ep/ep_moe.py
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated
Comment on lines +81 to +82
assert ret == 0, f"ncclGetUniqueId failed with code {ret}"
uid_bytes = bytes(uid_arr)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 assert disabled by -O in ctypes UID path

assert ret == 0 is silently elided when Python runs under the -O optimisation flag (common in production or Numba/Conda environments). If ncclGetUniqueId fails, uid_bytes would be all zeros; the all-gather propagates those zeros to every rank in the EP group, causing ncclCommInitRank to either produce mismatched communicators or hang indefinitely with no diagnostic message.

Suggested change
assert ret == 0, f"ncclGetUniqueId failed with code {ret}"
uid_bytes = bytes(uid_arr)
ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p))
if ret != 0:
raise RuntimeError(f"ncclGetUniqueId failed with code {ret}")

phu0ngng added 4 commits May 23, 2026 19:36
…em_reloc gating

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
ep_size,
num_experts,
max_tokens_per_rank,
recv_capacity_per_rank,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the ep.h file def of NVTEEpGroupConfig, you used max_recv_tokens_per_rank instead of this. Just for consistency, maybe we should use the same names?

f"ep_bootstrap requires world_size >= 4 (got {world_size}); NCCL EP requires"
" at least 4 ranks on the node for its HT mode."
)
UID_SIZE = 128
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: after looking into the headers in nccl I figured out what this is. However might be helpful to have an inline comment to say what this is. Like # NCCL_UNIQUE_ID_BYTES from nccl.h to store host name, listening port, etc.

def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, res, g_outputs):
del recv_capacity_per_rank, dispatch_output_per_expert_alignment
handle, out_leading, top_k = res
# Re-pin cotangent sharding: XLA transpose can drop the EP axis on a
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now understand that the sharding for other fwd-output cotangents can be lost when propagated to bwd. But is this a fault of JAX that we should ask to be fixed? Did you write this defensively because you ran into a bug where it was trying to read the entire global tensor?

Comment on lines +890 to +891
f = _sys._getframe(1)
cache_key = (f.f_code.co_filename, f.f_lineno, top_k, alignment)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would the public ep_dispatch wrapper interacts badly with this? Every ep_dispatch(...) call in a user's program ultimately routes through
token_counts, handle = tex.ep_prepare(topk_idx, dispatch_output_per_expert_alignment)
from jax/ep.py:191

so _sys._getframe(1) always sees the same (jax/ep.py, 191). That means a model with multiple MoE layers all using ep_dispatch end up sharing one handle_id across layers, which would corrupt the cache for the routing state?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you don't see it because the test is only 1 layer? I am not entirely surer because I have not tried running a case with multiple layers. This is just from reading the code

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is valid. Thanks for finding it!
I will expose an option so that the user can host the handle_id per layer.

phu0ngng added 2 commits May 29, 2026 11:10
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ache

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment on lines +264 to +268
Buffer_Type topk_idx, Buffer_Type tokens, Buffer_Type topk_weights,
Result_Type recv_tokens, Result_Type recv_topk_weights,
Result_Type workspace, EpDispatchConfig config) {
(void)ep_state;
auto token_dims = tokens.dimensions();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 topk_weights dtype unconditionally assumed float32 — silent data corruption

EpDispatchFFI wraps topk_weights as DType::kFloat32 regardless of the buffer's actual element type. If a caller passes bfloat16 or float16 weights (common in mixed-precision MoE training), the bytes are silently reinterpreted, producing completely wrong routing weights without any error. The Python abstract eval compounds the problem: topk_weights_aval is deleted before any dtype inspection, so JAX tracing also provides no defence. A dtype guard is needed in both the C++ FFI handler and in EpDispatchPrimitive.abstract.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently our router is fixed to fp32, so there is no chance it will become bf16 or f16. Whenever we change the router to output more than just 1 datatype, then we can change this one too

arg_shardings = tuple(a.sharding for a in arg_infos)
out_shardings = [
NamedSharding(mesh, PartitionSpec(*resolved)),
NamedSharding(mesh, PartitionSpec(*resolved, None)),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that len(resolved) is always num_leading_dims + 1 based on the _resolve_out_partition_spec, wouldnt it make this partitionspec have 1 more than needed dim? (so num_leading_dims + 1 + 1?)

maybe we should fix it with:

Suggested change
NamedSharding(mesh, PartitionSpec(*resolved, None)),
NamedSharding(mesh, PartitionSpec(*resolved[:-1], None)),

I notice that in your test script you already clarified that XLA will drop all trailing None in the partitionspec, hence there wasn't any error message for mismatching number of dims. However, why not size it exactly the same number of dimensions as grad_topk_weights_aval?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants