Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,12 @@ class TrainState:
serialized_item, value_metadata_tree
)
else:
# Deserialize value metadata tree to the same structure as item to allow
# for comparison with item that contains rich types.
if self._pytree_metadata_options.support_rich_types:
value_metadata_tree = tree_utils.deserialize_tree(
value_metadata_tree, item
)
# is_empty_or_leaf is necessary here to treat empty nodes (e.g. empty
# dicts, lists, custom nodes) as leaves, as they do not contain any
# actual data to be restored, but are needed to maintain the structure.
Expand All @@ -1083,12 +1089,11 @@ class TrainState:
restore_args, self._pytree_metadata_options
)

value_metadata_tree_deserialized = tree_utils.deserialize_tree(
value_metadata_tree, item
)
restore_args_deserialized = tree_utils.deserialize_tree(restore_args, item)
value_metadata_tree = value_metadata_tree_deserialized
restore_args = restore_args_deserialized
if not self._pytree_metadata_options.support_rich_types:
value_metadata_tree = tree_utils.deserialize_tree(
value_metadata_tree, item
)
restore_args = tree_utils.deserialize_tree(restore_args, item)

param_infos = self._get_param_infos(
item=value_metadata_tree,
Expand Down
13 changes: 9 additions & 4 deletions checkpoint/orbax/checkpoint/_src/tree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Any, Callable, Mapping, Optional, Tuple, TypeVar, Union

import flax
import jax
import jax.tree_util as jtu
from orbax.checkpoint._src.arrays import abstract_arrays
Expand Down Expand Up @@ -235,10 +236,14 @@ def _reconstruct_from_keypath(keypath, _):
result = serialized
for key in keypath:
key_name = get_key_name(key)
if isinstance(key, jax.tree_util.GetAttrKey) and isinstance_of_namedtuple(
result
):
result = getattr(result, key_name)
if isinstance(key, jax.tree_util.GetAttrKey):
if isinstance_of_namedtuple(result):
result = getattr(result, key_name)
elif isinstance(result, flax.struct.PyTreeNode):
# Special case to support flax.struct.PyTreeNode
result = result.__dict__[key_name]
else:
result = result[key_name]
else:
# Special case to support Pax.
if not isinstance(result, (list, tuple)) and key_name not in result:
Expand Down
Loading