diff --git a/src/jabs/scripts/cli/convert_to_nwb.py b/src/jabs/scripts/cli/convert_to_nwb.py index f3e110b8..8287bde5 100644 --- a/src/jabs/scripts/cli/convert_to_nwb.py +++ b/src/jabs/scripts/cli/convert_to_nwb.py @@ -4,6 +4,7 @@ import logging from pathlib import Path +import h5py import numpy as np from jabs.core.abstract.pose_est import PoseEstimation @@ -32,6 +33,70 @@ def _segments_to_edges(segments) -> list[tuple[int, int]]: return edges +def _h5_attr_to_jsonable(value: object) -> object: + """Convert an HDF5 attribute value to a JSON-serializable Python object. + + ``h5py`` returns attribute values as numpy scalars or arrays, ``bytes`` + (for fixed-length string attributes), or native Python objects. This + normalizes them to plain JSON-friendly types (``str``, ``int``, ``float``, + ``bool``, ``list``, ``None``) so they can be embedded losslessly in the NWB + metadata JSON. A value of an unrecognized type is preserved as its string + representation rather than dropped. + + Args: + value: A raw attribute value as returned by ``h5py``. + + Returns: + A JSON-serializable representation of ``value``. + """ + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + if isinstance(value, np.ndarray): + # h5py returns scalar attributes as 0-d arrays; .tolist() yields a bare + # scalar (not a list), so unwrap to the scalar before recursing. + if value.shape == (): + return _h5_attr_to_jsonable(value.item()) + return [_h5_attr_to_jsonable(item) for item in value.tolist()] + if isinstance(value, np.generic): + return _h5_attr_to_jsonable(value.item()) + if isinstance(value, list | tuple): + return [_h5_attr_to_jsonable(item) for item in value] + if value is None or isinstance(value, str | int | float | bool): + return value + logger.warning( + "Preserving HDF5 attribute of unsupported type %s as a string", type(value).__name__ + ) + return str(value) + + +def _collect_hdf5_attributes(path: Path) -> dict[str, dict[str, object]]: + """Collect every attribute from every object in an HDF5 file. + + Walks the whole file - the root group, all sub-groups, and all datasets - + recording each object's attributes keyed by its HDF5 path (``"/"`` for the + root). Objects that carry no attributes are omitted. Attribute values are + normalized to JSON-serializable types via :func:`_h5_attr_to_jsonable` so + they survive the NWB metadata round-trip. + + Args: + path: Path to the HDF5 file to read. + + Returns: + Mapping of HDF5 object path to a dict of that object's attributes. + """ + collected: dict[str, dict[str, object]] = {} + + def _record(name: str, obj: h5py.Group | h5py.Dataset) -> None: + if len(obj.attrs) == 0: + return + collected[name] = {key: _h5_attr_to_jsonable(val) for key, val in obj.attrs.items()} + + with h5py.File(path, "r") as h5: + _record("/", h5) # visititems does not visit the root group itself + h5.visititems(_record) + return collected + + def pose_to_pose_data( pose: PoseEstimation, subjects: dict[str, dict] | None = None, @@ -40,6 +105,10 @@ def pose_to_pose_data( Handles all supported JABS pose versions (v2-v8). + Every attribute stored anywhere in the source pose HDF5 file is captured + into ``PoseData.metadata["hdf5_attributes"]`` (keyed by HDF5 object path) + so arbitrary provenance attributes are not lost in the NWB conversion. + Args: pose: A loaded PoseEstimation object (any version). subjects: Optional per-animal biological metadata, keyed by identity @@ -84,6 +153,10 @@ def pose_to_pose_data( if file_hash is not None: metadata["source_file_hash"] = file_hash + hdf5_attributes = _collect_hdf5_attributes(Path(pose.pose_file)) + if hdf5_attributes: + metadata["hdf5_attributes"] = hdf5_attributes + return PoseData( points=points_array, point_mask=point_mask_array, diff --git a/tests/scripts/test_convert_to_nwb.py b/tests/scripts/test_convert_to_nwb.py index 3b745437..4f706ffd 100644 --- a/tests/scripts/test_convert_to_nwb.py +++ b/tests/scripts/test_convert_to_nwb.py @@ -1,10 +1,17 @@ """Tests for convert_to_nwb helper functions.""" import datetime +import json +import h5py +import numpy as np import pytest -from jabs.scripts.cli.convert_to_nwb import _parse_session_start_time +from jabs.scripts.cli.convert_to_nwb import ( + _collect_hdf5_attributes, + _h5_attr_to_jsonable, + _parse_session_start_time, +) def test_parse_utc_offset(): @@ -49,3 +56,100 @@ def test_parse_non_string_raises(value): """Test that ValueError is raised if value is not a string.""" with pytest.raises(ValueError, match="must be a string"): _parse_session_start_time(value) + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (b"hello", "hello"), + (np.bytes_(b"world"), "world"), + ("plain", "plain"), + (np.int64(7), 7), + (np.float64(1.5), 1.5), + (np.bool_(True), True), + (np.array(7), 7), + (np.array(1.5), 1.5), + (np.array(b"scalar"), "scalar"), + (42, 42), + (3.14, 3.14), + (None, None), + ], + ids=[ + "bytes", + "np_bytes", + "str", + "np_int", + "np_float", + "np_bool", + "zerod_int_array", + "zerod_float_array", + "zerod_bytes_array", + "py_int", + "py_float", + "none", + ], +) +def test_h5_attr_to_jsonable_scalars(value, expected): + """Scalar HDF5 attribute values normalize to plain JSON-friendly types.""" + assert _h5_attr_to_jsonable(value) == expected + + +def test_h5_attr_to_jsonable_numeric_array(): + """A numeric numpy array becomes a plain list of Python numbers.""" + result = _h5_attr_to_jsonable(np.array([6, 0, 0], dtype=np.uint16)) + assert result == [6, 0, 0] + assert all(isinstance(x, int) for x in result) + + +def test_h5_attr_to_jsonable_byte_string_array(): + """An array of fixed-length byte strings is decoded to a list of str.""" + result = _h5_attr_to_jsonable(np.array([b"a", b"b"], dtype="S1")) + assert result == ["a", "b"] + + +def test_h5_attr_to_jsonable_unsupported_falls_back_to_str(caplog): + """An unrecognized type is preserved as its string representation with a warning.""" + import logging + + value = complex(1, 2) + with caplog.at_level(logging.WARNING): + result = _h5_attr_to_jsonable(value) + + assert result == str(value) + assert "unsupported type" in caplog.text.lower() + + +def test_collect_hdf5_attributes(tmp_path): + """All attributes across the file are collected, keyed by object path.""" + h5_path = tmp_path / "sample_pose_est_v6.h5" + with h5py.File(h5_path, "w") as h5: + h5.attrs["experimenter"] = "Jane Doe" + h5.attrs["custom_flag"] = np.int64(1) + poseest = h5.create_group("poseest") + poseest.attrs["version"] = np.array([6, 0], dtype=np.uint16) + poseest.attrs["cm_per_pixel"] = np.float64(0.07) + points = poseest.create_dataset("points", data=np.zeros((2, 12, 2))) + points.attrs["note"] = b"raw bytes note" + # group with no attributes should be omitted + h5.create_group("static_objects") + + collected = _collect_hdf5_attributes(h5_path) + + assert collected["/"] == {"experimenter": "Jane Doe", "custom_flag": 1} + assert collected["poseest"] == {"version": [6, 0], "cm_per_pixel": pytest.approx(0.07)} + assert collected["poseest/points"] == {"note": "raw bytes note"} + assert "static_objects" not in collected + + +def test_collect_hdf5_attributes_is_json_serializable(tmp_path): + """The collected attributes survive json.dumps without a custom encoder.""" + h5_path = tmp_path / "sample_pose_est_v6.h5" + with h5py.File(h5_path, "w") as h5: + h5.attrs["str_attr"] = "value" + h5.attrs["int_array"] = np.array([1, 2, 3], dtype=np.int32) + h5.create_group("poseest").attrs["bytes_attr"] = b"bytes" + + collected = _collect_hdf5_attributes(h5_path) + + # Should not raise; round-trips back to the same structure. + assert json.loads(json.dumps(collected)) == collected