Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_jupyter_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:
PAPERMILL_EXE=".venv/bin/papermill"
source .venv/bin/activate
fi
export PYTHONPATH="${pwd}/src${PYTHONPATH:+:${PYTHONPATH}}"
export PYTHONPATH="${PWD}/src${PYTHONPATH:+:${PYTHONPATH}}"
export MAXTEXT_REPO_ROOT=$(pwd)
export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/run_tests_against_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ jobs:
PYTHON_EXE=".venv/bin/python3"
# Ensure pytest-cov is available and enable coverage flags
uv pip install pytest-cov
PYTEST_COV_ARGS="--cov=MaxText --cov=maxtext --cov-report=xml --cov-report=term"
PYTEST_COV_ARGS="--cov=maxtext --cov-report=xml --cov-report=term"
fi
export PYTHONPATH="${pwd}/src${PYTHONPATH:+:${PYTHONPATH}}"
export PYTHONPATH="${PWD}/src${PYTHONPATH:+:${PYTHONPATH}}"
if [ "${INPUTS_IS_SCHEDULED_RUN}" == "true" ]; then
FINAL_PYTEST_MARKER="${INPUTS_PYTEST_MARKER}"
Expand Down Expand Up @@ -209,7 +209,7 @@ jobs:
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage.xml
files: ./coverage.xml
# If scheduled, upload to scheduled flag only. If PR, upload to regular flag only.
flags: ${{ inputs.is_scheduled_run == 'true' && 'scheduled' || 'regular' }}
verbose: true
2 changes: 1 addition & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ class HardwareAndMesh(BaseModel):
"all_gather",
description="Strategy for context parallelism ('all_gather' or 'ring').",
)
context_parallel_reorder_strategy: ReorderStrategy = Field(
context_parallel_reorder_strategy: str = Field(
"auto",
description="Reorder strategy for load-balanced context parallelism.",
)
Expand Down
167 changes: 96 additions & 71 deletions src/maxtext/inference/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,11 @@ def _initialize_prefill_caches(self, model_mode):

self.cached_prefill_key = nnx.Cache(
jnp.zeros(cache_shape_key, dtype=dtype),
sharding=cache_axis_names,
out_sharding=cache_axis_names,
)
self.cached_prefill_value = nnx.Cache(
jnp.zeros(cache_shape_value, dtype=dtype),
sharding=cache_axis_names,
out_sharding=cache_axis_names,
)

if model_mode == MODEL_MODE_PREFILL:
Expand All @@ -380,7 +380,7 @@ def _initialize_prefill_caches(self, model_mode):

self.cache_prefill_segment_id = nnx.Cache(
jnp.zeros((cache_logical_shape[0], cache_length), dtype=jnp.int32),
sharding=segment_id_axis_names,
out_sharding=segment_id_axis_names,
)

if self.kv_quant:
Expand All @@ -394,11 +394,11 @@ def _initialize_prefill_caches(self, model_mode):

self.cached_prefill_key_scale = nnx.Cache(
jnp.zeros(cache_key_scale_shape, dtype=jnp.bfloat16),
sharding=cache_scale_axis_names,
out_sharding=cache_scale_axis_names,
)
self.cached_prefill_value_scale = nnx.Cache(
jnp.zeros(cache_value_scale_shape, dtype=jnp.bfloat16),
sharding=cache_scale_axis_names,
out_sharding=cache_scale_axis_names,
)
else:
self.cached_prefill_key_scale = None
Expand Down Expand Up @@ -433,19 +433,19 @@ def _initialize_ar_cache_vars(self, model_mode):
# TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding
self.cached_ar_key = nnx.Cache(
jnp.zeros(cache_shape_key, dtype=dtype),
sharding=cache_axis_names,
out_sharding=cache_axis_names,
)
self.cached_ar_key.value = nn.with_logical_constraint(
self.cached_ar_key.value,
self.cached_ar_key[...] = nn.with_logical_constraint(
self.cached_ar_key.get_value(),
cache_axis_names,
)

self.cached_ar_value = nnx.Cache(
jnp.zeros(cache_shape_value, dtype=dtype),
sharding=cache_axis_names,
out_sharding=cache_axis_names,
)
self.cached_ar_value.value = nn.with_logical_constraint(
self.cached_ar_value.value,
self.cached_ar_value[...] = nn.with_logical_constraint(
self.cached_ar_value.get_value(),
cache_axis_names,
)

Expand All @@ -455,12 +455,12 @@ def _initialize_ar_cache_vars(self, model_mode):
segment_id_axis_names = (CACHE_BATCH, CACHE_SEQUENCE)
self.cache_ar_segment_id = nnx.Cache(
jnp.zeros((cache_logical_shape[0], cache_length), dtype=jnp.int32),
sharding=segment_id_axis_names,
out_sharding=segment_id_axis_names,
)

self.cached_ar_lengths = nnx.Cache(
jnp.zeros((cache_logical_shape[0],), dtype=jnp.int32),
sharding=(CACHE_BATCH,),
out_sharding=(CACHE_BATCH,),
)

if self.kv_quant:
Expand All @@ -474,19 +474,19 @@ def _initialize_ar_cache_vars(self, model_mode):

self.cached_ar_key_scale = nnx.Cache(
jnp.zeros(cache_key_scale_shape, dtype=jnp.bfloat16),
sharding=cache_scale_axis_names,
out_sharding=cache_scale_axis_names,
)
self.cached_ar_value_scale = nnx.Cache(
jnp.zeros(cache_value_scale_shape, dtype=jnp.bfloat16),
sharding=cache_scale_axis_names,
out_sharding=cache_scale_axis_names,
)
else:
self.cached_ar_key_scale = None
self.cached_ar_value_scale = None

self.cache_ar_index = nnx.Cache(
jnp.zeros((1,), dtype=jnp.int32),
sharding=(),
out_sharding=(),
)

def _get_ar_cache_vars(self):
Expand Down Expand Up @@ -549,35 +549,45 @@ def kv_cache_chunked_prefill(
)

# We don't zero out remain values. Use segment id to mask out.
cached_prefill_key_vars[0].value = jax.lax.dynamic_update_slice_in_dim(
cached_key_value, key_shaped_for_cache, next_pos, cache_seq_axis
cached_prefill_key_vars[0].set_value(
jax.lax.dynamic_update_slice_in_dim(cached_key_value, key_shaped_for_cache, next_pos, cache_seq_axis)
)
cached_prefill_value_vars[0].value = jax.lax.dynamic_update_slice_in_dim(
cached_value_value, value_shaped_for_cache, next_pos, cache_seq_axis
cached_prefill_value_vars[0].set_value(
jax.lax.dynamic_update_slice_in_dim(cached_value_value, value_shaped_for_cache, next_pos, cache_seq_axis)
)

if decoder_segment_ids is not None:
# Need zero out the remain values to prevent wrong mask in autoregressive.
previous_segment_id = cached_prefill_segment_id_var.value[:, :next_pos]
cached_prefill_segment_id_var.value = jnp.zeros_like(cached_prefill_segment_id_var.value, dtype=jnp.int32)
cached_prefill_segment_id_var.value = jax.lax.dynamic_update_slice_in_dim(
cached_prefill_segment_id_var.value, previous_segment_id, start_index=0, axis=1
previous_segment_id = cached_prefill_segment_id_var.get_value()[:, :next_pos]
cached_prefill_segment_id_var.set_value(jnp.zeros_like(cached_prefill_segment_id_var.get_value(), dtype=jnp.int32))
cached_prefill_segment_id_var.set_value(
jax.lax.dynamic_update_slice_in_dim(
cached_prefill_segment_id_var.get_value(), previous_segment_id, start_index=0, axis=1
)
)
cached_prefill_segment_id_var.value = jax.lax.dynamic_update_slice_in_dim(
cached_prefill_segment_id_var.value, decoder_segment_ids, next_pos, axis=1
cached_prefill_segment_id_var.set_value(
jax.lax.dynamic_update_slice_in_dim(
cached_prefill_segment_id_var.get_value(), decoder_segment_ids, next_pos, axis=1
)
)

# Return needed kv cache to reduce computation of attention.
needed_prefill_key_value = jax.lax.dynamic_slice_in_dim(
cached_prefill_key_vars[0].value, start_index=0, slice_size=(next_pos + self.key_seq_len), axis=cache_seq_axis
cached_prefill_key_vars[0].get_value(),
start_index=0,
slice_size=(next_pos + self.key_seq_len),
axis=cache_seq_axis,
)
needed_prefill_value_value = jax.lax.dynamic_slice_in_dim(
cached_prefill_value_vars[0].value, start_index=0, slice_size=(next_pos + self.value_seq_len), axis=cache_seq_axis
cached_prefill_value_vars[0].get_value(),
start_index=0,
slice_size=(next_pos + self.value_seq_len),
axis=cache_seq_axis,
)
needed_segment_id = None
if decoder_segment_ids is not None:
needed_segment_id = jax.lax.dynamic_slice_in_dim(
cached_prefill_segment_id_var.value, start_index=0, slice_size=(next_pos + segment_id_seq_len), axis=1
cached_prefill_segment_id_var.get_value(), start_index=0, slice_size=(next_pos + segment_id_seq_len), axis=1
)

return (
Expand Down Expand Up @@ -620,14 +630,14 @@ def kv_cache_prefill(
value_shaped_for_cache, value_scale_shaped_for_cache = self.kv_quant.quantize(
value_shaped_for_cache, prefill_key_axis_names
)
cached_prefill_key_vars[1].value = key_scale_shaped_for_cache
cached_prefill_value_vars[1].value = value_scale_shaped_for_cache
cached_prefill_key_vars[1].set_value(key_scale_shaped_for_cache)
cached_prefill_value_vars[1].set_value(value_scale_shaped_for_cache)

cached_prefill_key_vars[0].value = key_shaped_for_cache
cached_prefill_value_vars[0].value = value_shaped_for_cache
cached_prefill_key_vars[0].set_value(key_shaped_for_cache)
cached_prefill_value_vars[0].set_value(value_shaped_for_cache)

if decoder_segment_ids is not None:
cached_prefill_segment_id_var.value = decoder_segment_ids
cached_prefill_segment_id_var.set_value(decoder_segment_ids)
return key, value, decoder_segment_ids

def update_ar_key_value(
Expand Down Expand Up @@ -691,51 +701,60 @@ def value_body(i, val):
new_token_locations[ar_cache_batch_axis] = i
return val.at[tuple(cache_locations)].set(one_token_value_shaped_for_cache[tuple(new_token_locations)])

cached_key.value = jax.lax.fori_loop(
0, one_token_key_shaped_for_cache.shape[0], key_body, cached_key.value, unroll=8
)
cached_value.value = jax.lax.fori_loop(
0, one_token_value_shaped_for_cache.shape[0], value_body, cached_value.value, unroll=8
cached_key[...] = jax.lax.fori_loop(0, one_token_key_shaped_for_cache.shape[0], key_body, cached_key[...], unroll=8)
cached_value[...] = jax.lax.fori_loop(
0, one_token_value_shaped_for_cache.shape[0], value_body, cached_value[...], unroll=8
)

else:
one_hot_indices = one_hot_indices.astype(int)

# Align batch size for cache with new token in decoding
if cached_key.value.shape[2] != one_token_key_shaped_for_cache.shape[2]:
cached_key.value = jnp.repeat(cached_key.value, one_token_key_shaped_for_cache.shape[2], axis=2)
cached_value.value = jnp.repeat(cached_value.value, one_token_value_shaped_for_cache.shape[2], axis=2)

cached_key.value = jax.lax.dynamic_update_index_in_dim(
cached_key.value, one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis
if cached_key.get_value().shape[2] != one_token_key_shaped_for_cache.shape[2]:
cached_key.set_value(jnp.repeat(cached_key.get_value(), one_token_key_shaped_for_cache.shape[2], axis=2))
cached_value.set_value(jnp.repeat(cached_value.get_value(), one_token_value_shaped_for_cache.shape[2], axis=2))

cached_key.set_value(
jax.lax.dynamic_update_index_in_dim(
cached_key.get_value(), one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis
)
)
cached_value.value = jax.lax.dynamic_update_index_in_dim(
cached_value.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis
cached_value.set_value(
jax.lax.dynamic_update_index_in_dim(
cached_value.get_value(), one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis
)
)
cached_key.value = nn.with_logical_constraint(cached_key.value, ar_cache_axis_names)
cached_value.value = nn.with_logical_constraint(cached_value.value, ar_cache_axis_names)
cached_key.set_value(nn.with_logical_constraint(cached_key.get_value(), ar_cache_axis_names))
cached_value.set_value(nn.with_logical_constraint(cached_value.get_value(), ar_cache_axis_names))

if self.kv_quant:
ar_cache_scale_axis_names = transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order)
ar_cache_scale_update_axis = ar_cache_scale_axis_names.index(CACHE_SCALE_SEQUENCE)
assert cached_key_scale is not None, "cached_key_scale_var cannot be None"
assert cached_value_scale is not None, "cached_value_scale_var cannot be None"
cached_key_scale.value = jax.lax.dynamic_update_index_in_dim(
cached_key_scale.value, one_token_key_scale_shaped_for_cache, ar_cache_update_idx, ar_cache_scale_update_axis
cached_key_scale.set_value(
jax.lax.dynamic_update_index_in_dim(
cached_key_scale.get_value(),
one_token_key_scale_shaped_for_cache,
ar_cache_update_idx,
ar_cache_scale_update_axis,
)
)
cached_value_scale.value = jax.lax.dynamic_update_index_in_dim(
cached_value_scale.value,
one_token_value_scale_shaped_for_cache,
ar_cache_update_idx,
ar_cache_scale_update_axis,
cached_value_scale.set_value(
jax.lax.dynamic_update_index_in_dim(
cached_value_scale.get_value(),
one_token_value_scale_shaped_for_cache,
ar_cache_update_idx,
ar_cache_scale_update_axis,
)
)

def get_cached_values(self, cache_vars, target_dtype, cache_axis_order) -> jax.Array | KVTensor:
"""get cached values"""
cache_var, cache_scale_var = cache_vars
cache_value = cache_var.value
cache_value = cache_var.get_value()
if cache_scale_var is not None:
scale_value = cache_scale_var.value
scale_value = cache_scale_var.get_value()
dtype = cache_value.dtype
if dtype == jnp.int8:
scale_value /= MAX_INT8
Expand Down Expand Up @@ -780,35 +799,41 @@ def kv_cache_autoregressive(
value,
cached_ar_key_vars,
cached_ar_value_vars,
cache_ar_index_var.value,
cache_ar_lengths_var.value,
cache_ar_index_var.get_value(),
cache_ar_lengths_var.get_value(),
use_ragged_attention,
)
active_indicator = jnp.zeros((self.batch, 1), dtype=jnp.int32) + DECODING_ACTIVE_SEQUENCE_INDICATOR

# Align batch size for cached segment IDs with indicator in decoding
if cached_ar_segment_id_var.value.shape[0] != active_indicator.shape[0]:
cached_ar_segment_id_var.value = jnp.repeat(cached_ar_segment_id_var.value, active_indicator.shape[0], axis=0)
if cached_ar_segment_id_var.get_value().shape[0] != active_indicator.shape[0]:
cached_ar_segment_id_var.set_value(
jnp.repeat(cached_ar_segment_id_var.get_value(), active_indicator.shape[0], axis=0)
)

cached_ar_segment_id_var.value = jax.lax.dynamic_update_index_in_dim(
cached_ar_segment_id_var.value, active_indicator, jnp.squeeze(cache_ar_index_var.value), 1
cached_ar_segment_id_var.set_value(
jax.lax.dynamic_update_index_in_dim(
cached_ar_segment_id_var.get_value(), active_indicator, jnp.squeeze(cache_ar_index_var.get_value()), 1
)
)
cache_ar_index_var.set_value(
jnp.mod(cache_ar_index_var.get_value() + 1, self.max_target_length - self.max_prefill_length)
)
cache_ar_index_var.value = jnp.mod(cache_ar_index_var.value + 1, self.max_target_length - self.max_prefill_length)
cache_ar_lengths_var.value = cache_ar_lengths_var.value.at[:].add(1)
cache_ar_lengths_var.set_value(cache_ar_lengths_var.get_value().at[:].add(1))

cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars()

cached_prefill = (
self.get_cached_values(cached_prefill_key_vars, key.dtype, self.prefill_cache_axis_order),
self.get_cached_values(cached_prefill_value_vars, value.dtype, self.prefill_cache_axis_order),
cached_prefill_segment_id_var.value,
cached_prefill_segment_id_var.get_value(),
)

cached_ar = (
self.get_cached_values(cached_ar_key_vars, key.dtype, self.ar_cache_axis_order),
self.get_cached_values(cached_ar_value_vars, value.dtype, self.ar_cache_axis_order),
cached_ar_segment_id_var.value,
cache_ar_lengths_var.value,
cached_ar_segment_id_var.get_value(),
cache_ar_lengths_var.get_value(),
)
return cached_prefill, cached_ar

Expand Down Expand Up @@ -877,7 +902,7 @@ def __init__(
self.recurrent_state = nnx.Cache(
jnp.zeros((int(batch), num_heads, k_head_dim, v_head_dim), dtype=dtype),
# Sharding: Batch, Heads, None (K), None (V)
sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None),
out_sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None),
)

# 2. Convolution State for the 1D Conv
Expand All @@ -886,7 +911,7 @@ def __init__(
self.conv_state = nnx.Cache(
jnp.zeros((int(batch), conv_kernel_size - 1, conv_dim), dtype=dtype),
# Sharding: Batch, None (Time), None (Dim)
sharding=(cache_batch_axis_name, None, None),
out_sharding=(cache_batch_axis_name, None, None),
)

def __call__(self):
Expand Down
Loading
Loading