diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index e00667cbb..ce6f90fb5 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -1057,6 +1057,12 @@ class TrainState: serialized_item, value_metadata_tree ) else: + # Deserialize value metadata tree to the same structure as item to allow + # for comparison with item that contains rich types. + if self._pytree_metadata_options.support_rich_types: + value_metadata_tree = tree_utils.deserialize_tree( + value_metadata_tree, item + ) # is_empty_or_leaf is necessary here to treat empty nodes (e.g. empty # dicts, lists, custom nodes) as leaves, as they do not contain any # actual data to be restored, but are needed to maintain the structure. @@ -1083,12 +1089,11 @@ class TrainState: restore_args, self._pytree_metadata_options ) - value_metadata_tree_deserialized = tree_utils.deserialize_tree( - value_metadata_tree, item - ) - restore_args_deserialized = tree_utils.deserialize_tree(restore_args, item) - value_metadata_tree = value_metadata_tree_deserialized - restore_args = restore_args_deserialized + if not self._pytree_metadata_options.support_rich_types: + value_metadata_tree = tree_utils.deserialize_tree( + value_metadata_tree, item + ) + restore_args = tree_utils.deserialize_tree(restore_args, item) param_infos = self._get_param_infos( item=value_metadata_tree, diff --git a/checkpoint/orbax/checkpoint/_src/tree/utils.py b/checkpoint/orbax/checkpoint/_src/tree/utils.py index 87d4704bf..9a21804b0 100644 --- a/checkpoint/orbax/checkpoint/_src/tree/utils.py +++ b/checkpoint/orbax/checkpoint/_src/tree/utils.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Mapping, Optional, Tuple, TypeVar, Union +import flax import jax import jax.tree_util as jtu from orbax.checkpoint._src.arrays import abstract_arrays @@ -235,10 +236,14 @@ def _reconstruct_from_keypath(keypath, _): result = serialized for key in keypath: key_name = get_key_name(key) - if isinstance(key, jax.tree_util.GetAttrKey) and isinstance_of_namedtuple( - result - ): - result = getattr(result, key_name) + if isinstance(key, jax.tree_util.GetAttrKey): + if isinstance_of_namedtuple(result): + result = getattr(result, key_name) + elif isinstance(result, flax.struct.PyTreeNode): + # Special case to support flax.struct.PyTreeNode + result = result.__dict__[key_name] + else: + result = result[key_name] else: # Special case to support Pax. if not isinstance(result, (list, tuple)) and key_name not in result: