Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/jabs/scripts/cli/convert_to_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from pathlib import Path

import h5py
import numpy as np

from jabs.core.abstract.pose_est import PoseEstimation
Expand Down Expand Up @@ -32,6 +33,70 @@ def _segments_to_edges(segments) -> list[tuple[int, int]]:
return edges


def _h5_attr_to_jsonable(value: object) -> object:
"""Convert an HDF5 attribute value to a JSON-serializable Python object.

``h5py`` returns attribute values as numpy scalars or arrays, ``bytes``
(for fixed-length string attributes), or native Python objects. This
normalizes them to plain JSON-friendly types (``str``, ``int``, ``float``,
``bool``, ``list``, ``None``) so they can be embedded losslessly in the NWB
metadata JSON. A value of an unrecognized type is preserved as its string
representation rather than dropped.

Args:
value: A raw attribute value as returned by ``h5py``.

Returns:
A JSON-serializable representation of ``value``.
"""
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
if isinstance(value, np.ndarray):
# h5py returns scalar attributes as 0-d arrays; .tolist() yields a bare
# scalar (not a list), so unwrap to the scalar before recursing.
if value.shape == ():
return _h5_attr_to_jsonable(value.item())
return [_h5_attr_to_jsonable(item) for item in value.tolist()]
if isinstance(value, np.generic):
return _h5_attr_to_jsonable(value.item())
if isinstance(value, list | tuple):
return [_h5_attr_to_jsonable(item) for item in value]
if value is None or isinstance(value, str | int | float | bool):
return value
logger.warning(
"Preserving HDF5 attribute of unsupported type %s as a string", type(value).__name__
)
return str(value)


def _collect_hdf5_attributes(path: Path) -> dict[str, dict[str, object]]:
"""Collect every attribute from every object in an HDF5 file.

Walks the whole file - the root group, all sub-groups, and all datasets -
recording each object's attributes keyed by its HDF5 path (``"/"`` for the
root). Objects that carry no attributes are omitted. Attribute values are
normalized to JSON-serializable types via :func:`_h5_attr_to_jsonable` so
they survive the NWB metadata round-trip.

Args:
path: Path to the HDF5 file to read.

Returns:
Mapping of HDF5 object path to a dict of that object's attributes.
"""
collected: dict[str, dict[str, object]] = {}

def _record(name: str, obj: h5py.Group | h5py.Dataset) -> None:
if len(obj.attrs) == 0:
return
collected[name] = {key: _h5_attr_to_jsonable(val) for key, val in obj.attrs.items()}

with h5py.File(path, "r") as h5:
_record("/", h5) # visititems does not visit the root group itself
h5.visititems(_record)
return collected


