You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A proposal for RDMA weight streaming: a master node reads model weights from disk and broadcasts only the slices each rank needs via all_sum, eliminating the requirement to store the full model on every node. This builds on the existing mlx_lm.share and sharded_load infrastructure.
Motivation
Currently, distributed inference in MLX requires the full model to be stored on every participating node:
TP (Tensor Parallel):sharded_load downloads all weight files to every node, then model.shard() selects the needed slices at eval time. Each node stores the full model on disk even though it only uses 1/N of the weights in memory.
PP (Pipeline Parallel):sharded_load already optimizes this — it downloads only the files needed for the local pipeline stage. This is the closest to what we're proposing.
EP (Expert Parallel, MoE): No built-in support. Each node stores and loads the full model.
For large models this is a significant constraint:
Model
FP16 Size
4-Node TP4 Disk Usage (current)
With Weight Streaming
Llama 405B
810 GB
3.2 TB (810 GB × 4)
810 GB (master only)
DeepSeek V3
670 GB
2.7 TB
670 GB
Kimi K2.5
612 GB
2.4 TB
612 GB
With mlx_lm.share demonstrating 5+ GB/s broadcast throughput over Thunderbolt 5 RDMA, streaming weights at startup is fast enough to be practical.
What Exists Today
Component
Status
What It Does
mlx_lm.share
Working
Broadcasts entire directories via all_sum at 5+ GB/s
sharded_load (PP)
Working
Downloads only the files each pipeline stage needs
sharded_load (TP)
Working but redundant
Downloads full model to every node, shards in memory
model.shard()
Working
Knows how to slice weight matrices for TP
model.pipeline()
Working
Knows which layers belong to which pipeline stage
MoE expert placement
Not implemented
No built-in expert-level distribution
Proposed Approach
Phase 1: TP Weight Streaming (Low-Hanging Fruit)
The master node (rank 0) reads each safetensor file, and instead of broadcasting the entire file, slices the weight matrices according to the TP sharding plan and broadcasts only the relevant slice to each rank.
Current TP flow:
Every node: read full model from disk → shard() → eval (loads 1/N into RAM)
Disk: N copies of full model
Proposed TP flow:
Rank 0: read full model from disk → slice weights → broadcast slices via all_sum
Rank 1-N: receive only their slices → load into RAM
Disk: 1 copy on rank 0 only
This requires:
Rank 0 inspects the model's shard plan (which model.shard() already computes)
For each weight file, rank 0 reads it, extracts each rank's slice
Broadcasts each slice via all_sum (rank 0 sends real data, others send zeros — same pattern as mlx_lm.share)
Each rank receives only its portion and loads directly into the model
The safetensors.index.json file already maps weight names to files, and model.shard() already knows the slicing dimensions. The main engineering work is connecting these two systems.
Phase 2: PP Weight Streaming
PP is partially solved — sharded_load already knows which files each stage needs and only downloads those. The missing piece is sourcing files from a master node via RDMA instead of from disk/HuggingFace:
Current PP flow:
Each node: download only needed files from HF → load
Disk: partial model per node (already efficient)
Proposed PP flow:
Rank 0: read files from disk → broadcast only the files each stage needs
Rank 1-N: receive only their stage's files → load
Disk: 1 copy on rank 0 only
This is simpler than TP streaming since it operates at the file level (no matrix slicing). sharded_load already computes the file-to-stage mapping.
Phase 3: MoE Expert Placement
This is the most impactful but also most complex. For MoE models (Mixtral, DeepSeek V3, Kimi K2.5), experts are independent weight blocks that can be distributed across nodes:
Proposed EP flow:
Rank 0: read model → distribute expert subsets to each node
Each node: holds 1/N of experts in memory
Inference: router selects active experts → each node runs its local experts → all_sum collects results
Benefits:
A 4-node cluster with 512 GB each can serve a model with up to ~2 TB of expert weights
Only active experts compute per token — idle experts consume zero GPU time
Could extend to demand-paging: evict cold experts, stream hot ones at 5 GB/s
This requires deeper integration with the model architecture — the router needs to know which experts are on which nodes, and the forward pass needs all_sum collectives after expert computation.
The most advanced version: experts are streamed on-demand during inference based on routing patterns. Hot experts stay resident, cold experts are evicted, and when a cold expert is needed, it's streamed from the master at 5 GB/s (~50ms for a typical expert block).
This is speculative and may not be practical for real-time inference, but for batch/offline workloads the latency could be acceptable.
Benefits Beyond Disk Space
Model switching without pre-staging — swap from one model to another by streaming from the master node. No need to copy hundreds of GB to every node first.
Serve models larger than any single node's storage — a node with 1 TB NVMe can participate in serving a 2 TB model if it only holds its shard.
Faster cold start for PP — pipeline stages can begin processing as soon as their layers arrive. Stage 0 starts inference while stages 2-3 are still receiving weights.
Dynamic reconfiguration — switch from TP2 to TP4 by re-streaming with different slicing. No file management, no pre-staging, just a different sharding plan.
Single source of truth — model updates, quantization changes, and fine-tune checkpoints only need to exist on the master node.
Hardware Context
Tested on a 5-node cluster of Mac Studio M3 Ultra (512 GB each) connected via Thunderbolt 5 full mesh (4 cables per node):
mlx_lm.share measured at 5.2 GB/s broadcast to 3 nodes simultaneously
Summary
A proposal for RDMA weight streaming: a master node reads model weights from disk and broadcasts only the slices each rank needs via
all_sum, eliminating the requirement to store the full model on every node. This builds on the existingmlx_lm.shareandsharded_loadinfrastructure.Motivation
Currently, distributed inference in MLX requires the full model to be stored on every participating node:
sharded_loaddownloads all weight files to every node, thenmodel.shard()selects the needed slices at eval time. Each node stores the full model on disk even though it only uses 1/N of the weights in memory.sharded_loadalready optimizes this — it downloads only the files needed for the local pipeline stage. This is the closest to what we're proposing.For large models this is a significant constraint:
With
mlx_lm.sharedemonstrating 5+ GB/s broadcast throughput over Thunderbolt 5 RDMA, streaming weights at startup is fast enough to be practical.What Exists Today
mlx_lm.shareall_sumat 5+ GB/ssharded_load(PP)sharded_load(TP)model.shard()model.pipeline()Proposed Approach
Phase 1: TP Weight Streaming (Low-Hanging Fruit)
The master node (rank 0) reads each safetensor file, and instead of broadcasting the entire file, slices the weight matrices according to the TP sharding plan and broadcasts only the relevant slice to each rank.
This requires:
model.shard()already computes)all_sum(rank 0 sends real data, others send zeros — same pattern asmlx_lm.share)The
safetensors.index.jsonfile already maps weight names to files, andmodel.shard()already knows the slicing dimensions. The main engineering work is connecting these two systems.Phase 2: PP Weight Streaming
PP is partially solved —
sharded_loadalready knows which files each stage needs and only downloads those. The missing piece is sourcing files from a master node via RDMA instead of from disk/HuggingFace:This is simpler than TP streaming since it operates at the file level (no matrix slicing).
sharded_loadalready computes the file-to-stage mapping.Phase 3: MoE Expert Placement
This is the most impactful but also most complex. For MoE models (Mixtral, DeepSeek V3, Kimi K2.5), experts are independent weight blocks that can be distributed across nodes:
Benefits:
This requires deeper integration with the model architecture — the router needs to know which experts are on which nodes, and the forward pass needs
all_sumcollectives after expert computation.Phase 4: Dynamic Expert Demand-Paging (Speculative)
The most advanced version: experts are streamed on-demand during inference based on routing patterns. Hot experts stay resident, cold experts are evicted, and when a cold expert is needed, it's streamed from the master at 5 GB/s (~50ms for a typical expert block).
This is speculative and may not be practical for real-time inference, but for batch/offline workloads the latency could be acceptable.
Benefits Beyond Disk Space
Model switching without pre-staging — swap from one model to another by streaming from the master node. No need to copy hundreds of GB to every node first.
Serve models larger than any single node's storage — a node with 1 TB NVMe can participate in serving a 2 TB model if it only holds its shard.
Heterogeneous clusters ([BUG] Distributed inference OOMs on machines with different RAM size #1804) — nodes with different RAM sizes can receive proportional weight slices. A 64 GB node gets fewer layers/experts than a 512 GB node, with the master orchestrating the placement.
Faster cold start for PP — pipeline stages can begin processing as soon as their layers arrive. Stage 0 starts inference while stages 2-3 are still receiving weights.
Dynamic reconfiguration — switch from TP2 to TP4 by re-streaming with different slicing. No file management, no pre-staging, just a different sharding plan.
Single source of truth — model updates, quantization changes, and fine-tune checkpoints only need to exist on the master node.
Hardware Context
Tested on a 5-node cluster of Mac Studio M3 Ultra (512 GB each) connected via Thunderbolt 5 full mesh (4 cables per node):
mlx_lm.sharemeasured at 5.2 GB/s broadcast to 3 nodes simultaneouslyall_sum-based broadcast measured 3.7 GB/s (withoutasync_evalpipelining)At these speeds, streaming a 200 GB TP4 slice takes ~38 seconds — a one-time startup cost that eliminates terabytes of redundant storage.
Questions for the Team
sharded_loador create a new loading path?Group.split()proposal ([Enhancement] Group.split() support for JACCL and Ring backends (parity with NCCL #3172) #3205) would be useful for hybrid TP+EP — is that on the roadmap?Happy to contribute implementation work if this aligns with the project's direction.