Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ scan_pipeline_repeats: false
scan_layers_per_stage: false
set_remat_policy_on_pipeline_iterations: true
set_remat_policy_on_layers_per_stage: false
pipeline_save_decoder_layer_input: true # set to false to reduce pipeline tmem at cost of recomputing decoder layer inputs in backward pass


# Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
Expand Down
10 changes: 9 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,14 @@ class PipelineParallelism(BaseModel):
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")
pipeline_save_decoder_layer_input: bool = Field(
True,
description=(
"Whether to save 'decoder_layer_input' activations in the pipeline remat policy. "
"Setting to False reduces temporary memory (tmem) during pipeline execution at the cost "
"of recomputing decoder layer inputs in the backward pass."
),
)


class RematAndOffload(BaseModel):
Expand Down Expand Up @@ -2850,7 +2858,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
# For AOT compilation and correctness, always prioritize the 'stage' axis for sharding when pipelining.
for rule in self.logical_axis_rules:
if rule and rule[0] == "activation_embed_and_logits_batch":
rule[1] = ["stage", "data", "fsdp", "fsdp_transpose", "expert"]
rule[1] = [ax for ax in ["stage", "data", "fsdp", "fsdp_transpose", "expert"] if ax in self.mesh_axes]
break

if "stage" in self.mesh_axes:
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/kernels/gather_reduce_sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __getitem__(self, shape):
_BF16 = VectorTypeHelper(ir.BF16Type.get)


# fmt: off
@jax.jit(
static_argnames=[
"reduce_group_size",
Expand All @@ -69,6 +70,7 @@ def __getitem__(self, shape):
"topk_wgt_zero_nan",
],
)
# fmt: on
def sc_gather_reduce(
op: jax.Array,
idx: jax.Array,
Expand Down
25 changes: 16 additions & 9 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,13 +1624,22 @@ def _sequence_descriptor(segment_ids):
dummy_attn_mask = None
mask_type = "causal"
else:
# Default case: no packing, no context parallelism
dummy_attn_mask = jnp.zeros(
(1, 1, 1, self.max_target_length, self.max_target_length),
dtype=jnp.uint8,
)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)
# Default case: no packing, no context parallelism.
# For synthetic data, segment IDs are always all-ones (one segment per sequence), so
# the segment mask is all-True and the combined mask reduces to pure causal masking.
# Use mask_type="causal" directly to avoid materializing f32/s32[seq,seq] tensors that
# XLA loop_broadcast_fusion hoists into the pipeline scan carry (+5 GiB temp memory).
if self.config.dataset_type == "synthetic":
attn_mask = None
dummy_attn_mask = None
mask_type = "causal"
else:
dummy_attn_mask = jnp.zeros(
(1, 1, 1, self.max_target_length, self.max_target_length),
dtype=jnp.uint8,
)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)

dpa_layer = DotProductAttention(
head_dim=head_dim,
Expand All @@ -1643,12 +1652,10 @@ def _sequence_descriptor(segment_ids):
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout,
scale_factor=1.0,
transpose_batch_sequence=False,
window_size=sliding_window_size,
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
context_parallel_axis=self.config.context_sharding,
context_parallel_strategy=self.config.context_parallel_strategy,
max_segments_per_seq=max_segments_per_seq,
)

Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def __init__(
mesh=mesh,
shard_mode=config.shard_mode,
debug_sharding=config.debug_sharding,
skip_trivial_specs=True,
)

def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
Expand Down
17 changes: 13 additions & 4 deletions src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax
from jax import lax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import NamedSharding, reshard
from maxtext.common.common_types import Array, DType, ShardMode
from maxtext.layers import nnx_wrappers
from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned
Expand Down Expand Up @@ -78,7 +78,10 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->

if not self.with_scale:
if out_sharding is not None:
y = jax.lax.with_sharding_constraint(y, out_sharding)
if self.shard_mode == ShardMode.EXPLICIT:
y = reshard(y, out_sharding)
else:
y = jax.lax.with_sharding_constraint(y, out_sharding)
return y

scale = self.scale.get_value()
Expand All @@ -88,8 +91,14 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
scale = jax.device_put(scale, max_utils.device_space())

