diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 238758da92..a684f92787 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,7 +20,7 @@ from typing import Sequence from flax import nnx, linen as nn -from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules +from flax.core.spmd import get_logical_axis_rules from flax.linen import partitioning as nn_partitioning from flax.training.train_state import TrainState @@ -1612,6 +1612,43 @@ def move(path, x): ) +def _resolve_logical_sharding(out_sharding, context_rules, local_rules) -> list: + """Resolves logical sharding annotations into physical sharding specs. + + This matches rules sequentially (first-match-wins) and ensures that physical + mesh axes are bound to at most one dimension per tensor, preventing JAX + DuplicateSpecError. + """ + local_rules_list = list(local_rules) if local_rules is not None else [] + context_rules_list = list(context_rules) if context_rules is not None else [] + merged_rules = local_rules_list + context_rules_list + raw_sharding = list(out_sharding) + assigned_positions = set() + assigned_axes = set() + + for rule_logical, rule_physical in merged_rules: + if rule_logical not in out_sharding: + continue + pos = out_sharding.index(rule_logical) + if pos in assigned_positions: + continue + + if rule_physical is None: + raw_sharding[pos] = None + assigned_positions.add(pos) + continue + + physical_axes = [rule_physical] if isinstance(rule_physical, str) else list(rule_physical) + if any(axis in assigned_axes for axis in physical_axes): + continue + + raw_sharding[pos] = rule_physical + assigned_positions.add(pos) + assigned_axes.update(physical_axes) + + return raw_sharding + + def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. @@ -1669,20 +1706,22 @@ def _make_named_sharding(v): context_rules = get_logical_axis_rules() local_rules = metadata.get("sharding_rules", ()) if context_rules or local_rules: - rules = composite_rules(context_rules, local_rules) - raw_sharding = from_sharding_rules(out_sharding, rules) + raw_sharding = _resolve_logical_sharding(out_sharding, context_rules, local_rules) mesh_axis_names = mesh.axis_names if mesh is not None else () - # from_sharding_rules leaves a logical name with no matching rule unchanged, so a - # name missing from logical_axis_rules (e.g. concat_embed on the MTP kernel) - # reaches NamedSharding and is rejected as an unknown mesh axis. Map any such - # leftover name to None (replicated), matching Linen, whose logical_to_mesh_axes - # replicates unmatched names. + # Map unmatched logical names to None (replicated), matching Linen's behavior. + # Also clean up tuples to only keep physical axes present in the active mesh. def _sanitize(x): if isinstance(x, list): x = tuple(x) - if x is None or (isinstance(x, str) and x in mesh_axis_names) or isinstance(x, tuple): - return x + if x is None: + return None + if isinstance(x, str): + return x if x in mesh_axis_names else None + if isinstance(x, tuple): + # Only keep axes that actually exist in the physical mesh. + sanitized_tuple = tuple(i for i in x if i in mesh_axis_names) + return sanitized_tuple if sanitized_tuple else None return None sanitized_sharding = [_sanitize(x) for x in raw_sharding] diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 2e90880a83..a77d2ff4b3 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -25,6 +25,7 @@ import optax from flax import linen as nn +from flax.linen import partitioning as nn_partitioning from flax import nnx from flax.core.scope import FrozenVariableDict from flax.training import train_state @@ -1690,6 +1691,60 @@ def test_string_out_sharding_is_wrapped_into_tuple(self): # The single string 'fsdp' is turned into a list, and 'layers' is prepended. self.assertEqual(result_sharding.spec, PartitionSpec("layers", "fsdp")) + def test_sequential_matching_first_match_wins(self): + """Multiple rules for the same logical axis are matched sequentially, first-match-wins.""" + # We define rules for 'embed' mapping to 'fsdp' (specific) then 'layers' (fallback) + rules = ( + ("embed", "fsdp"), + ("embed", "layers"), + ) + with nn_partitioning.axis_rules(rules): + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((3,)), + out_sharding=("embed",), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'embed' must match the first rule ('fsdp'), not the second ('layers'). + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp")) + + def test_deduplicates_assigned_physical_axes(self): + """Physical axes already bound to a dimension cannot be bound to another dimension.""" + # Define rules where 'embed' maps to ('fsdp', 'layers') and 'mlp' maps to 'fsdp'. + # Because 'embed' is defined first, it binds 'fsdp'. + # When matching 'mlp', 'fsdp' is already bound, so it is skipped (unassigned/None). + rules = ( + ("embed", ("fsdp", "layers")), + ("mlp", "fsdp"), + ) + with nn_partitioning.axis_rules(rules): + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((3, 4)), + out_sharding=("embed", "mlp"), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'embed' maps to ('fsdp', 'layers'). + # 'mlp' maps to None (replicated) because 'fsdp' is already bound. + self.assertEqual(result_sharding.spec, PartitionSpec(("fsdp", "layers"), None)) + + def test_resolves_when_context_rules_is_none(self): + """When context_rules is None but local_rules are defined, resolution should succeed.""" + # Ensure get_logical_axis_rules() returns None (which is the default outside axis_rules) + # We define local rules on the variable metadata. + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((3,)), + out_sharding=("embed",), + sharding_rules=(("embed", "fsdp"),), + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'embed' must match the local rules even when context_rules is None. + self.assertEqual(result_sharding.spec, PartitionSpec("fsdp")) + if __name__ == "__main__": unittest.main()