Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
### Constants shared by all layouts. ###

PYTREE_CHECKPOINTABLE_KEY = "pytree"
EMPTY_CHECKPOINTABLE_KEY = ""

METRICS_CHECKPOINTABLE_KEY = "metrics"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,17 @@ def load_pytree(
"""
start_time = time.time()
logging.info('Loading checkpoint from %s.', path)

abstract_pytree = _standardize_abstract_checkpointables(abstract_pytree)
validation.validate_pytree_checkpointable_name(checkpointable_name)

ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)
layout = asyncio_utils.run_sync(
layout_registry.get_checkpoint_layout_pytree(
path, ctx.checkpoint_layout, checkpointable_name
)
)
abstract_pytree = _standardize_abstract_checkpointables(abstract_pytree)

validation.validate_pytree_checkpointable_name(checkpointable_name)

loaded_pytree = _load_impl(
path,
Expand Down Expand Up @@ -257,17 +258,18 @@ def load_checkpointables(
"""
start_time = time.time()
logging.info('Loading checkpoint from %s.', path)
ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)
layout = asyncio_utils.run_sync(
layout_registry.get_checkpoint_layout(path, ctx.checkpoint_layout)
)

abstract_checkpointables = _standardize_abstract_checkpointables(
abstract_checkpointables
)
validation.validate_abstract_checkpointables(abstract_checkpointables)

ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)
layout = asyncio_utils.run_sync(
layout_registry.get_checkpoint_layout(path, ctx.checkpoint_layout)
)