scale = jnp.asarray(scale, self.dtype)
effective_scale = scale + self.scale_offset
return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding)
effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale
y = y * effective_scale
if out_sharding is not None:
if self.shard_mode == ShardMode.EXPLICIT:
y = reshard(y, out_sharding)
else:
y = jax.lax.with_sharding_constraint(y, out_sharding)
return y


class GlobalRMSNorm(RMSNorm):
Expand Down
187 changes: 143 additions & 44 deletions src/maxtext/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes):
rules=self.config.logical_axis_rules,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
skip_trivial_specs=True,
)

def _maybe_shard_with_name(self, inputs, sharding_name):
Expand All @@ -138,7 +139,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift):
# Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage)
state_io_batch_idx = loop_iteration % self.microbatches_per_stage
state_io_slice = state_io[:, state_io_batch_idx]
shift = self._maybe_shard_with_logical(shift, self.stages_in_logical)

if self.use_circ_storage:
# Setup potential input from circ_storage, which also has a rotating index for microbatch,
Expand All @@ -153,7 +153,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift):
# state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g.
# from circ_storage).
first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in)
first_stage_in = self._maybe_shard_with_logical(first_stage_in, self.stages_in_logical)

# Note that first_stage_in may correspond to bubble computation during the last few iterations.
# However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are
Expand All @@ -163,11 +162,7 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift):

def select_state_or_input(first_stage_in, shift):
# Selects input for stage 0, shift for other stages
return jnp.where(
jax.lax.broadcasted_iota("int32", shift.shape, 0, out_sharding=self.stages_in_sharding) == 0,
first_stage_in,
shift,
)
return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift)

# Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output)
stages_in = select_state_or_input(first_stage_in, shift)
Expand All @@ -178,7 +173,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
non-circular"""
# Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages
microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0)
microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage")))
microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
return microbatch_ids, repeat_ids
Expand All @@ -187,10 +181,133 @@ def get_pipeline_remat_policy(self):
"""Returns the pipeline remat policy for this pipeline."""
if self.config.remat_policy == "custom":
return self.remat_policy
save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input")

names_to_save = ["iteration_input"]
if self.config.pipeline_save_decoder_layer_input:
names_to_save.append("decoder_layer_input")
save_input_policy = jax.checkpoint_policies.save_only_these_names(*names_to_save)
if self.remat_policy is not None:
return jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy)
return save_input_policy
remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy)
else:
remat_policy = save_input_policy
return remat_policy

def get_weight_sharding(self, *init_args):
"""get weight sharding function for this pipeline."""
key = jax.random.PRNGKey(0)
keys = {"params": key, "dropout": key, "aqt": key}
weights = self.init(keys, *init_args)

def get_partition_spec(pytree):
def _is_leaf(x):
return isinstance(x, nn.spmd.LogicallyPartitioned)

def get_partition_spec_leaf(leaf):
return leaf.get_partition_spec()

return jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf)

partition_spec_with_extra_layer = get_partition_spec(weights)
logical_partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]}
return logical_partition_spec

def get_vmap_func_for_init(self):
"""This vmap func is used to initialize the weights only on init."""

def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode):
return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode)

vmap_func = nn.vmap(
func_to_vmap,
in_axes=(0, 0, 0, None, None),
spmd_axis_name=self.spmd_axis_name,
variable_axes={"params": 0, "_overwrite_with_gradient": 0},
split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout},
metadata_params={
nn.PARTITION_NAME: "layers",
"sub_weight_split_dims_mapping": (None),
"is_initializing": self.is_initializing(),
"x_times": self.num_stages,
},
)
return vmap_func

def get_main_vmap_func_for_iterations(self):
"""
Returns main stage function vmapped by number of stages.
This becomes a vmap over a single layer instance if body_instance is a single layer,
else a set of layers if body_instance is a set of layers.
"""

def func_to_vmap(
body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode
):
return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode)

vmap_func = nn.vmap(
func_to_vmap,
in_axes=(0, 0, 0, 0, None, None),
spmd_axis_name=self.spmd_axis_name,
variable_axes={"params": 0},
split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout},
metadata_params={
nn.PARTITION_NAME: "layers",
"sub_weight_split_dims_mapping": (None),
"is_initializing": self.is_initializing(),
"x_times": self.num_stages,
},
)
return vmap_func

def _run_weight_initialization(
self, example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode
):
"""Runs the initialization sequence mapping layers appropriately based on pipeline settings."""
vmap_func = self.get_vmap_func_for_init()

if self.config.num_pipeline_repeats > 1:
vmap_func = nn.vmap(
vmap_func,
in_axes=(0, segment_idx, position_idx, None, None),
variable_axes={"params": 0, "_overwrite_with_gradient": 0, "non_trainable": 0, "hyper_params": 0},
split_rngs={"params": True, "dropout": self.config.enable_dropout},
metadata_params={
nn.PARTITION_NAME: "circular_repeats",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": True,
"x_times": self.config.num_pipeline_repeats,
"optimizer_dims_mapping": None,
},
)
example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats])
example_segmentation = (
jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats])
if example_segmentation is not None
else None
)
example_position = (
jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats])
if example_position is not None
else None
)

example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None))
stage_outputs = vmap_func(
self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode
)
if self.config.scan_layers:
stage_outputs = stage_outputs[0]
if self.config.num_pipeline_repeats > 1:
stage_outputs = stage_outputs[0]
broadcasted_stage_outpus = jax.lax.broadcast(
stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]
)

return jnp.reshape(
broadcasted_stage_outpus,
[self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim],
out_sharding=self.output_sharding,
)

@staticmethod
def _remove_fsdp_from_physical_partition_spec(physical_partition_spec):
Expand Down Expand Up @@ -355,9 +472,7 @@ def vmap_gather(self, xs, ids, ids_dim):
ndim = xs.ndim

def _gather_one(x, i):
idx = tuple(i if d == ids_dim else slice(None) for d in range(ndim))
replicated_sharding = NamedSharding(self.mesh, P())
return x.at[idx].get(out_sharding=replicated_sharding)
return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim)

ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None)
outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
Expand All @@ -381,20 +496,16 @@ def get_new_loop_state(self, output, loop_state):
loop_iteration = loop_state["loop_iteration"]
old_prev_outputs = loop_state["prev_outputs"]

@jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True)
def _rotate_right(arr):
# we use +1 for right shifting
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
return jax.lax.ppermute(arr, axis_name="stage", perm=perm)
# Use lax.slice to avoid generating a gather.
last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0)
except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0)
return jnp.concatenate([last, except_last], axis=0)

@jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True)
def _shift_right(arr):
stage_idx = jax.lax.axis_index("stage")
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm)
return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr)
padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
# Use lax.slice to guarantee the gradient is a pad.
return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)

# Shift either rotates or shifts depending on if the last stage immediately must send to first or not
# For non-circular pipelines, the last stage does not need to send to first
Expand Down Expand Up @@ -437,29 +548,17 @@ def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in):
stream_buf_idx = loop_iteration % self.microbatches_per_stage
stream_slice = old_state_io[:, stream_buf_idx]

def _rotate_left(arr, stage_size):
# we use -1 for left shifting
perm = [(i, (i - 1) % stage_size) for i in range(stage_size)]
return jax.lax.ppermute(arr, axis_name="stage", perm=perm)

def _shift_left(arr, stage_size, output):
stage_idx = jax.lax.axis_index("stage")
arr = _rotate_left(arr, stage_size)
return jnp.where(stage_idx == stage_size - 1, output, arr)

@jax.shard_map(
mesh=self.mesh,
in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()),
out_specs=self.state_io_spec,
)
def _update_state_io(state_in, stream_slice, output, stream_buf_idx):
def _update_state_io(state_in, stream_slice, output):
# Shift the current slice to the left, then fill the last stage with the final output.
stage_size = jax.lax.axis_size("stage")
stream_slice = _shift_left(stream_slice, stage_size, output)
padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1)
stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0)
stream_slice = jnp.where(
jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice
)
stream_slice = jnp.expand_dims(stream_slice, 1)
return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1)

new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx)
new_state = _update_state_io(old_state_io, stream_slice, output)

return {
"state_io": new_state,
Expand Down
Loading
Loading