Skip to content

Feat: integrate NNX LoRA support via Qwix with unified configuration#3320

Open
RexBearIU wants to merge 1 commit intomainfrom
jackyf/feat/lora-nnx
Open

Feat: integrate NNX LoRA support via Qwix with unified configuration#3320
RexBearIU wants to merge 1 commit intomainfrom
jackyf/feat/lora-nnx

Conversation

@RexBearIU
Copy link
Copy Markdown
Collaborator

@RexBearIU RexBearIU commented Mar 5, 2026

Description

Overview
This pull request introduces native LoRA support in MaxText by leveraging the NNX model definition and the Qwix library. It enables a seamless workflow for applying LoRA adapters during
training and provides utilities for bidirectional checkpoint conversion with the HuggingFace ecosystem.

Key Changes

  • Core NNX Integration:
    • Refactored NNXDecoder layer application logic to support nnx.scan with dynamic graph initialization, ensuring compatibility with Qwix's parameter materialization.
  • SFT Pipeline Enhancements:
    • Integrated apply_lora_to_model and restore_lora_from_path into the SFT trainer.
    • Added dummy input preparation to materialize LoRA parameters before trainer initialization.
  • Bidirectional Conversion Scripts:
    • hf_lora_to_maxtext.py: Converts HuggingFace PEFT adapters to MaxText checkpoint format. Updated to 2026 copyright and cleaned up comments.
    • maxtext_to_hf_lora.py: Converts MaxText LoRA checkpoints back to HuggingFace format. Updated to use max_logging and 2026 copyright.
  • Configuration & Type System:
    • Added lora_module_path auto-detection logic for popular models (Llama, etc.) via lora_module_path.yml.
    • Updated types.py with specific LoRA/QLoRA fields.
  • Current Limitations:
    • QLoRA flags (lora_weight_qtype, lora_tile_size) are included in the configuration but explicitly marked as TODO / Not Working for this initial release.

Tests

The Qwix-based LoRA implementation was validated through a new unit test suite and verified via a comprehensive tutorial.

  1. Unit Tests
    Implemented tests/unit/lora_utils_test.py to ensure structural correctness and trainer compatibility. Key areas covered:
  • Model Transformation: Verified that apply_lora_to_model correctly injects nnx.LoRAParam into the model state.
  • Layer Scanning: Confirmed the implementation works with both scan_layers=True and scan_layers=False by handling the resulting differences in the nnx module path tree.
  • Trainer Compatibility: Validated that tunix.sft.peft_trainer.PeftTrainer correctly identifies the LoRA parameters for optimization, ensuring only adapter weights are trained.
  • Path Matching: Tested the regex logic for auto-detecting LoRA target modules across different model architectures (e.g., Llama).

Command to run unit tests:

1 # From the maxtext root directory
2 export PYTHONPATH=$PYTHONPATH:$(pwd)/src:$(pwd)
3 python3 tests/unit/lora_utils_test.py

  1. Documentation
  • Added docs/tutorials/posttraining/lora.md, which provides a step-by-step guide for running LoRA fine-tuning, including environment setup and checkpoint conversion. This tutorial serves as the reference for end-to-end functional verification.

Logit test result

https://paste.googleplex.com/6233928391327744

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch from 11939f9 to 6540bc8 Compare March 5, 2026 10:53
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 5, 2026

@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 7 times, most recently from 69e481b to 80b5592 Compare March 11, 2026 08:19
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 11 times, most recently from f5a0f6d to 23e79c7 Compare March 25, 2026 08:33
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 5 times, most recently from 5a05148 to 7570b3d Compare April 14, 2026 02:45
@RexBearIU RexBearIU marked this pull request as ready for review April 14, 2026 04:08
Comment thread src/maxtext/checkpoint_conversion/hf_lora_to_maxtext.py Outdated
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 2 times, most recently from 0dfeb76 to 2f91ad8 Compare April 16, 2026 10:32
@RexBearIU RexBearIU changed the title Jackyf/feat/lora nnx Feat: integrate NNX LoRA support via Qwix with unified configuration Apr 16, 2026
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch from 2f91ad8 to f5736a1 Compare April 16, 2026 10:54
Comment thread src/maxtext/layers/nnx_decoders.py
Comment thread docs/tutorials/posttraining/lora.md Outdated
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 5 times, most recently from b747695 to a952892 Compare April 24, 2026 14:39
Comment thread docs/_static/js/editable_commands.js Outdated
Comment thread docs/_static/js/editable_commands.js Outdated
Comment thread docs/_static/js/editable_commands.js Outdated
Comment thread src/maxtext/utils/lora_utils.py Outdated
Copy link
Copy Markdown
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @RexBearIU. Left some comments but this is generally looking good to me

Comment thread docs/tutorials/posttraining/lora.md
Comment thread pytest.ini
Comment thread docs/tutorials/posttraining/lora.md Outdated
Comment thread docs/tutorials/posttraining/lora.md Outdated
Comment thread docs/tutorials/posttraining/lora.md Outdated
Comment thread docs/tutorials/posttraining/lora.md Outdated
Comment thread docs/tutorials/posttraining/lora.md Outdated
Comment thread docs/tutorials/posttraining/lora.md Outdated
Comment thread src/maxtext/examples/sft_llama3_demo_tpu.ipynb
Comment thread src/maxtext/configs/types.py Outdated
Comment thread src/maxtext/checkpoint_conversion/utils/utils.py Outdated
Comment thread src/maxtext/utils/lora_utils.py
Comment thread src/dependencies/requirements/base_requirements/requirements.txt Outdated
Comment thread src/maxtext/utils/maxtext_utils.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants