Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 123 additions & 75 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,12 +867,93 @@ def local_permute(
sorted_experts_ids,
)

@staticmethod
def get_all_to_all_params_batch_sharded(
all_shards_group_sizes,
shard_id,
local_expert_size,
is_dispatch=True,
):
"""Generates granular offsets and sizes for ragged_all_to_all to avoid local sorting.

In standard batch-sharded MoE, tokens are initially distributed across sending devices.
This function calculates the necessary communication parameters so that each sending device
routes tokens directly to the correct receiving device and local expert.

Shape Key:
sd: Number of sending devices.
rd: Number of receiving devices (must equal sd).
te: Total number of experts across all devices.
le: Number of local experts per device (local_expert_size).
"""
# all_shards_group_sizes shape: [sd, te]
sd, te = all_shards_group_sizes.shape
# rd is the number of receiving devices.
rd = te // local_expert_size
le = local_expert_size
# The number of sending and receiving devices must be equal.
assert rd == sd

# r shape: [sending_device, receiving_device, local_expert]
r = all_shards_group_sizes.reshape((sd, rd, le))

# Transpose to [receiving_device, sending_device, local_expert].
# Flattening keeps local_expert in the minor-most dimension.
r_tokens_le_minor = r.transpose((1, 0, 2))
flat_r_tokens_le_minor = r_tokens_le_minor.reshape((rd, -1))

# Transpose to [receiving_device, local_expert, sending_device].
# Flattening keeps sending_device in the minor-most dimension.
r_tokens_sd_minor = r.transpose((1, 2, 0))
flat_r_tokens_sd_minor = r_tokens_sd_minor.reshape((rd, -1))

if is_dispatch:
# --- Dispatch (Send tokens to their assigned local experts) ---
# Send sizes: the amount of data the current shard needs to send to each expert.
send_sizes = all_shards_group_sizes[shard_id]

# Input offsets: local starting position for tokens sent to each expert.
# Calculated via exclusive cumulative sum over the local send amounts.
input_offsets = jnp.cumsum(send_sizes) - send_sizes

# Output offsets: destination layout in the receiving device's memory.
# The receiving device expects data for each local expert to be contiguous.
# Since `flat_r_tokens_sd_minor` organizes data by [receiving_device, local_expert, sending_device],
# we perform an exclusive cumulative sum along the sending_device dimension to compute offsets.
summed_reshaped = (jnp.cumsum(flat_r_tokens_sd_minor, axis=1) - flat_r_tokens_sd_minor).reshape((rd, le, sd))
output_offsets = summed_reshaped[:, :, shard_id].flatten()

# Receive sizes: the amount of data the current shard expects from all sending devices.
recv_sizes = flat_r_tokens_le_minor[shard_id]
else:
# --- Combine (Return tokens to their originating sending devices) ---
# Return tokens to their originating sending devices.
# Therefore, we send exactly the amounts we previously received in the dispatch phase.
send_sizes = flat_r_tokens_le_minor[shard_id]

# Input offsets: determined from the perspective of what was received from neighbors.
# We compute the cumulative sum over the previously received sizes, then transpose
# to match the order of devices.
expert_sizes = jnp.cumsum(flat_r_tokens_sd_minor[shard_id]) - flat_r_tokens_sd_minor[shard_id]
input_offsets = (
expert_sizes.reshape((le, sd)).transpose((1, 0)).flatten()
)

# Output offsets: destination layout in the original sending device's memory.
# We must place the tokens in the exact order they were originally sent.
all_batch_offsets = jnp.cumsum(all_shards_group_sizes, axis=1) - all_shards_group_sizes
output_offsets = all_batch_offsets.reshape((sd, rd, le))[:, shard_id, :].flatten()

# Receive sizes: equal to the amounts we originally sent.
recv_sizes = all_shards_group_sizes[shard_id]

return input_offsets, send_sizes, output_offsets, recv_sizes

@staticmethod
def get_all_to_all_params(
all_shards_group_sizes,
shard_id,
num_expert_parallelism,
is_batch_sharded=True,
):
"""Generates input offsets, send sizes, output offsets, and receive sizes used for ragged_all_to_all."""

Expand All @@ -882,78 +963,47 @@ class TransformStrategy(enum.Enum):
OUTPUT_OFFSET = enum.auto()
RECV_SIZE = enum.auto()

def transform_array(input_array, shard_id, strategy, is_batch_sharded):
def transform_array(input_array, shard_id, strategy):
"""Transforms the input array based on the specified strategy."""
# Prepares it for the usage with `ragged_all_to_all` API. The
# transformation determines how data is sent and received between shards.
if is_batch_sharded:
if strategy == TransformStrategy.INPUT_OFFSET:
# Index of input array for the send
local_array = input_array[shard_id]
return jnp.concatenate((jnp.array([0]), jnp.cumsum(local_array)[:-1]))
elif strategy == TransformStrategy.SEND_SIZE:
# Size of input array for the send
return input_array[shard_id]
elif strategy == TransformStrategy.OUTPUT_OFFSET:
# Received index in the target output
zero_row = jnp.zeros((1,) + input_array.shape[1:], dtype=input_array.dtype)
array_with_zeros = jnp.concatenate((zero_row, input_array), axis=0)
cumulated_array = jnp.cumsum(array_with_zeros, axis=0, dtype=input_array.dtype)
return cumulated_array[shard_id]
elif strategy == TransformStrategy.RECV_SIZE:
# Received size in the target output
return input_array[:, shard_id]
else:
raise ValueError(f"Unknown transform array strategy: {strategy}")

