diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index 3aeb4eb3f..61427cf16 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -914,7 +914,7 @@ def load( max_sequence_length=max_sequence_length, top_k=top_k, ) - if isinstance(state_dict, Dict[str, torch.Tensor]): + if isinstance(state_dict, dict): model.load_state_dict(state_dict, strict=strict) else: raise ValueError("`state_dict` must be a dictionary of parameter (torch) tensors.")