From f6735ef9e4135592ac7510e471f9ab030f7326a2 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Mon, 4 May 2026 03:45:06 +0000 Subject: [PATCH] Resolve deprecation warnings, logic errors, and formatting in MaxText - Resolve Deprecation Warnings: - Replaced deprecated `.value` access on `nnx.Variable` with `[...]` across layers, models, trainers, and unit tests. - Replaced deprecated `_object__state` with `_pytree__state` in `nnx_wrappers.py`. - Resolved deprecation warnings for `Module.sow` and `Variable.value` in MLA, GPT3, and other attention layers. - Fixed deprecated sharding warnings in mHC and MoE layers. - Resolved Pydantic serializer warnings for `context_parallel_reorder_strategy` in `types.py`. - Logic & Compatibility Fixes: - Corrected shape broadcasting and RecursionError by utilizing `set_value` in GDN cache and distillation trainer. - Fixed NNX wrappers logic for handling variable transitions. - Resolved engram overflow and inconsistent state access in kvcache. - Corrected Codecov configuration and CI workflows. - Style & Formatting: - Formatted all modified source and test files with `pyink` (2-space indentation, 125 line length) to ensure complete compliance with repository style guidelines. --- .github/workflows/run_jupyter_notebooks.yml | 2 +- .../workflows/run_tests_against_package.yml | 6 +- src/maxtext/configs/types.py | 2 +- src/maxtext/inference/kvcache.py | 177 ++++++++++-------- src/maxtext/inference/paged_attention.py | 42 ++--- src/maxtext/layers/attention_mla.py | 2 +- src/maxtext/layers/attention_op.py | 6 +- src/maxtext/layers/embeddings.py | 4 +- src/maxtext/layers/engram.py | 22 +-- src/maxtext/layers/initializers.py | 16 +- src/maxtext/layers/mhc.py | 50 ++--- src/maxtext/layers/moe.py | 23 ++- src/maxtext/layers/nnx_wrappers.py | 33 ++-- src/maxtext/layers/normalizations.py | 2 +- src/maxtext/models/gpt3.py | 5 +- src/maxtext/models/qwen3.py | 8 +- .../post_train/distillation/train_distill.py | 14 +- src/maxtext/utils/model_creation_utils.py | 14 +- tests/inference/kvcache_test.py | 16 +- tests/unit/attention_test.py | 8 +- tests/unit/model_creation_utils_test.py | 2 +- tests/unit/multi_token_prediction_test.py | 10 +- 22 files changed, 236 insertions(+), 228 deletions(-) diff --git a/.github/workflows/run_jupyter_notebooks.yml b/.github/workflows/run_jupyter_notebooks.yml index 25d46b6acc..7d868e8d5c 100644 --- a/.github/workflows/run_jupyter_notebooks.yml +++ b/.github/workflows/run_jupyter_notebooks.yml @@ -94,7 +94,7 @@ jobs: PAPERMILL_EXE=".venv/bin/papermill" source .venv/bin/activate fi - export PYTHONPATH="${pwd}/src${PYTHONPATH:+:${PYTHONPATH}}" + export PYTHONPATH="${PWD}/src${PYTHONPATH:+:${PYTHONPATH}}" export MAXTEXT_REPO_ROOT=$(pwd) export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 2955e082f0..08124bd8e9 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -139,9 +139,9 @@ jobs: PYTHON_EXE=".venv/bin/python3" # Ensure pytest-cov is available and enable coverage flags uv pip install pytest-cov - PYTEST_COV_ARGS="--cov=MaxText --cov=maxtext --cov-report=xml --cov-report=term" + PYTEST_COV_ARGS="--cov=maxtext --cov-report=xml --cov-report=term" fi - export PYTHONPATH="${pwd}/src${PYTHONPATH:+:${PYTHONPATH}}" + export PYTHONPATH="${PWD}/src${PYTHONPATH:+:${PYTHONPATH}}" if [ "${INPUTS_IS_SCHEDULED_RUN}" == "true" ]; then FINAL_PYTEST_MARKER="${INPUTS_PYTEST_MARKER}" @@ -209,7 +209,7 @@ jobs: continue-on-error: true with: token: ${{ secrets.CODECOV_TOKEN }} - file: ./coverage.xml + files: ./coverage.xml # If scheduled, upload to scheduled flag only. If PR, upload to regular flag only. flags: ${{ inputs.is_scheduled_run == 'true' && 'scheduled' || 'regular' }} verbose: true diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 9610bbab46..22988f324a 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -847,7 +847,7 @@ class HardwareAndMesh(BaseModel): "all_gather", description="Strategy for context parallelism ('all_gather' or 'ring').", ) - context_parallel_reorder_strategy: ReorderStrategy = Field( + context_parallel_reorder_strategy: str = Field( "auto", description="Reorder strategy for load-balanced context parallelism.", ) diff --git a/src/maxtext/inference/kvcache.py b/src/maxtext/inference/kvcache.py index 9f634069c4..897a0cc8a7 100644 --- a/src/maxtext/inference/kvcache.py +++ b/src/maxtext/inference/kvcache.py @@ -366,11 +366,11 @@ def _initialize_prefill_caches(self, model_mode): self.cached_prefill_key = nnx.Cache( jnp.zeros(cache_shape_key, dtype=dtype), - sharding=cache_axis_names, + out_sharding=cache_axis_names, ) self.cached_prefill_value = nnx.Cache( jnp.zeros(cache_shape_value, dtype=dtype), - sharding=cache_axis_names, + out_sharding=cache_axis_names, ) if model_mode == MODEL_MODE_PREFILL: @@ -380,7 +380,7 @@ def _initialize_prefill_caches(self, model_mode): self.cache_prefill_segment_id = nnx.Cache( jnp.zeros((cache_logical_shape[0], cache_length), dtype=jnp.int32), - sharding=segment_id_axis_names, + out_sharding=segment_id_axis_names, ) if self.kv_quant: @@ -394,11 +394,11 @@ def _initialize_prefill_caches(self, model_mode): self.cached_prefill_key_scale = nnx.Cache( jnp.zeros(cache_key_scale_shape, dtype=jnp.bfloat16), - sharding=cache_scale_axis_names, + out_sharding=cache_scale_axis_names, ) self.cached_prefill_value_scale = nnx.Cache( jnp.zeros(cache_value_scale_shape, dtype=jnp.bfloat16), - sharding=cache_scale_axis_names, + out_sharding=cache_scale_axis_names, ) else: self.cached_prefill_key_scale = None @@ -433,19 +433,19 @@ def _initialize_ar_cache_vars(self, model_mode): # TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding self.cached_ar_key = nnx.Cache( jnp.zeros(cache_shape_key, dtype=dtype), - sharding=cache_axis_names, + out_sharding=cache_axis_names, ) - self.cached_ar_key.value = nn.with_logical_constraint( - self.cached_ar_key.value, + self.cached_ar_key[...] = nn.with_logical_constraint( + self.cached_ar_key[...], cache_axis_names, ) self.cached_ar_value = nnx.Cache( jnp.zeros(cache_shape_value, dtype=dtype), - sharding=cache_axis_names, + out_sharding=cache_axis_names, ) - self.cached_ar_value.value = nn.with_logical_constraint( - self.cached_ar_value.value, + self.cached_ar_value[...] = nn.with_logical_constraint( + self.cached_ar_value[...], cache_axis_names, ) @@ -455,12 +455,12 @@ def _initialize_ar_cache_vars(self, model_mode): segment_id_axis_names = (CACHE_BATCH, CACHE_SEQUENCE) self.cache_ar_segment_id = nnx.Cache( jnp.zeros((cache_logical_shape[0], cache_length), dtype=jnp.int32), - sharding=segment_id_axis_names, + out_sharding=segment_id_axis_names, ) self.cached_ar_lengths = nnx.Cache( jnp.zeros((cache_logical_shape[0],), dtype=jnp.int32), - sharding=(CACHE_BATCH,), + out_sharding=(CACHE_BATCH,), ) if self.kv_quant: @@ -474,11 +474,11 @@ def _initialize_ar_cache_vars(self, model_mode): self.cached_ar_key_scale = nnx.Cache( jnp.zeros(cache_key_scale_shape, dtype=jnp.bfloat16), - sharding=cache_scale_axis_names, + out_sharding=cache_scale_axis_names, ) self.cached_ar_value_scale = nnx.Cache( jnp.zeros(cache_value_scale_shape, dtype=jnp.bfloat16), - sharding=cache_scale_axis_names, + out_sharding=cache_scale_axis_names, ) else: self.cached_ar_key_scale = None @@ -486,7 +486,7 @@ def _initialize_ar_cache_vars(self, model_mode): self.cache_ar_index = nnx.Cache( jnp.zeros((1,), dtype=jnp.int32), - sharding=(), + out_sharding=(), ) def _get_ar_cache_vars(self): @@ -549,35 +549,45 @@ def kv_cache_chunked_prefill( ) # We don't zero out remain values. Use segment id to mask out. - cached_prefill_key_vars[0].value = jax.lax.dynamic_update_slice_in_dim( - cached_key_value, key_shaped_for_cache, next_pos, cache_seq_axis + cached_prefill_key_vars[0].set_value( + jax.lax.dynamic_update_slice_in_dim(cached_key_value, key_shaped_for_cache, next_pos, cache_seq_axis) ) - cached_prefill_value_vars[0].value = jax.lax.dynamic_update_slice_in_dim( - cached_value_value, value_shaped_for_cache, next_pos, cache_seq_axis + cached_prefill_value_vars[0].set_value( + jax.lax.dynamic_update_slice_in_dim(cached_value_value, value_shaped_for_cache, next_pos, cache_seq_axis) ) if decoder_segment_ids is not None: # Need zero out the remain values to prevent wrong mask in autoregressive. - previous_segment_id = cached_prefill_segment_id_var.value[:, :next_pos] - cached_prefill_segment_id_var.value = jnp.zeros_like(cached_prefill_segment_id_var.value, dtype=jnp.int32) - cached_prefill_segment_id_var.value = jax.lax.dynamic_update_slice_in_dim( - cached_prefill_segment_id_var.value, previous_segment_id, start_index=0, axis=1 + previous_segment_id = cached_prefill_segment_id_var.get_value()[:, :next_pos] + cached_prefill_segment_id_var.set_value(jnp.zeros_like(cached_prefill_segment_id_var.get_value(), dtype=jnp.int32)) + cached_prefill_segment_id_var.set_value( + jax.lax.dynamic_update_slice_in_dim( + cached_prefill_segment_id_var.get_value(), previous_segment_id, start_index=0, axis=1 + ) ) - cached_prefill_segment_id_var.value = jax.lax.dynamic_update_slice_in_dim( - cached_prefill_segment_id_var.value, decoder_segment_ids, next_pos, axis=1 + cached_prefill_segment_id_var.set_value( + jax.lax.dynamic_update_slice_in_dim( + cached_prefill_segment_id_var.get_value(), decoder_segment_ids, next_pos, axis=1 + ) ) # Return needed kv cache to reduce computation of attention. needed_prefill_key_value = jax.lax.dynamic_slice_in_dim( - cached_prefill_key_vars[0].value, start_index=0, slice_size=(next_pos + self.key_seq_len), axis=cache_seq_axis + cached_prefill_key_vars[0].get_value(), + start_index=0, + slice_size=(next_pos + self.key_seq_len), + axis=cache_seq_axis, ) needed_prefill_value_value = jax.lax.dynamic_slice_in_dim( - cached_prefill_value_vars[0].value, start_index=0, slice_size=(next_pos + self.value_seq_len), axis=cache_seq_axis + cached_prefill_value_vars[0].get_value(), + start_index=0, + slice_size=(next_pos + self.value_seq_len), + axis=cache_seq_axis, ) needed_segment_id = None if decoder_segment_ids is not None: needed_segment_id = jax.lax.dynamic_slice_in_dim( - cached_prefill_segment_id_var.value, start_index=0, slice_size=(next_pos + segment_id_seq_len), axis=1 + cached_prefill_segment_id_var.get_value(), start_index=0, slice_size=(next_pos + segment_id_seq_len), axis=1 ) return ( @@ -620,14 +630,14 @@ def kv_cache_prefill( value_shaped_for_cache, value_scale_shaped_for_cache = self.kv_quant.quantize( value_shaped_for_cache, prefill_key_axis_names ) - cached_prefill_key_vars[1].value = key_scale_shaped_for_cache - cached_prefill_value_vars[1].value = value_scale_shaped_for_cache + cached_prefill_key_vars[1].set_value(key_scale_shaped_for_cache) + cached_prefill_value_vars[1].set_value(value_scale_shaped_for_cache) - cached_prefill_key_vars[0].value = key_shaped_for_cache - cached_prefill_value_vars[0].value = value_shaped_for_cache + cached_prefill_key_vars[0].set_value(key_shaped_for_cache) + cached_prefill_value_vars[0].set_value(value_shaped_for_cache) if decoder_segment_ids is not None: - cached_prefill_segment_id_var.value = decoder_segment_ids + cached_prefill_segment_id_var.set_value(decoder_segment_ids) return key, value, decoder_segment_ids def update_ar_key_value( @@ -691,51 +701,60 @@ def value_body(i, val): new_token_locations[ar_cache_batch_axis] = i return val.at[tuple(cache_locations)].set(one_token_value_shaped_for_cache[tuple(new_token_locations)]) - cached_key.value = jax.lax.fori_loop( - 0, one_token_key_shaped_for_cache.shape[0], key_body, cached_key.value, unroll=8 - ) - cached_value.value = jax.lax.fori_loop( - 0, one_token_value_shaped_for_cache.shape[0], value_body, cached_value.value, unroll=8 + cached_key[...] = jax.lax.fori_loop(0, one_token_key_shaped_for_cache.shape[0], key_body, cached_key[...], unroll=8) + cached_value[...] = jax.lax.fori_loop( + 0, one_token_value_shaped_for_cache.shape[0], value_body, cached_value[...], unroll=8 ) else: one_hot_indices = one_hot_indices.astype(int) # Align batch size for cache with new token in decoding - if cached_key.value.shape[2] != one_token_key_shaped_for_cache.shape[2]: - cached_key.value = jnp.repeat(cached_key.value, one_token_key_shaped_for_cache.shape[2], axis=2) - cached_value.value = jnp.repeat(cached_value.value, one_token_value_shaped_for_cache.shape[2], axis=2) - - cached_key.value = jax.lax.dynamic_update_index_in_dim( - cached_key.value, one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis + if cached_key.get_value().shape[2] != one_token_key_shaped_for_cache.shape[2]: + cached_key.set_value(jnp.repeat(cached_key.get_value(), one_token_key_shaped_for_cache.shape[2], axis=2)) + cached_value.set_value(jnp.repeat(cached_value.get_value(), one_token_value_shaped_for_cache.shape[2], axis=2)) + + cached_key.set_value( + jax.lax.dynamic_update_index_in_dim( + cached_key.get_value(), one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis + ) ) - cached_value.value = jax.lax.dynamic_update_index_in_dim( - cached_value.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis + cached_value.set_value( + jax.lax.dynamic_update_index_in_dim( + cached_value.get_value(), one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis + ) ) - cached_key.value = nn.with_logical_constraint(cached_key.value, ar_cache_axis_names) - cached_value.value = nn.with_logical_constraint(cached_value.value, ar_cache_axis_names) + cached_key.set_value(nn.with_logical_constraint(cached_key.get_value(), ar_cache_axis_names)) + cached_value.set_value(nn.with_logical_constraint(cached_value.get_value(), ar_cache_axis_names)) if self.kv_quant: ar_cache_scale_axis_names = transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order) ar_cache_scale_update_axis = ar_cache_scale_axis_names.index(CACHE_SCALE_SEQUENCE) assert cached_key_scale is not None, "cached_key_scale_var cannot be None" assert cached_value_scale is not None, "cached_value_scale_var cannot be None" - cached_key_scale.value = jax.lax.dynamic_update_index_in_dim( - cached_key_scale.value, one_token_key_scale_shaped_for_cache, ar_cache_update_idx, ar_cache_scale_update_axis + cached_key_scale.set_value( + jax.lax.dynamic_update_index_in_dim( + cached_key_scale.get_value(), + one_token_key_scale_shaped_for_cache, + ar_cache_update_idx, + ar_cache_scale_update_axis, + ) ) - cached_value_scale.value = jax.lax.dynamic_update_index_in_dim( - cached_value_scale.value, - one_token_value_scale_shaped_for_cache, - ar_cache_update_idx, - ar_cache_scale_update_axis, + cached_value_scale.set_value( + jax.lax.dynamic_update_index_in_dim( + cached_value_scale.get_value(), + one_token_value_scale_shaped_for_cache, + ar_cache_update_idx, + ar_cache_scale_update_axis, + ) ) def get_cached_values(self, cache_vars, target_dtype, cache_axis_order) -> jax.Array | KVTensor: """get cached values""" cache_var, cache_scale_var = cache_vars - cache_value = cache_var.value + cache_value = cache_var.get_value() if cache_scale_var is not None: - scale_value = cache_scale_var.value + scale_value = cache_scale_var.get_value() dtype = cache_value.dtype if dtype == jnp.int8: scale_value /= MAX_INT8 @@ -771,44 +790,54 @@ def kv_cache_autoregressive( if sequence != 1: raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}") - cached_ar_key_vars, cached_ar_value_vars, cached_ar_segment_id_var, cache_ar_index_var, cache_ar_lengths_var = ( - self._get_ar_cache_vars() - ) + ( + cached_ar_key_vars, + cached_ar_value_vars, + cached_ar_segment_id_var, + cache_ar_index_var, + cache_ar_lengths_var, + ) = self._get_ar_cache_vars() self.update_ar_key_value( key, value, cached_ar_key_vars, cached_ar_value_vars, - cache_ar_index_var.value, - cache_ar_lengths_var.value, + cache_ar_index_var.get_value(), + cache_ar_lengths_var.get_value(), use_ragged_attention, ) active_indicator = jnp.zeros((self.batch, 1), dtype=jnp.int32) + DECODING_ACTIVE_SEQUENCE_INDICATOR # Align batch size for cached segment IDs with indicator in decoding - if cached_ar_segment_id_var.value.shape[0] != active_indicator.shape[0]: - cached_ar_segment_id_var.value = jnp.repeat(cached_ar_segment_id_var.value, active_indicator.shape[0], axis=0) + if cached_ar_segment_id_var.get_value().shape[0] != active_indicator.shape[0]: + cached_ar_segment_id_var.set_value( + jnp.repeat(cached_ar_segment_id_var.get_value(), active_indicator.shape[0], axis=0) + ) - cached_ar_segment_id_var.value = jax.lax.dynamic_update_index_in_dim( - cached_ar_segment_id_var.value, active_indicator, jnp.squeeze(cache_ar_index_var.value), 1 + cached_ar_segment_id_var.set_value( + jax.lax.dynamic_update_index_in_dim( + cached_ar_segment_id_var.get_value(), active_indicator, jnp.squeeze(cache_ar_index_var.get_value()), 1 + ) + ) + cache_ar_index_var.set_value( + jnp.mod(cache_ar_index_var.get_value() + 1, self.max_target_length - self.max_prefill_length) ) - cache_ar_index_var.value = jnp.mod(cache_ar_index_var.value + 1, self.max_target_length - self.max_prefill_length) - cache_ar_lengths_var.value = cache_ar_lengths_var.value.at[:].add(1) + cache_ar_lengths_var.set_value(cache_ar_lengths_var.get_value().at[:].add(1)) cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars() cached_prefill = ( self.get_cached_values(cached_prefill_key_vars, key.dtype, self.prefill_cache_axis_order), self.get_cached_values(cached_prefill_value_vars, value.dtype, self.prefill_cache_axis_order), - cached_prefill_segment_id_var.value, + cached_prefill_segment_id_var.get_value(), ) cached_ar = ( self.get_cached_values(cached_ar_key_vars, key.dtype, self.ar_cache_axis_order), self.get_cached_values(cached_ar_value_vars, value.dtype, self.ar_cache_axis_order), - cached_ar_segment_id_var.value, - cache_ar_lengths_var.value, + cached_ar_segment_id_var.get_value(), + cache_ar_lengths_var.get_value(), ) return cached_prefill, cached_ar @@ -877,7 +906,7 @@ def __init__( self.recurrent_state = nnx.Cache( jnp.zeros((int(batch), num_heads, k_head_dim, v_head_dim), dtype=dtype), # Sharding: Batch, Heads, None (K), None (V) - sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None), + out_sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None), ) # 2. Convolution State for the 1D Conv @@ -886,7 +915,7 @@ def __init__( self.conv_state = nnx.Cache( jnp.zeros((int(batch), conv_kernel_size - 1, conv_dim), dtype=dtype), # Sharding: Batch, None (Time), None (Dim) - sharding=(cache_batch_axis_name, None, None), + out_sharding=(cache_batch_axis_name, None, None), ) def __call__(self): diff --git a/src/maxtext/inference/paged_attention.py b/src/maxtext/inference/paged_attention.py index 85dfb905b9..d3fe19270f 100644 --- a/src/maxtext/inference/paged_attention.py +++ b/src/maxtext/inference/paged_attention.py @@ -170,22 +170,22 @@ def __init__( self.key_pages = nnx.Cache( jnp.zeros(self.kv_pages_shape, dtype=self.dtype), - sharding=self.kv_pages_axis_names, + out_sharding=self.kv_pages_axis_names, ) self.value_pages = nnx.Cache( jnp.zeros(self.kv_pages_shape, dtype=self.dtype), - sharding=self.kv_pages_axis_names, + out_sharding=self.kv_pages_axis_names, ) def _maybe_materialize_cache(self, cache: nnx.Cache) -> nnx.Cache: """Materializes the cache if it's currently a ShapeDtypeStruct.""" - if isinstance(cache.value, jax.ShapeDtypeStruct): + if isinstance(cache.get_value(), jax.ShapeDtypeStruct): # This is needed because the Linen bridge lazily creates this state. We # need to ensure the cache state is accessible at runtime. # TODO: Delete this function when the to_linen bridge is no longer needed. return nnx.Cache( jnp.zeros(self.kv_pages_shape, dtype=self.dtype), - sharding=cache.sharding, + out_sharding=cache.get_metadata("out_sharding"), ) return cache @@ -204,8 +204,8 @@ def get_kv_pages(self): self.key_pages = self._maybe_materialize_cache(self.key_pages) self.value_pages = self._maybe_materialize_cache(self.value_pages) - self.key_pages.value = nn.with_logical_constraint(self.key_pages.value, self.kv_pages_axis_names) - self.value_pages.value = nn.with_logical_constraint(self.value_pages.value, self.kv_pages_axis_names) + self.key_pages.set_value(nn.with_logical_constraint(self.key_pages.get_value(), self.kv_pages_axis_names)) + self.value_pages.set_value(nn.with_logical_constraint(self.value_pages.get_value(), self.kv_pages_axis_names)) return self.key_pages, self.value_pages def pad_qkv(self, *qkv): @@ -264,9 +264,9 @@ def paged_attention_v2_prefill( is the batch_size is only 1 """ assert query.shape[0] == 1 # ensure the batch size is 0 - # shape of key_pages_cache.value is [num_kv_heads, num_pages, tokens_per_page, head_dim] - k_p = jnp.permute_dims(key_pages_cache.value, (1, 2, 0, 3)) - v_p = jnp.permute_dims(value_pages_cache.value, (1, 2, 0, 3)) + # shape of key_pages_cache.get_value() is [num_kv_heads, num_pages, tokens_per_page, head_dim] + k_p = jnp.permute_dims(key_pages_cache.get_value(), (1, 2, 0, 3)) + v_p = jnp.permute_dims(value_pages_cache.get_value(), (1, 2, 0, 3)) c_q_l = jnp.array([0, page_state.sequence_lengths[0]]) # [0, prefill_true_length] num_seqs = jnp.array([1]) query = query[0] # [batch_size, max_num_tokens, num_kv_heads, head_dim] to [max_num_tokens, num_kv_heads, head_dim] @@ -294,8 +294,8 @@ def paged_attention_v2_decode( """Apply ragged input Paged Attention in decode only.""" batch_size = query.shape[0] query = jnp.squeeze(query, axis=1) # [batch_size, seq_len, n_kv_head, head_dim] to [batch_size, n_kv_head, head_dim] - k_p = jnp.permute_dims(key_pages_cache.value, (1, 2, 0, 3)) - v_p = jnp.permute_dims(value_pages_cache.value, (1, 2, 0, 3)) + k_p = jnp.permute_dims(key_pages_cache.get_value(), (1, 2, 0, 3)) + v_p = jnp.permute_dims(value_pages_cache.get_value(), (1, 2, 0, 3)) c_q_l = jnp.arange(batch_size + 1) # one token per sequence num_seqs = jnp.array([batch_size]) # real number of requests, set it to batch_size result = paged_attention_kernel_v2.ragged_paged_attention( @@ -352,8 +352,8 @@ def wrap_paged_attention(q, k_pages, v_pages, lengths, page_indices, pages_per_c return wrap_paged_attention( query, - key_pages_cache.value, - value_pages_cache.value, + key_pages_cache.get_value(), + value_pages_cache.get_value(), page_state.sequence_lengths, page_state.page_map, self.pages_per_compute_block, @@ -441,12 +441,12 @@ def update_prefill_step_pages( ), f"prefill_step key/value should have the same shape, but getting {key.shape=} and {value.shape=} instead" batch_size, seq_len, n_kv_head, head_dim = key.shape assert seq_len % self.tokens_per_page == 0, f"seq_length {seq_len} and tokens_per_page {self.tokens_per_page}" - assert key_pages_cache.value.shape == value_pages_cache.value.shape, ( + assert key_pages_cache.get_value().shape == value_pages_cache.get_value().shape, ( f"prefill_step key/value_pages_cache should have the same shape, but " f"getting {key_pages_cache.shape=} and {value_pages_cache.shape=} instead" ) - v_n_kv, _, v_p, v_d = key_pages_cache.value.shape + v_n_kv, _, v_p, v_d = key_pages_cache.get_value().shape assert v_n_kv == n_kv_head, f"{v_n_kv=} {n_kv_head=}" assert v_p == self.tokens_per_page, f"{v_p=} {self.tokens_per_page=}" assert v_d == head_dim, f"{v_d=} {head_dim=}" @@ -485,13 +485,13 @@ def update_prefill_step_pages( ), ) - key_pages_cache.value = nn.with_logical_constraint(key, self.kv_pages_axis_names) - value_pages_cache.value = nn.with_logical_constraint(value, self.kv_pages_axis_names) + key_pages_cache.set_value(nn.with_logical_constraint(key, self.kv_pages_axis_names)) + value_pages_cache.set_value(nn.with_logical_constraint(value, self.kv_pages_axis_names)) def update_decode_step_pages(self, key_pages_cache, value_pages_cache, key, value, page_state): """Update decode-step pages""" - key_pages = key_pages_cache.value - value_pages = value_pages_cache.value + key_pages = key_pages_cache.get_value() + value_pages = value_pages_cache.get_value() batch_size, _, kv_heads, head_dim = key.shape kv_heads, _, _, head_dim = key_pages.shape @@ -511,6 +511,6 @@ def update_decode_step_pages(self, key_pages_cache, value_pages_cache, key, valu key_pages_updated = key_pages.at[kv_indices, broadcast_pages, broadcast_pos].set(new_key) value_pages_updated = value_pages.at[kv_indices, broadcast_pages, broadcast_pos].set(new_value) - key_pages_cache.value = key_pages_updated - value_pages_cache.value = value_pages_updated + key_pages_cache.set_value(key_pages_updated) + value_pages_cache.set_value(value_pages_updated) return key_pages_cache, value_pages_cache diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index ed0ca2f9a8..133273a36d 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -1200,7 +1200,7 @@ def __call__( sparse_loss=self.config.indexer_sparse_training, scaling_factor=self.config.indexer_loss_scaling_factor, ) - self.sow(nnx.Intermediate, "indexer_loss", indexer_loss) + self.indexer_loss = nnx.Intermediate(indexer_loss) # Check if we need QK Clip stats use_qk_clip = self.model_mode == MODEL_MODE_TRAIN and self.config.use_qk_clip diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index 2252ceceb2..e72cfe9134 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -902,7 +902,7 @@ def apply_attention( local_out, local_max, local_sum = impl(query, key, value, lengths, self.ragged_block_size) if record_max_logits: - self.sow("intermediates", "max_logits", local_max) + self.max_logits = nnx.Intermediate(local_max) return local_out, local_max, local_sum # 'vllm_rpa' uses the same dot-attention wrapper but routes to the vLLM @@ -951,7 +951,7 @@ def apply_attention( record_max_logits=record_max_logits, ) if max_logits is not None: - self.sow("intermediates", "max_logits", max_logits) + self.max_logits = nnx.Intermediate(max_logits) return out, None, None else: @@ -1861,7 +1861,7 @@ def apply_attention_dot( max_logits_per_group = jnp.max(attn_weights, axis=(-2, -1)) b, n_kv, g = max_logits_per_group.shape max_logits = max_logits_per_group.reshape(b, n_kv * g) - self.sow("intermediates", "max_logits", max_logits) + self.max_logits = nnx.Intermediate(max_logits) return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode, wv_product_einsum, sinks) diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index f933d27440..525fff1ed5 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -152,7 +152,7 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: raise ValueError("Input type must be an integer or unsigned integer.") embedding = jnp.asarray( - _maybe_move_embedding_to_device(self.embedding.value, self.config), + _maybe_move_embedding_to_device(self.embedding.get_value(), self.config), self.dtype, ) @@ -196,7 +196,7 @@ def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Arr Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ - embedding = self.embedding.value + embedding = self.embedding.get_value() attend_dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype return attend_on_embedding(query, embedding, attend_dtype, self.config, out_sharding) diff --git a/src/maxtext/layers/engram.py b/src/maxtext/layers/engram.py index d218d88c58..3b2eb4e2b5 100644 --- a/src/maxtext/layers/engram.py +++ b/src/maxtext/layers/engram.py @@ -15,7 +15,7 @@ """ DeepSeek-AI, `Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models `_, 2026 - + Reference implementation: https://github.com/deepseek-ai/Engram/blob/main/engram_demo_v1.py """ @@ -53,7 +53,7 @@ class CompressedTokenizer: def __init__(self, tokenizer: HFTokenizer): normalizer = self._build_normalizer() self.lookup_table_np, self.num_new_token = self._build_lookup_table(tokenizer, normalizer) - self.lookup_table = jnp.array(self.lookup_table_np, dtype=jnp.int64) + self.lookup_table = jnp.array(self.lookup_table_np, dtype=jnp.int32) def __len__(self) -> int: return self.num_new_token @@ -125,7 +125,7 @@ def __call__(self, input_ids) -> Array: """ Maps original token IDs to compressed IDs. """ - input_ids = jnp.asarray(input_ids, dtype=jnp.int64) + input_ids = jnp.asarray(input_ids, dtype=jnp.int32) # Map negative IDs to 0 for lookup, then mask output back. safe_ids = jnp.where(input_ids < 0, 0, input_ids) @@ -187,7 +187,7 @@ def __init__( # Pre-calculate odd multipliers for hashing: {layer_id: multipliers} # Store as JAX arrays self.layer_multipliers = { - k: jnp.array(v, dtype=jnp.int64) for k, v in self._calculate_multipliers_across_layers(seed).items() + k: jnp.array(v, dtype=jnp.int32) for k, v in self._calculate_multipliers_across_layers(seed).items() } # Pre-calculate unique prime vocab sizes for every head @@ -201,9 +201,9 @@ def _calculate_multipliers_across_layers(self, seed: int) -> dict[int, np.ndarra Returns: A dictionary mapping layer_id to a list of `max_ngram_size` multipliers. """ - # Pre-calculate bounds for random generation - max_long = np.iinfo(np.int64).max - m_max = int(max_long // self.tokenizer_vocab_size) + # Pre-calculate bounds for random generation using int32 to avoid overflow + max_int = np.iinfo(np.int32).max + m_max = int(max_int // self.tokenizer_vocab_size) half_bound = max(1, m_max // 2) # Hard-code prime number to align with reference LAYER_PRIME_OFFSET = 10007 @@ -214,7 +214,7 @@ def _calculate_multipliers_across_layers(self, seed: int) -> dict[int, np.ndarra layer_seed = int(seed + LAYER_PRIME_OFFSET * int(layer_id)) np_rng = np.random.default_rng(layer_seed) # Generate random odd integers - random_value = np_rng.integers(low=0, high=half_bound, size=(self.max_ngram_size,), dtype=np.int64) + random_value = np_rng.integers(low=0, high=half_bound, size=(self.max_ngram_size,), dtype=np.int32) multipliers = random_value * 2 + 1 layer_multipliers[layer_id] = multipliers return layer_multipliers @@ -272,7 +272,7 @@ def _get_ngram_hashes(self, compressed_ids: Array, layer_id: int) -> Array: Returns: hash_ids: [B, S, H_total] where H_total = H * num_ngram_orders """ - x = jnp.asarray(compressed_ids, dtype=jnp.int64) + x = jnp.asarray(compressed_ids, dtype=jnp.int32) B, _ = x.shape # 1. Create Sliding Windows via Shifting @@ -282,7 +282,7 @@ def _get_ngram_hashes(self, compressed_ids: Array, layer_id: int) -> Array: shifted_inputs.append(x) else: # Pre-allocate full array with PAD_ID - padding = jnp.full((B, k), self.pad_id, dtype=jnp.int64) + padding = jnp.full((B, k), self.pad_id, dtype=jnp.int32) # Fast memory copy, slicing and assignment # e.g., k=1, [PAD, The, cat] # k=2, [PAD, PAD, The] @@ -309,7 +309,7 @@ def _get_ngram_hashes(self, compressed_ids: Array, layer_id: int) -> Array: # Retrieve prime vocab sizes for all heads of this n-gram order vocab_sizes_for_this_gram = vocab_sizes[n - 2] - mods = jnp.array(vocab_sizes_for_this_gram, dtype=jnp.int64) + mods = jnp.array(vocab_sizes_for_this_gram, dtype=jnp.int32) # Broadcast Modulo: Map hash to valid table indices # [B, S, 1] % [H] -> [B, S, H] diff --git a/src/maxtext/layers/initializers.py b/src/maxtext/layers/initializers.py index 20baf9a633..ed09ab83f5 100644 --- a/src/maxtext/layers/initializers.py +++ b/src/maxtext/layers/initializers.py @@ -60,10 +60,10 @@ def init_fn(key, shape, dtype, in_axis, out_axis): return init_fn -def variable_to_logically_partitioned(variable: nnx.VariableState): +def variable_to_logically_partitioned(variable: nnx.Variable): """Wraps an NNX variable's value in `nn.LogicallyPartitioned`. - This function inspects the metadata of an `nnx.VariableState` object. If + This function inspects the metadata of an `nnx.Variable` object. If sharding information ('out_sharding', 'sharding' or 'sharding_names') is present, it wraps the variable's value in `nn.LogicallyPartitioned` to apply the specified sharding constraints. @@ -73,16 +73,16 @@ def variable_to_logically_partitioned(variable: nnx.VariableState): wrapping. Args: - variable: The `nnx.VariableState` object to process. + variable: The `nnx.Variable` object to process. Returns: The variable's value, potentially wrapped in `nn.LogicallyPartitioned`. """ - if isinstance(variable.value, aqt_tensor.QTensor): - return variable.value + if isinstance(variable.get_value(), aqt_tensor.QTensor): + return variable.get_value() if variable.type.__name__ == "_overwrite_with_gradient": - return variable.value + return variable.get_value() metadata = variable.get_metadata() out_sharding = None @@ -95,10 +95,10 @@ def variable_to_logically_partitioned(variable: nnx.VariableState): if out_sharding is not None: return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args] - variable.value, + variable.get_value(), out_sharding, # type: ignore[arg-type] mesh=metadata.get("mesh"), rules=metadata.get("rules"), ) else: - return variable.value + return variable.get_value() diff --git a/src/maxtext/layers/mhc.py b/src/maxtext/layers/mhc.py index a4a4771c91..ce700aafcd 100644 --- a/src/maxtext/layers/mhc.py +++ b/src/maxtext/layers/mhc.py @@ -33,9 +33,7 @@ def get_functions(expansion_rate: int): def expand(x: Array): # (batch, length, dim) -> (batch, length, streams, dim) - return jnp.repeat( - jnp.expand_dims(x, axis=2), expansion_rate, axis=2 - ).astype(x.dtype) + return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2).astype(x.dtype) def reduce(x: Array): # (batch, length, streams, dim) -> (batch, length, dim) @@ -109,15 +107,15 @@ def __init__( # Scalars self.res_alpha_scale = nnx.Param( default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), - sharding=(None,), + out_sharding=(None,), ) self.pre_alpha_scale = nnx.Param( default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), - sharding=(None,), + out_sharding=(None,), ) self.post_alpha_scale = nnx.Param( default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), - sharding=(None,), + out_sharding=(None,), ) # Weight matrices @@ -133,7 +131,7 @@ def __init__( in_axis=in_axis, out_axis=out_axis, ), - sharding=weight_sharding_axis_name, + out_sharding=weight_sharding_axis_name, ) self.pre_alpha = nnx.Param( scale_init( @@ -143,7 +141,7 @@ def __init__( in_axis=in_axis, out_axis=out_axis, ), - sharding=weight_sharding_axis_name, + out_sharding=weight_sharding_axis_name, ) self.post_alpha = nnx.Param( scale_init( @@ -153,23 +151,21 @@ def __init__( in_axis=in_axis, out_axis=out_axis, ), - sharding=weight_sharding_axis_name, + out_sharding=weight_sharding_axis_name, ) # Biases self.res_beta = nnx.Param( - default_bias_init( - self.rngs.params(), (self.k, self.k), self.weight_dtype - ), - sharding=(None, None), + default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype), + out_sharding=(None, None), ) self.pre_beta = nnx.Param( default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), - sharding=(None,), + out_sharding=(None,), ) self.post_beta = nnx.Param( default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), - sharding=(None,), + out_sharding=(None,), ) def res_mapping(self, x: Array): @@ -179,18 +175,14 @@ def res_mapping(self, x: Array): res_beta = jnp.asarray(self.res_beta[...], self.dtype) res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype) # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) - h_res = jnp.einsum( - "bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision - ) + h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) b, s, _ = h_res.shape h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :] output = sinkhorn(intermediate, self.sinkhorn_iterations) return output - def mapping( - self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int - ): + def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): """Helper function for both pre and post mappings.""" # In MaxText, we match weight precision to activations before Matmul alpha = jnp.asarray(alpha, self.dtype) @@ -236,9 +228,7 @@ def __call__( self.pre_beta[...], 1.0, ) - layer_input = jnp.einsum( - "bskd,bsk -> bsd", x, pre_mapping, precision=self.matmul_precision - ) + layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.matmul_precision) # 3. Pre-norm layer_input = norm_fn(layer_input) @@ -246,15 +236,11 @@ def __call__( # 4. Attention or MLP metadata = {} if mhc_type == HyperConnectionType.ATTENTION: - layer_out, _ = branch_fn( - inputs_q=layer_input, inputs_kv=layer_input, **kwargs - ) + layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs) elif mhc_type == HyperConnectionType.MLP_DENSE: layer_out = branch_fn(inputs=layer_input, **kwargs) elif mhc_type == HyperConnectionType.MLP_MOE: - layer_out, load_balance_loss, moe_bias_updates = branch_fn( - inputs=layer_input, **kwargs - ) + layer_out, load_balance_loss, moe_bias_updates = branch_fn(inputs=layer_input, **kwargs) metadata["load_balance_loss"] = load_balance_loss metadata["moe_bias_updates"] = moe_bias_updates else: @@ -277,7 +263,5 @@ def __call__( # 6. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb] res_mapping = self.res_mapping(norm_x) - res_out = jnp.einsum( - "bskd,bskm -> bsmd", x, res_mapping, precision=self.matmul_precision - ) + res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.matmul_precision) return res_out + post_out, metadata diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 09f1620922..e23c3eba9f 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -236,7 +236,7 @@ def __init__( kernel_in_axis, kernel_out_axis, ), - sharding=self.kernel_axes, + out_sharding=self.kernel_axes, ) if self.use_bias: @@ -244,7 +244,7 @@ def __init__( bias_shape = kernel_shape[-len(self.out_features_shape) :] self.bias = nnx.Param( default_bias_init(rngs.params(), bias_shape, self.weight_dtype), - sharding=bias_axes, + out_sharding=bias_axes, ) else: self.bias = None @@ -267,7 +267,6 @@ def quant_dot_general(self) -> nnx_wrappers.ToNNX | None: return getattr(self, self._quant_dot_general_name) def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.Array, Optional[jax.Array]]: - inputs = jnp.asarray(inputs, self.dtype) norm_axis = linears.normalize_axes(self.axis, inputs.ndim) @@ -460,7 +459,7 @@ def __init__( kernel_in_axis, kernel_out_axis, ), - sharding=self.wi_kernel_axes, + out_sharding=self.wi_kernel_axes, ) self.wo = nnx.Param( self.kernel_init( @@ -470,7 +469,7 @@ def __init__( kernel_in_axis, kernel_out_axis, ), - sharding=self.wo_kernel_axes, + out_sharding=self.wo_kernel_axes, ) else: # Pad model dimension in Unfused MoE weight kernels for GMM_v2 execution. @@ -487,7 +486,7 @@ def __init__( kernel_in_axis, kernel_out_axis, ), - sharding=self.wi_kernel_axes, + out_sharding=self.wi_kernel_axes, ) self.wi_1 = nnx.Param( self.kernel_init( @@ -497,7 +496,7 @@ def __init__( kernel_in_axis, kernel_out_axis, ), - sharding=self.wi_kernel_axes, + out_sharding=self.wi_kernel_axes, ) self.wo = nnx.Param( self.kernel_init( @@ -507,7 +506,7 @@ def __init__( kernel_in_axis, kernel_out_axis, ), - sharding=self.wo_kernel_axes, + out_sharding=self.wo_kernel_axes, ) if self.config.mlp_bias: @@ -517,15 +516,15 @@ def __init__( wo_bias_shape = (self.num_experts, self.moe_expert_input_dim) self.wi_0_bias = nnx.Param( default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype), - sharding=wi_bias_axes, + out_sharding=wi_bias_axes, ) self.wi_1_bias = nnx.Param( default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype), - sharding=wi_bias_axes, + out_sharding=wi_bias_axes, ) self.wo_bias = nnx.Param( default_bias_init(self.rngs.params(), wo_bias_shape, self.weight_dtype), - sharding=wo_bias_axes, + out_sharding=wo_bias_axes, ) else: self.wi_0_bias = None @@ -535,7 +534,7 @@ def __init__( if self.config.decoder_block == ctypes.DecoderBlockType.GEMMA4: self.per_expert_scale = nnx.Param( jnp.ones((self.num_experts,), dtype=self.weight_dtype), - sharding=("exp",), + out_sharding=("exp",), ) else: self.per_expert_scale = None diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index eb81d596d9..d41d924456 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -42,7 +42,7 @@ flax_config.update("flax_always_shard_variable", False) -def is_vanilla_variable(vs: variablelib.VariableState) -> bool: +def is_vanilla_variable(vs: variablelib.Variable) -> bool: """A variables state is vanilla if its metadata is essentially blank. Returns False only if it has non-empty hooks or any non-built-in attribute. @@ -56,16 +56,16 @@ def is_vanilla_variable(vs: variablelib.VariableState) -> bool: return True -def to_linen_var(vs: variablelib.VariableState) -> meta.AxisMetadata: +def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata: metadata = vs.get_metadata() if "linen_meta_type" in metadata: linen_type = metadata["linen_meta_type"] if hasattr(linen_type, "from_nnx_metadata"): - return linen_type.from_nnx_metadata({"value": vs.value, **metadata}) - return linen_type(vs.value, **metadata) + return linen_type.from_nnx_metadata({"value": vs.get_value(), **metadata}) + return linen_type(vs.get_value(), **metadata) if is_vanilla_variable(vs): - return vs.value - return nnx.bridge.NNXMeta(vs.type, vs.value, metadata) + return vs.get_value() + return nnx.bridge.NNXMeta(vs.type, vs.get_value(), metadata) def get_col_name(keypath: tp.Sequence[Any]) -> str: @@ -126,9 +126,6 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: linen_structured = {} for kp, v in nnx.traversals.flatten_mapping(nnx_attrs).items(): if isinstance(v, variablelib.Variable): - col_name = variablelib.variable_name_from_type(type(v)) - v = to_linen_var(v.to_state()) - elif isinstance(v, variablelib.VariableState): col_name = variablelib.variable_name_from_type(v.type) v = to_linen_var(v) else: @@ -141,7 +138,7 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: def _set_initializing(module: Module, initializing: bool): for _, value in graph.iter_graph(module): if isinstance(value, Pytree): - value._object__state._initializing = initializing # pylint: disable=protected-access + value._pytree__state._initializing = initializing # pylint: disable=protected-access def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs): @@ -249,7 +246,7 @@ def __call__( # rename default to params if "params" not in _rngs and "default" in _rngs: _rngs["params"] = _rngs.pop("default") - if self._object__state.initializing: + if self._pytree__state.initializing: out, updates = self.to_nnx__module.init_with_output(_rngs, *args, method=method, **kwargs) else: nnx_attrs = { @@ -415,7 +412,7 @@ class ToLinen(linen.Module): args: tp.Sequence = () kwargs: tp.Mapping[str, tp.Any] = FrozenDict({}) skip_rng: bool = False - metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var + metadata_fn: tp.Callable[[variablelib.Variable], tp.Any] | None = to_linen_var @linen.compact def __call__(self, *args, nnx_method: tp.Callable[..., Any] | str | None = None, **kwargs): @@ -494,7 +491,7 @@ def _update_variables(self, module): # group state by collection for path, leaf in nnx.to_flat_state(state): - type_ = leaf.type if isinstance(leaf, nnx.VariableState) else type(leaf) + type_ = leaf.type if isinstance(leaf, nnx.Variable) else type(leaf) collection = variablelib.variable_name_from_type(type_, allow_register=True) if collection not in collection_flat_state: collection_flat_state[collection] = [] @@ -505,18 +502,18 @@ def _update_variables(self, module): if self.is_mutable_collection(collection): def _to_linen_var(x): - if isinstance(x, nnx.VariableState): + if isinstance(x, nnx.Variable): if self.metadata_fn is not None: return self.metadata_fn(x) # pylint: disable=too-many-function-args else: - return x.value + return x.get_value() return x collection_state = nnx.traversals.unflatten_mapping(flat_state) collection_state = jax.tree.map( _to_linen_var, collection_state, - is_leaf=lambda x: isinstance(x, nnx.VariableState), + is_leaf=lambda x: isinstance(x, nnx.Variable), ) for k, v in collection_state.items(): self.put_variable(collection, k, v) @@ -532,7 +529,7 @@ class _Missing: def to_linen( nnx_class: tp.Callable[..., Module], *args, - metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var, + metadata_fn: tp.Callable[[variablelib.Variable], tp.Any] | None = to_linen_var, name: str | None = None, skip_rng: bool = False, abstract_init: bool = True, @@ -551,7 +548,7 @@ def to_linen( def to_linen_class( base_nnx_class: type[M], - base_metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var, + base_metadata_fn: tp.Callable[[variablelib.Variable], tp.Any] | None = to_linen_var, base_skip_rng: bool = False, **partial_kwargs: tp.Any, ) -> type[ToLinen]: diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index 3bce30d44e..bf91262bf1 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -81,7 +81,7 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> y = jax.lax.with_sharding_constraint(y, out_sharding) return y - scale = self.scale.value + scale = self.scale.get_value() # Move scale to device if parameter offloading is enabled if self.parameter_memory_host_offload: max_logging.log("normalizations.py: Moving scale parameter to device") diff --git a/src/maxtext/models/gpt3.py b/src/maxtext/models/gpt3.py index 7a054ff9cf..2736b8aafb 100644 --- a/src/maxtext/models/gpt3.py +++ b/src/maxtext/models/gpt3.py @@ -90,7 +90,7 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> if self.reductions_in_fp32: normed_inputs = normed_inputs.astype(self.dtype) - scale = self.scale.value + scale = self.scale[...] # Move scale to device if parameter offloading is enabled if self.parameter_memory_host_offload: max_logging.log("gpt3.py: Moving scale parameter to device") @@ -106,7 +106,7 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> ) if self.bias is not None: - bias = self.bias.value + bias = self.bias[...] bias = jnp.asarray(bias, self.dtype) output += bias return output @@ -354,7 +354,6 @@ def __init__( rngs: nnx.Rngs, quant: Optional[Quant] = None, ): - self.config = config self.mesh = mesh self.quant = quant diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index 3491663ebe..bd65f04438 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -556,7 +556,7 @@ def __call__( conv_state = None if model_mode != MODEL_MODE_TRAIN: # Retrieve state from self.cache - conv_state = self.cache.conv_state.value + conv_state = self.cache.conv_state[...] if conv_state.shape[0] != batch: # Assumes zero-initialized state for testing if conv_state.shape[0] == 1: @@ -578,7 +578,7 @@ def extract_state(c_in, v_len): new_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :] # Update self.cache in place - self.cache.conv_state.value = new_conv_state + self.cache.conv_state.set_value(new_conv_state) else: # Train: pad with zeros conv_input = jnp.pad(qkv, ((0, 0), (conv_kernel_size - 1, 0), (0, 0))) @@ -629,7 +629,7 @@ def extract_state(c_in, v_len): recurrent_state = None if model_mode != MODEL_MODE_TRAIN: # Retrieve state from self.cache - recurrent_state = self.cache.recurrent_state.value + recurrent_state = self.cache.recurrent_state[...] if recurrent_state.shape[0] != batch: if recurrent_state.shape[0] == 1: @@ -651,7 +651,7 @@ def extract_state(c_in, v_len): if model_mode != MODEL_MODE_TRAIN: # Update self.cache in place for both prefill and decode - self.cache.recurrent_state.value = recurrent_state_out + self.cache.recurrent_state.set_value(recurrent_state_out) # ========================================================================= # STEP D: Final Output Stage diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 7b009298a6..1a66a532fb 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -231,9 +231,11 @@ def __init__( self.strategy = strategy # Per-step per-device TFLOPs (constants for the run): student fwd+bwd + teacher fwd-only. - self._tflops_combined, self._tflops_student, self._tflops_teacher = ( - distillation_utils.calculate_distillation_tflops_per_device(student_config, teacher_config, is_offline=is_offline) - ) + ( + self._tflops_combined, + self._tflops_student, + self._tflops_teacher, + ) = distillation_utils.calculate_distillation_tflops_per_device(student_config, teacher_config, is_offline=is_offline) max_logging.log( f"Per-step per-device TFLOPs — combined: {self._tflops_combined:.2f}, " f"student (fwd+bwd): {self._tflops_student:.2f}, teacher (fwd-only): {self._tflops_teacher:.2f}" @@ -274,7 +276,7 @@ def _train_step(self, model, optimizer, inputs): """Overrides the main JIT block to natively handle ModelBundle module.""" batch = self.gen_model_input_fn(inputs) - current_step = model.training_step.value + current_step = model.training_step[...] def loss_wrapper(student, teacher, batch): if "teacher_output" in batch: @@ -317,8 +319,7 @@ def loss_wrapper(student, teacher, batch): out, grads = grad_fn(model.student_model, model.teacher_model, batch) - # Increment step counter after loss computation - model.training_step.value = current_step + 1 + model.training_step.set_value(current_step + 1) tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) @@ -671,7 +672,6 @@ def train_distill( # Hardware Execution (Safe Context) max_logging.log("Applying logical axis rules for model initialization and training...") with mesh, nn_partitioning.axis_rules(student_config.logical_axis_rules): - # 2. Load Models if is_offline: max_logging.log("Offline Distillation: Skipping Teacher Model loading.") diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index e2eb29412a..ab85894832 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -852,9 +852,9 @@ def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx): # structure of linen checkpoint: {'params': {'params': {'decoder': ...}}} is_nnx_checkpoint = False target_for_restore = jax.tree.map( - lambda v: v.value, + lambda v: v[...], sharded_state, - is_leaf=lambda n: hasattr(n, "value"), + is_leaf=lambda n: isinstance(n, nnx.Variable), ) target_for_restore = _adjust_target_for_moe_fusion( @@ -876,7 +876,7 @@ def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx): # NNX checkpoint: {'decoder': {'value': ...}}, or NNX-RL with extra 'base' nesting. # Restore only nnx.Param — RNG variable shapes may differ between checkpoint and model. target_for_restore = jax.tree.map( - lambda v: {"value": v.value}, + lambda v: {"value": v[...]}, sharded_state, is_leaf=lambda n: isinstance(n, nnx.Variable), ) @@ -892,7 +892,7 @@ def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx): # Free memory used by initial sharded_state before restore, to make room for the incoming checkpoint arrays. def _free_device_memory(node): if isinstance(node, nnx.Variable) and not isinstance(node, nnx.RngState): - val = node.value + val = node[...] else: val = node @@ -922,7 +922,7 @@ def _free_device_memory(node): if checkpoint: model_arrays = jax.tree.map( - lambda v: v.value, + lambda v: v[...], sharded_state, is_leaf=lambda n: isinstance(n, nnx.Variable), ) @@ -934,7 +934,7 @@ def _free_device_memory(node): # nnx.get_partition_spec returns Variables wrapping PartitionSpecs at the leaves; # unwrap to raw PartitionSpecs so _normalize_logical_axes can read them. logical_axes_tree = jax.tree.map( - lambda v: v.value, + lambda v: v.get_value(), specs, is_leaf=lambda n: isinstance(n, nnx.Variable), ) @@ -1024,7 +1024,7 @@ def setup_decode_state_from_nnx(model, config, rng, mesh): # Extract nnx.Param values, converting the State pytree to a plain nested dict. def _state_to_dict(tree): if isinstance(tree, nnx.Variable): - return tree.value + return tree.get_value() if hasattr(tree, "items") and not isinstance(tree, jax.Array): return {k: _state_to_dict(v) for k, v in tree.items()} return tree diff --git a/tests/inference/kvcache_test.py b/tests/inference/kvcache_test.py index 276c157158..ba068e0607 100644 --- a/tests/inference/kvcache_test.py +++ b/tests/inference/kvcache_test.py @@ -67,15 +67,15 @@ def test_update_kv_cache(self): model_mode, ) prefill_low_rank_main = jnp.transpose( - test_module.cached_prefill_key.value, + test_module.cached_prefill_key[...], test_module.key_axis_order, ) prefill_key_rope = jnp.transpose( - test_module.cached_prefill_value.value, + test_module.cached_prefill_value[...], test_module.key_axis_order, ) - ar_low_rank_main = jnp.transpose(test_module.cached_ar_key.value, test_module.key_axis_order) - ar_key_rope = jnp.transpose(test_module.cached_ar_value.value, test_module.key_axis_order) + ar_low_rank_main = jnp.transpose(test_module.cached_ar_key[...], test_module.key_axis_order) + ar_key_rope = jnp.transpose(test_module.cached_ar_value[...], test_module.key_axis_order) # Ensure prefill cache variables have correct shapes and values self.assertEqual( @@ -113,15 +113,15 @@ def test_update_kv_cache(self): model_mode, ) prefill_low_rank_main = jnp.transpose( - test_module.cached_prefill_key.value, + test_module.cached_prefill_key[...], test_module.key_axis_order, ) prefill_key_rope = jnp.transpose( - test_module.cached_prefill_value.value, + test_module.cached_prefill_value[...], test_module.key_axis_order, ) - ar_low_rank_main = jnp.transpose(test_module.cached_ar_key.value, test_module.key_axis_order) - ar_key_rope = jnp.transpose(test_module.cached_ar_value.value, test_module.key_axis_order) + ar_low_rank_main = jnp.transpose(test_module.cached_ar_key[...], test_module.key_axis_order) + ar_key_rope = jnp.transpose(test_module.cached_ar_value[...], test_module.key_axis_order) # Ensure prefill cache variables are same as before self.assertEqual( diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index c20741b40b..09d281c70a 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -614,10 +614,10 @@ def test_share_kv_projections(self): ) # Force unshared layer to copy weights from shared layer, mapping 'key' to 'value' - attention_no_share.query.kernel.value = attention_share_kv.query.kernel.value - attention_no_share.key.kernel.value = attention_share_kv.key.kernel.value - attention_no_share.value.kernel.value = attention_share_kv.key.kernel.value - attention_no_share.out.kernel.value = attention_share_kv.out.kernel.value + attention_no_share.query.kernel[...] = attention_share_kv.query.kernel[...] + attention_no_share.key.kernel[...] = attention_share_kv.key.kernel[...] + attention_no_share.value.kernel[...] = attention_share_kv.key.kernel[...] + attention_no_share.out.kernel[...] = attention_share_kv.out.kernel[...] output_no_share, _ = attention_no_share( lnx, diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py index bb0c6c962b..8e9926dcde 100644 --- a/tests/unit/model_creation_utils_test.py +++ b/tests/unit/model_creation_utils_test.py @@ -453,7 +453,7 @@ def __init__(self, rngs: nnx.Rngs): # Mirror the unwrap done in from_pretrained. logical_axes_tree = jax.tree.map( - lambda v: v.value, + lambda v: v.get_value(), specs, is_leaf=lambda n: isinstance(n, nnx.Variable), ) diff --git a/tests/unit/multi_token_prediction_test.py b/tests/unit/multi_token_prediction_test.py index ffe30bea6e..99300b97df 100644 --- a/tests/unit/multi_token_prediction_test.py +++ b/tests/unit/multi_token_prediction_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" multi_token_prediction_test """ +"""multi_token_prediction_test""" import unittest @@ -248,8 +248,8 @@ def test_sow_functionality(self): self.assertTrue(hasattr(state.mtp_block, "weights")) # Access the actual data tuple inside the .value attribute. - losses_val = state.mtp_block.losses.value - weights_val = state.mtp_block.weights.value + losses_val = state.mtp_block.losses[...] + weights_val = state.mtp_block.weights[...] self.assertEqual(len(losses_val), self.cfg.mtp_num_layers) self.assertEqual(len(weights_val), self.cfg.mtp_num_layers) @@ -283,8 +283,8 @@ def test_loss_aggregation_logic(self): # Perform the aggregation logic exactly as in `loss_fn`. if mtp_losses_var and mtp_weights_var: - sum_of_all_mtp_losses = jnp.sum(jnp.array(mtp_losses_var.value)) - sum_of_all_mtp_weights = jnp.sum(jnp.array(mtp_weights_var.value)) + sum_of_all_mtp_losses = jnp.sum(jnp.array(mtp_losses_var[...])) + sum_of_all_mtp_weights = jnp.sum(jnp.array(mtp_weights_var[...])) self.assertGreater(sum_of_all_mtp_weights, 0)