diff --git a/dlclive/pose_estimation_pytorch/runner.py b/dlclive/pose_estimation_pytorch/runner.py index 11188f1..551640c 100644 --- a/dlclive/pose_estimation_pytorch/runner.py +++ b/dlclive/pose_estimation_pytorch/runner.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Literal +import warnings import numpy as np import torch @@ -131,15 +132,25 @@ def __init__( path: str | Path, device: str = "auto", precision: Literal["FP16", "FP32"] = "FP32", - single_animal: bool = True, + single_animal: bool | None = None, dynamic: dict | dynamic_cropping.DynamicCropper | None = None, top_down_config: dict | TopDownConfig | None = None, ) -> None: super().__init__(path) self.device = _parse_device(device) self.precision = precision + if single_animal is not None: + warnings.warn( + "The `single_animal` parameter is deprecated and will be removed " + "in a future version. The number of individuals will be automaticalliy inferred " + "from the model configuration. Remove argument `single_animal` or set " + "`single_animal=None` to accept the inferred value and silence this warning.", + DeprecationWarning, + stacklevel=2, + ) self.single_animal = single_animal - + self.n_individuals = None + self.n_bodyparts = None self.cfg = None self.detector = None self.model = None @@ -191,9 +202,14 @@ def get_pose(self, frame: np.ndarray) -> np.ndarray: frame_batch, offsets_and_scales = self._prepare_top_down(tensor, detections) if len(frame_batch) == 0: - offsets_and_scales = [(0, 0), 1] - else: - tensor = frame_batch # still CHW, batched + zero_pose = ( + np.zeros((self.n_bodyparts, 3)) + if self.n_individuals < 2 else + np.zeros((self.n_individuals, self.n_bodyparts, 3)) + ) + return zero_pose + + tensor = frame_batch # still CHW, batched if self.dynamic is not None: tensor = self.dynamic.crop(tensor) @@ -260,6 +276,15 @@ def load_model(self) -> None: raw_data = torch.load(self.path, map_location="cpu", weights_only=True) self.cfg = raw_data["config"] + + # Infer single animal mode and n_bodyparts from model configuration + individuals = self.cfg.get("metadata", {}).get("individuals", ['idv1']) + bodyparts = self.cfg.get("metadata", {}).get("bodyparts", []) + self.n_individuals = len(individuals) + self.n_bodyparts = len(bodyparts) + if self.single_animal is None: + self.single_animal = self.n_individuals == 1 + self.model = models.PoseModel.build(self.cfg["model"]) self.model.load_state_dict(raw_data["pose"]) self.model = self.model.to(self.device)