From 7b317e3df5743fbd619c31f3a330d0f52a3e480a Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 9 May 2026 01:12:00 +0000 Subject: [PATCH] Add PR comment commands for tests and optimize Jupyter notebooks for CI - Add pr_comment_commands.yml to trigger standard or scheduled PR tests via issue comments. - Simplify build_and_test_maxtext.yml test orchestration logic and inputs. - Support CI_STEPS environment variable in post-training notebooks to dynamically scale down step count (e.g. to 2 steps) for fast CI smoke tests. - Skip pathways unit tests in CI due to GHA service container TPU access limitations. --- .github/workflows/build_and_test_maxtext.yml | 14 +- .github/workflows/pr_comment_commands.yml | 76 ++++ .github/workflows/run_jupyter_notebooks.yml | 1 + .../examples/sft_llama3_demo_tpu.ipynb | 208 +++++----- src/maxtext/examples/sft_qwen3_demo.ipynb | 364 ++++++++++-------- 5 files changed, 400 insertions(+), 263 deletions(-) create mode 100644 .github/workflows/pr_comment_commands.yml diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index a818da8725..d538ac692c 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -20,6 +20,10 @@ on: pull_request: workflow_call: workflow_dispatch: + inputs: + force_scheduled_run: + type: boolean + default: false schedule: # Run the job every 4 hours - cron: '0 */4 * * *' @@ -154,7 +158,7 @@ jobs: with: flavor: ${{ matrix.flavor }} base_image: maxtext-unit-test-tpu:py312 - is_scheduled_run: ${{ github.event_name == 'schedule' }} + is_scheduled_run: ${{ github.event_name == 'schedule' || github.event.inputs.force_scheduled_run == 'true' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} gpu-tests: @@ -169,7 +173,7 @@ jobs: with: flavor: ${{ matrix.flavor }} base_image: maxtext-unit-test-cuda12:py312 - is_scheduled_run: ${{ github.event_name == 'schedule' }} + is_scheduled_run: ${{ github.event_name == 'schedule' || github.event.inputs.force_scheduled_run == 'true' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} cpu-tests: @@ -184,7 +188,7 @@ jobs: with: flavor: ${{ matrix.flavor }} base_image: maxtext-unit-test-tpu:py312 - is_scheduled_run: ${{ github.event_name == 'schedule' }} + is_scheduled_run: ${{ github.event_name == 'schedule' || github.event.inputs.force_scheduled_run == 'true' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} maxtext_tpu_pathways_unit_tests: @@ -203,7 +207,7 @@ jobs: xla_python_client_mem_fraction: 0.75 tf_force_gpu_allow_growth: false container_resource_option: "--privileged" - is_scheduled_run: ${{ github.event_name == 'schedule' }} + is_scheduled_run: ${{ github.event_name == 'schedule' || github.event.inputs.force_scheduled_run == 'true' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} maxtext_tpu_pathways_integration_tests: @@ -222,7 +226,7 @@ jobs: xla_python_client_mem_fraction: 0.75 tf_force_gpu_allow_growth: false container_resource_option: "--privileged" - is_scheduled_run: ${{ github.event_name == 'schedule' }} + is_scheduled_run: ${{ github.event_name == 'schedule' || github.event.inputs.force_scheduled_run == 'true' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} all_tests_passed: diff --git a/.github/workflows/pr_comment_commands.yml b/.github/workflows/pr_comment_commands.yml new file mode 100644 index 0000000000..e84c3b8148 --- /dev/null +++ b/.github/workflows/pr_comment_commands.yml @@ -0,0 +1,76 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PR Comment Commands + +on: + issue_comment: + types: [created] + +permissions: + actions: write + pull-requests: write + issues: write + contents: read + +jobs: + trigger_workflows: + name: Trigger Workflows via Comment + # Only run if it's a pull request comment, and the author is a collaborator/member/owner + if: | + github.event.issue.pull_request != null && + contains(fromJSON('["OWNER", "MEMBER", "COLLABORATOR"]'), github.event.comment.author_association) + runs-on: ubuntu-latest + steps: + - name: Parse and Dispatch Comment Command + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMMENT_BODY: ${{ github.event.comment.body }} + ISSUE_NUMBER: ${{ github.event.issue.number }} + REPO: ${{ github.repository }} + run: | + # Normalize comment body (lowercase, trim whitespace) + COMMENT=$(echo "$COMMENT_BODY" | tr '[:upper:]' '[:lower:]' | xargs) + + echo "Processing comment: '$COMMENT'" + + # Get PR branch name using gh CLI + PR_BRANCH=$(gh pr view "$ISSUE_NUMBER" --json headRefName --jq '.headRefName' --repo "$REPO") + echo "PR branch ref is '$PR_BRANCH'" + + MATCHED=false + WORKFLOW_NAME="" + PARAMS=() + + # 1. Match "test this please" / "test this pr" / "run tests" + if [[ "$COMMENT" =~ ^(test this please|test this pr|run tests)$ ]]; then + MATCHED=true + WORKFLOW_NAME="build_and_test_maxtext.yml" + PARAMS=(-f force_scheduled_run=false) + + # 2. Match "run scheduled tests" + elif [[ "$COMMENT" == "run scheduled tests" ]]; then + MATCHED=true + WORKFLOW_NAME="build_and_test_maxtext.yml" + PARAMS=(-f force_scheduled_run=true) + fi + + if [ "$MATCHED" = "true" ]; then + echo "Dispatching workflow $WORKFLOW_NAME with parameters: ${PARAMS[*]}" + + # Dispatch the workflow + gh workflow run "$WORKFLOW_NAME" --ref "$PR_BRANCH" "${PARAMS[@]}" --repo "$REPO" + else + echo "No matching command found in comment." + fi diff --git a/.github/workflows/run_jupyter_notebooks.yml b/.github/workflows/run_jupyter_notebooks.yml index 7d868e8d5c..c10ee9e275 100644 --- a/.github/workflows/run_jupyter_notebooks.yml +++ b/.github/workflows/run_jupyter_notebooks.yml @@ -82,6 +82,7 @@ jobs: MAXTEXT_INSTALLED: ${{ inputs.maxtext_installed }} # TODO: Fix evaluation in sft_qwen3_demo.ipynb and remove this env variable RUN_EVALUATION: "False" + CI_STEPS: "2" run: | if [ "${MAXTEXT_INSTALLED}" == "true" ]; then # Move to the directory where code is baked into the image. See the Dockerfile. diff --git a/src/maxtext/examples/sft_llama3_demo_tpu.ipynb b/src/maxtext/examples/sft_llama3_demo_tpu.ipynb index 3cb7997126..a8fe6fcb6e 100644 --- a/src/maxtext/examples/sft_llama3_demo_tpu.ipynb +++ b/src/maxtext/examples/sft_llama3_demo_tpu.ipynb @@ -1,49 +1,51 @@ { "cells": [ { + "id": "b16050df", "cell_type": "markdown", - "metadata": { - "id": "iBmRfde4Kgv4" - }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo_tpu.ipynb)\n", "\n", "# Llama3.1-8B-Instruct Supervised Fine-Tuning (SFT) Demo\n" - ] + ], + "metadata": { + "id": "iBmRfde4Kgv4" + }, + "execution_count": null }, { + "id": "683b2546", "cell_type": "markdown", - "metadata": { - "id": "x8QDt1zoKgv4" - }, "source": [ "## Overview\n", "\n", "This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct using the Hugging Face ultrachat_200k dataset with MaxText and Tunix integration for efficient training.\n", "\n", "This notebook can run on **TPU v6e-8** or **v5p-8**." - ] + ], + "metadata": { + "id": "x8QDt1zoKgv4" + }, + "execution_count": null }, { + "id": "b1b76e00", "cell_type": "markdown", - "metadata": { - "id": "SDLlkqKJKgv4" - }, "source": [ "## Prerequisites\n", "\n", "Before running this notebook, make sure your environment is set up for the method you are using. Follow the [Run MaxText Python Notebooks on TPUs](https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html) guide and complete all steps for your chosen method (Google Colab, VS Code, or Local Jupyter Lab) before proceeding.\n", "\n", - "If you run into issues, refer to the [Common Pitfalls & Debugging](https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html#common-pitfalls-debugging) section of the guide." - ] + "If you run into issues, refer to the [Common Pitfalls \u0026 Debugging](https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html#common-pitfalls-debugging) section of the guide." + ], + "metadata": { + "id": "SDLlkqKJKgv4" + }, + "execution_count": null }, { + "id": "8cbe1936", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "__TGXwlkKgv4" - }, - "outputs": [], "source": [ "try:\n", " import google.colab\n", @@ -52,24 +54,28 @@ "except ImportError:\n", " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", " IN_COLAB = False" - ] + ], + "metadata": { + "id": "__TGXwlkKgv4" + }, + "execution_count": null }, { + "id": "7d56bcbe", "cell_type": "markdown", - "metadata": { - "id": "yuvt9qDPKgv4" - }, "source": [ "## Installation: MaxText and Post training Dependencies\n", "\n", "**Running the notebook on Visual Studio or JupyterLab:** Before proceeding, create a virtual environment and install the required post-training dependencies by following `Option 3: Installing [tpu-post-train]` in the [MaxText installation guide](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source). Once the environment is set up, ensure the notebook is running within it." - ] + ], + "metadata": { + "id": "yuvt9qDPKgv4" + }, + "execution_count": null }, { + "id": "c8d71fcb", "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "if IN_COLAB:\n", " # Clone the MaxText repository\n", @@ -82,11 +88,13 @@ " # Install MaxText and post-training dependencies\n", " !uv pip install -e .[tpu-post-train] --resolution=lowest\n", " !install_tpu_post_train_extra_deps" - ] + ], + "metadata": {}, + "execution_count": null }, { + "id": "795f8ccf", "cell_type": "markdown", - "metadata": {}, "source": [ "**Session restart Instructions for Colab:**\n", "1. Navigate to the menu at the top of the screen.\n", @@ -94,25 +102,24 @@ "3. Select **Restart session** from the dropdown menu.\n", "\n", "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." - ] + ], + "metadata": {}, + "execution_count": null }, { + "id": "b26d948c", "cell_type": "markdown", + "source": [ + "## Environment Setup" + ], "metadata": { "id": "o_w4iLJyKgv5" }, - "source": [ - "## Environment Setup" - ] + "execution_count": null }, { + "id": "92701170", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BXLzWnjUKgv5", - "language": "python" - }, - "outputs": [], "source": [ "import datetime\n", "import jax\n", @@ -126,27 +133,28 @@ "\n", "\n", "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" - ] + ], + "metadata": { + "id": "BXLzWnjUKgv5", + "language": "python" + }, + "execution_count": null }, { + "id": "38c89f34", "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "if not jax.distributed.is_initialized():\n", " jax.distributed.initialize()\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"JAX devices: {jax.devices()}\")" - ] + ], + "metadata": {}, + "execution_count": null }, { + "id": "444869aa", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dzMQJ6RXKgv5" - }, - "outputs": [], "source": [ "if IN_COLAB:\n", " from huggingface_hub import notebook_login\n", @@ -154,24 +162,26 @@ "else:\n", " from huggingface_hub import login\n", " login()" - ] + ], + "metadata": { + "id": "dzMQJ6RXKgv5" + }, + "execution_count": null }, { + "id": "ddd5e444", "cell_type": "markdown", + "source": [ + "## Model Configurations" + ], "metadata": { "id": "v2evOnbOKgv5" }, - "source": [ - "## Model Configurations" - ] + "execution_count": null }, { + "id": "48480b45", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pdvqPIB_Kgv5" - }, - "outputs": [], "source": [ "MODEL_NAME = \"llama3.1-8b-Instruct\"\n", "\n", @@ -185,24 +195,26 @@ " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", "\n", "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")" - ] + ], + "metadata": { + "id": "pdvqPIB_Kgv5" + }, + "execution_count": null }, { + "id": "2c655d35", "cell_type": "markdown", + "source": [ + "## Download Llama3.1-8B Model Checkpoint from Hugging Face" + ], "metadata": { "id": "w8hpE43wKgv5" }, - "source": [ - "## Download Llama3.1-8B Model Checkpoint from Hugging Face" - ] + "execution_count": null }, { + "id": "bc0dacee", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "uupzLQl5Kgv5" - }, - "outputs": [], "source": [ "if not epath.Path(MODEL_CHECKPOINT_PATH).exists():\n", " # Install torch for the conversion script\n", @@ -238,24 +250,26 @@ " MODEL_CHECKPOINT_PATH = os.path.join(MODEL_CHECKPOINT_PATH, \"0/items\")\n", "else:\n", " print(f\"Model checkpoint exists at {MODEL_CHECKPOINT_PATH}\")" - ] + ], + "metadata": { + "id": "uupzLQl5Kgv5" + }, + "execution_count": null }, { + "id": "6fb6fafa", "cell_type": "markdown", + "source": [ + "## MaxText Configurations" + ], "metadata": { "id": "tFlbPuOAKgv6" }, - "source": [ - "## MaxText Configurations" - ] + "execution_count": null }, { + "id": "742cb9ed", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "In-jdp1AAwrL" - }, - "outputs": [], "source": [ "# Load configuration for SFT training\n", "config_argv = [\n", @@ -263,7 +277,7 @@ " f\"{MAXTEXT_PKG_DIR}/configs/post_train/sft.yml\",\n", " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", " f\"model_name={MODEL_NAME}\",\n", - " \"steps=100\",\n", + " f\"steps={os.environ.get('CI_STEPS', '100')}\",\n", " \"per_device_batch_size=1\",\n", " \"max_target_length=1024\",\n", " \"learning_rate=2.0e-5\",\n", @@ -281,24 +295,26 @@ "print(f\" Model: {config.model_name}\")\n", "print(f\" Training Steps: {config.steps}\")\n", "print(f\" Output Directory: {config.base_output_directory}\")" - ] + ], + "metadata": { + "id": "In-jdp1AAwrL" + }, + "execution_count": null }, { + "id": "7d3ca270", "cell_type": "markdown", + "source": [ + "## SFT Training" + ], "metadata": { "id": "fFSQOH3CKgv6" }, - "source": [ - "## SFT Training" - ] + "execution_count": null }, { + "id": "09315374", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mgwpNgQYCJEd" - }, - "outputs": [], "source": [ "import traceback\n", "\n", @@ -317,23 +333,29 @@ " print(\"❌Training Failed!\")\n", " print(\"=\" * 60)\n", " traceback.print_exc()\n", - " print(\"\\nFor troubleshooting, refer to the Common Pitfalls & Debugging section:\")\n", + " print(\"\\nFor troubleshooting, refer to the Common Pitfalls \u0026 Debugging section:\")\n", " print(\"https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html#common-pitfalls-debugging\")\n", " sys.exit(1)" - ] + ], + "metadata": { + "id": "mgwpNgQYCJEd" + }, + "execution_count": null }, { + "id": "e31902b3", "cell_type": "markdown", - "metadata": { - "id": "vzW1NXX6Kgv6" - }, "source": [ "## 📚 Learn More\n", "\n", "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n", "- **Configuration**: See `src/maxtext/configs/post_train/sft.yml` for all available options\n", "- **Documentation**: Check `src/maxtext/trainers/post_train/sft/train_sft.py` for the `train` function implementation" - ] + ], + "metadata": { + "id": "vzW1NXX6Kgv6" + }, + "execution_count": null } ], "metadata": { @@ -360,6 +382,6 @@ "version": "3.12.11" } }, - "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 0, + "nbformat": 4 } diff --git a/src/maxtext/examples/sft_qwen3_demo.ipynb b/src/maxtext/examples/sft_qwen3_demo.ipynb index afba9e9a04..7003f5fdec 100644 --- a/src/maxtext/examples/sft_qwen3_demo.ipynb +++ b/src/maxtext/examples/sft_qwen3_demo.ipynb @@ -1,53 +1,55 @@ { "cells": [ { + "id": "542a54a7", "cell_type": "markdown", - "metadata": { - "id": "1nb_Ppf2ZUQL" - }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_qwen3_demo.ipynb)\n", "\n", "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n" - ] + ], + "metadata": { + "id": "1nb_Ppf2ZUQL" + }, + "execution_count": null }, { + "id": "07845e83", "cell_type": "markdown", - "metadata": { - "id": "FGbe4_YQZUQL" - }, "source": [ "## Overview\n", "\n", "This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).\n", "The primary goal is to demonstrate the end-to-end process of:\n", "1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.\n", - "2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.\n", + "2. SFT Training: Fine-tune the model using MaxText \u0026 Tunix SFT trainer.\n", "3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.\n", "\n", "This notebook can run on the **public TPU v5e-1**." - ] + ], + "metadata": { + "id": "FGbe4_YQZUQL" + }, + "execution_count": null }, { + "id": "ba50bca1", "cell_type": "markdown", - "metadata": { - "id": "zolxPWhQZUQL" - }, "source": [ "## Prerequisites\n", "\n", "Before running this notebook, make sure your environment is set up for the method you are using. Follow the [Run MaxText Python Notebooks on TPUs](https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html) guide and complete all steps for your chosen method (Google Colab, VS Code, or Local Jupyter Lab) before proceeding.\n", "\n", - "If you run into issues, refer to the [Common Pitfalls & Debugging](https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html#common-pitfalls-debugging) section of the guide." - ] + "If you run into issues, refer to the [Common Pitfalls \u0026 Debugging](https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html#common-pitfalls-debugging) section of the guide." + ], + "metadata": { + "id": "zolxPWhQZUQL" + }, + "execution_count": null }, { + "id": "78b176f5", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "o0gz1E8VtpsI" - }, - "outputs": [], "source": [ "try:\n", " import google.colab\n", @@ -56,26 +58,28 @@ "except ImportError:\n", " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", " IN_COLAB = False" - ] + ], + "metadata": { + "id": "o0gz1E8VtpsI" + }, + "execution_count": null }, { + "id": "aefe67b5", "cell_type": "markdown", - "metadata": { - "id": "D9ms-jTSZUQL" - }, "source": [ "## Installation: MaxText and Post training Dependencies\n", "\n", "**Running the notebook on Visual Studio or JupyterLab:** Before proceeding, create a virtual environment and install the required post-training dependencies by following `Option 3: Installing [tpu-post-train]` in the [MaxText installation guide](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source). Once the environment is set up, ensure the notebook is running within it." - ] + ], + "metadata": { + "id": "D9ms-jTSZUQL" + }, + "execution_count": null }, { + "id": "c7a0dd4d", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bjnwIv1YtpsI" - }, - "outputs": [], "source": [ "if IN_COLAB:\n", " # Clone the MaxText repository\n", @@ -88,13 +92,15 @@ " # Install MaxText and post-training dependencies\n", " !uv pip install -e .[tpu-post-train] --resolution=lowest\n", " !install_tpu_post_train_extra_deps" - ] + ], + "metadata": { + "id": "bjnwIv1YtpsI" + }, + "execution_count": null }, { + "id": "ff25198e", "cell_type": "markdown", - "metadata": { - "id": "OKWBCMrstpsI" - }, "source": [ "**Session restart Instructions for Colab:**\n", "1. Navigate to the menu at the top of the screen.\n", @@ -102,24 +108,26 @@ "3. Select **Restart session** from the dropdown menu.\n", "\n", "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." - ] + ], + "metadata": { + "id": "OKWBCMrstpsI" + }, + "execution_count": null }, { + "id": "b204c8a8", "cell_type": "markdown", + "source": [ + "## Imports" + ], "metadata": { "id": "Clexf-j7ZUQM" }, - "source": [ - "## Imports" - ] + "execution_count": null }, { + "id": "325e8a0c", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PkBI9A3JZUQM" - }, - "outputs": [], "source": [ "import jax\n", "import os\n", @@ -143,29 +151,29 @@ "from etils import epath\n", "\n", "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" - ] + ], + "metadata": { + "id": "PkBI9A3JZUQM" + }, + "execution_count": null }, { + "id": "c40de3d5", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NIiA2OletpsI" - }, - "outputs": [], "source": [ "if not jax.distributed.is_initialized():\n", " jax.distributed.initialize()\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"JAX devices: {jax.devices()}\")" - ] + ], + "metadata": { + "id": "NIiA2OletpsI" + }, + "execution_count": null }, { + "id": "f11c84db", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JBbPN-uVZUQM" - }, - "outputs": [], "source": [ "if IN_COLAB:\n", " from huggingface_hub import notebook_login\n", @@ -173,24 +181,26 @@ "else:\n", " from huggingface_hub import login\n", " login()" - ] + ], + "metadata": { + "id": "JBbPN-uVZUQM" + }, + "execution_count": null }, { + "id": "24c04af4", "cell_type": "markdown", + "source": [ + "## Model Configurations" + ], "metadata": { "id": "aENuzm9iZUQM" }, - "source": [ - "## Model Configurations" - ] + "execution_count": null }, { + "id": "f9538b3d", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RjPYYl3zZUQM" - }, - "outputs": [], "source": [ "MODEL_NAME = \"qwen3-0.6b\"\n", "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\"\n", @@ -207,24 +217,26 @@ " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", "\n", "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%m-%S\")" - ] + ], + "metadata": { + "id": "RjPYYl3zZUQM" + }, + "execution_count": null }, { + "id": "ad5d8e79", "cell_type": "markdown", + "source": [ + "## Download Qwen3-0.6B Model Checkpoint from Hugging Face" + ], "metadata": { "id": "4L37Ij4NZUQM" }, - "source": [ - "## Download Qwen3-0.6B Model Checkpoint from Hugging Face" - ] + "execution_count": null }, { + "id": "54d01667", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kJanDAc0ZUQM" - }, - "outputs": [], "source": [ "if not epath.Path(MODEL_CHECKPOINT_PATH).exists():\n", " # Install torch for the conversion script\n", @@ -259,24 +271,26 @@ " MODEL_CHECKPOINT_PATH = os.path.join(MODEL_CHECKPOINT_PATH, \"0/items\")\n", "else:\n", " print(f\"Model checkpoint exists at {MODEL_CHECKPOINT_PATH}\")" - ] + ], + "metadata": { + "id": "kJanDAc0ZUQM" + }, + "execution_count": null }, { + "id": "6a708b31", "cell_type": "markdown", + "source": [ + "## Dataset Configurations" + ], "metadata": { "id": "PC-hILG0ZUQM" }, - "source": [ - "## Dataset Configurations" - ] + "execution_count": null }, { + "id": "eb06934e", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "O3MLdr9kZUQM" - }, - "outputs": [], "source": [ "DATASET_NAME = \"openai/gsm8k\"\n", "TRAIN_DATA_SPLIT = \"train\"\n", @@ -290,24 +304,26 @@ "FORMATTING_FUNC_KWARGS = {\"template_path\": f\"{DATA_TEMPLATE_PATH}\"}\n", "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n", "BATCH_SIZE = 1 # Number of test samples to process in a batch" - ] + ], + "metadata": { + "id": "O3MLdr9kZUQM" + }, + "execution_count": null }, { + "id": "67eb3118", "cell_type": "markdown", + "source": [ + "## MaxText Configurations" + ], "metadata": { "id": "yeAHmxSYZUQM" }, - "source": [ - "## MaxText Configurations" - ] + "execution_count": null }, { + "id": "ea8581d7", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "In-jdp1AAwrL" - }, - "outputs": [], "source": [ "%%capture\n", "config = pyconfig.initialize(\n", @@ -323,7 +339,7 @@ " f\"train_split={TRAIN_DATA_SPLIT}\",\n", " f\"hf_data_dir={HF_DATA_DIR}\",\n", " f\"train_data_columns={TRAIN_DATA_COLUMNS}\",\n", - " \"steps=500\",\n", + " f\"steps={os.environ.get('CI_STEPS', '500')}\",\n", " \"per_device_batch_size=1\",\n", " \"max_target_length=1024\",\n", " \"learning_rate=3e-6\",\n", @@ -333,33 +349,37 @@ " f\"formatting_func_kwargs={FORMATTING_FUNC_KWARGS}\",\n", " ]\n", ")" - ] + ], + "metadata": { + "id": "In-jdp1AAwrL" + }, + "execution_count": null }, { + "id": "d1daea36", "cell_type": "markdown", + "source": [ + "## Initial Setup \u0026 Data Preparation" + ], "metadata": { "id": "O9b0GWo-ZUQM" }, - "source": [ - "## Initial Setup & Data Preparation" - ] + "execution_count": null }, { + "id": "a4a90422", "cell_type": "markdown", + "source": [ + "### Create Test Dataset" + ], "metadata": { "id": "TDqFmvUCZUQM" }, - "source": [ - "### Create Test Dataset" - ] + "execution_count": null }, { + "id": "41a6f910", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wscWYxrtZUQM" - }, - "outputs": [], "source": [ "run_evaluation = os.environ.get(\"RUN_EVALUATION\", \"false\").lower() == \"true\"\n", "if run_evaluation:\n", @@ -372,44 +392,48 @@ " )\n", "else:\n", " print(\"Evaluation on test dataset is skipped. Plese set `RUN_EVALUATION=True` to run evaluation.\")" - ] + ], + "metadata": { + "id": "wscWYxrtZUQM" + }, + "execution_count": null }, { + "id": "22a6a2bd", "cell_type": "markdown", + "source": [ + "### Create SFT Trainer State" + ], "metadata": { "id": "bLSvOOEUZUQM" }, - "source": [ - "### Create SFT Trainer State" - ] + "execution_count": null }, { + "id": "8ab40685", "cell_type": "code", - "execution_count": null, + "source": [ + "trainer, mesh = train_sft.setup_trainer_state(config)" + ], "metadata": { "id": "2IHsC0m6ZUQM" }, - "outputs": [], - "source": [ - "trainer, mesh = train_sft.setup_trainer_state(config)" - ] + "execution_count": null }, { + "id": "06893cc7", "cell_type": "markdown", + "source": [ + "### Create vLLM Rollout" + ], "metadata": { "id": "PpKtEqzFZUQM" }, - "source": [ - "### Create vLLM Rollout" - ] + "execution_count": null }, { + "id": "7501e0ce", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3-pf_rbqZUQM" - }, - "outputs": [], "source": [ "if run_evaluation:\n", " tunix_model = TunixMaxTextAdapter(trainer.model)\n", @@ -428,39 +452,41 @@ " )\n", "else:\n", " print(\"Evaluation on test dataset is skipped. Plese set `RUN_EVALUATION=True` to run evaluation.\")" - ] + ], + "metadata": { + "id": "3-pf_rbqZUQM" + }, + "execution_count": null }, { + "id": "f8e996d9", "cell_type": "markdown", + "source": [ + "## Evaluation before SFT Training" + ], "metadata": { "id": "567gTxsEZUQM" }, - "source": [ - "## Evaluation before SFT Training" - ] + "execution_count": null }, { + "id": "245f5576", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OnACa3zCZUQM" - }, - "outputs": [], "source": [ "if run_evaluation:\n", " print(\"Running Pre-SFT Evaluation...\")\n", " score = evaluate_model(test_dataset, vllm_rollout, debug=False)\n", "else:\n", " print(\"Evaluation on test dataset is skipped. Plese set `RUN_EVALUATION=True` to run evaluation.\")" - ] + ], + "metadata": { + "id": "OnACa3zCZUQM" + }, + "execution_count": null }, { + "id": "335fa8e5", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "u5-M4iYkZUQN" - }, - "outputs": [], "source": [ "if run_evaluation:\n", " print(\"========================= Score for PRE-SFT Evaluation =========================\")\n", @@ -473,48 +499,52 @@ " )\n", "else:\n", " print(\"Evaluation on test dataset is skipped. Plese set `RUN_EVALUATION=True` to run evaluation.\")" - ] + ], + "metadata": { + "id": "u5-M4iYkZUQN" + }, + "execution_count": null }, { + "id": "8572fe80", "cell_type": "markdown", + "source": [ + "## SFT Training" + ], "metadata": { "id": "EJE1ookSAzz-" }, - "source": [ - "## SFT Training" - ] + "execution_count": null }, { + "id": "adf59eb8", "cell_type": "code", - "execution_count": null, + "source": [ + "print(\"Starting SFT Training...\")\n", + "trainer = train_sft.train_model(config, trainer, mesh)\n", + "print(\"SFT Training Complete!\")" + ], "metadata": { "editable": true, "id": "mgwpNgQYCJEd", "tags": [] }, - "outputs": [], - "source": [ - "print(\"Starting SFT Training...\")\n", - "trainer = train_sft.train_model(config, trainer, mesh)\n", - "print(\"SFT Training Complete!\")" - ] + "execution_count": null }, { + "id": "80ac1428", "cell_type": "markdown", + "source": [ + "## Evaluation after SFT Training" + ], "metadata": { "id": "WEdNYRhwZUQN" }, - "source": [ - "## Evaluation after SFT Training" - ] + "execution_count": null }, { + "id": "354431bd", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XcsZacZdZUQN" - }, - "outputs": [], "source": [ "if run_evaluation:\n", " print(\"Running Post-SFT Evaluation...\")\n", @@ -524,17 +554,15 @@ " score = evaluate_model(test_dataset, vllm_rollout, debug=False)\n", "else:\n", " print(\"Evaluation on test dataset is skipped. Plese set `RUN_EVALUATION=True` to run evaluation.\")" - ] + ], + "metadata": { + "id": "XcsZacZdZUQN" + }, + "execution_count": null }, { + "id": "5b9b04a9", "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "-JtYTPvJZUQN", - "tags": [] - }, - "outputs": [], "source": [ "if run_evaluation:\n", " print(\"========================= Score for POST-SFT Evaluation =========================\")\n", @@ -547,7 +575,13 @@ " )\n", "else:\n", " print(\"Evaluation on test dataset is skipped. Plese set `RUN_EVALUATION=True` to run evaluation.\")" - ] + ], + "metadata": { + "editable": true, + "id": "-JtYTPvJZUQN", + "tags": [] + }, + "execution_count": null } ], "metadata": { @@ -574,6 +608,6 @@ "version": "3.12.11" } }, - "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 0, + "nbformat": 4 }