From 4bf2ded580027ea97cea919b1585664bafc51c44 Mon Sep 17 00:00:00 2001 From: maxtext authors Date: Tue, 28 Apr 2026 12:57:17 -0700 Subject: [PATCH] Remove local sort after ragged all-to-all PiperOrigin-RevId: 907133776 --- src/maxtext/layers/moe.py | 198 +++++++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 75 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index a08c1d10ff..280964ced6 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -867,12 +867,93 @@ def local_permute( sorted_experts_ids, ) + @staticmethod + def get_all_to_all_params_batch_sharded( + all_shards_group_sizes, + shard_id, + local_expert_size, + is_dispatch=True, + ): + """Generates granular offsets and sizes for ragged_all_to_all to avoid local sorting. + + In standard batch-sharded MoE, tokens are initially distributed across sending devices. + This function calculates the necessary communication parameters so that each sending device + routes tokens directly to the correct receiving device and local expert. + + Shape Key: + sd: Number of sending devices. + rd: Number of receiving devices (must equal sd). + te: Total number of experts across all devices. + le: Number of local experts per device (local_expert_size). + """ + # all_shards_group_sizes shape: [sd, te] + sd, te = all_shards_group_sizes.shape + # rd is the number of receiving devices. + rd = te // local_expert_size + le = local_expert_size + # The number of sending and receiving devices must be equal. + assert rd == sd + + # r shape: [sending_device, receiving_device, local_expert] + r = all_shards_group_sizes.reshape((sd, rd, le)) + + # Transpose to [receiving_device, sending_device, local_expert]. + # Flattening keeps local_expert in the minor-most dimension. + r_tokens_le_minor = r.transpose((1, 0, 2)) + flat_r_tokens_le_minor = r_tokens_le_minor.reshape((rd, -1)) + + # Transpose to [receiving_device, local_expert, sending_device]. + # Flattening keeps sending_device in the minor-most dimension. + r_tokens_sd_minor = r.transpose((1, 2, 0)) + flat_r_tokens_sd_minor = r_tokens_sd_minor.reshape((rd, -1)) + + if is_dispatch: + # --- Dispatch (Send tokens to their assigned local experts) --- + # Send sizes: the amount of data the current shard needs to send to each expert. + send_sizes = all_shards_group_sizes[shard_id] + + # Input offsets: local starting position for tokens sent to each expert. + # Calculated via exclusive cumulative sum over the local send amounts. + input_offsets = jnp.cumsum(send_sizes) - send_sizes + + # Output offsets: destination layout in the receiving device's memory. + # The receiving device expects data for each local expert to be contiguous. + # Since `flat_r_tokens_sd_minor` organizes data by [receiving_device, local_expert, sending_device], + # we perform an exclusive cumulative sum along the sending_device dimension to compute offsets. + summed_reshaped = (jnp.cumsum(flat_r_tokens_sd_minor, axis=1) - flat_r_tokens_sd_minor).reshape((rd, le, sd)) + output_offsets = summed_reshaped[:, :, shard_id].flatten() + + # Receive sizes: the amount of data the current shard expects from all sending devices. + recv_sizes = flat_r_tokens_le_minor[shard_id] + else: + # --- Combine (Return tokens to their originating sending devices) --- + # Return tokens to their originating sending devices. + # Therefore, we send exactly the amounts we previously received in the dispatch phase. + send_sizes = flat_r_tokens_le_minor[shard_id] + + # Input offsets: determined from the perspective of what was received from neighbors. + # We compute the cumulative sum over the previously received sizes, then transpose + # to match the order of devices. + expert_sizes = jnp.cumsum(flat_r_tokens_sd_minor[shard_id]) - flat_r_tokens_sd_minor[shard_id] + input_offsets = ( + expert_sizes.reshape((le, sd)).transpose((1, 0)).flatten() + ) + + # Output offsets: destination layout in the original sending device's memory. + # We must place the tokens in the exact order they were originally sent. + all_batch_offsets = jnp.cumsum(all_shards_group_sizes, axis=1) - all_shards_group_sizes + output_offsets = all_batch_offsets.reshape((sd, rd, le))[:, shard_id, :].flatten() + + # Receive sizes: equal to the amounts we originally sent. + recv_sizes = all_shards_group_sizes[shard_id] + + return input_offsets, send_sizes, output_offsets, recv_sizes + @staticmethod def get_all_to_all_params( all_shards_group_sizes, shard_id, num_expert_parallelism, - is_batch_sharded=True, ): """Generates input offsets, send sizes, output offsets, and receive sizes used for ragged_all_to_all.""" @@ -882,78 +963,47 @@ class TransformStrategy(enum.Enum): OUTPUT_OFFSET = enum.auto() RECV_SIZE = enum.auto() - def transform_array(input_array, shard_id, strategy, is_batch_sharded): + def transform_array(input_array, shard_id, strategy): """Transforms the input array based on the specified strategy.""" - # Prepares it for the usage with `ragged_all_to_all` API. The - # transformation determines how data is sent and received between shards. - if is_batch_sharded: - if strategy == TransformStrategy.INPUT_OFFSET: - # Index of input array for the send - local_array = input_array[shard_id] - return jnp.concatenate((jnp.array([0]), jnp.cumsum(local_array)[:-1])) - elif strategy == TransformStrategy.SEND_SIZE: - # Size of input array for the send - return input_array[shard_id] - elif strategy == TransformStrategy.OUTPUT_OFFSET: - # Received index in the target output - zero_row = jnp.zeros((1,) + input_array.shape[1:], dtype=input_array.dtype) - array_with_zeros = jnp.concatenate((zero_row, input_array), axis=0) - cumulated_array = jnp.cumsum(array_with_zeros, axis=0, dtype=input_array.dtype) - return cumulated_array[shard_id] - elif strategy == TransformStrategy.RECV_SIZE: - # Received size in the target output - return input_array[:, shard_id] - else: - raise ValueError(f"Unknown transform array strategy: {strategy}") - - # If the batch is unsharded then we send the same data slice to all other - # shards. We also assume each shard will have the local processed inputs - # sorted to start from index 0. Finally, len(input_array.shape) == 1 since - # there is only one batch shard. + if strategy == TransformStrategy.INPUT_OFFSET: + # The data on each shard always starts at 0. + return jnp.zeros(num_expert_parallelism, dtype=input_array.dtype) + elif strategy == TransformStrategy.SEND_SIZE: + # The send amount is always the amount of data the current expert + # shard needs to process. + return jnp.repeat(input_array[shard_id], num_expert_parallelism) + elif strategy == TransformStrategy.OUTPUT_OFFSET: + # The offset in each shard will just be the start of the group which + # that shard is responsible for. + output_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(input_array[:-1])))[shard_id] + return jnp.repeat(output_offset, num_expert_parallelism) + # The amount that each shard receives from all other shards is + # equivalent to the group sizes (aka input_array). + elif strategy == TransformStrategy.RECV_SIZE: + # Received size in the target output + return input_array else: - if strategy == TransformStrategy.INPUT_OFFSET: - # The data on each shard always starts at 0. - return jnp.zeros(num_expert_parallelism, dtype=input_array.dtype) - elif strategy == TransformStrategy.SEND_SIZE: - # The send amount is always the amount of data the current expert - # shard needs to process. - return jnp.repeat(input_array[shard_id], num_expert_parallelism) - elif strategy == TransformStrategy.OUTPUT_OFFSET: - # The offset in each shard will just be the start of the group which - # that shard is responsible for. - output_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(input_array[:-1])))[shard_id] - return jnp.repeat(output_offset, num_expert_parallelism) - # The amount that each shard receives from all other shards is - # equivalent to the group sizes (aka input_array). - elif strategy == TransformStrategy.RECV_SIZE: - # Received size in the target output - return input_array - else: - raise ValueError(f"Unknown transform array strategy: {strategy}") + raise ValueError(f"Unknown transform array strategy: {strategy}") input_offsets = transform_array( all_shards_group_sizes, shard_id, TransformStrategy.INPUT_OFFSET, - is_batch_sharded, ) send_sizes = transform_array( all_shards_group_sizes, shard_id, TransformStrategy.SEND_SIZE, - is_batch_sharded, ) output_offsets = transform_array( all_shards_group_sizes, shard_id, TransformStrategy.OUTPUT_OFFSET, - is_batch_sharded, ) recv_sizes = transform_array( all_shards_group_sizes, shard_id, TransformStrategy.RECV_SIZE, - is_batch_sharded, ) return input_offsets, send_sizes, output_offsets, recv_sizes @@ -1215,11 +1265,12 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs): batch_size, sequence_length, _ = x.shape num_expert_parallelism = self.get_expert_parallelism_size() + local_expert_size = self.config.num_experts // num_expert_parallelism + all_shards_group_sizes = None # Initialize for static analysis if num_expert_parallelism > 1: expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name) else: expert_shard_id = 0 - num_expert_parallelism = self.get_expert_parallelism_size() if self.config.use_ring_of_experts: # The ring-of-experts strategy first duplicates the inputs to all # expert shards, and then routes within each shard. @@ -1231,7 +1282,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r ) # "Route" tokens within each shard. - num_experts_per_shard = self.config.num_experts // num_expert_parallelism + num_experts_per_shard = local_expert_size x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute( x, logits, @@ -1254,16 +1305,16 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r if num_expert_parallelism > 1: batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data" # get group sizes for all shards - local_expert_size = self.config.num_experts // num_expert_parallelism reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1) global_group_sizes = group_sizes if is_batch_sharded_by_expert: - all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + all_shards_group_sizes = jax.lax.all_gather(group_sizes, axis_name=batch_axis) + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params_batch_sharded( all_shards_group_sizes, expert_shard_id, - num_expert_parallelism, + local_expert_size, + is_dispatch=True, ) buffer_size = self.get_ragged_buffer_size( @@ -1284,13 +1335,16 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r recv_sizes, axis_name=self._expert_parallelism_name, ) - global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, - global_group_sizes, - local_expert_size, - shard_index=expert_shard_id, - use_custom_sort_vjp=self.config.use_custom_sort_vjp, + # After ragged_all_to_all, x is already in expert order. + # We just need to update group_sizes and selected_experts for GMM. + # all_shards_group_sizes: [num_batch_shards, num_experts] + group_sizes = jnp.sum( + all_shards_group_sizes.reshape(all_shards_group_sizes.shape[0], -1, local_expert_size), axis=0 + )[expert_shard_id] + selected_experts = jnp.repeat( + jnp.arange(local_expert_size), + group_sizes, + total_repeat_length=x.shape[0], ) else: x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( @@ -1429,19 +1483,14 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): ) if is_batch_sharded_by_expert: - # locally unpermute back to the original order - local_output = _sort_activations( - intermediate_output, - jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable - self.config.use_custom_sort_vjp, - ) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params_batch_sharded( + all_shards_group_sizes, expert_shard_id, - num_expert_parallelism, + local_expert_size, + is_dispatch=False, ) intermediate_output = jax.lax.ragged_all_to_all( - local_output, + intermediate_output, output_shape, input_offsets, send_sizes, @@ -1458,7 +1507,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): reshaped_group_sizes, # pylint: disable=undefined-variable expert_shard_id, num_expert_parallelism, - is_batch_sharded=False, ) intermediate_output = jax.lax.ragged_all_to_all( intermediate_output,