diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index 3aeb4eb3f..6d697ebce 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -269,7 +269,7 @@ def __init__( for i, task in enumerate(prediction_tasks): self.prediction_task_dict[task.task_name] = task - self._task_weights = defaultdict() + self._task_weights = defaultdict(lambda: 1.0) if task_weights: for task, val in zip(cast(List[PredictionTask], prediction_tasks), task_weights): self._task_weights[task.task_name] = val