if not hasattr(layout, 'load_checkpointables'):
raise NotImplementedError(
f'Layout {type(layout)} does not support loading checkpointables.'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout

RESERVED_CHECKPOINTABLE_KEYS = checkpoint_layout.RESERVED_CHECKPOINTABLE_KEYS
EMPTY_CHECKPOINTABLE_KEY = checkpoint_layout.EMPTY_CHECKPOINTABLE_KEY


def validate_pytree_checkpointable_name(
checkpointable_name: str | None,
Expand All @@ -30,7 +33,13 @@ def validate_pytree_checkpointable_name(
"""
if checkpointable_name is None:
return
if checkpointable_name in checkpoint_layout.RESERVED_CHECKPOINTABLE_KEYS:
if checkpointable_name == EMPTY_CHECKPOINTABLE_KEY:
raise ValueError(
'Empty string is not supported as a checkpointable name in'
' `load_pytree`. Checkpointable name must be a valid non-empty string'
' name or None if loading a legacy V0 direct pytree checkpoint.'
)
if checkpointable_name in RESERVED_CHECKPOINTABLE_KEYS:
raise ValueError(
f'Provided reserved checkpointable key: {checkpointable_name}.'
)
Expand All @@ -47,9 +56,15 @@ def validate_abstract_checkpointables(abstract_checkpointables):
"""
if abstract_checkpointables is None:
return
if EMPTY_CHECKPOINTABLE_KEY in abstract_checkpointables:
raise ValueError(
'Empty string is not supported as a checkpointable name in'
' `load_checkpointables`. Each checkpointable name must be a valid'
' non-empty string name.'
)
if (
provided_reserved_keys := abstract_checkpointables.keys()
& checkpoint_layout.RESERVED_CHECKPOINTABLE_KEYS
& RESERVED_CHECKPOINTABLE_KEYS
):
raise ValueError(
f'Provided reserved checkpointable keys: {provided_reserved_keys}.'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,34 @@
from unittest import mock

from absl.testing import absltest
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.loading import validation


class ValidationTest(absltest.TestCase):
def test_validate_pytree_checkpointable_name(self):
validation.validate_pytree_checkpointable_name(None)
validation.validate_pytree_checkpointable_name('pytree')
validation.validate_pytree_checkpointable_name('a')

with self.assertRaisesRegex(ValueError, 'Empty string is not supported'):
validation.validate_pytree_checkpointable_name('')

with mock.patch.object(
validation, 'RESERVED_CHECKPOINTABLE_KEYS', {'reserved'}
):
with self.assertRaisesRegex(ValueError, 'reserved'):
validation.validate_pytree_checkpointable_name('reserved')

def test_validate_abstract_checkpointables(self):
validation.validate_abstract_checkpointables(None)
validation.validate_abstract_checkpointables({})
validation.validate_abstract_checkpointables({'a': 1})

with self.assertRaisesRegex(ValueError, 'Empty string is not supported'):
validation.validate_abstract_checkpointables({'': 1})

with mock.patch.object(
checkpoint_layout, 'RESERVED_CHECKPOINTABLE_KEYS', {'reserved'}
validation, 'RESERVED_CHECKPOINTABLE_KEYS', {'reserved'}
):
with self.assertRaisesRegex(ValueError, 'reserved'):
validation.validate_abstract_checkpointables({'reserved': 1})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.layout import registry as layout_registry
from orbax.checkpoint.experimental.v1._src.loading import validation
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types

Expand All @@ -30,6 +31,7 @@
InvalidLayoutError = errors.InvalidLayoutError
PyTreeMetadata = metadata_types.PyTreeMetadata
PYTREE_CHECKPOINTABLE_KEY = checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
EMPTY_CHECKPOINTABLE_KEY = checkpoint_layout.EMPTY_CHECKPOINTABLE_KEY


def pytree_metadata(
Expand Down Expand Up @@ -85,6 +87,7 @@ def _get_abstract_array(arr):
Returns:
A `CheckpointMetadata[PyTreeMetadata]` object.
"""
validation.validate_pytree_checkpointable_name(checkpointable_name)
ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from orbax.checkpoint.experimental.v1._src.path import async_utils as path_async_utils
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.saving import path_utils as saving_path_utils
from orbax.checkpoint.experimental.v1._src.saving import validation
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.synchronization import thread_utils
from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types
Expand Down Expand Up @@ -371,6 +372,7 @@ def save_checkpointables_impl(
partial_save: bool = False,
) -> async_types.AsyncResponse[None]:
"""See caller docstrings."""
validation.validate_abstract_checkpointables(checkpointables)
start_time = time.time()
event_tracking.record_save_start(path, async_origin=async_origin)
# Ensure the operation ID is incremented as soon as possible. This must be
Expand Down
18 changes: 8 additions & 10 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.saving import execution
from orbax.checkpoint.experimental.v1._src.saving import validation
from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types

Expand Down Expand Up @@ -70,12 +71,13 @@ def save_pytree(
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
"""
save_checkpointables(
execution.save_checkpointables_impl(
path,
{PYTREE_CHECKPOINTABLE_KEY: pytree},
overwrite=overwrite,
custom_metadata=custom_metadata,
)
async_origin=False,
).result()


def save_checkpointables(
Expand Down Expand Up @@ -131,6 +133,7 @@ def save_checkpointables(
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
"""
validation.validate_abstract_checkpointables(checkpointables)
execution.save_checkpointables_impl(
path,
checkpointables,
Expand Down Expand Up @@ -200,11 +203,12 @@ def save_pytree_async(
An `AsyncResponse` that can be used to block until the save is complete.
Blocking can be done using `response.result()`, which returns `None`.
"""
return save_checkpointables_async(
return execution.save_checkpointables_impl(
path,
{PYTREE_CHECKPOINTABLE_KEY: pytree},
overwrite=overwrite,
custom_metadata=custom_metadata,
async_origin=True,
)


Expand Down Expand Up @@ -275,6 +279,7 @@ def save_checkpointables_async(
An `AsyncResponse` that can be used to block until the save is complete.
Blocking can be done using `response.result()`, which returns `None`.
"""
validation.validate_abstract_checkpointables(checkpointables)
return execution.save_checkpointables_impl(
path,
checkpointables,
Expand Down Expand Up @@ -303,13 +308,6 @@ def get_v0_checkpointer_and_args(
Returns:
A tuple containing the V0 Checkpointer and Args.
"""
if (
provided_reserved_keys := checkpointables.keys()
& checkpoint_layout.RESERVED_CHECKPOINTABLE_KEYS
):
raise ValueError(
f'Provided reserved checkpointable keys: {provided_reserved_keys}.'
)
checkpointables = execution.add_internal_checkpointables(
checkpointables, context=context, metrics=metrics
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Validation functions involved in saving."""

from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout

RESERVED_CHECKPOINTABLE_KEYS = checkpoint_layout.RESERVED_CHECKPOINTABLE_KEYS
EMPTY_CHECKPOINTABLE_KEY = checkpoint_layout.EMPTY_CHECKPOINTABLE_KEY


def validate_abstract_checkpointables(abstract_checkpointables):
"""Validates the abstract_checkpointables dictionary.

Args:
abstract_checkpointables: A dictionary of abstract checkpointables.

Raises:
ValueError: If any of the keys in abstract_checkpointables are reserved.
"""
if abstract_checkpointables is None:
return
if EMPTY_CHECKPOINTABLE_KEY in abstract_checkpointables:
raise ValueError(
'Empty string is not supported as a checkpointable name in'
' `save_checkpointables`. Each checkpointable name must be a valid'
' non-empty string name.'
)
if (
provided_reserved_keys := abstract_checkpointables.keys()
& RESERVED_CHECKPOINTABLE_KEYS
):
raise ValueError(
f'Provided reserved checkpointable keys: {provided_reserved_keys}.'
)
Original file line number Diff line number Diff line change
Expand Up @@ -437,34 +437,30 @@ def get_v0_type_handler_registry(
type handler registry.
context: The Context to be used to default construct the LeafHandlers.
"""

def _get_typestr(leaf_type: Any) -> str:
if leaf_type == jax.Array:
return type_handlers_v0.JAX_ARRAY_TYPE_STR
elif leaf_type == np.ndarray:
return 'np.ndarray'
elif leaf_type in (int, float, bytes, np.number):
return 'scalar'
elif leaf_type == str:
return 'string'
else:
return f'{leaf_type!r}'

# register standardard v1 leaf handlers to the v0 type handler registry.
handlers = []
for leaf_type, _, leaf_handler_type in leaf_handler_registry.get_all():
# We must reverse the order of the leaf handlers to ensure that the last
# registered handler is the first one used as V1 registry is ordered by
# priority of generic to specific, while V0 type handler registry is ordered
# by the reverse.
for leaf_type, _, leaf_handler_type in reversed(
leaf_handler_registry.get_all()
):
try:
leaf_handler = leaf_handler_type(context=context) # pytype: disable=wrong-keyword-args
except TypeError as e:
raise ValueError(
f'Failed to default construct LeafHandler[{leaf_type}]. All'
' LeafHandler types must be able to be constructed with a context.'
) from e

typestrs = leaf_handler_registry.get_secondary_typestrs(leaf_handler_type)
typestr = typestrs[0] if typestrs else f'{leaf_type!r}'
handlers.append((
leaf_type,
CompatibleTypeHandler(
leaf_handler,
typestr=_get_typestr(leaf_type),
typestr=typestr,
),
))
return type_handler_registry.create_type_handler_registry(*handlers)
Loading
Loading