# If the batch is unsharded then we send the same data slice to all other
# shards. We also assume each shard will have the local processed inputs
# sorted to start from index 0. Finally, len(input_array.shape) == 1 since
# there is only one batch shard.
if strategy == TransformStrategy.INPUT_OFFSET:
# The data on each shard always starts at 0.
return jnp.zeros(num_expert_parallelism, dtype=input_array.dtype)
elif strategy == TransformStrategy.SEND_SIZE:
# The send amount is always the amount of data the current expert
# shard needs to process.
return jnp.repeat(input_array[shard_id], num_expert_parallelism)
elif strategy == TransformStrategy.OUTPUT_OFFSET:
# The offset in each shard will just be the start of the group which
# that shard is responsible for.
output_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(input_array[:-1])))[shard_id]
return jnp.repeat(output_offset, num_expert_parallelism)
# The amount that each shard receives from all other shards is
# equivalent to the group sizes (aka input_array).
elif strategy == TransformStrategy.RECV_SIZE:
# Received size in the target output
return input_array
else:
if strategy == TransformStrategy.INPUT_OFFSET:
# The data on each shard always starts at 0.
return jnp.zeros(num_expert_parallelism, dtype=input_array.dtype)
elif strategy == TransformStrategy.SEND_SIZE:
# The send amount is always the amount of data the current expert
# shard needs to process.
return jnp.repeat(input_array[shard_id], num_expert_parallelism)
elif strategy == TransformStrategy.OUTPUT_OFFSET:
# The offset in each shard will just be the start of the group which
# that shard is responsible for.
output_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(input_array[:-1])))[shard_id]
return jnp.repeat(output_offset, num_expert_parallelism)
# The amount that each shard receives from all other shards is
# equivalent to the group sizes (aka input_array).
elif strategy == TransformStrategy.RECV_SIZE:
# Received size in the target output
return input_array
else:
raise ValueError(f"Unknown transform array strategy: {strategy}")
raise ValueError(f"Unknown transform array strategy: {strategy}")

input_offsets = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.INPUT_OFFSET,
is_batch_sharded,
)
send_sizes = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.SEND_SIZE,
is_batch_sharded,
)
output_offsets = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.OUTPUT_OFFSET,
is_batch_sharded,
)
recv_sizes = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.RECV_SIZE,
is_batch_sharded,
)
return input_offsets, send_sizes, output_offsets, recv_sizes

Expand Down Expand Up @@ -1215,11 +1265,12 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
batch_size, sequence_length, _ = x.shape
num_expert_parallelism = self.get_expert_parallelism_size()
local_expert_size = self.config.num_experts // num_expert_parallelism
all_shards_group_sizes = None # Initialize for static analysis
if num_expert_parallelism > 1:
expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name)
else:
expert_shard_id = 0
num_expert_parallelism = self.get_expert_parallelism_size()
if self.config.use_ring_of_experts:
# The ring-of-experts strategy first duplicates the inputs to all
# expert shards, and then routes within each shard.
Expand All @@ -1231,7 +1282,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
)

# "Route" tokens within each shard.
num_experts_per_shard = self.config.num_experts // num_expert_parallelism
num_experts_per_shard = local_expert_size
x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute(
x,
logits,
Expand All @@ -1254,16 +1305,16 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
if num_expert_parallelism > 1:
batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data"
# get group sizes for all shards
local_expert_size = self.config.num_experts // num_expert_parallelism
reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1)
global_group_sizes = group_sizes

if is_batch_sharded_by_expert:
all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis)
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
all_shards_group_sizes = jax.lax.all_gather(group_sizes, axis_name=batch_axis)
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params_batch_sharded(
all_shards_group_sizes,
expert_shard_id,
num_expert_parallelism,
local_expert_size,
is_dispatch=True,
)

buffer_size = self.get_ragged_buffer_size(
Expand All @@ -1284,13 +1335,16 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
recv_sizes,
axis_name=self._expert_parallelism_name,
)
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name)
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
x,
global_group_sizes,
local_expert_size,
shard_index=expert_shard_id,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
# After ragged_all_to_all, x is already in expert order.
# We just need to update group_sizes and selected_experts for GMM.
# all_shards_group_sizes: [num_batch_shards, num_experts]
group_sizes = jnp.sum(
all_shards_group_sizes.reshape(all_shards_group_sizes.shape[0], -1, local_expert_size), axis=0
)[expert_shard_id]
selected_experts = jnp.repeat(
jnp.arange(local_expert_size),
group_sizes,
total_repeat_length=x.shape[0],
)
else:
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
Expand Down Expand Up @@ -1429,19 +1483,14 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
)

if is_batch_sharded_by_expert:
# locally unpermute back to the original order
local_output = _sort_activations(
intermediate_output,
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
self.config.use_custom_sort_vjp,
)
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params_batch_sharded(
all_shards_group_sizes,
expert_shard_id,
num_expert_parallelism,
local_expert_size,
is_dispatch=False,
)
intermediate_output = jax.lax.ragged_all_to_all(
local_output,
intermediate_output,
output_shape,
input_offsets,
send_sizes,
Expand All @@ -1458,7 +1507,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
reshaped_group_sizes, # pylint: disable=undefined-variable
expert_shard_id,
num_expert_parallelism,
is_batch_sharded=False,
)
intermediate_output = jax.lax.ragged_all_to_all(
intermediate_output,
Expand Down
Loading