diff --git a/.github/README.md b/.github/README.md new file mode 120000 index 0000000000..525f9bef8c --- /dev/null +++ b/.github/README.md @@ -0,0 +1 @@ +../jaxpp.README.md \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index c0d04607f2..6666df3920 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -92,6 +92,6 @@ "quantization=int8", "quantize_kvcache=True" ] - } + }, ] } \ No newline at end of file diff --git a/jaxpp.Dockerfile b/jaxpp.Dockerfile new file mode 100644 index 0000000000..451e51059a --- /dev/null +++ b/jaxpp.Dockerfile @@ -0,0 +1,26 @@ +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG BASE_IMAGE +FROM $BASE_IMAGE AS base +ARG JAX_INSTALL_URL + +COPY requirements.txt /tmp/requirements.txt +RUN uv pip install -U pip && uv pip install --no-cache-dir -U -r /tmp/requirements.txt + +COPY --chown=$USER_UID:$USER_GID . maxtext + +RUN uv pip install --no-cache-dir -e '/workdir/jaxpp[dev]' +RUN uv pip install --no-cache-dir -e /workdir/maxtext[cuda_12] --resolution=lowest && \ + if [[ -n "$JAX_INSTALL_URL" ]]; then uv pip install $JAX_INSTALL_URL; fi diff --git a/jaxpp.README.md b/jaxpp.README.md new file mode 100644 index 0000000000..e852cd5614 --- /dev/null +++ b/jaxpp.README.md @@ -0,0 +1,78 @@ +# Overview + +This repository is a fork of [MaxText](https://github.com/AI-Hypercomputer/maxtext) created for training with [JaxPP](https://github.com/NVIDIA/jaxpp). + +# Notable changes + +The changes between this repo and the upstream MaxText is kept minimal in general. +Some of the notable changes are listed below. + +* The `__call__` method of the `Decoder` class in [src/MaxText/layers/decoders.py](src/MaxText/layers/decoders.py) + calls `jaxpp.pipeline_enter_stage` to mark stage boundaries for pipeline parallelism. +* The `maybe_initialize_jax_distributed_system` function in [src/MaxText/max_utils.py](src/MaxText/max_utils.py) + creates `RemoteMpmdMesh` to be used by JaxPP. +* [src/MaxText/train.py](src/MaxText/train.py) contains changes to + * Enable pipeline parallelism for the train step, and + * Mark the pipeline loop in the train step with `jaxpp.treduce`. + +# Docker image + +For ease of use, we provide a docker image with this fork under `/workdir/maxtext`. +The docker image has all the dependencies that are needed to use MaxText with JaxPP installed. + +## Building and Testing Docker Container + +The build process uses the JaxPP base image as a starting point. Follow the instructions at [JaxPP's Building the Base Image](https://github.com/NVIDIA/jaxpp#building-the-base-image) to build the `jaxpp-base` image first. + +### Prerequisites +- Docker installed and configured +- NVIDIA Container Toolkit installed +- JaxPP base image built and available locally + +### Building the Main Image + +After building the base image, you can build the main image: + +```bash +# Check if jaxpp-base image exists +if [ -z "$(docker images -q jaxpp-base)" ]; then + echo "Error: jaxpp-base image not found. Please build it first using the instructions at https://github.com/NVIDIA/jaxpp#building-the-base-image." +else + docker build --force-rm=true \ + -f jaxpp.Dockerfile \ + --build-arg BASE_IMAGE=jaxpp-base \ + -t maxtext-jaxpp . +fi +``` + +### Running Tests + +The container includes several test suites for different models: + +1. **Tiny Llama4 Model Tests**: +```bash +docker run --gpus=all --shm-size=10.24gb --ulimit memlock=-1 --ulimit stack=67108864 \ + -e CUDA_VISIBLE_DEVICES=0 --rm --workdir /workdir/maxtext maxtext-jaxpp \ + "nvidia-smi && CONFIG_FILE=./scripts/llama4_proxy_config.sh bash scripts/test_1gpu_config.sh" +``` + +2. **Tiny Mixtral Model Tests**: +```bash +docker run --gpus=all --shm-size=10.24gb --ulimit memlock=-1 --ulimit stack=67108864 \ + -e CUDA_VISIBLE_DEVICES=0 --rm --workdir /workdir/maxtext maxtext-jaxpp \ + "nvidia-smi && MODEL_CONFIG='model_name=mixtral-8x7b override_model_config=True base_num_decoder_layers=2 base_emb_dim=512 base_mlp_dim=1792' bash scripts/test_1gpu_config.sh" +``` + +3. **Tiny Mistral Model Tests**: +```bash +docker run --gpus=all --shm-size=10.24gb --ulimit memlock=-1 --ulimit stack=67108864 \ + -e CUDA_VISIBLE_DEVICES=0 --rm --workdir /workdir/maxtext maxtext-jaxpp \ + "nvidia-smi && bash MODEL_CONFIG='model_name=mistral-7b override_model_config=True base_num_decoder_layers=2' bash scripts/test_1gpu_config.sh" +``` + +Note: The tests require GPU access and sufficient GPU memory. + +# Profiling + +Profiling is enabled by default in the 6th step, and the first 7 steps are ignored in the performance statistics. +It allows the performance statstics to be collected without the profiling overhead while producing the profiling data while running the benchmarks. \ No newline at end of file diff --git a/scripts/deepseek3_proxy_config.sh b/scripts/deepseek3_proxy_config.sh new file mode 100644 index 0000000000..eda5d6f234 --- /dev/null +++ b/scripts/deepseek3_proxy_config.sh @@ -0,0 +1,8 @@ +export MODEL_CONFIG=" + model_name=deepseek3-671b + override_model_config=True + base_num_decoder_layers=4 + base_emb_dim=896 + base_mlp_dim=2304 + base_moe_mlp_dim=256 +" diff --git a/scripts/llama3.3_proxy_config.sh b/scripts/llama3.3_proxy_config.sh new file mode 100644 index 0000000000..8efc81e6ff --- /dev/null +++ b/scripts/llama3.3_proxy_config.sh @@ -0,0 +1,5 @@ +export MODEL_CONFIG=" + model_name=llama3.3-70b + override_model_config=True + base_num_decoder_layers=2 +" diff --git a/scripts/llama4_proxy_config.sh b/scripts/llama4_proxy_config.sh new file mode 100644 index 0000000000..139746b2be --- /dev/null +++ b/scripts/llama4_proxy_config.sh @@ -0,0 +1,8 @@ +export MODEL_CONFIG=" + model_name=llama4-17b-16e + override_model_config=True + base_num_decoder_layers=2 + base_emb_dim=640 + base_mlp_dim=2048 + base_moe_mlp_dim=2048 +" diff --git a/scripts/local_mc.sh b/scripts/local_mc.sh new file mode 100644 index 0000000000..187968b293 --- /dev/null +++ b/scripts/local_mc.sh @@ -0,0 +1,25 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ -z "$N_PROCS" ] || [ -z "$N_GPUS" ] || [ -z "$COMMAND" ]; then + echo "N_PROCS, N_GPUS, and COMMAND must be set" + exit 1 +fi + +seq 0 $(($N_PROCS - 1)) | xargs -P $N_PROCS -I {} bash -c ' \ +n_gpus=$2; \ +start=$(({} * n_gpus)); \ +end=$((start + n_gpus - 1)); \ +JAX_COORDINATOR_IP="localhost" JAX_COORDINATOR_PORT=1234 NNODES=$1 NODE_RANK={} \ +CUDA_VISIBLE_DEVICES=$(seq -s, $start $end) $3' _ $N_PROCS $N_GPUS "$COMMAND" diff --git a/scripts/run_local_mc.sh b/scripts/run_local_mc.sh new file mode 100644 index 0000000000..4f0cd40813 --- /dev/null +++ b/scripts/run_local_mc.sh @@ -0,0 +1,18 @@ +export TEST_CONFIG=" + override_model_config=True dataset_type=synthetic steps=10 +" + +export COMMAND="python3 -u -m MaxText.train src/MaxText/configs/base.yml \ + base_output_directory=run_local_mc_outputs \ + run_name=run_$(date +%Y-%m-%d-%H:%M:%S) \ + enable_checkpointing=false \ + async_checkpointing=false \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + hardware=gpu \ + $MODEL_CONFIG \ + $TEST_CONFIG \ + $PARALLELISM_CONFIG + $JAXPP_CONFIG" + +bash ./scripts/local_mc.sh diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh new file mode 100644 index 0000000000..2be4249b48 --- /dev/null +++ b/scripts/run_tests.sh @@ -0,0 +1,24 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +bash scripts/test_8gpu_llama4_proxy.sh + +RAY_ADDRESS=local python -m MaxText.train src/MaxText/configs/base.yml run_name=runner_jaxpp_$(date +%Y-%m-%d-%H-%M) base_output_directory=/tmp/log hardware=gpu dataset_type=synthetic model_name=gpt3-52k steps=20 dtype=bfloat16 max_target_length=2048 per_device_batch_size=4 dcn_data_parallelism=1 ici_data_parallelism=2 ici_tensor_parallelism=2 ici_pipeline_parallelism=2 num_pipeline_repeats=1 num_pipeline_microbatches=8 enable_checkpointing=false use_jaxpp=True schedule=interleaved_1f1b + +tests=$(python3 -m pytest --co -q -W ignore::DeprecationWarning tests/train_compile_jaxpp_test.py | awk '/^[[:space:]]*$/{exit} {print}') + +for t in $tests; do + echo $t + python3 -m pytest --log-cli-level=INFO -s "$t" +done diff --git a/scripts/test_1gpu_config.sh b/scripts/test_1gpu_config.sh new file mode 100644 index 0000000000..083e1199ae --- /dev/null +++ b/scripts/test_1gpu_config.sh @@ -0,0 +1,28 @@ +if [ -n "$MODEL_CONFIG" ] && [ -n "$CONFIG_FILE" ]; then + echo "Error: both MODEL_CONFIG and CONFIG_FILE are set" + exit 1 +fi + +if [ -n "$CONFIG_FILE" ]; then + source $CONFIG_FILE +fi + +export N_PROCS=1 +export N_GPUS=1 + +# Run plain JAX config +bash ./scripts/run_local_mc.sh + +# Run JaxPP config +export PARALLELISM_CONFIG="ici_pipeline_parallelism=1" + +export JAXPP_CONFIG=" + scan_layers=False + use_jaxpp=True + schedule=interleaved_1f1b + num_pipeline_microbatches=4 + num_pipeline_repeats=1 + max_target_length=64 +" + +bash ./scripts/run_local_mc.sh diff --git a/scripts/test_8gpu_deepseek3_proxy.sh b/scripts/test_8gpu_deepseek3_proxy.sh new file mode 100644 index 0000000000..5b42422d49 --- /dev/null +++ b/scripts/test_8gpu_deepseek3_proxy.sh @@ -0,0 +1,22 @@ +source scripts/deepseek3_proxy_config.sh + +export PARALLELISM_CONFIG=" + dcn_pipeline_parallelism=1 ici_pipeline_parallelism=2 + ici_data_parallelism=1 + ici_tensor_parallelism=2 + ici_expert_parallelism=2 + ici_fsdp_parallelism=1 +" + +export JAXPP_CONFIG=" + scan_layers=False + use_jaxpp=True + schedule=interleaved_1f1b + num_pipeline_microbatches=4 + num_pipeline_repeats=1 +" + +export N_PROCS=2 +export N_GPUS=4 + +bash ./scripts/run_local_mc.sh diff --git a/scripts/test_8gpu_llama3.3_proxy.sh b/scripts/test_8gpu_llama3.3_proxy.sh new file mode 100644 index 0000000000..16eb980c05 --- /dev/null +++ b/scripts/test_8gpu_llama3.3_proxy.sh @@ -0,0 +1,30 @@ +export JAX_USE_SHARDY_PARTITIONER=0 +export JAXPP_ENABLE_LICM=1 +export NVTE_FUSED_ATTN=1 +# --xla_dump_hlo_pass_re=.* +export XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_hlo_as_text --xla_dump_to='./llama3-hlos-pp2' --xla_gpu_enable_latency_hiding_scheduler=true" +source scripts/llama3.3_proxy_config.sh + +export PARALLELISM_CONFIG=" + dcn_pipeline_parallelism=1 ici_pipeline_parallelism=2 + ici_data_parallelism=1 + ici_context_parallelism=2 + ici_tensor_parallelism=2 + ici_fsdp_parallelism=1 + per_device_batch_size=1 + max_target_length=8192 +" + +export JAXPP_CONFIG=" + scan_layers=False + use_jaxpp=True + schedule=interleaved_1f1b + num_pipeline_microbatches=4 + num_pipeline_repeats=1 + profiler=xplane +" + +export N_PROCS=8 +export N_GPUS=1 + +bash ./scripts/run_local_mc.sh diff --git a/scripts/test_8gpu_llama4_proxy.sh b/scripts/test_8gpu_llama4_proxy.sh new file mode 100644 index 0000000000..73d7c8749a --- /dev/null +++ b/scripts/test_8gpu_llama4_proxy.sh @@ -0,0 +1,22 @@ +source scripts/llama4_proxy_config.sh + +export PARALLELISM_CONFIG=" + dcn_pipeline_parallelism=1 ici_pipeline_parallelism=2 + ici_data_parallelism=1 + ici_tensor_parallelism=2 + ici_expert_parallelism=2 + ici_fsdp_parallelism=1 +" + +export JAXPP_CONFIG=" + scan_layers=False + use_jaxpp=True + schedule=interleaved_1f1b + num_pipeline_microbatches=4 + num_pipeline_repeats=1 +" + +export N_PROCS=2 +export N_GPUS=4 + +bash ./scripts/run_local_mc.sh diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index c204c65ea7..f75847c94f 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -88,7 +89,7 @@ gcs_metrics: False save_config_to_gcs: False # Gradient dtype -grad_dtype: "float32" +grad_dtype: "bfloat16" # Activation dtypes. dtype: "bfloat16" @@ -170,8 +171,8 @@ mtp_eval_target_module: 0 # mixture of experts (moe) num_experts: 1 num_experts_per_tok: 1 -megablox: True -sparse_matmul: True +megablox: False # Only used when sparse_matmul=True +sparse_matmul: False capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.01 # weight for the load balance loss use_random_routing: False # whether to use random routing for debug/test purpose @@ -221,7 +222,7 @@ inhomogeneous_layer_cycle_interval: 1 # but a smaller size per microbatch which may hurt per-stage performance. Additionally, note when microbatches > num_stages we have the opportunity to # perform the circular transfer (last stage to first) asynchronously. # The bubble fraction is (num_stages - 1) / (num_pipeline_repeats * num_pipeline_microbatches + num_stages - 1) -num_layers_per_pipeline_stage: 1 +num_layers_per_pipeline_stage: 1 # NOTE(jaxpp) unused in JaxPP # The number of repeats will be set to num_decoder_layers / (num_pipeline_stages * num_layers_per_pipeline_stage) num_pipeline_repeats: -1 pipeline_parallel_layers: -1 # Pipeline only this number of layers - for the remaining layers the "stage" mesh axes will act like data parallelism. @@ -264,7 +265,7 @@ set_remat_policy_on_layers_per_stage: False # Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp', # 'save_qkv_proj', 'qkv_proj_offloaded', 'custom', 'minimal_offloaded', 'save_out_proj' and 'full'. # These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) -remat_policy: 'full' +remat_policy: 'minimal' # If "custom" remat_policy is chosen, you can select tensors from the following list to offload on host memory, rematerialize or save on device memory. # Pick one of these options for following tensors: ['remat','device','offload'] decoder_layer_input: 'device' # this tensor cannot be rematerialized - it serves as periodic checkpoints that act as the remat start points @@ -281,12 +282,12 @@ out_proj: 'remat' optimizer_memory_host_offload: False parameter_memory_host_offload: False -scan_layers: True # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. -param_scan_axis: 1 +scan_layers: False # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. +param_scan_axis: 0 # NOTE(jaxpp) we set to 0 instead of 1 to avoid flax transposes # The attention parameter dictates the specific algorithm/methodology used to compute the attention scores # The attention_type parameter determines the variants of attention, e.g. global or local_sliding -attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te +attention: 'cudnn_flash_te' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla attention_bias: False # If True, adds a learnable bias to the query, key, and value projections attention_sink: False @@ -361,7 +362,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' shard_mode: "auto" # can be either auto or explicit mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ - ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_batch', ['fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], @@ -381,8 +382,8 @@ logical_axis_rules: [ ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_kv_batch', ['fsdp', 'fsdp_transpose', 'expert']], + ['activation_kv_batch_no_exp', ['fsdp', 'fsdp_transpose']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], @@ -460,7 +461,7 @@ dcn_pipeline_parallelism: 1 dcn_expert_parallelism: 1 dcn_autoregressive_parallelism: 1 # never recommended ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_fsdp_transpose_parallelism: 1 ici_sequence_parallelism: 1 ici_context_parallelism: 1 @@ -620,10 +621,10 @@ autoregressive_decode_assert: "" # e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command} profiler: "" # Supported profiler: '', xplane, nsys # If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host. -upload_all_profiler_results: False +upload_all_profiler_results: True # Skip first n steps for profiling, to omit things like compilation and to give # the iteration time a chance to stabilize. -skip_first_n_steps_for_profiler: 1 +skip_first_n_steps_for_profiler: 5 # Profile for a small number of steps to avoid a large profile file size. profiler_steps: 5 hide_profiler_step_metric: False @@ -716,6 +717,12 @@ eval_interval: -1 # the specific number of train step between eval_step eval_steps: -1 # run this number of steps for eval, recommend setting this to prevent error due to running out of evel data target_eval_loss: 0. # early stop once reaching target eval_loss +# NOTE(jaxpp): begin parameters +use_jaxpp: False +schedule: "eager_1f1b" +fuse_steady_state: False +# NOTE(jaxpp): end parameters + # Goodput parameters enable_goodput_recording: False monitor_goodput: False @@ -800,7 +807,7 @@ allow_split_physical_axes: False # Apply transformations to the mesh to optimize for TPU v6e optimize_mesh_for_tpu_v6e: False -shardy: True # Whether to use shardy XLA backend (default in Jax starting 0.7.0), or GSPMD (to be fully deprecated ~2026) +shardy: False # Whether to use shardy XLA backend (default in Jax starting 0.7.0), or GSPMD (to be fully deprecated ~2026) use_ragged_attention: False ragged_block_size: 256 @@ -910,6 +917,7 @@ subslice_shape: "" # NNX enable_nnx: false +shard_optimizer_over_data: True ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/MaxText/configs/models/gpt3-175b.yml b/src/MaxText/configs/models/gpt3-175b.yml index 5e24cf4268..af515e66b2 100644 --- a/src/MaxText/configs/models/gpt3-175b.yml +++ b/src/MaxText/configs/models/gpt3-175b.yml @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,13 +29,13 @@ logits_via_embedding: True normalize_embedding_logits: False logits_dot_in_fp32: False normalization_layer_epsilon: 1.e-05 -use_iota_embed: True +use_iota_embed: False fused_qkv: True opt_type: "adam_pax" decoder_block: "gpt3" -dataset_path: "gs://mlperf-llm-public2" -dataset_name: "c4/en:3.0.4" -eval_dataset_name: "c4/en:3.0.5" +#dataset_path: "gs://mlperf-llm-public2" +#dataset_name: "c4/en:3.0.4" +#eval_dataset_name: "c4/en:3.0.5" gradient_clipping_threshold: 1. adam_b1: 0.9 adam_b2: 0.95 diff --git a/src/MaxText/configs/models/gpt3-52k.yml b/src/MaxText/configs/models/gpt3-52k.yml index 5513663f82..d41f94076c 100644 --- a/src/MaxText/configs/models/gpt3-52k.yml +++ b/src/MaxText/configs/models/gpt3-52k.yml @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +19,7 @@ base_emb_dim: 16 base_num_query_heads: 2 base_num_kv_heads: 2 base_mlp_dim: 64 -base_num_decoder_layers: 1 +base_num_decoder_layers: 8 head_dim: 8 trainable_position_size: 2048 mlp_activations: ["gelu"] @@ -28,7 +29,7 @@ logits_via_embedding: True normalize_embedding_logits: False logits_dot_in_fp32: False normalization_layer_epsilon: 1.e-05 -use_iota_embed: True +use_iota_embed: False fused_qkv: True opt_type: "adam_pax" decoder_block: "gpt3" diff --git a/src/MaxText/configs/models/llama2-70b.yml b/src/MaxText/configs/models/llama2-70b.yml index 67dd87f68f..ceecd6819d 100644 --- a/src/MaxText/configs/models/llama2-70b.yml +++ b/src/MaxText/configs/models/llama2-70b.yml @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# model config for llama2-7b +# model config for llama2-70b base_emb_dim: 8192 base_num_query_heads: 64 @@ -25,4 +26,3 @@ vocab_size: 32000 logits_via_embedding: False normalization_layer_epsilon: 1.0e-5 decoder_block: "llama2" -logical_axis_rules: [['norm', 'fsdp']] diff --git a/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py index 361cd3ea75..bd1179423f 100644 --- a/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -323,7 +324,7 @@ def make_c4_mlperf_train_iterator( """Make train iterator of customized C4 dataset for mlperf gpt3 training.""" train_ds = get_dataset( dataset_name=config.dataset_name, - split="train2", + split="train" if config.dataset_name == "c4/en:3.1.0" else "train2", dataloading_host_index=process_indices.index(jax.process_index()), dataloading_host_count=len(process_indices), enable_data_shuffling=config.enable_data_shuffling, @@ -334,10 +335,14 @@ def make_c4_mlperf_train_iterator( sp_tokenizer = get_tokenizer( config.tokenizer_path, config.tokenizer_type, config.add_bos, config.add_eos, config.hf_access_token ) + # A hack that (global_batch_size_to_load * num_process) when using jaxpp + # because preprocess_train_dataset divides batch sizes with num_process + # natively, and jaxpp uses only one process + global_batch_size_to_load = config.global_batch_size_to_load if not config.use_jaxpp else config.global_batch_size_to_load * jax.process_count() train_ds = preprocess_train_dataset( train_ds, sp_tokenizer=sp_tokenizer, - train_global_batch_size_to_load=config.global_batch_size_to_load, + train_global_batch_size_to_load=global_batch_size_to_load, max_target_length=config.max_target_length, shuffle_buffer_size=128, data_shuffle_seed=config.data_shuffle_seed, diff --git a/src/MaxText/input_pipeline/input_pipeline_interface.py b/src/MaxText/input_pipeline/input_pipeline_interface.py index 27b105bb21..4cd3a660c0 100644 --- a/src/MaxText/input_pipeline/input_pipeline_interface.py +++ b/src/MaxText/input_pipeline/input_pipeline_interface.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -59,7 +60,7 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh): # Return synthetic dataset if selected if config.dataset_type == "synthetic": - return SyntheticDataIterator(config, mesh), None + return SyntheticDataIterator(config, mesh), SyntheticDataIterator(config, mesh) dataset_type_to_train_eval_iterator = { "tfds": (make_tfds_train_iterator, make_tfds_eval_iterator), "grain": (make_grain_train_iterator, make_grain_eval_iterator), @@ -94,7 +95,7 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh): # Generate output eval iterator output_eval_iterator = None - if config.eval_interval > 0: + if config.eval_interval > 0 and not config.use_jaxpp: process_indices_eval = get_process_loading_real_data( config.data_sharding, config.global_batch_size_to_load_eval, diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index aba50a8d66..74c292c015 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -59,6 +60,9 @@ simple_layer, ) +import jaxpp +from packaging.version import Version + # ------------------------------------------------------------------------------ # The network: Decoder Definitions # ------------------------------------------------------------------------------ @@ -310,6 +314,8 @@ def get_remat_policy(self): elif cfg.remat_policy == "minimal": # save all except context policy = self.minimal_policy() + elif cfg.remat_policy == "save_dot_only": + policy = jax.checkpoint_policies.checkpoint_dots elif cfg.remat_policy == "save_dot_with_context_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( "query_proj", @@ -726,7 +732,7 @@ def __call__( deterministic, model_mode, ) - if cfg.using_pipeline_parallelism: + if cfg.using_pipeline_parallelism and not cfg.use_jaxpp: if cfg.pipeline_fsdp_ag_once: partition_spec = self.pipeline_module.get_weight_sharding( y, decoder_segment_ids, decoder_positions, deterministic, model_mode @@ -779,6 +785,7 @@ def __call__( )(y, *broadcast_args) else: if cfg.scan_layers: + assert not cfg.use_jaxpp, "Layer scanning is not supported with JaxPP" if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." layer_call_kwargs = { @@ -841,11 +848,26 @@ def __call__( **layer_kwargs, )(y, *broadcast_args) else: + num_logical_stages = 1 + layers_per_stage = cfg.num_decoder_layers + cutoffs = [cfg.num_decoder_layers] + if cfg.use_jaxpp: + num_logical_stages = cfg.dcn_pipeline_parallelism * cfg.ici_pipeline_parallelism * cfg.num_pipeline_repeats + layers_per_stage, rem = divmod(cfg.num_decoder_layers, num_logical_stages) + assert layers_per_stage > 0, (cfg.num_decoder_layers, num_logical_stages) + cutoffs = [] + tot = 0 + for stage in range(num_logical_stages): + num_layers_in_pipeline_stage = layers_per_stage + (1 if stage < rem else 0) + tot += num_layers_in_pipeline_stage + cutoffs.append(tot - 1) + + stage_id = 0 + add_last_enter_stage = Version(jaxpp.__version__) > Version("0.6.1") if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." - dense_layer = RemattedBlockLayers[0] - moe_layer = RemattedBlockLayers[1] +<<<<<<< HEAD layers = [dense_layer, moe_layer] layer_prefixes = ["dense_layers", "moe_layers"] num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers @@ -870,6 +892,40 @@ def __call__( ) if kv_caches is not None and kv_cache is not None: kv_caches[index] = kv_cache +======= + for index in range(cfg.first_num_dense_layers): + dense_layer = self.decoder_layer[0] if stage_id != num_logical_stages - 1 else RemattedBlockLayers[0] + y = dense_layer(config=cfg, mesh=mesh, name=f"dense_layers_{index}", quant=self.quant, model_mode=model_mode)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + ) + if index != cfg.num_decoder_layers - 1 and cutoffs[stage_id] == index: + y = jaxpp.api.pipeline_enter_stage(y, f"stage_{stage_id}") + stage_id += 1 + + for index in range(cfg.first_num_dense_layers, cfg.num_decoder_layers): + moe_layer = RemattedBlockLayers[1] if stage_id != num_logical_stages - 1 else self.decoder_layer[1] + y = moe_layer(config=cfg, mesh=mesh, name=f"moe_layers_{index - cfg.first_num_dense_layers}", quant=self.quant, model_mode=model_mode)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + ) + if index != cfg.num_decoder_layers - 1 and cutoffs[stage_id] == index: + y = jaxpp.api.pipeline_enter_stage(y, f"stage_{stage_id}") + stage_id += 1 + +>>>>>>> jaxpp/main else: for lyr in range(cfg.num_decoder_layers): RemattedBlockLayer = RemattedBlockLayers[0] @@ -888,7 +944,8 @@ def __call__( layer_kwargs = {"layer_idx": lyr} if cfg.decoder_block == DecoderBlockType.GPT_OSS: layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} - layer = RemattedBlockLayer( + layer_ctor = RemattedBlockLayer if stage_id != num_logical_stages - 1 else self.decoder_layer[0] + layer = layer_ctor( config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs ) kv_cache = kv_caches[lyr] if kv_caches is not None else None @@ -905,8 +962,14 @@ def __call__( attention_metadata=attention_metadata, **layer_call_kwargs, ) +<<<<<<< HEAD if kv_caches is not None and kv_cache is not None: kv_caches[lyr] = kv_cache +======= + if lyr != cfg.num_decoder_layers - 1 and cutoffs[stage_id] == lyr: + y = jaxpp.api.pipeline_enter_stage(y, f"stage_{stage_id}") + stage_id += 1 +>>>>>>> jaxpp/main assert isinstance(y, jax.Array) @@ -921,6 +984,9 @@ def __call__( else: logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + if add_last_enter_stage: + logits = jaxpp.api.pipeline_enter_stage(logits, f"stage_{stage_id}") + # The API of the Decoder is now a tuple, providing both the main output # and the raw hidden state needed for auxiliary tasks. return logits, hidden_state, kv_caches diff --git a/src/MaxText/max_utils.py b/src/MaxText/max_utils.py index a8e0897031..5880a2c824 100644 --- a/src/MaxText/max_utils.py +++ b/src/MaxText/max_utils.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,6 +44,13 @@ from MaxText.gcloud_stub import is_decoupled from MaxText.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN +import tensorflow as tf +import jaxpp.api as jaxpp +import optax + +# jaxpp related imports +import jaxpp.api as jaxpp + initialize_multi_tier_checkpointing = initialization.initialize_multi_tier_checkpointing HYBRID_RING_64X4 = "hybrid_ring_64x4" HYBRID_RING_32X8 = "hybrid_ring_32x8" @@ -50,6 +58,12 @@ # pylint: disable=too-many-positional-arguments +def maybe_unwrap(a: jaxpp.MpmdArray | jax.Array): + if isinstance(a, jaxpp.MpmdArray): + return v if (v := a.first_mpmd_replica) is not None else 0 + return a + + def with_memory_kind(t, memory_kind): return jax.tree_util.tree_map(lambda x: x.with_memory_kind(kind=memory_kind), t) @@ -69,7 +83,8 @@ def finder(x): def l2norm_pytree(x): """L2 norm of a pytree of arrays.""" - return jnp.sqrt(jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(jnp.square(y)), x, initializer=0.0)) + per_param_sum = [jnp.sum(jnp.square(x)) for x in jax.tree.leaves(x)] + return jnp.sqrt(jaxpp.cross_mpmd_all_reduce(*(e.astype(jnp.float32) for e in per_param_sum))) def calculate_num_params_from_pytree(params): @@ -635,6 +650,23 @@ def _cross_entropy_with_logits_bwd( cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd) +def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): + prev_params_shardings = state_mesh_shardings.params + if config.shard_optimizer_over_data: + if isinstance(state_mesh_shardings.opt_state, optax.ScaleByAdamState): + sharded_fp32_params = state_mesh_shardings.opt_state.mu + elif isinstance(state_mesh_shardings.opt_state, tuple) and isinstance(state_mesh_shardings.opt_state[0], optax.ScaleByAdamState): + sharded_fp32_params = state_mesh_shardings.opt_state[0].mu + else: + raise NotImplementedError(f"Could not find optimizer state shardings from optimizer of type {type(state_mesh_shardings.opt_state)}") + if "params" not in sharded_fp32_params.keys(): + # When quantization=fp8 is enabled the sharded_fp32_params + # are not wrapped in `params`. Here we wrap them back. + sharded_fp32_params = {"params": sharded_fp32_params} + state_mesh_shardings = state_mesh_shardings.replace(params=dict(prev_params_shardings, **sharded_fp32_params)) + return prev_params_shardings, state_mesh_shardings + + def print_pytree_shape(print_str, ptree): print("\n") print(print_str) @@ -701,7 +733,8 @@ def print_mem_stats(label: str): stats = d.memory_stats() used = round(stats["bytes_in_use"] / 2**30, 2) limit = round(stats["bytes_limit"] / 2**30, 2) - max_logging.log(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") + peak_size = round(stats["peak_bytes_in_use"] / 2**30, 2) + max_logging.log(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) ({peak_size=} GiB) on {d}") except (RuntimeError, KeyError, TypeError) as ex: max_logging.log(f"\tMemstats unavailable, error: {ex}") diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 30abfbdf92..24ff5ca5a0 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,6 +44,10 @@ from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE from MaxText.inference.page_manager import PageState +import chex +from optax._src import base +import jaxpp.api as jaxpp + OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" @@ -668,7 +673,6 @@ def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, co ) return total_tflops, learnable_weight_tflops, causal_attention_tflops - def apply_gradient_clipping(raw_grads, state, clipping_threshold): """Applies gradient clipping to raw gradients, with special handing for FLAX fp8 stats. @@ -824,7 +828,7 @@ def setup_initial_state( tx, config, rng, - mesh, + maybe_mpmd_mesh, checkpoint_manager, is_training=True, ): @@ -845,12 +849,19 @@ def setup_initial_state( state_mesh_annotations: the mesh annotations for the train state """ + mesh = maybe_mpmd_mesh + if isinstance(maybe_mpmd_mesh, jaxpp.MpmdMesh): + mesh = maybe_mpmd_mesh.lowering_mesh() + unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( model, tx, config, rng, mesh, is_training ) # Initialization with nn_partitioning.axis_rules(config.logical_axis_rules): + if checkpoint_manager is not None: + assert not config.use_jaxpp + restored, raw_params = checkpointing.load_state_if_possible( checkpoint_manager, data_iterator, @@ -883,20 +894,91 @@ def setup_initial_state( else: init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) init_state_partial.__name__ = "initialize_state" - # pylint: disable=not-callable - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )(rng) - if raw_params: # If we loaded a partial state, we need to merge it. - state = state.replace(params=raw_params) + if config.use_jaxpp: + # First infer placement based on loop usage + # Imported here to avoid circular import errors + from MaxText import maxtext_utils + from MaxText.train import train_step + params_shardings, _state_mesh_shardings = max_utils.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + data_sharding = maxtext_utils.get_input_data_sharding(config, mesh) + ( + functional_train, + in_shard_train, + out_shard_train, + static_argnums_train, + donate_argnums_train, + ) = maxtext_utils.get_functional_train_with_signature(train_step, data_sharding, _state_mesh_shardings, model, config, params_shardings=params_shardings) + + p_train_step = jaxpp.mpmd_jit_with_loop( + functional_train, + mpmd_mesh=maybe_mpmd_mesh, + donate_argnums=donate_argnums_train, + in_shardings=in_shard_train, + out_shardings=out_shard_train, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + global_mpmd_train_step = p_train_step.trace_and_place( + unboxed_abstract_state, next(data_iterator), rng + ) + unboxed_abstract_state_placements = global_mpmd_train_step.in_shardings[0][0] + def attach_right_mesh(shaped: jax.ShapeDtypeStruct, dist_sharding): + sharding: jax.sharding.NamedSharding = shaped.sharding + jax_mesh = maybe_mpmd_mesh.mpmd_submesh(sorted(dist_sharding.mesh_ids)).jax_mesh + mpmd_sharding = jax.sharding.NamedSharding(jax_mesh, sharding.spec) + return jax.ShapeDtypeStruct(shaped.shape, shaped.dtype, sharding=mpmd_sharding, weak_type=shaped.weak_type) + + unboxed_mpmd_abstract_state = jax.tree.map(attach_right_mesh, unboxed_abstract_state, unboxed_abstract_state_placements) + replicated_sharding = jax.sharding.NamedSharding( + maybe_mpmd_mesh.lowering_mesh(), jax.sharding.PartitionSpec() + ) + state = jaxpp.mpmd_jit_rev( + lambda rng: jax.tree.map(jax._src.numpy.lax_numpy._array_copy, max_utils.unbox_logicallypartioned(init_state_partial(rng))), + out_refs=jax.tree.map(lambda s: s.mesh_ids, unboxed_abstract_state_placements), + mpmd_mesh=maybe_mpmd_mesh, + in_shardings=replicated_sharding, + out_shardings=in_shard_train[0], + )(rng) + else: + # pylint: disable=not-callable + state = jax.jit( + init_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )(rng) + if raw_params: # If we loaded a partial state, we need to merge it. + state = state.replace(params=raw_params) - state = max_utils.unbox_logicallypartioned(state) + state = max_utils.unbox_logicallypartioned(state) return state, state_mesh_annotations, state_mesh_shardings, data_iterator +def add_data_to_sharding(mesh, path, aval, sharding): + if not isinstance(sharding, jax.sharding.NamedSharding): + raise AssertionError(f"Expected NamedSharding, found {sharding} of {type(sharding)=} at {jax.tree_util.keystr(path)}") + try: + sharded_shape = sharding.shard_shape(aval.shape) + except Exception as e: + raise AssertionError(f"Could not shard value {jax.tree_util.keystr(path)} of shape={aval.shape} with {sharding=}") from e + pspec = sharding.spec + + if 'data' in jax.tree.leaves(pspec): + return sharding + + for idx, (size, partition) in enumerate(zip(sharded_shape, pspec)): + if partition is None: + partition = () + + if isinstance(partition, str): + partition = (partition,) + + if size % mesh.shape['data'] == 0 and (partition is None or 'tensor' not in partition): + added_component = ('data',) + partition + new_pspec = jax.sharding.PartitionSpec(*(pspec[:idx] + (added_component,) + pspec[idx+1:])) + return sharding.update(spec=new_pspec) + return sharding + + def get_abstract_state(model, tx, config, rng, mesh, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) diff --git a/src/MaxText/metric_logger.py b/src/MaxText/metric_logger.py index 2a840df211..4d02c6a827 100644 --- a/src/MaxText/metric_logger.py +++ b/src/MaxText/metric_logger.py @@ -92,6 +92,7 @@ def reset_eval_metrics(self): def write_metrics(self, metrics, step, is_training=True): """Entry point for all metrics writing in Train's Main.""" if metrics: + metrics = jax.tree.map(max_utils.maybe_unwrap, metrics) self.log_metrics(metrics, step, is_training) if self.config.enable_tensorboard: @@ -293,21 +294,21 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None): """Records eval metrics and writes the metrics to GCS and/or to TensorBoard.""" if metrics: self.cumulative_eval_metrics["scalar"]["eval/total_loss"] += float( - metrics["scalar"].get("evaluation/total_loss", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/total_loss", 0.0)) ) self.cumulative_eval_metrics["scalar"]["eval/total_weights"] += float( - metrics["scalar"].get("evaluation/total_weights", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/total_weights", 0.0)) ) self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float( - metrics["scalar"].get("evaluation/moe_lb_loss", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/moe_lb_loss", 0.0)) ) - self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(metrics["scalar"].get("evaluation/mtp_loss", 0.0)) + self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/mtp_loss", 0.0))) self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] += float( - metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0)) ) if self.config.use_dpo: self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] += float( - metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)) ) if eval_step_count: diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index 18ca85c2df..9f52bbc411 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +32,9 @@ from functools import partial from etils import epath +# jaxpp +import jaxpp.api as jaxpp + @overload def from_config( @@ -80,16 +84,20 @@ def from_config( """ devices_array = maxtext_utils.create_device_mesh(config, devices) - if mesh is None: + if not config.use_jaxpp: if config.shard_mode == ShardMode.EXPLICIT: axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) else: axis_types = tuple([AxisType.Auto] * len(config.mesh_axes)) - mesh = Mesh(devices_array, config.mesh_axes, axis_types=axis_types) + else: + mesh = jaxpp.MpmdMesh(Mesh(devices_array, config.mesh_axes), 'stage') + model = create_model(config, mesh.lowering_mesh() if config.use_jaxpp else mesh, model_mode=model_mode, rngs=rngs) - model = create_model(config, mesh, model_mode=model_mode, rngs=rngs) - + if config.use_jaxpp: + # At this point, model.mesh has mesh.lowering_mesh() as its value, but we need to set it to the original mesh + # so that the caller can have access to the original mesh. + model.mesh = mesh # Return only the model return model diff --git a/src/MaxText/profiler.py b/src/MaxText/profiler.py index e32e49ff2a..1131a4342c 100644 --- a/src/MaxText/profiler.py +++ b/src/MaxText/profiler.py @@ -40,8 +40,9 @@ def __init__(self, config, offset_step=0): self.finished_initial_profile_step = self._set_last_profiler_step(config.profiler_steps, config.steps) if config.profiler != "" and self.start_initial_profile_step >= config.steps: raise ValueError("Profiling requested but initial profiling step set past training final step") + self.use_jaxpp = config.use_jaxpp - def maybe_activate_profiler(self, step, state): + def maybe_activate_profiler(self, step, state, maybe_mpmd_mesh=None, profiling_process_ids=None): """Conditionally activates the profiler based on the current step. This method checks if the current training step matches the step designated for starting an initial profile, or if it meets the criteria for @@ -49,14 +50,21 @@ def maybe_activate_profiler(self, step, state): """ if self.mode != "" and (step == self.start_initial_profile_step or self.should_activate_periodic_profile(step)): optional_postfix = f"step_{step}" if self.profile_period > 0 else "" - self.activate(blocking_object=state, optional_postfix=optional_postfix) - - def activate(self, blocking_object=None, optional_postfix=""): + if self.use_jaxpp: + assert maybe_mpmd_mesh is not None + assert profiling_process_ids is not None + if maybe_mpmd_mesh.jax_mesh.is_multi_process and jax.process_index() in profiling_process_ids: + optional_postfix = f"mpmd_{maybe_mpmd_mesh.my_mpmd_axis_index:02}_gpu_{profiling_process_ids[jax.process_index()].id:06}_{optional_postfix}" + optional_postfix = f"proc_{jax.process_index():06}_{optional_postfix}" + profile = profiling_process_ids is None or jax.process_index() in profiling_process_ids + self.activate(blocking_object=state, optional_postfix=optional_postfix, profile=profile) + + def activate(self, blocking_object=None, optional_postfix="", profile=True): """Start the profiler. nsys profiler becomes no-op when libcudart.so is not available on the system.""" if self.profile_cleanly and blocking_object is not None: jax.block_until_ready(blocking_object) - if not (self.upload_all_profiler_results or jax.process_index() == 0): + if not (self.upload_all_profiler_results or jax.process_index() == 0) or not profile: return if self.mode != "": self.output_path = os.path.join(self.base_output_dir, optional_postfix) @@ -70,21 +78,22 @@ def activate(self, blocking_object=None, optional_postfix=""): elif self.mode == "xplane": jax.profiler.start_trace(self.output_path) - def maybe_deactivate_profiler(self, step, state): + def maybe_deactivate_profiler(self, step, state, profiling_process_ids=None): """Conditionally deactivates the profiler based on the current step. This method checks if the current training step matches the step designated for finishing the initial profile, or if it meets the criteria for deactivating a periodic profile. """ if self.mode != "" and (step == self.finished_initial_profile_step or self.should_deactivate_periodic_profile(step)): - self.deactivate(blocking_object=state) + profile = profiling_process_ids is None or jax.process_index() in profiling_process_ids + self.deactivate(blocking_object=state, profile=profile) - def deactivate(self, blocking_object=None): + def deactivate(self, blocking_object=None, profile=True): """End the profiler. The result is uploaded to the output bucket.""" if self.profile_cleanly and blocking_object is not None: jax.block_until_ready(blocking_object) - if not (self.upload_all_profiler_results or jax.process_index() == 0): + if not (self.upload_all_profiler_results or jax.process_index() == 0) or not profile: return if self.mode == "nsys": if self.libcudart is not None: diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 532731a59b..099a955994 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,13 +74,254 @@ def _load_config(config_name: str) -> omegaconf.DictConfig: loaded_parent_config_filename = os.path.join(os.path.dirname(config_name), base_path) if not os.path.isfile(loaded_parent_config_filename): dir_path = os.path.dirname(os.path.realpath(__file__)) +<<<<<<< HEAD loaded_parent_config_filename = os.path.join(dir_path, "configs", base_path) +======= + file_path = os.path.join(dir_path, "configs", "models", f"{model_name}.yml") + # Use omegaconf.OmegaConf to load the model-specific configuration. + model_vars = omegaconf.OmegaConf.load(file_path) + model_vars = omegaconf.OmegaConf.to_container(model_vars, resolve=True) + if raw_keys["override_model_config"]: + model_vars = {key: value for key, value in model_vars.items() if key not in keys_from_env_and_command_line} + updated_keys = list(model_vars.keys()) + raw_keys = validate_and_update_keys(raw_keys, model_vars, config_name) + return updated_keys + + +def create_parallelisms_list(raw_keys): + ici_parallelism = [ + raw_keys["ici_data_parallelism"], + raw_keys["ici_pipeline_parallelism"], + raw_keys["ici_fsdp_parallelism"], + raw_keys["ici_fsdp_transpose_parallelism"], + raw_keys["ici_sequence_parallelism"], + raw_keys["ici_context_parallelism"], + raw_keys["ici_context_autoregressive_parallelism"], + raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_transpose_parallelism"], + raw_keys["ici_tensor_sequence_parallelism"], + raw_keys["ici_expert_parallelism"], + raw_keys["ici_autoregressive_parallelism"], + ] + dcn_parallelism = [ + raw_keys["dcn_data_parallelism"], + raw_keys["dcn_pipeline_parallelism"], + raw_keys["dcn_fsdp_parallelism"], + raw_keys["dcn_fsdp_transpose_parallelism"], + raw_keys["dcn_sequence_parallelism"], + raw_keys["dcn_context_parallelism"], + raw_keys["dcn_context_autoregressive_parallelism"], + raw_keys["dcn_tensor_parallelism"], + raw_keys["dcn_tensor_transpose_parallelism"], + raw_keys["dcn_tensor_sequence_parallelism"], + raw_keys["dcn_expert_parallelism"], + raw_keys["dcn_autoregressive_parallelism"], + ] + raw_keys["ici_parallelism"] = ici_parallelism + raw_keys["dcn_parallelism"] = dcn_parallelism + return raw_keys + + +def set_mu_dtype(raw_keys): + # Default mu_dtype to weight_dtype if unset + if raw_keys["mu_dtype"]: + assert raw_keys["opt_type"] != "adam_pax", "opt_type adam_pax doesn't support explicitly setting mu_dtype" + + if raw_keys["mu_dtype"] == "": + return raw_keys["weight_dtype"] + else: + return jax.numpy.dtype(raw_keys["mu_dtype"]) + + +def validate_and_set_hlo_dump_defaults(raw_keys): + if not raw_keys["dump_hlo"]: + return raw_keys + if os.environ.get("XLA_FLAGS") and raw_keys["dump_hlo_xla_flags"]: + raise ValueError("You must set either XLA_FLAGS or dump_hlo_xla_flags to dump HLO, but not both.") + if not os.environ.get("XLA_FLAGS") and not raw_keys["dump_hlo_xla_flags"]: + raw_keys["dump_hlo_xla_flags"] = f"--xla_dump_to={raw_keys['dump_hlo_local_dir']} --xla_dump_large_constants" + if raw_keys["dump_hlo_local_module_name"]: + raw_keys["dump_hlo_xla_flags"] = ( + f"{raw_keys['dump_hlo_xla_flags']} --xla_dump_hlo_module_re={raw_keys['dump_hlo_local_module_name']}" + ) + if not raw_keys["dump_hlo_gcs_dir"]: + raw_keys["dump_hlo_gcs_dir"] = os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], "xla_dump") + else: + raw_keys["dump_hlo_gcs_dir"] = gcs_utils.add_trailing_slash(raw_keys["dump_hlo_gcs_dir"]) + if not os.environ.get("XLA_FLAGS"): + os.environ["XLA_FLAGS"] = raw_keys["dump_hlo_xla_flags"] + return raw_keys + + +def validate_multiple_slices(raw_keys): + if ( + math.fabs( + math.prod( + [ + raw_keys["dcn_data_parallelism"], + raw_keys["dcn_pipeline_parallelism"], + raw_keys["dcn_fsdp_parallelism"], + raw_keys["dcn_fsdp_transpose_parallelism"], + raw_keys["dcn_sequence_parallelism"], + raw_keys["dcn_context_parallelism"], + raw_keys["dcn_tensor_parallelism"], + raw_keys["dcn_tensor_sequence_parallelism"], + raw_keys["dcn_expert_parallelism"], + raw_keys["dcn_context_autoregressive_parallelism"], + raw_keys["dcn_autoregressive_parallelism"], + ] + ) + ) + > 1 + ): + assert raw_keys["num_slices"] > 1, "DCN parallelism requested but only one slice available." + + +def set_and_validate_pipeline_config(raw_keys): + if using_pipeline_parallelism(raw_keys): + # For pipeline parallelism, model_fsdp_ag_once should be False, and pipeline_fsdp_ag_once is typically True. + if raw_keys["model_fsdp_ag_once"]: + raise ValueError( + "You should only set pipeline_fsdp_once=True, leave model_fsdp_ag_once=False with pipeline parallelism." + ) + + def modify_activation_embed_and_logits_batch(logical_axis_rules): + for idx, logical_rule in enumerate(logical_axis_rules): + if logical_rule[0] == "activation_embed_and_logits_batch": + # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages. + # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape. + # The "stage" needs to be listed first since the microbatch dimension is first before the reshape. + logical_axis_rules[idx] = [ + "activation_embed_and_logits_batch", + ["stage", "data", "fsdp", "fsdp_transpose", "expert"] if not raw_keys["use_jaxpp"] else + ["stage", "fsdp", "fsdp_transpose", "expert"], + ] + break # Exit the loop after modifying the list + return logical_axis_rules + + def pipeline_first_axis(raw_keys): + # We have seen better performance when axes used for DCN are earlier in this list than ICI, see (b/339009148) for details + ici_parallelism = [ + raw_keys["ici_pipeline_parallelism"], + raw_keys["ici_data_parallelism"], + raw_keys["ici_fsdp_parallelism"], + raw_keys["ici_fsdp_transpose_parallelism"], + raw_keys["ici_sequence_parallelism"], + raw_keys["ici_context_parallelism"], + raw_keys["ici_context_autoregressive_parallelism"], + raw_keys["ici_tensor_parallelism"], + raw_keys["ici_tensor_transpose_parallelism"], + raw_keys["ici_tensor_sequence_parallelism"], + raw_keys["ici_expert_parallelism"], + raw_keys["ici_autoregressive_parallelism"], + ] + dcn_parallelism = [ + raw_keys["dcn_pipeline_parallelism"], + raw_keys["dcn_data_parallelism"], + raw_keys["dcn_fsdp_parallelism"], + raw_keys["dcn_fsdp_transpose_parallelism"], + raw_keys["dcn_sequence_parallelism"], + raw_keys["dcn_context_parallelism"], + raw_keys["dcn_context_autoregressive_parallelism"], + raw_keys["dcn_tensor_parallelism"], + raw_keys["dcn_tensor_transpose_parallelism"], + raw_keys["dcn_tensor_sequence_parallelism"], + raw_keys["dcn_expert_parallelism"], + raw_keys["dcn_autoregressive_parallelism"], + ] + mesh_axes = [ + "stage", + "data", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive", + ] + data_sharding = [ + [ + "stage", + "data", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive", + ] + ] + + raw_keys["ici_parallelism"] = ici_parallelism + raw_keys["dcn_parallelism"] = dcn_parallelism + raw_keys["mesh_axes"] = mesh_axes + raw_keys["data_sharding"] = data_sharding + return raw_keys + + raw_keys["using_pipeline_parallelism"] = True + raw_keys["logical_axis_rules"] = modify_activation_embed_and_logits_batch(raw_keys["logical_axis_rules"]) + raw_keys = pipeline_first_axis(raw_keys) + num_stages = int(raw_keys["ici_pipeline_parallelism"] * raw_keys["dcn_pipeline_parallelism"]) + if raw_keys["use_jaxpp"]: + assert raw_keys["pipeline_delay_activation_forwarding"] is False + assert raw_keys["num_pipeline_repeats"] >= 1 + assert raw_keys["num_pipeline_microbatches"] >= 1 + return raw_keys + + if raw_keys["pipeline_parallel_layers"] == -1: + if raw_keys["decoder_block"] == "deepseek": + moe_layers = raw_keys["num_decoder_layers"] - raw_keys["first_num_dense_layers"] + raw_keys["pipeline_parallel_layers"] = moe_layers + else: + raw_keys["pipeline_parallel_layers"] = raw_keys["num_decoder_layers"] +>>>>>>> jaxpp/main else: loaded_parent_config_filename = base_path +<<<<<<< HEAD base_cfg = _load_config(loaded_parent_config_filename) cfg = omegaconf.OmegaConf.merge(base_cfg, cfg) return cfg +======= + if raw_keys["num_pipeline_repeats"] == -1: + num_pipeline_repeats, remainder = divmod( + raw_keys["pipeline_parallel_layers"], num_stages * raw_keys["num_layers_per_pipeline_stage"] + ) + assert ( + not remainder + ), f"The number of layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) times the number of stages ({num_stages}) must divide the number of pipeline_parallel_layers which defaults to decoder layers ({raw_keys['pipeline_parallel_layers']}) " + raw_keys["num_pipeline_repeats"] = num_pipeline_repeats + assert ( + num_stages * raw_keys["num_pipeline_repeats"] * raw_keys["num_layers_per_pipeline_stage"] + == raw_keys["pipeline_parallel_layers"] + ), f"The product of pipeline stages ({num_stages}), repeats ({raw_keys['num_pipeline_repeats']}), and layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) must be equal to pipeline_parallel_layers which defaults to decoder layers ({raw_keys['pipeline_parallel_layers']})" + if raw_keys["num_pipeline_microbatches"] == -1: + if raw_keys["pipeline_delay_activation_forwarding"]: + raw_keys["num_pipeline_microbatches"] = 2 * num_stages + else: + raw_keys["num_pipeline_microbatches"] = num_stages + assert ( + raw_keys["num_pipeline_microbatches"] % num_stages == 0 or raw_keys["use_jaxpp"] + ), f"The number of microbatches ({raw_keys['num_pipeline_microbatches']}) must be divisible by the number of stages ({num_stages})" + assert ( + raw_keys["micro_batch_size_to_train_on"] % raw_keys["num_pipeline_microbatches"] == 0 + ), f"The batch size ({raw_keys['micro_batch_size_to_train_on']}) must be divisible by the number of microbatches ({raw_keys['num_pipeline_microbatches']})" + if raw_keys["pipeline_delay_activation_forwarding"]: + assert ( + raw_keys["num_pipeline_microbatches"] >= 2 * num_stages + ), f"Delayed activation forwarding requires at least 2 * num_stages microbatches, but {num_stages} stages are used with {raw_keys['num_pipeline_microbatches']} microbatches" + else: + raw_keys["using_pipeline_parallelism"] = False + return raw_keys +>>>>>>> jaxpp/main def _tuples_to_lists(l: list | tuple | Any) -> list | Any: @@ -102,9 +344,72 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]: if raw_keys.get("dataset_type") == "hf" and "tokenizer_type" not in raw_keys: raw_keys["tokenizer_type"] = "huggingface" +<<<<<<< HEAD for key, value in raw_keys.items(): if key not in valid_fields: logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key)) +======= +def validate_sparse_matmul_parallelism(raw_keys): + # TODO: remove once b/434699033 resolved + if raw_keys["sparse_matmul"] and (using_expert_parallelism(raw_keys) and (not raw_keys["use_jaxpp"] and using_pipeline_parallelism(raw_keys))): + raise ValueError("Sparse matmul doesn't support using expert and pipeline parallelism together.") + + # TODO: remove once b/435539039 resolved + if raw_keys["sparse_matmul"] and ( + using_fsdp_and_transpose_parallelism(raw_keys) + and using_expert_parallelism(raw_keys) + and using_tensor_parallelism(raw_keys) + ): + raise ValueError("Sparse matmul doesn't support using fsdp, expert, and tensor parallelism together.") + tensor_parallelism = ( + raw_keys["ici_tensor_parallelism"] + * raw_keys["dcn_tensor_parallelism"] + * raw_keys["ici_tensor_sequence_parallelism"] + * raw_keys["dcn_tensor_sequence_parallelism"] + * raw_keys["ici_tensor_transpose_parallelism"] + * raw_keys["dcn_tensor_transpose_parallelism"] + ) + if raw_keys["sparse_matmul"] and using_tensor_parallelism(raw_keys) and (raw_keys["emb_dim"] % tensor_parallelism): + raise ValueError( + f"The embedding dimension {raw_keys['emb_dim']} is not divisible by tensor parallelism setting {tensor_parallelism}." + ) + expert_parallelism = raw_keys["ici_expert_parallelism"] * raw_keys["dcn_expert_parallelism"] + if raw_keys["sparse_matmul"] and using_expert_parallelism(raw_keys) and (raw_keys["num_experts"] % expert_parallelism): + raise ValueError( + f"The expert dimension {raw_keys['num_experts']} is not divisible by expert parallelism setting {expert_parallelism}." + ) + + +def validate_ring_of_experts_parallelism(raw_keys): + if raw_keys["use_ring_of_experts"] and not using_expert_parallelism(raw_keys): + raise ValueError("Ring-of-experts requires expert-parallelism to be enabled.") + + +def validate_shard_fsdp_on_expert_parallelism(raw_keys): + if raw_keys["fsdp_shard_on_exp"] and raw_keys["num_experts"] % raw_keys["ici_fsdp_parallelism"] != 0: + raise ValueError("fsdp_shard_on_exp requires num_experts is divisiable by ici_fsdp_parallelism.") + if raw_keys["fsdp_shard_on_exp"] and (using_tensor_parallelism(raw_keys) or using_expert_parallelism(raw_keys)): + raise ValueError( + "fsdp_shard_on_exp requires ici_expert_parallelism = 1 and ici_tensor_parallelism/ici_tensor_transpose_parallelism = 1." + ) + + +def validate_ragged_dot(raw_keys): + if raw_keys["sparse_matmul"] and not raw_keys["megablox"]: + config_flag = "jax_ragged_dot_use_ragged_dot_instruction" + try: + jax.config.update(config_flag, True) + except AttributeError: + max_logging.log(f"JAX config {config_flag} not found, possibly due to old JAX version.") + + +def create_new_logical_axis_rules(old_logical_axis_rules, new_logical_axis_rules): + new_logical_axis = set() + replacements = [] + for logical_axis, mesh_axes in new_logical_axis_rules: + logical_axis_exists = any(rule for rule in old_logical_axis_rules if rule[0] == logical_axis) + if not logical_axis_exists: +>>>>>>> jaxpp/main continue new_value = value @@ -127,6 +432,141 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]: return pydantic_kwargs +<<<<<<< HEAD +======= +def update_model_keys(raw_keys, model_keys, key): + """Update `key` value in `raw_keys` from the value in `model_keys`.""" + assert key in model_keys and key in raw_keys + if key == "logical_axis_rules": + raw_keys[key] = create_new_logical_axis_rules( + old_logical_axis_rules=raw_keys[key], new_logical_axis_rules=model_keys[key] + ) + return + raw_keys[key] = model_keys[key] + + +def validate_and_update_keys(raw_keys, model_keys, config_name: str): + """Validate and update model specific config keys""" + max_logging.log("Updating following parameters in config\n") + + for k in model_keys: + max_logging.log(f"{k}: {model_keys[k]}") + if k not in raw_keys: + raise ValueError(f"Key {k} does not exist in config {config_name}.") + elif not isinstance(raw_keys[k], type(model_keys[k])): + raise ValueError(f"Type of key:{k} does not match with {type(model_keys[k])}") + else: + update_model_keys(raw_keys, model_keys, k) + return raw_keys + + +def get_individual_scales(scale): + """Choose appropriate scales for individual dimensions based on global scale + We choose to rotate between doubling: + num_head and mlp_dim + embed_dim + num_layers + Any one of these steps is not a perfect doubling, although going through a cycle + of three is a near perfect 8x scaling except for the linear -> softmax -> output step""" + + log_2_scale = math.floor((math.log2(scale))) + if 2**log_2_scale != scale: + raise ValueError( + "Global parameter scale should be a power of 2. If you want finer grained control of the model sizes " + "then you can explicitly set base_embed_dim, base_num_heads, base_mlp_dim, base_num_decoder_layers and/or head_dim." + ) + base_scale, rem = divmod(log_2_scale, 3) + num_head_scale = base_scale + int(rem > 0) + mlp_dim_scale = num_head_scale + emb_scale = base_scale + int(rem > 1) + layer_scale = base_scale + return emb_scale, num_head_scale, mlp_dim_scale, layer_scale + + +def calculate_global_batch_sizes( + per_device_batch_size, expansion_factor_real_data, num_devices, gradient_accumulation_steps +): + """Calculates target global batch size from target devices and per_device_batch""" + if per_device_batch_size < 1.0: + # For per_device_batch_size<1, we load the data as if per_device_batch_size=1 + if expansion_factor_real_data != -1: + micro_batch_size_to_load = num_devices * expansion_factor_real_data + else: + micro_batch_size_to_load = num_devices + else: + if expansion_factor_real_data != -1: + micro_batch_size_to_load = int(num_devices * per_device_batch_size * expansion_factor_real_data) + else: + micro_batch_size_to_load = int(num_devices * per_device_batch_size) + + micro_batch_size_to_train_on = int(num_devices * per_device_batch_size) + global_batch_size_to_load = int(micro_batch_size_to_load * gradient_accumulation_steps) + global_batch_size_to_train_on = int(micro_batch_size_to_train_on * gradient_accumulation_steps) + return global_batch_size_to_load, global_batch_size_to_train_on, micro_batch_size_to_train_on + + +def get_num_target_devices(raw_keys): + # In AOT case compile_topology is set (e.g. is not the empty string), and we determine the + # number of devices from the compile_topology. In non-AOT settings we simply can use jax.devices(). + if raw_keys.get("compile_topology"): + compile_topology = accelerator_to_spec_map.get_system_characteristics(raw_keys["compile_topology"]) + devices_per_slice = compile_topology.devices_per_slice + return int(devices_per_slice * raw_keys["compile_topology_num_slices"]) + elif raw_keys.get("subslice_shape") and raw_keys.get("enable_single_controller"): + subslice_shape = tuple(int(x) for x in raw_keys["subslice_shape"].split(",")) + return prod(subslice_shape) + else: + return len(jax.devices()) + + +def get_quantization_local_shard_count(raw_keys): + if raw_keys["quantization_local_shard_count"] == -1: + return raw_keys["num_slices"] + else: + return raw_keys["quantization_local_shard_count"] + + +def get_context_parallel_size(raw_keys): + cp_size = raw_keys["ici_context_parallelism"] * raw_keys["dcn_context_parallelism"] + # ep acts as cp in attention + if raw_keys["expert_shard_attention_option"] == "context": + cp_size = cp_size * raw_keys["ici_expert_parallelism"] * raw_keys["dcn_expert_parallelism"] + return cp_size + + +def using_pipeline_parallelism(raw_keys) -> bool: + return raw_keys["use_jaxpp"] or int(raw_keys["ici_pipeline_parallelism"]) > 1 or int(raw_keys["dcn_pipeline_parallelism"]) > 1 + +def using_tensor_parallelism(raw_keys) -> bool: + return ( + int(raw_keys["ici_tensor_parallelism"]) > 1 + or int(raw_keys["dcn_tensor_parallelism"]) > 1 + or int(raw_keys["ici_tensor_sequence_parallelism"]) > 1 + or int(raw_keys["dcn_tensor_sequence_parallelism"]) > 1 + ) + + +def using_sequence_parallelism(raw_keys) -> bool: + return int(raw_keys["ici_sequence_parallelism"]) > 1 or int(raw_keys["dcn_sequence_parallelism"]) > 1 + + +def using_expert_parallelism(raw_keys) -> bool: + if int(raw_keys["ici_expert_parallelism"]) > 1 and int(raw_keys["dcn_expert_parallelism"]) > 1: + raise ValueError("Expert parallelism can only be enabled on ICI or DCN, not both.") + return int(raw_keys["ici_expert_parallelism"]) > 1 or int(raw_keys["dcn_expert_parallelism"]) > 1 + + +def using_fsdp_and_transpose_parallelism(raw_keys) -> bool: + return ( + int(raw_keys["ici_fsdp_parallelism"]) > 1 + or int(raw_keys["dcn_fsdp_parallelism"]) > 1 + or int(raw_keys["ici_fsdp_transpose_parallelism"]) > 1 + or int(raw_keys["dcn_fsdp_transpose_parallelism"]) > 1 + ) + + +@register_pytree_node_class +>>>>>>> jaxpp/main class HyperParameters: """ Wrapper class to expose the configuration in a read-only manner, diff --git a/src/MaxText/train.py b/src/MaxText/train.py index 96bd799b1e..243d68c905 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,6 +42,8 @@ _diag_modules = _cloud_diag() diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules +from packaging.version import Version + from MaxText import checkpointing from MaxText import exceptions from MaxText import max_logging @@ -71,11 +74,25 @@ from MaxText.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn from MaxText.train_utils import validate_train_config from MaxText.metric_logger import record_activation_metrics + +""" +JaxPP related imports +""" +# system +import subprocess + +from statistics import mean + +# jaxpp +from jaxpp import __version__ as jaxpp_version +from packaging.version import Version +import jaxpp.api as jaxpp + # pylint: disable=too-many-positional-arguments def get_first_step(state): - return int(state.step) + return int(max_utils.maybe_unwrap(state.step)) # ----------------------------------------------------------------------------- @@ -99,7 +116,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): aux: a dictionary including intermediate_outputs, total_loss, and total_weights """ # decimate proportion of data when per_device_batch_size<1 - if is_train: + if is_train and not config.use_jaxpp: for k, v in data.items(): data[k] = v[: config.micro_batch_size_to_train_on, :] else: @@ -213,6 +230,44 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): return loss, aux +def load_schedule(config): + pipeline_parallel_dim = config.dcn_pipeline_parallelism * config.ici_pipeline_parallelism + num_logical_stages = config.num_pipeline_repeats * pipeline_parallel_dim + schedule = None + if config.schedule == "1f1b": + assert num_logical_stages <= pipeline_parallel_dim + schedule = jaxpp.Std1F1B(num_logical_stages) + elif config.schedule == "eager_1f1b": + assert num_logical_stages <= pipeline_parallel_dim + schedule = jaxpp.Eager1F1B(num_logical_stages) + elif config.schedule == "interleaved_1f1b": + if Version(jaxpp_version) > Version("0.6.1"): + schedule = jaxpp.Interleaved1F1B(num_logical_stages, pipeline_parallel_dim, config.fuse_steady_state) + else: + schedule = jaxpp.Interleaved1F1B(num_logical_stages, pipeline_parallel_dim) + elif config.schedule == "zero_bubble": + assert num_logical_stages <= pipeline_parallel_dim + schedule = jaxpp.ZeroBubble(num_logical_stages) + elif config.schedule == "dualpipev": + schedule = jaxpp.DualPipeV(num_logical_stages, pipeline_parallel_dim) + else: + raise NotImplementedError(f"Unknown schedule {config.schedule}") + return schedule + + +def add_leading_axis( + axis_name: str, path: jax.tree_util.KeyPath, s: jax.sharding.NamedSharding +): + assert isinstance(s, jax.sharding.NamedSharding) + used = {n for ns in s.spec for n in (ns if isinstance(ns, tuple) else (ns,))} + if axis_name in used: + raise ValueError( + f"mesh axis name {axis_name} cannot appear in " + f"out_shardings. Found out_shardings{jax.tree_util.keystr(path)}={s.spec}" + ) + return jax.sharding.NamedSharding(s.mesh, jax.sharding.PartitionSpec(axis_name, *s.spec), memory_kind=s.memory_kind) + + def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): """ @@ -261,24 +316,92 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat max_utils.with_memory_kind(reference_params_sharding, "device"), ) extra_dpo_args = [reference_params] + if config.shard_optimizer_over_data: params = jax.tree.map(jax.lax.with_sharding_constraint, params, params_shardings) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) - + def compute_grads(data): + grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) + def cast(p, a): + sp = jax.tree_util.keystr(p) + if 'token_embedder' in sp or 'position_embedder' in sp: + return a + return a.astype(jnp.dtype(config.grad_dtype)) + raw_grads['params'] = jax.tree_util.tree_map_with_path(cast, raw_grads['params']) + return ((loss, aux), raw_grads) + +<<<<<<< HEAD raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, raw_grads, ) +======= + if not config.use_jaxpp: + (loss, aux), raw_grads = compute_grads(data) + else: + def microbatched(a): + shape = ( + state_mesh_shardings.step.mesh.shape["data"], + config.num_pipeline_microbatches, + -1, + config.max_target_length, + ) + if shape[0] == 1: + shape = shape[1:] + return a.reshape(*shape) + data = jax.tree.map(microbatched, data) + + # Perform data parallelism manually through `vmap` + vmapped_compute_grads = compute_grads + if state_mesh_shardings.step.mesh.shape["data"] > 1: + vmapped_compute_grads = jax.vmap(compute_grads, spmd_axis_name="data") + + loss_aux_sharding = jax.sharding.NamedSharding(state_mesh_shardings.step.mesh, jax.sharding.PartitionSpec()) + param_operation = {'params': jaxpp.Add} + if nn.fp8_ops.OVERWRITE_WITH_GRADIENT in params: + param_operation[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = jaxpp.Max + assert all(k in param_operation for k in params.keys()) + + axis = 1 if state_mesh_shardings.step.mesh.shape["data"] > 1 else 0 + (loss, aux), raw_grads = jaxpp.treduce( + vmapped_compute_grads, + data, + axis=axis, + schedule=load_schedule(config), + operation=(jaxpp.Concat(axis=axis), param_operation) + ) + + if state_mesh_shardings.step.mesh.shape["data"] > 1: + (loss, aux), raw_grads = jax.lax.with_sharding_constraint( + ((loss, aux), raw_grads), + jax.tree.map_with_path( + functools.partial(add_leading_axis, "data"), + (loss_aux_sharding, params_shardings) + ), + ) + # reduce-scatter gradients across "data" + owg = raw_grads.pop(nn.fp8_ops.OVERWRITE_WITH_GRADIENT, None) + raw_grads = jax.tree.map(functools.partial(jax.numpy.sum, axis=0), raw_grads) + if owg is not None: + owg = jax.tree.map(functools.partial(jax.numpy.max, axis=0), owg) + raw_grads[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = owg + + raw_grads = jax.lax.with_sharding_constraint(raw_grads, state_mesh_shardings.params) + raw_grads = jax.tree_util.tree_map(lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, raw_grads) + owg = raw_grads.pop(nn.fp8_ops.OVERWRITE_WITH_GRADIENT, None) + raw_grad_norm = max_utils.l2norm_pytree(raw_grads) +>>>>>>> jaxpp/main intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] mtp_loss = aux["mtp_loss"] if config.gradient_clipping_threshold > 0: - grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, raw_grad_norm, config.gradient_clipping_threshold) else: grads = raw_grads + if owg is not None: + grads[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = owg if config.optimizer_memory_host_offload: state = state.replace( opt_state=jax.device_put( @@ -305,16 +428,36 @@ def move(path, value): ) new_state = state.apply_gradients(grads=grads) - scalar_metrics = { - "learning/loss": loss, - "learning/moe_lb_loss": moe_lb_loss, - "learning/mtp_loss": mtp_loss, - "learning/total_weights": total_weights, - } + if config.use_jaxpp: + # TODO: refine logic to match the one in MaxText's gradient accumulation + # or use that altogether (add support for scan instead of + # treduce in JaxPP) + scalar_metrics = { + "learning/loss": loss.sum() / total_weights.sum(), + "learning/moe_lb_loss": moe_lb_loss.sum(), + "learning/mtp_loss": mtp_loss.sum(), + "learning/total_weights": total_weights.sum(), + } + else: + scalar_metrics = { + "learning/loss": loss, + "learning/moe_lb_loss": moe_lb_loss, + "learning/mtp_loss": mtp_loss, + "learning/total_weights": total_weights, + } if not config.optimizer_memory_host_offload: - scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) - scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) - scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + owg = grads.pop(nn.fp8_ops.OVERWRITE_WITH_GRADIENT, None) + scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads["params"]) + if owg is not None: + grads[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = owg + scalar_metrics["learning/raw_grad_norm"] = raw_grad_norm + + new_params = new_state.params + owg = new_params.pop(nn.fp8_ops.OVERWRITE_WITH_GRADIENT, None) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_params) + if owg is not None: + new_params[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = owg + new_state = new_state.replace(params=new_params) if config.use_dpo: scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] metrics = { @@ -364,7 +507,7 @@ def eval_step(model, config, state, data, dropout_rng): if config.use_dpo: metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] - return metrics + return jax.tree.map(jax._src.numpy.lax_numpy._array_copy, metrics) def train_loop(config, recorder, state=None): @@ -374,7 +517,7 @@ def train_loop(config, recorder, state=None): checkpoint_manager, state_mesh_shardings, model, - mesh, + maybe_mpmd_mesh, learning_rate_schedule, data_iterator, eval_data_iterator, @@ -387,6 +530,7 @@ def train_loop(config, recorder, state=None): state = _merge_dpo_state(state, reference_params) state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) +<<<<<<< HEAD params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( @@ -406,6 +550,22 @@ def train_loop(config, recorder, state=None): if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded +======= + mesh = maybe_mpmd_mesh.lowering_mesh() if config.use_jaxpp else maybe_mpmd_mesh + params_shardings, state_mesh_shardings = maxtext_utils.maybe_update_params_sharding_with_opt( + config, state_mesh_shardings + ) + + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, model, maybe_mpmd_mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings + ) + + if not config.use_jaxpp: + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + shaped_batch = maxtext_utils.get_shaped_batch(config) + if config.shard_optimizer_over_data: + state = jax.lax.with_sharding_constraint(state, state_mesh_shardings) +>>>>>>> jaxpp/main compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) @@ -419,9 +579,18 @@ def train_loop(config, recorder, state=None): metric_logger.write_setup_info_to_tensorboard(state.params) try: + step_time = [] + step_tflops = [] + # NOTE: The dict values are unused when use_jaxpp is False. + profiling_process_ids = {pid: "" for pid in jax.process_indices()} + if config.use_jaxpp: + idx = tuple(slice(None) if i == maybe_mpmd_mesh.mpmd_axis else 0 for i in range(len(maybe_mpmd_mesh.jax_mesh.shape))) + first_device_per_mpmd_rank = maybe_mpmd_mesh.jax_mesh.devices[idx] + profiling_process_ids = {d.process_index: d for d in first_device_per_mpmd_rank} + last_step_completion = datetime.datetime.now() for step in np.arange(start_step, config.steps): - prof.maybe_activate_profiler(step, state) + prof.maybe_activate_profiler(step, state, maybe_mpmd_mesh=maybe_mpmd_mesh, profiling_process_ids=profiling_process_ids) with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch() @@ -435,8 +604,13 @@ def train_loop(config, recorder, state=None): nextrng = jax.jit(jax.random.fold_in)(init_rng, step) with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): +<<<<<<< HEAD if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) +======= + if config.shard_optimizer_over_data and not config.use_jaxpp: + state = jax.lax.with_sharding_constraint(state, state_mesh_shardings) +>>>>>>> jaxpp/main state, metrics = p_train_step(state, example_batch, nextrng) step_time_delta = datetime.datetime.now() - last_step_completion @@ -467,6 +641,7 @@ def train_loop(config, recorder, state=None): if config.eval_steps > 0 and eval_step_count >= config.eval_steps: break with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + eval_batch = jax.tree_util.tree_map(lambda a: a[:2], eval_batch) eval_metrics = p_eval_step(state, eval_batch, nextrng) metric_logger.record_eval_metrics(step, metrics=eval_metrics) max_logging.log(f"Completed eval step {eval_step_count}") @@ -476,12 +651,21 @@ def train_loop(config, recorder, state=None): prof.deactivate() raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - prof.maybe_deactivate_profiler(step, state) + prof.maybe_deactivate_profiler(step, state, profiling_process_ids=profiling_process_ids) if step == start_step: max_utils.print_mem_stats("After params initialized") metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + step_time.append(metrics['scalar']['perf/step_time_seconds']) + step_tflops.append(metrics['scalar']['perf/per_device_tflops_per_sec']) + + if config.use_jaxpp and prof.mode != "": + command = """find . -wholename '*proc_*_mpmd*/*.xplane.pb' | sort | awk '{line=$0; sub(/.*mpmd_/, "", line); sub(/_.*/, "", line); printf "%d:%s:0 ", line, $0}'""" + subprocess.run( + [f"merge_multihost_xplanes $({command})"], + shell=True, cwd=config.tensorboard_dir, check=True + ) if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] @@ -494,6 +678,13 @@ def train_loop(config, recorder, state=None): finally: metric_logger.flush_metrics_and_cleanup() + # last_profiling_step + 2 as (1) we count steps from 0, and (2) the execution time for merge_multihost_xplanes is + # counted toward the execution time for the step right after the last profiling step. + num_warmup_steps = (prof.finished_initial_profile_step + 2) if prof.mode != "" else 6 + max_logging.log( + f"excluding the first {num_warmup_steps} steps: avg time per step {mean(step_time[num_warmup_steps:])}, avg tflops per step {mean(step_tflops[num_warmup_steps:])}" + ) + return state diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index 4b0d5e88df..f4edcf1807 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,6 +47,9 @@ from MaxText.layers import quantizations from MaxText.utils import gcs_utils +import jaxpp.api as jaxpp +from jaxpp.api import MpmdMesh + # pylint: disable=too-many-positional-arguments Transformer = models.transformer_as_linen @@ -118,18 +122,32 @@ def jit_and_compile( static_argnums, donate_argnums, logical_axis_rules, + mpmd_mesh=None, ): """Jit, lower, and compile func.""" with mesh, logical_axis_rules: - jitted = jax.jit( + if mpmd_mesh is not None: + p_train_step = jaxpp.mpmd_jit_with_loop( func, + mpmd_mesh=mpmd_mesh, + donate_argnums=donate_argnums, in_shardings=in_shardings, out_shardings=out_shardings, - static_argnums=static_argnums, - donate_argnums=donate_argnums, - ) - lowered = jitted.lower(*func_input_args, **func_input_kwargs) - compiled = lowered.compile() + ) + assert len(func_input_kwargs) == 0 + compiled = p_train_step.compile(*func_input_args) + else: + jitted = jax.jit( + func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + donate_argnums=donate_argnums, + ) + lowered = jitted.lower(*func_input_args, **func_input_kwargs) + + if mpmd_mesh is None: + compiled = lowered.compile() return compiled @@ -203,21 +221,28 @@ def main(argv: Sequence[str]) -> None: # Create target mesh topology_mesh = get_topology_mesh(config) + if config.use_jaxpp: + mpmd_mesh = MpmdMesh(topology_mesh, 'stage') + mesh = mpmd_mesh.lowering_mesh() + else: + mpmd_mesh = None + mesh = topology_mesh # Print system information after building the compile topology to avoid # prematurely initializing the backend. max_utils.print_system_information() # Get shaped inputs - shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config) + shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(mesh, config) + params_shardings, state_mesh_shardings = max_utils.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) # Get data sharding - data_sharding = sharding.get_input_data_sharding(config, topology_mesh) + data_sharding = sharding.get_input_data_sharding(config, mesh) # Get function to compile and shardings func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = ( maxtext_utils.get_functional_train_with_signature( - train.train_step, data_sharding, state_mesh_shardings, model, config + train.train_step, data_sharding, state_mesh_shardings, model, config, params_shardings=params_shardings ) ) @@ -227,12 +252,13 @@ def main(argv: Sequence[str]) -> None: func_to_compile, shaped_train_args, shaped_train_kwargs, - topology_mesh, + mesh, in_shard, out_shard, static_argnums, donate_argnums, nn_partitioning.axis_rules(config.logical_axis_rules), + mpmd_mesh=mpmd_mesh ) print("Jitting and compilation complete!", flush=True) @@ -241,9 +267,11 @@ def main(argv: Sequence[str]) -> None: print("Saving compiled object...") save_compiled(compiled, config.compiled_trainstep_file) print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") - print("Finished train_compile.py successfully!", flush=True) - print(f"Cost analysis: {compiled.cost_analysis()}") - print(f"Memory analysis: {compiled.memory_analysis()}") + + if not config.use_jaxpp: + print("Finished train_compile.py successfully!", flush=True) + print(f"Cost analysis: {compiled.cost_analysis()}") + print(f"Memory analysis: {compiled.memory_analysis()}") # Dump HLO if requested if config.dump_hlo: diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index 0a5788c250..762af93725 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -28,6 +28,7 @@ from MaxText.utils.goodput_utils import maybe_record_goodput from MaxText import model_creation_utils +import jaxpp.api as jaxpp def create_training_tools(config, model, mesh): """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" @@ -76,7 +77,7 @@ def create_training_tools(config, model, mesh): return init_rng, checkpoint_manager, learning_rate_schedule, tx -def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings): +def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, maybe_mpmd_mesh, params_shardings): """Returns a JIT-compiled train step function, which is loaded from a file if specified in the config.""" ( functional_train, @@ -96,18 +97,27 @@ def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, tr p_train_step = maxtext_utils.load_compiled(config, functional_train, state, execution_devices) max_logging.log("Loaded compiled function!") else: - p_train_step = jax.jit( + if not config.use_jaxpp: + p_train_step = jax.jit( functional_train, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, + donate_argnums=donate_argnums) + else: + max_logging.log("Running with jaxpp") + p_train_step = jaxpp.mpmd_jit_with_loop( + functional_train, + mpmd_mesh=maybe_mpmd_mesh, donate_argnums=donate_argnums, - ) + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return p_train_step -def jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step): +def jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step, maybe_mpmd_mesh): """Returns a JIT-compiled eval step function.""" ( functional_eval, @@ -119,13 +129,23 @@ def jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) p_eval_step = None if config.compiled_trainstep_file == "": - p_eval_step = jax.jit( + if not config.use_jaxpp: + p_eval_step = jax.jit( + functional_eval, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + donate_argnums=donate_argnums, + ) + else: + p_eval_step = jaxpp.mpmd_jit_by_yield( functional_eval, + mpmd_mesh=maybe_mpmd_mesh, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, donate_argnums=donate_argnums, - ) + ) return p_eval_step @@ -133,7 +153,7 @@ def jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) def jit_train_and_eval_step( config, model, - mesh, + maybe_mpmd_mesh, state, state_mesh_shardings, train_step, @@ -142,11 +162,12 @@ def jit_train_and_eval_step( params_shardings=None, ): """Returns a JIT-compiled train and eval step function.""" + mesh = maybe_mpmd_mesh.lowering_mesh() if config.use_jaxpp else maybe_mpmd_mesh data_sharding = sharding.get_input_data_sharding(config, mesh) - p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings) + p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, maybe_mpmd_mesh, params_shardings) p_eval_step = None if eval_data_iterator: - p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) + p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step, maybe_mpmd_mesh) return p_train_step, p_eval_step @@ -172,7 +193,13 @@ def setup_train_loop(config, recorder, devices=None): with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): model = model_creation_utils.from_config(config, devices) - mesh = model.mesh + maybe_mpmd_mesh = model.mesh + if config.use_jaxpp: + assert isinstance(maybe_mpmd_mesh, jaxpp.MpmdMesh) + model.mesh = mesh = maybe_mpmd_mesh.lowering_mesh() + else: + assert isinstance(maybe_mpmd_mesh, jax.sharding.Mesh) + mesh = maybe_mpmd_mesh init_rng, checkpoint_manager, learning_rate_schedule, tx = create_training_tools(config, model, mesh) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): @@ -197,9 +224,17 @@ def setup_train_loop(config, recorder, devices=None): ) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + model, data_iterator, tx, config, init_rng, maybe_mpmd_mesh, checkpoint_manager ) + def make_line(keypath, array_or_array_ref): + sharding = array_or_array_ref.sharding + return (f"{jax.tree_util.keystr(keypath):<120}, {str(array_or_array_ref.dtype):<10}, " + f"{str(array_or_array_ref.shape):<26}, {sharding._to_xla_hlo_sharding(array_or_array_ref.ndim)}") + + max_logging.log("shardings/weights") + max_logging.log("\n".join(make_line(keypath, array_ref) for keypath, array_ref in jax.tree_util.tree_leaves_with_path(state))) + # TODO(aireenmei, hengtaoguo): support sharding in vit for multimodal if not config.using_pipeline_parallelism and not config.use_multimodal: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage @@ -242,7 +277,7 @@ def setup_train_loop(config, recorder, devices=None): checkpoint_manager, state_mesh_shardings, model, - mesh, + maybe_mpmd_mesh, learning_rate_schedule, data_iterator, eval_data_iterator, diff --git a/tests/train_compile_jaxpp_test.py b/tests/train_compile_jaxpp_test.py new file mode 100644 index 0000000000..9755ce5a3e --- /dev/null +++ b/tests/train_compile_jaxpp_test.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from MaxText.train_compile import main as train_compile_main + + +def get_args( + model_name, + num_nodes, + ici_dp, + ici_tp, + dcn_pp, + vp, + ga, + per_device_batch_size=2, + ici_cp=1, + ici_ep=1, + dcn_ep=1, + quantization=None, + disable_cache: bool = False, +): + res = ( + None, + "MaxText/configs/base.yml", + f"model_name={model_name}", + "attention=dot_product", + "remat_policy=minimal", + "dtype=bfloat16", + "max_target_length=2048", + f"per_device_batch_size={per_device_batch_size}", + "hardware=gpu", + # SPMD Parallelism + f"ici_data_parallelism={ici_dp}", + f"ici_tensor_parallelism={ici_tp}", + f"ici_context_parallelism={ici_cp}", + # Pipeline + f"dcn_pipeline_parallelism={dcn_pp}", + f"num_pipeline_microbatches={ga}", + f"num_pipeline_repeats={vp}", + # JaxPP + "use_jaxpp=True", + "schedule=interleaved_1f1b", + "compile_topology=a3", + f"compile_topology_num_slices={num_nodes}", + ) + if ici_ep > 1: + res = res + (f"ici_expert_parallelism={ici_ep}",) + if dcn_ep > 1: + res = res + (f"dcn_expert_parallelism={dcn_ep}",) + if quantization is not None: + res = res + (f"quantization={quantization}",) + if disable_cache: + res = res + ("jax_cache_dir=",) + return res + + +class TrainCompile(unittest.TestCase): + def test_compile_llama4(self): + train_compile_main( + get_args( + model_name="llama4-17b-16e", + num_nodes=32, + ici_dp=1, + ici_ep=2, + ici_tp=4, + dcn_ep=8, + dcn_pp=4, + vp=4, + ga=64, + per_device_batch_size=4 + ) + ) + + def test_compile_gpt3(self): + train_compile_main( + get_args( + model_name="gpt3-175b", + num_nodes=16, + ici_dp=2, + ici_tp=4, + dcn_pp=8, + vp=6, + ga=32, + ) + ) + + def test_compile_gpt3_fp8(self): + train_compile_main( + get_args( + model_name="gpt3-175b", + num_nodes=16, + ici_dp=2, + ici_tp=4, + dcn_pp=8, + vp=6, + ga=32, + quantization="fp8", + ) + ) + + def test_compile_llama3(self): + train_compile_main( + get_args( + model_name="llama3.3-70b", + num_nodes=8, + ici_dp=1, + ici_cp=2, + ici_tp=4, + dcn_pp=4, + vp=5, + ga=64, + per_device_batch_size=2 + ) + )