def pose_to_pose_data(
pose: PoseEstimation,
subjects: dict[str, dict] | None = None,
Expand All @@ -40,6 +105,10 @@ def pose_to_pose_data(

Handles all supported JABS pose versions (v2-v8).

Every attribute stored anywhere in the source pose HDF5 file is captured
into ``PoseData.metadata["hdf5_attributes"]`` (keyed by HDF5 object path)
so arbitrary provenance attributes are not lost in the NWB conversion.

Args:
pose: A loaded PoseEstimation object (any version).
subjects: Optional per-animal biological metadata, keyed by identity
Expand Down Expand Up @@ -84,6 +153,10 @@ def pose_to_pose_data(
if file_hash is not None:
metadata["source_file_hash"] = file_hash

hdf5_attributes = _collect_hdf5_attributes(Path(pose.pose_file))
if hdf5_attributes:
metadata["hdf5_attributes"] = hdf5_attributes

return PoseData(
points=points_array,
point_mask=point_mask_array,
Expand Down
106 changes: 105 additions & 1 deletion tests/scripts/test_convert_to_nwb.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
"""Tests for convert_to_nwb helper functions."""

import datetime
import json

import h5py
import numpy as np
import pytest

from jabs.scripts.cli.convert_to_nwb import _parse_session_start_time
from jabs.scripts.cli.convert_to_nwb import (
_collect_hdf5_attributes,
_h5_attr_to_jsonable,
_parse_session_start_time,
)


def test_parse_utc_offset():
Expand Down Expand Up @@ -49,3 +56,100 @@ def test_parse_non_string_raises(value):
"""Test that ValueError is raised if value is not a string."""
with pytest.raises(ValueError, match="must be a string"):
_parse_session_start_time(value)


@pytest.mark.parametrize(
("value", "expected"),
[
(b"hello", "hello"),
(np.bytes_(b"world"), "world"),
("plain", "plain"),
(np.int64(7), 7),
(np.float64(1.5), 1.5),
(np.bool_(True), True),
(np.array(7), 7),
(np.array(1.5), 1.5),
(np.array(b"scalar"), "scalar"),
(42, 42),
(3.14, 3.14),
(None, None),
],
ids=[
"bytes",
"np_bytes",
"str",
"np_int",
"np_float",
"np_bool",
"zerod_int_array",
"zerod_float_array",
"zerod_bytes_array",
"py_int",
"py_float",
"none",
],
)
def test_h5_attr_to_jsonable_scalars(value, expected):
"""Scalar HDF5 attribute values normalize to plain JSON-friendly types."""
assert _h5_attr_to_jsonable(value) == expected


def test_h5_attr_to_jsonable_numeric_array():
"""A numeric numpy array becomes a plain list of Python numbers."""
result = _h5_attr_to_jsonable(np.array([6, 0, 0], dtype=np.uint16))
assert result == [6, 0, 0]
assert all(isinstance(x, int) for x in result)


def test_h5_attr_to_jsonable_byte_string_array():
"""An array of fixed-length byte strings is decoded to a list of str."""
result = _h5_attr_to_jsonable(np.array([b"a", b"b"], dtype="S1"))
assert result == ["a", "b"]


def test_h5_attr_to_jsonable_unsupported_falls_back_to_str(caplog):
"""An unrecognized type is preserved as its string representation with a warning."""
import logging

value = complex(1, 2)
with caplog.at_level(logging.WARNING):
result = _h5_attr_to_jsonable(value)

assert result == str(value)
assert "unsupported type" in caplog.text.lower()


def test_collect_hdf5_attributes(tmp_path):
"""All attributes across the file are collected, keyed by object path."""
h5_path = tmp_path / "sample_pose_est_v6.h5"
with h5py.File(h5_path, "w") as h5:
h5.attrs["experimenter"] = "Jane Doe"
h5.attrs["custom_flag"] = np.int64(1)
poseest = h5.create_group("poseest")
poseest.attrs["version"] = np.array([6, 0], dtype=np.uint16)
poseest.attrs["cm_per_pixel"] = np.float64(0.07)
points = poseest.create_dataset("points", data=np.zeros((2, 12, 2)))
points.attrs["note"] = b"raw bytes note"
# group with no attributes should be omitted
h5.create_group("static_objects")

collected = _collect_hdf5_attributes(h5_path)

assert collected["/"] == {"experimenter": "Jane Doe", "custom_flag": 1}
assert collected["poseest"] == {"version": [6, 0], "cm_per_pixel": pytest.approx(0.07)}
assert collected["poseest/points"] == {"note": "raw bytes note"}
assert "static_objects" not in collected


def test_collect_hdf5_attributes_is_json_serializable(tmp_path):
"""The collected attributes survive json.dumps without a custom encoder."""
h5_path = tmp_path / "sample_pose_est_v6.h5"
with h5py.File(h5_path, "w") as h5:
h5.attrs["str_attr"] = "value"
h5.attrs["int_array"] = np.array([1, 2, 3], dtype=np.int32)
h5.create_group("poseest").attrs["bytes_attr"] = b"bytes"

collected = _collect_hdf5_attributes(h5_path)

# Should not raise; round-trips back to the same structure.
assert json.loads(json.dumps(collected)) == collected
Loading