From 1491c8f99448a7fab2d74a3cf2bec0a02ba8a815 Mon Sep 17 00:00:00 2001 From: Marco Berlot Date: Wed, 11 Mar 2026 06:36:48 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 881983235 --- checkpoint/orbax/checkpoint/_src/path/step.py | 26 ++- .../checkpoint/_src/path/storage_backend.py | 190 ++++++++++++++++++ .../_src/path/storage_backend_test.py | 75 +++++++ 3 files changed, 283 insertions(+), 8 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/_src/path/storage_backend.py create mode 100644 checkpoint/orbax/checkpoint/_src/path/storage_backend_test.py diff --git a/checkpoint/orbax/checkpoint/_src/path/step.py b/checkpoint/orbax/checkpoint/_src/path/step.py index 6fdc34eca..d478ca506 100644 --- a/checkpoint/orbax/checkpoint/_src/path/step.py +++ b/checkpoint/orbax/checkpoint/_src/path/step.py @@ -36,6 +36,7 @@ from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import gcs_utils +from orbax.checkpoint._src.path import storage_backend as storage_backend_lib # pylint: disable=unused-import from orbax.checkpoint._src.path import temporary_paths # Allowed checkpoint step naming using any non empty `step_prefix`. @@ -190,7 +191,9 @@ def build_step_path( base_path: epath.PathLike, name_format: NameFormat[Metadata], step: int ) -> epath.Path: """Returns `step` path under `base_path` for step `name_format`.""" - return epath.Path(base_path) / name_format.build_name(step) + backend = storage_backend_lib.resolve_storage_backend(str(base_path)) + label = name_format.build_name(step) + return epath.Path(base_path) / backend.name_to_path_component(label) def build_step_metadatas( @@ -354,11 +357,12 @@ def _build_metadata( return None if step is not None: - # step already known, just check exists. if step_path.exists(): return Metadata(step=step, path=step_path) - # Regex: [prefix]*(step) + name = step_path.name + backend = storage_backend_lib.resolve_storage_backend(str(step_path.parent)) + name = backend.path_component_to_name(name) if self.step_format_fixed_length and self.step_format_fixed_length > 0: zero_present = rf'0\d{{{self.step_format_fixed_length-1}}}' zero_not_present = rf'[1-9]\d{{{self.step_format_fixed_length-1}}}\d*' @@ -367,7 +371,7 @@ def _build_metadata( zero_padded_step_group = r'(0|[1-9]\d*)' name_regex = f'^{step_prefix_with_underscore(self.step_prefix)}{zero_padded_step_group}$' - match = re.search(name_regex, step_path.name) + match = re.search(name_regex, name) if match is None: return None (step_,) = match.groups() @@ -403,9 +407,15 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]: os.path.join(path_prefix, self.step_prefix or '') ) ] - else: - prefix = step_prefix_with_underscore(self.step_prefix) - return [x for x in base_path.iterdir() if x.name.startswith(prefix)] + backend = storage_backend_lib.resolve_storage_backend(str(base_path)) + assets = backend.list_checkpoints(str(base_path)) + logical_prefix = step_prefix_with_underscore(self.step_prefix) + result = [] + for a in assets: + name = a.version if a.version is not None else epath.Path(a.path).name + if name.startswith(logical_prefix): + result.append(base_path / backend.name_to_path_component(name)) + return result def _get_step_paths_and_total_steps( self, base_path: epath.PathLike, is_primary_host: bool @@ -505,7 +515,7 @@ def _find_all_with_single_host_load_and_broadcast( ) base_path = epath.Path(base_path) paths_to_step_dict: dict[epath.Path, int] = { - base_path / self.build_name(step): step + build_step_path(base_path, self, step): step for step in padded_step_list if step >= 0 } diff --git a/checkpoint/orbax/checkpoint/_src/path/storage_backend.py b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py new file mode 100644 index 000000000..9a600a2d2 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py @@ -0,0 +1,190 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint storage interface and base implementations. + +This module defines the abstract StorageBackend interface for managing +checkpoint paths across different file systems. Base implementations for GCS +and local file systems are provided here +""" + +import abc +import dataclasses +import enum + +from absl import logging +from etils import epath +from orbax.checkpoint._src.path import atomicity_types + + +@dataclasses.dataclass(frozen=True) +class CheckpointPathMetadata: + """Internal representation of checkpoint path metadata. + + Attributes: + path: The file system path of the checkpoint. + status: The status of the checkpoint. + version: The version of the checkpoint with an index and step number. (e.g. + '1.step_1') + tags: A list of tags associated with the checkpoint. May not be available in + all backend implementations; for unsupported backends this field will be + `None`. + """ + + class Status(enum.Enum): + COMMITTED = 1 + UNCOMMITTED = 2 + + path: str + status: Status + version: str | None + tags: set[str] | None = None + + +@dataclasses.dataclass(frozen=True) +class CheckpointFilter: + """Criteria for filtering checkpoints. + + TODO: b/466312058 This class will contain fields for filtering checkpoints by + various criteria. + """ + + +@dataclasses.dataclass(frozen=True) +class CheckpointReadOptions: + """Options for reading checkpoints. + + Attributes: + filter: Optional filter criteria for selecting checkpoints. + enable_strong_reads: If True, enables strong read consistency when querying + checkpoints. This may have performance implications but ensures the most + up-to-date results. + """ + + filter: CheckpointFilter | None = None + enable_strong_reads: bool = False + + +class StorageBackend(abc.ABC): + """An abstract base class for a storage backend. + + This class defines a common interface for managing checkpoint paths in + different file systems. + """ + + def name_to_path_component(self, name: str) -> str: + """Converts a logical step name to a filesystem path component.""" + return name + + def path_component_to_name(self, path_component: str) -> str: + """Converts a filesystem path component back to a logical step name.""" + return path_component + + @abc.abstractmethod + def list_checkpoints( + self, + base_path: str | epath.PathLike, + ) -> list[CheckpointPathMetadata]: + """Lists checkpoints for a given base path and version pattern.""" + raise NotImplementedError('Subclasses must provide implementation') + + @abc.abstractmethod + def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]: + """Returns a TemporaryPath class for the storage backend.""" + raise NotImplementedError('Subclasses must provide implementation') + + @abc.abstractmethod + def delete_checkpoint( + self, + checkpoint_path: str | epath.PathLike, + ) -> None: + """Deletes a checkpoint from the storage backend.""" + raise NotImplementedError('Subclasses must provide implementation') + + +class GCSStorageBackend(StorageBackend): + """A StorageBackend implementation for GCS (Google Cloud Storage). + + # TODO(b/425293362): Implement this class. + """ + + def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]: + """Returns the final checkpoint path directly.""" + raise NotImplementedError( + 'get_temporary_path_class is not yet implemented for GCSStorageBackend.' + ) + + def list_checkpoints( + self, base_path: str | epath.PathLike + ) -> list[CheckpointPathMetadata]: + """Lists checkpoints for a given base path and version pattern.""" + raise NotImplementedError( + 'list_checkpoints is not yet implemented for GCSStorageBackend.' + ) + + def delete_checkpoint( + self, + checkpoint_path: str | epath.PathLike, + ) -> None: + """Deletes the checkpoint at the given path.""" + raise NotImplementedError( + 'delete_checkpoint is not yet implemented for GCSStorageBackend.' + ) + + +class LocalStorageBackend(StorageBackend): + """A LocalStorageBackend implementation for local file systems.""" + + def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]: + """Returns the final checkpoint path directly.""" + raise NotImplementedError( + 'get_temporary_path_class is not yet implemented for' + ' LocalStorageBackend.' + ) + + def list_checkpoints( + self, + base_path: str | epath.PathLike, + ) -> list[CheckpointPathMetadata]: + """Lists checkpoints for a given base path.""" + base = epath.Path(base_path) + if not base.exists(): + return [] + return [ + CheckpointPathMetadata( + path=str(child), + status=CheckpointPathMetadata.Status.COMMITTED, + version=None, + ) + for child in base.iterdir() + ] + + def delete_checkpoint( + self, + checkpoint_path: str | epath.PathLike, + ) -> None: + """Deletes the checkpoint at the given path.""" + try: + epath.Path(checkpoint_path).rmtree() + logging.info('Removed old checkpoint (%s)', checkpoint_path) + except OSError: + logging.exception('Failed to remove checkpoint (%s)', checkpoint_path) + + +def resolve_storage_backend( + path: str, +) -> StorageBackend: + """Returns a StorageBackend for the given path.""" + del path + return LocalStorageBackend() diff --git a/checkpoint/orbax/checkpoint/_src/path/storage_backend_test.py b/checkpoint/orbax/checkpoint/_src/path/storage_backend_test.py new file mode 100644 index 000000000..cd762a711 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/path/storage_backend_test.py @@ -0,0 +1,75 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for checkpoint storage backend base implementations.""" + +from absl.testing import absltest +from etils import epath +from orbax.checkpoint._src.path import storage_backend + + +class LocalStorageBackendTest(absltest.TestCase): + + def test_name_to_path_component_is_identity(self): + backend = storage_backend.LocalStorageBackend() + self.assertEqual(backend.name_to_path_component('step_1'), 'step_1') + + def test_path_component_to_name_is_identity(self): + backend = storage_backend.LocalStorageBackend() + self.assertEqual(backend.path_component_to_name('step_1'), 'step_1') + + def test_list_checkpoints_returns_children(self): + tmpdir = self.create_tempdir() + base = epath.Path(tmpdir.full_path) + (base / 'step_0').mkdir() + (base / 'step_1').mkdir() + backend = storage_backend.LocalStorageBackend() + assets = backend.list_checkpoints(str(base)) + self.assertLen(assets, 2) + paths = sorted([a.path for a in assets]) + self.assertEqual( + paths, + sorted([ + str(base / 'step_0'), + str(base / 'step_1'), + ]), + ) + for asset in assets: + self.assertEqual( + asset.status, + storage_backend.CheckpointPathMetadata.Status.COMMITTED, + ) + self.assertIsNone(asset.version) + + def test_list_checkpoints_non_existent_path_returns_empty(self): + tmpdir = self.create_tempdir() + base = epath.Path(tmpdir.full_path) / 'non_existent' + backend = storage_backend.LocalStorageBackend() + assets = backend.list_checkpoints(str(base)) + self.assertEmpty(assets) + + def test_list_checkpoints_empty_directory(self): + tmpdir = self.create_tempdir() + base = epath.Path(tmpdir.full_path) + backend = storage_backend.LocalStorageBackend() + assets = backend.list_checkpoints(str(base)) + self.assertEmpty(assets) + + def test_resolve_storage_backend_returns_local(self): + backend = storage_backend.resolve_storage_backend('/tmp/some/path') + self.assertIsInstance(backend, storage_backend.LocalStorageBackend) + + +if __name__ == '__main__': + absltest.main()