diff --git a/docs/_static/js/editable_commands.js b/docs/_static/js/editable_commands.js index 861cf6564d..46b217dc88 100644 --- a/docs/_static/js/editable_commands.js +++ b/docs/_static/js/editable_commands.js @@ -16,16 +16,23 @@ document.addEventListener('DOMContentLoaded', () => { "", "", "", + "", "", "", "", "", + "", + "", + "", + "", + "", "", "", "", "", "", "", + "", "", "", "", diff --git a/docs/conf.py b/docs/conf.py index 8710f8f913..614b41e71a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -86,7 +86,7 @@ # -- Options for autodoc ---------------------------------------------------- autodoc_member_order = "bysource" autodoc_typehints = "description" -autodoc_mock_imports = [ +packages_to_mock = [ "safetensors", "tensorflow_datasets", "torch", @@ -96,6 +96,12 @@ "librosa", "sentencepiece", ] +autodoc_mock_imports = [] +for pkg in packages_to_mock: + try: + __import__(pkg) + except ImportError: + autodoc_mock_imports.append(pkg) autosummary_generate = True # Theme-specific options @@ -122,7 +128,7 @@ os.path.join("run_maxtext", "run_maxtext_via_multihost_runner.md"), os.path.join("reference", "core_concepts", "llm_calculator.ipynb"), os.path.join("reference", "api.rst"), - os.path.join("reference", "api_generated", "MaxText*.rst"), + os.path.join("reference", "api_generated", "maxtext*.rst"), os.path.join("reference", "api_generated", "modules.rst"), os.path.join("reference", "api_generated", "dependencies.github_deps.rst"), os.path.join("reference", "api_generated", "dependencies.github_deps.install_pre_train_deps.rst"), @@ -167,7 +173,7 @@ r"https://huggingface\.co/settings/tokens", # Ignore GitHub PRs and blobs that trigger rate limiting r"https://github\.com/AI-Hypercomputer/maxtext/pull/.*", - r"https://github\.com/google/maxtext/blob/.*", + r"https://github\.com/AI-Hypercomputer/maxtext/blob/.*", ] @@ -209,12 +215,10 @@ def run_apidoc(_): os.path.join(MAXTEXT_REPO_ROOT, "src"), # Paths to exclude os.path.join(MAXTEXT_REPO_ROOT, "tests"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "experimental"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "experimental"), os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "inference"), os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "scratch_code"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "utils", "ckpt_conversion"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "rl"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "multimodal_utils.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "checkpoint_conversion"), ] # Run the command and check for errors diff --git a/docs/tutorials/post_training_index.md b/docs/tutorials/post_training_index.md index 5500f60808..d277cfb4be 100644 --- a/docs/tutorials/post_training_index.md +++ b/docs/tutorials/post_training_index.md @@ -26,6 +26,8 @@ MaxText was co-designed with key Google led innovations to provide a unified pos - **SFT (Supervised Fine-Tuning)** - [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html) - [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft_on_multi_host.html) +- **LoRA (Low-Rank Adaptation)** + - [LoRA on Single-Host TPUs](posttraining/lora.md) - **Multimodal SFT** - [Multimodal Support](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/multimodal.html) - **Reinforcement Learning (RL)** @@ -68,6 +70,7 @@ posttraining/sft_on_multi_host.md posttraining/rl.md posttraining/rl_on_multi_host.md posttraining/knowledge_distillation.md +posttraining/lora.md posttraining/multimodal.md posttraining/full_finetuning.md posttraining/gepa_optimization.md diff --git a/docs/tutorials/posttraining/lora.md b/docs/tutorials/posttraining/lora.md new file mode 100644 index 0000000000..9f1d6377ed --- /dev/null +++ b/docs/tutorials/posttraining/lora.md @@ -0,0 +1,202 @@ + + +# LoRA Fine-tuning on single-host TPUs + +**Low-Rank Adaptation (LoRA)** is a Parameter-Efficient Fine-Tuning (PEFT) technique designed to optimize large language models while minimizing resource consumption. + +Unlike traditional full-parameter fine-tuning, LoRA: + +- **Freezes the pre-trained model weights**, preserving the original knowledge. +- **Injects trainable rank decomposition matrices** into the Transformer layers. + +This approach **greatly reduces the number of trainable parameters** required for downstream tasks, making the process faster and more memory-efficient. + +This tutorial provides step-by-step instructions for setting up the environment and performing LoRA fine-tuning on a Hugging Face dataset using MaxText. + +We use [Tunix](https://github.com/google/tunix), a JAX-based library, to power these post-training tasks. + +In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started! + +## Setup environment variables + +Login to Hugging Face. Provide your access token when prompted: + +```bash +hf auth login +``` + +Set the following environment variables before running LoRA Fine-tuning. + +```sh +# -- Model configuration -- +export MODEL_NAME= # e.g., 'gemma3-4b' + +# -- MaxText configuration -- +export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory or /path/to/my-output-directory +export RUN_NAME= # e.g., $(date +%Y-%m-%d-%H-%M-%S) +export STEPS= # e.g., 1000 +export PER_DEVICE_BATCH_SIZE= # e.g., 1 +export LORA_RANK= # e.g., 16 +export LORA_ALPHA= # e.g., 32.0 +export LEARNING_RATE= # e.g., 3e-6 +export MAX_TARGET_LENGTH= # e.g., 1024 + +# -- Dataset configuration -- +export DATASET_NAME= # e.g., openai/gsm8k +export TRAIN_SPLIT= # e.g., train +export HF_DATA_DIR= # e.g., main +export TRAIN_DATA_COLUMNS= # e.g., ['question','answer'] +export CHAT_TEMPLATE_PATH= # e.g., maxtext/examples/chat_templates/math_qa.json + +# -- LoRA Conversion configuration (Optional) -- +export HF_LORA_ADAPTER_PATH= # e.g., 'username/adapter-name' +``` + +## Customizing Trainable Layers (Optional) + +By default, MaxText determines which layers to apply LoRA to based on the model's architecture by reading `src/maxtext/configs/post_train/lora_module_path.yml`. + +If you need to fine-tune specific components (e.g., targeting only Attention layers to optimize memory usage), you can override these defaults through the following hierarchy: + +### Configuration Hierarchy + +1. **Command Line Argument**: Pass the `lora_module_path` argument directly in your training command. This is the most flexible way for experimental iterations. +2. **Task-Specific Config (`sft.yml`)**: Define the `lora_module_path` parameter in `src/maxtext/configs/post_train/sft.yml` to set a persistent configuration for your SFT runs. +3. **Global Defaults**: Automatic detection via the model-to-regex mapping defined in `lora_module_path.yml`. + +## Get your model checkpoint + +This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint. + +### Option 1: Using an existing MaxText checkpoint + +If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. + +```sh +export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items or /path/to/my-model-checkpoint/0/items +``` + +### Option 2: Converting a Hugging Face checkpoint + +Refer to the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have the correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. + +```sh +export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items or /path/to/my-model-checkpoint/0/items +``` + +## Run a Fresh LoRA Fine-Tuning on Hugging Face Dataset + +Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process. + +Execute the following command to begin training: + +```sh +python3 -m maxtext.trainers.post_train.sft.train_sft \ + run_name="${RUN_NAME?}" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \ + model_name="${MODEL_NAME?}" \ + load_parameters_path="${MAXTEXT_CKPT_PATH?}" \ + hf_path="${DATASET_NAME?}" \ + train_split="${TRAIN_SPLIT?}" \ + hf_data_dir="${HF_DATA_DIR?}" \ + train_data_columns="${TRAIN_DATA_COLUMNS?}" \ + steps="${STEPS?}" \ + per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \ + max_target_length="${MAX_TARGET_LENGTH?}" \ + learning_rate="${LEARNING_RATE?}" \ + chat_template_path="${CHAT_TEMPLATE_PATH?}" \ + enable_nnx=True \ + pure_nnx_decoder=True \ + lora.enable_lora=True \ + lora.lora_rank="${LORA_RANK?}" \ + lora.lora_alpha="${LORA_ALPHA?}" +``` + +Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`. + +## (Optional) Resume from a previous LoRA checkpoint + +If you want to resume training from a previous run or further fine-tune an existing LoRA adapter, you can specify the LoRA checkpoint path. + +### Step 1: Convert HF LoRA adapter to MaxText format + +If your LoRA adapter is currently in Hugging Face format, you must convert it to MaxText format before it can be loaded. Use the integrated conversion utility: + +```sh +python3 -m maxtext.checkpoint_conversion.to_maxtext \ + model_name="${MODEL_NAME?}" \ + hf_lora_adapter_path="${HF_LORA_ADAPTER_PATH?}" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}/converted_adapter" \ + hardware=cpu skip_jax_distributed_system=True +``` + +### Step 2: Set the restore path + +Point `LORA_RESTORE_PATH` to the converted MaxText adapter directory (the directory containing the `0/items` or Orbax files). + +- **load_parameters_path**: Points to the frozen base model weights (the original model). +- **lora_restore_path**: Points to the previous LoRA adapter weights you wish to load. + +```sh +export LORA_RESTORE_PATH= # e.g., gs://my-bucket/run-1/checkpoints/0/items or /path/to/run-1/checkpoints/0/items +``` + +### Step 3: Run LoRA Fine-Tuning with the Restore Path + +Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process. + +Execute the following command to begin training: + +```sh +python3 -m maxtext.trainers.post_train.sft.train_sft \ + run_name="${RUN_NAME?}" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \ + model_name="${MODEL_NAME?}" \ + load_parameters_path="${MAXTEXT_CKPT_PATH?}" \ + lora.lora_restore_path="${LORA_RESTORE_PATH?}" \ + hf_path="${DATASET_NAME?}" \ + train_split="${TRAIN_SPLIT?}" \ + hf_data_dir="${HF_DATA_DIR?}" \ + train_data_columns="${TRAIN_DATA_COLUMNS?}" \ + steps="${STEPS?}" \ + per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \ + max_target_length="${MAX_TARGET_LENGTH?}" \ + learning_rate="${LEARNING_RATE?}" \ + chat_template_path="${CHAT_TEMPLATE_PATH?}" \ + enable_nnx=True \ + pure_nnx_decoder=True \ + lora.enable_lora=True \ + lora.lora_rank="${LORA_RANK?}" \ + lora.lora_alpha="${LORA_ALPHA?}" +``` + +Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`. + +## (Optional) Convert Fine-tuned LoRA to Hugging Face Format + +After completing the fine-tuning process, your LoRA weights are stored in MaxText/Orbax format. To use these weights with the Hugging Face ecosystem (e.g., for inference or sharing), convert them back using the `to_huggingface.py` script. + +```sh +python3 -m maxtext.checkpoint_conversion.to_huggingface \ + model_name="${MODEL_NAME?}" \ + lora.lora_restore_path="${BASE_OUTPUT_DIRECTORY?}/${RUN_NAME?}/checkpoints//model_params" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}/hf_lora_adapter" +``` + +- `lora.lora_restore_path`: Point this to the specific checkpoint directory (e.g., `.../checkpoints/1000/items`) that you want to export. +- `base_output_directory`: The local or GCS directory where the Hugging Face `adapter_model.safetensors` and `adapter_config.json` will be saved. +- `lora.lora_rank` / `lora.lora_alpha`: Must match the values used during the training phase to ensure the `adapter_config.json` is generated correctly. diff --git a/pytest.ini b/pytest.ini index 10ed0cc6f5..5494d8e000 100644 --- a/pytest.ini +++ b/pytest.ini @@ -15,6 +15,7 @@ addopts = --ignore=tests/unit/gemma3_layers_test.py --ignore=tests/unit/gpt_vs_reference_test.py --ignore=tests/unit/llama4_layers_test.py + --ignore=tests/unit/hf_checkpoint_conversion_test.py --ignore=tests/unit/yarn_vs_reference_test.py --ignore=tests/unit/moba_vs_reference_test.py --ignore=tests/unit/offline_engine_test.py diff --git a/src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt b/src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt index b757fc7c57..954636d0c1 100644 --- a/src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt +++ b/src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt @@ -22,6 +22,7 @@ openai openai-harmony papermill partial-json-parser +peft perfetto prometheus-fastapi-instrumentator py-cpuinfo diff --git a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt index 4744ce0315..78ac583ffd 100644 --- a/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt @@ -2,6 +2,7 @@ # See https://maxtext.readthedocs.io/en/latest/development/update_dependencies.html for details. absl-py>=2.4.0 +accelerate>=1.13.0 aiofiles>=25.1.0 aiohappyeyeballs>=2.6.1 aiohttp>=3.13.5 @@ -77,15 +78,15 @@ gcsfs>=2026.2.0 gepa>=0.1.1 gguf>=0.19.0 google-api-core>=2.30.3 -google-api-python-client>=2.195.0 -google-auth>=2.50.0 -google-auth-httplib2>=0.3.1 -google-auth-oauthlib>=1.3.1 +google-api-python-client>=2.196.0 +google-auth>=2.51.0 +google-auth-httplib2>=0.4.0 +google-auth-oauthlib>=1.4.0 google-cloud-aiplatform>=1.150.0 google-cloud-appengine-logging>=1.9.0 google-cloud-audit-log>=0.5.0 google-cloud-bigquery>=3.41.0 -google-cloud-core>=2.5.1 +google-cloud-core>=2.6.0 google-cloud-logging>=3.15.0 google-cloud-mldiagnostics>=1.0.2 google-cloud-monitoring>=2.30.0 @@ -96,9 +97,9 @@ google-crc32c>=1.8.0 google-genai>=1.75.0 google-metrax>=0.2.3 google-pasta>=0.2.0 -google-resumable-media>=2.8.2 +google-resumable-media>=2.9.0 google-tunix>=0.1.3 -googleapis-common-protos>=1.74.0 +googleapis-common-protos>=1.75.0 grain>=0.2.16 grpc-google-iam-v1>=0.14.4 grpcio>=1.78.0 @@ -219,6 +220,7 @@ parso>=0.8.7 partial-json-parser>=0.2.1.1.post7 pathspec>=1.1.1 pathwaysutils>=0.1.8 +peft>=0.19.1 perfetto>=0.16.0 pexpect>=4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' pillow>=12.1.1 @@ -231,7 +233,7 @@ prometheus-fastapi-instrumentator>=7.1.0 promise>=2.3 prompt-toolkit>=3.0.52 propcache>=0.4.1 -proto-plus>=1.27.2 +proto-plus>=1.28.0 protobuf>=6.33.6 psutil>=7.2.2 ptyprocess>=0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index 13814b5403..dbc259d92c 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Converts a MaxText checkpoint to a HuggingFace-compatible model checkpoint. +"""Converts a MaxText checkpoint to a HuggingFace-compatible format. + +This script supports three conversion modes: +1. Base: Converts a standard MaxText model to a full Hugging Face model. + (Requires `load_parameters_path`) +2. Adapter: Converts a standalone MaxText LoRA checkpoint to HF PEFT format. + (Requires `lora.lora_restore_path`) +3. Merged: Merges MaxText LoRA weights into the base model to produce a full HF model. + (Requires both `load_parameters_path` and `lora.lora_restore_path`) It is invoked using MaxText's pyconfig, which means you provide a base config file and can override parameters on the command line. @@ -20,8 +28,10 @@ Key Parameters (to be set in the config file or as command-line overrides): model_name: (Required) The name of the model to convert (e.g., "gemma2-2b"). Must be a key in `maxtext.utils.globals.HF_IDS`. - load_parameters_path: (Required) Path to the MaxText checkpoint directory - containing the parameter-only checkpoint. + load_parameters_path: (Required for Base/Merged) Path to the MaxText base + parameter-only checkpoint. + lora.lora_restore_path: (Required for Adapter/Merged) Path to the MaxText + LoRA checkpoint directory. base_output_directory: (Optional) The directory where the converted HuggingFace checkpoint will be saved. Can be a local path, a GCS path (gs://...), or a HuggingFace Hub repo ID (hf://...). @@ -44,22 +54,20 @@ is a Hub repo ID (e.g., "hf://my-user/my-model"). Example Usage: - To convert a gemma2-2b MaxText checkpoint and save it to a local directory: + To merge a LoRA adapter into a base model and save as a full HF model: export HF_AUTH_TOKEN="hf_YOUR_TOKEN" python src/maxtext/checkpoint_conversion/to_huggingface.py \ src/maxtext/configs/base.yml \ - model_name="gemma2-2b" \ - load_parameters_path="/path/to/your/maxtext/checkpoint/" \ - base_output_directory="/path/to/your/output/directory" \ - scan_layers=False - - Note: Other parameters in base.yml (like per_device_batch_size, max_target_length, etc.) - are used to initialize the model structure and should be consistent with the - checkpoint being converted, but often don't need to be changed from their defaults. + model_name="gemma3-4b" \ + load_parameters_path="/path/to/base/checkpoint/" \ + lora.lora_restore_path="/path/to/lora/checkpoint/" \ + base_output_directory="/path/to/output/" \ + scan_layers=True """ import jax +import jax.numpy as jnp import os from typing import Sequence import time @@ -84,6 +92,7 @@ detect_and_extract_checkpoint, MemoryMonitorTqdm, print_peak_memory, + save_adapter_files, ) from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -100,6 +109,22 @@ FLAGS = flags.FLAGS +def _get_lora_delta(key, lora_state_dict, lora_scaling): + """Calculates the LoRA delta for a given parameter key.""" + a_key, b_key = key + "_lora_a", key + "_lora_b" + if a_key not in lora_state_dict and key.startswith("params-"): + a_key, b_key = key[7:] + "_lora_a", key[7:] + "_lora_b" + + if a_key in lora_state_dict and b_key in lora_state_dict: + data_a, data_b = jnp.asarray(lora_state_dict[a_key], dtype=jnp.float32), jnp.asarray( + lora_state_dict[b_key], dtype=jnp.float32 + ) + if data_a.ndim > 2: + return jnp.einsum("ipr,rpo->ipo", data_a, data_b) * lora_scaling + return jnp.matmul(data_a, data_b) * lora_scaling + return None + + def _get_model_mappings( model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: pyconfig.HyperParameters ): @@ -246,6 +271,89 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool): raise ValueError(error_msg) +def _transform_weights_to_adapter(param_map, state_dict): + """Extracts standalone PEFT weights from MaxText state dict.""" + processed_params_list = [] + found_hf_modules = set() + for mt_key, hf_paths in param_map.items(): + a_key, b_key = mt_key + "_lora_a", mt_key + "_lora_b" + if a_key not in state_dict and mt_key.startswith("params-"): + a_key, b_key = mt_key[7:] + "_lora_a", mt_key[7:] + "_lora_b" + if a_key in state_dict and b_key in state_dict: + data_a, data_b = state_dict[a_key], state_dict[b_key] + hf_paths = [hf_paths] if not isinstance(hf_paths, list) else hf_paths + for i in range(min(data_a.shape[1] if data_a.ndim > 2 else 1, len(hf_paths))): + found_hf_modules.add(hf_paths[i].split(".")[-2]) + name = hf_paths[i].replace(".weight", "") + processed_params_list.append( + ( + f"base_model.model.{name}.lora_A.weight", + jax.numpy.asarray((data_a[:, i, :] if data_a.ndim > 2 else data_a).T), + ) + ) + processed_params_list.append( + ( + f"base_model.model.{name}.lora_B.weight", + jax.numpy.asarray((data_b[:, i, :] if data_b.ndim > 2 else data_b).T), + ) + ) + return dict(processed_params_list), found_hf_modules + + +def _transform_weights_to_full_model(config, filtered_map_keys, state_dict, param_map, hook_fn_map, shape_map): + """Transforms MaxText weights to HF full model format, with optional LoRA merging.""" + processed_params_list = [] + lora_scaling = config.lora.lora_alpha / config.lora.lora_rank if config.lora.lora_rank > 0 else 1.0 + for key in MemoryMonitorTqdm(filtered_map_keys, leave=True): + weight = [state_dict[subkey] for subkey in key] if isinstance(key, tuple) else state_dict.get(key) + if weight is not None and not isinstance(key, tuple): + delta = _get_lora_delta(key, state_dict, lora_scaling) + if delta is not None: + if delta.shape != weight.shape and delta.size == weight.size: + delta = delta.reshape(weight.shape) + weight = (jnp.asarray(weight, dtype=jnp.float32) + delta).astype(weight.dtype) + if weight is not None: + processed_params_list.extend(process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config)) + return dict(processed_params_list) + + +def _transform_and_save_weights( + config, + lora_restore_path, + load_parameters_path, + param_map, + maxtext_state_dict, + filtered_map_keys, + hook_fn_map, + shape_map, + output_directory, + hf_config_obj, + tokenizer, + processor, +): + """Orchestrates weight transformation and saving based on conversion mode.""" + start = time.time() + if lora_restore_path and not load_parameters_path: + # Adapter Mode + transformed_hf_weights, found_hf_modules = _transform_weights_to_adapter(param_map, maxtext_state_dict) + save_adapter_files(output_directory, transformed_hf_weights, config, found_hf_modules, HF_IDS.get(config.model_name)) + max_logging.log(f"✅ LoRA adapter successfully saved at {output_directory}") + else: + # Base or Merged Mode + transformed_hf_weights = _transform_weights_to_full_model( + config, filtered_map_keys, maxtext_state_dict, param_map, hook_fn_map, shape_map + ) + + if not transformed_hf_weights: + raise ValueError("Error: No weights were transformed. Check mappings and parameter paths.") + + max_logging.log("\nSaving HuggingFace model...") + save_model_files(transformed_hf_weights, hf_config_obj, tokenizer, processor, output_directory) + max_logging.log(f"✅ MaxText model successfully saved in HuggingFace format at {output_directory}") + + max_logging.log(f"Elapse for transform and save: {(time.time() - start) / 60:.2f} min") + + def main(argv: Sequence[str]) -> None: """Main function to convert a MaxText checkpoint to HuggingFace format. @@ -265,8 +373,14 @@ def main(argv: Sequence[str]) -> None: max_utils.print_system_information() overall_start = time.time() - # Load Maxtext checkpoint using Orbax to get full parameter dict - max_logging.log(f"\nLoading Orbax checkpoint from: {config.load_parameters_path}") + lora_restore_path = config.lora.lora_restore_path + load_parameters_path = config.load_parameters_path + + if not load_parameters_path and not lora_restore_path: + raise ValueError("Either load_parameters_path or lora_restore_path must be specified.") + + # Load Maxtext checkpoint using Orbax (now smart enough to load both if present) + max_logging.log("\nLoading Orbax checkpoint(s)...") start = time.time() checkpoint_dict = load_orbax_checkpoint(config) max_logging.log(f"Elapse for checkpoint load: {(time.time() - start) / 60:.2f} min") @@ -306,7 +420,10 @@ def main(argv: Sequence[str]) -> None: maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict) # Validate that checkpoint keys match the parameter mapping - filtered_map_keys = validate_and_filter_param_map_keys(param_map.keys(), maxtext_state_dict.keys()) + state_keys = set(maxtext_state_dict) | { + k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict if "_lora_" in k + } + filtered_map_keys = validate_and_filter_param_map_keys(param_map, state_keys) # When not converting a multimodal model, skip vision encoder weights even if # they are present in the checkpoint (e.g. converting text-only from a @@ -320,44 +437,22 @@ def main(argv: Sequence[str]) -> None: ] # Iterate through the parameter map to transform and collect weights. - # This loop handles both simple 1-to-1 mappings and complex N-to-1 mappings - # (where multiple MaxText weights are combined into a single HF weight). - max_logging.log("\nProccessing weight...") - start = time.time() - processed_params_list = [] - - for key in MemoryMonitorTqdm(filtered_map_keys, total=len(filtered_map_keys), leave=True): - if isinstance(key, tuple): - # if key is tuple of param names, weight is list of param weights - weight = [maxtext_state_dict[subkey] for subkey in key] - else: - # if key is single param name, weight is single param weight - weight = maxtext_state_dict[key] - - processed_params = process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config) - processed_params_list.extend(processed_params) - - max_logging.log(f"Weight dtype after transform: {type(processed_params[0][1].dtype)}") - - transformed_hf_weights = dict(processed_params_list) - max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min") - - # 5. Save in HuggingFace Format - if not transformed_hf_weights: - print("Error: No weights were transformed. Check mappings and parameter paths.") - return - - max_logging.log("\nSaving HuggingFace model...") - start = time.time() - save_model_files( - weight_arrays=transformed_hf_weights, - config=hf_config_obj, - tokenizer=tokenizer, - processor=processor, - output_dir=output_directory, + max_logging.log("\nProcessing weights...") + _transform_and_save_weights( + config, + lora_restore_path, + load_parameters_path, + param_map, + maxtext_state_dict, + filtered_map_keys, + hook_fn_map, + shape_map, + output_directory, + hf_config_obj, + tokenizer, + processor, ) - max_logging.log(f"✅ MaxText model successfully saved in HuggingFace format at {output_directory}") - max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min") + max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min") print_peak_memory() diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index ce8354f0c5..a01a44b146 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -12,52 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -This script converts a HuggingFace model checkpoint to a MaxText-compatible -Orbax checkpoint. +"""Converts a HuggingFace model checkpoint to a MaxText-compatible Orbax checkpoint. + +This script supports three conversion modes: +1. Base: Converts a standard Hugging Face model to MaxText format. +2. Adapter: Converts a standalone Hugging Face LoRA adapter to MaxText PEFT format. + (Requires `hf_lora_adapter_path` in config, and `load_parameters_path` should be empty) +3. Merged: Merges a Hugging Face LoRA adapter into the base weights during conversion. + (Requires both `hf_lora_adapter_path` and `load_parameters_path` to be set/not empty) Key Parameters (to be set in the config file or as command-line overrides): - model_name: (Required) The name of the model to convert (e.g., "gemma2-2b"). + model_name: (Required) The name of the model to convert (e.g., "gemma3-4b"). Must be a key in `maxtext.utils.globals.HF_IDS`. - base_output_directory: (Optional) The directory where the converted HuggingFace - checkpoint will be saved. Can be a local path, a GCS - path (gs://...), or a HuggingFace Hub repo ID (hf://...). - Defaults to "./mt_output/". + base_output_directory: (Optional) The directory where the converted checkpoint + will be saved. Can be a local or GCS path. + load_parameters_path: (Optional) For Merged mode, path to the MaxText base weights. + hf_lora_adapter_path: (Optional) For Adapter or Merged mode, path to the HF LoRA adapter. scan_layers: (bool) Whether the MaxText model was trained with scanned layers. - This must match the training configuration of the checkpoint. --lazy_load_tensors: (bool) If True, uses an on-demand loading strategy to minimize RAM - usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM - Defaults to False. - --hf_model_path: (Optional) Specifies a local or remote directory containing the model weights. - If unspecified, we use the default Hugging Face repository ID - (e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `maxtext.utils.globals`). - This is necessary for locally dequantized models like GPT-OSS or DeepSeek. - --save_dtype: (Optional) Specifies the data type of saved model weights. - Default to `bfloat16` to save memory. + usage during conversion. Recommended for large models. + --hf_model_path: (Optional) Specifies a local or remote directory containing the base HF weights. + --save_dtype: (Optional) Data type of saved weights. Default to `bfloat16`. Environment Variables: - HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to - download models from HuggingFace Hub. + HF_AUTH_TOKEN: (Required) HuggingFace authentication token. Example Usage: - To convert a gemma2-2b model and save it to a specific directory: - - python -m maxtext.checkpoint_conversion.to_maxtext \ - maxtext/configs/base.yml model_name="gemma2-2b" \ - base_output_directory="/path/to/your/output/directory" \ - hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \ - scan_layers=False - - For models with scanned layers (e.g., some custom architectures), you might - need to set scan_layers=True and param_scan_axis accordingly. - - To convert a 70B model with minimal RAM usage: + To merge a HF LoRA adapter into base weights and save as a MaxText checkpoint: python -m maxtext.checkpoint_conversion.to_maxtext \ - maxtext/configs/base.yml model_name="llama3.1-70b" \ - base_output_directory="gs://my-bucket/maxtext-checkpoints" \ + maxtext/configs/base.yml model_name="gemma3-4b" \ + load_parameters_path="gs://my-bucket/maxtext-base-weights" \ + hf_lora_adapter_path="my-user/my-lora-adapter" \ + base_output_directory="gs://my-bucket/maxtext-merged-output" \ hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \ - --lazy_load_tensors=True + scan_layers=True """ import argparse @@ -574,6 +563,248 @@ def _slicing_loader(base_loader, slice_idx): ) +def convert_hf_lora_key_to_maxtext(hf_key: str, param_mapping: dict) -> tuple[str | None, int | None]: + """Convert HF LoRA key to MaxText parameter path and optional layer index.""" + hf_param_key = hf_key.replace(".lora_A.weight", ".weight").replace(".lora_B.weight", ".weight") + hf_param_key = hf_param_key.replace(".lora_A", "").replace(".lora_B", "") + + if hf_param_key.startswith("base_model.model."): + hf_param_key = hf_param_key[len("base_model.model.") :] + + if hf_param_key.startswith("language_model.model."): + hf_param_key = "model.language_model." + hf_param_key[len("language_model.model.") :] + + for mt_key, hf_keys in param_mapping.items(): + if isinstance(hf_keys, str): + if hf_keys == hf_param_key: + return mt_key, None + continue + + if not hf_keys: + continue + + if isinstance(hf_keys[0], list): + for i, sub_list in enumerate(hf_keys): + for j, hf_k in enumerate(sub_list): + if hf_k == hf_param_key: + return mt_key, (i, j) + else: + for i, hf_k in enumerate(hf_keys): + if hf_k == hf_param_key: + return mt_key, i + + return None, None + + +def _process_and_stack_weights( + indexed_weights: dict[str, Any], + is_scanned: bool, + num_layers: int, + axis_to_stack: int, + target_dtype: np.dtype, + mt_key: str, + suffix: str, + config: Any, +) -> np.ndarray: + """Transposes and optionally stacks weights across layers.""" + # Llama 3.1 models require a specific layout transformation for their RoPE embeddings + needs_llama31_rope_shuffle = config.rope_type == "llama3.1" or "llama3.1" in config.model_name.lower() + is_2d_indexed = any(isinstance(k, tuple) for k in indexed_weights.keys()) + + for idx in list(indexed_weights.keys()): + w = indexed_weights[idx].T + + if needs_llama31_rope_shuffle: + if "query-kernel" in mt_key and suffix == "kernel_lora_b": + w = w * (1.0 / np.sqrt(config.head_dim)) + + if ("query-kernel" in mt_key or "key-kernel" in mt_key) and suffix == "kernel_lora_b": + num_heads = config.num_query_heads if "query-kernel" in mt_key else config.num_kv_heads + head_dim = config.head_dim + orig_shape = w.shape + + work_val = w.reshape(orig_shape[0], num_heads, head_dim) + half = head_dim // 2 + + first_half = work_val[..., :half] + second_half = work_val[..., half:] + interleaved = np.stack([first_half, second_half], axis=-1).reshape(work_val.shape) + w = interleaved.reshape(orig_shape) + + indexed_weights[idx] = w + + if not is_scanned: + return np.array(indexed_weights[0], dtype=target_dtype) + + if is_2d_indexed: + num_experts = max(k[0] for k in indexed_weights.keys()) + 1 + num_layers_2d = max(k[1] for k in indexed_weights.keys()) + 1 + + sample_weight = next(iter(indexed_weights.values())) + weights_array = np.zeros((num_experts, num_layers_2d) + sample_weight.shape, dtype=target_dtype) + + for (e_idx, l_idx), w in indexed_weights.items(): + weights_array[e_idx, l_idx] = w.astype(target_dtype) + + return weights_array + + weights_list = [None] * num_layers + for idx, w in indexed_weights.items(): + if isinstance(idx, int) and idx < num_layers: + weights_list[idx] = w + + sample_weight = next((w for w in weights_list if w is not None), None) + if sample_weight is None: + return np.array([], dtype=target_dtype) + + for i in range(num_layers): + if weights_list[i] is None: + weights_list[i] = np.zeros_like(sample_weight) + + return np.stack(weights_list, axis=axis_to_stack).astype(target_dtype) + + +def convert_lora_to_maxtext_adapter( + config, + lora_weights: dict[str, Any], + save_dtype: str = "bfloat16", +) -> dict[str, Any]: + """Converts HF LoRA weights to MaxText adapter format.""" + model_key = config.model_name + if "-Instruct" in model_key: + max_logging.log("Warning: You want an Instruct version, so we are using the base model architecture instead.") + model_key = model_key.replace("-Instruct", "") + hf_config_obj = HF_MODEL_CONFIGS[model_key] + hf_config_dict = hf_config_obj.to_dict() + param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_dict, config, config.scan_layers) + + mt_adapter_tree = {} + mapped_count = 0 + target_dtype = ml_dtypes.bfloat16 if save_dtype == "bfloat16" else np.float32 + + collected_weights = {} + + for hf_key, weight in lora_weights.items(): + mt_key, index = convert_hf_lora_key_to_maxtext(hf_key, param_map_mt_to_hf) + + if mt_key: + if hasattr(weight, "numpy"): + # bfloat16 to numpy direct conversion is not fully supported in all PyTorch versions + if weight.dtype == torch.bfloat16: + weight = weight.to(torch.float32) + weight = weight.detach().cpu().numpy() + suffix = "kernel_lora_a" if "lora_A" in hf_key or "lora_a" in hf_key else "kernel_lora_b" + + if isinstance(mt_key, tuple): + mt_key = mt_key[0] # Fallback for composite keys, though LoRA usually doesn't target them directly + + if mt_key not in collected_weights: + collected_weights[mt_key] = {} + if suffix not in collected_weights[mt_key]: + collected_weights[mt_key][suffix] = {} + + idx = index if index is not None else 0 + collected_weights[mt_key][suffix][idx] = weight + mapped_count += 1 + + for mt_key, suffixes in collected_weights.items(): + clean_mt_key = mt_key.replace("-kernel", "") + parts = clean_mt_key.split("-") + if parts[0] == "params": + parts = parts[1:] + + for suffix, indexed_weights in suffixes.items(): + is_scanned = isinstance(param_map_mt_to_hf.get(mt_key), list) + num_layers = len(param_map_mt_to_hf[mt_key]) if is_scanned else 1 + + final_weight = _process_and_stack_weights( + indexed_weights, is_scanned, num_layers, config.param_scan_axis, target_dtype, mt_key, suffix, config + ) + + current = mt_adapter_tree + for part in parts: + if part not in current: + current[part] = {} + current = current[part] + current[suffix] = {"value": final_weight} + + max_logging.log(f"Successfully mapped {mapped_count} out of {len(lora_weights)} LoRA parameters") + return mt_adapter_tree + + +def _setup_merge_mode_getter(tensor_getter, config, hf_lora_adapter_path, revision): + """Helper function to intercept the tensor_getter and inject LoRA weights dynamically.""" + max_logging.log("LoRA adapter path provided and load_parameters_path provided. Merging LoRA into base weights.") + hf_access_token = config.hf_access_token + lora_weights = load_hf_dict_from_safetensors(hf_lora_adapter_path, hf_access_token, revision) + + # Load adapter config to get scaling factor + if os.path.isdir(hf_lora_adapter_path): + config_path = os.path.join(hf_lora_adapter_path, "adapter_config.json") + else: + config_path = hf_hub_download(hf_lora_adapter_path, "adapter_config.json", token=hf_access_token) + with open(config_path, "r", encoding="utf-8") as f: + adapter_config = json.load(f) + + lora_alpha = adapter_config.get("lora_alpha", 8) + lora_rank = adapter_config.get("r", 8) + scaling = lora_alpha / lora_rank if lora_rank > 0 else 1.0 + + base_to_lora = {} + for k, w in lora_weights.items(): + if hasattr(w, "numpy"): + if w.dtype == torch.bfloat16: + w = w.to(torch.float32) + w = w.detach().cpu().numpy() + + hf_param_key = k.replace(".lora_A.weight", ".weight").replace(".lora_B.weight", ".weight") + hf_param_key = hf_param_key.replace(".lora_A", "").replace(".lora_B", "") + + if hf_param_key.startswith("base_model.model."): + hf_param_key = hf_param_key[len("base_model.model.") :] + if hf_param_key.startswith("language_model.model."): + hf_param_key = "model.language_model." + hf_param_key[len("language_model.model.") :] + + if hf_param_key not in base_to_lora: + base_to_lora[hf_param_key] = {} + + if "lora_A" in k or "lora_a" in k: + base_to_lora[hf_param_key]["A"] = w + else: + base_to_lora[hf_param_key]["B"] = w + + original_getter = tensor_getter + + def _merged_getter(key): + base_w = original_getter(key) + if key in base_to_lora: + lora_dict = base_to_lora[key] + if "A" in lora_dict and "B" in lora_dict: + lora_a = np.array(lora_dict["A"], dtype=np.float32) + lora_b = np.array(lora_dict["B"], dtype=np.float32) + + if lora_a.ndim > 2 or lora_b.ndim > 2: + # Use einsum for multi-dimensional LoRA weights to contract on rank dimension + delta = np.einsum("...ir,rj...->...ij...", lora_b, lora_a) * scaling + else: + delta = np.matmul(lora_b, lora_a) * scaling + + if hasattr(base_w, "dtype"): + original_dtype = base_w.dtype + else: + original_dtype = np.float32 + + if delta.shape != base_w.shape and delta.size == base_w.size: + delta = delta.reshape(base_w.shape) + + base_w = np.array(base_w, dtype=np.float32) + delta + return base_w.astype(original_dtype) + + return base_w + + return _merged_getter + + def main( args: Sequence[str], lazy_load_tensors: bool = False, @@ -615,159 +846,178 @@ def main( hf_token = config.hf_access_token - if lazy_load_tensors and config.use_multimodal: - raise ValueError("lazy loading of HF tensors is not supported for multimodal models yet.") + hf_lora_adapter_path = config.hf_lora_adapter_path - hf_state_dict_numpy = None - hf_loader = None + is_adapter_only = bool(hf_lora_adapter_path and not config.load_parameters_path) + is_merge_mode = bool(hf_lora_adapter_path and config.load_parameters_path) - # Define the appropriate tensor getter based on mode - if lazy_load_tensors: - max_logging.log(f"Lazy loading ENABLED. Initializing LazyHFLoader for: {model_id}...") - hf_loader = LazyHFLoader(model_id, hf_token, revision=revision) + if is_adapter_only: + max_logging.log("LoRA adapter path provided and load_parameters_path NOT provided. Converting LoRA adapter ONLY.") + hf_access_token = config.hf_access_token + lora_weights = load_hf_dict_from_safetensors(hf_lora_adapter_path, hf_access_token, revision) - print_ram_usage("After LazyLoader init") - tensor_getter = hf_loader.get_tensor + model_name_for_path = model_name_original or config.model_name + jax_weights = convert_lora_to_maxtext_adapter(config, lora_weights, save_dtype) + adapter_name = os.path.basename(os.path.normpath(hf_lora_adapter_path)) + output_directory = os.path.join(output_directory, model_name_for_path, adapter_name) else: - max_logging.log(f"Lazy loading DISABLED. Loading full HuggingFace model: {model_id}...") - - # Eager load methods: - # - Method 1: transformers_class.from_pretrained(..., dtype="auto") - # - Method 2: safetensors.safe_open(..., framework="pt") - # - # Comparison: - # - Both methods result in the same dtype (usually bfloat16) and model structure - # for most models (e.g., DeepSeek-V2), with similar loading times. - # - Exception: Gemma-3 uses different internal naming (prefixes) between - # Method 1 and Method 2. Current MaxText 'param_mapping' for Gemma-3 assumes - # the Transformers-style structure (Method 1). - # - The 'safetensors' method is a necessary fallback for: - # 1. "Day-0" models where the official Transformers code hasn't been merged yet - # (e.g., DeepSeek-V3.2 during its initial release). - # 2. Weights omitted by official Transformers class - # (e.g., Multi-Token Prediction weights (`layers.61`) in DeepSeek-V3). - # - # Recommendation: - # - Use 'transformers' as the default for backward compatibility of mapping. - # - 'safetensors' is an interchangeable and valid alternative for most models, - # and is strictly required if the model or specific weights lack Transformers support. - if eager_load_method == "transformers": - max_logging.log("Eager load with Transformers backend, from_pretrained with auto dtype") - # For auto mode, loaded dtype is the same as `dtype` specified in config.json (or `torch_dtype` for older version) - # e.g., https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/config.json#L54 - hf_state_dict_numpy = load_hf_dict_from_transformers(model_id, token=hf_token, revision=revision, dtype="auto") - elif eager_load_method == "safetensors": - max_logging.log("Eager load with Safetensors backend, safe_open with pt framework") - # For safe_open, loaded dtype is the same as original safetensor - # e.g., https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/model.safetensors.index.json - hf_state_dict_numpy = load_hf_dict_from_safetensors(model_id, token=hf_token, revision=revision, framework="pt") + + if lazy_load_tensors and config.use_multimodal: + raise ValueError("lazy loading of HF tensors is not supported for multimodal models yet.") + + hf_state_dict_numpy = None + hf_loader = None + + # Define the appropriate tensor getter based on mode + if lazy_load_tensors: + max_logging.log(f"Lazy loading ENABLED. Initializing LazyHFLoader for: {model_id}...") + hf_loader = LazyHFLoader(model_id, hf_token, revision=revision) + + print_ram_usage("After LazyLoader init") + tensor_getter = hf_loader.get_tensor else: - raise NotImplementedError - - unique_dtypes = {tensor.dtype for tensor in hf_state_dict_numpy.values()} - max_logging.log(f"HuggingFace model loaded. dtypes: {unique_dtypes}") - print_ram_usage("After full HF model load") - - def _eager_getter(key): - if key not in hf_state_dict_numpy: - raise ValueError(f"HuggingFace key {key} not found in state_dict.") - v = hf_state_dict_numpy[key] - # target dtype is "float32" - if save_dtype == DType.FLOAT32: - return v.to(torch.float32).numpy() - # target dtype is "bfloat16" - elif save_dtype == DType.BFLOAT16: - # - torch.bfloat16 -> torch.float32 -> np.float32 -> ml_dtypes.bfloat16 - # As numpy doesn't accept bfloat16 directly, we convert to float32 first - # - torch.float16 -> np.float16 -> ml_dtypes.bfloat16 - # - torch.float32 -> np.float32 -> ml_dtypes.bfloat16 - if v.dtype == torch.bfloat16: - v = v.to(torch.float32) - return v.numpy().astype(ml_dtypes.bfloat16) - raise NotImplementedError(f"Save dtype {save_dtype} is not currently implemented.") - - tensor_getter = _eager_getter - - # Get parameter mappings and hooks - model_key = config.model_name - # load config - hf_config_obj = HF_MODEL_CONFIGS[model_key] - hf_config_dict = hf_config_obj.to_dict() - # example of param mapping (gemma2, maxtext:huggingface): - # "params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": - # f"model.layers.{global_layer_idx}.input_layernorm.weight", - param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_dict, config, config.scan_layers) - # Example of Hook FN mapping, to perform reshape: - # f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel, - hook_fn_map_mt = HOOK_FNS[model_key](hf_config_dict, config, config.scan_layers, saving_to_hf=False) - max_logging.log("Parameter mappings and hooks obtained.") - - maxtext_abstract_dict, abstract_params_treedef = get_maxtext_model_info(config) - - # Weight transformation - max_logging.log("Starting weight transformation...") - start = time.time() - # Stores MaxText weights: numpy.ndarray - final_mt_weights = [None] * len(maxtext_abstract_dict) - - # Preprocess key - filtered_map_keys = validate_and_filter_param_map_keys(param_map_mt_to_hf.keys(), maxtext_abstract_dict.keys()) - - for mt_param_key_or_keys in MemoryMonitorTqdm( - filtered_map_keys, - desc="Transforming weights", - unit="param", - leave=True, - dynamic_ncols=True, - smoothing=0, - ): - if not lazy_load_tensors: - max_logging.log(f"maxtext param: {mt_param_key_or_keys}") - - hf_source_keys_or_key = param_map_mt_to_hf.get(mt_param_key_or_keys) - if hf_source_keys_or_key is None: - raise ValueError(f"MaxText parameter {mt_param_key_or_keys} not found in mapping.") - hook_fn = hook_fn_map_mt.get(mt_param_key_or_keys) - - # Step 1: Resolves MaxText key(s) to target indices and shapes - # based on MaxText key form (`atomic_mt_key` or `composite_mt_key`) - mt_target_idx_or_indices, mt_target_shape_or_shapes = _get_maxtext_indices_and_shapes( - mt_param_key_or_keys, maxtext_abstract_dict - ) + max_logging.log(f"Lazy loading DISABLED. Loading full HuggingFace model: {model_id}...") + + # Eager load methods: + # - Method 1: transformers_class.from_pretrained(..., dtype="auto") + # - Method 2: safetensors.safe_open(..., framework="pt") + # + # Comparison: + # - Both methods result in the same dtype (usually bfloat16) and model structure + # for most models (e.g., DeepSeek-V2), with similar loading times. + # - Exception: Gemma-3 uses different internal naming (prefixes) between + # Method 1 and Method 2. Current MaxText 'param_mapping' for Gemma-3 assumes + # the Transformers-style structure (Method 1). + # - The 'safetensors' method is a necessary fallback for: + # 1. "Day-0" models where the official Transformers code hasn't been merged yet + # (e.g., DeepSeek-V3.2 during its initial release). + # 2. Weights omitted by official Transformers class + # (e.g., Multi-Token Prediction weights (`layers.61`) in DeepSeek-V3). + # + # Recommendation: + # - Use 'transformers' as the default for backward compatibility of mapping. + # - 'safetensors' is an interchangeable and valid alternative for most models, + # and is strictly required if the model or specific weights lack Transformers support. + if eager_load_method == "transformers": + max_logging.log("Eager load with Transformers backend, from_pretrained with auto dtype") + # For auto mode, loaded dtype is the same as `dtype` specified in config.json (or `torch_dtype` for older version) + # e.g., https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/config.json#L54 + hf_state_dict_numpy = load_hf_dict_from_transformers(model_id, token=hf_token, revision=revision, dtype="auto") + elif eager_load_method == "safetensors": + max_logging.log("Eager load with Safetensors backend, safe_open with pt framework") + # For safe_open, loaded dtype is the same as original safetensor + # e.g., https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/model.safetensors.index.json + hf_state_dict_numpy = load_hf_dict_from_safetensors(model_id, token=hf_token, revision=revision, framework="pt") + else: + raise NotImplementedError + + unique_dtypes = {tensor.dtype for tensor in hf_state_dict_numpy.values()} + max_logging.log(f"HuggingFace model loaded. dtypes: {unique_dtypes}") + print_ram_usage("After full HF model load") + + def _eager_getter(key): + if key not in hf_state_dict_numpy: + raise ValueError(f"HuggingFace key {key} not found in state_dict.") + v = hf_state_dict_numpy[key] + # target dtype is "float32" + if save_dtype == DType.FLOAT32: + return v.to(torch.float32).numpy() + # target dtype is "bfloat16" + elif save_dtype == DType.BFLOAT16: + # - torch.bfloat16 -> torch.float32 -> np.float32 -> ml_dtypes.bfloat16 + # As numpy doesn't accept bfloat16 directly, we convert to float32 first + # - torch.float16 -> np.float16 -> ml_dtypes.bfloat16 + # - torch.float32 -> np.float32 -> ml_dtypes.bfloat16 + if v.dtype == torch.bfloat16: + v = v.to(torch.float32) + return v.numpy().astype(ml_dtypes.bfloat16) + raise NotImplementedError(f"Save dtype {save_dtype} is not currently implemented.") + + tensor_getter = _eager_getter + + if is_merge_mode: + tensor_getter = _setup_merge_mode_getter(tensor_getter, config, hf_lora_adapter_path, revision) + + # Get parameter mappings and hooks + model_key = config.model_name + # load config + hf_config_obj = HF_MODEL_CONFIGS[model_key] + hf_config_dict = hf_config_obj.to_dict() + # example of param mapping (gemma2, maxtext:huggingface): + # "params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": + # f"model.layers.{global_layer_idx}.input_layernorm.weight", + param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_dict, config, config.scan_layers) + # Example of Hook FN mapping, to perform reshape: + # f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel, + hook_fn_map_mt = HOOK_FNS[model_key](hf_config_dict, config, config.scan_layers, saving_to_hf=False) + max_logging.log("Parameter mappings and hooks obtained.") + + maxtext_abstract_dict, abstract_params_treedef = get_maxtext_model_info(config) + + # Weight transformation + max_logging.log("Starting weight transformation...") + start = time.time() + # Stores MaxText weights: numpy.ndarray + final_mt_weights = [None] * len(maxtext_abstract_dict) + + # Preprocess key + filtered_map_keys = validate_and_filter_param_map_keys(param_map_mt_to_hf.keys(), maxtext_abstract_dict.keys()) + + for mt_param_key_or_keys in MemoryMonitorTqdm( + filtered_map_keys, + desc="Transforming weights", + unit="param", + leave=True, + dynamic_ncols=True, + smoothing=0, + ): + if not lazy_load_tensors: + max_logging.log(f"maxtext param: {mt_param_key_or_keys}") + + hf_source_keys_or_key = param_map_mt_to_hf.get(mt_param_key_or_keys) + if hf_source_keys_or_key is None: + raise ValueError(f"MaxText parameter {mt_param_key_or_keys} not found in mapping.") + hook_fn = hook_fn_map_mt.get(mt_param_key_or_keys) + + # Step 1: Resolves MaxText key(s) to target indices and shapes + # based on MaxText key form (`atomic_mt_key` or `composite_mt_key`) + mt_target_idx_or_indices, mt_target_shape_or_shapes = _get_maxtext_indices_and_shapes( + mt_param_key_or_keys, maxtext_abstract_dict + ) - # Step 2: Determine the loading function for hf key - # based on hf_key form (unscanned, scanned, unscanned with expert stacking, or scanned with expert stacking) - load_fn = _get_hf_loading_function( - hf_source_keys_or_key, - tensor_getter, - hook_fn, - mt_target_shape_or_shapes, - config, - ) + # Step 2: Determine the loading function for hf key + # based on hf_key form (unscanned, scanned, unscanned with expert stacking, or scanned with expert stacking) + load_fn = _get_hf_loading_function( + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) - # Step 3: Load hf keys and convert to maxtext keys - # based on tensor load mode (lazy, eager) and MaxText key form (`atomic_mt_key` or `composite_mt_key`) - _get_maxtext_weight( - load_fn, - mt_target_idx_or_indices, - mt_target_shape_or_shapes, - mt_param_key_or_keys, - final_mt_weights, - save_dtype, - lazy_load_tensors, - ) + # Step 3: Load hf keys and convert to maxtext keys + # based on tensor load mode (lazy, eager) and MaxText key form (`atomic_mt_key` or `composite_mt_key`) + _get_maxtext_weight( + load_fn, + mt_target_idx_or_indices, + mt_target_shape_or_shapes, + mt_param_key_or_keys, + final_mt_weights, + save_dtype, + lazy_load_tensors, + ) - del hf_state_dict_numpy - max_logging.log("Weight transformation preparation complete.") - max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min") - print_ram_usage("Before creating full JAX tree") + del hf_state_dict_numpy + max_logging.log("Weight transformation preparation complete.") + max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min") + print_ram_usage("Before creating full JAX tree") - # Create final MaxText parameters tree - jax_weights = jax.tree_util.tree_unflatten(abstract_params_treedef, final_mt_weights) - del final_mt_weights, abstract_params_treedef + # Create final MaxText parameters tree + jax_weights = jax.tree_util.tree_unflatten(abstract_params_treedef, final_mt_weights) + del final_mt_weights, abstract_params_treedef print_ram_usage("Before saving") - if lazy_load_tensors: + if lazy_load_tensors and not is_adapter_only: max_logging.log("Starting checkpoint save (loading weights just-in-time)...") else: max_logging.log("Starting checkpoint save...") @@ -784,7 +1034,10 @@ def _eager_getter(key): ) print_ram_usage("Program Ends") - max_logging.log(f"Conversion complete. Checkpoint saved to {output_directory}") + if is_adapter_only: + max_logging.log(f"LoRA adapter conversion completed successfully. Saved to {output_directory}") + else: + max_logging.log(f"Conversion complete. Checkpoint saved to {output_directory}") max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min") print_peak_memory() diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 26724fe863..15d7a63aad 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -790,13 +790,13 @@ def format_meter( def load_orbax_checkpoint(config) -> dict: - """Loads a full Orbax checkpoint from disk with unsharded arrays. + """Loads Orbax checkpoints from Base and/or LoRA paths in config. Args: config: MaxText config containing checkpoint storage settings Returns: - Dictionary containing the full checkpoint structure + Dictionary containing all weights merged into a single structure. """ # Create Orbax checkpointer ckptr = ocp.Checkpointer( @@ -807,10 +807,6 @@ def load_orbax_checkpoint(config) -> dict: ) ) - # Get checkpoint metadata - checkpoint_path = epath.Path(config.load_parameters_path) - metadata = ckptr.metadata(checkpoint_path) - # Create a mesh with all devices for unsharded restoration devices = np.array(jax.devices()).reshape((-1,)) single_device_mesh = jax.sharding.Mesh(devices, ("x",)) @@ -824,14 +820,44 @@ def create_restore_args(tree_metadata): else: return None - restore_args = jax.tree_util.tree_map( - lambda x: create_restore_args(x) if hasattr(x, "shape") else None, - metadata.item_metadata.tree, - is_leaf=lambda x: hasattr(x, "shape"), - ) - - # Restore the entire checkpoint - return ckptr.restore(checkpoint_path, restore_args=restore_args) + lora_path = config.lora.lora_restore_path + paths = [p for p in [config.load_parameters_path, lora_path] if p] + + merged_dict = {} + for path in paths: + checkpoint_path = epath.Path(path) + metadata = ckptr.metadata(checkpoint_path) + restore_args = jax.tree_util.tree_map( + lambda x: create_restore_args(x) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + merged_dict.update(ckptr.restore(checkpoint_path, restore_args=restore_args)) + + return merged_dict + + +def save_adapter_files(output_dir, weights, config, found_modules, model_id): + """Saves HF LoRA adapter weights and config.""" + os.makedirs(output_dir, exist_ok=True) + adapter_file = os.path.join(output_dir, "adapter_model.safetensors") + numpy_save_file(weights, adapter_file) + + # Create PEFT adapter_config.json + adapter_config = { + "base_model_name_or_path": model_id, + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": config.lora.lora_rank, + "lora_alpha": config.lora.lora_alpha, + "target_modules": list(found_modules), + "lora_dropout": 0.0, + "bias": "none", + "inference_mode": True, + } + config_file = os.path.join(output_dir, "adapter_config.json") + with open(config_file, "w", encoding="utf-8") as f: + json.dump(adapter_config, f, indent=4) def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]: diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index cc2f674fd4..6e19ccc445 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -43,6 +43,7 @@ load_parameters_path: "" # LoRA adapter support configs lora_input_adapters_path: "" # Input GCS path for a parent directory which has all the LoRA adapters (lora_id as subdir) +hf_lora_adapter_path: "" # Input HF repo ID or local path for HF LoRA adapter # Loads a full checkpoint including optimizer state and step count from a specific directory # e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items diff --git a/src/maxtext/configs/post_train/lora_module_path.yml b/src/maxtext/configs/post_train/lora_module_path.yml new file mode 100644 index 0000000000..11f81d52c5 --- /dev/null +++ b/src/maxtext/configs/post_train/lora_module_path.yml @@ -0,0 +1,28 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Recommended LoRA module paths by model architecture prefix. +# These models have been explicitly tested and verified for LoRA. + +llama3.1: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" +qwen3: "decoder/layers/self_attention/(query|key|value|out)|decoder/layers/mlp/(wi_0|wi_1|wo)" +mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" +deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)" +gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)" +gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))" +olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" +gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))" + +# Fallback for unverified models +default: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" diff --git a/src/maxtext/configs/post_train/sft.yml b/src/maxtext/configs/post_train/sft.yml index 32c86ddb31..3ba5cf2161 100644 --- a/src/maxtext/configs/post_train/sft.yml +++ b/src/maxtext/configs/post_train/sft.yml @@ -21,6 +21,15 @@ sft_train_on_completion_only: True packing: True learning_rate: 2.e-5 +# -------------- LoRA / QLoRA -------------- +lora: + enable_lora: False + lora_rank: 0 + lora_alpha: 0.0 + lora_module_path: "" + # Optional path to LoRA weights to load before training. Ignored if the current run is resumed. + lora_restore_path: "" + # -------------- HF pipeline -------------- dataset_type: hf hf_path: 'HuggingFaceH4/ultrachat_200k' diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c364484c7f..20594bccc3 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -308,6 +308,13 @@ class Checkpointing(BaseModel): load_parameters_path: PathStr = Field("", description="Loads only model parameters from a specific checkpoint path.") lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.") + hf_lora_adapter_path: PathStr = Field( + "", + description=( + "HuggingFace LoRA adapter repo ID (e.g., 'username/adapter-repo') or local " + "path to directory containing adapter_model.safetensors." + ), + ) load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.") enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.") load_checkpoint_only_once: bool = Field(False, description="If True, deep copy the reference model to the actor model.") @@ -345,7 +352,8 @@ class Checkpointing(BaseModel): description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.", ) enable_autocheckpoint: bool = Field( - False, description="If True, enables autocheckpoint or preemption induced checkpointing." + False, + description="If True, enables autocheckpoint or preemption induced checkpointing.", ) @@ -495,7 +503,8 @@ class ModelArchitecture(BaseModel): ) fused_mlp: bool = Field(False, description="If supported, fuse the MLP layers.") qk_norm_with_scale: bool = Field( - True, description="Whether to apply scale on query and key normalizations (default True)." + True, + description="Whether to apply scale on query and key normalizations (default True).", ) v_norm_with_scale: bool = Field(True, description="Whether to apply scale on value normalization (default True).") @@ -542,9 +551,13 @@ class Attention(BaseModel): "global", description="The variant of attention to use." ) share_kv_projections: bool = Field( - False, description="If True, for global attention, Key and Value projections share the same weights." + False, + description="If True, for global attention, Key and Value projections share the same weights.", + ) + global_num_kv_heads: int = Field( + 0, + description="If greater than 0, sets the number of KV heads for global attention.", ) - global_num_kv_heads: int = Field(0, description="If greater than 0, sets the number of KV heads for global attention.") attention_sink: bool = Field(False, description="If True, enables attention sinks.") float32_qk_product: bool = Field(False, description="In dot-product attention, cast query-key product to fp32.") float32_logits: bool = Field( @@ -1046,7 +1059,8 @@ class Tokenizer(BaseModel): use_chat_template: bool = Field(False, description="Whether to use the chat template for tokenization.") chat_template_path: str = Field("", description="Path to chat template json file.") chat_template: str = Field( - "", description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template." + "", + description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.", ) tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.") tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.") @@ -1138,7 +1152,8 @@ class GrainDataset(BaseModel): description="Path to a JSON file specifying the mixture weights for Grain training data.", ) grain_file_type: str = Field( - "arrayrecord", description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet." + "arrayrecord", + description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet.", ) grain_use_elastic_iterator: bool = Field( False, @@ -1212,6 +1227,26 @@ class FineTuning(BaseModel): use_grpo: None | bool = Field(None, description="If True, enables Group Relative Policy Optimization.") +class LoRA(BaseModel): + """Configuration for LoRA / QLoRA adapters.""" + + model_config = ConfigDict(extra="forbid") + + enable_lora: bool = Field(False, description="If True, enables LoRA/QLoRA during fine-tuning.") + lora_rank: NonNegativeInt = Field(0, description="LoRA rank. Set >0 when LoRA is enabled.") + lora_alpha: NonNegativeFloat = Field(0.0, description="LoRA alpha scaling factor.") + lora_module_path: str = Field( + "", + description=( + "Regex identifying target modules for LoRA, e.g." " '.*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj'." + ), + ) + lora_restore_path: PathStr = Field( + "", + description=("Optional path to LoRA weights to load before training. Ignored if the current run is resumed."), + ) + + class Distillation(BaseModel): """Configuration for Knowledge Distillation.""" @@ -1229,7 +1264,8 @@ class Distillation(BaseModel): # --- Offline Distillation Field --- offline_data_dir: Optional[str] = Field( - None, description="GCS or local path to the pre-generated ArrayRecord teacher data." + None, + description="GCS or local path to the pre-generated ArrayRecord teacher data.", ) # --- Loss Params --- @@ -1237,7 +1273,8 @@ class Distillation(BaseModel): distill_temperature: float = Field(1.0, description="Temperature for distillation softening.") distill_beta: float = Field(0.0, description="Weight for the feature loss component. Use 0.0 to disable") distill_feature_loss_type: Literal["cosine", "l2"] = Field( - "cosine", description="The type of loss to use for feature distillation ('cosine' or 'l2')." + "cosine", + description="The type of loss to use for feature distillation ('cosine' or 'l2').", ) distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.") distill_alpha_end: Optional[float] = Field(None, description="Target alpha at end of training. None keeps alpha fixed.") @@ -1344,10 +1381,12 @@ class Optimizer(BaseModel): opt_type: OptimizerType = Field(OptimizerType.ADAMW, description="The type of optimizer to use.") skip_step_on_spikes: bool = Field( - False, description="If True, skip the training step when loss or gradient spike is detected." + False, + description="If True, skip the training step when loss or gradient spike is detected.", ) skip_step_interval: PositiveInt = Field( - 128, description="The rolling interval to calculate the mean and standard deviation." + 128, + description="The rolling interval to calculate the mean and standard deviation.", ) skip_step_scaling_factor: float = Field(6.0, description="The scaling factor to determine if a spike occurred.") gradient_accumulation_steps: PositiveInt = Field( @@ -1796,7 +1835,10 @@ class VisionTower(BaseModel): temporal_patch_size_for_vit: int = Field(2, description="Temporal patch size for video inputs.") num_position_embeddings_for_vit: int = Field(1024, description="Number of position embeddings for ViT.") deepstack_visual_indexes_for_vit: list[int] = Field([], description="Layer indices to extract deep visual features.") - vision_output_length: int = Field(-1, description="The output length (number of soft tokens) from the vision encoder.") + vision_output_length: int = Field( + -1, + description="The output length (number of soft tokens) from the vision encoder.", + ) class VisionProjector(BaseModel): @@ -1900,18 +1942,28 @@ class RL(BaseModel): grpo_epsilon: float = Field(0.2, description="Epsilon value for clipping in the GRPO loss.") loss_algo: Literal["grpo", "gspo-token"] = Field("grpo", description="Loss algorithm, i.e., 'grpo' or 'gspo-token'.") use_agentic_rollout: bool = Field( - False, description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts." + False, + description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts.", + ) + max_concurrency: int = Field( + 256, + description="Maximum number of concurrent rollout requests (agentic rollout only).", ) - max_concurrency: int = Field(256, description="Maximum number of concurrent rollout requests (agentic rollout only).") off_policy_steps: int = Field( - 0, description="Number of off-policy steps tolerated before requiring a policy update (agentic only)." + 0, + description="Number of off-policy steps tolerated before requiring a policy update (agentic only).", + ) + system_prompt: str = Field( + "", + description="System prompt injected into the agent at rollout time (agentic only).", ) - system_prompt: str = Field("", description="System prompt injected into the agent at rollout time (agentic only).") degenerate_group_masking: bool = Field( - True, description="Mask degenerate groups (all-zero advantages) from contributing to loss (agentic only)." + True, + description="Mask degenerate groups (all-zero advantages) from contributing to loss (agentic only).", ) epsilon_high: Optional[float] = Field( - None, description="Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only)." + None, + description="Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only).", ) reshard_chunk_size: Optional[int] = Field( None, @@ -2255,6 +2307,10 @@ class MaxTextConfig( default_factory=RL, description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).", ) + lora: LoRA = Field( + default_factory=LoRA, + description="Configuration for LoRA / QLoRA adapters.", + ) model_config = ConfigDict(extra="forbid", protected_namespaces=()) @model_validator(mode="before") diff --git a/src/maxtext/examples/sft_llama3_demo_gpu.ipynb b/src/maxtext/examples/sft_llama3_demo_gpu.ipynb index 1eed01572d..7a7f3f36dc 100644 --- a/src/maxtext/examples/sft_llama3_demo_gpu.ipynb +++ b/src/maxtext/examples/sft_llama3_demo_gpu.ipynb @@ -190,8 +190,7 @@ "from maxtext.trainers.post_train.sft import train_sft\n", "\n", "MAXTEXT_REPO_ROOT = os.path.dirname(os.path.dirname(maxtext.__file__))\n", - "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")\n" -, + "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")\n", "\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"JAX devices: {jax.devices()}\")\n", @@ -434,7 +433,7 @@ "_prev_level = _pyconfig_logger.level\n", "_pyconfig_logger.setLevel(logging.WARNING)\n", "\n", - "config = pyconfig.initialize(config_argv)\n", + "config = pyconfig.initialize_pydantic(config_argv)\n", "\n", "_pyconfig_logger.setLevel(_prev_level)\n", "\n", diff --git a/src/maxtext/examples/sft_llama3_demo_tpu.ipynb b/src/maxtext/examples/sft_llama3_demo_tpu.ipynb index 3cb7997126..b0f763d3a0 100644 --- a/src/maxtext/examples/sft_llama3_demo_tpu.ipynb +++ b/src/maxtext/examples/sft_llama3_demo_tpu.ipynb @@ -275,7 +275,7 @@ " \"profiler=xplane\",\n", "]\n", "\n", - "config = pyconfig.initialize(config_argv)\n", + "config = pyconfig.initialize_pydantic(config_argv)\n", "\n", "print(\"✓ SFT configuration loaded:\")\n", "print(f\" Model: {config.model_name}\")\n", diff --git a/src/maxtext/examples/sft_qwen3_demo.ipynb b/src/maxtext/examples/sft_qwen3_demo.ipynb index afba9e9a04..f1b8ec7a24 100644 --- a/src/maxtext/examples/sft_qwen3_demo.ipynb +++ b/src/maxtext/examples/sft_qwen3_demo.ipynb @@ -310,7 +310,7 @@ "outputs": [], "source": [ "%%capture\n", - "config = pyconfig.initialize(\n", + "config = pyconfig.initialize_pydantic(\n", " [\n", " \"\",\n", " f\"{MAXTEXT_PKG_DIR}/configs/post_train/sft.yml\",\n", diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index 9fa42d6abe..07231f965e 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -323,4 +323,4 @@ def load_weights(self, rng_key: jax.Array) -> None: model = model_creation_utils.from_pretrained( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) - self.model = nnx.data(model) \ No newline at end of file + self.model = nnx.data(model) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 606e81afd1..262eb62277 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -65,6 +65,7 @@ ) from maxtext.multimodal import utils as mm_utils from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding +from maxtext.utils.maxtext_utils_nnx import nnx_ensure_scan_leading_axis from maxtext.utils.sharding import create_sharding # ------------------------------------------------------------------------------ @@ -382,7 +383,7 @@ def __init__( RemattedGemma4Block = gemma4.Gemma4ScannableBlock if scan_length > 0: - self.layers = self._create_scanned_layers( + self.scanned_blocks = self._create_scanned_layers( RemattedGemma4Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs ) self.layers_remainder = RemattedGemma4Block( @@ -495,7 +496,7 @@ def scan_body(carry, rng_state_slice): _, (stacked_params, stacked_rest) = jax.lax.scan(scan_body, None, rngs_state) if scan_axis != 0: - stacked_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), stacked_params) + stacked_params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), stacked_params) def _add_scan_metadata(state, axis): def _update_leaf(leaf): @@ -588,9 +589,13 @@ def _extract_matching_state(template, full): return {k: _extract_matching_state(v, full[k]) for k, v in template.items()} return full + dynamic_graph_init = bool(getattr(self, "disable_quant_stats_update", False)) + updated_graphdef = [graphdef] + use_kv = kv_caches_stacked is not None def layer_fn(carry, scanned_vars): + # Unpack the sliced variables for THIS layer if use_kv: current_params, current_state, kv_cache_layer = scanned_vars @@ -618,13 +623,19 @@ def layer_fn(carry, scanned_vars): updated_kv = None # Extract the updated state to return it - new_current_state = nnx.state(layer) + if dynamic_graph_init: + new_graphdef, updated_params, updated_state = nnx.split(layer, nnx.Param, ...) + updated_graphdef[0] = new_graphdef + returned_params = updated_params + new_current_state = nnx.State.merge(returned_params, updated_state) + else: + new_current_state = nnx.state(layer) if use_kv: return new_carry, (new_current_state, updated_kv) return new_carry, new_current_state - layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + layer_fn_wrapped = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) if use_kv: # If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan @@ -634,7 +645,6 @@ def layer_fn(carry, scanned_vars): # kv_caches_stacked is actually the original kv_caches list in this new flow kv_caches_list = kv_caches_stacked - current_carry = x_in for i in range(length): @@ -643,7 +653,9 @@ def layer_fn(carry, scanned_vars): current_state = jax.tree.map(lambda x, i=i: x[i], state) # Call the layer - current_carry, (_, updated_kv) = layer_fn(current_carry, (current_params, current_state, kv_caches_list[i])) + current_carry, (_, updated_kv) = layer_fn_wrapped( + current_carry, (current_params, current_state, kv_caches_list[i]) + ) # Update the list in-place (mutates the list passed by reference) kv_caches_list[i] = updated_kv @@ -652,16 +664,27 @@ def layer_fn(carry, scanned_vars): # inference with vLLM, parameters do not change and we don't need intermediates. return current_carry, layers, None else: - final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + params = nnx_ensure_scan_leading_axis(params, length) + state = nnx_ensure_scan_leading_axis(state, length) + + final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) returned_kv_stacked = None if scan_axis != 0: new_params, new_rest = scanned_state.split(nnx.Param, ...) - new_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), new_params) + new_params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), new_params) scanned_state = nnx.merge_state(new_params, new_rest) - nnx.update(layers, scanned_state) - return final_carry, layers, returned_kv_stacked if use_kv else None + if dynamic_graph_init: + # If graph changed, we need to merge with the new graphdef. + # Note: scanned_state here is the full state (Params + rest). + new_params, new_rest = scanned_state.split(nnx.Param, ...) + out_layers = nnx.merge(updated_graphdef[0], new_params, new_rest) + else: + nnx.update(layers, scanned_state) + out_layers = layers + + return final_carry, out_layers, returned_kv_stacked if use_kv else None def get_decoder_layers(self): """Retrieves decoder layer classes based on config using a dictionary lookup.""" @@ -671,6 +694,8 @@ def get_scannable(normal_cls, scannable_cls): return [scannable_cls] if cfg.scan_layers else [normal_cls] def get_deepseek(): + if cfg.use_batch_split_schedule: + return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] layer_map = { @@ -1351,7 +1376,9 @@ def _apply_gemma4_scanned_blocks( # Apply the main scan over the full blocks if scan_length > 0: - y, self.layers, _ = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + y, self.scanned_blocks, _ = self._apply_layers_sequentially( + self.scanned_blocks, y, *layer_args, length=scan_length, **layer_kwargs + ) # Apply any remaining layers that did not fit into a full scanned block num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length diff --git a/src/maxtext/multimodal/processor_llama4.py b/src/maxtext/multimodal/processor_llama4.py index 172da260be..1da08857e0 100644 --- a/src/maxtext/multimodal/processor_llama4.py +++ b/src/maxtext/multimodal/processor_llama4.py @@ -279,7 +279,8 @@ def split_to_tiles(images: np.ndarray, num_tiles_height: int, num_tiles_width: i def preprocess_mm_data_llama4(images): """ Pre-process image for Llama4 model. Find best resolution and split into tiles with an additional global tile. - Original implementation from image_processing_llama4.py: http://shortn/_VXLgQ1lmkz + Original implementation from image_processing_llama4.py: + https://github.com/huggingface/transformers/blob/28d3148b079fa50b82f4888dfcc3cd3de953f956/src/transformers/models/llama4/image_processing_llama4_fast.py Args: images: The np.array image [H, W, C] or images [N, H, W, C] to pre-process. Returns: diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index c7c726cec9..3674ab70ff 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -49,6 +49,7 @@ from tunix.sft import metrics_logger, peft_trainer, profiler +from maxtext.optimizers import optimizers from maxtext.configs import pyconfig from maxtext.trainers.pre_train.train import loss_fn from maxtext.common.goodput import ( @@ -60,8 +61,8 @@ maybe_record_goodput, record_goodput, ) -from maxtext.optimizers import optimizers from maxtext.trainers.post_train.sft import hooks +from maxtext.utils import lora_utils from maxtext.utils import max_utils from maxtext.utils import max_logging from maxtext.utils import maxtext_utils @@ -126,7 +127,15 @@ def use_maxtext_loss_function(trainer, mt_config): The trainer configured with the MaxText loss function. """ - def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targets_position, targets_segmentation): + def loss_func( + model, + inputs, + inputs_position, + inputs_segmentation, + targets, + targets_position, + targets_segmentation, + ): data = { "inputs": inputs, "inputs_position": inputs_position, @@ -146,7 +155,11 @@ def setup_trainer_state(mt_config, goodput_recorder=None): tunix_config = get_tunix_config(mt_config) with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): + model, mesh = model_creation_utils.from_pretrained(mt_config) + if mt_config.lora.enable_lora: + model = lora_utils.apply_lora_to_model(model, mesh, mt_config) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) @@ -160,9 +173,13 @@ def setup_trainer_state(mt_config, goodput_recorder=None): with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION): training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder) data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) + # Provide rules context so 'norm' is translated to mesh axes during maybe_restore with nn_partitioning.axis_rules(mt_config.logical_axis_rules): trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + if mt_config.lora.lora_restore_path: + trainer = lora_utils.restore_lora_from_path(trainer, mt_config) + trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) @@ -173,7 +190,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None): def train_model(mt_config, trainer, mesh): """Runs the SFT training loop in Tunix.""" with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): - trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) + trainer.train( + trainer.data_hooks.train_data_iterator, + trainer.data_hooks.eval_data_iterator, + ) return trainer @@ -204,7 +224,7 @@ def main(argv: Sequence[str]) -> None: pathwaysutils.initialize() os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - mt_config = pyconfig.initialize(argv) + mt_config = pyconfig.initialize_pydantic(argv) max_utils.print_system_information() goodput_recorder = create_goodput_recorder(mt_config) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 24099ef22a..8554d46e3e 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +13,28 @@ # limitations under the License. """ Common LoRA utils needed to support LoRA adapters.""" - from functools import partial import json +import os +import re +from typing import Any, Optional +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from flax.training import train_state import jax import jax.numpy as jnp - -from flax.training import train_state -from flax.linen import partitioning as nn_partitioning +from orbax import checkpoint as ocp +import qwix from maxtext.common import checkpointing +from maxtext.configs import pyconfig from maxtext.utils import gcs_utils +from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils -from maxtext.utils import max_logging +from maxtext.utils import sharding +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR def apply_lora_on_base_params(base_params, lora_params, lora_scale_factor=1.0): @@ -243,7 +250,11 @@ def get_lora_param_shape(base_array_shape, lora_rank, lora_module): f"Encountered unexpected shape={base_array_shape} of array in base params. Array dimensions > 4 not supported." ) - if lora_module in ["self_attention.query", "self_attention.key", "self_attention.value"]: + if lora_module in [ + "self_attention.query", + "self_attention.key", + "self_attention.value", + ]: lora_a_shape = base_array_shape[:-2] + (lora_rank,) lora_b_shape = (lora_rank,) + base_array_shape[1:] elif lora_module in ["self_attention.out"]: @@ -270,7 +281,11 @@ def get_lora_param_sharding(base_param_sharding, lora_module): base_memory_kind = base_param_sharding.memory_kind base_pspec = base_param_sharding.spec - if lora_module in ["self_attention.query", "self_attention.key", "self_attention.value"]: + if lora_module in [ + "self_attention.query", + "self_attention.key", + "self_attention.value", + ]: lora_a_pspec_tuple = base_pspec[:-2] + ((),) lora_a_pspec = jax.sharding.PartitionSpec(*lora_a_pspec_tuple) @@ -311,7 +326,13 @@ def add_lora_params(lora_params, module_name, base_params, lora_rank, lora_targe for name, param in base_params.items(): if isinstance(param, dict): lora_params[name] = {} - add_lora_params(lora_params[name], f"{module_name}.{name}", param, lora_rank, lora_target_modules) + add_lora_params( + lora_params[name], + f"{module_name}.{name}", + param, + lora_rank, + lora_target_modules, + ) else: if name not in ["kernel", "scale", "embedding"]: raise ValueError(f"Unexpected key={name} exists in the abstract params of base model.") @@ -349,3 +370,238 @@ def get_lora_annotations(lora_abstract_params): ) return unboxed_abstract_lora_state, lora_state_mesh_annotations + + +# --- Qwix LoRA Utils --- + + +def _get_lora_module_path(mt_config: pyconfig.HyperParameters) -> str: + """Gets the regex for modules to apply LoRA on from config, architecture map, or fallback.""" + if mt_config.lora.lora_module_path: + return mt_config.lora.lora_module_path + + config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", "lora_module_path.yml") + lora_configs = pyconfig._load_config(config_path) # pylint: disable=protected-access + model_name = mt_config.model_name.lower() + + # Find the first matching architecture prefix or use 'default' + matched_key = next((k for k in lora_configs if k != "default" and model_name.startswith(k)), "default") + + if matched_key == "default": + max_logging.log(f"Warning: Model '{model_name}' is unverified; falling back to default LoRA path.") + else: + max_logging.log(f"Auto-detected lora_module_path for model '{model_name}' (matched: '{matched_key}')") + + raw_path = lora_configs.get(matched_key, "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))") + + # This regex makes the layer index optional, matching both scanned and unscanned layer paths + # (e.g. 'layers/0/mlp/...' vs 'layers/mlp/...'). + optional_layer_index = "(?:[0-9]+/)?" + final_path = str(raw_path).replace("layers/", f"layers/{optional_layer_index}") + + max_logging.log(f"Using lora_module_path: {final_path}") + return final_path + + +def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvider: + """Builds a Qwix LoRA provider from MaxText LoRA settings.""" + lora_module_path = _get_lora_module_path(mt_config) + lora_kwargs = { + "module_path": lora_module_path, + "rank": mt_config.lora.lora_rank, + "alpha": mt_config.lora.lora_alpha, + "dropout": 0.0, + } + max_logging.log( + f"LoRA configured: module_path={lora_module_path} " + f"rank={mt_config.lora.lora_rank} alpha={mt_config.lora.lora_alpha}" + ) + return qwix.LoraProvider(**lora_kwargs) + + +def _prepare_dummy_inputs() -> tuple[jnp.ndarray, jnp.ndarray]: + """Builds dummy decoder inputs used to materialize LoRA parameters.""" + # Keep LoRA warmup as small as possible to minimize compile/memory overhead. + dummy_bs = 1 + seq_len = 1 + decoder_input_tokens = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32) + decoder_positions = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32) + return decoder_input_tokens, decoder_positions + + +def is_lora_enabled(model: nnx.Module) -> bool: + """Checks if the model has LoRA parameters.""" + for _, value in nnx.iter_graph(model): + if isinstance(value, nnx.LoRAParam): + return True + return False + + +def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperParameters): + """Validates that LoRA is active or that target modules were matched.""" + + if is_lora_enabled(lora_model): + return + + lora_module_path = _get_lora_module_path(mt_config) + compiled_module_path = re.compile(lora_module_path) + matched_module_paths = [] + sample_module_paths = [] + + for path, _ in nnx.iter_modules(lora_model): + module_path = "/".join(str(p) for p in path) + if len(sample_module_paths) < 100: + sample_module_paths.append(module_path) + if compiled_module_path.search(module_path): + matched_module_paths.append(module_path) + + if not matched_module_paths: + max_logging.log( + f"LoRA module_path='{lora_module_path}' did not match any weights. " f"Sample module paths: {sample_module_paths}" + ) + raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.") + + raise ValueError( + "LoRA module path matched target modules, but nnx.LoRAParam is still " + "missing. For Tunix PeftTrainer, LoRA params must be materialized before " + "trainer initialization, otherwise it falls back to full-model training. " + f"Sample matches: {matched_module_paths[:10]}" + ) + + +def apply_lora_to_model( + model: nnx.Module, + mesh: Optional[jax.sharding.Mesh], + mt_config: pyconfig.HyperParameters, +) -> nnx.Module: + """Optionally applies LoRA/QLoRA to a MaxText model using Qwix.""" + # Skip Qwix LoRA if MaxText LoRA adapters are loaded + if mt_config.lora_input_adapters_path: + max_logging.log("MaxText LoRA adapters loaded, skipping Qwix LoRA application") + return model + + if not mt_config.lora.enable_lora: + return model + + # Dynamically detect and set LoRA rank before model creation if restoring + + lora_provider = _build_lora_provider(mt_config) + + model_rngs = getattr(model.decoder, "rngs", None) + decoder_input_tokens, decoder_positions = _prepare_dummy_inputs() + + lora_model = qwix.apply_lora_to_model( + model, + lora_provider, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + rngs=model_rngs, + ) + + if mesh is not None: + with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + graph_def, state = nnx.split(lora_model) + + # We handle explicit replication for LoRA to ensure safety and efficiency. + state = jax.tree_util.tree_map( + lambda x: x.replace(sharding=jax.sharding.PartitionSpec(), out_sharding=None, sharding_names=None) + if isinstance(x, nnx.LoRAParam) + else x, + state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + + # Use logical_to_mesh_sharding to correctly map logical axes like 'embed' + # to physical mesh axes. + dst_shardings = sharding.logical_to_mesh_sharding( + nnx.get_partition_spec(state), mesh, rules=mt_config.logical_axis_rules + ) + + from tunix.rl import reshard # pylint: disable=import-outside-toplevel + + state = reshard.reshard_pytree(state, dst_shardings) + lora_model = nnx.merge(graph_def, state) + + _verify_lora_parameters(lora_model, mt_config) + + return lora_model + + +def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any: + """Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run.""" + lora_restore_path = mt_config.lora.lora_restore_path + + train_steps = getattr(trainer, "train_steps", 0) + if train_steps > 0: + max_logging.log( + f"PeftTrainer restored current run at step {train_steps}; " f"ignoring lora_restore_path '{lora_restore_path}'." + ) + return trainer + + if not is_lora_enabled(trainer.model): + lora_module_path = _get_lora_module_path(mt_config) + if not mt_config.lora.enable_lora: + raise ValueError( + "lora_restore_path is set but LoRA is not enabled on the model. " + f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules." + ) + + abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam) + + target_for_restore = jax.tree.map( + lambda v: {"value": v.value}, + abstract_lora_params, + is_leaf=lambda n: isinstance(n, nnx.Variable), + ) + + sharding_tree = jax.tree.map(lambda x: x.sharding if hasattr(x, "sharding") else None, target_for_restore) + restore_args_tree = ocp.checkpoint_utils.construct_restore_args(target_for_restore, sharding_tree) + + try: + restore_args = ocp.args.PyTreeRestore( + item=target_for_restore, + restore_args=restore_args_tree, + partial_restore=True, + ) + restored_lora_params = ocp.PyTreeCheckpointer().restore( + lora_restore_path, + args=restore_args, + ) + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.log(f"Guided restore failed: {e}. Falling back to basic restore.") + restored_lora_params = ocp.PyTreeCheckpointer().restore(lora_restore_path) + + # Post processing + def _map_to_state(path, variable): + if not isinstance(variable, nnx.Variable): + return + + str_path = [str(k.key if hasattr(k, "key") else (k.name if hasattr(k, "name") else k)) for k in path] + + curr = restored_lora_params + for p in str_path: + if isinstance(curr, dict) and p in curr: + curr = curr[p] + elif hasattr(curr, p): + curr = getattr(curr, p) + else: + return + + if isinstance(curr, dict) and "value" in curr: + matched_val = curr["value"] + elif hasattr(curr, "value"): + matched_val = getattr(curr, "value") + else: + matched_val = curr + + variable.value = matched_val + + jax.tree_util.tree_map_with_path( + _map_to_state, + abstract_lora_params, + is_leaf=lambda n: isinstance(n, nnx.Variable), + ) + + nnx.update(trainer.model, abstract_lora_params) + max_logging.log(f"LoRA restore complete from '{lora_restore_path}'.") + return trainer diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 09ccdbe804..4a14b461e7 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1741,7 +1741,13 @@ def save_quantized_checkpoint_if_configured(config, params): def add_config_to_summary_writer(config, summary_writer): """Writes config params to tensorboard""" if jax.process_index() == 0: - for key, value in config.get_keys().items(): + if hasattr(config, "get_keys"): + config_dict = config.get_keys() + elif hasattr(config, "model_dump"): + config_dict = config.model_dump() + else: + config_dict = dict(config) + for key, value in config_dict.items(): max_utils.add_text_to_summary_writer(key, str(value), summary_writer) diff --git a/src/maxtext/utils/maxtext_utils_nnx.py b/src/maxtext/utils/maxtext_utils_nnx.py index 32fa5ec80b..5b645b85ca 100644 --- a/src/maxtext/utils/maxtext_utils_nnx.py +++ b/src/maxtext/utils/maxtext_utils_nnx.py @@ -173,3 +173,17 @@ def create_sharded_state(): with jax.set_mesh(mesh): sharded_state = create_sharded_state() return nnx.merge(graphdef, sharded_state) + + +def nnx_ensure_scan_leading_axis(tree, length): + """Broadcasts scalar-like variables to have a leading scan axis.""" + + def _op(x): + is_var = isinstance(x, nnx.Variable) + val = x.get_value() if is_var else x + if hasattr(val, "shape") and len(val.shape) == 0: + new_val = jax.numpy.broadcast_to(val, (length,)) + return x.replace(value=new_val) if is_var else new_val + return x + + return jax.tree.map(_op, tree, is_leaf=lambda x: isinstance(x, nnx.Variable)) diff --git a/tests/post_training/unit/lora_utils_test.py b/tests/post_training/unit/lora_utils_test.py new file mode 100644 index 0000000000..fea833c9dd --- /dev/null +++ b/tests/post_training/unit/lora_utils_test.py @@ -0,0 +1,242 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Qwix LoRA utils in lora_utils.py""" +import sys +import unittest +from unittest import mock +import jax +import optax +import pytest +from flax import nnx + +# Skip the entire test suite if dependencies are missing +pytestmark = [pytest.mark.post_training] + +# Now safe to do top-level imports +from tunix.sft import peft_trainer +from maxtext.utils import lora_utils +from maxtext.utils import model_creation_utils +from maxtext.configs import pyconfig +from tests.utils.test_helpers import get_decoupled_parallelism_overrides, get_test_config_path # pylint: disable=no-name-in-module + +# --------------------------------------------------------------------------- +# Shared minimal config overrides +# --------------------------------------------------------------------------- +_BASE_CONFIG = { + "per_device_batch_size": 1.0, + "run_name": "lora_utils_test", + "enable_checkpointing": False, + "base_num_decoder_layers": 1, + "attention": "dot_product", + "max_target_length": 8, + "base_emb_dim": 128, + "base_num_query_heads": 2, + "base_num_kv_heads": 2, + "base_mlp_dim": 256, + "max_prefill_predict_length": 4, + "model_name": "llama2-7b", + "enable_nnx": True, + "pure_nnx_decoder": True, + "override_model_config": True, + "weight_dtype": "bfloat16", +} + + +def _make_config(**overrides): + """Return a MaxTextConfig object suitable for unit tests.""" + extra_args = get_decoupled_parallelism_overrides() + # Use initialize_pydantic to get nested models as objects (attribute access) + return pyconfig.initialize_pydantic( + [sys.argv[0], get_test_config_path()], + **_BASE_CONFIG, + **extra_args, + **overrides, + ) + + +class LoraUtilsTest(unittest.TestCase): + """Tests for lora_utils.py (Qwix LoRA Utils)""" + + # pylint: disable=protected-access + + def test_get_lora_module_path(self): + """Test retrieving LoRA module path from config.""" + mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) + mock_config.lora = mock.MagicMock() + mock_config.lora.lora_module_path = "" + + mock_config.model_name = "llama3.1-8b" + path = lora_utils._get_lora_module_path(mock_config) + self.assertEqual( + path, + "decoder/layers/(?:[0-9]+/)?.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))", + ) + + mock_config.model_name = "unknown_model" + # Ensure lora.lora_module_path is still empty string to trigger fallback + mock_config.lora.lora_module_path = "" + path = lora_utils._get_lora_module_path(mock_config) + # Fallback to default + self.assertEqual( + path, + "decoder/layers/(?:[0-9]+/)?.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))", + ) + + mock_config.lora.lora_module_path = "custom/path" + path = lora_utils._get_lora_module_path(mock_config) + self.assertEqual(path, "custom/path") + + def test_build_lora_provider(self): + """Test building Qwix LoraProvider from config.""" + mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) + mock_config.model_name = "default" + mock_config.lora = mock.MagicMock() + mock_config.lora.lora_module_path = "custom/path" + mock_config.lora.lora_rank = 8 + mock_config.lora.lora_alpha = 16.0 + + with mock.patch("qwix.LoraProvider") as mock_provider: + lora_utils._build_lora_provider(mock_config) + mock_provider.assert_called_once_with(module_path="custom/path", rank=8, alpha=16.0, dropout=0.0) + + def test_prepare_dummy_inputs(self): + """Test preparation of dummy inputs for LoRA verification.""" + tokens, positions = lora_utils._prepare_dummy_inputs() + self.assertEqual(tokens.shape, (1, 1)) + self.assertEqual(positions.shape, (1, 1)) + + def test_verify_lora_parameters_enabled(self): + """Test verification of LoRA parameters when enabled.""" + mock_model = mock.MagicMock() + mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) + + # Note: we use our local is_lora_enabled now + with mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=True): + # Should not raise + lora_utils._verify_lora_parameters(mock_model, mock_config) + + def test_verify_lora_parameters_not_enabled_no_match(self): + """Test verification fails when LoRA parameters are expected but not found.""" + mock_model = mock.MagicMock() + mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) + mock_config.lora = mock.MagicMock() + mock_config.model_name = "llama" + mock_config.lora.lora_module_path = "non_existent" + + with mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=False): + mock_model.iter_modules.return_value = [] + with self.assertRaisesRegex(ValueError, "no LoRA parameters found"): + lora_utils._verify_lora_parameters(mock_model, mock_config) + + def test_apply_lora_to_model_disabled(self): + """Test applying LoRA when it is disabled in config.""" + cfg = _make_config(lora={"enable_lora": False}) + model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) + # Pydantic MaxTextConfig supports direct attribute access + self.assertFalse(cfg.lora.enable_lora) + result = lora_utils.apply_lora_to_model(model, None, cfg) + self.assertEqual(result, model) + self.assertFalse(lora_utils.is_lora_enabled(result)) + + def test_apply_lora_to_model_adapters_loaded(self): + """Test applying LoRA when adapters are already provided.""" + cfg = _make_config(**{"lora_input_adapters_path": "some/path"}) + model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) + result = lora_utils.apply_lora_to_model(model, None, cfg) + self.assertEqual(result, model) + # is_lora_enabled checks for LoRAParam which Qwix adds. + # If we skip Qwix, it should stay False. + self.assertFalse(lora_utils.is_lora_enabled(result)) + + def _run_apply_lora_test(self, scan_layers: bool): + """Helper to run LoRA application test with/without scanned layers.""" + # Passing nested dict as 'lora' kwarg to _make_config + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_rank": 4, + "lora_alpha": 8.0, + "lora_module_path": ".*mlp/wi_.*", + }, + scan_layers=scan_layers, + ) + + # Create a real small model using standard creation utils + model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) + + # Verify model is NOT lora enabled initially + self.assertFalse(lora_utils.is_lora_enabled(model)) + + # Apply LoRA + lora_model = lora_utils.apply_lora_to_model(model, model.mesh, cfg) + + # Verify we can find LoRAParam in the state + _, state = nnx.split(lora_model) + lora_params = nnx.filter_state(state, nnx.LoRAParam) + self.assertGreater(len(jax.tree_util.tree_leaves(lora_params)), 0) + + # Verify it IS now LoRA enabled + self.assertTrue(lora_utils.is_lora_enabled(lora_model)) + + # Test fit for PeftTrainer + trainer_cfg = peft_trainer.TrainingConfig(eval_every_n_steps=10) + optimizer = optax.adam(1e-4) + + # This instantiation will fail if wrt=nnx.LoRAParam cannot find any matching params + trainer = peft_trainer.PeftTrainer(model=lora_model, optimizer=optimizer, training_config=trainer_cfg) + + # Verify optimizer is indeed targeting LoRAParams + opt_state = nnx.state(trainer.optimizer) + self.assertGreater(len(jax.tree_util.tree_leaves(opt_state)), 0) + + def test_apply_lora_to_model_scan_layers_false(self): + """Test applying LoRA to model with scan_layers=False.""" + self._run_apply_lora_test(scan_layers=False) + + def test_apply_lora_to_model_scan_layers_true(self): + """Test applying LoRA to model with scan_layers=True.""" + self._run_apply_lora_test(scan_layers=True) + + def test_restore_lora_from_path(self): + """Test restoration of LoRA parameters from a path.""" + cfg = _make_config( + lora={"enable_lora": True, "lora_restore_path": "some/path", "lora_rank": 4, "lora_alpha": 8.0}, + scan_layers=False, + ) + model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) + model = lora_utils.apply_lora_to_model(model, None, cfg) + + trainer = mock.MagicMock() + trainer.model = model + trainer.train_steps = 0 + + restored_state = nnx.state(model, nnx.LoRAParam) + + with mock.patch("orbax.checkpoint.PyTreeCheckpointer.restore", return_value=restored_state) as mock_restore: + with mock.patch("flax.nnx.update") as mock_update: + lora_utils.restore_lora_from_path(trainer, cfg) + mock_restore.assert_called_once() + args, kwargs = mock_restore.call_args + self.assertEqual(args[0], "some/path") + # Handle cases where partial_restore is passed as kwarg or within args object + if "partial_restore" in kwargs: + self.assertTrue(kwargs["partial_restore"]) + elif "args" in kwargs and hasattr(kwargs["args"], "partial_restore"): + self.assertTrue(kwargs["args"].partial_restore) + mock_update.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/hf_checkpoint_conversion_test.py b/tests/unit/hf_checkpoint_conversion_test.py index edf914af76..02ed7a5598 100644 --- a/tests/unit/hf_checkpoint_conversion_test.py +++ b/tests/unit/hf_checkpoint_conversion_test.py @@ -14,9 +14,20 @@ """ Tests for kernels """ +import unittest +from unittest.mock import MagicMock import numpy as np from maxtext.utils.max_utils import permute_to_match_maxtext_rope, unpermute_from_match_maxtext_rope -import unittest +from maxtext.checkpoint_conversion import to_huggingface as to_hf +from maxtext.checkpoint_conversion.to_huggingface import ( + _get_lora_delta, + _transform_weights_to_adapter, + _transform_weights_to_full_model, +) +from maxtext.checkpoint_conversion.to_maxtext import ( + convert_hf_lora_key_to_maxtext, + _process_and_stack_weights, +) class HFCheckpointConversionTest(unittest.TestCase): @@ -51,5 +62,112 @@ def test_huggingface_to_maxtext_back_to_huggingface_flow(self): print("Test failed: wq2 does not match wq4") +class MaxTextToHFLoRAConversionTest(unittest.TestCase): + """Tests the conversion modes (Base, Adapter, Merged) in to_huggingface with LoRA support.""" + + def setUp(self): + super().setUp() + self.base_key = "params-decoder-layers-layers_0-self_attention-query-kernel" + self.a_key = self.base_key + "_lora_a" + self.b_key = self.base_key + "_lora_b" + self.scaling = 2.0 + + # Simple weights for verification + # W: (10, 2, 20), A: (10, 2, 4), B: (4, 2, 20) + self.w_base = np.ones((10, 2, 20), dtype=np.float32) + self.w_a = np.ones((10, 2, 4), dtype=np.float32) * 0.5 + self.w_b = np.ones((4, 2, 20), dtype=np.float32) * 0.5 + + # Expected Merged: W + (B@A)*scaling + # B@A for each head: (20, 4) @ (4, 10) -> (20, 10) wait, MaxText shapes: + # MaxText A: (in, heads, rank), B: (rank, heads, out) + # Merging logic: matmul(A[:, i, :], B[:, i, :]) -> (in, out) + # head_delta = (0.5 * 0.5) * rank * scaling = 0.25 * 4 * 2.0 = 2.0 + # W_merged head = 1.0 + 2.0 = 3.0 + self.expected_merged_val = 3.0 + + def test_get_lora_delta(self): + lora_dict = {self.a_key: self.w_a, self.b_key: self.w_b} + delta = _get_lora_delta(self.base_key, lora_dict, self.scaling) + + self.assertEqual(delta.shape, (10, 2, 20)) + self.assertTrue(np.allclose(delta, 2.0)) + + def test_transform_weights_to_adapter(self): + param_map = {self.base_key: "model.layers.0.self_attn.q_proj.weight"} + lora_dict = {self.a_key: self.w_a, self.b_key: self.w_b} + + weights, modules = _transform_weights_to_adapter(param_map, lora_dict) + + self.assertIn("base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight", weights) + self.assertIn("base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight", weights) + self.assertEqual(weights["base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight"].shape, (4, 10)) + self.assertEqual(weights["base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight"].shape, (20, 4)) + self.assertIn("q_proj", modules) + + def test_transform_weights_to_full_model_merged(self): + config = MagicMock() + config.lora.lora_alpha = 32.0 + config.lora.lora_rank = 16.0 # scaling = 2.0 + + state_dict = {self.base_key: self.w_base, self.a_key: self.w_a, self.b_key: self.w_b} + param_map = {self.base_key: "model.layers.0.self_attn.q_proj.weight"} + + # Mock process_maxtext_param to just return the weight + orig_proc = to_hf.process_maxtext_param + to_hf.process_maxtext_param = lambda k, w, pm, hfm, sm, c: [(pm[k], w)] + + try: + weights = _transform_weights_to_full_model(config, [self.base_key], state_dict, param_map, {}, {}) + finally: + to_hf.process_maxtext_param = orig_proc + + self.assertIn("model.layers.0.self_attn.q_proj.weight", weights) + self.assertTrue(np.allclose(weights["model.layers.0.self_attn.q_proj.weight"], self.expected_merged_val)) + + +class HFToMaxTextLoRAConversionTest(unittest.TestCase): + """Tests the conversion logic in to_maxtext with LoRA support.""" + + def test_convert_hf_lora_key_to_maxtext(self): + param_mapping = { + "params-decoder-layers-layers_0-self_attention-query-kernel": "model.layers.0.self_attn.q_proj.weight", + "params-decoder-layers-layers_1-mlp-wi_0-kernel": [ + "model.layers.1.mlp.gate_proj.weight", + "model.layers.1.mlp.up_proj.weight", + ], + } + + # Simple 1-to-1 + mt_key, idx = convert_hf_lora_key_to_maxtext( + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight", param_mapping + ) + self.assertEqual(mt_key, "params-decoder-layers-layers_0-self_attention-query-kernel") + self.assertIsNone(idx) + + # Scanned/List mapping + mt_key, idx = convert_hf_lora_key_to_maxtext( + "base_model.model.model.layers.1.mlp.up_proj.lora_B.weight", param_mapping + ) + self.assertEqual(mt_key, "params-decoder-layers-layers_1-mlp-wi_0-kernel") + self.assertEqual(idx, 1) + + def test_process_and_stack_weights(self): + config = MagicMock() + config.model_name = "llama3.1-8b" + config.head_dim = 128 + + # 1. Non-scanned case + indexed = {0: np.ones((10, 20))} + stacked = _process_and_stack_weights(indexed, False, 1, 0, np.float32, "test", "suffix", config) + self.assertEqual(stacked.shape, (20, 10)) # Transposed + + # 2. Scanned case (stacking along layers) + indexed = {0: np.ones((10, 20)) * 1.0, 1: np.ones((10, 20)) * 2.0} + stacked = _process_and_stack_weights(indexed, True, 2, 0, np.float32, "test", "suffix", config) + self.assertEqual(stacked.shape, (2, 20, 10)) + self.assertEqual(stacked[1, 0, 0], 2.0) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/maxtext_utils_nnx_test.py b/tests/unit/maxtext_utils_nnx_test.py index 2ac59b7326..10e2b8621f 100644 --- a/tests/unit/maxtext_utils_nnx_test.py +++ b/tests/unit/maxtext_utils_nnx_test.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from typing import Any import jax +import jax.numpy as jnp from flax import nnx from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax.experimental import mesh_utils @@ -178,6 +179,28 @@ def test_get_partition_spec_nnx(self): self.assertEqual(spec["linear"]["bias"], expected_spec_b) self.assertNotIsInstance(spec["linear"]["kernel"], NamedSharding) + def test_nnx_ensure_scan_leading_axis_mixed(self): + """Test broadcasting on a mixed state of scalars and arrays.""" + length = 8 + state = nnx.State( + { + "scalar": nnx.Param(jnp.array(1.0)), + "array": nnx.Param(jnp.zeros((16,))), + "raw_scalar": jnp.array(2.0), + "raw_array": jnp.zeros((10,)), + } + ) + + broadcast_state = maxtext_utils_nnx.nnx_ensure_scan_leading_axis(state, length) + + # NNX Variables + self.assertEqual(broadcast_state["scalar"].get_value().shape, (length,)) + self.assertEqual(broadcast_state["array"].get_value().shape, (16,)) + + # Raw JAX types + self.assertEqual(broadcast_state["raw_scalar"].shape, (length,)) + self.assertEqual(broadcast_state["raw_array"].shape, (10,)) + if __name__ == "__main__": unittest.main() diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index b9d4295a94..3392bc86b9 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -54,6 +54,7 @@ from maxtext.utils import max_logging from maxtext.utils import maxtext_utils from maxtext.utils import model_creation_utils +from maxtext.utils import lora_utils import numpy as np import torch import torch.nn.functional as F @@ -245,21 +246,28 @@ def get_data(golden_data_point, config): def main(config, test_args): # pylint: disable=W0621 """Test the Whole Model of model_name""" + init_rng = jax.random.PRNGKey(config.init_weights_seed) + init_rng, rng1 = jax.random.split(init_rng) + devices_array = maxtext_utils.create_device_mesh(config) + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) + if not test_args.run_hf_model: """Comparing maxtext/huggingface model with pre-loaded golden logitis""" max_logging.log("Initializing MaxText model") - init_rng = jax.random.PRNGKey(config.init_weights_seed) - init_rng, rng1 = jax.random.split(init_rng) - devices_array = maxtext_utils.create_device_mesh(config) - mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + if config.pure_nnx_decoder and config.enable_nnx: + model = model_creation_utils.from_pretrained(config, mesh=mesh, model_mode=MODEL_MODE_TRAIN) + + if config.lora.enable_lora: + model = lora_utils.apply_lora_to_model(model, mesh, config) + if config.lora.lora_restore_path: + mock_trainer = type("MockTrainer", (), {"model": model, "train_steps": 0}) + lora_utils.restore_lora_from_path(mock_trainer, config) + state = None else: model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, config, False, rng1) - state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) + state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) if test_args.golden_logits_path == "": input_golden_data_path = os.path.join( @@ -284,15 +292,24 @@ def main(config, test_args): # pylint: disable=W0621 max_logging.log(f"\n--- Comparing forward pass for golden data index: {golden_data_index} ---") ids, decoder_segment_ids, decoder_positions, golden_logits, seq_len, images = get_data(golden_data_point, config) max_logging.log("maxtext forward pass") - full_train_logits = model.apply( - state.params, - ids, - decoder_positions, - decoder_segment_ids, - encoder_images=images, - enable_dropout=False, - rngs={"aqt": init_rng}, - ) + if state is None: + full_train_logits = model( + decoder_input_tokens=ids, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + encoder_images=images, + enable_dropout=False, + ) + else: + full_train_logits = model.apply( + state.params, + ids, + decoder_positions, + decoder_segment_ids, + encoder_images=images, + enable_dropout=False, + rngs={"aqt": init_rng}, + ) full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits, tiled=True) # if full_train_logits shape is [num_hosts, batch_size, seq_len, vocab_size] @@ -374,7 +391,7 @@ def main(config, test_args): # pylint: disable=W0621 max_logging.log("\n[test criteria]") max_logging.log( f"Checking Numerical Differences between train logits and golden logits against " - f"atol={test_args.rtol} rtol={test_args.atol}." + f"atol={test_args.atol} rtol={test_args.rtol}." ) rtol_val = float(test_args.rtol) atol_val = float(test_args.atol) @@ -414,7 +431,15 @@ def main(config, test_args): # pylint: disable=W0621 torch_dtype = dtype_mapping.get(config.dtype.name, torch.bfloat16) max_logging.log(f"Loading HF model with dtype: {torch_dtype} (derived from config.dtype: {config.dtype})") - hf_model = AutoModelForCausalLM.from_pretrained(test_args.hf_model_path, dtype=torch_dtype, token=hf_token) + hf_model = AutoModelForCausalLM.from_pretrained(test_args.hf_model_path, torch_dtype=torch_dtype, token=hf_token) + hf_lora_path = config.hf_lora_adapter_path + if hf_lora_path: + max_logging.log(f"Loading HF PEFT LoRA adapter from {hf_lora_path}") + try: + from peft import PeftModel # pylint: disable=import-outside-toplevel + except ImportError as exc: + raise ImportError("peft library is required to load HF LoRA adapter. Run `pip install peft`.") from exc + hf_model = PeftModel.from_pretrained(hf_model, hf_lora_path) # Load tokenizer: `test_args.hf_model_path` or fallback to `config.tokenizer_path` try: @@ -431,21 +456,23 @@ def main(config, test_args): # pylint: disable=W0621 if any(config.model_name.startswith(prefix) for prefix in pad_token_prefixes): tokenizer.pad_token = tokenizer.eos_token - init_rng = jax.random.PRNGKey(config.init_weights_seed) - init_rng, rng1 = jax.random.split(init_rng) - devices_array = maxtext_utils.create_device_mesh(config) - mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + if config.pure_nnx_decoder and config.enable_nnx: + maxtext_model = model_creation_utils.from_pretrained(config, mesh=mesh, model_mode=MODEL_MODE_TRAIN) + + if config.lora.enable_lora: + maxtext_model = lora_utils.apply_lora_to_model(maxtext_model, mesh, config) + if config.lora.lora_restore_path: + mock_trainer = type("MockTrainer", (), {"model": maxtext_model, "train_steps": 0}) + lora_utils.restore_lora_from_path(mock_trainer, config) + maxtext_state = None else: maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, rng1) - if test_args.ckpt_type == "linen": - maxtext_state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) - else: - maxtext_state, _ = model_creation_utils.setup_decode_state_from_nnx(maxtext_model, config, rng1, mesh) + if test_args.ckpt_type == "linen": + maxtext_state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) + else: + maxtext_state, _ = model_creation_utils.setup_decode_state_from_nnx(maxtext_model, config, rng1, mesh) prompts = ["I love to", "Today is a", "What is the"] all_data_to_save = [] @@ -480,14 +507,22 @@ def main(config, test_args): # pylint: disable=W0621 hf_logits_torch = hf_model(**inputs).logits # --- MaxText Forward Pass --- - mt_logits_jax = maxtext_model.apply( - maxtext_state.params, - mt_ids, - mt_decoder_positions, - mt_decoder_segment_ids, - enable_dropout=False, - rngs={"aqt": init_rng}, - ) + if maxtext_state is None: + mt_logits_jax = maxtext_model( + decoder_input_tokens=mt_ids, + decoder_positions=mt_decoder_positions, + decoder_segment_ids=mt_decoder_segment_ids, + enable_dropout=False, + ) + else: + mt_logits_jax = maxtext_model.apply( + maxtext_state.params, + mt_ids, + mt_decoder_positions, + mt_decoder_segment_ids, + enable_dropout=False, + rngs={"aqt": init_rng}, + ) mt_logits_jax_sliced = mt_logits_jax[:, :actual_seq_len, :] mt_logits_torch = convert_jax_weight_to_torch(mt_logits_jax_sliced) @@ -566,7 +601,7 @@ def main(config, test_args): # pylint: disable=W0621 # Reconstruct model_args (script name + the args MaxText needs) model_args = [sys.argv[0]] + remaining_args - cfg = pyconfig.initialize(model_args) + cfg = pyconfig.initialize_pydantic(model_args) assert ( test_args.atol is not None or test_args.max_kl_div is not None ), "At least one of --atol or --max_kl_div must be specified to define the test criteria."