Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions docs/guides/checkpointing_solutions/gcs_checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,30 @@ startup. The first valid condition met is the one executed:

### MaxText configuration

Flag | Description | Type | Default
:------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :------
`enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False`
`async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True`
`checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000`
`enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False`
`checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` (Ignored if directory is prefixed with gs://) | `string` | `""`
`checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""`
`load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled)
`load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled)
`lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled)
`force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False`
| Flag | Description | Type | Default |
| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- |
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
| `checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` (Ignored if directory is prefixed with `gs://`) | `string` | `""` |
| `checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""` |
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |

## Storage and format configuration

These settings control the underlying storage mechanism
([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.

Flag | Description | Type | Default
:----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------
`checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB)
`checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True`
`checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True`
`checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96`
`enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False`
`source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"`
`checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None`
| Flag | Description | Type | Default |
| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ |
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` |
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` |
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |
4 changes: 2 additions & 2 deletions docs/guides/data_input_pipeline/data_input_grain.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pr

```sh
bash src/dependencies/scripts/setup_gcsfuse.sh \
DATASET_GCS_BUCKET=maxtext-dataset \
DATASET_GCS_BUCKET=gs://<your-dataset-bucket> \
MOUNT_PATH=/tmp/gcsfuse && \
python3 -m maxtext.trainers.pre_train.train \
run_name=<RUN_NAME> base_output_directory=gs://<MY_BUCKET> \
run_name=<run-name> base_output_directory=gs://<your-bucket> \
dataset_type=grain \
grain_file_type=arrayrecord # or parquet \
grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 -m maxtext.trainers.pre_train.train run_name=example_load_compile \
compiled_trainstep_file=my_compiled_train.pickle \
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3 \
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
base_output_directory=gs://<your-output-bucket> dataset_path=gs://<your-dataset-bucket>
```

In the save step of example 2 above we included exporting the compiler flag `LIBTPU_INIT_ARGS` and `learning_rate` because those affect the compiled object `my_compiled_train.pickle.` The sizes of the model (e.g. `global_parameter_scale`, `max_sequence_length` and `per_device_batch`) are fixed when you initially compile via `compile_train.py`, you will see a size error if you try to run the saved compiled object with different sizes than you compiled with. However a subtle note is that the **learning rate schedule** is also fixed when you run `compile_train` - which is determined by both `steps` and `learning_rate`. The optimizer parameters such as `adam_b1` are passed only as shaped objects to the compiler - thus their real values are determined when you run `train.py`, not during the compilation. If you do pass in different shapes (e.g. `per_device_batch`), you will get a clear error message reporting that the compiled signature has different expected shapes than what was input. If you attempt to run on different hardware than the compilation targets requested via `compile_topology`, you will get an error saying there is a failure to map the devices from the compiled to your real devices. Using different XLA flags or a LIBTPU than what was compiled will probably run silently with the environment you compiled in without error. However there is no guaranteed behavior in this case; you should run in the same environment you compiled in.
Expand Down Expand Up @@ -125,7 +125,7 @@ export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 -m maxtext.trainers.pre_train.train run_name=example_load_compile \
compiled_trainstep_file=my_compiled_train.pickle \
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3 \
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
base_output_directory=gs://<your-output-bucket> dataset_path=gs://<your-dataset-bucket>
```

As in the TPU case, note that the compilation environment must match the execution environment, in this case by setting the same `XLA_FLAGS`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ MaxText has integrated the ML Diagnostics [SDK](https://github.com/AI-Hypercompu
```
python3 -m maxtext.trainers.pre_train.train \
run_name=${USER}-tpu-job \
base_output_directory="gs://your-output-bucket/" \
dataset_path="gs://your-dataset-bucket/" \
base_output_directory="gs://<your-output-bucket>/" \
dataset_path="gs://<your-dataset-bucket>/" \
steps=100 \
log_period=10 \
managed_mldiagnostics=True
Expand All @@ -49,8 +49,8 @@ MaxText has integrated the ML Diagnostics [SDK](https://github.com/AI-Hypercompu
```
python3 -m maxtext.trainers.pre_train.train \
run_name=${USER}-tpu-job \
base_output_directory="gs://your-output-bucket/" \
dataset_path="gs://your-dataset-bucket/" \
base_output_directory="gs://<your-output-bucket>/" \
dataset_path="gs://<your-dataset-bucket>/" \
steps=100 \
log_period=10 \
profiler=xplane \
Expand All @@ -62,8 +62,8 @@ MaxText has integrated the ML Diagnostics [SDK](https://github.com/AI-Hypercompu
```
python3 -m maxtext.trainers.pre_train.train \
run_name=${USER}-tpu-job \
base_output_directory="gs://your-output-bucket/" \
dataset_path="gs://your-dataset-bucket/" \
base_output_directory="gs://<your-output-bucket>/" \
dataset_path="gs://<your-dataset-bucket>/" \
steps=100 \
log_period=10 \
profiler=xplane \
Expand Down
Loading
Loading