From 9de53213b38a87d1ed9f9645ea0604e848d1885d Mon Sep 17 00:00:00 2001 From: Gagik Amirkhanyan Date: Sun, 26 Apr 2026 20:31:38 +0000 Subject: [PATCH] Add dataset type olmo_grain for AI2 OLMo numpy pretrain mixes. --- docs/guides/data_input_pipeline/olmo_grain.md | 76 +++ scripts/run_olmo3_7b_grain_resume_test.sh | 135 +++++ scripts/run_olmo3_7b_grain_smoke.sh | 96 ++++ src/maxtext/configs/base.yml | 8 + src/maxtext/configs/types.py | 22 + .../input_pipeline_interface.py | 5 +- src/maxtext/input_pipeline/olmo_data.py | 415 +++++++++++++++ src/maxtext/input_pipeline/olmo_data_grain.py | 497 ++++++++++++++++++ .../olmo_grain_data_processing.py | 195 +++++++ tests/unit/input_pipeline/__init__.py | 0 .../olmo_data_grain_resume_test.py | 195 +++++++ .../input_pipeline/olmo_data_grain_test.py | 353 +++++++++++++ tests/unit/input_pipeline/olmo_data_test.py | 350 ++++++++++++ tools/data_generation/build_olmo_npy_index.py | 278 ++++++++++ .../download_olmo_data_to_gcs.py | 331 ++++++++++++ 15 files changed, 2955 insertions(+), 1 deletion(-) create mode 100644 docs/guides/data_input_pipeline/olmo_grain.md create mode 100755 scripts/run_olmo3_7b_grain_resume_test.sh create mode 100755 scripts/run_olmo3_7b_grain_smoke.sh create mode 100644 src/maxtext/input_pipeline/olmo_data.py create mode 100644 src/maxtext/input_pipeline/olmo_data_grain.py create mode 100644 src/maxtext/input_pipeline/olmo_grain_data_processing.py create mode 100644 tests/unit/input_pipeline/__init__.py create mode 100644 tests/unit/input_pipeline/olmo_data_grain_resume_test.py create mode 100644 tests/unit/input_pipeline/olmo_data_grain_test.py create mode 100644 tests/unit/input_pipeline/olmo_data_test.py create mode 100755 tools/data_generation/build_olmo_npy_index.py create mode 100755 tools/data_generation/download_olmo_data_to_gcs.py diff --git a/docs/guides/data_input_pipeline/olmo_grain.md b/docs/guides/data_input_pipeline/olmo_grain.md new file mode 100644 index 0000000000..8b519c1cd2 --- /dev/null +++ b/docs/guides/data_input_pipeline/olmo_grain.md @@ -0,0 +1,76 @@ +# OLMo numpy pipeline (`dataset_type=olmo_grain`) + +Grain-based input pipeline for AI2's pre-tokenized OLMo data mixes (e.g. +`OLMo-mix-0925-official.txt`). Reads headerless flat `.npy` token streams +from a gcsfuse mount, shards across hosts, optionally masks repeated-n-gram +instances, and yields the shapes the MaxText pretrain trainer expects. + +## Quick start + +1. **Download the data** to a GCS bucket: + + ```bash + python tools/data_generation/download_olmo_data_to_gcs.py \ + --mix-file /path/to/OLMo-mix-0925-official.txt \ + --gcs-dest gs://my-bucket/dataset/ \ + --staging-dir /mnt/local-ssd/olmo-staging \ + --workers 16 + ``` + +2. **Mount it read-only** with gcsfuse (`np.memmap` needs a local path): + + ```bash + gcsfuse --implicit-dirs --o ro my-bucket /mnt/olmo-readonly + ``` + +3. **Build the index**: + + ```bash + python tools/data_generation/build_olmo_npy_index.py \ + --mix-file /path/to/OLMo-mix-0925-official.txt \ + --gcs-base gs://my-bucket/dataset/ \ + --tokenizer allenai/dolma3-tokenizer \ + --sequence-length 8192 \ + --output /path/to/olmo_index_seq8192.json + ``` + +4. **Configure + run** the trainer: + + ```yaml + dataset_type: olmo_grain + olmo_index_path: /path/to/olmo_index_seq8192.json + olmo_path_remap_from: "gs://my-bucket/" + olmo_path_remap_to: "/mnt/olmo-readonly/" + max_target_length: 8192 # must equal index sequence_length + tokenizer_type: huggingface + tokenizer_path: allenai/Olmo-3-7B-Instruct + ``` + + See `scripts/run_olmo3_7b_grain_smoke.sh` for a runnable smoke launcher. + +## Resume + +Stateless sampler: record at step *k* is a pure function of `(seed, shard, k)`. On startup, the trainer adapter reads the latest step from +`config.checkpoint_dir` and shifts the sampler so the data stream picks +up where it left off — no Grain-iterator-state in the checkpoint. + +`scripts/run_olmo3_7b_grain_resume_test.sh` validates this end-to-end. + +## Notes + +- Files are headerless raw uint32 by default (matches AI2's published + format). The numpy `.npy` extension is misleading. +- Documents may span instance boundaries; this matches OLMo-core. +- `olmo_apply_ngram_filter: True` (default) zeroes loss on instances with + ≥ 32 repetitions of any 1–13-gram, per OLMo-core. +- For mixing pretraining + midtraining, build a combined index by + concatenating the two .txt mix files. + +## Troubleshooting + +| Symptom | Fix | +| ------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------- | +| `OLMo index sequence_length=N but config.max_target_length=M` | Rebuild the index with `--sequence-length M`. | +| `q_block_size=512 should divide q_seq_len=…` | Set `max_target_length` to a multiple of 512. | +| OOM during compile on a small TPU | Shrink with `override_model_config=True base_num_decoder_layers=N`, use `weight_dtype=bfloat16`. | +| Resume restarts at step 0 | Iterator log should print `resumed_step=N initial_step=…`; if both 0, `checkpoint_dir` is empty or wrong. | diff --git a/scripts/run_olmo3_7b_grain_resume_test.sh b/scripts/run_olmo3_7b_grain_resume_test.sh new file mode 100755 index 0000000000..8854b94a6b --- /dev/null +++ b/scripts/run_olmo3_7b_grain_resume_test.sh @@ -0,0 +1,135 @@ +#!/bin/bash +# End-to-end resume test for the OLMo grain pipeline (stateless sampler + +# step-derived initial_step). See scripts/run_olmo3_7b_grain_smoke.sh for +# the env-var contract; this script accepts the same vars. +# +# Plan: +# Run A: train 50 steps from scratch, save checkpoint at step 50, exit. +# Run B: relaunch with the SAME run_name (so the checkpoint dir is reused). +# The trainer restores model state at step 50; our iterator factory +# detects the latest checkpoint step and sets ``initial_step`` so +# the data stream picks up at absolute position 50 * per_host_batch. +# Train 25 more steps (to step 75). +# +# What success looks like: +# * Run B's first step (step 51) reports a loss similar to Run A's step 50 +# loss. A spike or jump → model state didn't restore. +# * No repeats: Run B's batches are NOT the same as Run A's batches at the +# same absolute step. (Hard to assert without batch-content hashing in +# the trainer; for the smoke we rely on the unit tests + loss continuity.) +# * No regression: Run B's loss continues to decrease. +# +# Outputs: +# ${LOG_A} — first 50 steps +# ${LOG_B} — resumed 25 steps +# $OUTPUT_DIR//checkpoints/ — Orbax checkpoint(s) + +set -euo pipefail + +MAXTEXT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +VENV_PATH="${VENV_PATH:-${MAXTEXT_ROOT}/maxtext_venv}" +HF_SECRETS="${HF_SECRETS:-}" +INDEX_PATH="${INDEX_PATH:?INDEX_PATH is required (path to olmo index JSON)}" +GCS_BASE="${GCS_BASE:?GCS_BASE is required (e.g. gs://my-bucket/)}" +LOCAL_MOUNT="${LOCAL_MOUNT:?LOCAL_MOUNT is required (gcsfuse mount path of GCS_BASE)}" +OUTPUT_DIR="${OUTPUT_DIR:-/tmp/olmo_resume_test_out}" +RUN_NAME="${RUN_NAME:-olmo_resume_$(date +%Y%m%d-%H%M%S)}" + +# Where each run's stdout is teed. Keep them under OUTPUT_DIR so the +# script doesn't depend on a hard-coded absolute path. +LOG_A="${LOG_A:-${OUTPUT_DIR}/${RUN_NAME}.runA.log}" +LOG_B="${LOG_B:-${OUTPUT_DIR}/${RUN_NAME}.runB.log}" + +PER_DEVICE_BATCH="${PER_DEVICE_BATCH:-1}" +SEQ_LEN="${SEQ_LEN:-8192}" +WEIGHT_DTYPE="${WEIGHT_DTYPE:-bfloat16}" +NUM_LAYERS="${NUM_LAYERS:-4}" +DATA_SEED="${DATA_SEED:-42}" + +# Run A trains 50 steps + saves a checkpoint at step 50; Run B continues to 75. +STEPS_A="${STEPS_A:-50}" +STEPS_B="${STEPS_B:-75}" +CHECKPOINT_PERIOD="${CHECKPOINT_PERIOD:-50}" + +# shellcheck disable=SC1090,SC1091 +source "${VENV_PATH}/bin/activate" +if [[ -n "${HF_SECRETS:-}" && -f "${HF_SECRETS}" ]]; then + # shellcheck disable=SC1090 + source "${HF_SECRETS}" +fi +: "${HF_TOKEN:?HF_TOKEN must be set (or HF_SECRETS pointing at a file that exports it)}" +export PYTHONPATH="${MAXTEXT_ROOT}/src:${PYTHONPATH:-}" +export PYTHONUNBUFFERED=1 + +mkdir -p "${OUTPUT_DIR}" + +TOKENIZER_PATH="${TOKENIZER_PATH:-allenai/Olmo-3-7B-Instruct}" + +run_train() { + local steps="$1" + local logfile="$2" + echo "----- launching: steps=${steps} log=${logfile} -----" + python -m maxtext.trainers.pre_train.train \ + "${MAXTEXT_ROOT}/src/maxtext/configs/base.yml" \ + model_name=olmo3-7b-pt \ + run_name="${RUN_NAME}" \ + base_output_directory="${OUTPUT_DIR}" \ + dataset_type=olmo_grain \ + olmo_index_path="${INDEX_PATH}" \ + olmo_path_remap_from="${GCS_BASE}" \ + olmo_path_remap_to="${LOCAL_MOUNT}" \ + data_shuffle_seed="${DATA_SEED}" \ + olmo_apply_ngram_filter=True \ + grain_worker_count=0 \ + per_device_batch_size="${PER_DEVICE_BATCH}" \ + max_target_length="${SEQ_LEN}" \ + steps="${steps}" \ + enable_checkpointing=True \ + async_checkpointing=False \ + checkpoint_period="${CHECKPOINT_PERIOD}" \ + save_checkpoint_on_completion=True \ + tokenizer_type=huggingface \ + tokenizer_path="${TOKENIZER_PATH}" \ + weight_dtype="${WEIGHT_DTYPE}" \ + override_model_config=True \ + base_num_decoder_layers="${NUM_LAYERS}" \ + sharding_tolerance=0.05 \ + 2>&1 | tee "${logfile}" +} + +echo "=== OLMo 3 grain resume test ===" +echo " run_name : ${RUN_NAME}" +echo " output_dir : ${OUTPUT_DIR}/${RUN_NAME}" +echo " per_device_bs : ${PER_DEVICE_BATCH}" +echo " seq_len : ${SEQ_LEN}" +echo " num_layers : ${NUM_LAYERS}" +echo " Run A steps : ${STEPS_A} (will checkpoint at step ${CHECKPOINT_PERIOD})" +echo " Run B steps : ${STEPS_B} (resumed via initial_step)" +echo + +# Run A +run_train "${STEPS_A}" "${LOG_A}" + +echo +echo "=== Run A done. Last 3 step events: ===" +grep -E "completed step:" "${LOG_A}" | tail -3 +echo + +# Run B (resume) +run_train "${STEPS_B}" "${LOG_B}" + +echo +echo "=== Run B done ===" +echo "First 3 step events from Run B (expect step >= ${STEPS_A}):" +grep -E "completed step:" "${LOG_B}" | head -3 +echo +echo "Last 3 step events from Run B:" +grep -E "completed step:" "${LOG_B}" | tail -3 +echo + +echo "=== Pass criteria (manual check): ===" +echo " 1. Run B's first step number >= ${STEPS_A} (model state restored)" +echo " 2. Run B's first step loss within ~5% of Run A's last step loss" +echo " (model continued, no re-init)" +echo " 3. Loss continues to decrease across Run B" +echo " 4. iterator log line shows 'resumed_step=${STEPS_A} initial_step=...' on Run B" diff --git a/scripts/run_olmo3_7b_grain_smoke.sh b/scripts/run_olmo3_7b_grain_smoke.sh new file mode 100755 index 0000000000..d5bfad3937 --- /dev/null +++ b/scripts/run_olmo3_7b_grain_smoke.sh @@ -0,0 +1,96 @@ +#!/bin/bash +# Smoke training run for OLMo 3 7B on the OLMo numpy grain pipeline. +# +# Validates that dataset_type=olmo_grain wires through the trainer, that +# OlmoNpyDataSource reads .npy data via a gcsfuse mount, and that 50 steps +# execute without crashes / shape mismatches with monotonically decreasing +# loss. +# +# Required env vars: +# INDEX_PATH JSON index from tools/data_generation/build_olmo_npy_index.py +# GCS_BASE gs:// prefix recorded in the index (e.g. gs://my-bucket/) +# LOCAL_MOUNT gcsfuse mount of GCS_BASE on this host +# HF_TOKEN HuggingFace token for the tokenizer (or HF_SECRETS=) +# Optional: VENV_PATH, OUTPUT_DIR, PER_DEVICE_BATCH, SEQ_LEN, STEPS, +# WEIGHT_DTYPE, NUM_LAYERS. +# +# Usage: +# INDEX_PATH=/path/to/olmo_index_seq8192.json \ +# LOCAL_MOUNT=/mnt/your-mount/ \ +# GCS_BASE=gs://your-bucket/ \ +# HF_TOKEN=hf_... \ +# bash scripts/run_olmo3_7b_grain_smoke.sh + +set -euo pipefail + +MAXTEXT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" + +VENV_PATH="${VENV_PATH:-${MAXTEXT_ROOT}/maxtext_venv}" +HF_SECRETS="${HF_SECRETS:-}" +INDEX_PATH="${INDEX_PATH:?INDEX_PATH is required (path to olmo index JSON)}" +GCS_BASE="${GCS_BASE:?GCS_BASE is required (e.g. gs://my-bucket/)}" +LOCAL_MOUNT="${LOCAL_MOUNT:?LOCAL_MOUNT is required (gcsfuse mount path of GCS_BASE)}" +OUTPUT_DIR="${OUTPUT_DIR:-/tmp/olmo_smoke_out}" + +PER_DEVICE_BATCH="${PER_DEVICE_BATCH:-1}" +SEQ_LEN="${SEQ_LEN:-8192}" +STEPS="${STEPS:-50}" +DATA_SEED="${DATA_SEED:-42}" +# Smoke test uses a reduced model (bf16, 4 layers) so it fits small TPU +# slices; we're validating the data path, not full-size convergence. +WEIGHT_DTYPE="${WEIGHT_DTYPE:-bfloat16}" +NUM_LAYERS="${NUM_LAYERS:-4}" + +RUN_NAME="${RUN_NAME:-olmo_grain_smoke_$(date +%Y%m%d-%H%M%S)}" + +# Activate venv + load HF secrets. +# shellcheck disable=SC1090,SC1091 +source "${VENV_PATH}/bin/activate" +if [[ -n "${HF_SECRETS:-}" && -f "${HF_SECRETS}" ]]; then + # shellcheck disable=SC1090 + source "${HF_SECRETS}" +fi +: "${HF_TOKEN:?HF_TOKEN must be set (or HF_SECRETS pointing at a file that exports it)}" +export PYTHONPATH="${MAXTEXT_ROOT}/src:${PYTHONPATH:-}" +export PYTHONUNBUFFERED=1 + +mkdir -p "${OUTPUT_DIR}" + +echo "=== OLMo 3 7B + olmo_grain smoke run ===" +echo " run_name : ${RUN_NAME}" +echo " index : ${INDEX_PATH}" +echo " path remap : ${GCS_BASE} → ${LOCAL_MOUNT}" +echo " per_device_bs : ${PER_DEVICE_BATCH}" +echo " seq_len : ${SEQ_LEN}" +echo " steps : ${STEPS}" +echo " weight_dtype : ${WEIGHT_DTYPE}" +echo " num_layers : ${NUM_LAYERS} (full 7B has 32)" +echo " output_dir : ${OUTPUT_DIR}" +echo + +# Data is already tokenized; the tokenizer is loaded only for pad/eos IDs + +# vocab_size checks. Olmo-3-7B-Instruct uses the same dolma3 tokenizer. +TOKENIZER_PATH="${TOKENIZER_PATH:-allenai/Olmo-3-7B-Instruct}" + +python -m maxtext.trainers.pre_train.train \ + "${MAXTEXT_ROOT}/src/maxtext/configs/base.yml" \ + model_name=olmo3-7b-pt \ + run_name="${RUN_NAME}" \ + base_output_directory="${OUTPUT_DIR}" \ + dataset_type=olmo_grain \ + olmo_index_path="${INDEX_PATH}" \ + olmo_path_remap_from="${GCS_BASE}" \ + olmo_path_remap_to="${LOCAL_MOUNT}" \ + data_shuffle_seed="${DATA_SEED}" \ + olmo_apply_ngram_filter=True \ + grain_worker_count=0 \ + per_device_batch_size="${PER_DEVICE_BATCH}" \ + max_target_length="${SEQ_LEN}" \ + steps="${STEPS}" \ + enable_checkpointing=False \ + tokenizer_type=huggingface \ + tokenizer_path="${TOKENIZER_PATH}" \ + weight_dtype="${WEIGHT_DTYPE}" \ + override_model_config=True \ + base_num_decoder_layers="${NUM_LAYERS}" \ + sharding_tolerance=0.05 diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 39ef5b2dee..7671414fd3 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -724,6 +724,14 @@ grain_shuffle_buffer_size: 100 # shuffle buffer when using sequential access for # for using pathways colocated_python_data_input: False # experimental feature, under testing +# OLMo numpy pipeline (dataset_type=olmo_grain). Worker count, buffer size, +# and shuffle seed reuse grain_worker_count / grain_per_worker_buffer_size / +# data_shuffle_seed. +olmo_index_path: '' # JSON from tools/data_generation/build_olmo_npy_index.py +olmo_path_remap_from: '' # rewrite index paths starting with this prefix... +olmo_path_remap_to: '' # ...to this one (e.g. gs://bucket/ -> /mnt/.../ for gcsfuse). +olmo_apply_ngram_filter: True # mask instances with repetitive n-grams (OLMo-core filter) + # Training loop steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps log_period: 100 # The frequency of Tensorboard flush, gcs metrics writing, and managed profiler metrics updating. diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 43338ed507..9c1277aae7 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -176,6 +176,7 @@ class DatasetType(str, Enum): GRAIN = "grain" TFDS = "tfds" C4MLPERF = "c4_mlperf" + OLMO_GRAIN = "olmo_grain" class SamplingStrategy(str, Enum): @@ -1128,6 +1129,26 @@ class GrainDataset(BaseModel): grain_shuffle_buffer_size: int = Field(100, description="Shuffle buffer size when using Parquet or TFRecord.") +class OlmoGrainDataset(BaseModel): + """Configuration for the OLMo numpy fixed-seq-length input pipeline (dataset_type=olmo_grain). + + Worker count, per-worker buffer size, and shuffle seed reuse the standard + grain flags (``grain_worker_count``, ``grain_per_worker_buffer_size``, + ``data_shuffle_seed``); only OLMo-specific fields are listed here. + """ + + olmo_index_path: PathStr = Field("", description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.") + olmo_path_remap_from: PathStr = Field( + "", + description="If set, rewrite index file paths starting with this prefix to olmo_path_remap_to.", + ) + olmo_path_remap_to: PathStr = Field( + "", + description="Replacement prefix used together with olmo_path_remap_from (e.g. /mnt/disks/.../).", + ) + olmo_apply_ngram_filter: bool = Field(True, description="Mask repetitive instances per OLMo-core's repetition filter.") + + class FineTuning(BaseModel): """Configuration for fine-tuning methods like DPO, SFT, and GRPO.""" @@ -2092,6 +2113,7 @@ class MaxTextConfig( TfdsDataset, HfDataset, GrainDataset, + OlmoGrainDataset, Tokenizer, # Inference InferenceGeneral, diff --git a/src/maxtext/input_pipeline/input_pipeline_interface.py b/src/maxtext/input_pipeline/input_pipeline_interface.py index ac37a7bdda..0bef3763c9 100644 --- a/src/maxtext/input_pipeline/input_pipeline_interface.py +++ b/src/maxtext/input_pipeline/input_pipeline_interface.py @@ -23,6 +23,8 @@ from maxtext.input_pipeline.grain_data_processing import make_grain_eval_iterator from maxtext.input_pipeline.hf_data_processing import make_hf_train_iterator from maxtext.input_pipeline.hf_data_processing import make_hf_eval_iterator +from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_train_iterator +from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_eval_iterator from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator from maxtext.input_pipeline.tfds_data_processing import make_tfds_eval_iterator from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator @@ -71,10 +73,11 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh): "grain": (make_grain_train_iterator, make_grain_eval_iterator), "hf": (make_hf_train_iterator, make_hf_eval_iterator), "c4_mlperf": (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator), + "olmo_grain": (make_olmo_grain_train_iterator, make_olmo_grain_eval_iterator), } # Collect train and eval iterators - if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf"]: + if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf", "olmo_grain"]: if config.dataset_type == "c4_mlperf": assert config.packing, "c4_mlperf dataloader only works with packing. For padded version, use tfds dataloader" train_iterator, eval_iterator = dataset_type_to_train_eval_iterator[config.dataset_type] diff --git a/src/maxtext/input_pipeline/olmo_data.py b/src/maxtext/input_pipeline/olmo_data.py new file mode 100644 index 0000000000..96151035a0 --- /dev/null +++ b/src/maxtext/input_pipeline/olmo_data.py @@ -0,0 +1,415 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for OLMo-core-style numpy FSL datasets. + +Dependency-free layer. AI2's mix files describe a virtual concatenation of +flat token-ID arrays; instances are non-overlapping ``sequence_length``-token +windows of that stream. This module builds the index that maps a global +instance index to (file, byte-offset), and ports OLMo-core's repeated-n-gram +filter (``olmo_core/data/utils.py::find_periodic_sequences``). +""" + +from __future__ import annotations + +import ast +import bisect +import dataclasses +import hashlib +import io +import json +import os +import struct +from dataclasses import dataclass, field +from typing import Generator, List, NamedTuple, Optional, Sequence, Tuple + +import numpy as np + + +# Bumped whenever the on-disk index format or fingerprint inputs change. +INDEX_FORMAT_VERSION = "1" + + +@dataclass(frozen=True) +class OlmoNpyFileEntry: + """One file in the mix: ``n_tokens // sequence_length`` instances starting + at global index ``instance_offset``. Trailing tokens are dropped (matches + OLMo-core).""" + + path: str + label: str + n_tokens: int + n_instances: int + instance_offset: int + + +@dataclass +class OlmoNpyIndex: + """Index over the files in an OLMo data mix. Build via + :func:`build_index`, persist via :meth:`save`, restore via + :func:`load_index`. Mutating fields invalidates :attr:`fingerprint`.""" + + format_version: str + sequence_length: int + dtype: str # numpy dtype string, e.g. "uint32" + tokenizer: str # informational, e.g. "allenai/dolma3-tokenizer" + files: Tuple[OlmoNpyFileEntry, ...] + total_instances: int + total_tokens: int + fingerprint: str = "" + + # Lazily computed bisect helper: cumulative instance_offsets+sentinel for + # binary search in ``global_to_local``. Not serialized. + _instance_offset_starts: Optional[List[int]] = field(default=None, repr=False, compare=False) + + def __post_init__(self): + starts = [f.instance_offset for f in self.files] + starts.append(self.total_instances) # sentinel for bisect + object.__setattr__(self, "_instance_offset_starts", starts) + + def to_json_dict(self) -> dict: + """Return a JSON-serializable view (drops cached lookup helpers).""" + return { + "format_version": self.format_version, + "sequence_length": self.sequence_length, + "dtype": self.dtype, + "tokenizer": self.tokenizer, + "total_instances": self.total_instances, + "total_tokens": self.total_tokens, + "fingerprint": self.fingerprint, + "files": [dataclasses.asdict(f) for f in self.files], + } + + def save(self, path: str) -> None: + """Write the index as JSON to ``path`` (local filesystem).""" + with open(path, "w", encoding="utf-8") as fh: + json.dump(self.to_json_dict(), fh, indent=2) + + +def load_index(path: str) -> OlmoNpyIndex: + """Load an index from JSON written by :meth:`OlmoNpyIndex.save`. + + Args: + path: Local filesystem path to the JSON file. + + Returns: + The materialized :class:`OlmoNpyIndex`. + + Raises: + ValueError: If ``format_version`` doesn't match this code's expectation. + """ + with open(path, encoding="utf-8") as fh: + data = json.load(fh) + if data.get("format_version") != INDEX_FORMAT_VERSION: + raise ValueError( + f"Index format version mismatch: file has " + f"{data.get('format_version')!r}, code expects {INDEX_FORMAT_VERSION!r}." + ) + files = tuple(OlmoNpyFileEntry(**entry) for entry in data["files"]) + return OlmoNpyIndex( + format_version=data["format_version"], + sequence_length=data["sequence_length"], + dtype=data["dtype"], + tokenizer=data["tokenizer"], + files=files, + total_instances=data["total_instances"], + total_tokens=data["total_tokens"], + fingerprint=data["fingerprint"], + ) + + +def global_to_local(index: OlmoNpyIndex, instance_id: int) -> Tuple[int, int]: + """Global instance index → ``(file_idx, token_offset)``. + + ``token_offset`` is in *tokens* (not bytes); the slice + ``arr[token_offset : token_offset + sequence_length]`` is the instance. + """ + if instance_id < 0 or instance_id >= index.total_instances: + raise IndexError(f"instance_id {instance_id} out of range " f"[0, {index.total_instances})") + starts = index._instance_offset_starts # type: ignore[attr-defined] # pylint: disable=protected-access + file_idx = bisect.bisect_right(starts, instance_id) - 1 + local_instance = instance_id - index.files[file_idx].instance_offset + token_offset = local_instance * index.sequence_length + return file_idx, token_offset + + +def compute_fingerprint( + sequence_length: int, + dtype: str, + tokenizer: str, + files: Sequence[OlmoNpyFileEntry], +) -> str: + """Stable hash over the fields a restart must preserve. + + If any of these change, the global instance ordering changes and resuming + training from a checkpoint would silently produce different batches. + """ + h = hashlib.sha256() + h.update(INDEX_FORMAT_VERSION.encode("utf-8")) + h.update(b"\x00") + h.update(str(sequence_length).encode("utf-8")) + h.update(b"\x00") + h.update(dtype.encode("utf-8")) + h.update(b"\x00") + h.update(tokenizer.encode("utf-8")) + h.update(b"\x00") + for f in files: + h.update(f.path.encode("utf-8")) + h.update(b"\x00") + h.update(str(f.n_tokens).encode("utf-8")) + h.update(b"\x00") + return f"sha256:{h.hexdigest()}" + + +_NPY_MAGIC = b"\x93NUMPY" + + +def parse_npy_header(stream: io.RawIOBase) -> Tuple[str, Tuple[int, ...]]: + """Parse a .npy v1/v2/v3 header. Returns ``(dtype_str, shape)``.""" + magic = stream.read(6) + if magic != _NPY_MAGIC: + raise ValueError(f"Not a .npy file (magic={magic!r})") + major = stream.read(1)[0] + stream.read(1) # minor version byte — unused + if major == 1: + header_len = struct.unpack(" "uint32" + return dtype_str, shape + + +def read_npy_header_from_path(path: str) -> Tuple[str, Tuple[int, ...]]: + """Convenience wrapper for :func:`parse_npy_header` on a local file.""" + with open(path, "rb") as fh: + return parse_npy_header(fh) + + +def read_raw_metadata_from_path(path: str, dtype: str) -> Tuple[str, Tuple[int, ...]]: + """Headerless raw binary: ``n_tokens = file_size // itemsize``. + + AI2's ``.npy``-extension files are actually raw uint32 dumps, no header; + olmo-core reads them with ``np.memmap`` and a known dtype. + """ + itemsize = np.dtype(dtype).itemsize + size_bytes = os.path.getsize(path) + if size_bytes % itemsize != 0: + raise ValueError( + f"File size {size_bytes} of {path} is not a multiple of itemsize " + f"{itemsize} for dtype {dtype}; this is unexpected." + ) + return dtype, (size_bytes // itemsize,) + + +def has_npy_magic(first_bytes: bytes) -> bool: + """Quick check: does this look like a real .npy file?""" + return len(first_bytes) >= 6 and first_bytes[:6] == _NPY_MAGIC + + +def _file_entry_from_header( + path: str, + label: str, + dtype: str, + shape: Tuple[int, ...], + sequence_length: int, + instance_offset: int, +) -> OlmoNpyFileEntry: + """Build a file entry from a parsed .npy header (validates shape is 1-D).""" + if len(shape) != 1: + raise ValueError(f"Expected 1-D .npy array for {path}, got shape {shape}.") + n_tokens = int(shape[0]) + n_instances = n_tokens // sequence_length + return OlmoNpyFileEntry( + path=path, + label=label, + n_tokens=n_tokens, + n_instances=n_instances, + instance_offset=instance_offset, + ) + + +def build_index( + paths_and_labels: Sequence[Tuple[str, str]], + sequence_length: int, + *, + tokenizer: str, + header_reader=read_npy_header_from_path, +) -> OlmoNpyIndex: + """Build an :class:`OlmoNpyIndex` from ``(path, label)`` entries. + + Order matters — global instance ordering is the concatenation in this + order. ``header_reader`` is the seam tests use to avoid disk; production + paths pass a GCS-aware reader. + """ + if sequence_length <= 0: + raise ValueError(f"sequence_length must be positive, got {sequence_length}") + if not paths_and_labels: + raise ValueError("paths_and_labels must be non-empty") + + entries: List[OlmoNpyFileEntry] = [] + observed_dtype: Optional[str] = None + cum_offset = 0 + for path, label in paths_and_labels: + dtype, shape = header_reader(path) + if observed_dtype is None: + observed_dtype = dtype + elif dtype != observed_dtype: + raise ValueError(f"Heterogeneous dtypes across mix files: {observed_dtype!r} " f"and {dtype!r} (at {path}).") + entry = _file_entry_from_header( + path=path, + label=label, + dtype=dtype, + shape=shape, + sequence_length=sequence_length, + instance_offset=cum_offset, + ) + entries.append(entry) + cum_offset += entry.n_instances + + files = tuple(entries) + total_instances = cum_offset + total_tokens = sum(f.n_tokens for f in files) + + fingerprint = compute_fingerprint( + sequence_length=sequence_length, + dtype=observed_dtype or "", + tokenizer=tokenizer, + files=files, + ) + + return OlmoNpyIndex( + format_version=INDEX_FORMAT_VERSION, + sequence_length=sequence_length, + dtype=observed_dtype or "", + tokenizer=tokenizer, + files=files, + total_instances=total_instances, + total_tokens=total_tokens, + fingerprint=fingerprint, + ) + + +class RepetitionTuple(NamedTuple): + """``arr[start:end]`` is a periodic span of length ``period``, + ``times = (end - start) // period``.""" + + start: int + end: int + period: int + times: int + + +def _find_end_first_consecutive_true(arr: np.ndarray) -> int: + """End offset (exclusive) of the leading run of True in ``arr``. + + Returns 0 if ``arr[0]`` is False, ``len(arr)`` if all True. + """ + if not arr[0]: + return 0 + prog = np.cumsum(arr) + if prog[-1] == len(arr): + return int(len(arr)) + # First index where the cumulative sum stops increasing == start of False run. + true_locs = np.where(prog[:-1:] == prog[1::])[0] + return int(true_locs[0] + 1) + + +def _find_start_last_consecutive_true(arr: np.ndarray) -> int: + """Start offset of the trailing run of True in ``arr``, or -1 if none.""" + reverse = _find_end_first_consecutive_true(arr[::-1]) + return len(arr) - reverse if reverse > 0 else -1 + + +def _group_consecutive_values(arr: np.ndarray, stepsize: int = 1) -> List[np.ndarray]: + """Split a 1-D array of ints into runs of consecutive values.""" + if len(arr) == 0: + return [] + return np.split(arr, np.where(np.diff(arr) != stepsize)[0] + 1) + + +def find_periodic_sequences( + arr: np.ndarray, + max_period: int, + min_period: int = 1, + mask_value: int = -1, +) -> Generator[RepetitionTuple, None, None]: + """Yield :class:`RepetitionTuple` for periodic spans of length ≥ 3 in + ``arr``. ``mask_value`` must not appear in the array (used internally as + reshape padding).""" + if (arr == mask_value).sum() > 0: + raise ValueError("`mask_value` is in the array") + + max_period = min(max_period, len(arr) // 3) + + for period in range(min_period, max_period + 1): + pad = (period - (len(arr) % period)) % period + padded_arr = np.pad(arr, (0, pad), constant_values=mask_value) if pad else arr + shaped_arr = padded_arr.reshape(-1, period) + + is_equal_to_prev_row = shaped_arr == np.roll(shaped_arr, shift=1, axis=0) + rows_with_period = np.where(is_equal_to_prev_row.all(axis=1))[0] + if len(rows_with_period) == 0: + continue + + for sequence in _group_consecutive_values(rows_with_period): + start_row = int(sequence[0]) + end_row = int(sequence[-1]) + + start_offset = _find_start_last_consecutive_true(is_equal_to_prev_row[start_row - 1]) + start_offset = period - start_offset if start_offset > 0 else 0 + + end_offset = _find_end_first_consecutive_true(is_equal_to_prev_row[(end_row + 1) % shaped_arr.shape[0]]) + + start_pos = (start_row - 1) * period - start_offset + end_pos = ((end_row + 1) * period) + end_offset + + out = RepetitionTuple( + start=start_pos, + end=end_pos, + period=period, + times=(end_pos - start_pos) // period, + ) + if out.times > 2: + yield out + + +def is_clean_instance( + input_ids: np.ndarray, + *, + repetition_max_period: int = 13, + repetition_min_period: int = 1, + repetition_max_count: int = 32, + mask_value: int = -1, +) -> bool: + """``False`` iff ``input_ids`` has any periodic span (period ∈ + [min, max]) that repeats ≥ ``repetition_max_count`` times. Defaults + match OLMo-core's ``_validate_instance``.""" + for m in find_periodic_sequences( + input_ids, + max_period=repetition_max_period, + min_period=repetition_min_period, + mask_value=mask_value, + ): + if m.times >= repetition_max_count: + return False + return True diff --git a/src/maxtext/input_pipeline/olmo_data_grain.py b/src/maxtext/input_pipeline/olmo_data_grain.py new file mode 100644 index 0000000000..fe2ca062f8 --- /dev/null +++ b/src/maxtext/input_pipeline/olmo_data_grain.py @@ -0,0 +1,497 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OLMo numpy fixed-seq-length dataset on top of Grain. + +A Grain ``RandomAccessDataSource`` over the AI2 OLMo virtual token stream +plus a deterministic global-shuffle sampler. See +``docs/guides/data_input_pipeline/olmo_grain.md`` for an overview. +""" + +from __future__ import annotations + +import hashlib +import threading +from typing import Any, Dict, List, Optional + +import numpy as np + +import grain + +from maxtext.input_pipeline.olmo_data import ( + OlmoNpyIndex, + global_to_local, + is_clean_instance, +) + + +class OlmoNpyDataSource(grain.sources.RandomAccessDataSource): + """Random-access view of an OLMo numpy mix as a stream of token windows. + + Files are opened lazily and cached as ``np.memmap`` per worker. Open mmaps + are reference-counted by :class:`_MmapCache` so we don't blow past + ``ulimit -n`` when iterating over the full 950-file mix. + + The data source is **process-safe**: every Grain worker subprocess builds + its own ``_MmapCache`` after the fork. No shared mutable state. + + Args: + index: The :class:`OlmoNpyIndex` describing the mix. Path strings must be + reachable from the data-loading host (typically a GCSFUSE mount path + like ``/mnt//...``). + path_remap: Optional dict to rewrite ``index.files[i].path``. Useful when + the index was built with ``gs://`` paths and you want to read from a + gcsfuse mount, or vice versa. A path is rewritten if it starts with + any key in this dict. + max_open_files: Soft cap on the number of mmaps held open in the + per-worker cache. The cache is LRU. + """ + + def __init__( + self, + index: OlmoNpyIndex, + *, + path_remap: Optional[Dict[str, str]] = None, + max_open_files: int = 64, + ): + self._index = index + self._dtype = np.dtype(index.dtype) + self._sequence_length = index.sequence_length + self._path_remap = dict(path_remap or {}) + self._mmaps = _MmapCache(max_open_files=max_open_files) + + # ---- Grain's RandomAccessDataSource interface --------------------------- # + + def __len__(self) -> int: + return self._index.total_instances + + def __getitem__(self, instance_id: int) -> Dict[str, Any]: + file_idx, token_offset = global_to_local(self._index, instance_id) + file_entry = self._index.files[file_idx] + arr = self._mmaps.get(self._resolve_path(file_entry.path), self._dtype) + # Always copy: the memmap is opened read-only (mode="r"), and we need to + # hand a writable, picklable array back to Grain so transforms can + # mutate it freely and worker processes can serialize it without + # dragging the memmap object along. + tokens = np.array(arr[token_offset : token_offset + self._sequence_length], copy=True) + return { + "tokens": tokens, + "instance_id": int(instance_id), + "file_id": int(file_idx), + } + + # ---- Helpers ------------------------------------------------------------ # + + def _resolve_path(self, path: str) -> str: + for prefix, replacement in self._path_remap.items(): + if path.startswith(prefix): + return replacement + path[len(prefix) :] + return path + + def __getstate__(self): + # Mmap caches don't survive pickling; rebuild after unpickle. + state = self.__dict__.copy() + state["_mmaps"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._mmaps = _MmapCache(max_open_files=64) + + def __repr__(self) -> str: + """Stable repr — Grain compares ``repr(data_source)`` between the + checkpoint and the live source on resume. The default repr embeds the + object id, which breaks resume across separate construction. We hash + the index fingerprint + seq + path-remap instead so equivalent sources + compare equal as strings.""" + return ( + f"OlmoNpyDataSource(fingerprint={self._index.fingerprint!r}, " + f"seq={self._sequence_length}, dtype={self._dtype.str!r}, " + f"remap={sorted(self._path_remap.items())!r})" + ) + + +class _MmapCache: + """Tiny LRU-ish cache of open ``np.memmap`` handles.""" + + def __init__(self, max_open_files: int = 64): + self._max = max_open_files + self._mmaps: Dict[str, np.memmap] = {} + self._lock = threading.Lock() + + def get(self, path: str, dtype: np.dtype) -> np.memmap: + """Return a cached ``np.memmap`` for ``path``, opening it lazily.""" + with self._lock: + arr = self._mmaps.get(path) + if arr is not None: + # touch: re-insert at end for LRU ordering. + del self._mmaps[path] + self._mmaps[path] = arr + return arr + # Open the file as 1-D memmap; for raw .npy (no header) we need to know + # length, but np.memmap can derive it from file size when shape=(-1,). + arr = np.memmap(path, dtype=dtype, mode="r") + self._mmaps[path] = arr + while len(self._mmaps) > self._max: + # Evict oldest. + oldest = next(iter(self._mmaps)) + del self._mmaps[oldest] + return arr + + +class OlmoIndexSampler: + """Global-shuffle sampler over an OLMo numpy mix. + + Mirrors OLMo-core's :class:`NumpyDataLoaderBase` shuffle math: a single + Fisher-Yates over ``[0, total_instances)`` keyed by ``hash(seed, epoch)``, + then partitioned across ``shard_count`` hosts. + + Implements Grain's ``Sampler`` protocol — i.e. ``__getitem__`` returning + :class:`grain.python.RecordMetadata`. Grain calls + ``sampler[index]`` for each global step; the sampler is responsible for + mapping that to the actual record_key fed to ``data_source[record_key]``. + + Indexing semantics: + + * ``index`` here is a *per-host* (per-data-loader) global step counter + starting at 0 and advancing without bound (we support infinite epochs). + * ``epoch = index // num_local_instances_per_epoch`` selects which + permutation to use; ``in_epoch = index % num_local_instances_per_epoch`` + selects the position within this host's shard of that permutation. + + Checkpointing is trivial: the only mutable state is "which epoch's + permutation is currently cached" (a perf optimization). The user-visible + position is just the index passed to ``__getitem__``. + + Args: + total_instances: ``index.total_instances`` from the OLMo index. + seed: Base seed for the shuffle. + shard_index: Zero-based index of this data-loading host. Typically + ``jax.process_index()``. + shard_count: Number of data-loading hosts. Typically + ``jax.process_count()``. + shuffle: If ``False``, instances are emitted in linear order — useful + for debugging. + initial_step: Per-host batch step at which the *training run* should + resume. ``__getitem__(local_idx)`` returns the record at absolute + position ``local_idx + initial_step``. Use this to resume a run from + a saved trainer step without saving Grain's iterator state — our + sampler is a pure function of its inputs, so the (seed, shard, + absolute step) tuple fully determines the next record. + """ + + def __init__( + self, + *, + total_instances: int, + seed: int, + shard_index: int = 0, + shard_count: int = 1, + shuffle: bool = True, + initial_step: int = 0, + ): + if shard_count <= 0 or shard_index < 0 or shard_index >= shard_count: + raise ValueError(f"Invalid shard config: shard_index={shard_index} of {shard_count}") + if total_instances <= 0: + raise ValueError(f"total_instances must be positive, got {total_instances}") + if initial_step < 0: + raise ValueError(f"initial_step must be non-negative, got {initial_step}") + self._total = int(total_instances) + self._seed = int(seed) + self._shard_index = int(shard_index) + self._shard_count = int(shard_count) + self._shuffle = bool(shuffle) + self._initial_step = int(initial_step) + # Cache the shuffled-and-sharded indices for the most-recently-touched + # epoch. Cheap to recompute on epoch boundaries; expensive to keep many + # epochs resident at once for the full 724 M-instance mix. + self._cached_epoch: Optional[int] = None + self._cached_shard_indices: Optional[np.ndarray] = None + self._cache_lock = threading.Lock() + + # ---- Public API --------------------------------------------------------- # + + @property + def num_instances(self) -> int: + return self._total + + @property + def num_local_instances_per_epoch(self) -> int: + """Instances assigned to *this* host per epoch (drops trailing remainder).""" + return self._total // self._shard_count + + def shuffled_global_indices(self, *, seed: int, epoch: int) -> np.ndarray: + """Build the full shuffled list for ``(seed, epoch)``. + + For the production 724 M-instance mix this allocates ~5.8 GB at uint64 + (numpy's default for ``permutation``). For production we should swap to + an on-disk memmap scheme like olmo-core's + ``build_and_save_global_indices``. Sized for unit tests + the initial + smoke training run for now. + """ + if not self._shuffle: + return np.arange(self._total, dtype=np.uint64) + rng = np.random.default_rng(_combine_seed_epoch(seed, epoch)) + order = rng.permutation(self._total) + return order.astype(np.uint64, copy=False) + + def shard_indices(self, *, seed: int, epoch: int) -> np.ndarray: + """Slice the global shuffled order down to this host's share.""" + full = self.shuffled_global_indices(seed=seed, epoch=epoch) + n_per = self.num_local_instances_per_epoch + start = self._shard_index * n_per + end = start + n_per + return full[start:end] + + def _shard_indices_for_epoch(self, epoch: int) -> np.ndarray: + with self._cache_lock: + if self._cached_epoch == epoch and self._cached_shard_indices is not None: + return self._cached_shard_indices + shard = self.shard_indices(seed=self._seed, epoch=epoch) + self._cached_epoch = epoch + self._cached_shard_indices = shard + return shard + + def __getstate__(self): + # threading.Lock can't be pickled, and the per-epoch cache is a pure perf + # optimization — drop both before serialization to forked Grain workers. + state = self.__dict__.copy() + state["_cache_lock"] = None + state["_cached_epoch"] = None + state["_cached_shard_indices"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._cache_lock = threading.Lock() + + # ---- Sampler protocol --------------------------------------------------- # + + def __getitem__(self, index: int) -> grain.RecordMetadata: + """Map a per-host global step ``index`` to the next record to fetch. + + The lookup applies ``initial_step`` as a transparent offset: the caller + sees a fresh stream starting at index 0, but the underlying record + pointer is at absolute position ``index + initial_step``. That's the + mechanism that lets resume work without persisting any iterator state. + """ + if index < 0: + raise IndexError(f"sampler index must be non-negative, got {index}") + n_per = self.num_local_instances_per_epoch + if n_per == 0: + raise IndexError( + f"No instances assigned to shard {self._shard_index}/{self._shard_count} " f"(total_instances={self._total})" + ) + absolute = index + self._initial_step + epoch = absolute // n_per + in_epoch = absolute % n_per + shard = self._shard_indices_for_epoch(epoch) + record_key = int(shard[in_epoch]) + return grain.RecordMetadata(index=index, record_key=record_key) + + # Grain >=0.2.16 expects either a finite ``__len__`` or that the sampler + # raises ``IndexError`` on out-of-bounds. We support infinite training and + # never raise IndexError for non-negative indices, so we omit ``__len__``. + + def __repr__(self) -> str: + """Stable repr — Grain compares ``repr(sampler)`` between the checkpoint + and the live sampler to validate the sampler is unchanged on resume. + + We deliberately **exclude** ``initial_step`` from the repr: a sampler + rebuilt with a different ``initial_step`` produces a different absolute + position via offset arithmetic, but it's still the *same logical sampler* + over the same data. Including the step here would break interop with + Grain's iterator-state checkpointing path (different reprs reject each + other). The repr captures only the immutable config that defines the + sample space; the offset is just a starting cursor. + """ + return ( + f"OlmoIndexSampler(total_instances={self._total}, seed={self._seed}, " + f"shard_index={self._shard_index}, shard_count={self._shard_count}, " + f"shuffle={self._shuffle})" + ) + + +def _combine_seed_epoch(seed: int, epoch: int) -> int: + """Stable 64-bit mix of (seed, epoch) for the per-epoch shuffle RNG. + + Uses SHA-256 truncated to 64 bits — no fixed points (unlike a raw multiply + by a constant when seed=epoch=0), and avoids the numpy uint64 multiplication + overflow warnings that dog SplitMix-style mixers in pure numpy. + """ + digest = hashlib.sha256(f"olmo-shuffle:{int(seed)}:{int(epoch)}".encode("utf-8")).digest() + return int.from_bytes(digest[:8], "little") + + +class NgramFilterTransform(grain.transforms.Map): + """Add an ``instance_mask`` field per OLMo-core's repetition filter. + + ``instance_mask = True`` if the instance is "clean" (kept fully in the + loss); ``False`` if it has too-repetitive periodic spans (zero-out at + loss time). We don't drop the instance — that would mess with sharding — + matching OLMo-core's behavior. + """ + + def __init__( + self, + *, + max_period: int = 13, + min_period: int = 1, + max_count: int = 32, + mask_value: int = -1, + ): + self._max_period = int(max_period) + self._min_period = int(min_period) + self._max_count = int(max_count) + self._mask_value = int(mask_value) + + def map(self, element: Dict[str, Any]) -> Dict[str, Any]: + """Add ``instance_mask`` to ``element`` based on the n-gram filter.""" + tokens = element["tokens"] + clean = is_clean_instance( + tokens, + repetition_max_period=self._max_period, + repetition_min_period=self._min_period, + repetition_max_count=self._max_count, + mask_value=self._mask_value, + ) + out = dict(element) + out["instance_mask"] = bool(clean) + return out + + +class ShiftToInputsTargets(grain.transforms.Map): + """Convert a ``tokens`` array into the keys MaxText's pretrain trainer expects. + + Produces, for a single instance of length ``L = sequence_length``: + + * ``inputs``: ``tokens.astype(int32)``, shape ``(L,)`` + * ``targets``: ``tokens`` shifted left by one, padded with 0 at position + ``L-1``, shape ``(L,)`` + * ``inputs_position``: ``[0, 1, ..., L-1]`` int32 + * ``inputs_segmentation``: ``int32`` ones, shape ``(L,)`` — single segment + * ``targets_segmentation``: ``int32`` ones, shape ``(L,)`` with the last + position zeroed (loss masked at the padded position); the entire row is + zero if ``instance_mask`` is False (n-gram filter flagged the instance). + + Outputs are the full ``L`` tokens (not ``L-1``) because the TPU splash + attention kernel requires ``q_seq_len`` divisible by 512; producing length + ``L-1`` would break that invariant for typical OLMo ``L=8192``. + + The OLMo dataset has no document boundaries inside an instance — sequences + span doc boundaries with no special masking — so ``segmentation`` and + ``position`` are trivially uniform within an instance. + """ + + def map(self, element: Dict[str, Any]) -> Dict[str, Any]: + """Convert ``tokens`` into ``inputs`` / ``targets`` / segmentation tensors.""" + tokens = element["tokens"].astype(np.int32, copy=False) + L = tokens.shape[0] # == sequence_length from the index + instance_mask = bool(element.get("instance_mask", True)) + seg_value = np.int32(1) if instance_mask else np.int32(0) + + # Output rank-2 (batch, seq) tensors of length L (= max_target_length). + # The TPU splash-attention kernel requires q_seq_len to be divisible by + # 512, which means the trainer-side seq length must be the full L — + # using ``tokens[:-1]`` (length L-1) breaks that invariant. + # + # For next-token prediction we still want ``targets[i] = tokens[i+1]``, + # so we shift and pad the last position with 0 then *mask it out* via + # ``targets_segmentation[L-1] = 0``. The trainer's segmentation-aware + # loss skips positions where targets_segmentation == 0, so the padded + # last token contributes nothing to the loss. Information loss is + # 1 token per ``L``-token instance (~0.012% at L=8192). + inputs = tokens + targets = np.empty(L, dtype=np.int32) + targets[:-1] = tokens[1:] + targets[-1] = 0 # pad; loss masked below + + targets_seg = np.full(L, seg_value, dtype=np.int32) + targets_seg[-1] = 0 # never compute loss on the boundary position + + return { + "inputs": inputs, + "targets": targets, + "inputs_position": np.arange(L, dtype=np.int32), + "inputs_segmentation": np.ones(L, dtype=np.int32), + "targets_segmentation": targets_seg, + } + + +def make_olmo_grain_data_loader( + index: OlmoNpyIndex, + *, + seed: int, + batch_size: int, + shard_index: int, + shard_count: int, + apply_ngram_filter: bool = True, + shift_to_inputs_targets: bool = True, + path_remap: Optional[Dict[str, str]] = None, + grain_worker_count: int = 0, + grain_worker_buffer_size: int = 1, + initial_step: int = 0, +): + """Build a Grain ``DataLoader`` for OLMo-style fixed-seq-length training. + + Args: + index: Loaded :class:`OlmoNpyIndex`. + seed: Shuffle seed (paired with the implicit per-step ``epoch = + step // n_per_host`` to drive the per-epoch permutation). + batch_size: Per-host batch size (i.e. global_batch / shard_count). + shard_index: This host's data-loading rank. + shard_count: Total data-loading hosts. + apply_ngram_filter: Add :class:`NgramFilterTransform` (recommended). + shift_to_inputs_targets: Add :class:`ShiftToInputsTargets` so the loader + yields the ``inputs``/``targets`` shape MaxText's trainer expects. + path_remap: Pass-through to :class:`OlmoNpyDataSource`. + grain_worker_count: ``0`` runs in-process; otherwise Grain forks workers. + grain_worker_buffer_size: Per-worker batch prefetch. + initial_step: Start the *underlying sampler* at this absolute step. + The Grain DataLoader still iterates from its own 0, but every record + lookup is shifted by ``initial_step``. Set this to ``train_step * + batch_size`` on resume to pick up the data stream where it left off + *without* needing Grain's iterator-state checkpointing. + + Returns: + A ``grain.DataLoader``. + """ + source = OlmoNpyDataSource(index, path_remap=path_remap) + sampler = OlmoIndexSampler( + total_instances=index.total_instances, + seed=seed, + shard_index=shard_index, + shard_count=shard_count, + initial_step=initial_step, + ) + + ops: List[Any] = [] + if apply_ngram_filter: + ops.append(NgramFilterTransform()) + if shift_to_inputs_targets: + ops.append(ShiftToInputsTargets()) + ops.append(grain.transforms.Batch(batch_size=batch_size, drop_remainder=True)) + + # Grain expects ``shard_options`` on the DataLoader (sharding used to live + # on the Sampler). Our sampler already does the shard-by-rank slicing, but + # Grain still requires this object to validate checkpoint compatibility. + shard_options = grain.sharding.ShardOptions(shard_index=shard_index, shard_count=shard_count, drop_remainder=True) + return grain.DataLoader( + data_source=source, + sampler=sampler, + operations=ops, + shard_options=shard_options, + worker_count=grain_worker_count, + worker_buffer_size=grain_worker_buffer_size, + ) diff --git a/src/maxtext/input_pipeline/olmo_grain_data_processing.py b/src/maxtext/input_pipeline/olmo_grain_data_processing.py new file mode 100644 index 0000000000..b674aa0b90 --- /dev/null +++ b/src/maxtext/input_pipeline/olmo_grain_data_processing.py @@ -0,0 +1,195 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MaxText trainer adapter for the OLMo numpy fixed-seq-length pipeline. + +The trainer expects ``dataset_type`` to map to two factory functions +``(make__train_iterator, make__eval_iterator)`` that take +``(config, mesh, process_indices)`` and return a +:class:`MultiHostDataLoadIterator`. + +This module provides those for ``dataset_type=olmo_grain``. The hard work +lives in :mod:`maxtext.input_pipeline.olmo_data_grain` (data source + +sampler + transforms); here we just wire it to MaxText's config + the +multihost dataloading wrapper. + +Notes +----- + +* **Sequence length match**: ``config.max_target_length`` must match the + ``sequence_length`` recorded in the index JSON. Mismatches raise at load + time. +* **Path remap**: AI2's index typically holds ``gs://`` URIs. For training, + we read via a GCSFUSE mount on each TPU host. The + ``olmo_path_remap_from`` / ``olmo_path_remap_to`` config pair rewrites + the prefix at runtime. +* **Sharding**: each data-loading host is assigned a non-overlapping shard + of the global instance space via ``OlmoIndexSampler``. We use + ``process_indices.index(jax.process_index())`` as the local shard index + (matches the pattern in :mod:`grain_data_processing`). +""" + +from __future__ import annotations + +import os +from typing import List + +import jax + +from maxtext.input_pipeline import multihost_dataloading +from maxtext.input_pipeline.olmo_data import load_index +from maxtext.input_pipeline.olmo_data_grain import make_olmo_grain_data_loader +from maxtext.utils import max_logging + + +def _build_path_remap(config) -> dict: + src = getattr(config, "olmo_path_remap_from", "") or "" + dst = getattr(config, "olmo_path_remap_to", "") or "" + if src and dst: + return {src: dst} + if src or dst: + raise ValueError("olmo_path_remap_from and olmo_path_remap_to must both be set or both empty.") + return {} + + +def _detect_resumed_step(config) -> int: + """Find the latest checkpoint step in ``config.checkpoint_dir``, or 0. + + We read the directory listing directly rather than instantiating an Orbax + CheckpointManager because (a) the trainer creates its own manager later, + and (b) Orbax stores each step as an integer-named subdir + (``/``), so the latest step is just ``max(int(d) for d in listdir + if d.isdigit())``. Cheap, no I/O on the actual checkpoint payload. + + Returning 0 means "fresh run, start from the beginning of the data + stream"; any positive return value flows into ``initial_step`` on the + Grain DataLoader so resume picks up where the model state did. + """ + if not getattr(config, "enable_checkpointing", False): + return 0 + ckpt_dir = getattr(config, "checkpoint_dir", "") or "" + if not ckpt_dir or not os.path.isdir(ckpt_dir): + return 0 + steps = [] + for name in os.listdir(ckpt_dir): + if name.isdigit(): + steps.append(int(name)) + return max(steps) if steps else 0 + + +def _make_loader_for_host( + config, + *, + process_indices: List[int], + seed: int, +): + """Construct an OLMo grain DataLoader for the current data-loading host.""" + index = load_index(config.olmo_index_path) + if index.sequence_length != config.max_target_length: + raise ValueError( + f"OLMo index sequence_length={index.sequence_length} but " + f"config.max_target_length={config.max_target_length}. Either rebuild " + f"the index with the matching seq length or update the config." + ) + + this_proc = jax.process_index() + shard_index = process_indices.index(this_proc) + shard_count = len(process_indices) + + per_host_batch = config.global_batch_size_to_load // shard_count + if per_host_batch * shard_count != config.global_batch_size_to_load: + raise ValueError( + f"global_batch_size_to_load={config.global_batch_size_to_load} is not " f"divisible by shard_count={shard_count}" + ) + + # Resume = step counter from the latest checkpoint (if any) × per-host + # batch. Our sampler is stateless, so this single integer is enough to + # rejoin the stream — no Grain iterator-state serialization needed. + resumed_step = _detect_resumed_step(config) + initial_step = resumed_step * per_host_batch + + max_logging.log( + f"OLMo grain loader: index={config.olmo_index_path} " + f"total_instances={index.total_instances:,} " + f"shard={shard_index}/{shard_count} per_host_batch={per_host_batch} " + f"seq={index.sequence_length} resumed_step={resumed_step} " + f"initial_step={initial_step}" + ) + + # Worker count and per-worker buffer reuse the standard grain flags. The + # ``-1`` value of ``grain_worker_count`` is the auto-tuning sentinel for + # the standard pipeline; we don't auto-tune yet, so treat it as 0 + # (in-process) for safety. + worker_count = max(int(getattr(config, "grain_worker_count", 0) or 0), 0) + worker_buffer = int(getattr(config, "grain_per_worker_buffer_size", 1) or 1) + + return make_olmo_grain_data_loader( + index, + seed=seed, + batch_size=per_host_batch, + shard_index=shard_index, + shard_count=shard_count, + apply_ngram_filter=getattr(config, "olmo_apply_ngram_filter", True), + shift_to_inputs_targets=True, + path_remap=_build_path_remap(config), + grain_worker_count=worker_count, + grain_worker_buffer_size=worker_buffer, + initial_step=initial_step, + ) + + +def make_olmo_grain_train_iterator(config, global_mesh, process_indices): + """Train iterator for ``dataset_type=olmo_grain``.""" + if not getattr(config, "olmo_index_path", ""): + raise ValueError( + "When dataset_type=olmo_grain, please set config.olmo_index_path to " + "the JSON produced by tools/data_generation/build_olmo_npy_index.py." + ) + loader = _make_loader_for_host( + config, + process_indices=process_indices, + seed=int(getattr(config, "data_shuffle_seed", 0)), + ) + return multihost_dataloading.MultiHostDataLoadIterator( + loader, + global_mesh, + config.generate_padding_batch_train, + expansion_loading_factor_for_grain=config.expansion_factor_real_data, + ) + + +def make_olmo_grain_eval_iterator(config, global_mesh, process_indices): + """Eval iterator for ``dataset_type=olmo_grain``. + + Currently reuses the train data with a different seed: the OLMo mix is a + pretraining corpus with no canonical eval partition, so eval here means + "deterministic held-out shuffle" rather than "held-out documents". For a + real eval split, point a future ``config.eval_olmo_index_path`` at a + separate index built over different files; the rest of this function is + unchanged. + """ + if not getattr(config, "olmo_index_path", ""): + raise ValueError("When dataset_type=olmo_grain, please set config.olmo_index_path.") + loader = _make_loader_for_host( + config, + process_indices=process_indices, + # Distinct seed so eval doesn't overlap train batch order. + seed=int(getattr(config, "data_shuffle_seed", 0)) ^ 0x1F1F1F1F, + ) + return multihost_dataloading.MultiHostDataLoadIterator( + loader, + global_mesh, + config.generate_padding_batch_eval, + expansion_loading_factor_for_grain=config.expansion_factor_real_data, + ) diff --git a/tests/unit/input_pipeline/__init__.py b/tests/unit/input_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/input_pipeline/olmo_data_grain_resume_test.py b/tests/unit/input_pipeline/olmo_data_grain_resume_test.py new file mode 100644 index 0000000000..bc1f547496 --- /dev/null +++ b/tests/unit/input_pipeline/olmo_data_grain_resume_test.py @@ -0,0 +1,195 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Resume-via-step-offset tests for the OLMo Grain loader (Option B). + +Our :class:`OlmoIndexSampler` is a pure ``__getitem__(idx) → record_key`` +function of ``(seed, shard, idx + initial_step)``. That means resume +doesn't need Grain's iterator-state checkpoint; supplying +``initial_step = saved_step * per_host_batch`` at construction time is +enough. These tests pin that contract: + + A) Two loaders with the same seed but different ``initial_step`` produce + a stream where the late-starter sees exactly the same records the + early-starter did from that absolute step onward. + B) The sampler's ``__repr__`` is independent of ``initial_step`` so it + stays compatible with Grain's repr-based sampler-equality validation + (i.e. "step offset" and "Grain checkpoint" can coexist). +""" + +from __future__ import annotations + +import os +import tempfile +import unittest +from typing import List, Tuple + +import numpy as np + +from maxtext.input_pipeline.olmo_data import build_index, OlmoNpyIndex +from maxtext.input_pipeline.olmo_data_grain import ( + OlmoIndexSampler, + make_olmo_grain_data_loader, +) + + +def _write_raw_uint32(tmpdir: str, name: str, values: np.ndarray) -> str: + assert values.dtype == np.uint32 and values.ndim == 1 + path = os.path.join(tmpdir, name) + values.tofile(path) + return path + + +def _build_synthetic_index(tmpdir: str, *, sizes: Tuple[int, ...], sequence_length: int) -> OlmoNpyIndex: + """Build a small in-memory OLMo index over raw-binary uint32 token files.""" + paths: List[str] = [] + start = 0 + for i, n in enumerate(sizes): + arr = np.arange(start, start + n, dtype=np.uint32) + paths.append(_write_raw_uint32(tmpdir, f"f_{i:03d}.bin", arr)) + start += n + + def _reader(path: str): + return "uint32", (os.path.getsize(path) // 4,) + + return build_index( + [(p, "lab") for p in paths], + sequence_length=sequence_length, + tokenizer="test", + header_reader=_reader, + ) + + +def _take(iterator, n: int): + return [next(iterator) for _ in range(n)] + + +def _assert_batch_equal(a: dict, b: dict, msg: str = "") -> None: + assert set(a.keys()) == set(b.keys()), f"{msg}: key set differs: {sorted(a)} vs {sorted(b)}" + for k in a: + np.testing.assert_array_equal(a[k], b[k], err_msg=f"{msg}: batch field {k!r} differs") + + +class SamplerInitialStepTest(unittest.TestCase): + + def test_offset_matches_unbroken_stream(self): + """``OlmoIndexSampler(initial_step=N)[i]`` must equal an unbroken + sampler's record at absolute index ``N + i`` for any (i, N).""" + s_full = OlmoIndexSampler(total_instances=128, seed=11, shard_count=1) + for n in [0, 1, 7, 23, 127]: + s_offset = OlmoIndexSampler(total_instances=128, seed=11, shard_count=1, initial_step=n) + for i in [0, 1, 5, 17]: + self.assertEqual( + s_offset[i].record_key, + s_full[n + i].record_key, + msg=f"mismatch at initial_step={n}, i={i}", + ) + + def test_initial_step_does_not_change_repr(self): + """Grain validates samplers by ``repr(sampler)`` on resume; the offset + is a runtime cursor, not part of the sampler's identity.""" + a = OlmoIndexSampler(total_instances=64, seed=1, shard_count=1) + b = OlmoIndexSampler(total_instances=64, seed=1, shard_count=1, initial_step=42) + self.assertEqual(repr(a), repr(b)) + + def test_negative_initial_step_raises(self): + with self.assertRaises(ValueError): + OlmoIndexSampler(total_instances=8, seed=0, initial_step=-1) + + def test_offset_crosses_epoch_boundary(self): + """An offset large enough to roll the epoch must use the next epoch's + permutation, not wrap inside epoch 0.""" + n = 32 + s = OlmoIndexSampler(total_instances=n, seed=2, shard_count=1) + # epoch 0 has permutation P0; epoch 1 has P1. + p0 = s.shard_indices(seed=2, epoch=0) + p1 = s.shard_indices(seed=2, epoch=1) + + # offset just past end of epoch 0 → first lookups should come from P1. + s_off = OlmoIndexSampler(total_instances=n, seed=2, shard_count=1, initial_step=n + 3) + self.assertEqual(s_off[0].record_key, int(p1[3])) + self.assertEqual(s_off[1].record_key, int(p1[4])) + # And the canonical first 3 of epoch 1 are NOT visited (because we + # started 3 records into epoch 1). + self.assertNotIn(int(p0[0]), [s_off[i].record_key for i in range(3)]) + + +class LoaderResumeViaInitialStepTest(unittest.TestCase): + + def test_loader_offset_matches_unbroken_loader(self): + """A fresh loader with ``initial_step=K`` produces the same batches as + an unbroken loader's batches starting at the K-th batch.""" + with tempfile.TemporaryDirectory() as d: + # 256 tokens at seq=4 → 64 instances. Batch=4 → 16 batches/epoch. + idx = _build_synthetic_index(d, sizes=(256,), sequence_length=4) + common_kwargs = { + "seed": 7, + "batch_size": 4, + "shard_index": 0, + "shard_count": 1, + "apply_ngram_filter": False, + "shift_to_inputs_targets": True, + "grain_worker_count": 0, + } + + # Reference: 30 batches uninterrupted. + ref = make_olmo_grain_data_loader(idx, initial_step=0, **common_kwargs) + it_ref = iter(ref) + ref_batches = _take(it_ref, 30) + + # Resume: skip the first 15 batches by starting the *sampler* at + # absolute index 15 * batch_size = 60. + resumed = make_olmo_grain_data_loader( + idx, + initial_step=15 * common_kwargs["batch_size"], + **common_kwargs, + ) + it_res = iter(resumed) + res_batches = _take(it_res, 15) + + for i, (a, b) in enumerate(zip(ref_batches[15:], res_batches)): + _assert_batch_equal(a, b, msg=f"resumed batch {i}") + + def test_resume_works_across_an_epoch_boundary(self): + """The harder case: ``initial_step`` lands past the end of epoch 0.""" + with tempfile.TemporaryDirectory() as d: + # 128 tokens at seq=4 → 32 instances. Batch=4 → 8 batches/epoch. + idx = _build_synthetic_index(d, sizes=(128,), sequence_length=4) + common_kwargs = { + "seed": 3, + "batch_size": 4, + "shard_index": 0, + "shard_count": 1, + "apply_ngram_filter": False, + "shift_to_inputs_targets": True, + } + + # Take 12 batches uninterrupted (covers ~1.5 epochs). + ref = make_olmo_grain_data_loader(idx, initial_step=0, **common_kwargs) + ref_batches = _take(iter(ref), 12) + + # Skip 10 batches (= 40 instances, well past epoch boundary at 32). + resumed = make_olmo_grain_data_loader( + idx, + initial_step=10 * common_kwargs["batch_size"], + **common_kwargs, + ) + res_batches = _take(iter(resumed), 2) + + for i, (a, b) in enumerate(zip(ref_batches[10:], res_batches)): + _assert_batch_equal(a, b, msg=f"epoch-cross batch {i}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/input_pipeline/olmo_data_grain_test.py b/tests/unit/input_pipeline/olmo_data_grain_test.py new file mode 100644 index 0000000000..08a493d1c3 --- /dev/null +++ b/tests/unit/input_pipeline/olmo_data_grain_test.py @@ -0,0 +1,353 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``maxtext.input_pipeline.olmo_data_grain``. + +Covers test plan items A.3-A.7, A.9, A.10: + A.3 data source __getitem__ returns the right tokens + A.4 data source shape/dtype contract + A.5 (deferred to equivalence test once Option 1 lands) + A.6 sampler partitions the global index space exactly + A.7 reshuffle determinism + fingerprint + A.9 checkpoint / restore — get_state() round-trip + A.10 iteration termination at end-of-epoch and epoch roll-over + +All tests use *synthetic* raw-binary token files (matching AI2's headerless +``.npy`` layout) so they're fast and deterministic. +""" + +from __future__ import annotations + +import os +import tempfile +import unittest +from typing import List, Tuple + +import numpy as np + +from maxtext.input_pipeline.olmo_data import build_index, OlmoNpyIndex +from maxtext.input_pipeline.olmo_data_grain import ( + NgramFilterTransform, + OlmoIndexSampler, + OlmoNpyDataSource, + ShiftToInputsTargets, + _combine_seed_epoch, + make_olmo_grain_data_loader, +) + + +def _write_raw_uint32(tmpdir: str, name: str, values: np.ndarray) -> str: + """Write a 1-D uint32 array as raw binary (no .npy header) — matches AI2.""" + assert values.dtype == np.uint32 and values.ndim == 1 + path = os.path.join(tmpdir, name) + values.tofile(path) + return path + + +def _build_synthetic_index( + tmpdir: str, + *, + sizes: Tuple[int, ...], + sequence_length: int, +) -> Tuple[OlmoNpyIndex, List[str]]: + """Make ``len(sizes)`` raw-binary uint32 files and an index over them. + + Tokens in file ``i`` are ``[start_i, start_i + size_i)`` so every global + token has a unique value, making chunk-extraction tests easy to assert. + """ + paths: List[str] = [] + start = 0 + for i, n in enumerate(sizes): + arr = np.arange(start, start + n, dtype=np.uint32) + paths.append(_write_raw_uint32(tmpdir, f"f_{i:03d}.bin", arr)) + start += n + + def _reader(path: str): + return "uint32", (os.path.getsize(path) // 4,) + + idx = build_index( + [(p, "lab") for p in paths], + sequence_length=sequence_length, + tokenizer="test", + header_reader=_reader, + ) + return idx, paths + + +class DataSourceTest(unittest.TestCase): + + def test_getitem_returns_correct_tokens(self): + """A.3: instance i covers the tokens we expect from the synthetic mix.""" + with tempfile.TemporaryDirectory() as d: + sizes = (12, 8, 16) + seq = 4 + idx, _ = _build_synthetic_index(d, sizes=sizes, sequence_length=seq) + ds = OlmoNpyDataSource(idx) + + self.assertEqual(len(ds), idx.total_instances) + + # Build the expected concatenation: file0 starts at 0, file1 at 12, + # file2 at 20. + starts = [0] + for s in sizes: + starts.append(starts[-1] + s) + # Rebuild a virtual stream that drops trailing remainder per file. + dropped: List[np.ndarray] = [] + for fi, s in enumerate(sizes): + full = np.arange(starts[fi], starts[fi] + s, dtype=np.uint32) + keep = (s // seq) * seq + dropped.append(full[:keep]) + concat = np.concatenate(dropped) + + # Every global instance i should equal the i-th seq-length window of the + # concatenated, remainder-dropped stream. + for i, item in enumerate(ds): + self.assertEqual(item["tokens"].dtype, np.uint32) + self.assertEqual(item["tokens"].shape, (seq,)) + np.testing.assert_array_equal(item["tokens"], concat[i * seq : (i + 1) * seq]) + self.assertEqual(item["instance_id"], i) + # file_id should match the file the instance lives in. + self.assertGreaterEqual(item["file_id"], 0) + self.assertLess(item["file_id"], len(sizes)) + + def test_returned_token_array_is_safe_to_mutate(self): + """A.4: the returned ``tokens`` array is a copy, not a view into mmap.""" + with tempfile.TemporaryDirectory() as d: + idx, _ = _build_synthetic_index(d, sizes=(16,), sequence_length=4) + ds = OlmoNpyDataSource(idx) + a = ds[0]["tokens"] + original = a.copy() + a[0] = 99999 + b = ds[0]["tokens"] + np.testing.assert_array_equal(b, original) + + def test_path_remap(self): + with tempfile.TemporaryDirectory() as d: + idx, paths = _build_synthetic_index(d, sizes=(8,), sequence_length=4) + # Move the file so the original path doesn't work, then remap. + moved = paths[0] + ".moved" + os.rename(paths[0], moved) + + ds_no_remap = OlmoNpyDataSource(idx) + with self.assertRaises(FileNotFoundError): + _ = ds_no_remap[0] + + ds = OlmoNpyDataSource(idx, path_remap={paths[0]: moved}) + self.assertEqual(ds[0]["tokens"].shape, (4,)) + + +class SamplerTest(unittest.TestCase): + + def test_sharding_partition_is_disjoint_and_complete(self): + """A.6: with shard_count=N, the union of the N hosts' shard_indices + covers every global index in the per-epoch shuffle exactly once, + after dropping the trailing remainder.""" + total = 23 + n_shards = 4 + s_full = OlmoIndexSampler(total_instances=total, seed=42, shard_index=0, shard_count=1) + full = s_full.shuffled_global_indices(seed=42, epoch=0) + self.assertEqual(len(full), total) + + seen = [] + for shard in range(n_shards): + s = OlmoIndexSampler( + total_instances=total, + seed=42, + shard_index=shard, + shard_count=n_shards, + ) + seen.append(s.shard_indices(seed=42, epoch=0)) + + cat = np.concatenate(seen) + # 23 // 4 = 5 per shard, 4*5 = 20 covered, 3 trailing dropped. + self.assertEqual(len(cat), n_shards * (total // n_shards)) + # No duplicates. + self.assertEqual(len(np.unique(cat)), len(cat)) + # All from the global shuffle's first 20 entries. + np.testing.assert_array_equal(np.sort(cat), np.sort(full[:20])) + + def test_reshuffle_determinism(self): + """A.7: same (seed, epoch) ⇒ same shuffle; different epoch ⇒ different.""" + s = OlmoIndexSampler(total_instances=100, seed=7) + a = s.shuffled_global_indices(seed=7, epoch=0) + b = s.shuffled_global_indices(seed=7, epoch=0) + np.testing.assert_array_equal(a, b) + + c = s.shuffled_global_indices(seed=7, epoch=1) + self.assertFalse(np.array_equal(a, c)) + + d = s.shuffled_global_indices(seed=8, epoch=0) + self.assertFalse(np.array_equal(a, d)) + + def test_combine_seed_epoch_distinguishes_inputs(self): + self.assertEqual(_combine_seed_epoch(0, 0), _combine_seed_epoch(0, 0)) + self.assertNotEqual(_combine_seed_epoch(0, 0), _combine_seed_epoch(0, 1)) + self.assertNotEqual(_combine_seed_epoch(0, 0), _combine_seed_epoch(1, 0)) + + def test_getitem_emits_each_local_index_per_epoch_then_rolls(self): + """A.10: ``sampler[i]`` matches the i-th element of this host's shard, + rolling over to the next epoch's shuffle at i == per_epoch.""" + total = 16 + s = OlmoIndexSampler( + total_instances=total, + seed=3, + shard_index=1, + shard_count=4, + ) + expected_per_epoch = total // 4 # = 4 + expected_e0 = list(s.shard_indices(seed=3, epoch=0)) + expected_e1 = list(s.shard_indices(seed=3, epoch=1)) + + seen = [int(s[i].record_key) for i in range(expected_per_epoch * 2)] + self.assertEqual(seen, expected_e0 + expected_e1) + + # Index field equals the input. + for i in range(expected_per_epoch * 2): + self.assertEqual(s[i].index, i) + + def test_getitem_negative_raises(self): + s = OlmoIndexSampler(total_instances=8, seed=0) + with self.assertRaises(IndexError): + _ = s[-1] + + +class SamplerCheckpointTest(unittest.TestCase): + """A.9: with a __getitem__-style sampler, the only checkpoint state is the + global step counter (Grain handles persisting that). The sampler itself is + stateless across runs given the same (seed, shard_options).""" + + def test_resume_index_yields_same_record_key(self): + """Two independent samplers with the same config + same index produce + the same RecordMetadata — i.e. there is no hidden mutable state that + would differ across an in-process restart.""" + s_a = OlmoIndexSampler(total_instances=40, seed=11, shard_count=1, shard_index=0) + s_b = OlmoIndexSampler(total_instances=40, seed=11, shard_count=1, shard_index=0) + for i in [0, 1, 2, 7, 39, 40, 79, 80, 1000]: + self.assertEqual(s_a[i].record_key, s_b[i].record_key) + self.assertEqual(s_a[i].index, i) + + def test_different_seeds_diverge(self): + s_a = OlmoIndexSampler(total_instances=40, seed=11, shard_count=1, shard_index=0) + s_b = OlmoIndexSampler(total_instances=40, seed=12, shard_count=1, shard_index=0) + keys_a = [s_a[i].record_key for i in range(40)] + keys_b = [s_b[i].record_key for i in range(40)] + self.assertNotEqual(keys_a, keys_b) + self.assertEqual(sorted(keys_a), sorted(keys_b)) # same set, different order + + +class TransformsTest(unittest.TestCase): + + def test_ngram_filter_marks_clean(self): + rng = np.random.default_rng(0) + arr = rng.integers(0, 10_000, size=512, dtype=np.uint32) + out = NgramFilterTransform().map({"tokens": arr, "instance_id": 0, "file_id": 0}) + self.assertTrue(out["instance_mask"]) + + def test_ngram_filter_marks_dirty(self): + arr = np.full(200, 7, dtype=np.uint32) # period=1, repeats=200, dirty + out = NgramFilterTransform().map({"tokens": arr, "instance_id": 0, "file_id": 0}) + self.assertFalse(out["instance_mask"]) + + def test_shift_to_inputs_targets_clean(self): + arr = np.arange(8, dtype=np.uint32) + out = ShiftToInputsTargets().map({"tokens": arr, "instance_id": 0, "file_id": 0, "instance_mask": True}) + # Only rank-2 (batch, seq) tensors are returned — scalar metadata is + # dropped to satisfy the trainer's 2D sharding contract. + self.assertEqual( + set(out.keys()), + { + "inputs", + "targets", + "inputs_position", + "inputs_segmentation", + "targets_segmentation", + }, + ) + self.assertEqual(out["inputs"].dtype, np.int32) + self.assertEqual(out["targets"].dtype, np.int32) + # All outputs have length L (= len(tokens)) so the trainer sees + # ``max_target_length`` exactly (required by splash attention kernel + # block-size constraints). + np.testing.assert_array_equal(out["inputs"], np.arange(8, dtype=np.int32)) + # Targets shifted by 1; last position padded with 0 and masked out below. + np.testing.assert_array_equal(out["targets"], np.array([1, 2, 3, 4, 5, 6, 7, 0], dtype=np.int32)) + np.testing.assert_array_equal(out["inputs_position"], np.arange(8, dtype=np.int32)) + np.testing.assert_array_equal(out["inputs_segmentation"], np.ones(8, dtype=np.int32)) + # Last target position masked even when instance is clean. + np.testing.assert_array_equal(out["targets_segmentation"], np.array([1, 1, 1, 1, 1, 1, 1, 0], dtype=np.int32)) + + def test_shift_to_inputs_targets_dirty_zeros_target_segmentation(self): + arr = np.arange(8, dtype=np.uint32) + out = ShiftToInputsTargets().map({"tokens": arr, "instance_id": 0, "file_id": 0, "instance_mask": False}) + # inputs_segmentation still 1 (data is fed to the model); the dirty + # instance flag zeroes the entire targets_segmentation row. + np.testing.assert_array_equal(out["inputs_segmentation"], np.ones(8, dtype=np.int32)) + np.testing.assert_array_equal(out["targets_segmentation"], np.zeros(8, dtype=np.int32)) + + +class FactoryTest(unittest.TestCase): + + def test_make_olmo_grain_data_loader_yields_batches(self): + with tempfile.TemporaryDirectory() as d: + idx, _ = _build_synthetic_index(d, sizes=(16, 16, 16), sequence_length=4) + # 12 global instances, 1 host, batch 2 → 6 batches per epoch. + loader = make_olmo_grain_data_loader( + idx, + seed=0, + batch_size=2, + shard_index=0, + shard_count=1, + apply_ngram_filter=True, + shift_to_inputs_targets=True, + ) + batches = [] + for i, batch in enumerate(loader): + batches.append(batch) + if i >= 5: + break + self.assertEqual(len(batches), 6) + for b in batches: + # Inputs and targets are both length seq = 4; targets are shifted + # within that window with the last position padded + masked. + self.assertEqual(b["inputs"].shape, (2, 4)) + self.assertEqual(b["targets"].shape, (2, 4)) + self.assertEqual(b["targets_segmentation"].shape, (2, 4)) + + def test_two_workers_preserve_record_count(self): + with tempfile.TemporaryDirectory() as d: + idx, _ = _build_synthetic_index(d, sizes=(64,), sequence_length=4) + # 16 instances; batch=4 ⇒ 4 batches/epoch. + loader = make_olmo_grain_data_loader( + idx, + seed=0, + batch_size=4, + shard_index=0, + shard_count=1, + apply_ngram_filter=False, + # Disable shift so we can read the raw ``instance_id`` field for + # this audit-style test. (Production runs always use shift=True.) + shift_to_inputs_targets=False, + grain_worker_count=2, + ) + ids = [] + for i, batch in enumerate(loader): + ids.extend(batch["instance_id"].tolist()) + if i >= 3: + break + self.assertEqual(len(ids), 16) + self.assertEqual(sorted(ids), list(range(16))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/input_pipeline/olmo_data_test.py b/tests/unit/input_pipeline/olmo_data_test.py new file mode 100644 index 0000000000..8e779f76c3 --- /dev/null +++ b/tests/unit/input_pipeline/olmo_data_test.py @@ -0,0 +1,350 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``maxtext.input_pipeline.olmo_data``. + +Covers test plan items A.1 (index correctness), A.2 (global→local mapping), +A.7 (fingerprint stability), and A.8 (n-gram instance filter). Tests for +the loader paths (Options 1 & 2) live alongside their respective modules. +""" + +from __future__ import annotations + +import os +import tempfile +import unittest +from typing import Tuple + +import numpy as np + +from maxtext.input_pipeline.olmo_data import ( + OlmoNpyFileEntry, + build_index, + compute_fingerprint, + find_periodic_sequences, + global_to_local, + is_clean_instance, + load_index, + read_npy_header_from_path, +) + + +def _write_uint32_npy(tmpdir: str, name: str, values: np.ndarray) -> str: + """Write a 1-D uint32 array to ``tmpdir/name`` and return the path.""" + assert values.dtype == np.uint32 + assert values.ndim == 1 + path = os.path.join(tmpdir, name) + np.save(path, values) + # numpy adds .npy if missing + return path if path.endswith(".npy") else path + ".npy" + + +def _make_synthetic_mix(tmpdir: str, sizes: Tuple[int, ...]) -> Tuple[str, ...]: + """Make a mix of files where file i contains tokens [start_i .. start_i + size_i).""" + paths = [] + start = 0 + for i, n in enumerate(sizes): + arr = np.arange(start, start + n, dtype=np.uint32) + paths.append(_write_uint32_npy(tmpdir, f"file_{i:03d}.npy", arr)) + start += n + return tuple(paths) + + +def _stub_reader(spec): + """Return a header_reader that pretends each path has the given (dtype, shape).""" + + def _reader(path: str): + return spec[path] + + return _reader + + +class IndexCorrectnessTest(unittest.TestCase): + """A.1: index counts and offsets match the underlying files.""" + + def test_index_total_counts_match_files(self): + with tempfile.TemporaryDirectory() as d: + sizes = (10, 7, 25, 16) + paths = _make_synthetic_mix(d, sizes) + seq = 4 + + idx = build_index( + [(p, "label") for p in paths], + sequence_length=seq, + tokenizer="test", + ) + + self.assertEqual(idx.sequence_length, seq) + self.assertEqual(idx.dtype, "uint32") + self.assertEqual(idx.total_tokens, sum(sizes)) + # Trailing tokens are dropped per OLMo-core convention. + expected_instances = sum(s // seq for s in sizes) + self.assertEqual(idx.total_instances, expected_instances) + + cum = 0 + for entry, n_tokens in zip(idx.files, sizes): + self.assertEqual(entry.n_tokens, n_tokens) + self.assertEqual(entry.n_instances, n_tokens // seq) + self.assertEqual(entry.instance_offset, cum) + cum += entry.n_instances + + def test_round_trip_save_load(self): + with tempfile.TemporaryDirectory() as d: + paths = _make_synthetic_mix(d, (32, 64)) + idx = build_index([(p, "lab") for p in paths], sequence_length=8, tokenizer="t") + out = os.path.join(d, "idx.json") + idx.save(out) + restored = load_index(out) + self.assertEqual(restored.fingerprint, idx.fingerprint) + self.assertEqual(restored.total_instances, idx.total_instances) + self.assertEqual(len(restored.files), len(idx.files)) + for a, b in zip(restored.files, idx.files): + self.assertEqual(a, b) + + def test_dtype_mismatch_raises(self): + spec = { + "a.npy": ("uint32", (10,)), + "b.npy": ("uint16", (10,)), + } + with self.assertRaisesRegex(ValueError, "Heterogeneous"): + build_index( + [("a.npy", "x"), ("b.npy", "y")], + sequence_length=4, + tokenizer="t", + header_reader=_stub_reader(spec), + ) + + def test_non_1d_raises(self): + spec = {"a.npy": ("uint32", (10, 2))} + with self.assertRaisesRegex(ValueError, "1-D"): + build_index( + [("a.npy", "x")], + sequence_length=4, + tokenizer="t", + header_reader=_stub_reader(spec), + ) + + def test_empty_paths_raises(self): + with self.assertRaisesRegex(ValueError, "non-empty"): + build_index([], sequence_length=4, tokenizer="t") + + +class GlobalToLocalTest(unittest.TestCase): + """A.2: every global instance maps to the expected (file, token offset).""" + + def setUp(self): + self.spec = { + "a.npy": ("uint32", (10,)), # 2 instances at seq=4 + "b.npy": ("uint32", (7,)), # 1 instance + "c.npy": ("uint32", (25,)), # 6 instances + } + self.idx = build_index( + [("a.npy", "x"), ("b.npy", "y"), ("c.npy", "z")], + sequence_length=4, + tokenizer="t", + header_reader=_stub_reader(self.spec), + ) + + def test_first_index_each_file(self): + self.assertEqual(global_to_local(self.idx, 0), (0, 0)) + self.assertEqual(global_to_local(self.idx, 2), (1, 0)) + self.assertEqual(global_to_local(self.idx, 3), (2, 0)) + + def test_last_index_each_file(self): + # File 0 has 2 instances → last is global 1, local-token 4. + self.assertEqual(global_to_local(self.idx, 1), (0, 4)) + # File 1 has 1 instance → only index 2. + # File 2 last instance is global 8, local-token 20. + self.assertEqual(global_to_local(self.idx, 8), (2, 20)) + + def test_full_partition_is_a_function(self): + # For each global index, the (file, offset) pair is well-defined and + # within bounds of the file's instance count. + for i in range(self.idx.total_instances): + file_idx, tok_off = global_to_local(self.idx, i) + f = self.idx.files[file_idx] + local_inst = tok_off // self.idx.sequence_length + self.assertGreaterEqual(local_inst, 0) + self.assertLess(local_inst, f.n_instances) + + def test_out_of_range_raises(self): + with self.assertRaises(IndexError): + global_to_local(self.idx, -1) + with self.assertRaises(IndexError): + global_to_local(self.idx, self.idx.total_instances) + + def test_real_npy_round_trip(self): + """Build index from real .npy headers, then read instance i and confirm + its first/last token match what the global index predicts.""" + with tempfile.TemporaryDirectory() as d: + sizes = (10, 7, 25) + paths = _make_synthetic_mix(d, sizes) + seq = 4 + idx = build_index([(p, "lab") for p in paths], sequence_length=seq, tokenizer="t") + # Recall: file i contains sequential ints starting at sum(sizes[:i]). + starts = [0] + for s in sizes: + starts.append(starts[-1] + s) + for i in range(idx.total_instances): + file_idx, tok_off = global_to_local(idx, i) + path = idx.files[file_idx].path + arr = np.load(path) + chunk = arr[tok_off : tok_off + seq] + # Tokens are sequential ints; the chunk's first value should be the + # right global token offset. + global_tok_off_expected = starts[file_idx] + tok_off + self.assertEqual(int(chunk[0]), global_tok_off_expected) + self.assertEqual(int(chunk[-1]), global_tok_off_expected + seq - 1) + + +class FingerprintTest(unittest.TestCase): + """A.7: the fingerprint changes iff the relevant inputs change.""" + + def _entries(self): + return ( + OlmoNpyFileEntry("a.npy", "x", n_tokens=10, n_instances=2, instance_offset=0), + OlmoNpyFileEntry("b.npy", "y", n_tokens=8, n_instances=2, instance_offset=2), + ) + + def test_same_inputs_same_fingerprint(self): + e = self._entries() + self.assertEqual( + compute_fingerprint(seq_ := 8, "uint32", "tok", e), + compute_fingerprint(seq_, "uint32", "tok", e), + ) + + def test_different_seq_changes_fingerprint(self): + e = self._entries() + a = compute_fingerprint(8, "uint32", "tok", e) + b = compute_fingerprint(16, "uint32", "tok", e) + self.assertNotEqual(a, b) + + def test_different_dtype_changes_fingerprint(self): + e = self._entries() + self.assertNotEqual( + compute_fingerprint(8, "uint32", "tok", e), + compute_fingerprint(8, "uint16", "tok", e), + ) + + def test_different_tokenizer_changes_fingerprint(self): + e = self._entries() + self.assertNotEqual( + compute_fingerprint(8, "uint32", "alpha", e), + compute_fingerprint(8, "uint32", "beta", e), + ) + + def test_file_reorder_changes_fingerprint(self): + a, b = self._entries() + e1 = (a, b) + e2 = (b, a) + self.assertNotEqual( + compute_fingerprint(8, "uint32", "tok", e1), + compute_fingerprint(8, "uint32", "tok", e2), + ) + + def test_file_size_change_changes_fingerprint(self): + a, b = self._entries() + a2 = OlmoNpyFileEntry( + a.path, + a.label, + n_tokens=a.n_tokens + 1, + n_instances=a.n_instances, + instance_offset=a.instance_offset, + ) + self.assertNotEqual( + compute_fingerprint(8, "uint32", "tok", (a, b)), + compute_fingerprint(8, "uint32", "tok", (a2, b)), + ) + + +class NpyHeaderTest(unittest.TestCase): + + def test_header_round_trip(self): + with tempfile.TemporaryDirectory() as d: + path = _write_uint32_npy(d, "h.npy", np.arange(123, dtype=np.uint32)) + dtype, shape = read_npy_header_from_path(path) + self.assertEqual(dtype, "uint32") + self.assertEqual(shape, (123,)) + + +class NgramFilterTest(unittest.TestCase): + """A.8: ``is_clean_instance`` flags excessive periodic repetition.""" + + def test_clean_random_input(self): + rng = np.random.default_rng(0) + arr = rng.integers(low=0, high=10_000, size=2048, dtype=np.uint32) + # Random tokens are extremely unlikely to have 32+ repetitions of any + # period in [1, 13]. + self.assertTrue(is_clean_instance(arr)) + + def test_dirty_period_one(self): + # A run of 100 identical tokens is period=1, times>=32 → dirty. + arr = np.concatenate( + [ + np.arange(50, dtype=np.uint32), + np.full(100, 7, dtype=np.uint32), + np.arange(50, dtype=np.uint32) + 1000, + ] + ) + self.assertFalse(is_clean_instance(arr)) + + def test_dirty_period_three(self): + # Repeat the pattern (1, 2, 3) forty times → period=3, times=40 → dirty. + pattern = np.array([1, 2, 3], dtype=np.uint32) + repeats = np.tile(pattern, 40) + surround = np.array([99, 88, 77, 66, 55], dtype=np.uint32) + arr = np.concatenate([surround, repeats, surround]) + self.assertFalse(is_clean_instance(arr)) + + def test_below_threshold_is_clean(self): + # 10 repetitions of a 3-gram is well below 32. + pattern = np.array([1, 2, 3], dtype=np.uint32) + arr = np.concatenate( + [ + np.arange(20, dtype=np.uint32) + 100, + np.tile(pattern, 10), + np.arange(20, dtype=np.uint32) + 200, + ] + ) + self.assertTrue(is_clean_instance(arr)) + + def test_period_above_max_ignored(self): + # Period 50 > default repetition_max_period (13). Even a long repeat is + # ignored. Construct a 50-token unit repeated 40 times = 2000 tokens. + rng = np.random.default_rng(1) + unit = rng.integers(0, 10_000, size=50, dtype=np.uint32) + arr = np.tile(unit, 40) + # Use the default config (max_period=13) → should be clean. + self.assertTrue(is_clean_instance(arr)) + # If we raise max_period to 50 we should now flag it. + self.assertFalse(is_clean_instance(arr, repetition_max_period=50, repetition_max_count=32)) + + def test_find_periodic_sequences_smoke(self): + arr = np.concatenate( + [ + np.array([5, 6, 7], dtype=np.uint32), + np.tile(np.array([1, 2], dtype=np.uint32), 5), + np.array([8, 9, 10], dtype=np.uint32), + ] + ) + matches = list(find_periodic_sequences(arr, max_period=5)) + # Expect at least one match with period=2 covering the [1,2]*5 region. + found = [m for m in matches if m.period == 2 and m.times >= 3] + self.assertTrue(found, f"no period=2 matches; got {matches}") + m = found[0] + self.assertEqual(arr[m.start : m.end].tolist(), [1, 2] * m.times) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/data_generation/build_olmo_npy_index.py b/tools/data_generation/build_olmo_npy_index.py new file mode 100755 index 0000000000..06a4f3423a --- /dev/null +++ b/tools/data_generation/build_olmo_npy_index.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Build an OLMo-style numpy mix index from a mix file. + +Reads each ``.npy`` file's header (small range read from GCS, no download of +the array data) and writes a JSON index suitable for +``maxtext.input_pipeline.olmo_data.load_index``. + +Usage: + + python tools/data_generation/build_olmo_npy_index.py \\ + --mix-file /home/.../OLMo-mix-0925-official.txt \\ + --gcs-base gs://my-bucket/dataset/ \\ + --tokenizer allenai/dolma3-tokenizer \\ + --sequence-length 8192 \\ + --output /tmp/olmo_index.json \\ + --workers 32 + +The mix file format is the same as AI2's OLMo-core data mix files: + + label,relative/path/{TOKENIZER}/...000000.npy + +``{TOKENIZER}`` is substituted with the value of ``--tokenizer``. +""" + +from __future__ import annotations + +import argparse +import io +import os +import sys +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Tuple + +# Allow running from the repo root with src/ on the path. +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../src")) + +from maxtext.input_pipeline.olmo_data import ( # noqa: E402 + build_index, + has_npy_magic, + parse_npy_header, + read_raw_metadata_from_path, +) +import numpy as np # noqa: E402 + + +def parse_mix_file(mix_path: str, tokenizer: str) -> List[Tuple[str, str]]: + """Parse an OLMo data-mix .txt file. Returns list of (label, rel_path).""" + entries: List[Tuple[str, str]] = [] + with open(mix_path, encoding="utf-8") as f: + for line_num, line in enumerate(f, start=1): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split(",", maxsplit=1) + if len(parts) != 2: + print( + f"WARNING: {mix_path}:{line_num}: skipping malformed line: {line!r}", + file=sys.stderr, + ) + continue + label, rel_path = parts + entries.append((label, rel_path.replace("{TOKENIZER}", tokenizer))) + return entries + + +def _split_gs_uri(uri: str) -> Tuple[str, str]: + if not uri.startswith("gs://"): + raise ValueError(f"Not a gs:// URI: {uri}") + path = uri[len("gs://") :] + bucket, _, key = path.partition("/") + if not bucket or not key: + raise ValueError(f"Malformed gs:// URI: {uri}") + return bucket, key + + +def _read_gcs_prefix(client, uri: str, n_bytes: int = 4096) -> bytes: + """Range-read the first ``n_bytes`` of a GCS object. ~1 GCS roundtrip.""" + bucket_name, blob_name = _split_gs_uri(uri) + blob = client.bucket(bucket_name).blob(blob_name) + # download_as_bytes(start=, end=) does a Range request. ``end`` is inclusive. + return blob.download_as_bytes(start=0, end=n_bytes - 1) + + +def _gcs_blob_size(client, uri: str) -> int: + """Return blob size in bytes (single metadata roundtrip, no data read).""" + bucket_name, blob_name = _split_gs_uri(uri) + blob = client.bucket(bucket_name).blob(blob_name) + blob.reload() + if blob.size is None: + raise RuntimeError(f"GCS blob {uri} has unknown size") + return int(blob.size) + + +def make_header_reader(dtype_for_raw: str): + """Return a header_reader(path) -> (dtype, shape) accepting gs:// or local. + + Auto-detects whether each file is a real ``.npy`` (has the magic bytes and + a parseable header) or AI2's headerless raw binary (the OLMo `.npy` files). + + In raw mode we trust ``dtype_for_raw`` and compute ``n_tokens`` from the + blob size — no bytes need to be downloaded for the array data. + """ + client_holder = {"client": None} + + def _get_client(): + if client_holder["client"] is None: + # Lazy import: this script can run in dry-run / local-only mode without + # google-cloud-storage installed. + from google.cloud import storage # pylint: disable=import-outside-toplevel + + client_holder["client"] = storage.Client() + return client_holder["client"] + + def _gcs_reader(uri: str): + client = _get_client() + head = _read_gcs_prefix(client, uri, n_bytes=8) + if has_npy_magic(head): + # Real .npy — fetch enough bytes to cover the header. + header_bytes = _read_gcs_prefix(client, uri, n_bytes=4096) + return parse_npy_header(io.BytesIO(header_bytes)) + # Raw binary: derive shape from blob size + dtype itemsize. + size = _gcs_blob_size(client, uri) + itemsize = np.dtype(dtype_for_raw).itemsize + if size % itemsize != 0: + raise ValueError( + f"GCS blob {uri} size {size} is not a multiple of dtype " f"{dtype_for_raw} itemsize ({itemsize})." + ) + return dtype_for_raw, (size // itemsize,) + + def _local_reader(path: str): + with open(path, "rb") as fh: + head = fh.read(8) + if has_npy_magic(head): + return parse_npy_header(io.BytesIO(open(path, "rb").read())) + return read_raw_metadata_from_path(path, dtype_for_raw) + + def _reader(path: str): + if path.startswith("gs://"): + return _gcs_reader(path) + return _local_reader(path) + + return _reader + + +def _scan_one(reader, idx: int, label: str, path: str): + dtype, shape = reader(path) + return idx, label, path, dtype, shape + + +def scan_headers_parallel( + paths_and_labels: List[Tuple[str, str]], + *, + dtype_for_raw: str, + workers: int = 32, + progress_every: int = 50, +) -> List[Tuple[str, str, str, Tuple[int, ...]]]: + """Read .npy headers for all entries in parallel; preserve input order. + + Returns a list of (label, path, dtype, shape) tuples in the same order as + the input. + """ + reader = make_header_reader(dtype_for_raw=dtype_for_raw) + results: List[Tuple[int, str, str, str, Tuple[int, ...]]] = [None] * len(paths_and_labels) # type: ignore[list-item] + start = time.time() + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = {pool.submit(_scan_one, reader, i, label, path): i for i, (label, path) in enumerate(paths_and_labels)} + done = 0 + for fut in as_completed(futures): + idx, label, path, dtype, shape = fut.result() + results[idx] = (idx, label, path, dtype, shape) + done += 1 + if done % progress_every == 0 or done == len(paths_and_labels): + elapsed = time.time() - start + print( + f" scanned {done}/{len(paths_and_labels)} headers ({elapsed:.0f}s)", + file=sys.stderr, + flush=True, + ) + # Drop the index helper column. + return [(label, path, dtype, shape) for (_, label, path, dtype, shape) in results] + + +def parse_args(): + """Parse CLI args for the index builder.""" + p = argparse.ArgumentParser(description="Build an OLMo-style numpy mix index by scanning .npy headers.") + p.add_argument("--mix-file", required=True, help="Path to the mix .txt file.") + p.add_argument( + "--gcs-base", + required=True, + help=( + "Base prefix for resolved file paths, e.g. gs://my-bucket/dataset/." + " Mix-file relative paths are joined to this." + ), + ) + p.add_argument( + "--tokenizer", + default="allenai/dolma3-tokenizer", + help="Substituted for {TOKENIZER} in mix paths. Also stored in the index.", + ) + p.add_argument( + "--sequence-length", + type=int, + required=True, + help="Tokens per training instance (e.g. 8192).", + ) + p.add_argument("--output", required=True, help="Output JSON path.") + p.add_argument( + "--dtype", + default="uint32", + help=( + "Numpy dtype for files lacking a .npy header (the AI2 'pseudo-.npy'" + " files are headerless uint32 streams). Default: uint32." + ), + ) + p.add_argument("--workers", type=int, default=32, help="Parallel header-scan threads.") + return p.parse_args() + + +def main(): + args = parse_args() + + print(f"Parsing mix file: {args.mix_file}", file=sys.stderr) + entries = parse_mix_file(args.mix_file, args.tokenizer) + print(f" {len(entries)} entries", file=sys.stderr) + + base = args.gcs_base.rstrip("/") + "/" + resolved = [(label, base + rel.lstrip("/")) for label, rel in entries] + + print( + f"Scanning {len(resolved)} .npy headers ({args.workers} threads, " f"raw dtype={args.dtype})...", + file=sys.stderr, + ) + headers = scan_headers_parallel(resolved, dtype_for_raw=args.dtype, workers=args.workers) + + # Cache the (dtype, shape) we already read so build_index doesn't re-scan. + header_cache = {path: (dtype, shape) for (_, path, dtype, shape) in headers} + + def _cached_reader(path: str): + return header_cache[path] + + paths_for_build = [(path, label) for (label, path, _, _) in headers] + index = build_index( + paths_for_build, + sequence_length=args.sequence_length, + tokenizer=args.tokenizer, + header_reader=_cached_reader, + ) + + # Sanity: total token count for human inspection. + total_t = index.total_tokens + total_i = index.total_instances + print( + f"Total tokens: {total_t:,} | instances at SEQ={args.sequence_length}: " + f"{total_i:,} | fingerprint: {index.fingerprint}", + file=sys.stderr, + ) + + index.save(args.output) + print(f"Wrote index to {args.output}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/tools/data_generation/download_olmo_data_to_gcs.py b/tools/data_generation/download_olmo_data_to_gcs.py new file mode 100755 index 0000000000..071f4ca9e1 --- /dev/null +++ b/tools/data_generation/download_olmo_data_to_gcs.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download OLMo .npy dataset files from HTTP to GCS with HTTP-Range resume. + +Each HTTP transfer is staged to local disk and resumes via the HTTP +``Range: bytes=N-`` header on any RequestException — a 240 GB file whose +upstream connection drops at 140 GB resumes from there instead of +restarting at 0. After the local file is complete (and matches +Content-Length), it's uploaded to GCS in one shot, verified against +``blob.size``, and the local copy removed. + +Local disk usage at peak is bounded by ``--workers * largest_file_size``. + +Usage: + python download_olmo_data_to_gcs.py \\ + --mix-file /path/to/OLMo-mix-0925-official.txt \\ + --gcs-dest gs://my-bucket/olmo/ \\ + --staging-dir /mnt/local-ssd/olmo-staging \\ + --workers 2 + +Skip-existing is on by default, so re-running against the full mix file +finishes only the entries missing from GCS. A partial local stage file is +detected via os.path.getsize() and resumed. +""" + +import argparse +import os +import sys +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from urllib.parse import urljoin + +import requests +from requests.exceptions import RequestException +from google.cloud import storage + +from maxtext.utils.gcs_utils import gcs_path_exists, parse_gcs_bucket_and_prefix + + +def parse_mix_file(mix_path: str, tokenizer: str) -> list[tuple[str, str]]: + """Parse an OLMo data-mix .txt file. Returns list of (label, rel_path).""" + entries = [] + with open(mix_path, encoding="utf-8") as f: + for line_num, line in enumerate(f, start=1): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split(",", maxsplit=1) + if len(parts) != 2: + print(f"WARNING: {mix_path}:{line_num}: skipping malformed line: {line!r}") + continue + label, rel_path = parts + entries.append((label, rel_path.replace("{TOKENIZER}", tokenizer))) + return entries + + +def _total_size_from_response(resp, bytes_have): + """Extract total file size from a response, preferring Content-Range.""" + cr = resp.headers.get("Content-Range") + if cr and "/" in cr: + tail = cr.split("/")[-1].strip() + if tail.isdigit(): + return int(tail) + cl = resp.headers.get("Content-Length") + if cl is not None: + # 200 OK ⇒ Content-Length is the full size; 206 ⇒ remaining size. + return int(cl) + (bytes_have if resp.status_code == 206 else 0) + return None + + +def http_resumable_to_local( + url: str, + local_path: str, + max_retries: int = 20, + chunk_size: int = 8 * 1024 * 1024, + timeout: int = 300, + progress_every_s: float = 30.0, +) -> tuple[int, int | None]: + """Download ``url`` to ``local_path`` with Range-based resume. + + Returns (bytes_written, total_size). total_size may be None if the server + refuses to disclose it on every attempt. + """ + os.makedirs(os.path.dirname(local_path), exist_ok=True) + session = requests.Session() + total_size = None + bytes_have = os.path.getsize(local_path) if os.path.exists(local_path) else 0 + if bytes_have: + print(f" resuming {os.path.basename(local_path)} from {bytes_have / 1e9:.2f} GB", flush=True) + + for attempt in range(1, max_retries + 1): + try: + headers = {} + if bytes_have > 0: + headers["Range"] = f"bytes={bytes_have}-" + + with session.get(url, stream=True, headers=headers, timeout=timeout) as resp: + resp.raise_for_status() + # If we asked for a Range and the server returned 200 (full body), + # discard our partial file — server doesn't honor ranges and is + # restarting from byte 0. + if bytes_have > 0 and resp.status_code == 200: + print( + f" server returned 200 to Range request; restarting " f"{os.path.basename(local_path)} from 0", + flush=True, + ) + bytes_have = 0 + # Truncate; we'll reopen "wb" below. + if os.path.exists(local_path): + os.remove(local_path) + + new_total = _total_size_from_response(resp, bytes_have) + if new_total is not None: + if total_size is None: + total_size = new_total + elif new_total != total_size: + print( + f" WARNING: total size changed mid-download " f"({total_size} → {new_total}); using new value", + flush=True, + ) + total_size = new_total + + mode = "ab" if bytes_have > 0 else "wb" + last_log = time.time() + attempt_start = time.time() + attempt_start_bytes = bytes_have + with open(local_path, mode) as fp: + for chunk in resp.iter_content(chunk_size=chunk_size): + if not chunk: + continue + fp.write(chunk) + bytes_have += len(chunk) + now = time.time() + if now - last_log >= progress_every_s: + dt = max(now - attempt_start, 1e-6) + dbytes = bytes_have - attempt_start_bytes + rate_mb = (dbytes / 1024 / 1024) / dt + if total_size: + pct = 100.0 * bytes_have / total_size + eta_s = (total_size - bytes_have) / max(dbytes / dt, 1) + print( + f" [{os.path.basename(local_path)}] " + f"{bytes_have/1e9:.2f}/{total_size/1e9:.2f} GB " + f"({pct:.1f}%) @ {rate_mb:.0f} MB/s; ETA {eta_s/60:.0f} min", + flush=True, + ) + else: + print( + f" [{os.path.basename(local_path)}] " f"{bytes_have/1e9:.2f} GB @ {rate_mb:.0f} MB/s", + flush=True, + ) + last_log = now + + # Clean exit. Validate full size if known. + if total_size is not None and bytes_have != total_size: + raise RuntimeError(f"truncated: have {bytes_have} bytes, expected {total_size}") + return bytes_have, total_size + + except (RequestException, RuntimeError, ConnectionError) as exc: + wait = min(2**attempt, 60) + print( + f" [retry {attempt}/{max_retries}] @{bytes_have/1e9:.2f} GB: " + f"{type(exc).__name__}: {str(exc)[:200]}; sleeping {wait}s", + flush=True, + ) + time.sleep(wait) + # Refresh bytes_have from disk in case the partial write was flushed. + bytes_have = os.path.getsize(local_path) if os.path.exists(local_path) else 0 + + raise RuntimeError( + f"exhausted {max_retries} retries; got {bytes_have} bytes" f"{'/' + str(total_size) if total_size else ''}" + ) + + +def download_one( + rel_path: str, + http_base_url: str, + gcs_dest_prefix: str, + staging_dir: str, + skip_existing: bool, + max_retries: int, +) -> tuple[str, str, int]: + """Download one mix-file entry: HTTP → local stage → GCS upload.""" + if not http_base_url.endswith("/"): + http_base_url += "/" + http_url = urljoin(http_base_url, rel_path) + gcs_dest = gcs_dest_prefix.rstrip("/") + "/" + rel_path.lstrip("/") + + if skip_existing and gcs_path_exists(gcs_dest): + return rel_path, "skipped", 0 + + local_path = os.path.join(staging_dir, rel_path) + + try: + bytes_written, _ = http_resumable_to_local(http_url, local_path, max_retries=max_retries) + + bucket_name, blob_name = parse_gcs_bucket_and_prefix(gcs_dest) + client = storage.Client() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name) + blob.upload_from_filename(local_path) + blob.reload() + if blob.size != bytes_written: + return ( + rel_path, + f"error:size_mismatch local={bytes_written} gcs={blob.size}", + bytes_written, + ) + + os.remove(local_path) + return rel_path, "ok", bytes_written + + except Exception: # pylint: disable=broad-except + return rel_path, f"error:{traceback.format_exc()[-800:]}", 0 + + +def parse_args(): + """Parse CLI args for the HTTP → GCS downloader.""" + parser = argparse.ArgumentParser( + description="Resumable HTTP→GCS downloader for OLMo .npy mix files.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument("--mix-file", required=True) + parser.add_argument("--gcs-dest", required=True) + parser.add_argument("--tokenizer", default="allenai/dolma3-tokenizer") + parser.add_argument("--http-base-url", default="http://olmo-data.org/") + parser.add_argument("--workers", type=int, default=2) + parser.add_argument( + "--staging-dir", + default="/tmp/olmo-staging", + help="Local scratch dir for partial downloads.", + ) + parser.add_argument("--max-retries", type=int, default=20) + parser.add_argument("--no-skip-existing", action="store_true") + parser.add_argument("--dry-run", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + + print(f"Parsing mix file: {args.mix_file}") + entries = parse_mix_file(args.mix_file, args.tokenizer) + print(f" {len(entries)} files found (tokenizer={args.tokenizer!r})") + + gcs_dest_prefix = args.gcs_dest.rstrip("/") + http_base = args.http_base_url.rstrip("/") + "/" + + if args.dry_run: + for label, rel_path in entries: + url = urljoin(http_base, rel_path) + gcs = gcs_dest_prefix + "/" + rel_path.lstrip("/") + print(f" [{label}] {url} → {gcs}") + print(f"\nDry run complete. {len(entries)} files listed.") + return + + os.makedirs(args.staging_dir, exist_ok=True) + skip_existing = not args.no_skip_existing + if skip_existing: + print("Existing GCS files will be skipped (use --no-skip-existing to re-download).") + print(f"Staging to: {args.staging_dir}") + print(f"Workers: {args.workers} Max retries per file: {args.max_retries}") + + rel_paths = [r for _, r in entries] + n_ok = n_skipped = n_error = 0 + total_bytes = 0 + errors = [] + start = time.time() + + with ThreadPoolExecutor(max_workers=args.workers) as pool: + futures = { + pool.submit( + download_one, + rel_path, + args.http_base_url, + gcs_dest_prefix, + args.staging_dir, + skip_existing, + args.max_retries, + ): rel_path + for rel_path in rel_paths + } + for i, future in enumerate(as_completed(futures), start=1): + rel_path, status, nbytes = future.result() + total_bytes += nbytes + if status == "ok": + n_ok += 1 + elif status == "skipped": + n_skipped += 1 + else: + n_error += 1 + errors.append((rel_path, status)) + + elapsed = time.time() - start + tag = status.split(":", 1)[0] + print( + f" [{i}/{len(rel_paths)}] ok={n_ok} skipped={n_skipped} err={n_error}" + f" | {total_bytes/1e9:.2f} GB | {elapsed:.0f}s | {tag} {rel_path}", + flush=True, + ) + + elapsed = time.time() - start + print( + f"\nDone in {elapsed:.1f}s. ok={n_ok} skipped={n_skipped} errors={n_error}" + f" | {total_bytes/1e9:.2f} GB transferred" + ) + + if errors: + print(f"\nFailed files ({len(errors)}):") + for rel_path, status in errors: + print(f" {rel_path}\n {status}") + sys.exit(1) + + +if __name__ == "__main__": + main()