From eff38e6e3b70c46b670f63a5c30198d7785f6b21 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 15:15:39 +0100 Subject: [PATCH 01/32] Add FrozenDataContainer class --- eitprocessing/datahandling/__init__.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/eitprocessing/datahandling/__init__.py b/eitprocessing/datahandling/__init__.py index 190182068..14a24d185 100644 --- a/eitprocessing/datahandling/__init__.py +++ b/eitprocessing/datahandling/__init__.py @@ -1,3 +1,4 @@ +import dataclasses from copy import deepcopy from dataclasses import dataclass @@ -16,3 +17,26 @@ def __bool__(self): def deepcopy(self) -> Self: """Return a deep copy of the object.""" return deepcopy(self) + + +@dataclass(eq=False, frozen=True) +class FrozenDataContainer(Equivalence): + """Base class for data container classes.""" + + def __bool__(self): + return True + + def deepcopy(self) -> Self: + """Return a deep copy of the object.""" + return deepcopy(self) + + def update(self: Self, **kwargs: object) -> Self: + """Return a copy of the object with specified fields replaced. + + Args: + **kwargs: Fields to replace. + + Returns: + A new instance of the object with the specified fields replaced. + """ + return dataclasses.replace(self, **kwargs) From 3a1532d9b22fa4d489f41476cab44510a10536cd Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 15:16:54 +0100 Subject: [PATCH 02/32] Add function to freeze a numpy array --- eitprocessing/utils/frozen_array.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 eitprocessing/utils/frozen_array.py diff --git a/eitprocessing/utils/frozen_array.py b/eitprocessing/utils/frozen_array.py new file mode 100644 index 000000000..7a2955679 --- /dev/null +++ b/eitprocessing/utils/frozen_array.py @@ -0,0 +1,22 @@ +from typing import Literal + +import numpy as np + +DEFAULT_FREEZE_METHOD = "memoryview" + + +def freeze_array(a: np.ndarray, *, method: Literal["flag", "memoryview"] = DEFAULT_FREEZE_METHOD) -> np.ndarray: + """Return a read-only array that cannot be made writeable again.""" + match method: + case "flag": + if a.flags["WRITEABLE"]: + a = a.copy() + a.flags["WRITEABLE"] = False + return a # is already read-only, e.g., a view of a read-only array + case "memoryview": + a_c = np.ascontiguousarray(a) + ro_buf = memoryview(a_c).toreadonly() + return np.frombuffer(ro_buf, dtype=a_c.dtype).reshape(a_c.shape) + case _: + msg = f"Invalid method: {method!r}" + raise ValueError(msg) From 70d2b904ae3432d7405329e2c6cf06d5012b65d8 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 15:17:26 +0100 Subject: [PATCH 03/32] Make EITData objects frozen --- eitprocessing/datahandling/eitdata.py | 140 +++++++++++++++++++------- 1 file changed, 106 insertions(+), 34 deletions(-) diff --git a/eitprocessing/datahandling/eitdata.py b/eitprocessing/datahandling/eitdata.py index ea50d9bf0..03c5205b1 100644 --- a/eitprocessing/datahandling/eitdata.py +++ b/eitprocessing/datahandling/eitdata.py @@ -8,9 +8,10 @@ import numpy as np -from eitprocessing.datahandling import DataContainer +from eitprocessing.datahandling import FrozenDataContainer from eitprocessing.datahandling.continuousdata import ContinuousData from eitprocessing.datahandling.mixins.slicing import SelectByTime +from eitprocessing.utils.frozen_array import freeze_array if TYPE_CHECKING: from typing_extensions import Self @@ -19,8 +20,8 @@ T = TypeVar("T", bound="EITData") -@dataclass(eq=False) -class EITData(DataContainer, SelectByTime): +@dataclass(eq=False, frozen=True) +class EITData(FrozenDataContainer, SelectByTime): """Container for EIT impedance data. This class holds the pixel impedance from an EIT measurement, as well as metadata describing the measurement. The @@ -41,46 +42,99 @@ class is meant to hold data from (part of) a singular continuous measurement. """ # TODO: fix docstring path: str | Path | list[Path | str] = field(compare=False, repr=False) - nframes: int = field(repr=False) time: np.ndarray = field(repr=False) sample_frequency: float = field(metadata={"check_equivalence": True}, repr=False) vendor: Vendor = field(metadata={"check_equivalence": True}, repr=False) label: str | None = field(default=None, compare=False, metadata={"check_equivalence": True}) - description: str = field(default="", compare=False, repr=False) + description: str | None = field(default=None, compare=False, repr=False) name: str | None = field(default=None, compare=False, repr=False) - pixel_impedance: np.ndarray = field(repr=False, kw_only=True) + values: np.ndarray = field(repr=False, kw_only=True) suppress_simulated_warning: InitVar[bool] = False - def __post_init__(self, suppress_simulated_warning: bool) -> None: - if not self.label: - self.label = f"{self.__class__.__name__}_{id(self)}" + def __init__( + self, + *, + time: np.ndarray, + sample_frequency: float, + vendor: Vendor | str, + path: str | Path | list[Path | str], + values: np.ndarray | None = None, + label: str | None = None, + description: str | None = None, + name: str | None = None, + suppress_simulated_warning: bool = False, + **kwargs, + ): + if "pixel_impedance" in kwargs: + if values is not None: + msg = "Cannot provide both 'pixel_impedance' and 'values'." + raise ValueError(msg) + warnings.warn("`pixel_impedance` has been replaced by `values`.", DeprecationWarning, stacklevel=2) + values = kwargs.pop("pixel_impedance") + + if "nframes" in kwargs: + warnings.warn( + "`nframes` is no longer a constructor argument. Use `len(eitdata)` instead.", + DeprecationWarning, + stacklevel=2, + ) + _ = kwargs.pop("nframes") + + if kwargs: + msg = f"Unexpected keyword arguments: {', '.join(kwargs.keys())}." + raise TypeError(msg) + + if not isinstance(values, np.ndarray): + msg = f"'values' must be a numpy ndarray, not {type(values)}." + raise TypeError(msg) + + label = label or f"{self.__class__.__name__}_{id(self)}" + object.__setattr__(self, "label", label) + object.__setattr__(self, "name", name) + object.__setattr__(self, "description", description) - self.path = self.ensure_path_list(self.path) - if len(self.path) == 1: - self.path = self.path[0] + path_list = self.ensure_path_list(path) + if len(path_list) == 1: + object.__setattr__(self, "path", path_list[0]) + else: + object.__setattr__(self, "path", path_list) - self.name = self.name or self.label - old_sample_frequency = self.sample_frequency - self.sample_frequency = float(self.sample_frequency) - if self.sample_frequency != old_sample_frequency: + object.__setattr__(self, "sample_frequency", float(sample_frequency)) + if self.sample_frequency != sample_frequency: msg = ( "Sample frequency could not be correctly converted from " - f"{old_sample_frequency} ({type(old_sample_frequency)}) to " + f"{sample_frequency} ({type(sample_frequency)}) to " f"{self.sample_frequency:.1f} (float)." ) raise TypeError(msg) - if (lv := len(self.pixel_impedance)) != (lt := len(self.time)): + if (lv := len(values)) != (lt := len(time)): msg = f"The number of time points ({lt}) does not match the number of pixel impedance values ({lv})." raise ValueError(msg) - if not suppress_simulated_warning and self.vendor == Vendor.SIMULATED: + object.__setattr__(self, "values", freeze_array(values)) + object.__setattr__(self, "time", freeze_array(time)) + + vendor = Vendor(vendor) + if not suppress_simulated_warning and vendor == Vendor.SIMULATED: warnings.warn( "The simulated vendor is used for testing purposes. " "It is not a real vendor and should not be used in production code.", UserWarning, stacklevel=2, ) + object.__setattr__(self, "vendor", vendor) + + @property + def pixel_impedance(self) -> np.ndarray: + """Alias to `values`.""" + return self.values + + @property + def nframes(self) -> int: + """Number of frames in the data.""" + warnings.warn("`nframes` is deprecated. Use `len(eitdata)` instead.", DeprecationWarning, stacklevel=2) + return self.pixel_impedance.shape[0] @property def framerate(self) -> float: @@ -135,20 +189,9 @@ def _sliced_copy( end_index: int, newlabel: str, # noqa: ARG002 ) -> Self: - cls = self.__class__ - time = np.copy(self.time[start_index:end_index]) - nframes = len(time) - - pixel_impedance = np.copy(self.pixel_impedance[start_index:end_index, :, :]) - - return cls( - path=self.path, - nframes=nframes, - vendor=self.vendor, - time=time, - sample_frequency=self.sample_frequency, - label=self.label, # newlabel gives errors - pixel_impedance=pixel_impedance, + return self.update( + time=self.time[start_index:end_index], + values=self.pixel_impedance[start_index:end_index, :, :], ) def __len__(self): @@ -174,12 +217,41 @@ def get_summed_impedance(self, *, return_label: str | None = None, **return_kwar "sample_frequency": self.sample_frequency, } | return_kwargs - return ContinuousData(label=return_label, time=np.copy(self.time), values=summed_impedance, **return_kwargs_) + return ContinuousData(label=return_label, time=self.time, values=summed_impedance, **return_kwargs_) def calculate_global_impedance(self) -> np.ndarray: """Return the global impedance, i.e. the sum of all included pixels at each frame.""" return np.nansum(self.pixel_impedance, axis=(1, 2)) + def update(self, **kwargs) -> Self: + """Return a copy of the object with specified fields replaced. + + Args: + **kwargs: Fields to replace. + + Returns: + A new instance of the object with the specified fields replaced. + """ + if "pixel_impedance" in kwargs: + if "values" in kwargs: + msg = "Cannot provide both 'pixel_impedance' and 'values'." + raise ValueError(msg) + warnings.warn("`pixel_impedance` has been replaced by `values`.", DeprecationWarning, stacklevel=2) + kwargs["values"] = kwargs.pop("pixel_impedance") + + if "framerate" in kwargs: + if "sample_frequency" in kwargs: + msg = "Cannot provide both 'framerate' and 'sample_frequency'." + raise ValueError(msg) + warnings.warn( + "`framerate` has been deprecated. Use `sample_frequency` instead.", + DeprecationWarning, + stacklevel=2, + ) + kwargs["sample_frequency"] = kwargs.pop("framerate") + + return super().update(**kwargs) + class Vendor(Enum): """Enum indicating the vendor (manufacturer) of the source EIT device. From bc6196ec84c79aa2ee71c908914de65afafa90ed Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 15:17:48 +0100 Subject: [PATCH 04/32] Update MDN to use frozen EITData object --- eitprocessing/filters/mdn.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/eitprocessing/filters/mdn.py b/eitprocessing/filters/mdn.py index b3b2c2b3c..012e7d5e4 100644 --- a/eitprocessing/filters/mdn.py +++ b/eitprocessing/filters/mdn.py @@ -1,7 +1,7 @@ -import copy import math import warnings from dataclasses import dataclass +from dataclasses import replace as dataclass_replace from typing import TypeVar, cast, overload import numpy as np @@ -155,17 +155,17 @@ def apply( # pyright: ignore[reportInconsistentOverload] return new_data # TODO: Replace with input_data.update(...) when implemented - return_object = copy.deepcopy(input_data) - for attr, value in kwargs.items(): - setattr(return_object, attr, value) - if isinstance(return_object, ContinuousData): - return_object.values = new_data - elif isinstance(return_object, EITData): - return_object.pixel_impedance = new_data + kwargs["values"] = new_data - capture("filtered_data", return_object) - return return_object + if isinstance(input_data, ContinuousData): + filtered_data: T = dataclass_replace(input_data, **kwargs) + elif isinstance(input_data, EITData): + filtered_data: T = input_data.update(**kwargs) + + capture("filtered_data", filtered_data) + + return filtered_data def _validate_arguments( self, From 1b5a8845ddcb89c25773e10d21c56c0d4ff7252e Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 15:18:43 +0100 Subject: [PATCH 05/32] Update ROI to use frozen EITData --- eitprocessing/roi/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eitprocessing/roi/__init__.py b/eitprocessing/roi/__init__.py index 21b10adf8..d02daad5e 100644 --- a/eitprocessing/roi/__init__.py +++ b/eitprocessing/roi/__init__.py @@ -13,7 +13,7 @@ import re import sys import warnings -from dataclasses import InitVar, dataclass, field, replace +from dataclasses import InitVar, dataclass, field from dataclasses import replace as dataclass_replace from typing import TYPE_CHECKING, TypeVar, overload @@ -189,7 +189,7 @@ def __replace__(self, /, **changes) -> Self: elif isinstance(changes["plot_config"], dict): changes["plot_config"] = self._plot_config.update(**changes["plot_config"]) label = changes.pop("label", None) - return replace(self, label=label, **changes) + return dataclass_replace(self, label=label, **changes) update = __replace__ # TODO: add tests for update @@ -242,7 +242,7 @@ def transform_and_mask(data: np.ndarray) -> np.ndarray: case np.ndarray(): return transform_and_mask(data) case EITData(): - return dataclass_replace(data, pixel_impedance=transform_and_mask(data.pixel_impedance), **kwargs) + return data.update(pixel_impedance=transform_and_mask(data.pixel_impedance), **kwargs) case PixelMap(): return data.update(values=transform_and_mask(data.values), **kwargs) case _: From a3105f16d5013c269e075a1b38c35c1e352de7c4 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 15:19:13 +0100 Subject: [PATCH 06/32] Add test to see whether frozen objects work as intended --- tests/test_frozen.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/test_frozen.py diff --git a/tests/test_frozen.py b/tests/test_frozen.py new file mode 100644 index 000000000..7d98ebd1f --- /dev/null +++ b/tests/test_frozen.py @@ -0,0 +1,76 @@ +import contextlib + +import numpy as np +import pytest + +from eitprocessing.datahandling.eitdata import EITData + + +@pytest.fixture +def frozen_eitdata_object() -> EITData: + return EITData( + path="test_path", + label="test_label", + time=np.arange(10) / 10.0, + pixel_impedance=np.random.default_rng().random((10, 3, 3)), + sample_frequency=10.0, + vendor="simulated", + ) + + +def test_frozen_time_axis(frozen_eitdata_object: EITData): + with pytest.raises(AttributeError, match="cannot assign to field 'vendor'"): + frozen_eitdata_object.vendor = "new_vendor" + + with pytest.raises(ValueError, match="output array is read-only"): + frozen_eitdata_object.time += 1.0 + + with pytest.raises(AttributeError, match="cannot assign to field 'time'"): + frozen_eitdata_object.time = frozen_eitdata_object.time + 1.0 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + frozen_eitdata_object.time[0] = 1.0 + + +def test_frozen_values(frozen_eitdata_object: EITData): + with pytest.raises(ValueError, match="output array is read-only"): + frozen_eitdata_object.pixel_impedance += 1.0 + + with pytest.raises(AttributeError, match="cannot assign to field 'pixel_impedance'"): + frozen_eitdata_object.pixel_impedance = frozen_eitdata_object.pixel_impedance + 1.0 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + frozen_eitdata_object.pixel_impedance[0, 0, 0] = 1.0 + + +def test_unfreeze_array_on_copy(frozen_eitdata_object: EITData): + values_copy = frozen_eitdata_object.pixel_impedance.copy() + assert values_copy.flags["WRITEABLE"] + values_copy += 1.0 + new_frozen_eitdata_object = frozen_eitdata_object.update(pixel_impedance=values_copy) + assert not new_frozen_eitdata_object.pixel_impedance.flags["WRITEABLE"] + assert np.array_equal(values_copy, new_frozen_eitdata_object.pixel_impedance) + + +def test_frozen_slice(frozen_eitdata_object: EITData): + values_view = frozen_eitdata_object.pixel_impedance[:5, :, :] + assert not values_view.flags["WRITEABLE"] + with pytest.raises(ValueError, match="assignment destination is read-only"): + values_view[0, 0, 0] = 1.0 + + values_view = values_view.copy() + assert values_view.flags["WRITEABLE"] + values_view += 1.0 + + +def test_cannot_unfreeze(frozen_eitdata_object: EITData): + base = frozen_eitdata_object.pixel_impedance.base + with contextlib.suppress(AttributeError): + while True: + base = base.base + + if not isinstance(base, memoryview): + pytest.skip("Array is not based on a memoryview; cannot test unfreeze.") + + with pytest.raises(ValueError, match="cannot set WRITEABLE flag to True of this array"): + frozen_eitdata_object.pixel_impedance.flags["WRITEABLE"] = True From fe5af4ccc5fe20fa6e79fb1fdcb91ff0462e385d Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 15:19:23 +0100 Subject: [PATCH 07/32] Update pixel breath tests to work with frozen EITData --- tests/test_pixel_breath.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_pixel_breath.py b/tests/test_pixel_breath.py index 43ee5791b..4efdecabe 100644 --- a/tests/test_pixel_breath.py +++ b/tests/test_pixel_breath.py @@ -393,11 +393,14 @@ def test_phase_modes(draeger_50hz_healthy_volunteer_pressure_pod: Sequence, pyte eit_data = sequence.eit_data["raw"] # reduce the pixel set to some 'well-behaved' pixels with positive TIV - eit_data.pixel_impedance = eit_data.pixel_impedance[:, 14:20, 14:20] + eit_data = eit_data.update(values=eit_data.pixel_impedance[:, 14:20, 14:20]) + np.savetxt("v_before_new.txt", eit_data.values[0]) # flip a single pixel, so the differences between algorithms becomes predictable flip_row, flip_col = 5, 5 - eit_data.pixel_impedance[:, flip_row, flip_col] = -eit_data.pixel_impedance[:, flip_row, flip_col] + new_values = eit_data.pixel_impedance.copy() + new_values[:, flip_row, flip_col] = -new_values[:, flip_row, flip_col] + eit_data = eit_data.update(pixel_impedance=new_values) cd = sequence.continuous_data["global_impedance_(raw)"] From d597d74e4460da1528fd8710dfe8088b4a7a9b13 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 15:45:29 +0100 Subject: [PATCH 08/32] Remove nframes and allow other EITData attributes to be None --- eitprocessing/datahandling/eitdata.py | 80 +++++++++++-------- eitprocessing/datahandling/loading/draeger.py | 3 +- eitprocessing/datahandling/loading/sentec.py | 3 +- eitprocessing/datahandling/loading/timpel.py | 3 +- tests/test_amplitude_lungspace.py | 4 +- tests/test_parameter_tiv.py | 3 +- tests/test_pixel_breath.py | 6 +- tests/test_rate_detection.py | 3 +- tests/test_tiv_lungspace.py | 5 +- tests/test_watershed.py | 10 +-- 10 files changed, 56 insertions(+), 64 deletions(-) diff --git a/eitprocessing/datahandling/eitdata.py b/eitprocessing/datahandling/eitdata.py index 03c5205b1..68fadd0a5 100644 --- a/eitprocessing/datahandling/eitdata.py +++ b/eitprocessing/datahandling/eitdata.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from dataclasses import InitVar, dataclass, field +from dataclasses import KW_ONLY, InitVar, dataclass, field from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any, TypeVar @@ -31,58 +31,42 @@ class is meant to hold data from (part of) a singular continuous measurement. disk. Args: - path: The path of list of paths of the source from which data was derived. - nframes: Number of frames. time: The time of each frame (since start measurement). + values: Impedance values for each pixel at each frame. sample_frequency: The (average) frequency at which the frames are collected, in Hz. vendor: The vendor of the device the data was collected with. + path: The path of list of paths of the source from which data was derived. label: Computer readable label identifying this dataset. name: Human readable name for the data. - pixel_impedance: Impedance values for each pixel at each frame. + description: Human readable description of the data. """ # TODO: fix docstring - path: str | Path | list[Path | str] = field(compare=False, repr=False) time: np.ndarray = field(repr=False) + values: np.ndarray = field(repr=False) + _: KW_ONLY sample_frequency: float = field(metadata={"check_equivalence": True}, repr=False) vendor: Vendor = field(metadata={"check_equivalence": True}, repr=False) + path: str | Path | list[Path | str] | None = field(compare=False, repr=False, default=None) label: str | None = field(default=None, compare=False, metadata={"check_equivalence": True}) description: str | None = field(default=None, compare=False, repr=False) name: str | None = field(default=None, compare=False, repr=False) - values: np.ndarray = field(repr=False, kw_only=True) suppress_simulated_warning: InitVar[bool] = False def __init__( self, - *, time: np.ndarray, + values: np.ndarray | None = None, + *, sample_frequency: float, vendor: Vendor | str, - path: str | Path | list[Path | str], - values: np.ndarray | None = None, + path: str | Path | list[Path | str] | None = None, label: str | None = None, description: str | None = None, name: str | None = None, suppress_simulated_warning: bool = False, **kwargs, ): - if "pixel_impedance" in kwargs: - if values is not None: - msg = "Cannot provide both 'pixel_impedance' and 'values'." - raise ValueError(msg) - warnings.warn("`pixel_impedance` has been replaced by `values`.", DeprecationWarning, stacklevel=2) - values = kwargs.pop("pixel_impedance") - - if "nframes" in kwargs: - warnings.warn( - "`nframes` is no longer a constructor argument. Use `len(eitdata)` instead.", - DeprecationWarning, - stacklevel=2, - ) - _ = kwargs.pop("nframes") - - if kwargs: - msg = f"Unexpected keyword arguments: {', '.join(kwargs.keys())}." - raise TypeError(msg) + values = self._parse_kwargs(values, kwargs) if not isinstance(values, np.ndarray): msg = f"'values' must be a numpy ndarray, not {type(values)}." @@ -93,11 +77,14 @@ def __init__( object.__setattr__(self, "name", name) object.__setattr__(self, "description", description) - path_list = self.ensure_path_list(path) - if len(path_list) == 1: - object.__setattr__(self, "path", path_list[0]) + if path is None: + object.__setattr__(self, "path", None) else: - object.__setattr__(self, "path", path_list) + path_list = self.ensure_path_list(path) + if len(path_list) == 1: + object.__setattr__(self, "path", path_list[0]) + else: + object.__setattr__(self, "path", path_list) object.__setattr__(self, "sample_frequency", float(sample_frequency)) if self.sample_frequency != sample_frequency: @@ -125,6 +112,27 @@ def __init__( ) object.__setattr__(self, "vendor", vendor) + def _parse_kwargs(self, values: np.ndarray | None, kwargs: dict[str, Any]) -> np.ndarray | None: + if "pixel_impedance" in kwargs: + if values is not None: + msg = "Cannot provide both 'pixel_impedance' and 'values'." + raise ValueError(msg) + warnings.warn("`pixel_impedance` has been replaced by `values`.", DeprecationWarning, stacklevel=2) + values = kwargs.pop("pixel_impedance") + + if "nframes" in kwargs: + warnings.warn( + "`nframes` is no longer a constructor argument. Use `len(eitdata)` instead.", + DeprecationWarning, + stacklevel=2, + ) + _ = kwargs.pop("nframes") + + if kwargs: + msg = f"Unexpected keyword arguments: {', '.join(kwargs.keys())}." + raise TypeError(msg) + return values + @property def pixel_impedance(self) -> np.ndarray: """Alias to `values`.""" @@ -169,16 +177,18 @@ def concatenate(self: Self, other: Self, newlabel: str | None = None) -> Self: msg = f"Concatenation failed. Second dataset ({other.name}) may not start before first ({self.name}) ends." raise ValueError(msg) - self_path = self.ensure_path_list(self.path) - other_path = self.ensure_path_list(other.path) + self_path = [] if self.path is None else self.ensure_path_list(self.path) + other_path = [] if other.path is None else self.ensure_path_list(other.path) + concat_path = [*self_path, *other_path] + if not concat_path: + concat_path = None newlabel = newlabel or f"Merge of <{self.label}> and <{other.label}>" return self.__class__( vendor=self.vendor, - path=[*self_path, *other_path], + path=concat_path, label=self.label, # TODO: using newlabel leads to errors sample_frequency=self.sample_frequency, - nframes=self.nframes + other.nframes, time=np.concatenate((self.time, other.time)), pixel_impedance=np.concatenate((self.pixel_impedance, other.pixel_impedance), axis=0), ) diff --git a/eitprocessing/datahandling/loading/draeger.py b/eitprocessing/datahandling/loading/draeger.py index 5fb048fd8..1649b5c8c 100644 --- a/eitprocessing/datahandling/loading/draeger.py +++ b/eitprocessing/datahandling/loading/draeger.py @@ -120,10 +120,9 @@ def load_from_single_path( # noqa: PLR0915 vendor=Vendor.DRAEGER, path=path, sample_frequency=sample_frequency, - nframes=n_frames, time=time, label="raw", - pixel_impedance=pixel_impedance, + values=pixel_impedance, ) eitdata_collection = DataCollection(EITData, raw=eit_data) diff --git a/eitprocessing/datahandling/loading/sentec.py b/eitprocessing/datahandling/loading/sentec.py index 5a252eda2..aee08c606 100644 --- a/eitprocessing/datahandling/loading/sentec.py +++ b/eitprocessing/datahandling/loading/sentec.py @@ -86,10 +86,9 @@ def load_from_single_path( vendor=Vendor.SENTEC, path=path, sample_frequency=sample_frequency, - nframes=n_frames, time=time_array, label="raw", - pixel_impedance=image, + values=image, ), ) diff --git a/eitprocessing/datahandling/loading/timpel.py b/eitprocessing/datahandling/loading/timpel.py index 2a93107f7..024e4deba 100644 --- a/eitprocessing/datahandling/loading/timpel.py +++ b/eitprocessing/datahandling/loading/timpel.py @@ -96,10 +96,9 @@ def load_from_single_path( vendor=Vendor.TIMPEL, label="raw", path=path, - nframes=nframes, time=time, sample_frequency=sample_frequency, - pixel_impedance=pixel_impedance, + values=pixel_impedance, ) eitdata_collection = DataCollection(EITData, raw=eit_data) diff --git a/tests/test_amplitude_lungspace.py b/tests/test_amplitude_lungspace.py index 6032ae428..1e8ea3cd6 100644 --- a/tests/test_amplitude_lungspace.py +++ b/tests/test_amplitude_lungspace.py @@ -26,15 +26,13 @@ def factory(amplitudes: npt.ArrayLike, duration: float = 20.0) -> EITData: sine = np.sin(t * 2 * np.pi / 5) # 5 second period signal = amplitudes[None, :, :] * (sine[:, None, None] + 0.5) return EITData( - pixel_impedance=signal, + values=signal, path="", - nframes=len(t), time=t, sample_frequency=sample_frequency, vendor=Vendor.SIMULATED, label="simulated", description="Simulated EIT data for testing purposes", - name="", suppress_simulated_warning=True, ) diff --git a/tests/test_parameter_tiv.py b/tests/test_parameter_tiv.py index 3410e4eda..f3be9bcaf 100644 --- a/tests/test_parameter_tiv.py +++ b/tests/test_parameter_tiv.py @@ -86,13 +86,12 @@ def mock_eit_data(): """Fixture to provide an instance of EITData.""" return EITData( path="", - nframes=2000, time=np.linspace(0, 18, (18 * 1000), endpoint=False), sample_frequency=1000, vendor=Vendor.DRAEGER, label="mock_eit_data", name="mock_eit_data", - pixel_impedance=mock_pixel_impedance(), + values=mock_pixel_impedance(), ) diff --git a/tests/test_pixel_breath.py b/tests/test_pixel_breath.py index 4efdecabe..459dcf6e2 100644 --- a/tests/test_pixel_breath.py +++ b/tests/test_pixel_breath.py @@ -75,13 +75,12 @@ def mock_eit_data(): """Fixture to provide an instance of EITData.""" return EITData( path="", - nframes=400, time=np.linspace(0, 2 * np.pi, 400), sample_frequency=399 / 2 * np.pi, vendor=Vendor.DRAEGER, label="mock_eit_data", name="mock_eit_data", - pixel_impedance=mock_pixel_impedance(), + values=mock_pixel_impedance(), ) @@ -113,13 +112,12 @@ def mock_zero_eit_data(): """Fixture to provide an instance of EITData with one element set to zero.""" return EITData( path="", - nframes=400, time=np.linspace(0, 2 * np.pi, 400), sample_frequency=399 / 2 * np.pi, vendor=Vendor.DRAEGER, label="mock_eit_data", name="mock_eit_data", - pixel_impedance=mock_pixel_impedance_one_zero(), + values=mock_pixel_impedance_one_zero(), ) diff --git a/tests/test_rate_detection.py b/tests/test_rate_detection.py index 4c943ace4..cfd5e98e8 100644 --- a/tests/test_rate_detection.py +++ b/tests/test_rate_detection.py @@ -107,12 +107,11 @@ def factory( # noqa: PLR0913 return EITData( path=".", - nframes=nframes, time=time, sample_frequency=sample_frequency, vendor="draeger", label="test_signal", - pixel_impedance=pixel_impedance, + values=pixel_impedance, ) return factory diff --git a/tests/test_tiv_lungspace.py b/tests/test_tiv_lungspace.py index 3770351e1..6064c1f49 100644 --- a/tests/test_tiv_lungspace.py +++ b/tests/test_tiv_lungspace.py @@ -26,15 +26,12 @@ def factory(amplitudes: npt.ArrayLike, duration: float = 20.0) -> EITData: sine = np.sin(t * 2 * np.pi / 5) # 5 second period signal = amplitudes[None, :, :] * (sine[:, None, None] + 0.5) return EITData( - pixel_impedance=signal, - path="", - nframes=len(t), + values=signal, time=t, sample_frequency=sample_frequency, vendor=Vendor.SIMULATED, label="simulated", description="Simulated EIT data for testing purposes", - name="", suppress_simulated_warning=True, ) diff --git a/tests/test_watershed.py b/tests/test_watershed.py index 92390725f..6c456a265 100644 --- a/tests/test_watershed.py +++ b/tests/test_watershed.py @@ -48,15 +48,12 @@ def factory( pixel_impedance = sinusoid_shape * amplitude + end_expiratory_value return EITData( - pixel_impedance=pixel_impedance, - path="", - nframes=len(time), + values=pixel_impedance, time=time, sample_frequency=sample_frequency, vendor=Vendor.SIMULATED, label="simulated", description="Simulated EIT data for testing purposes", - name="", suppress_simulated_warning=True, ) @@ -164,15 +161,12 @@ def test_watershed_captures(draeger_50hz_healthy_volunteer_pressure_pod: Sequenc def test_watershed_no_amplitude(): eit_data = EITData( - pixel_impedance=np.ones((100, 32, 32)), - path="", - nframes=100, + values=np.ones((100, 32, 32)), time=np.arange(100) / 20, sample_frequency=20, vendor=Vendor.SIMULATED, label="simulated", description="Simulated EIT data with no amplitude", - name="", suppress_simulated_warning=True, ) with pytest.raises(ValueError, match="No breaths found in TIV or amplitude data"): From 74825980516dbb3cfe2aeea96cbda37bb449484b Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 19:52:00 +0100 Subject: [PATCH 09/32] Migrate all code to use EITData.values instead of pixel_impedance --- eitprocessing/datahandling/eitdata.py | 12 ++++----- .../datahandling/loading/__init__.py | 2 +- eitprocessing/features/pixel_breath.py | 4 +-- eitprocessing/features/rate_detection.py | 2 +- eitprocessing/filters/mdn.py | 6 +---- .../parameters/tidal_impedance_variation.py | 2 +- eitprocessing/plotting/filter.py | 11 ++------ eitprocessing/roi/__init__.py | 9 +++---- tests/eitdata/test_loading_sentec.py | 4 +-- tests/test_frozen.py | 25 +++++++++---------- tests/test_mdn_filter.py | 4 +-- tests/test_pixel_breath.py | 11 ++++---- tests/test_pixelmask.py | 14 +++++------ 13 files changed, 45 insertions(+), 61 deletions(-) diff --git a/eitprocessing/datahandling/eitdata.py b/eitprocessing/datahandling/eitdata.py index 68fadd0a5..f3713c727 100644 --- a/eitprocessing/datahandling/eitdata.py +++ b/eitprocessing/datahandling/eitdata.py @@ -142,7 +142,7 @@ def pixel_impedance(self) -> np.ndarray: def nframes(self) -> int: """Number of frames in the data.""" warnings.warn("`nframes` is deprecated. Use `len(eitdata)` instead.", DeprecationWarning, stacklevel=2) - return self.pixel_impedance.shape[0] + return self.values.shape[0] @property def framerate(self) -> float: @@ -190,7 +190,7 @@ def concatenate(self: Self, other: Self, newlabel: str | None = None) -> Self: label=self.label, # TODO: using newlabel leads to errors sample_frequency=self.sample_frequency, time=np.concatenate((self.time, other.time)), - pixel_impedance=np.concatenate((self.pixel_impedance, other.pixel_impedance), axis=0), + values=np.concatenate((self.values, other.values), axis=0), ) def _sliced_copy( @@ -201,11 +201,11 @@ def _sliced_copy( ) -> Self: return self.update( time=self.time[start_index:end_index], - values=self.pixel_impedance[start_index:end_index, :, :], + values=self.values[start_index:end_index, :, :], ) def __len__(self): - return self.pixel_impedance.shape[0] + return self.values.shape[0] def get_summed_impedance(self, *, return_label: str | None = None, **return_kwargs) -> ContinuousData: """Return a ContinuousData-object with the same time axis and summed pixel values over time. @@ -215,7 +215,7 @@ def get_summed_impedance(self, *, return_label: str | None = None, **return_kwar the current object. **return_kwargs: Keyword arguments for the creation of the returned object. """ - summed_impedance = np.nansum(self.pixel_impedance, axis=(1, 2)) + summed_impedance = np.nansum(self.values, axis=(1, 2)) if return_label is None: return_label = f"summed {self.label}" @@ -231,7 +231,7 @@ def get_summed_impedance(self, *, return_label: str | None = None, **return_kwar def calculate_global_impedance(self) -> np.ndarray: """Return the global impedance, i.e. the sum of all included pixels at each frame.""" - return np.nansum(self.pixel_impedance, axis=(1, 2)) + return np.nansum(self.values, axis=(1, 2)) def update(self, **kwargs) -> Self: """Return a copy of the object with specified fields replaced. diff --git a/eitprocessing/datahandling/loading/__init__.py b/eitprocessing/datahandling/loading/__init__.py index b69180f54..b759eed8f 100644 --- a/eitprocessing/datahandling/loading/__init__.py +++ b/eitprocessing/datahandling/loading/__init__.py @@ -52,7 +52,7 @@ def load_eit_data( ... vendor="sentec", ... label="initial_measurement" ... ) - >>> pixel_impedance = sequence.eit_data["raw"].pixel_impedance + >>> pixel_impedance = sequence.eit_data["raw"].values ``` """ from eitprocessing.datahandling.loading import ( # noqa: PLC0415 diff --git a/eitprocessing/features/pixel_breath.py b/eitprocessing/features/pixel_breath.py index 3f156afcf..032740a4d 100644 --- a/eitprocessing/features/pixel_breath.py +++ b/eitprocessing/features/pixel_breath.py @@ -141,7 +141,7 @@ def find_pixel_breaths( # noqa: C901, PLR0912, PLR0915 [breath.middle_time for breath in continuous_breaths.values], ) - _, n_rows, n_cols = eit_data.pixel_impedance.shape + _, n_rows, n_cols = eit_data.values.shape from eitprocessing.parameters.tidal_impedance_variation import TIV # noqa: PLC0415 @@ -169,7 +169,7 @@ def find_pixel_breaths( # noqa: C901, PLR0912, PLR0915 mean_tiv_pixel[~all_nan_mask] = np.nanmean(pixel_tiv_with_continuous_data_timing[:, ~all_nan_mask], axis=0) time = eit_data.time - pixel_impedance = eit_data.pixel_impedance + pixel_impedance = eit_data.values pixel_breaths = np.full((len(continuous_breaths), n_rows, n_cols), None, dtype=object) diff --git a/eitprocessing/features/rate_detection.py b/eitprocessing/features/rate_detection.py index 742467f47..7becb61dc 100644 --- a/eitprocessing/features/rate_detection.py +++ b/eitprocessing/features/rate_detection.py @@ -152,7 +152,7 @@ def apply( set to segment length - 1. """ - pixel_impedance = eit_data.pixel_impedance.copy().astype(np.float32) + pixel_impedance = eit_data.values.copy().astype(np.float32) summed_impedance = np.nansum(pixel_impedance, axis=(1, 2)) len_segment = int(self.welch_window * eit_data.sample_frequency) diff --git a/eitprocessing/filters/mdn.py b/eitprocessing/filters/mdn.py index 012e7d5e4..53fdf0319 100644 --- a/eitprocessing/filters/mdn.py +++ b/eitprocessing/filters/mdn.py @@ -182,14 +182,10 @@ def _validate_arguments( msg = "Axis should not be provided when using ContinuousData or EITData." raise ValueError(msg) - if isinstance(input_data, ContinuousData): + if isinstance(input_data, (ContinuousData, EITData)): data = input_data.values sample_frequency_ = cast("float", input_data.sample_frequency) axis_ = 0 - elif isinstance(input_data, EITData): - data = input_data.pixel_impedance - sample_frequency_ = cast("float", input_data.sample_frequency) - axis_ = 0 elif isinstance(input_data, np.ndarray): data = input_data axis_ = DEFAULT_AXIS if axis is MISSING else axis diff --git a/eitprocessing/parameters/tidal_impedance_variation.py b/eitprocessing/parameters/tidal_impedance_variation.py index 6b47086b6..50e01aef0 100644 --- a/eitprocessing/parameters/tidal_impedance_variation.py +++ b/eitprocessing/parameters/tidal_impedance_variation.py @@ -203,7 +203,7 @@ def compute_pixel_parameter( msg = f"Invalid {tiv_timing}. The tiv_timing must be either 'continuous' or 'pixel'." raise ValueError(msg) - data = eit_data.pixel_impedance + data = eit_data.values _, n_rows, n_cols = data.shape if tiv_timing == "pixel": diff --git a/eitprocessing/plotting/filter.py b/eitprocessing/plotting/filter.py index fc050344c..5b7e0efdd 100644 --- a/eitprocessing/plotting/filter.py +++ b/eitprocessing/plotting/filter.py @@ -202,19 +202,12 @@ def _get_data( unfiltered_signal = unfiltered_data filtered_signal = filtered_data sample_frequency_ = sample_frequency - case ContinuousData(), ContinuousData(): + case (ContinuousData(), ContinuousData()) | (EITData(), EITData()): unfiltered_signal = unfiltered_data.values filtered_signal = filtered_data.values sample_frequency_ = unfiltered_data.sample_frequency if sample_frequency is not MISSING: - msg = "Sample frequency should not be provided when using ContinuousData." - raise ValueError(msg) - case EITData(), EITData(): - unfiltered_signal = unfiltered_data.pixel_impedance - filtered_signal = filtered_data.pixel_impedance - sample_frequency_ = unfiltered_data.sample_frequency - if sample_frequency is not MISSING: - msg = "Sample frequency should not be provided when using EITData." + msg = "Sample frequency should not be provided when using ContinuousData or EITData." raise ValueError(msg) case _: msg = "Unfiltered and filtered data must be either numpy arrays, ContinuousData, or EITData." diff --git a/eitprocessing/roi/__init__.py b/eitprocessing/roi/__init__.py index d02daad5e..979c4e7a8 100644 --- a/eitprocessing/roi/__init__.py +++ b/eitprocessing/roi/__init__.py @@ -206,9 +206,8 @@ def apply(self, data: PixelMap, **kwargs) -> PixelMap: ... def apply(self, data, **kwargs): """Apply pixel mask to data, returning a copy of the object with pixel values masked. - Data can be a numpy array, an EITData object or PixelMap object. In case of an EITData object, the mask will be - applied to the `pixel_impedance` attribute. In case of a PixelMap, the mask will be applied to the `values` - attribute. + Data can be a numpy array, an EITData object or PixelMap object. In case of an EITData or PixelMap object, the + mask will be applied to the `values` attribute. The input data can have any dimension. The mask is applied to the last two dimensions. The size of the last two dimensions must match the size of the dimensions of the mask, and will generally (but do not have to) have the @@ -241,9 +240,7 @@ def transform_and_mask(data: np.ndarray) -> np.ndarray: match data: case np.ndarray(): return transform_and_mask(data) - case EITData(): - return data.update(pixel_impedance=transform_and_mask(data.pixel_impedance), **kwargs) - case PixelMap(): + case EITData() | PixelMap(): return data.update(values=transform_and_mask(data.values), **kwargs) case _: msg = f"Data should be an array, or EITData or PixelMap object, not {type(data)}." diff --git a/tests/eitdata/test_loading_sentec.py b/tests/eitdata/test_loading_sentec.py index 73f8bb244..dc437d48f 100644 --- a/tests/eitdata/test_loading_sentec.py +++ b/tests/eitdata/test_loading_sentec.py @@ -29,8 +29,8 @@ def test_load_sentec_single_file( assert isinstance(eit_data, EITData) assert np.isclose(eit_data.sample_frequency, 50.2, rtol=2e-2), "Sample frequency should be approximately 50.2 Hz" assert len(eit_data.time) > 0, "Time axis should not be empty" - assert len(eit_data.time) == len(eit_data.pixel_impedance), "Length of time axis should match number of frames" - assert len(eit_data) == len(eit_data.pixel_impedance), "Length of EITData should match number of frames" + assert len(eit_data.time) == len(eit_data.values), "Length of time axis should match number of frames" + assert len(eit_data) == len(eit_data.values), "Length of EITData should match number of frames" assert len(sequence.continuous_data) == 0, "Sentec data should not have continuous data channels" assert len(sequence.sparse_data) == 0, "Sentec data should not have sparse data channels" diff --git a/tests/test_frozen.py b/tests/test_frozen.py index 7d98ebd1f..4ba452a70 100644 --- a/tests/test_frozen.py +++ b/tests/test_frozen.py @@ -9,10 +9,9 @@ @pytest.fixture def frozen_eitdata_object() -> EITData: return EITData( - path="test_path", label="test_label", time=np.arange(10) / 10.0, - pixel_impedance=np.random.default_rng().random((10, 3, 3)), + values=np.random.default_rng().random((10, 3, 3)), sample_frequency=10.0, vendor="simulated", ) @@ -34,26 +33,26 @@ def test_frozen_time_axis(frozen_eitdata_object: EITData): def test_frozen_values(frozen_eitdata_object: EITData): with pytest.raises(ValueError, match="output array is read-only"): - frozen_eitdata_object.pixel_impedance += 1.0 + frozen_eitdata_object.values += 1.0 - with pytest.raises(AttributeError, match="cannot assign to field 'pixel_impedance'"): - frozen_eitdata_object.pixel_impedance = frozen_eitdata_object.pixel_impedance + 1.0 + with pytest.raises(AttributeError, match="cannot assign to field 'values'"): + frozen_eitdata_object.values = frozen_eitdata_object.values + 1.0 with pytest.raises(ValueError, match="assignment destination is read-only"): - frozen_eitdata_object.pixel_impedance[0, 0, 0] = 1.0 + frozen_eitdata_object.values[0, 0, 0] = 1.0 def test_unfreeze_array_on_copy(frozen_eitdata_object: EITData): - values_copy = frozen_eitdata_object.pixel_impedance.copy() + values_copy = frozen_eitdata_object.values.copy() assert values_copy.flags["WRITEABLE"] values_copy += 1.0 - new_frozen_eitdata_object = frozen_eitdata_object.update(pixel_impedance=values_copy) - assert not new_frozen_eitdata_object.pixel_impedance.flags["WRITEABLE"] - assert np.array_equal(values_copy, new_frozen_eitdata_object.pixel_impedance) + new_frozen_eitdata_object = frozen_eitdata_object.update(values=values_copy) + assert not new_frozen_eitdata_object.values.flags["WRITEABLE"] + assert np.array_equal(values_copy, new_frozen_eitdata_object.values) def test_frozen_slice(frozen_eitdata_object: EITData): - values_view = frozen_eitdata_object.pixel_impedance[:5, :, :] + values_view = frozen_eitdata_object.values[:5, :, :] assert not values_view.flags["WRITEABLE"] with pytest.raises(ValueError, match="assignment destination is read-only"): values_view[0, 0, 0] = 1.0 @@ -64,7 +63,7 @@ def test_frozen_slice(frozen_eitdata_object: EITData): def test_cannot_unfreeze(frozen_eitdata_object: EITData): - base = frozen_eitdata_object.pixel_impedance.base + base = frozen_eitdata_object.values.base with contextlib.suppress(AttributeError): while True: base = base.base @@ -73,4 +72,4 @@ def test_cannot_unfreeze(frozen_eitdata_object: EITData): pytest.skip("Array is not based on a memoryview; cannot test unfreeze.") with pytest.raises(ValueError, match="cannot set WRITEABLE flag to True of this array"): - frozen_eitdata_object.pixel_impedance.flags["WRITEABLE"] = True + frozen_eitdata_object.values.flags["WRITEABLE"] = True diff --git a/tests/test_mdn_filter.py b/tests/test_mdn_filter.py index b5c8051b6..a140a1c06 100644 --- a/tests/test_mdn_filter.py +++ b/tests/test_mdn_filter.py @@ -133,9 +133,9 @@ def test_with_eit_data(draeger_20hz_healthy_volunteer: Sequence): ) filtered_data = mdn_filter.apply(eit_data) - filtered_signal = mdn_filter.apply(eit_data.pixel_impedance, sample_frequency=eit_data.sample_frequency, axis=0) + filtered_signal = mdn_filter.apply(eit_data.values, sample_frequency=eit_data.sample_frequency, axis=0) - assert np.allclose(filtered_data.pixel_impedance, filtered_signal) + assert np.allclose(filtered_data.values, filtered_signal) @pytest.mark.parametrize( diff --git a/tests/test_pixel_breath.py b/tests/test_pixel_breath.py index 459dcf6e2..eaf195fc7 100644 --- a/tests/test_pixel_breath.py +++ b/tests/test_pixel_breath.py @@ -316,7 +316,7 @@ def test_with_custom_mean_pixel_tiv( for row, col in itertools.product(range(2), range(2)): time_point = test_result[1, row, col].middle_time index = np.where(mock_eit_data.time == time_point)[0] - value_at_time = mock_eit_data.pixel_impedance[index[0], row, col] + value_at_time = mock_eit_data.values[index[0], row, col] if mean == -1: assert np.isclose(value_at_time, -1, atol=0.01) elif mean == 1: @@ -391,19 +391,18 @@ def test_phase_modes(draeger_50hz_healthy_volunteer_pressure_pod: Sequence, pyte eit_data = sequence.eit_data["raw"] # reduce the pixel set to some 'well-behaved' pixels with positive TIV - eit_data = eit_data.update(values=eit_data.pixel_impedance[:, 14:20, 14:20]) - np.savetxt("v_before_new.txt", eit_data.values[0]) + eit_data = eit_data.update(values=eit_data.values[:, 14:20, 14:20]) # flip a single pixel, so the differences between algorithms becomes predictable flip_row, flip_col = 5, 5 - new_values = eit_data.pixel_impedance.copy() + new_values = eit_data.values.copy() new_values[:, flip_row, flip_col] = -new_values[:, flip_row, flip_col] - eit_data = eit_data.update(pixel_impedance=new_values) + eit_data = eit_data.update(values=new_values) cd = sequence.continuous_data["global_impedance_(raw)"] # replace the 'global' data with the sum of the middly pixels - cd.values = np.sum(eit_data.pixel_impedance, axis=(1, 2)) + cd.values = np.sum(eit_data.values, axis=(1, 2)) pb_negative_amplitude = PixelBreath(phase_correction_mode="negative amplitude").find_pixel_breaths(eit_data, cd) pb_phase_shift = PixelBreath(phase_correction_mode="phase shift").find_pixel_breaths(eit_data, cd) diff --git a/tests/test_pixelmask.py b/tests/test_pixelmask.py index eb2d5e5a2..d5bdbfc34 100644 --- a/tests/test_pixelmask.py +++ b/tests/test_pixelmask.py @@ -158,19 +158,19 @@ def test_pixelmask_apply_eitdata(draeger_20hz_healthy_volunteer: Sequence): mask = PixelMask(np.full((32, 32), np.nan), suppress_all_nan_warning=True) masked_eit_data = mask.apply(eit_data) - assert masked_eit_data.pixel_impedance.shape == eit_data.pixel_impedance.shape - assert np.all(np.isnan(masked_eit_data.pixel_impedance)) + assert masked_eit_data.values.shape == eit_data.values.shape + assert np.all(np.isnan(masked_eit_data.values)) mask_values = np.full((32, 32), np.nan) mask_values[10:23, 10:23] = 1.0 # let center pixels pass mask = PixelMask(mask_values) masked_eit_data = mask.apply(eit_data) - assert np.array_equal(masked_eit_data.pixel_impedance[:, 10:23, 10:23], eit_data.pixel_impedance[:, 10:23, 10:23]) - assert np.all(np.isnan(masked_eit_data.pixel_impedance[:, :10, :])) - assert np.all(np.isnan(masked_eit_data.pixel_impedance[:, 23:, :])) - assert np.all(np.isnan(masked_eit_data.pixel_impedance[:, :, :10])) - assert np.all(np.isnan(masked_eit_data.pixel_impedance[:, :, 23:])) + assert np.array_equal(masked_eit_data.values[:, 10:23, 10:23], eit_data.values[:, 10:23, 10:23]) + assert np.all(np.isnan(masked_eit_data.values[:, :10, :])) + assert np.all(np.isnan(masked_eit_data.values[:, 23:, :])) + assert np.all(np.isnan(masked_eit_data.values[:, :, :10])) + assert np.all(np.isnan(masked_eit_data.values[:, :, 23:])) def test_pixelmask_apply_pixelmap(): From 7c3380bc5b919ddc16df146ebab707e493bcc2f8 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 20:30:49 +0100 Subject: [PATCH 10/32] Remove derived_from and parameters from data containers --- eitprocessing/datahandling/continuousdata.py | 7 +----- eitprocessing/datahandling/datacollection.py | 12 --------- eitprocessing/datahandling/intervaldata.py | 8 ------ eitprocessing/datahandling/loading/draeger.py | 4 --- eitprocessing/datahandling/loading/timpel.py | 3 --- eitprocessing/datahandling/sparsedata.py | 6 ----- eitprocessing/features/breath_detection.py | 2 -- eitprocessing/features/pixel_breath.py | 1 - eitprocessing/parameters/eeli.py | 2 -- .../parameters/tidal_impedance_variation.py | 2 -- tests/test_datacollection.py | 25 +++---------------- tests/test_parameter_tiv.py | 6 ----- tests/test_pixel_breath.py | 4 --- 13 files changed, 4 insertions(+), 78 deletions(-) diff --git a/eitprocessing/datahandling/continuousdata.py b/eitprocessing/datahandling/continuousdata.py index 8ca919465..fb29793b9 100644 --- a/eitprocessing/datahandling/continuousdata.py +++ b/eitprocessing/datahandling/continuousdata.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from typing_extensions import Any, Self + from typing_extensions import Self T = TypeVar("T", bound="ContinuousData") @@ -32,8 +32,6 @@ class ContinuousData(DataContainer, SelectByTime): unit: Unit of the data, if applicable. category: Category the data falls into, e.g. 'airway pressure'. description: Human readable extended description of the data. - parameters: Parameters used to derive this data. - derived_from: Traceback of intermediates from which the current data was derived. values: Data points. """ @@ -42,8 +40,6 @@ class ContinuousData(DataContainer, SelectByTime): unit: str = field(metadata={"check_equivalence": True}, repr=False) category: str = field(metadata={"check_equivalence": True}, repr=False) description: str = field(default="", compare=False, repr=False) - parameters: dict[str, Any] = field(default_factory=dict, repr=False, metadata={"check_equivalence": True}) - derived_from: Any | list[Any] = field(default_factory=list, repr=False, compare=False) time: np.ndarray = field(kw_only=True, repr=False) values: np.ndarray = field(kw_only=True, repr=False) sample_frequency: float | None = field(kw_only=True, repr=False, metadata={"check_equivalence": True}, default=None) @@ -128,7 +124,6 @@ def concatenate(self: Self, other: Self, newlabel: str | None = None) -> Self: category=self.category, time=np.concatenate((self.time, other.time)), values=np.concatenate((self.values, other.values)), - derived_from=[*self.derived_from, *other.derived_from, self, other], sample_frequency=self.sample_frequency, ) diff --git a/eitprocessing/datahandling/datacollection.py b/eitprocessing/datahandling/datacollection.py index 25bd8874a..a9952dbec 100644 --- a/eitprocessing/datahandling/datacollection.py +++ b/eitprocessing/datahandling/datacollection.py @@ -93,18 +93,6 @@ def _check_item( ) raise KeyError(msg) - def get_loaded_data(self) -> dict[str, V]: - """Return all data that was directly loaded from disk.""" - return {k: v for k, v in self.items() if v.loaded} - - def get_data_derived_from(self, obj: V) -> dict[str, V]: - """Return all data that was derived from a specific source.""" - return {k: v for k, v in self.items() if any(obj is item for item in v.derived_from)} - - def get_derived_data(self) -> dict[str, V]: - """Return all data that was derived from any source.""" - return {k: v for k, v in self.items() if v.derived_from} - def concatenate(self: Self, other: Self) -> Self: """Concatenate this collection with an equivalent collection. diff --git a/eitprocessing/datahandling/intervaldata.py b/eitprocessing/datahandling/intervaldata.py index 2ae95826a..7fa186069 100644 --- a/eitprocessing/datahandling/intervaldata.py +++ b/eitprocessing/datahandling/intervaldata.py @@ -43,8 +43,6 @@ class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): category: Category the data falls into, e.g. 'breath'. intervals: A list of intervals (tuples containing a start time and end time). values: An optional list of values associated with each interval. - parameters: Parameters used to derive the data. - derived_from: Traceback of intermediates from which the current data was derived. description: Extended human readible description of the data. default_partial_inclusion: Whether to include a trimmed version of an interval when selecting data """ @@ -55,8 +53,6 @@ class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): category: str = field(metadata={"check_equivalence": True}, repr=False) intervals: list[Interval | tuple[float, float]] = field(repr=False) values: list[Any] | None = field(repr=False, default=None) - parameters: dict[str, Any] = field(default_factory=dict, metadata={"check_equivalence": True}, repr=False) - derived_from: list[Any] = field(default_factory=list, compare=False, repr=False) description: str = field(compare=False, default="", repr=False) default_partial_inclusion: bool = field(repr=False, default=False) @@ -92,7 +88,6 @@ def _sliced_copy( unit=self.unit, category=self.category, description=description, - derived_from=[*self.derived_from, self], intervals=intervals, values=values, ) @@ -130,7 +125,6 @@ def select_by_time( if start_time is None and end_time is None: copy_ = copy.deepcopy(self) - copy_.derived_from.append(self) copy_.label = newlabel return copy_ @@ -157,7 +151,6 @@ def select_by_time( name=self.name, unit=self.unit, category=self.category, - derived_from=[*self.derived_from, self], intervals=list(filtered_intervals), values=values, ) @@ -228,7 +221,6 @@ def concatenate(self: Self, other: Self, newlabel: str | None = None) -> Self: unit=self.unit, category=self.category, description=self.description, - derived_from=[*self.derived_from, *other.derived_from, self, other], intervals=self.intervals + other.intervals, values=new_values, ) diff --git a/eitprocessing/datahandling/loading/draeger.py b/eitprocessing/datahandling/loading/draeger.py index 1649b5c8c..b211c1e44 100644 --- a/eitprocessing/datahandling/loading/draeger.py +++ b/eitprocessing/datahandling/loading/draeger.py @@ -140,7 +140,6 @@ def load_from_single_path( # noqa: PLR0915 name="Global impedance (raw)", unit="a.u.", category="impedance", - derived_from=[eit_data], time=eit_data.time, values=eit_data.calculate_global_impedance(), sample_frequency=sample_frequency, @@ -152,7 +151,6 @@ def load_from_single_path( # noqa: PLR0915 name="Minimum values detected by Draeger device.", unit=None, category="minvalue", - derived_from=[eit_data], time=np.array([t for t, d in phases if d == -1]), ), ) @@ -162,7 +160,6 @@ def load_from_single_path( # noqa: PLR0915 name="Maximum values detected by Draeger device.", unit=None, category="maxvalue", - derived_from=[eit_data], time=np.array([t for t, d in phases if d == 1]), ), ) @@ -178,7 +175,6 @@ def load_from_single_path( # noqa: PLR0915 name="Events loaded from Draeger data", unit=None, category="event", - derived_from=[eit_data], time=time, values=events, ), diff --git a/eitprocessing/datahandling/loading/timpel.py b/eitprocessing/datahandling/loading/timpel.py index 024e4deba..22d04506a 100644 --- a/eitprocessing/datahandling/loading/timpel.py +++ b/eitprocessing/datahandling/loading/timpel.py @@ -167,7 +167,6 @@ def load_from_single_path( name="Minimum values detected by Timpel device.", unit=None, category="minvalue", - derived_from=[eit_data], time=time[min_indices], ), ) @@ -179,7 +178,6 @@ def load_from_single_path( name="Maximum values detected by Timpel device.", unit=None, category="maxvalue", - derived_from=[eit_data], time=time[max_indices], ), ) @@ -207,7 +205,6 @@ def load_from_single_path( name="QRS complexes detected by Timpel device", unit=None, category="qrs_complex", - derived_from=[eit_data], time=time[qrs_indices], ), ) diff --git a/eitprocessing/datahandling/sparsedata.py b/eitprocessing/datahandling/sparsedata.py index 2f0bd8365..e4840f3d3 100644 --- a/eitprocessing/datahandling/sparsedata.py +++ b/eitprocessing/datahandling/sparsedata.py @@ -30,8 +30,6 @@ class SparseData(DataContainer, SelectByTime): unit: Unit of the data, if applicable. category: Category the data falls into, e.g. 'detected r peak'. description: Human readable extended description of the data. - parameters: Parameters used to derive the data. - derived_from: Traceback of intermediates from which the current data was derived. values: List or array of values. These van be numeric data, text or Python objects. """ @@ -41,8 +39,6 @@ class SparseData(DataContainer, SelectByTime): category: str = field(metadata={"check_equivalence": True}, repr=False) time: np.ndarray = field(repr=False) description: str = field(compare=False, default="", repr=False) - parameters: dict[str, Any] = field(default_factory=dict, metadata={"check_equivalence": True}, repr=False) - derived_from: list[Any] = field(default_factory=list, compare=False, repr=False) values: Any | None = None def __repr__(self) -> str: @@ -78,7 +74,6 @@ def _sliced_copy( unit=self.unit, category=self.category, description=description, - derived_from=[*self.derived_from, self], time=time, values=values, ) @@ -122,7 +117,6 @@ def concatenate(self, other: Self, newlabel: str | None = None) -> Self: # noqa unit=self.unit, category=self.category, description=self.description, - derived_from=[*self.derived_from, *other.derived_from, self, other], time=np.concatenate((self.time, other.time)), values=new_values, ) diff --git a/eitprocessing/features/breath_detection.py b/eitprocessing/features/breath_detection.py index 2cedf7fc9..64cb8c87e 100644 --- a/eitprocessing/features/breath_detection.py +++ b/eitprocessing/features/breath_detection.py @@ -123,8 +123,6 @@ def find_breaths( category="breath", intervals=[(breath.start_time, breath.end_time) for breath in breaths], values=breaths, - parameters={type(self): dict(vars(self))}, - derived_from=[continuous_data], ) if store: diff --git a/eitprocessing/features/pixel_breath.py b/eitprocessing/features/pixel_breath.py index 032740a4d..f0bcd915e 100644 --- a/eitprocessing/features/pixel_breath.py +++ b/eitprocessing/features/pixel_breath.py @@ -291,7 +291,6 @@ def find_pixel_breaths( # noqa: C901, PLR0912, PLR0915 values=list( pixel_breaths, ), ## TODO: change back to pixel_breaths array when IntervalData works with 3D array - derived_from=[eit_data], ) if store: sequence.interval_data.add(pixel_breaths_container) diff --git a/eitprocessing/parameters/eeli.py b/eitprocessing/parameters/eeli.py index 75fbff071..a2a224929 100644 --- a/eitprocessing/parameters/eeli.py +++ b/eitprocessing/parameters/eeli.py @@ -109,8 +109,6 @@ def compute_parameter( category="impedance", time=time, description="End-expiratory lung impedance (EELI) determined on continuous data", - parameters=self.breath_detection_kwargs, - derived_from=[continuous_data], values=values, ) if store: diff --git a/eitprocessing/parameters/tidal_impedance_variation.py b/eitprocessing/parameters/tidal_impedance_variation.py index 50e01aef0..29818260c 100644 --- a/eitprocessing/parameters/tidal_impedance_variation.py +++ b/eitprocessing/parameters/tidal_impedance_variation.py @@ -142,7 +142,6 @@ def compute_continuous_parameter( category="impedance difference", time=[breath.middle_time for breath in breaths.values if breath is not None], description="Tidal impedance variation determined on continuous data", - derived_from=[continuous_data], values=tiv_values, ) if store: @@ -253,7 +252,6 @@ def compute_pixel_parameter( category="impedance difference", time=list(all_pixels_breath_timings), description="Tidal impedance variation determined on pixel impedance", - derived_from=[eit_data], values=list(all_pixels_tiv_values.astype(float)), ) diff --git a/tests/test_datacollection.py b/tests/test_datacollection.py index e6589d413..e105f06c3 100644 --- a/tests/test_datacollection.py +++ b/tests/test_datacollection.py @@ -16,7 +16,6 @@ def create_data_object() -> Callable[[str, list | None], ContinuousData]: def internal( label: str, - derived_from: list | None = None, time: np.ndarray | None = None, values: np.ndarray | None = None, ) -> ContinuousData: @@ -28,7 +27,6 @@ def internal( time=time if isinstance(time, np.ndarray) else np.array([]), sample_frequency=0, values=values if isinstance(values, np.ndarray) else np.array([]), - derived_from=derived_from if isinstance(derived_from, list) else [], ) return internal @@ -121,9 +119,9 @@ def test_set_item(create_data_object: Callable[[str], ContinuousData]): def test_loaded_derived_data(create_data_object: Callable): data_object_loaded_1 = create_data_object("label 1") data_object_loaded_2 = create_data_object("label 2") - data_object_derived_1_a = create_data_object("label 1 der a", derived_from=[data_object_loaded_1]) - data_object_derived_1_b = create_data_object("label 1 der b", derived_from=[data_object_loaded_1]) - data_object_derived_2_a = create_data_object("label 2 der a", derived_from=[data_object_loaded_2]) + data_object_derived_1_a = create_data_object("label 1 der a") + data_object_derived_1_b = create_data_object("label 1 der b") + data_object_derived_2_a = create_data_object("label 2 der a") dc = DataCollection(ContinuousData) dc.add( @@ -134,23 +132,6 @@ def test_loaded_derived_data(create_data_object: Callable): data_object_derived_2_a, ) - assert dc.get_loaded_data() == { - "label 1": data_object_loaded_1, - "label 2": data_object_loaded_2, - } - assert dc.get_derived_data() == { - "label 1 der a": data_object_derived_1_a, - "label 1 der b": data_object_derived_1_b, - "label 2 der a": data_object_derived_2_a, - } - assert dc.get_data_derived_from(data_object_loaded_1) == { - "label 1 der a": data_object_derived_1_a, - "label 1 der b": data_object_derived_1_b, - } - assert dc.get_data_derived_from(data_object_loaded_2) == { - "label 2 der a": data_object_derived_2_a, - } - def test_concatenate(create_data_object: Callable): time = np.arange(100) / 20 diff --git a/tests/test_parameter_tiv.py b/tests/test_parameter_tiv.py index f3be9bcaf..bccdb530d 100644 --- a/tests/test_parameter_tiv.py +++ b/tests/test_parameter_tiv.py @@ -73,8 +73,6 @@ def mock_continuous_data(): unit="au", category="relative impedance", description="Global impedance created for testing TIV parameter", - parameters={}, - derived_from="mock_eit_data", time=np.linspace(0, 18, (18 * 1000), endpoint=False), values=mock_global_impedance(), sample_frequency=1000, @@ -420,8 +418,6 @@ def test_tiv_with_no_breaths_continuous(mock_continuous_data: ContinuousData): category="breath", intervals=[], values=[], - parameters={}, - derived_from=[], ), ), patch.object(tiv, "_calculate_tiv_values", return_value=[]), @@ -448,8 +444,6 @@ def test_tiv_with_no_breaths_pixel( category="breath", intervals=[], values=np.empty((0, 2, 2), dtype=object), - parameters={}, - derived_from=[], ), ), patch.object(tiv, "_calculate_tiv_values", return_value=[]), diff --git a/tests/test_pixel_breath.py b/tests/test_pixel_breath.py index eaf195fc7..724213a99 100644 --- a/tests/test_pixel_breath.py +++ b/tests/test_pixel_breath.py @@ -62,8 +62,6 @@ def mock_continuous_data(): unit="au", category="relative impedance", description="Global impedance created for testing pixel breath feature", - parameters={}, - derived_from="mock_eit_data", time=np.linspace(0, 2 * np.pi, 400), values=mock_global_impedance(), sample_frequency=399 / 2 * np.pi, @@ -170,8 +168,6 @@ def _mock(*_args, **_kwargs) -> SparseData: category="impedance difference", time=time, description="Mock tidal impedance variation", - parameters={}, - derived_from=[], values=values, ) From f116c3731fd63ce2aadc97f8daad4b84e8a4833f Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 20:31:03 +0100 Subject: [PATCH 11/32] Fix import --- eitprocessing/parameters/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eitprocessing/parameters/__init__.py b/eitprocessing/parameters/__init__.py index 2ab04a6d0..41a0e14b9 100644 --- a/eitprocessing/parameters/__init__.py +++ b/eitprocessing/parameters/__init__.py @@ -2,7 +2,7 @@ import numpy as np -from eitprocessing.datahandling.continuousdata import DataContainer +from eitprocessing.datahandling import DataContainer class ParameterCalculation(ABC): From 414f07eef9165e7f50c93f0795cd8a5426ef7ebf Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 20:32:40 +0100 Subject: [PATCH 12/32] Remove lock and unlock methods --- eitprocessing/datahandling/continuousdata.py | 59 ------------------- eitprocessing/datahandling/loading/draeger.py | 1 - 2 files changed, 60 deletions(-) diff --git a/eitprocessing/datahandling/continuousdata.py b/eitprocessing/datahandling/continuousdata.py index fb29793b9..6ccec236b 100644 --- a/eitprocessing/datahandling/continuousdata.py +++ b/eitprocessing/datahandling/continuousdata.py @@ -45,10 +45,6 @@ class ContinuousData(DataContainer, SelectByTime): sample_frequency: float | None = field(kw_only=True, repr=False, metadata={"check_equivalence": True}, default=None) def __post_init__(self) -> None: - if self.loaded: - self.lock() - self.lock("time") - if self.sample_frequency is None: msg = ( "`sample_frequency` is set to `None`. This will not be supported in future versions. " @@ -168,61 +164,6 @@ def convert_data(x, add=None, subtract=None, multiply=None, divide=None): copy.values = function(copy.values, **func_args) return copy - def lock(self, *attr: str) -> None: - """Lock attributes, essentially rendering them read-only. - - Locked attributes cannot be overwritten. Attributes can be unlocked using `unlock()`. - - Args: - *attr: any number of attributes can be passed here, all of which will be locked. Defaults to "values". - - Examples: - >>> # lock the `values` attribute of `data` - >>> data.lock() - >>> data.values = [1, 2, 3] # will result in an AttributeError - >>> data.values[0] = 1 # will result in a RuntimeError - """ - if not attr: - # default values are not allowed when using *attr, so set a default here if none is supplied - attr = ("values",) - for attr_ in attr: - getattr(self, attr_).flags["WRITEABLE"] = False - - def unlock(self, *attr: str) -> None: - """Unlock attributes, rendering them editable. - - Locked attributes cannot be overwritten, but can be unlocked with this function to make them editable. - - Args: - *attr: any number of attributes can be passed here, all of which will be unlocked. Defaults to "values". - - Examples: - >>> # lock the `values` attribute of `data` - >>> data.lock() - >>> data.values = [1, 2, 3] # will result in an AttributeError - >>> data.values[0] = 1 # will result in a RuntimeError - >>> data.unlock() - >>> data.values = [1, 2, 3] - >>> print(data.values) - [1,2,3] - >>> data.values[0] = 1 # will result in a RuntimeError - >>> print(data.values) - 1 - """ - if not attr: - # default values are not allowed when using *attr, so set a default here if none is supplied - attr = ("values",) - for attr_ in attr: - getattr(self, attr_).flags["WRITEABLE"] = True - - @property - def locked(self) -> bool: - """Return whether the values attribute is locked. - - See lock(). - """ - return not self.values.flags["WRITEABLE"] - @property def loaded(self) -> bool: """Return whether the data was loaded from disk, or derived from elsewhere.""" diff --git a/eitprocessing/datahandling/loading/draeger.py b/eitprocessing/datahandling/loading/draeger.py index b211c1e44..d3383c9ca 100644 --- a/eitprocessing/datahandling/loading/draeger.py +++ b/eitprocessing/datahandling/loading/draeger.py @@ -255,7 +255,6 @@ def _convert_medibus_data( category=field_info.signal_name, sample_frequency=sample_frequency, ) - continuous_data.lock() continuousdata_collection.add(continuous_data) else: From 23c1f469cd2bd8fba7f5f94daf6da9bbdd284ed0 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 20:41:23 +0100 Subject: [PATCH 13/32] Make ContinuousData frozen --- eitprocessing/datahandling/continuousdata.py | 69 +++----------------- tests/test_eeli.py | 5 +- tests/test_pixel_breath.py | 2 +- 3 files changed, 11 insertions(+), 65 deletions(-) diff --git a/eitprocessing/datahandling/continuousdata.py b/eitprocessing/datahandling/continuousdata.py index 6ccec236b..ec02bf65c 100644 --- a/eitprocessing/datahandling/continuousdata.py +++ b/eitprocessing/datahandling/continuousdata.py @@ -6,8 +6,9 @@ import numpy as np -from eitprocessing.datahandling import DataContainer +from eitprocessing.datahandling import FrozenDataContainer from eitprocessing.datahandling.mixins.slicing import SelectByTime +from eitprocessing.utils.frozen_array import freeze_array if TYPE_CHECKING: from collections.abc import Callable @@ -17,8 +18,8 @@ T = TypeVar("T", bound="ContinuousData") -@dataclass(eq=False) -class ContinuousData(DataContainer, SelectByTime): +@dataclass(eq=False, frozen=True) +class ContinuousData(FrozenDataContainer, SelectByTime): """Container for data with a continuous time axis. Continuous data is assumed to be sequential (i.e. a single data point at each time point, sorted by time) and @@ -56,48 +57,8 @@ def __post_init__(self) -> None: msg = f"The number of time points ({lt}) does not match the number of values ({lv})." raise ValueError(msg) - def __setattr__(self, attr: str, value: Any): # noqa: ANN401 - try: - old_value = getattr(self, attr) - except AttributeError: - pass - else: - if isinstance(old_value, np.ndarray) and old_value.flags["WRITEABLE"] is False: - msg = f"Attribute '{attr}' is locked and can't be overwritten." - raise AttributeError(msg) - super().__setattr__(attr, value) - - def copy( - self, - label: str, - *, - name: str | None = None, - unit: str | None = None, - description: str | None = None, - parameters: dict | None = None, - ) -> Self: - """Create a copy. - - Whenever data is altered, it should probably be copied first. The alterations should then be made in the copy. - """ - obj = self.__class__( - label=label, - name=name or label, - unit=unit or self.unit, - description=description or f"Derived from {self.name}", - parameters=self.parameters | (parameters or {}), - derived_from=[*self.derived_from, self], - category=self.category, - # copying data can become inefficient with large datasets if the - # data is not directly edited afer copying but overridden instead; - # consider creating a view and locking it, requiring the user to - # make a copy if they want to edit the data directly - time=np.copy(self.time), - values=np.copy(self.values), - sample_frequency=self.sample_frequency, - ) - obj.unlock() - return obj + object.__setattr__(self, "time", freeze_array(self.time)) + object.__setattr__(self, "values", freeze_array(self.values)) def __add__(self: Self, other: Self) -> Self: return self.concatenate(other) @@ -179,19 +140,7 @@ def _sliced_copy( newlabel: str, # noqa: ARG002 ) -> Self: # TODO: check correct implementation - cls = self.__class__ - time = np.copy(self.time[start_index:end_index]) - values = np.copy(self.values[start_index:end_index]) - description = f"Slice ({start_index}-{end_index}) of <{self.description}>" + time = self.time[start_index:end_index] + values = self.values[start_index:end_index] - return cls( - label=self.label, # TODO: newlabel gives errors - name=self.name, - unit=self.unit, - category=self.category, - description=description, - derived_from=[*self.derived_from, self], - time=time, - values=values, - sample_frequency=self.sample_frequency, - ) + return self.update(time=time, values=values) diff --git a/tests/test_eeli.py b/tests/test_eeli.py index c837685fa..c83ab461d 100644 --- a/tests/test_eeli.py +++ b/tests/test_eeli.py @@ -119,12 +119,9 @@ def test_with_data(draeger_20hz_healthy_volunteer_pressure_pod: Sequence, pytest def test_non_impedance_data(draeger_20hz_healthy_volunteer_pressure_pod: Sequence) -> None: cd = draeger_20hz_healthy_volunteer_pressure_pod.continuous_data["global_impedance_(raw)"] - original_category = cd.category _ = EELI().compute_parameter(cd) - cd.category = "foo" + cd = cd.update(category="foo") with pytest.raises(ValueError, match="This method will only work on 'impedance' data, not 'foo'."): _ = EELI().compute_parameter(cd) - - cd.category = original_category diff --git a/tests/test_pixel_breath.py b/tests/test_pixel_breath.py index 724213a99..4d6f3f06e 100644 --- a/tests/test_pixel_breath.py +++ b/tests/test_pixel_breath.py @@ -398,7 +398,7 @@ def test_phase_modes(draeger_50hz_healthy_volunteer_pressure_pod: Sequence, pyte cd = sequence.continuous_data["global_impedance_(raw)"] # replace the 'global' data with the sum of the middly pixels - cd.values = np.sum(eit_data.values, axis=(1, 2)) + cd = cd.update(values=np.sum(eit_data.values, axis=(1, 2))) pb_negative_amplitude = PixelBreath(phase_correction_mode="negative amplitude").find_pixel_breaths(eit_data, cd) pb_phase_shift = PixelBreath(phase_correction_mode="phase shift").find_pixel_breaths(eit_data, cd) From 405a85753c53ac224e8d40c50a07ec328f21f8fe Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 20:43:05 +0100 Subject: [PATCH 14/32] Update ContinuousData attributes --- eitprocessing/datahandling/continuousdata.py | 17 +++++++++-------- eitprocessing/datahandling/loading/timpel.py | 10 +++++----- tests/test_breath_detection.py | 12 ++---------- tests/test_continuous_data.py | 12 ------------ tests/test_eeli.py | 8 ++++---- 5 files changed, 20 insertions(+), 39 deletions(-) diff --git a/eitprocessing/datahandling/continuousdata.py b/eitprocessing/datahandling/continuousdata.py index ec02bf65c..d82d0fcda 100644 --- a/eitprocessing/datahandling/continuousdata.py +++ b/eitprocessing/datahandling/continuousdata.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from dataclasses import dataclass, field +from dataclasses import KW_ONLY, dataclass, field from typing import TYPE_CHECKING, TypeVar import numpy as np @@ -36,13 +36,14 @@ class ContinuousData(FrozenDataContainer, SelectByTime): values: Data points. """ - label: str = field(compare=False) - name: str = field(compare=False, repr=False) - unit: str = field(metadata={"check_equivalence": True}, repr=False) - category: str = field(metadata={"check_equivalence": True}, repr=False) - description: str = field(default="", compare=False, repr=False) - time: np.ndarray = field(kw_only=True, repr=False) - values: np.ndarray = field(kw_only=True, repr=False) + time: np.ndarray = field(repr=False) + values: np.ndarray = field(repr=False) + _: KW_ONLY + label: str | None = field(compare=False, default=None) + name: str | None = field(compare=False, repr=False, default=None) + description: str | None = field(compare=False, repr=False, default=None) + unit: str | None = field(metadata={"check_equivalence": True}, repr=False, default=None) + category: str | None = field(metadata={"check_equivalence": True}, repr=False, default=None) sample_frequency: float | None = field(kw_only=True, repr=False, metadata={"check_equivalence": True}, default=None) def __post_init__(self) -> None: diff --git a/eitprocessing/datahandling/loading/timpel.py b/eitprocessing/datahandling/loading/timpel.py index 22d04506a..96198b889 100644 --- a/eitprocessing/datahandling/loading/timpel.py +++ b/eitprocessing/datahandling/loading/timpel.py @@ -108,13 +108,13 @@ def load_from_single_path( continuousdata_collection = DataCollection(ContinuousData) continuousdata_collection.add( ContinuousData( - "global_impedance_(raw)", - "Global impedance", - "a.u.", - "global_impedance", - "Global impedance calculated from raw EIT data", time=time, values=eit_data.calculate_global_impedance(), + label="global_impedance_(raw)", + name="Global impedance", + unit="a.u.", + category="global_impedance", + description="Global impedance calculated from raw EIT data", sample_frequency=sample_frequency, ), ) diff --git a/tests/test_breath_detection.py b/tests/test_breath_detection.py index 248722147..42cc8ce5a 100644 --- a/tests/test_breath_detection.py +++ b/tests/test_breath_detection.py @@ -481,13 +481,9 @@ def test_find_breaths(): label = "waveform_data" cd = ContinuousData( - label, - "Generated waveform data", - "", - "mock", - "", time=time, values=y, + label=label, sample_frequency=sample_frequency, ) seq = Sequence("sequence_label") @@ -514,13 +510,9 @@ def test_find_breaths(): y_copy = np.copy(y) y_copy[438] = -100 # single timepoint around the peak of the 4th breath cd = ContinuousData( - label, - "Generated waveform data", - "", - "mock", - "", time=time, values=y_copy, + label=label, sample_frequency=sample_frequency, ) seq.continuous_data.add(cd, overwrite=True) diff --git a/tests/test_continuous_data.py b/tests/test_continuous_data.py index 59dbb56b8..c3292efeb 100644 --- a/tests/test_continuous_data.py +++ b/tests/test_continuous_data.py @@ -14,20 +14,12 @@ def test_sample_frequency_deprecation_warning(): with pytest.warns(DeprecationWarning, match="`sample_frequency` is set to `None`"): ContinuousData( - "label", - "name", - "unit", - "category", time=time, values=values, ) with pytest.warns(DeprecationWarning, match="`sample_frequency` is set to `None`"): ContinuousData( - "label", - "name", - "unit", - "category", time=time, values=values, sample_frequency=None, @@ -35,10 +27,6 @@ def test_sample_frequency_deprecation_warning(): with warnings.catch_warnings(record=True) as w: ContinuousData( - "label", - "name", - "unit", - "category", time=time, values=values, sample_frequency=sample_frequency, diff --git a/tests/test_eeli.py b/tests/test_eeli.py index c83ab461d..0cf18fe3e 100644 --- a/tests/test_eeli.py +++ b/tests/test_eeli.py @@ -18,10 +18,10 @@ def create_continuous_data_object(): def internal(sample_frequency: float, duration: float, frequency: float) -> ContinuousData: time, values = _make_cosine_wave(sample_frequency, duration, frequency) return ContinuousData( - "label", - "name", - "unit", - "impedance", + label="label", + name="name", + unit="unit", + category="impedance", time=time, values=values, sample_frequency=sample_frequency, From 65443219cf2bff5b018925b37cbf9a18dc3a6194 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 1 Dec 2025 20:53:54 +0100 Subject: [PATCH 15/32] Add tests for frozen ContinuousData --- tests/test_continuous_data.py | 42 +++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/test_continuous_data.py b/tests/test_continuous_data.py index c3292efeb..404ed53fc 100644 --- a/tests/test_continuous_data.py +++ b/tests/test_continuous_data.py @@ -1,3 +1,4 @@ +import dataclasses import warnings import numpy as np @@ -6,6 +7,47 @@ from eitprocessing.datahandling.continuousdata import ContinuousData +@pytest.fixture +def continuous_data_object(): + n = 100 + sample_frequency = 10 + time = np.arange(n) / sample_frequency + values = np.arange(n) + + return ContinuousData( + time=time, + values=values, + sample_frequency=sample_frequency, + ) + + +def test_continuous_data_frozen(continuous_data_object: ContinuousData): + with pytest.raises(dataclasses.FrozenInstanceError, match="cannot assign to field"): + continuous_data_object.sample_frequency = 20 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + continuous_data_object.time[0] = -1 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + continuous_data_object.values[0] = -1 + + +def test_continuous_data_copy_array(continuous_data_object: ContinuousData): + time = continuous_data_object.time + + with pytest.raises(ValueError, match="assignment destination is read-only"): + time[0] = -1 + + with pytest.raises(ValueError, match="cannot set WRITEABLE flag"): + time.flags["WRITEABLE"] = True + + new_time = time.copy() + new_time[0] = -1 + + new_continuous_data_object = continuous_data_object.update(time=new_time) + assert new_continuous_data_object.time[0] == -1 + + def test_sample_frequency_deprecation_warning(): n = 100 sample_frequency = 10 From 037c71bd8b37eea51bf585d3f02cb03f5c7ab508 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Tue, 2 Dec 2025 08:45:29 +0100 Subject: [PATCH 16/32] Change EITData.path to be a tuple of Paths --- eitprocessing/datahandling/eitdata.py | 25 ++++++++++--------- .../datahandling/loading/__init__.py | 5 ++-- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/eitprocessing/datahandling/eitdata.py b/eitprocessing/datahandling/eitdata.py index f3713c727..5d1c14aa2 100644 --- a/eitprocessing/datahandling/eitdata.py +++ b/eitprocessing/datahandling/eitdata.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from collections.abc import Sequence as SequenceType from dataclasses import KW_ONLY, InitVar, dataclass, field from enum import Enum from pathlib import Path @@ -46,7 +47,7 @@ class is meant to hold data from (part of) a singular continuous measurement. _: KW_ONLY sample_frequency: float = field(metadata={"check_equivalence": True}, repr=False) vendor: Vendor = field(metadata={"check_equivalence": True}, repr=False) - path: str | Path | list[Path | str] | None = field(compare=False, repr=False, default=None) + path: tuple[Path] | None = field(compare=False, repr=False, default=None) label: str | None = field(default=None, compare=False, metadata={"check_equivalence": True}) description: str | None = field(default=None, compare=False, repr=False) name: str | None = field(default=None, compare=False, repr=False) @@ -59,7 +60,7 @@ def __init__( *, sample_frequency: float, vendor: Vendor | str, - path: str | Path | list[Path | str] | None = None, + path: str | Path | SequenceType[Path | str] | None = None, label: str | None = None, description: str | None = None, name: str | None = None, @@ -80,7 +81,7 @@ def __init__( if path is None: object.__setattr__(self, "path", None) else: - path_list = self.ensure_path_list(path) + path_list = self.ensure_path_tuple(path) if len(path_list) == 1: object.__setattr__(self, "path", path_list[0]) else: @@ -155,17 +156,17 @@ def framerate(self) -> float: return self.sample_frequency @staticmethod - def ensure_path_list( - path: str | Path | list[str | Path], - ) -> list[Path]: + def ensure_path_tuple( + path: str | Path | SequenceType[str | Path], + ) -> tuple[Path, ...]: """Return the path or paths as a list. The path of any EITData object can be a single str/Path or a list of str/Path objects. This method returns a list of Path objects given either a str/Path or list of str/Paths. """ - if isinstance(path, list): - return [Path(p) for p in path] - return [Path(path)] + if isinstance(path, SequenceType): + return tuple(Path(p) for p in path) + return (Path(path),) def __add__(self: Self, other: Self) -> Self: return self.concatenate(other) @@ -177,9 +178,9 @@ def concatenate(self: Self, other: Self, newlabel: str | None = None) -> Self: msg = f"Concatenation failed. Second dataset ({other.name}) may not start before first ({self.name}) ends." raise ValueError(msg) - self_path = [] if self.path is None else self.ensure_path_list(self.path) - other_path = [] if other.path is None else self.ensure_path_list(other.path) - concat_path = [*self_path, *other_path] + self_path = () if self.path is None else self.ensure_path_tuple(self.path) + other_path = () if other.path is None else self.ensure_path_tuple(other.path) + concat_path = (*self_path, *other_path) if not concat_path: concat_path = None newlabel = newlabel or f"Merge of <{self.label}> and <{other.label}>" diff --git a/eitprocessing/datahandling/loading/__init__.py b/eitprocessing/datahandling/loading/__init__.py index b759eed8f..e8acd5071 100644 --- a/eitprocessing/datahandling/loading/__init__.py +++ b/eitprocessing/datahandling/loading/__init__.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence as SequenceType from functools import reduce from pathlib import Path @@ -7,7 +8,7 @@ def load_eit_data( - path: str | Path | list[str | Path], + path: str | Path | SequenceType[str | Path], vendor: Vendor | str, label: str | None = None, name: str | None = None, @@ -70,7 +71,7 @@ def load_eit_data( first_frame = _check_first_frame(first_frame) - paths = EITData.ensure_path_list(path) + paths = EITData.ensure_path_tuple(path) eit_datasets: list[DataCollection] = [] continuous_datasets: list[DataCollection] = [] From 4117811de80cd90df9567bf4630cb17dc6750e2c Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Tue, 16 Dec 2025 16:07:35 +0100 Subject: [PATCH 17/32] Add option to use memoryview for freezing arrays --- eitprocessing/utils/frozen_array.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/eitprocessing/utils/frozen_array.py b/eitprocessing/utils/frozen_array.py index 7a2955679..32198609d 100644 --- a/eitprocessing/utils/frozen_array.py +++ b/eitprocessing/utils/frozen_array.py @@ -1,3 +1,4 @@ +import warnings from typing import Literal import numpy as np @@ -7,13 +8,26 @@ def freeze_array(a: np.ndarray, *, method: Literal["flag", "memoryview"] = DEFAULT_FREEZE_METHOD) -> np.ndarray: """Return a read-only array that cannot be made writeable again.""" + # Memory buffers cannot represent object/structured fields safely. + if method == "memoryview": + dt = a.dtype + if dt.hasobject: + warnings.warn( + "Cannot use 'memoryview' method for object or structured dtypes; falling back to 'flag' method.", + RuntimeWarning, + stacklevel=2, + ) + method = "flag" + match method: case "flag": + # Make a copy if needed and mark it readonly (can be flipped back by a user). if a.flags["WRITEABLE"]: a = a.copy() - a.flags["WRITEABLE"] = False - return a # is already read-only, e.g., a view of a read-only array + a.flags["WRITEABLE"] = False + return a case "memoryview": + # Numeric/plain dtypes → create a readonly buffer-backed view that can't be flipped. a_c = np.ascontiguousarray(a) ro_buf = memoryview(a_c).toreadonly() return np.frombuffer(ro_buf, dtype=a_c.dtype).reshape(a_c.shape) From c665170aeb69d999fd2f9df68614064a51d52467 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Tue, 16 Dec 2025 16:08:15 +0100 Subject: [PATCH 18/32] Simplify logic --- eitprocessing/utils/frozen_array.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/eitprocessing/utils/frozen_array.py b/eitprocessing/utils/frozen_array.py index 32198609d..9c4d1b2b0 100644 --- a/eitprocessing/utils/frozen_array.py +++ b/eitprocessing/utils/frozen_array.py @@ -9,15 +9,13 @@ def freeze_array(a: np.ndarray, *, method: Literal["flag", "memoryview"] = DEFAULT_FREEZE_METHOD) -> np.ndarray: """Return a read-only array that cannot be made writeable again.""" # Memory buffers cannot represent object/structured fields safely. - if method == "memoryview": - dt = a.dtype - if dt.hasobject: - warnings.warn( - "Cannot use 'memoryview' method for object or structured dtypes; falling back to 'flag' method.", - RuntimeWarning, - stacklevel=2, - ) - method = "flag" + if method == "memoryview" and a.dtype.hasobject: + warnings.warn( + "Cannot use 'memoryview' method for object or structured dtypes; falling back to 'flag' method.", + RuntimeWarning, + stacklevel=2, + ) + method = "flag" match method: case "flag": @@ -27,7 +25,7 @@ def freeze_array(a: np.ndarray, *, method: Literal["flag", "memoryview"] = DEFAU a.flags["WRITEABLE"] = False return a case "memoryview": - # Numeric/plain dtypes → create a readonly buffer-backed view that can't be flipped. + # For numeric/plain dtypes, create a readonly buffer-backed view that can't be flipped. a_c = np.ascontiguousarray(a) ro_buf = memoryview(a_c).toreadonly() return np.frombuffer(ro_buf, dtype=a_c.dtype).reshape(a_c.shape) From 3cea05c0fbbfcc5ba5721127553e3bfda7b0c7c2 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Thu, 18 Dec 2025 13:10:17 +0100 Subject: [PATCH 19/32] Add NamedTupleArray including tests --- .../datahandling/namedtuple_array.py | 494 +++++++++++++ tests/test_namedtuple_array.py | 659 ++++++++++++++++++ 2 files changed, 1153 insertions(+) create mode 100644 eitprocessing/datahandling/namedtuple_array.py create mode 100644 tests/test_namedtuple_array.py diff --git a/eitprocessing/datahandling/namedtuple_array.py b/eitprocessing/datahandling/namedtuple_array.py new file mode 100644 index 000000000..c231a1989 --- /dev/null +++ b/eitprocessing/datahandling/namedtuple_array.py @@ -0,0 +1,494 @@ +# %% +"""Array-like interface over NamedTuple collections enabling NumPy slicing. + +Motivation +---------- +Some data is best represented with multiple related data points (e.g., the start, middle and end time of a breath). They +can be represented as NamedTuple instances; lightweight containers that group related fields together. Lists (or tuples) +of NamedTuples are more difficult to handle efficiently compared to NumPy arrays, especially in multi-dimensional cases. +When NamedTuples are collected inside NumPy arrays, however, they loose their NamedTuple context, removing the field +names and data types. + +This module provides NamedTupleArray, a container that wraps homogeneous collections of NamedTuple instances into a +NumPy structured array, preserving the NamedTuple field names and types while enabling NumPy-style slicing and +field-wise access. It allows access to fields and even computed properties by name. + +Key features +------------ +- Homogeneous type checking: ensures all items share the same NamedTuple type. +- Safe field views: returns read-only views for direct field access. +- Property evaluation: computes per-item properties, resolving postponed + annotations to pick appropriate NumPy dtypes. +- Shape preservation: supports nested sequences, maintaining their shape in the + structured array. +- Interop: from_ndarray helper to map last-axis columns to NamedTuple fields. + +Example: +```python +class Coordinate(NamedTuple): + x: float + y: float + z: float + + @property + def r(self) -> float: + \"\"\"The radial distance from the origin.\"\"\" + return (self.x**2 + self.y**2 + self.z**2) ** 0.5 + +coords = [Coordinate(1.0, 2.0, 2.0), Coordinate(3.0, 4.0, 0.0), Coordinate(0.0, 0.0, 5.0)] +arr = NamedTupleArray(coords) + +arr[1:] # Slice of Coordinates +# NamedTupleArray[Coordinate](array([(3., 4., 0.), (0., 0., 5.)], +# dtype=[('x', ' returns a read-only view of the field. + - Property/attribute: arr["duration"] -> computes the property for each + item, producing an ndarray. If the property has a type hint, the result + dtype is chosen accordingly (e.g., int -> int64, float -> float64). + Otherwise, a heuristic casts ints to int64, else tries float, else object. + - Direct array access via items property: arr.items (always returns the + underlying NumPy array; if frozen=True, modifications are prevented). + + Notes: + ----- + - Homogeneity: All elements must be the same NamedTuple type. + - String fields are kept as object dtype to avoid truncation. + - Properties are evaluated per element; heavy properties may be costly. + - Field views are always read-only to prevent accidental mutation. + - The .items property returns the underlying array; modifications are only + prevented if frozen=True was passed during construction. + + Example: + ------- + >>> class Breath(NamedTuple): + ... start_time: float + ... middle_time: float + ... end_time: float + ... @property + ... def duration(self) -> float: + ... return self.end_time - self.start_time + ... + >>> breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.6, 2.1)] + >>> arr = NamedTupleArray(breaths) + >>> arr["duration"] + array([1. , 1.1]) + """ + + _type: type[T] + __items: np.ndarray + + def __init__(self, items: NonStringSeq[T] | np.ndarray | Nested[T], frozen: bool = False): + """Initialize a NamedTupleArray from a sequence or nested sequence of NamedTuple items. + + Args: + items: A sequence (or nested sequence) of NamedTuple instances, or a numpy ndarray containing them. + frozen: If True, makes the underlying array immutable. + """ + if isinstance(items, np.ndarray) and items.size == 0: + msg = "Cannot infer type from empty array." + raise ValueError(msg) + if not isinstance(items, np.ndarray) and not items: + msg = "Cannot infer type from empty sequence." + raise ValueError(msg) + + # Infer NT type from first leaf element + leaf = _first_leaf(items) + self._type = type(leaf) # type: ignore[assignment] + + # Validate homogeneity + _check_homogeneous(items, self._type) + + # Build structured dtype and array with same shape + dt = _get_tuple_dtype(self._type) + self.__items = np.asarray(items, dtype=dt) + + if frozen: + self._freeze() + + object.__setattr__(self, "_initialized", True) + + def _freeze(self) -> None: + """Make the underlying array immutable.""" + dt = _get_tuple_dtype(self._type) + freeze_method = "flag" if dt.hasobject else "memoryview" + self.__items = freeze_array(self.__items, method=freeze_method) + + def __setattr__(self, name: str, value: object) -> None: + """Allow setting _type and __items only during initialization; block modification after.""" + # Check if initialization is complete (use object.__getattribute__ to bypass our __getattr__) + try: + initialized = object.__getattribute__(self, "_initialized") + except AttributeError: + initialized = False + + # Allow setting _type and __items only during initialization + if not initialized and name in ("_type", "_NamedTupleArray__items"): + super().__setattr__(name, value) + elif initialized and name in ("_type", "_NamedTupleArray__items"): + msg = f"{type(self).__name__!r} object is immutable; cannot modify {name!r} after initialization." + raise AttributeError(msg) + else: + msg = f"{type(self).__name__!r} object is immutable; cannot set attribute {name!r}." + raise AttributeError(msg) + + @classmethod + def from_ndarray(cls, arr: np.ndarray, namedtuple_type: type[T], frozen: bool = False) -> NamedTupleArray[T]: + """Build a NamedTupleArray from an unstructured numpy array. + + The length of the last axis must equal to the number of fields in the given NamedTuple type. + + Example: + This examples represents a sequence of 10 breaths for each of 32x32 pixels. Each breath contains 3 fields: + start_time, middle_time, end_time. + + ```python + assert breath_data.shape == (10, 32, 32, 3) + breaths = NamedTupleArray.from_ndarray(breath_data, Breath) + ``` + + This is equivalent to a list of 10 nested lists, each containing 32 lists (rows) of 32 (columns) Breath + objects. + """ + if arr.ndim < 1: + msg = "arr must have at least 1 dimension." + raise ValueError(msg) + n_fields = len(namedtuple_type._fields) + if (lal := arr.shape[-1]) != n_fields: + msg = f"Last axis must have size {n_fields} for {namedtuple_type.__name__}, not {lal}." + raise ValueError(msg) + dt = _get_tuple_dtype(namedtuple_type) + out = np.empty(arr.shape[:-1], dtype=dt) + + if not dt.fields: + msg = "Generated dtype has no fields; cannot proceed." + raise RuntimeError(msg) + + fields = cast("dict[str, tuple[np.dtype, int]]", dt.fields) + for i, name in enumerate(namedtuple_type._fields): + # Cast each column to the target field dtype to avoid unintended promotion + target_dt = fields[name][0] + out[name] = arr[..., i].astype(target_dt, copy=False) + + inst = cls.__new__(cls) + inst._type = namedtuple_type # noqa: SLF001 + inst._NamedTupleArray__items = out # noqa: SLF001 + + if frozen: + inst._freeze() # noqa: SLF001 + + return inst + + @property + def shape(self) -> tuple[int, ...]: + """The shape of the NamedTupleArray.""" + return self.__items.shape + + @property + def ndim(self) -> int: + """The number of dimensions of the NamedTupleArray.""" + return self.__items.ndim + + @property + def dtype(self) -> np.dtype: + """The dtype of the underlying structured array.""" + return self.__items.dtype + + @property + def items(self) -> np.ndarray: + """The underlying NumPy structured array. + + Returns the private array. If this instance was created with frozen=True, + modifications via this reference are prevented. Otherwise, modifications + are allowed. + """ + return self.__items + + @property + def flags(self) -> flagsobj: + """The flags of the underlying structured array. + + If this instance was created with frozen=True, the WRITEABLE flag cannot + be changed. Otherwise, the flags are fully mutable. + """ + return self.__items.flags + + def __getattr__(self, attr: str): + """Block access to the private array. + + All array attributes should be accessed via explicit properties. + This prevents users from bypassing immutability controls. + """ + msg = f"{type(self).__name__!r} object has no attribute {attr!r}" + raise AttributeError(msg) + + def __array__(self, dtype: np.dtype | None = None) -> np.ndarray: + return self.__items.astype(dtype) if dtype is not None else self.__items + + def __iter__(self) -> Generator[T | NamedTupleArray[T], None, None]: + if self.ndim == 1: + for item in self.__items: + yield self._type(*item) # type: ignore[call-arg] + else: + # yield structured subarrays along axis 0 + for i in range(self.__items.shape[0]): + out = NamedTupleArray.__new__(NamedTupleArray) + out._type = self._type # noqa: SLF001 + out._NamedTupleArray__items = self.__items[i] # noqa: SLF001 + yield out + + def __len__(self) -> int: + return self.__items.shape[0] if self.__items.ndim > 0 else 0 + + def __repr__(self) -> str: + return f"NamedTupleArray[{self._type.__name__}]{repr(self.__items).removeprefix('array')}" + + @overload + def __getitem__(self, index: str) -> np.ndarray: ... + + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> NamedTupleArray[T]: ... + + @overload + def __getitem__(self, index: NonStringSeq) -> NamedTupleArray[T]: ... + + def __getitem__(self, index: str | int | slice | NonStringSeq) -> np.ndarray | NamedTupleArray[T] | T: + # Field-name access: return field view + if isinstance(index, str): + names = self.__items.dtype.names or () + if index in names: + view = self.__items[index] + # Ensure field view is read-only + with contextlib.suppress(Exception): + view.flags.writeable = False + return view + # Computed property or attribute on the NT → compute over all items + return self._compute_property(index) + + # NumPy-style indexing + result = self.__items[index] + + # Structured scalar (np.void) → return NamedTuple + if isinstance(result, np.void): + # For structured np.void, convert to NamedTuple + return self._type(*result.tolist()) # type: ignore[call-arg] + + # Zero-d structured ndarray (shape == ()) → convert to NamedTuple + if isinstance(result, np.ndarray) and result.dtype.fields is not None and result.ndim == 0: + scalar = result.item() # np.void + return self._type(*scalar.tolist()) # type: ignore[call-arg] + + # Structured ndarray → wrap + if isinstance(result, np.ndarray) and result.dtype.fields is not None: + out: NamedTupleArray[T] = type(self).__new__(type(self)) + out._type = self._type + out._NamedTupleArray__items = result + return out + + # Non-structured ndarray (e.g. field slice) → return as-is + return result + + def _compute_property(self, attr: str) -> np.ndarray: + """Compute a property or attribute across all items, preserving the array shape.""" + # Verify attribute exists on the NT instance + sample = self._type(*self.__items.flat[0].tolist()) # type: ignore[call-arg] + if not hasattr(sample, attr): + msg = f"Field or property '{attr}' not found in NamedTuple." + raise KeyError(msg) + + # Collect values (single pass using flat indexing) + out_obj = np.empty(self.shape, dtype=object) + for i, rec in enumerate(self.__items.reshape(-1)): + nt = self._type(*rec.tolist()) # type: ignore[call-arg] + out_obj.reshape(-1)[i] = getattr(nt, attr) + + # Determine target dtype from property annotation if available (handles postponed annotations) + target_dtype: np.dtype | None = None + attr_member = getattr(self._type, attr, None) + if isinstance(attr_member, property) and attr_member.fget is not None: + with contextlib.suppress(Exception): + hints = get_type_hints(attr_member.fget) + ret_ann = hints.get("return") + if ret_ann is not None: + target = _python_to_np_dtype(ret_ann) + target_dtype = np.dtype(target) + + # Cast accordingly + if target_dtype is not None and target_dtype != np.dtype(object): + with contextlib.suppress(Exception): + return out_obj.astype(target_dtype) + + # Heuristics: ints → int64; floats → float64; numpy scalar families respected + with contextlib.suppress(Exception): + if all(isinstance(v, (int, np.integer)) for v in out_obj.flat): + return out_obj.astype(np.int64) + with contextlib.suppress(Exception): + if all(isinstance(v, (float, np.floating)) for v in out_obj.flat): + return out_obj.astype(np.float64) + + return out_obj + + +def _first_leaf( + seq: NamedTuple | np.ndarray | list[NamedTuple] | tuple[NamedTuple, ...] | Nested[NamedTuple], +) -> NamedTuple: + """Recursively find the first NamedTuple instance in a nested sequence or ndarray.""" + if _is_namedtuple_instance(seq): + return seq + if isinstance(seq, np.ndarray): + if seq.size == 0: + msg = "Cannot infer type from empty ndarray." + raise ValueError(msg) + return _first_leaf(seq.flat[0]) + if isinstance(seq, (list, tuple)): + if not seq: + msg = "Cannot infer type from empty nested sequence." + raise ValueError(msg) + return _first_leaf(seq[0]) + msg = "Items must be NamedTuple or nested sequences thereof." + raise TypeError(msg) + + +def _check_homogeneous(seq: SequenceType[NamedTuple] | np.ndarray | Nested[T], typ: type[T]) -> None: + """Recursively check that all NamedTuple instances in the nested sequence/ndarray are of the given type.""" + if isinstance(seq, np.ndarray): + for it in seq.flat: + _check_homogeneous(it, typ) + return + if isinstance(seq, (list, tuple)) and not _is_namedtuple_instance(seq): + seq_ = cast("SequenceType[NamedTuple | Nested[T]]", seq) + for it in seq_: + it: NamedTuple | Nested[T] + _check_homogeneous(it, typ) + return + if _is_namedtuple_instance(seq): + if type(seq) is not typ: + msg = "All items must be of the same NamedTuple type." + raise ValueError(msg) + return + msg = "Items must be NamedTuple or nested sequences thereof." + raise TypeError(msg) + + +def _python_to_np_dtype(py: type) -> np.dtype | str: + """Map basic Python types to NumPy dtypes.""" + if py is int: + return "i8" + if py is float: + return "f8" + if py is bool: + return "?" + if py is str: + return np.dtype(object) # keep Python str; avoids truncation + + if get_origin(py) is Union: + args = [a for a in get_args(py) if a is not type(None)] + if len(args) == 1: + return _python_to_np_dtype(args[0]) + return np.dtype(object) + + +def _is_namedtuple_instance(item: object) -> TypeGuard[NamedTuple]: + """Check if item is a NamedTuple instance.""" + return isinstance(item, tuple) and hasattr(item, "_fields") + + +def _is_namedtuple_type(item: object) -> TypeGuard[type[NamedTuple]]: + """Check if item is a NamedTuple type.""" + return isinstance(item, type) and issubclass(item, tuple) and hasattr(item, "_fields") + + +def _get_tuple_dtype(item: NamedTuple | type[NamedTuple]) -> np.dtype: + """Generate a NumPy structured dtype from a NamedTuple type.""" + if _is_namedtuple_instance(item): + item = type(item) + if not _is_namedtuple_type(item): + msg = "item must be a NamedTuple or a NamedTuple type." + raise TypeError(msg) + + hints = get_type_hints(item, include_extras=False) + names_in_order = list(item._fields) + + def to_np_dtype(py: type) -> str | np.dtype: + if py is int: + return "i8" + if py is float: + return "f8" + if py is bool: + return "?" + # everything else (incl. str) → object to avoid truncation + return np.dtype(object) + + fields = [(name, to_np_dtype(hints.get(name, object))) for name in names_in_order] + return np.dtype(fields) diff --git a/tests/test_namedtuple_array.py b/tests/test_namedtuple_array.py new file mode 100644 index 000000000..4fff3583d --- /dev/null +++ b/tests/test_namedtuple_array.py @@ -0,0 +1,659 @@ +from typing import NamedTuple + +import numpy as np +import pytest + +from eitprocessing.datahandling.namedtuple_array import NamedTupleArray + + +class Mixed(NamedTuple): + """NamedTuple with mixed field types and a computed property.""" + + a: int + b: float + c: bool + d: str + + @property + def d_length(self) -> int: + """Computed property returning the length of string d.""" + return len(self.d) + + +class Simple(NamedTuple): + """NamedTuple with simple numeric fields.""" + + x: int + y: float + + +class Breath(NamedTuple): + """NamedTuple representing a breath with start, mid, end times.""" + + start: float + mid: float + end: float + + @property + def duration(self) -> float: + """Computed property returning the duration of the breath.""" + return self.end - self.start + + +def test_1d_mixed_types_and_properties(): + items = [Mixed(1, 2.0, True, "foo"), Mixed(3, 4.5, False, "hello")] + nta = NamedTupleArray(items) + + assert nta.shape == (2,) + # scalar access + v0 = nta[0] + assert isinstance(v0, Mixed) + assert v0.a == 1 + assert v0.b == 2.0 + assert v0.c is True + assert v0.d == "foo" + + # field views have expected dtype and values + a = nta["a"] + assert a.dtype.kind in ("i", "u") + assert a.shape == (2,) + assert (a == np.array([1, 3])).all() + + b = nta["b"] + assert np.issubdtype(b.dtype, np.floating) + + c = nta["c"] + assert np.issubdtype(c.dtype, np.bool_) + + d = nta["d"] + # strings kept as object + assert d.dtype == object + + # computed property -> returns int dtype (annotated) + dl = nta["d_length"] + assert np.issubdtype(dl.dtype, np.integer) + assert list(dl) == [3, 5] + + +def test_2d_indexing_and_slicing(): + nested = [[Simple(i + j, float(i * j)) for j in range(3)] for i in range(2)] + nta2d = NamedTupleArray(nested) + assert nta2d.shape == (2, 3) + + # scalar multi-dimensional indexing returns NamedTuple + item = nta2d[0, 1] + assert isinstance(item, Simple) + assert item.x == 1 + assert item.y == 0.0 + + # row slice returns NamedTupleArray + row = nta2d[0] + assert isinstance(row, NamedTupleArray) + assert row.shape == (3,) + + # field access on 2D returns array with original shape + xs = nta2d["x"] + assert xs.shape == (2, 3) + assert xs[0, 1] == 1 + + +def test_3d_from_ndarray_and_indexing(): + # create shape (2,2,2,2) last axis 2 fields + arr = np.array( + [ + [[[1, 2.0], [3, 4.0]], [[5, 6.0], [7, 8.0]]], + [[[9, 10.0], [11, 12.0]], [[13, 14.0], [15, 16.0]]], + ], + dtype=float, + ) + nta = NamedTupleArray.from_ndarray(arr, Simple) + assert nta.shape == (2, 2, 2) + + # random 3D scalar access + s = nta[1, 0, 1] + assert isinstance(s, Simple) + assert s.x == 11 + assert s.y == 12.0 + + +def test_field_views_readonly_and_shape_preserved(): + items = [Simple(1, 2.0), Simple(3, 4.0)] + nta = NamedTupleArray(items) + assert nta.flags.writeable is True + vx = nta["x"] + assert vx.flags.writeable is False + assert vx.shape == (2,) + + +def test_calculated_property_float_dtype(): + breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.4, 2.2)] + nta = NamedTupleArray(breaths) + dur = nta["duration"] + assert np.issubdtype(dur.dtype, np.floating) + assert pytest.approx(list(dur), rel=1e-9) == [1.0, 1.2] + + +def test_forwarding_attributes_and_methods(): + items = [Simple(1, 2.0), Simple(3, 4.0)] + nta = NamedTupleArray(items) + + # dtype via property + assert nta.dtype is not None + + # reshape is not available (array is private now) + assert not hasattr(nta, "reshape") + + # flags and ndim reflect the underlying array + assert nta.flags.writeable is not None + assert nta.ndim == 1 + + # field-view methods are available (e.g., sum) + vx = nta["x"] + assert hasattr(vx, "sum") + assert int(vx.sum()) == 4 + + +def test_frozen_namedtuple_array_numeric(): + items = [Simple(1, 2.0), Simple(3, 4.0)] + frozen_nta = NamedTupleArray(items, frozen=True) + nta = NamedTupleArray(items, frozen=False) + + # Field views are always read-only (for both frozen and unfrozen) + field_view = nta["x"] + with pytest.raises(ValueError, match="assignment destination is read-only"): + field_view[0] = 10 + + # Frozen array also prevents field modification + frozen_field_view = frozen_nta["x"] + with pytest.raises(ValueError, match="assignment destination is read-only"): + frozen_field_view[0] = 10 + + # Both frozen and unfrozen arrays block new attributes + with pytest.raises(AttributeError): + nta.new_attribute = 42 + with pytest.raises(AttributeError): + frozen_nta.new_attribute = 42 + + # Both frozen and unfrozen arrays block modification of _type + with pytest.raises(AttributeError): + nta._type = np.floating + with pytest.raises(AttributeError): + frozen_nta._type = np.floating + + # Frozen array prevents toggling writeable flag + with pytest.raises(ValueError, match="cannot set WRITEABLE flag to True of this array"): + frozen_nta.flags.writeable = True + + +def test_frozen_namedtuple_array_string(): + items = [Mixed(1, 2.0, True, "foo"), Mixed(3, 4.0, False, "bar")] + frozen_nta = NamedTupleArray(items, frozen=True) + + # Frozen array with object dtype prevents field modification via read-only view + frozen_field_view = frozen_nta["d"] + with pytest.raises(ValueError, match="assignment destination is read-only"): + frozen_field_view[0] = "baz" + + # Frozen array blocks new attributes + with pytest.raises(AttributeError): + frozen_nta.new_attribute = 42 + + # Frozen array blocks modifying _type + with pytest.raises(AttributeError): + frozen_nta._type = np.floating + + # For object dtypes, users cannot access the underlying array to toggle the writeable flag + # because __items is now private (inaccessible). This is the key benefit of the refactoring. + # The writeable flag is protected by preventing direct array access. + with pytest.raises(AttributeError, match="has no attribute '_items'"): + _ = frozen_nta._items + + +def test_items_property_unfrozen(): + """Test that .items property returns the underlying array for unfrozen arrays.""" + items = [Simple(1, 2.0), Simple(3, 4.0)] + nta = NamedTupleArray(items, frozen=False) + + # Access via .items property works + underlying = nta.items + assert isinstance(underlying, np.ndarray) + assert underlying.shape == (2,) + assert underlying.dtype.names == ("x", "y") + + # For unfrozen arrays, modifications are allowed (though field views are still read-only) + assert nta.items.flags.writeable is True + + +def test_items_property_frozen(): + """Test that .items property is read-only for frozen arrays.""" + items = [Simple(1, 2.0), Simple(3, 4.0)] + nta = NamedTupleArray(items, frozen=True) + + # Access via .items property works + underlying = nta.items + assert isinstance(underlying, np.ndarray) + assert underlying.shape == (2,) + + # For frozen arrays with numeric dtypes, writeable is False + assert nta.items.flags.writeable is False + + +def test_slicing_frozen_array(): + """Test that slicing a frozen array returns a frozen sub-array.""" + items = [Simple(i, float(i)) for i in range(5)] + nta = NamedTupleArray(items, frozen=True) + + # Slice returns NamedTupleArray with same frozenness + sliced = nta[1:3] + assert isinstance(sliced, NamedTupleArray) + assert sliced.shape == (2,) + assert sliced.items.flags.writeable is False + + +def test_computed_property_frozen(): + """Test that computed properties work correctly on frozen arrays.""" + breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.4, 2.2)] + nta = NamedTupleArray(breaths, frozen=True) + + # Computed property returns correct values + dur = nta["duration"] + assert pytest.approx(list(dur), rel=1e-9) == [1.0, 1.2] + # Computed properties create new arrays, so they're writable + assert dur.flags.writeable is True + + +def test_single_item_array(): + """Test NamedTupleArray with a single item.""" + items = [Simple(42, 3.14)] + nta = NamedTupleArray(items, frozen=True) + + assert nta.shape == (1,) + assert nta[0] == Simple(42, 3.14) + assert nta["x"][0] == 42 + + +def test_2d_frozen_array(): + """Test 2D frozen array slicing and access.""" + nested = [[Simple(i + j, float(i * j)) for j in range(2)] for i in range(2)] + nta = NamedTupleArray(nested, frozen=True) + + assert nta.shape == (2, 2) + assert nta.items.flags.writeable is False + + # Slicing should also be frozen + row = nta[0] + assert row.items.flags.writeable is False + + +def test_all_properties_accessible(): + """Test that all expected properties are accessible.""" + items = [Simple(1, 2.0), Simple(3, 4.0)] + nta = NamedTupleArray(items) + + # All properties should be accessible + assert nta.shape == (2,) + assert nta.ndim == 1 + assert nta.dtype is not None + assert nta.flags is not None + assert nta.items is not None + + +def test_setattr_blocked_post_init(): + """Test that __setattr__ blocks all attribute setting after init.""" + items = [Simple(1, 2.0)] + nta = NamedTupleArray(items, frozen=False) + + # Cannot set any new attributes + with pytest.raises(AttributeError, match="immutable"): + nta.custom_attr = "value" + + # Cannot modify internal attributes + with pytest.raises(AttributeError, match="immutable"): + nta._type = int + + +# ============================================================================ +# Edge Cases and Error Conditions +# ============================================================================ + + +def test_empty_sequence(): + """Test that empty sequences raise ValueError.""" + with pytest.raises(ValueError, match="Cannot infer type from empty"): + NamedTupleArray([]) + + +def test_empty_ndarray(): + """Test that empty ndarrays raise ValueError.""" + empty_array = np.array([], dtype=float) + with pytest.raises(ValueError, match="Cannot infer type from empty"): + NamedTupleArray(empty_array) + + +def test_non_namedtuple_items(): + """Test that passing non-NamedTuple items raises TypeError.""" + items = [(1, 2.0), (3, 4.0)] # Regular tuples, not NamedTuples + with pytest.raises(TypeError, match="NamedTuple"): + NamedTupleArray(items) + + +def test_ndarray_last_axis_mismatch(): + """Test that from_ndarray with mismatched last axis raises error.""" + # Simple has 2 fields, but array has 3 columns + arr = np.array([[1, 2.0, 3.0], [4, 5.0, 6.0]]) + with pytest.raises(ValueError): + NamedTupleArray.from_ndarray(arr, Simple) + + +def test_frozen_from_ndarray(): + """Test that from_ndarray with frozen=True works correctly.""" + arr = np.array([[1, 2.0], [3, 4.0]]) + nta = NamedTupleArray.from_ndarray(arr, Simple, frozen=True) + + assert nta.shape == (2,) + assert nta.items.flags.writeable is False + + +def test_iteration(): + """Test that NamedTupleArray is iterable.""" + items = [Simple(1, 2.0), Simple(3, 4.0), Simple(5, 6.0)] + nta = NamedTupleArray(items) + + # Iteration should yield NamedTuple instances + for i, item in enumerate(nta): + assert isinstance(item, Simple) + assert item == items[i] + + +def test_len(): + """Test that len() works on NamedTupleArray.""" + items = [Simple(1, 2.0), Simple(3, 4.0), Simple(5, 6.0)] + nta = NamedTupleArray(items) + + assert len(nta) == 3 + + +def test_repr(): + """Test that repr() produces a meaningful string.""" + items = [Simple(1, 2.0), Simple(3, 4.0)] + nta = NamedTupleArray(items) + + r = repr(nta) + assert "NamedTupleArray" in r + assert "Simple" in r + + +def test_computed_property_nonexistent_attribute(): + """Test accessing a computed property that doesn't exist.""" + items = [Simple(1, 2.0), Simple(3, 4.0)] + nta = NamedTupleArray(items) + + # Accessing a non-existent attribute should raise KeyError + with pytest.raises(KeyError): + _ = nta["nonexistent_property"] + + +def test_heterogeneous_items(): + """Test that arrays with different NamedTuple types raise error.""" + items = [Simple(1, 2.0), Mixed(3, 4.0, True, "foo")] + with pytest.raises(ValueError): + NamedTupleArray(items) + + +def test_mixed_items_and_non_items(): + """Test that mixing NamedTuples with non-NamedTuples raises error.""" + items = [Simple(1, 2.0), (3, 4.0)] # Mix of NamedTuple and tuple + with pytest.raises(TypeError): + NamedTupleArray(items) + + +def test_passing_array_in_list(): + """Test that passing a numpy array wrapped in a list works.""" + arr = np.array([(1, 2.0), (3, 4.0)], dtype=[("x", "i8"), ("y", "f8")]) + # This should be treated as a single-item list containing an array + # and should fail since arrays aren't NamedTuple instances + with pytest.raises(TypeError): + NamedTupleArray([arr]) + + +def test_empty_nested_list(): + """Test that empty nested lists raise ValueError.""" + with pytest.raises(ValueError, match="Cannot infer type from empty"): + NamedTupleArray([[]]) + + +def test_casting_computed_property(): + """Test computed properties with type annotations are cast correctly.""" + breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.4, 2.2)] + nta = NamedTupleArray(breaths) + + # duration is annotated as float + dur = nta["duration"] + assert np.issubdtype(dur.dtype, np.floating) + + # d_length (from Mixed) is annotated as int + mixed_items = [Mixed(1, 2.0, True, "foo"), Mixed(3, 4.0, False, "hello")] + nta_mixed = NamedTupleArray(mixed_items) + d_len = nta_mixed["d_length"] + assert np.issubdtype(d_len.dtype, np.integer) + + +def test_array_protocol(): + """Test that __array__ interface works (if implemented).""" + items = [Simple(1, 2.0), Simple(3, 4.0)] + nta = NamedTupleArray(items) + + # Should be convertible to numpy array + arr = np.asarray(nta) + assert isinstance(arr, np.ndarray) + assert arr.shape == (2,) + + +def test_frozen_unfrozen_mixed_access(): + """Test accessing frozen array via different methods.""" + items = [Simple(1, 2.0), Simple(3, 4.0)] + frozen_nta = NamedTupleArray(items, frozen=True) + + # Scalar access still works + assert frozen_nta[0] == Simple(1, 2.0) + + # Field access returns read-only view + field = frozen_nta["x"] + assert field.flags.writeable is False + + # items property returns frozen array + assert frozen_nta.items.flags.writeable is False + + +def test_nested_list_homogeneity(): + """Test that nested lists maintain homogeneity checks.""" + # Valid nested structure + nested = [[Simple(i, float(i)) for i in range(2)], [Simple(j, float(j)) for j in range(2, 4)]] + nta = NamedTupleArray(nested) + assert nta.shape == (2, 2) + + # Invalid: mixed types in nested structure + invalid_nested = [[Simple(1, 2.0), Mixed(3, 4.0, True, "foo")]] + with pytest.raises(ValueError): + NamedTupleArray(invalid_nested) + + +def test_slicing_preserves_type(): + """Test that slicing returns the correct type.""" + items = [Simple(i, float(i)) for i in range(5)] + nta = NamedTupleArray(items) + + # Integer indexing returns item + item = nta[2] + assert isinstance(item, Simple) + + # Slice returns NamedTupleArray + sliced = nta[1:3] + assert isinstance(sliced, NamedTupleArray) + assert sliced.shape == (2,) + + +def test_from_ndarray_0d_array(): + """Test that from_ndarray with 0D array raises error.""" + arr = np.array(2.0) + with pytest.raises(ValueError, match="at least 1 dimension"): + NamedTupleArray.from_ndarray(arr, Simple) + + +def test_multidimensional_iteration(): + """Test iteration over multi-dimensional NamedTupleArray.""" + nested = [[Simple(i, float(i)) for i in range(2)], [Simple(j + 2, float(j + 2)) for j in range(2)]] + nta = NamedTupleArray(nested) + + # Iteration over 2D array yields 1D NamedTupleArrays + rows = list(nta) + assert len(rows) == 2 + assert all(isinstance(row, NamedTupleArray) for row in rows) + assert rows[0].shape == (2,) + + +def test_computed_property_with_heuristic_int(): + """Test computed property that infers int dtype via heuristic.""" + + # Create a NamedTuple with a property that returns int but has no type annotation + class AnnotationlessNT(NamedTuple): + value: int + + @property + def doubled(self) -> int: + """Unannotated property returning int.""" + return self.value * 2 + + items = [AnnotationlessNT(1), AnnotationlessNT(2)] + nta = NamedTupleArray(items) + + result = nta["doubled"] + assert np.issubdtype(result.dtype, np.integer) + assert list(result) == [2, 4] + + +def test_computed_property_with_heuristic_float(): + """Test computed property that infers float dtype via heuristic.""" + + class FloatPropertyNT(NamedTuple): + value: float + + @property + def halved(self) -> float: + """Unannotated property returning float.""" + return self.value / 2.0 + + items = [FloatPropertyNT(2.0), FloatPropertyNT(4.0)] + nta = NamedTupleArray(items) + + result = nta["halved"] + assert np.issubdtype(result.dtype, np.floating) + assert pytest.approx(list(result)) == [1.0, 2.0] + + +def test_computed_property_mixed_types_returns_object(): + """Test computed property with mixed types returns object dtype.""" + + class MixedPropertyNT(NamedTuple): + value: int + + @property + def mixed(self) -> int | str: + """Unannotated property that sometimes returns int, sometimes str.""" + return self.value if self.value < 2 else f"str_{self.value}" + + items = [MixedPropertyNT(1), MixedPropertyNT(3)] + nta = NamedTupleArray(items) + + result = nta["mixed"] + assert result.dtype == object + assert result[0] == 1 + assert result[1] == "str_3" + + +def test_getitem_returns_2d_array(): + """Test that indexing returns correct types for various index types.""" + items = [Simple(i, float(i)) for i in range(6)] + nta = NamedTupleArray(items) + + # Fancy indexing with list returns NamedTupleArray + indexed = nta[[0, 2, 4]] + assert isinstance(indexed, NamedTupleArray) + assert indexed.shape == (3,) + + +def test_zero_d_ndarray_from_indexing(): + """Test that scalar indexing returns NamedTuple correctly.""" + items = [Simple(1, 2.0), Simple(3, 4.0)] + nta = NamedTupleArray(items) + + # Scalar access should return NamedTuple + scalar = nta[0] + assert isinstance(scalar, Simple) + assert scalar.x == 1 + assert scalar.y == 2.0 + + +def test_from_ndarray_empty_fields(): + """Test that from_ndarray handles edge cases.""" + # Just verify the function is defined and doesn't break in normal usage + # The empty fields check is hard to trigger naturally + arr = np.array([[1, 2.0], [3, 4.0]]) + nta = NamedTupleArray.from_ndarray(arr, Simple) + assert nta.shape == (2,) + + +def test_nested_empty_list_deeply(): + """Test that deeply nested empty lists raise ValueError.""" + with pytest.raises(ValueError): + NamedTupleArray([[[]]]) + + +def test_property_with_exception_in_heuristic(): + """Test that computed property handles exceptions in type inference gracefully.""" + + class NTWithUnusualProperty(NamedTuple): + value: int + + @property + def unusual(self) -> dict: + # Return something that can't be easily cast + return {"key": self.value} + + items = [NTWithUnusualProperty(1), NTWithUnusualProperty(2)] + nta = NamedTupleArray(items) + + # Should return object dtype (via exception handling) + result = nta["unusual"] + assert result.dtype == object + + +def test_from_ndarray_3d_array(): + """Test from_ndarray with 3D array.""" + arr = np.array([[[1, 2.0], [3, 4.0]], [[5, 6.0], [7, 8.0]]], dtype=float) + nta = NamedTupleArray.from_ndarray(arr, Simple) + + assert nta.shape == (2, 2) + assert nta[0, 1] == Simple(3, 4.0) + + +def test_iteration_1d_direct_yield(): + """Test that 1D iteration yields NamedTuple items directly.""" + items = [Simple(1, 1.0), Simple(2, 2.0), Simple(3, 3.0)] + nta = NamedTupleArray(items) + + yielded = list(nta) + assert len(yielded) == 3 + assert all(isinstance(item, Simple) for item in yielded) + assert yielded == items + + +def test_from_ndarray_empty_namedtuple(): + """Test that from_ndarray with empty NamedTuple raises RuntimeError (lines 232-233).""" + + class Empty(NamedTuple): + pass + + arr = np.array([], dtype=float) + with pytest.raises(RuntimeError, match="no fields"): + NamedTupleArray.from_ndarray(arr, Empty) From c96d549874842af085e787748d6a48747db5da91 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Thu, 18 Dec 2025 13:10:58 +0100 Subject: [PATCH 20/32] Freeze Event dataclass --- eitprocessing/datahandling/event.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eitprocessing/datahandling/event.py b/eitprocessing/datahandling/event.py index d1a1c3aa4..19d7715ec 100644 --- a/eitprocessing/datahandling/event.py +++ b/eitprocessing/datahandling/event.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -@dataclass +@dataclass(frozen=True) class Event: """Single time point event registered during an EIT measurement.""" From 5cddfc66d3de48486d99c610da9802db2480e9e6 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Thu, 18 Dec 2025 13:11:06 +0100 Subject: [PATCH 21/32] Add target version for ruff --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 31563f828..51f7e86c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -190,3 +190,6 @@ replace = 'version: "{new_version}"' filename = "eitprocessing/__init__.py" search = '__version__ = "{current_version}"' replace = '__version__ = "{new_version}"' + + +[tool.pyrefly] From 97591e013e363e51d6ecc87c2b521201c8ee60e7 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Thu, 18 Dec 2025 13:18:56 +0100 Subject: [PATCH 22/32] Set IntervalData intervals to NamedTupleArray --- eitprocessing/datahandling/intervaldata.py | 34 ++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/eitprocessing/datahandling/intervaldata.py b/eitprocessing/datahandling/intervaldata.py index 7fa186069..8bb798b75 100644 --- a/eitprocessing/datahandling/intervaldata.py +++ b/eitprocessing/datahandling/intervaldata.py @@ -7,6 +7,7 @@ from eitprocessing.datahandling import DataContainer from eitprocessing.datahandling.mixins.slicing import HasTimeIndexer, SelectByIndex +from eitprocessing.datahandling.namedtuple_array import NamedTupleArray, Nested T = TypeVar("T", bound="IntervalData") @@ -17,6 +18,11 @@ class Interval(NamedTuple): start_time: float end_time: float + @property + def duration(self) -> float: + """Duration of the interval.""" + return self.end_time - self.start_time + @dataclass(eq=False) class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): @@ -51,18 +57,42 @@ class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): name: str = field(compare=False, repr=False) unit: str | None = field(metadata={"check_equivalence": True}, repr=False) category: str = field(metadata={"check_equivalence": True}, repr=False) - intervals: list[Interval | tuple[float, float]] = field(repr=False) + intervals: NamedTupleArray[Interval] = field(repr=False) values: list[Any] | None = field(repr=False, default=None) description: str = field(compare=False, default="", repr=False) default_partial_inclusion: bool = field(repr=False, default=False) def __post_init__(self) -> None: - self.intervals = [Interval._make(interval) for interval in self.intervals] + self.intervals = self._parse_intervals(self.intervals) if self.has_values and (lv := len(self.values)) != (lt := len(self.intervals)): msg = f"The number of time points ({lt}) does not match the number of values ({lv})." raise ValueError(msg) + @staticmethod + def _parse_intervals( + intervals: list[Interval] | Nested[Interval] | np.ndarray | NamedTupleArray[Interval], + ) -> NamedTupleArray[Interval]: + """Parse intervals into a NamedTupleArray of Interval.""" + if isinstance(intervals, NamedTupleArray): + if intervals.dtype is not Interval: + msg = f"Expected intervals of type 'Interval', got '{intervals.dtype.__name__}'" + raise TypeError(msg) + return intervals + + if isinstance(intervals, np.ndarray): + try: + return NamedTupleArray.from_numpy_array(intervals, Interval) + except ValueError as e: + msg = f"Could not parse intervals from numpy array with dtype '{intervals.dtype}'" + raise TypeError(msg) from e + + try: + return NamedTupleArray.from_nested(intervals, Interval) + except Exception as e: + msg = "Could not parse intervals from given input." + raise TypeError(msg) from e + def __len__(self) -> int: return len(self.intervals) From 903b0403d5ea9239aeee836390525958f73d9b02 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Thu, 5 Feb 2026 17:17:59 +0100 Subject: [PATCH 23/32] Update NamedTupleArray --- .../datahandling/namedtuple_array.py | 149 ++++++++++++++---- tests/test_namedtuple_array.py | 135 +++++++++++++++- 2 files changed, 242 insertions(+), 42 deletions(-) diff --git a/eitprocessing/datahandling/namedtuple_array.py b/eitprocessing/datahandling/namedtuple_array.py index c231a1989..4f730ed06 100644 --- a/eitprocessing/datahandling/namedtuple_array.py +++ b/eitprocessing/datahandling/namedtuple_array.py @@ -79,7 +79,7 @@ def r(self) -> float: from numpy._core.multiarray import flagsobj -T = TypeVar("T", bound=NamedTuple) +T = TypeVar("T", bound=tuple) NonStringSeq: TypeAlias = tuple[T, ...] | list[T] Nested = T | NonStringSeq @@ -143,32 +143,43 @@ class NamedTupleArray(Generic[T]): array([1. , 1.1]) """ - _type: type[T] + namedtuple_type: type[T] __items: np.ndarray - def __init__(self, items: NonStringSeq[T] | np.ndarray | Nested[T], frozen: bool = False): + def __init__( + self, + items: NonStringSeq[T] | np.ndarray | Nested[T], + namedtuple_type: type[T] | None = None, + frozen: bool = True, + ): """Initialize a NamedTupleArray from a sequence or nested sequence of NamedTuple items. Args: items: A sequence (or nested sequence) of NamedTuple instances, or a numpy ndarray containing them. - frozen: If True, makes the underlying array immutable. + namedtuple_type: + Optional explicit type of the NamedTuple items. If not provided, it will be inferred from the first leaf + item. + frozen: If True (default), makes the underlying array immutable. """ - if isinstance(items, np.ndarray) and items.size == 0: - msg = "Cannot infer type from empty array." - raise ValueError(msg) - if not isinstance(items, np.ndarray) and not items: - msg = "Cannot infer type from empty sequence." - raise ValueError(msg) + if namedtuple_type is not None: + self.namedtuple_type = namedtuple_type + else: + if isinstance(items, np.ndarray) and items.size == 0: + msg = "Cannot infer type from empty array." + raise ValueError(msg) + if not isinstance(items, np.ndarray) and not items: + msg = "Cannot infer type from empty sequence." + raise ValueError(msg) - # Infer NT type from first leaf element - leaf = _first_leaf(items) - self._type = type(leaf) # type: ignore[assignment] + # Infer NT type from first leaf element + leaf = _first_leaf(items) + self.namedtuple_type = type(leaf) # type: ignore[assignment] # Validate homogeneity - _check_homogeneous(items, self._type) + _check_homogeneous(items, self.namedtuple_type) # Build structured dtype and array with same shape - dt = _get_tuple_dtype(self._type) + dt = _get_tuple_dtype(self.namedtuple_type) self.__items = np.asarray(items, dtype=dt) if frozen: @@ -178,22 +189,22 @@ def __init__(self, items: NonStringSeq[T] | np.ndarray | Nested[T], frozen: bool def _freeze(self) -> None: """Make the underlying array immutable.""" - dt = _get_tuple_dtype(self._type) + dt = _get_tuple_dtype(self.namedtuple_type) freeze_method = "flag" if dt.hasobject else "memoryview" self.__items = freeze_array(self.__items, method=freeze_method) def __setattr__(self, name: str, value: object) -> None: - """Allow setting _type and __items only during initialization; block modification after.""" + """Allow setting type and __items only during initialization; block modification after.""" # Check if initialization is complete (use object.__getattribute__ to bypass our __getattr__) try: initialized = object.__getattribute__(self, "_initialized") except AttributeError: initialized = False - # Allow setting _type and __items only during initialization - if not initialized and name in ("_type", "_NamedTupleArray__items"): + # Allow setting type and __items only during initialization + if not initialized and name in ("namedtuple_type", "_NamedTupleArray__items"): super().__setattr__(name, value) - elif initialized and name in ("_type", "_NamedTupleArray__items"): + elif initialized and name in ("namedtuple_type", "_NamedTupleArray__items"): msg = f"{type(self).__name__!r} object is immutable; cannot modify {name!r} after initialization." raise AttributeError(msg) else: @@ -201,9 +212,10 @@ def __setattr__(self, name: str, value: object) -> None: raise AttributeError(msg) @classmethod - def from_ndarray(cls, arr: np.ndarray, namedtuple_type: type[T], frozen: bool = False) -> NamedTupleArray[T]: - """Build a NamedTupleArray from an unstructured numpy array. + def from_array(cls, arr: np.ndarray | Nested, namedtuple_type: type[T], frozen: bool = True) -> NamedTupleArray[T]: + """Build a NamedTupleArray from an unstructured numpy array or nested list. + The list must be convertible to a numpy array. The last axis of the array is mapped to the fields. The length of the last axis must equal to the number of fields in the given NamedTuple type. Example: @@ -211,13 +223,16 @@ def from_ndarray(cls, arr: np.ndarray, namedtuple_type: type[T], frozen: bool = start_time, middle_time, end_time. ```python - assert breath_data.shape == (10, 32, 32, 3) + breath_data = load_breath_data() # shape (10, 32, 32, 3) breaths = NamedTupleArray.from_ndarray(breath_data, Breath) ``` This is equivalent to a list of 10 nested lists, each containing 32 lists (rows) of 32 (columns) Breath objects. """ + if not isinstance(arr, np.ndarray): + arr = np.array(arr) + if arr.ndim < 1: msg = "arr must have at least 1 dimension." raise ValueError(msg) @@ -239,7 +254,7 @@ def from_ndarray(cls, arr: np.ndarray, namedtuple_type: type[T], frozen: bool = out[name] = arr[..., i].astype(target_dt, copy=False) inst = cls.__new__(cls) - inst._type = namedtuple_type # noqa: SLF001 + inst.namedtuple_type = namedtuple_type inst._NamedTupleArray__items = out # noqa: SLF001 if frozen: @@ -281,6 +296,36 @@ def flags(self) -> flagsobj: """ return self.__items.flags + def to_array(self) -> np.ndarray: + """Convert to an unstructured numpy array. + + Returns a 2D array where each column corresponds to a field of the NamedTuple, + in the order of the NamedTuple fields. This allows convenient slicing by + column indices like `arr[:, [0, 2]]`. + + Returns: + A 2D unstructured numpy array of shape (n_items, n_fields). + + Example: + >>> class Point(NamedTuple): + ... x: float + ... y: float + ... z: float + >>> nta = NamedTupleArray([Point(1.0, 2.0, 3.0), Point(4.0, 5.0, 6.0)]) + >>> arr = nta.to_array() + >>> arr.shape + (2, 3) + >>> arr[:, [0, 2]] # Get x and z columns + array([[1., 3.], + [4., 6.]]) + """ + # Stack each field as a column to create unstructured array + if not self.__items.dtype.names: + # No fields, return empty array + return np.empty((self.shape[0], 0)) + + return np.column_stack([self.__items[name] for name in self.__items.dtype.names]) + def __getattr__(self, attr: str): """Block access to the private array. @@ -296,12 +341,12 @@ def __array__(self, dtype: np.dtype | None = None) -> np.ndarray: def __iter__(self) -> Generator[T | NamedTupleArray[T], None, None]: if self.ndim == 1: for item in self.__items: - yield self._type(*item) # type: ignore[call-arg] + yield self.namedtuple_type(*item) # type: ignore[call-arg] else: # yield structured subarrays along axis 0 for i in range(self.__items.shape[0]): out = NamedTupleArray.__new__(NamedTupleArray) - out._type = self._type # noqa: SLF001 + out.namedtuple_type = self.namedtuple_type out._NamedTupleArray__items = self.__items[i] # noqa: SLF001 yield out @@ -309,7 +354,42 @@ def __len__(self) -> int: return self.__items.shape[0] if self.__items.ndim > 0 else 0 def __repr__(self) -> str: - return f"NamedTupleArray[{self._type.__name__}]{repr(self.__items).removeprefix('array')}" + return f"NamedTupleArray[{self.namedtuple_type.__name__}]{repr(self.__items).removeprefix('array')}" + + def __eq__(self, other: object) -> bool: + """Compare two NamedTupleArray instances for equality. + + Two NamedTupleArray instances are equal if: + - They are both NamedTupleArray instances + - They have the same NamedTuple type + - Their underlying arrays are equal (including NaN equality for floats) + """ + if not isinstance(other, NamedTupleArray): + return False + + if self.namedtuple_type is not other.namedtuple_type: + return False + + # Compare shapes + if self.__items.shape != other.__items.shape: + return False + + # Compare dtypes + if self.__items.dtype != other.__items.dtype: + return False + + # For structured arrays, compare field by field to handle NaN values properly + for name in self.__items.dtype.names or []: + self_field = self.__items[name] + other_field = other.__items[name] + + # Use array_equal with equal_nan for each field + if not np.array_equal(self_field, other_field, equal_nan=True): + return False + + return True + + __hash__ = None # type: ignore[assignment] @overload def __getitem__(self, index: str) -> np.ndarray: ... @@ -342,17 +422,17 @@ def __getitem__(self, index: str | int | slice | NonStringSeq) -> np.ndarray | N # Structured scalar (np.void) → return NamedTuple if isinstance(result, np.void): # For structured np.void, convert to NamedTuple - return self._type(*result.tolist()) # type: ignore[call-arg] + return self.namedtuple_type(*result.tolist()) # type: ignore[call-arg] # Zero-d structured ndarray (shape == ()) → convert to NamedTuple if isinstance(result, np.ndarray) and result.dtype.fields is not None and result.ndim == 0: scalar = result.item() # np.void - return self._type(*scalar.tolist()) # type: ignore[call-arg] + return self.namedtuple_type(*scalar.tolist()) # type: ignore[call-arg] # Structured ndarray → wrap if isinstance(result, np.ndarray) and result.dtype.fields is not None: out: NamedTupleArray[T] = type(self).__new__(type(self)) - out._type = self._type + out.namedtuple_type = self.namedtuple_type out._NamedTupleArray__items = result return out @@ -362,7 +442,7 @@ def __getitem__(self, index: str | int | slice | NonStringSeq) -> np.ndarray | N def _compute_property(self, attr: str) -> np.ndarray: """Compute a property or attribute across all items, preserving the array shape.""" # Verify attribute exists on the NT instance - sample = self._type(*self.__items.flat[0].tolist()) # type: ignore[call-arg] + sample = self.namedtuple_type(*self.__items.flat[0].tolist()) # type: ignore[call-arg] if not hasattr(sample, attr): msg = f"Field or property '{attr}' not found in NamedTuple." raise KeyError(msg) @@ -370,12 +450,12 @@ def _compute_property(self, attr: str) -> np.ndarray: # Collect values (single pass using flat indexing) out_obj = np.empty(self.shape, dtype=object) for i, rec in enumerate(self.__items.reshape(-1)): - nt = self._type(*rec.tolist()) # type: ignore[call-arg] + nt = self.namedtuple_type(*rec.tolist()) # type: ignore[call-arg] out_obj.reshape(-1)[i] = getattr(nt, attr) # Determine target dtype from property annotation if available (handles postponed annotations) target_dtype: np.dtype | None = None - attr_member = getattr(self._type, attr, None) + attr_member = getattr(self.namedtuple_type, attr, None) if isinstance(attr_member, property) and attr_member.fget is not None: with contextlib.suppress(Exception): hints = get_type_hints(attr_member.fget) @@ -416,6 +496,7 @@ def _first_leaf( msg = "Cannot infer type from empty nested sequence." raise ValueError(msg) return _first_leaf(seq[0]) + msg = "Items must be NamedTuple or nested sequences thereof." raise TypeError(msg) @@ -469,7 +550,7 @@ def _is_namedtuple_type(item: object) -> TypeGuard[type[NamedTuple]]: return isinstance(item, type) and issubclass(item, tuple) and hasattr(item, "_fields") -def _get_tuple_dtype(item: NamedTuple | type[NamedTuple]) -> np.dtype: +def _get_tuple_dtype(item: NamedTuple | type[tuple]) -> np.dtype: """Generate a NumPy structured dtype from a NamedTuple type.""" if _is_namedtuple_instance(item): item = type(item) diff --git a/tests/test_namedtuple_array.py b/tests/test_namedtuple_array.py index 4fff3583d..3475999c4 100644 --- a/tests/test_namedtuple_array.py +++ b/tests/test_namedtuple_array.py @@ -106,7 +106,7 @@ def test_3d_from_ndarray_and_indexing(): ], dtype=float, ) - nta = NamedTupleArray.from_ndarray(arr, Simple) + nta = NamedTupleArray.from_array(arr, Simple) assert nta.shape == (2, 2, 2) # random 3D scalar access @@ -118,7 +118,7 @@ def test_3d_from_ndarray_and_indexing(): def test_field_views_readonly_and_shape_preserved(): items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items) + nta = NamedTupleArray(items, frozen=False) assert nta.flags.writeable is True vx = nta["x"] assert vx.flags.writeable is False @@ -342,13 +342,13 @@ def test_ndarray_last_axis_mismatch(): # Simple has 2 fields, but array has 3 columns arr = np.array([[1, 2.0, 3.0], [4, 5.0, 6.0]]) with pytest.raises(ValueError): - NamedTupleArray.from_ndarray(arr, Simple) + NamedTupleArray.from_array(arr, Simple) def test_frozen_from_ndarray(): """Test that from_ndarray with frozen=True works correctly.""" arr = np.array([[1, 2.0], [3, 4.0]]) - nta = NamedTupleArray.from_ndarray(arr, Simple, frozen=True) + nta = NamedTupleArray.from_array(arr, Simple, frozen=True) assert nta.shape == (2,) assert nta.items.flags.writeable is False @@ -497,7 +497,7 @@ def test_from_ndarray_0d_array(): """Test that from_ndarray with 0D array raises error.""" arr = np.array(2.0) with pytest.raises(ValueError, match="at least 1 dimension"): - NamedTupleArray.from_ndarray(arr, Simple) + NamedTupleArray.from_array(arr, Simple) def test_multidimensional_iteration(): @@ -599,7 +599,7 @@ def test_from_ndarray_empty_fields(): # Just verify the function is defined and doesn't break in normal usage # The empty fields check is hard to trigger naturally arr = np.array([[1, 2.0], [3, 4.0]]) - nta = NamedTupleArray.from_ndarray(arr, Simple) + nta = NamedTupleArray.from_array(arr, Simple) assert nta.shape == (2,) @@ -631,7 +631,7 @@ def unusual(self) -> dict: def test_from_ndarray_3d_array(): """Test from_ndarray with 3D array.""" arr = np.array([[[1, 2.0], [3, 4.0]], [[5, 6.0], [7, 8.0]]], dtype=float) - nta = NamedTupleArray.from_ndarray(arr, Simple) + nta = NamedTupleArray.from_array(arr, Simple) assert nta.shape == (2, 2) assert nta[0, 1] == Simple(3, 4.0) @@ -656,4 +656,123 @@ class Empty(NamedTuple): arr = np.array([], dtype=float) with pytest.raises(RuntimeError, match="no fields"): - NamedTupleArray.from_ndarray(arr, Empty) + NamedTupleArray.from_array(arr, Empty) + + +def test_equality_identical_arrays(): + """Test that two NamedTupleArray instances with identical data are equal.""" + items1 = [Simple(1, 2.0), Simple(3, 4.5)] + items2 = [Simple(1, 2.0), Simple(3, 4.5)] + + arr1 = NamedTupleArray(items1) + arr2 = NamedTupleArray(items2) + + assert arr1 == arr2 + assert arr2 == arr1 # Test equality is symmetric + + +def test_equality_different_values(): + """Test that NamedTupleArray instances with different values are not equal.""" + items1 = [Simple(1, 2.0), Simple(3, 4.5)] + items2 = [Simple(1, 2.0), Simple(3, 5.0)] + + arr1 = NamedTupleArray(items1) + arr2 = NamedTupleArray(items2) + + assert arr1 != arr2 + + +def test_equality_different_lengths(): + """Test that NamedTupleArray instances with different lengths are not equal.""" + items1 = [Simple(1, 2.0), Simple(3, 4.5)] + items2 = [Simple(1, 2.0), Simple(3, 4.5), Simple(5, 6.0)] + + arr1 = NamedTupleArray(items1) + arr2 = NamedTupleArray(items2) + + assert arr1 != arr2 + + +def test_equality_different_types(): + """Test that NamedTupleArray instances with different NamedTuple types are not equal.""" + simple_items = [Simple(1, 2.0), Simple(3, 4.5)] + breath_items = [Breath(1.0, 2.0, 3.0), Breath(3.0, 4.0, 5.0)] + + arr1 = NamedTupleArray(simple_items) + arr2 = NamedTupleArray(breath_items) + + assert arr1 != arr2 + + +def test_equality_with_nan_values(): + """Test that NamedTupleArray instances with NaN values can be compared correctly.""" + items1 = [Simple(1, np.nan), Simple(3, 4.5)] + items2 = [Simple(1, np.nan), Simple(3, 4.5)] + + arr1 = NamedTupleArray(items1) + arr2 = NamedTupleArray(items2) + + # NaN values should be considered equal in this comparison + assert arr1 == arr2 + + +def test_equality_with_nan_different_positions(): + """Test that NamedTupleArray with NaN in different positions are not equal.""" + items1 = [Simple(1, np.nan), Simple(3, 4.5)] + items2 = [Simple(1, 2.0), Simple(3, np.nan)] + + arr1 = NamedTupleArray(items1) + arr2 = NamedTupleArray(items2) + + assert arr1 != arr2 + + +def test_equality_2d_arrays(): + """Test equality comparison for 2D NamedTupleArray instances.""" + nested1 = [[Simple(i + j, float(i * j)) for j in range(3)] for i in range(2)] + nested2 = [[Simple(i + j, float(i * j)) for j in range(3)] for i in range(2)] + + arr1 = NamedTupleArray(nested1) + arr2 = NamedTupleArray(nested2) + + assert arr1 == arr2 + + +def test_equality_2d_arrays_different(): + """Test inequality for 2D NamedTupleArray instances with different values.""" + nested1 = [[Simple(i + j, float(i * j)) for j in range(3)] for i in range(2)] + nested2 = [[Simple(i + j + 1, float(i * j)) for j in range(3)] for i in range(2)] + + arr1 = NamedTupleArray(nested1) + arr2 = NamedTupleArray(nested2) + + assert arr1 != arr2 + + +def test_equality_not_equal_to_non_namedtuplearray(): + """Test that NamedTupleArray is not equal to other types.""" + items = [Simple(1, 2.0), Simple(3, 4.5)] + arr = NamedTupleArray(items) + + # Test inequality with list + assert arr != items + + # Test inequality with numpy array + assert arr != np.array([(1, 2.0), (3, 4.5)]) + + # Test inequality with None + assert arr is not None + + # Test inequality with string + assert arr != "not an array" + + +def test_equality_frozen_and_unfrozen(): + """Test that frozen and unfrozen arrays with same data are equal.""" + items1 = [Simple(1, 2.0), Simple(3, 4.5)] + items2 = [Simple(1, 2.0), Simple(3, 4.5)] + + arr_frozen = NamedTupleArray(items1, frozen=True) + arr_unfrozen = NamedTupleArray(items2, frozen=False) + + assert arr_frozen == arr_unfrozen From 2afd20afd10ae48fc379ea5d1ca125e88fb32e85 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 9 Feb 2026 14:04:30 +0100 Subject: [PATCH 24/32] Make DataContainers frozen --- eitprocessing/datahandling/__init__.py | 14 +------------- eitprocessing/datahandling/continuousdata.py | 4 ++-- eitprocessing/datahandling/eitdata.py | 4 ++-- eitprocessing/datahandling/intervaldata.py | 14 ++++++-------- eitprocessing/datahandling/sparsedata.py | 2 +- 5 files changed, 12 insertions(+), 26 deletions(-) diff --git a/eitprocessing/datahandling/__init__.py b/eitprocessing/datahandling/__init__.py index 14a24d185..823e57371 100644 --- a/eitprocessing/datahandling/__init__.py +++ b/eitprocessing/datahandling/__init__.py @@ -7,20 +7,8 @@ from eitprocessing.datahandling.mixins.equality import Equivalence -@dataclass(eq=False) -class DataContainer(Equivalence): - """Base class for data container classes.""" - - def __bool__(self): - return True - - def deepcopy(self) -> Self: - """Return a deep copy of the object.""" - return deepcopy(self) - - @dataclass(eq=False, frozen=True) -class FrozenDataContainer(Equivalence): +class DataContainer(Equivalence): """Base class for data container classes.""" def __bool__(self): diff --git a/eitprocessing/datahandling/continuousdata.py b/eitprocessing/datahandling/continuousdata.py index d82d0fcda..e119c8522 100644 --- a/eitprocessing/datahandling/continuousdata.py +++ b/eitprocessing/datahandling/continuousdata.py @@ -6,7 +6,7 @@ import numpy as np -from eitprocessing.datahandling import FrozenDataContainer +from eitprocessing.datahandling import DataContainer from eitprocessing.datahandling.mixins.slicing import SelectByTime from eitprocessing.utils.frozen_array import freeze_array @@ -19,7 +19,7 @@ @dataclass(eq=False, frozen=True) -class ContinuousData(FrozenDataContainer, SelectByTime): +class ContinuousData(DataContainer, SelectByTime): """Container for data with a continuous time axis. Continuous data is assumed to be sequential (i.e. a single data point at each time point, sorted by time) and diff --git a/eitprocessing/datahandling/eitdata.py b/eitprocessing/datahandling/eitdata.py index 5d1c14aa2..ae3ae34d1 100644 --- a/eitprocessing/datahandling/eitdata.py +++ b/eitprocessing/datahandling/eitdata.py @@ -9,7 +9,7 @@ import numpy as np -from eitprocessing.datahandling import FrozenDataContainer +from eitprocessing.datahandling import DataContainer from eitprocessing.datahandling.continuousdata import ContinuousData from eitprocessing.datahandling.mixins.slicing import SelectByTime from eitprocessing.utils.frozen_array import freeze_array @@ -22,7 +22,7 @@ @dataclass(eq=False, frozen=True) -class EITData(FrozenDataContainer, SelectByTime): +class EITData(DataContainer, SelectByTime): """Container for EIT impedance data. This class holds the pixel impedance from an EIT measurement, as well as metadata describing the measurement. The diff --git a/eitprocessing/datahandling/intervaldata.py b/eitprocessing/datahandling/intervaldata.py index 8bb798b75..c4aebbbec 100644 --- a/eitprocessing/datahandling/intervaldata.py +++ b/eitprocessing/datahandling/intervaldata.py @@ -24,7 +24,7 @@ def duration(self) -> float: return self.end_time - self.start_time -@dataclass(eq=False) +@dataclass(eq=False, frozen=True) class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): """Container for interval data existing over a period of time. @@ -63,9 +63,12 @@ class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): default_partial_inclusion: bool = field(repr=False, default=False) def __post_init__(self) -> None: - self.intervals = self._parse_intervals(self.intervals) + object.__setattr__(self, "intervals", self._parse_intervals(self.intervals)) - if self.has_values and (lv := len(self.values)) != (lt := len(self.intervals)): + if self.values is not None and len(self.values) == 0: + object.__setattr__(self, "values", None) + + if self.values is not None and (lv := len(self.values)) != (lt := len(self.intervals)): msg = f"The number of time points ({lt}) does not match the number of values ({lv})." raise ValueError(msg) @@ -153,11 +156,6 @@ def select_by_time( """ newlabel = newlabel or self.label - if start_time is None and end_time is None: - copy_ = copy.deepcopy(self) - copy_.label = newlabel - return copy_ - partial_inclusion = partial_inclusion or self.default_partial_inclusion selection_start = start_time or self.intervals[0].start_time diff --git a/eitprocessing/datahandling/sparsedata.py b/eitprocessing/datahandling/sparsedata.py index e4840f3d3..f23177fb9 100644 --- a/eitprocessing/datahandling/sparsedata.py +++ b/eitprocessing/datahandling/sparsedata.py @@ -11,7 +11,7 @@ T = TypeVar("T", bound="SparseData") -@dataclass(eq=False) +@dataclass(eq=False, frozen=True) class SparseData(DataContainer, SelectByTime): """Container for data related to individual time points. From 51e3cf3a5255b4b956e8b1fc264cca5054bc6491 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 9 Feb 2026 14:22:26 +0100 Subject: [PATCH 25/32] Remove NamedTupleArray and replace with StructuredArray --- .../datahandling/namedtuple_array.py | 575 ---------- .../datahandling/structured_array.py | 993 ++++++++++++++++++ 2 files changed, 993 insertions(+), 575 deletions(-) delete mode 100644 eitprocessing/datahandling/namedtuple_array.py create mode 100644 eitprocessing/datahandling/structured_array.py diff --git a/eitprocessing/datahandling/namedtuple_array.py b/eitprocessing/datahandling/namedtuple_array.py deleted file mode 100644 index 4f730ed06..000000000 --- a/eitprocessing/datahandling/namedtuple_array.py +++ /dev/null @@ -1,575 +0,0 @@ -# %% -"""Array-like interface over NamedTuple collections enabling NumPy slicing. - -Motivation ----------- -Some data is best represented with multiple related data points (e.g., the start, middle and end time of a breath). They -can be represented as NamedTuple instances; lightweight containers that group related fields together. Lists (or tuples) -of NamedTuples are more difficult to handle efficiently compared to NumPy arrays, especially in multi-dimensional cases. -When NamedTuples are collected inside NumPy arrays, however, they loose their NamedTuple context, removing the field -names and data types. - -This module provides NamedTupleArray, a container that wraps homogeneous collections of NamedTuple instances into a -NumPy structured array, preserving the NamedTuple field names and types while enabling NumPy-style slicing and -field-wise access. It allows access to fields and even computed properties by name. - -Key features ------------- -- Homogeneous type checking: ensures all items share the same NamedTuple type. -- Safe field views: returns read-only views for direct field access. -- Property evaluation: computes per-item properties, resolving postponed - annotations to pick appropriate NumPy dtypes. -- Shape preservation: supports nested sequences, maintaining their shape in the - structured array. -- Interop: from_ndarray helper to map last-axis columns to NamedTuple fields. - -Example: -```python -class Coordinate(NamedTuple): - x: float - y: float - z: float - - @property - def r(self) -> float: - \"\"\"The radial distance from the origin.\"\"\" - return (self.x**2 + self.y**2 + self.z**2) ** 0.5 - -coords = [Coordinate(1.0, 2.0, 2.0), Coordinate(3.0, 4.0, 0.0), Coordinate(0.0, 0.0, 5.0)] -arr = NamedTupleArray(coords) - -arr[1:] # Slice of Coordinates -# NamedTupleArray[Coordinate](array([(3., 4., 0.), (0., 0., 5.)], -# dtype=[('x', ' returns a read-only view of the field. - - Property/attribute: arr["duration"] -> computes the property for each - item, producing an ndarray. If the property has a type hint, the result - dtype is chosen accordingly (e.g., int -> int64, float -> float64). - Otherwise, a heuristic casts ints to int64, else tries float, else object. - - Direct array access via items property: arr.items (always returns the - underlying NumPy array; if frozen=True, modifications are prevented). - - Notes: - ----- - - Homogeneity: All elements must be the same NamedTuple type. - - String fields are kept as object dtype to avoid truncation. - - Properties are evaluated per element; heavy properties may be costly. - - Field views are always read-only to prevent accidental mutation. - - The .items property returns the underlying array; modifications are only - prevented if frozen=True was passed during construction. - - Example: - ------- - >>> class Breath(NamedTuple): - ... start_time: float - ... middle_time: float - ... end_time: float - ... @property - ... def duration(self) -> float: - ... return self.end_time - self.start_time - ... - >>> breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.6, 2.1)] - >>> arr = NamedTupleArray(breaths) - >>> arr["duration"] - array([1. , 1.1]) - """ - - namedtuple_type: type[T] - __items: np.ndarray - - def __init__( - self, - items: NonStringSeq[T] | np.ndarray | Nested[T], - namedtuple_type: type[T] | None = None, - frozen: bool = True, - ): - """Initialize a NamedTupleArray from a sequence or nested sequence of NamedTuple items. - - Args: - items: A sequence (or nested sequence) of NamedTuple instances, or a numpy ndarray containing them. - namedtuple_type: - Optional explicit type of the NamedTuple items. If not provided, it will be inferred from the first leaf - item. - frozen: If True (default), makes the underlying array immutable. - """ - if namedtuple_type is not None: - self.namedtuple_type = namedtuple_type - else: - if isinstance(items, np.ndarray) and items.size == 0: - msg = "Cannot infer type from empty array." - raise ValueError(msg) - if not isinstance(items, np.ndarray) and not items: - msg = "Cannot infer type from empty sequence." - raise ValueError(msg) - - # Infer NT type from first leaf element - leaf = _first_leaf(items) - self.namedtuple_type = type(leaf) # type: ignore[assignment] - - # Validate homogeneity - _check_homogeneous(items, self.namedtuple_type) - - # Build structured dtype and array with same shape - dt = _get_tuple_dtype(self.namedtuple_type) - self.__items = np.asarray(items, dtype=dt) - - if frozen: - self._freeze() - - object.__setattr__(self, "_initialized", True) - - def _freeze(self) -> None: - """Make the underlying array immutable.""" - dt = _get_tuple_dtype(self.namedtuple_type) - freeze_method = "flag" if dt.hasobject else "memoryview" - self.__items = freeze_array(self.__items, method=freeze_method) - - def __setattr__(self, name: str, value: object) -> None: - """Allow setting type and __items only during initialization; block modification after.""" - # Check if initialization is complete (use object.__getattribute__ to bypass our __getattr__) - try: - initialized = object.__getattribute__(self, "_initialized") - except AttributeError: - initialized = False - - # Allow setting type and __items only during initialization - if not initialized and name in ("namedtuple_type", "_NamedTupleArray__items"): - super().__setattr__(name, value) - elif initialized and name in ("namedtuple_type", "_NamedTupleArray__items"): - msg = f"{type(self).__name__!r} object is immutable; cannot modify {name!r} after initialization." - raise AttributeError(msg) - else: - msg = f"{type(self).__name__!r} object is immutable; cannot set attribute {name!r}." - raise AttributeError(msg) - - @classmethod - def from_array(cls, arr: np.ndarray | Nested, namedtuple_type: type[T], frozen: bool = True) -> NamedTupleArray[T]: - """Build a NamedTupleArray from an unstructured numpy array or nested list. - - The list must be convertible to a numpy array. The last axis of the array is mapped to the fields. - The length of the last axis must equal to the number of fields in the given NamedTuple type. - - Example: - This examples represents a sequence of 10 breaths for each of 32x32 pixels. Each breath contains 3 fields: - start_time, middle_time, end_time. - - ```python - breath_data = load_breath_data() # shape (10, 32, 32, 3) - breaths = NamedTupleArray.from_ndarray(breath_data, Breath) - ``` - - This is equivalent to a list of 10 nested lists, each containing 32 lists (rows) of 32 (columns) Breath - objects. - """ - if not isinstance(arr, np.ndarray): - arr = np.array(arr) - - if arr.ndim < 1: - msg = "arr must have at least 1 dimension." - raise ValueError(msg) - n_fields = len(namedtuple_type._fields) - if (lal := arr.shape[-1]) != n_fields: - msg = f"Last axis must have size {n_fields} for {namedtuple_type.__name__}, not {lal}." - raise ValueError(msg) - dt = _get_tuple_dtype(namedtuple_type) - out = np.empty(arr.shape[:-1], dtype=dt) - - if not dt.fields: - msg = "Generated dtype has no fields; cannot proceed." - raise RuntimeError(msg) - - fields = cast("dict[str, tuple[np.dtype, int]]", dt.fields) - for i, name in enumerate(namedtuple_type._fields): - # Cast each column to the target field dtype to avoid unintended promotion - target_dt = fields[name][0] - out[name] = arr[..., i].astype(target_dt, copy=False) - - inst = cls.__new__(cls) - inst.namedtuple_type = namedtuple_type - inst._NamedTupleArray__items = out # noqa: SLF001 - - if frozen: - inst._freeze() # noqa: SLF001 - - return inst - - @property - def shape(self) -> tuple[int, ...]: - """The shape of the NamedTupleArray.""" - return self.__items.shape - - @property - def ndim(self) -> int: - """The number of dimensions of the NamedTupleArray.""" - return self.__items.ndim - - @property - def dtype(self) -> np.dtype: - """The dtype of the underlying structured array.""" - return self.__items.dtype - - @property - def items(self) -> np.ndarray: - """The underlying NumPy structured array. - - Returns the private array. If this instance was created with frozen=True, - modifications via this reference are prevented. Otherwise, modifications - are allowed. - """ - return self.__items - - @property - def flags(self) -> flagsobj: - """The flags of the underlying structured array. - - If this instance was created with frozen=True, the WRITEABLE flag cannot - be changed. Otherwise, the flags are fully mutable. - """ - return self.__items.flags - - def to_array(self) -> np.ndarray: - """Convert to an unstructured numpy array. - - Returns a 2D array where each column corresponds to a field of the NamedTuple, - in the order of the NamedTuple fields. This allows convenient slicing by - column indices like `arr[:, [0, 2]]`. - - Returns: - A 2D unstructured numpy array of shape (n_items, n_fields). - - Example: - >>> class Point(NamedTuple): - ... x: float - ... y: float - ... z: float - >>> nta = NamedTupleArray([Point(1.0, 2.0, 3.0), Point(4.0, 5.0, 6.0)]) - >>> arr = nta.to_array() - >>> arr.shape - (2, 3) - >>> arr[:, [0, 2]] # Get x and z columns - array([[1., 3.], - [4., 6.]]) - """ - # Stack each field as a column to create unstructured array - if not self.__items.dtype.names: - # No fields, return empty array - return np.empty((self.shape[0], 0)) - - return np.column_stack([self.__items[name] for name in self.__items.dtype.names]) - - def __getattr__(self, attr: str): - """Block access to the private array. - - All array attributes should be accessed via explicit properties. - This prevents users from bypassing immutability controls. - """ - msg = f"{type(self).__name__!r} object has no attribute {attr!r}" - raise AttributeError(msg) - - def __array__(self, dtype: np.dtype | None = None) -> np.ndarray: - return self.__items.astype(dtype) if dtype is not None else self.__items - - def __iter__(self) -> Generator[T | NamedTupleArray[T], None, None]: - if self.ndim == 1: - for item in self.__items: - yield self.namedtuple_type(*item) # type: ignore[call-arg] - else: - # yield structured subarrays along axis 0 - for i in range(self.__items.shape[0]): - out = NamedTupleArray.__new__(NamedTupleArray) - out.namedtuple_type = self.namedtuple_type - out._NamedTupleArray__items = self.__items[i] # noqa: SLF001 - yield out - - def __len__(self) -> int: - return self.__items.shape[0] if self.__items.ndim > 0 else 0 - - def __repr__(self) -> str: - return f"NamedTupleArray[{self.namedtuple_type.__name__}]{repr(self.__items).removeprefix('array')}" - - def __eq__(self, other: object) -> bool: - """Compare two NamedTupleArray instances for equality. - - Two NamedTupleArray instances are equal if: - - They are both NamedTupleArray instances - - They have the same NamedTuple type - - Their underlying arrays are equal (including NaN equality for floats) - """ - if not isinstance(other, NamedTupleArray): - return False - - if self.namedtuple_type is not other.namedtuple_type: - return False - - # Compare shapes - if self.__items.shape != other.__items.shape: - return False - - # Compare dtypes - if self.__items.dtype != other.__items.dtype: - return False - - # For structured arrays, compare field by field to handle NaN values properly - for name in self.__items.dtype.names or []: - self_field = self.__items[name] - other_field = other.__items[name] - - # Use array_equal with equal_nan for each field - if not np.array_equal(self_field, other_field, equal_nan=True): - return False - - return True - - __hash__ = None # type: ignore[assignment] - - @overload - def __getitem__(self, index: str) -> np.ndarray: ... - - @overload - def __getitem__(self, index: int) -> T: ... - - @overload - def __getitem__(self, index: slice) -> NamedTupleArray[T]: ... - - @overload - def __getitem__(self, index: NonStringSeq) -> NamedTupleArray[T]: ... - - def __getitem__(self, index: str | int | slice | NonStringSeq) -> np.ndarray | NamedTupleArray[T] | T: - # Field-name access: return field view - if isinstance(index, str): - names = self.__items.dtype.names or () - if index in names: - view = self.__items[index] - # Ensure field view is read-only - with contextlib.suppress(Exception): - view.flags.writeable = False - return view - # Computed property or attribute on the NT → compute over all items - return self._compute_property(index) - - # NumPy-style indexing - result = self.__items[index] - - # Structured scalar (np.void) → return NamedTuple - if isinstance(result, np.void): - # For structured np.void, convert to NamedTuple - return self.namedtuple_type(*result.tolist()) # type: ignore[call-arg] - - # Zero-d structured ndarray (shape == ()) → convert to NamedTuple - if isinstance(result, np.ndarray) and result.dtype.fields is not None and result.ndim == 0: - scalar = result.item() # np.void - return self.namedtuple_type(*scalar.tolist()) # type: ignore[call-arg] - - # Structured ndarray → wrap - if isinstance(result, np.ndarray) and result.dtype.fields is not None: - out: NamedTupleArray[T] = type(self).__new__(type(self)) - out.namedtuple_type = self.namedtuple_type - out._NamedTupleArray__items = result - return out - - # Non-structured ndarray (e.g. field slice) → return as-is - return result - - def _compute_property(self, attr: str) -> np.ndarray: - """Compute a property or attribute across all items, preserving the array shape.""" - # Verify attribute exists on the NT instance - sample = self.namedtuple_type(*self.__items.flat[0].tolist()) # type: ignore[call-arg] - if not hasattr(sample, attr): - msg = f"Field or property '{attr}' not found in NamedTuple." - raise KeyError(msg) - - # Collect values (single pass using flat indexing) - out_obj = np.empty(self.shape, dtype=object) - for i, rec in enumerate(self.__items.reshape(-1)): - nt = self.namedtuple_type(*rec.tolist()) # type: ignore[call-arg] - out_obj.reshape(-1)[i] = getattr(nt, attr) - - # Determine target dtype from property annotation if available (handles postponed annotations) - target_dtype: np.dtype | None = None - attr_member = getattr(self.namedtuple_type, attr, None) - if isinstance(attr_member, property) and attr_member.fget is not None: - with contextlib.suppress(Exception): - hints = get_type_hints(attr_member.fget) - ret_ann = hints.get("return") - if ret_ann is not None: - target = _python_to_np_dtype(ret_ann) - target_dtype = np.dtype(target) - - # Cast accordingly - if target_dtype is not None and target_dtype != np.dtype(object): - with contextlib.suppress(Exception): - return out_obj.astype(target_dtype) - - # Heuristics: ints → int64; floats → float64; numpy scalar families respected - with contextlib.suppress(Exception): - if all(isinstance(v, (int, np.integer)) for v in out_obj.flat): - return out_obj.astype(np.int64) - with contextlib.suppress(Exception): - if all(isinstance(v, (float, np.floating)) for v in out_obj.flat): - return out_obj.astype(np.float64) - - return out_obj - - -def _first_leaf( - seq: NamedTuple | np.ndarray | list[NamedTuple] | tuple[NamedTuple, ...] | Nested[NamedTuple], -) -> NamedTuple: - """Recursively find the first NamedTuple instance in a nested sequence or ndarray.""" - if _is_namedtuple_instance(seq): - return seq - if isinstance(seq, np.ndarray): - if seq.size == 0: - msg = "Cannot infer type from empty ndarray." - raise ValueError(msg) - return _first_leaf(seq.flat[0]) - if isinstance(seq, (list, tuple)): - if not seq: - msg = "Cannot infer type from empty nested sequence." - raise ValueError(msg) - return _first_leaf(seq[0]) - - msg = "Items must be NamedTuple or nested sequences thereof." - raise TypeError(msg) - - -def _check_homogeneous(seq: SequenceType[NamedTuple] | np.ndarray | Nested[T], typ: type[T]) -> None: - """Recursively check that all NamedTuple instances in the nested sequence/ndarray are of the given type.""" - if isinstance(seq, np.ndarray): - for it in seq.flat: - _check_homogeneous(it, typ) - return - if isinstance(seq, (list, tuple)) and not _is_namedtuple_instance(seq): - seq_ = cast("SequenceType[NamedTuple | Nested[T]]", seq) - for it in seq_: - it: NamedTuple | Nested[T] - _check_homogeneous(it, typ) - return - if _is_namedtuple_instance(seq): - if type(seq) is not typ: - msg = "All items must be of the same NamedTuple type." - raise ValueError(msg) - return - msg = "Items must be NamedTuple or nested sequences thereof." - raise TypeError(msg) - - -def _python_to_np_dtype(py: type) -> np.dtype | str: - """Map basic Python types to NumPy dtypes.""" - if py is int: - return "i8" - if py is float: - return "f8" - if py is bool: - return "?" - if py is str: - return np.dtype(object) # keep Python str; avoids truncation - - if get_origin(py) is Union: - args = [a for a in get_args(py) if a is not type(None)] - if len(args) == 1: - return _python_to_np_dtype(args[0]) - return np.dtype(object) - - -def _is_namedtuple_instance(item: object) -> TypeGuard[NamedTuple]: - """Check if item is a NamedTuple instance.""" - return isinstance(item, tuple) and hasattr(item, "_fields") - - -def _is_namedtuple_type(item: object) -> TypeGuard[type[NamedTuple]]: - """Check if item is a NamedTuple type.""" - return isinstance(item, type) and issubclass(item, tuple) and hasattr(item, "_fields") - - -def _get_tuple_dtype(item: NamedTuple | type[tuple]) -> np.dtype: - """Generate a NumPy structured dtype from a NamedTuple type.""" - if _is_namedtuple_instance(item): - item = type(item) - if not _is_namedtuple_type(item): - msg = "item must be a NamedTuple or a NamedTuple type." - raise TypeError(msg) - - hints = get_type_hints(item, include_extras=False) - names_in_order = list(item._fields) - - def to_np_dtype(py: type) -> str | np.dtype: - if py is int: - return "i8" - if py is float: - return "f8" - if py is bool: - return "?" - # everything else (incl. str) → object to avoid truncation - return np.dtype(object) - - fields = [(name, to_np_dtype(hints.get(name, object))) for name in names_in_order] - return np.dtype(fields) diff --git a/eitprocessing/datahandling/structured_array.py b/eitprocessing/datahandling/structured_array.py new file mode 100644 index 000000000..defb47a92 --- /dev/null +++ b/eitprocessing/datahandling/structured_array.py @@ -0,0 +1,993 @@ +# %% +"""Array-like interface over NamedTuple and dataclass collections enabling NumPy slicing. + +Motivation +---------- +Some data is best represented with multiple related data points (e.g., the start, middle and end time of a breath). They +can be represented as NamedTuple or dataclass instances; lightweight containers that group related fields together. +Lists (or tuples) of such instances are more difficult to handle efficiently compared to NumPy arrays, especially in +multi-dimensional cases. When such instances are collected inside NumPy arrays, however, they lose their context, +removing the field names and data types. + +This module provides StructuredArray, a container that wraps homogeneous collections of NamedTuple or dataclass +instances into a NumPy structured array, preserving field names and types while enabling NumPy-style slicing and +field-wise access. It allows access to fields and even computed properties by name. Dataclass instances benefit from +automatic __post_init__ validation on access. + +Key features +------------ +- Homogeneous type checking: ensures all items share the same NamedTuple or dataclass type. +- Safe field views: returns read-only views for direct field access. +- Property evaluation: computes per-item properties, resolving postponed + annotations to pick appropriate NumPy dtypes. +- Shape preservation: supports nested sequences, maintaining their shape in the + structured array. +- Validation support: dataclass __post_init__ is called on item access. +- Interop: from_array helper to map array columns to fields. + +Example with NamedTuple: +```python +from typing import NamedTuple + +class Coordinate(NamedTuple): + x: float + y: float + z: float + + @property + def r(self) -> float: + \"\"\"The radial distance from the origin.\"\"\" + return (self.x**2 + self.y**2 + self.z**2) ** 0.5 + +coords = [Coordinate(1.0, 2.0, 2.0), Coordinate(3.0, 4.0, 0.0), Coordinate(0.0, 0.0, 5.0)] +arr = StructuredArray(coords) + +arr[0] # Access a single Coordinate +# Coordinate(x=1.0, y=2.0, z=2.0) +arr["x"] # Access x field across all Coordinates +# array([1., 3., 0.]) +arr["r"] # Access computed property across all Coordinates +# array([3., 5., 5.]) +``` + +Example with dataclass: +```python +from dataclasses import dataclass + +@dataclass +class Point: + x: float + y: float + + def __post_init__(self): + if self.x < 0 or self.y < 0: + raise ValueError("Coordinates must be non-negative") + +points = [Point(1.0, 2.0), Point(3.0, 4.0)] +arr = StructuredArray(points) # Validation runs on item construction/access +arr[0] # Point(x=1.0, y=2.0) +``` +""" + +from __future__ import annotations + +import contextlib +import warnings +from dataclasses import fields as dc_fields +from dataclasses import is_dataclass +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Generic, + NamedTuple, + TypeAlias, + TypeGuard, + TypeVar, + Union, + cast, + get_args, + get_origin, + get_type_hints, + overload, +) + +import numpy as np + +from eitprocessing.utils.frozen_array import freeze_array + +if TYPE_CHECKING: + from collections.abc import Generator + from collections.abc import Sequence as SequenceType + + from numpy._core.multiarray import flagsobj + + +T = TypeVar("T", bound=tuple | object) # NamedTuple or dataclass instance +NonStringSeq: TypeAlias = tuple[T, ...] | list[T] +Nested = T | NonStringSeq + +# Mutable types that should trigger warnings when used in frozen dataclasses +MUTABLE_TYPES: tuple[type, ...] = (list, dict, set, bytearray, np.ndarray) + +# Immutable types that are allowed in dataclass fields +# Note: np.generic is included to allow all numpy scalar types (np.int64, np.float32, etc.) +ALLOWED_IMMUTABLE_TYPES: tuple[type, ...] = (str, int, float, bool, bytes, tuple, frozenset, Path, np.generic) + +# All allowed types (used for disallowed type checking) +ALLOWED_TYPES: tuple[type, ...] = ALLOWED_IMMUTABLE_TYPES + MUTABLE_TYPES + + +class StructuredArray(Generic[T]): + """An array-like container for homogeneous NamedTuple or dataclass instances. + + Overview + -------- + StructuredArray wraps a sequence (or nested sequence) of NamedTuple or dataclass items + into a NumPy structured ndarray, enabling: + - NumPy-style indexing and slicing (preserving shape). + - Field access by name that returns read-only NumPy views. + - Computation of per-item properties or attributes, returning a NumPy array + with dtype inferred from the property's type annotation when available. + - Immutability control: pass frozen=True to prevent all array modifications. + - Validation support: dataclass __post_init__ runs on item access. + + Construction + ------------ + - From a sequence: StructuredArray([item1, item2, ...]) + Validates that all items are of the same NamedTuple or dataclass type. + Type is inferred from the first item. For empty sequences, provide item_type explicitly. + - From a sequence with explicit type: StructuredArray([...], item_type=MyType) + Allows type specification upfront, enabling empty sequences and ensuring type safety. + - From an ndarray: StructuredArray.from_array(arr, ItemType) + 'arr' must have last axis equal to the number of fields in ItemType. + Columns along the last axis are mapped to item fields. + - Use StructuredArray(..., frozen=True) or StructuredArray.from_array(..., frozen=True) to make the array + immutable (prevents all modifications). + + Access + ------ + - Field by name: arr["x"] -> returns a read-only view of the field. + - Property/attribute: arr["duration"] -> computes the property for each + item, producing an ndarray. If the property has a type hint, the result + dtype is chosen accordingly (e.g., int -> int64, float -> float64). + Otherwise, a heuristic casts ints to int64, else tries float, else object. + - Direct array access via items property: arr.items (always returns the + underlying NumPy array; if frozen=True, modifications are prevented). + + Notes: + ----- + - Homogeneity: All elements must be the same NamedTuple or dataclass type. + - String fields are kept as object dtype to avoid truncation. + - Properties are evaluated per element; heavy properties may be costly. + - Field views are always read-only to prevent accidental mutation. + - The .items property returns the underlying array; modifications are only + prevented if frozen=True was passed during construction. + - For dataclasses: __post_init__ is called when accessing items via iteration + or indexing to ensure validation happens. + + Example with NamedTuple: + ------- + >>> from typing import NamedTuple + >>> class Breath(NamedTuple): + ... start_time: float + ... middle_time: float + ... end_time: float + ... @property + ... def duration(self) -> float: + ... return self.end_time - self.start_time + ... + >>> breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.6, 2.1)] + >>> arr = StructuredArray(breaths) + >>> arr["duration"] + array([1. , 1.1]) + + Example with dataclass: + ------- + >>> from dataclasses import dataclass + >>> @dataclass + ... class Event: + ... time: float + ... value: float + ... + ... def __post_init__(self): + ... if self.value < 0: + ... raise ValueError("value must be positive") + ... + >>> events = [Event(0.0, 1.0), Event(1.0, 2.0)] + >>> arr = StructuredArray(events) + >>> arr[0] # Runs __post_init__ validation + Event(time=0.0, value=1.0) + + Example with explicit type for empty list: + ------- + >>> # Create an empty StructuredArray with explicit type + >>> arr = StructuredArray([], item_type=Event) + >>> len(arr) + 0 + """ + + item_type: type[T] + _items: np.ndarray + _is_dataclass: bool + frozen: bool + + def __init__( + self, + items: NonStringSeq[T] | np.ndarray | Nested[T], + item_type: type[T] | None = None, + frozen: bool = True, + ): + """Initialize a StructuredArray from a sequence or nested sequence of items. + + Args: + items: A sequence (or nested sequence) of NamedTuple or dataclass instances, + or a numpy ndarray containing them. + item_type: + Optional explicit type of the items. If provided, all items are validated + against this type. If not provided, the type is inferred from the first + leaf item. Required when items is empty. + frozen: If True (default), makes the underlying array immutable. + + Raises: + ValueError: If items is empty and item_type is None, or if homogeneity check fails. + TypeError: If items contain unsupported types. + """ + if item_type is not None: + self.item_type = item_type + else: + if isinstance(items, np.ndarray) and items.size == 0: + msg = "Cannot infer type from empty array. Provide item_type explicitly." + raise ValueError(msg) + if not isinstance(items, np.ndarray) and not items: + msg = "Cannot infer type from empty sequence. Provide item_type explicitly." + raise ValueError(msg) + + # Infer type from first leaf element + leaf = _first_leaf(items) + self.item_type = type(leaf) # type: ignore[assignment] + + # Validate homogeneity + _check_homogeneous(items, self.item_type) + + # Determine if item_type is a dataclass + self._is_dataclass = is_dataclass(self.item_type) + + # Build structured dtype and array with same shape + dt = _get_struct_dtype(self.item_type) + + # For dataclasses, numpy.asarray doesn't know how to convert them directly, + # so we convert to tuples first (via named tuples or tuples of field values) + if self._is_dataclass: + items = cast("NonStringSeq[T] | np.ndarray | Nested[T]", _dataclass_items_to_tuples(items, self.item_type)) + + self._items = np.asarray(items, dtype=dt) + + # Check for disallowed field types in dataclass + if self._is_dataclass: + disallowed_fields = _get_disallowed_field_names(self.item_type) + if disallowed_fields: + warnings.warn( + f"Dataclass '{self.item_type.__name__}' has fields with disallowed types: " + f"{', '.join(disallowed_fields)}. " + f"Allowed types are: str, int, float, bool, bytes, tuple, frozenset, pathlib.Path, " + f"numpy types, frozen dataclasses, and NamedTuples.", + UserWarning, + stacklevel=2, + ) + + self.frozen = frozen + if frozen: + # Check if dataclass is not frozen while array is + if self._is_dataclass and not _is_dataclass_frozen(self.item_type): + warnings.warn( + f"StructuredArray is frozen, but the underlying dataclass " + f"'{self.item_type.__name__}' is not. Items accessed from the array " + f"can still be modified. Consider freezing the dataclass with " + f"@dataclass(frozen=True).", + UserWarning, + stacklevel=2, + ) + # Check if dataclass has mutable fields (independent check - can warn even if dataclass is frozen) + if self._is_dataclass and _has_mutable_fields(self.item_type): + warnings.warn( + f"StructuredArray is frozen, but the dataclass '{self.item_type.__name__}' has mutable fields " + f"(list, dict, set, etc.). The contents of these fields can still be modified. " + f"Consider using immutable types (tuple) instead.", + UserWarning, + stacklevel=2, + ) + self._freeze() + + object.__setattr__(self, "_initialized", True) + + def _freeze(self) -> None: + """Make the underlying array immutable.""" + dt = _get_struct_dtype(self.item_type) + freeze_method = "flag" if dt.hasobject else "memoryview" + self._items = freeze_array(self._items, method=freeze_method) + + def __setattr__(self, name: str, value: object) -> None: + """Allow setting type and _items only during initialization; block modification after.""" + # Check if initialization is complete (use object.__getattribute__ to bypass our __getattr__) + try: + initialized = object.__getattribute__(self, "_initialized") + except AttributeError: + initialized = False + + # Allow setting type, _items, and _is_dataclass only during initialization + if not initialized and name in ("item_type", "_items", "_is_dataclass", "frozen"): + super().__setattr__(name, value) + elif initialized and name in ("item_type", "_items", "_is_dataclass", "frozen"): + msg = f"{type(self).__name__!r} object is immutable; cannot modify {name!r} after initialization." + raise AttributeError(msg) + else: + msg = f"{type(self).__name__!r} object is immutable; cannot set attribute {name!r}." + raise AttributeError(msg) + + @classmethod + def from_array(cls, arr: np.ndarray | Nested, item_type: type[T], frozen: bool = True) -> StructuredArray[T]: # noqa: C901 + """Build a StructuredArray from an unstructured numpy array or nested list. + + The list must be convertible to a numpy array. The last axis of the array is mapped to the fields. + The length of the last axis must equal to the number of fields in the given type. + + Example: + This examples represents a sequence of 10 breaths for each of 32x32 pixels. Each breath contains 3 fields: + start_time, middle_time, end_time. + + ```python + breath_data = load_breath_data() # shape (10, 32, 32, 3) + breaths = StructuredArray.from_array(breath_data, Breath) + ``` + + This is equivalent to a list of 10 nested lists, each containing 32 lists (rows) of 32 (columns) items. + """ + if not isinstance(arr, np.ndarray): + arr = np.array(arr) + + if arr.ndim < 1: + msg = "array must have at least 1 dimension." + raise ValueError(msg) + + n_fields = _get_n_fields(item_type) + if (lal := arr.shape[-1]) != n_fields: + msg = f"Last axis must have size {n_fields} for {item_type.__name__}, not {lal}." + raise ValueError(msg) + + dt = _get_struct_dtype(item_type) + out = np.empty(arr.shape[:-1], dtype=dt) + + if not dt.fields: + msg = "Generated dtype has no fields; cannot proceed." + raise RuntimeError(msg) + + fields = cast("dict[str, tuple[np.dtype, int]]", dt.fields) + field_names = _get_field_names(item_type) + for i, name in enumerate(field_names): + # Cast each column to the target field dtype to avoid unintended promotion + target_dt = fields[name][0] + out[name] = arr[..., i].astype(target_dt, copy=False) + + inst = cls.__new__(cls) + inst.item_type = item_type + inst._is_dataclass = is_dataclass(item_type) # noqa: SLF001 + inst._items = out # noqa: SLF001 + + if inst._is_dataclass and "__post_init__" in dir(inst.item_type): # noqa: SLF001 + # Run __post_init__ for all items to ensure validation + for record in inst._items.flat: # noqa: SLF001 + inst._reconstruct_item(record) # noqa: SLF001 + + # Check for disallowed field types in dataclass + if inst._is_dataclass: # noqa: SLF001 + disallowed_fields = _get_disallowed_field_names(item_type) + if disallowed_fields: + warnings.warn( + f"Dataclass '{item_type.__name__}' has fields with disallowed types: " + f"{', '.join(disallowed_fields)}. " + "Allowed types are: str, int, float, bool, bytes, tuple, frozenset, pathlib.Path, " + "numpy types, frozen dataclasses, and NamedTuples.", + UserWarning, + stacklevel=2, + ) + + inst.frozen = frozen + if frozen: + # Check if dataclass is not frozen while array is + if inst._is_dataclass and not _is_dataclass_frozen(item_type): # noqa: SLF001 + warnings.warn( + "StructuredArray is frozen, but the underlying dataclass " + f"'{item_type.__name__}' is not. Consider freezing the dataclass with " + "@dataclass(frozen=True).", + UserWarning, + stacklevel=2, + ) + # Check if dataclass has mutable fields (independent check - can warn even if dataclass is frozen) + if inst._is_dataclass and _has_mutable_fields(item_type): # noqa: SLF001 + warnings.warn( + f"StructuredArray is frozen, but the dataclass '{item_type.__name__}' has mutable fields " + f"(list, dict, set, etc.). The contents of these fields can still be modified. " + f"Consider using immutable types (tuple) instead.", + UserWarning, + stacklevel=2, + ) + inst._freeze() # noqa: SLF001 + + object.__setattr__(inst, "_initialized", True) + + return inst + + @property + def shape(self) -> tuple[int, ...]: + """The shape of the StructuredArray.""" + return self._items.shape + + @property + def ndim(self) -> int: + """The number of dimensions of the StructuredArray.""" + return self._items.ndim + + @property + def dtype(self) -> np.dtype: + """The dtype of the underlying structured array.""" + return self._items.dtype + + @property + def items(self) -> np.ndarray: + """The underlying NumPy structured array. + + Returns the private array. If this instance was created with frozen=True, + modifications via this reference are prevented. Otherwise, modifications + are allowed. + """ + return self._items + + @property + def flags(self) -> flagsobj: + """The flags of the underlying structured array. + + If this instance was created with frozen=True, the WRITEABLE flag cannot + be changed. Otherwise, the flags are fully mutable. + """ + return self._items.flags + + def to_array(self) -> np.ndarray: + """Convert to an unstructured numpy array. + + Returns a 2D array where each column corresponds to a field of the item type, + in field order. This allows convenient slicing by column indices like + `arr[:, [0, 2]]`. + + Returns: + A 2D unstructured numpy array of shape (n_items, n_fields). + + Example: + >>> from typing import NamedTuple + >>> class Point(NamedTuple): + ... x: float + ... y: float + ... z: float + >>> arr = StructuredArray([Point(1.0, 2.0, 3.0), Point(4.0, 5.0, 6.0)]) + >>> arr_2d = arr.to_array() + >>> arr_2d.shape + (2, 3) + >>> arr_2d[:, [0, 2]] # Get x and z columns + array([[1., 3.], + [4., 6.]]) + """ + # Stack each field as a column to create unstructured array + if not self._items.dtype.names: + # No fields, return empty array + return np.empty((self.shape[0], 0)) + + return np.column_stack([self._items[name] for name in self._items.dtype.names]) + + def __array__(self, dtype: np.dtype | None = None) -> np.ndarray: + return self._items.astype(dtype) if dtype is not None else self._items + + def __iter__(self) -> Generator[T | StructuredArray[T], None, None]: + if self.ndim == 1: + for item in self._items: + yield self._reconstruct_item(item) + else: + # yield structured subarrays along axis 0 + for i in range(self._items.shape[0]): + out = StructuredArray.__new__(StructuredArray) + out.item_type = self.item_type + out._is_dataclass = self._is_dataclass # noqa: SLF001 + out._items = self._items[i] # noqa: SLF001 + yield out + + def __len__(self) -> int: + return self._items.shape[0] if self._items.ndim > 0 else 0 + + def __repr__(self) -> str: + return f"StructuredArray[{self.item_type.__name__}]{repr(self._items).removeprefix('array')}" + + def __eq__(self, other: object) -> bool: + """Compare two StructuredArray instances for equality. + + Two StructuredArray instances are equal if: + - They are both StructuredArray instances + - They have the same item type + - Their underlying arrays are equal (including NaN equality for floats) + """ + if not isinstance(other, StructuredArray): + return False + + if self.item_type is not other.item_type: + return False + + # Compare shapes + if self._items.shape != other._items.shape: + return False + + # Compare dtypes + if self._items.dtype != other._items.dtype: + return False + + # For structured arrays, compare field by field to handle NaN values properly + for name in self._items.dtype.names or []: + self_field = self._items[name] + other_field = other._items[name] + + # Use array_equal with equal_nan for each field + if not np.array_equal(self_field, other_field, equal_nan=True): + return False + + return True + + __hash__ = None # type: ignore[assignment] + + def __add__(self, other: StructuredArray[T]) -> StructuredArray[T]: + if not isinstance(other, StructuredArray): + msg = f"Can only concatenate StructuredArray (not '{type(other).__name__}') to StructuredArray." + raise TypeError(msg) + + if self.item_type is not other.item_type: + msg = "Cannot concatenate StructuredArray with different item types." + raise TypeError(msg) + + new_items = np.concatenate((self._items, other._items), axis=0) + frozen = self.frozen or other.frozen + + # Create a new StructuredArray directly with the concatenated structured array + inst = self.__class__.__new__(self.__class__) + inst.item_type = self.item_type + inst._is_dataclass = self._is_dataclass + inst._items = new_items + inst.frozen = frozen + + if frozen: + inst._freeze() + + object.__setattr__(inst, "_initialized", True) + return inst + + @overload + def __getitem__(self, index: str) -> np.ndarray: ... + + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> StructuredArray[T]: ... + + @overload + def __getitem__(self, index: NonStringSeq) -> StructuredArray[T]: ... + + def __getitem__(self, index: str | int | slice | NonStringSeq) -> np.ndarray | StructuredArray[T] | T: + # Field-name access: return field view + if isinstance(index, str): + names = self._items.dtype.names or () + if index in names: + view = self._items[index] + # Ensure field view is read-only + with contextlib.suppress(Exception): + view.flags.writeable = False + return view + # Computed property or attribute on the item type → compute over all items + return self._compute_property(index) + + # NumPy-style indexing + result = self._items[index] + + # Structured scalar (np.void) → return reconstructed item + if isinstance(result, np.void): + return self._reconstruct_item(result) + + # Zero-d structured ndarray (shape == ()) → convert to item + if isinstance(result, np.ndarray) and result.dtype.fields is not None and result.ndim == 0: + scalar = result.item() # np.void + return self._reconstruct_item(scalar) + + # Structured ndarray → wrap + if isinstance(result, np.ndarray) and result.dtype.fields is not None: + out: StructuredArray[T] = type(self).__new__(type(self)) + out.item_type = self.item_type + out._is_dataclass = self._is_dataclass + out._items = result + out.frozen = self.frozen + object.__setattr__(out, "_initialized", True) + return out + + # Non-structured ndarray (e.g. field slice) → return as-is + return result + + def _reconstruct_item(self, record: np.void) -> T: + """Reconstruct an item (NamedTuple or dataclass) from a numpy void record. + + For dataclasses, this calls __post_init__ to ensure validation runs. + """ + values = record.tolist() + + return self.item_type(*values) + + def _compute_property(self, attr: str) -> np.ndarray: + """Compute a property or attribute across all items, preserving the array shape.""" + # Verify attribute exists on a sample item + sample_record = self._items.flat[0] + sample = self._reconstruct_item(sample_record) + if not hasattr(sample, attr): + msg = f"Field or property '{attr}' not found in {self.item_type.__name__}." + raise KeyError(msg) + + # Collect values (single pass using flat indexing) + out_obj = np.empty(self.shape, dtype=object) + for i, rec in enumerate(self._items.reshape(-1)): + item = self._reconstruct_item(rec) + out_obj.reshape(-1)[i] = getattr(item, attr) + + # Determine target dtype from property annotation if available (handles postponed annotations) + target_dtype: np.dtype | None = None + attr_member = getattr(self.item_type, attr, None) + if isinstance(attr_member, property) and attr_member.fget is not None: + with contextlib.suppress(Exception): + hints = get_type_hints(attr_member.fget) + ret_ann = hints.get("return") + if ret_ann is not None: + target = _python_to_np_dtype(ret_ann) + target_dtype = np.dtype(target) + + # Cast accordingly + if target_dtype is not None and target_dtype != np.dtype(object): + with contextlib.suppress(Exception): + return out_obj.astype(target_dtype) + + # Heuristics: ints → int64; floats → float64; numpy scalar families respected + with contextlib.suppress(Exception): + if all(isinstance(v, (int, np.integer)) for v in out_obj.flat): + return out_obj.astype(np.int64) + with contextlib.suppress(Exception): + if all(isinstance(v, (float, np.floating)) for v in out_obj.flat): + return out_obj.astype(np.float64) + + return out_obj + + +def _first_leaf( + seq: NamedTuple | np.ndarray | list[NamedTuple] | tuple[NamedTuple, ...] | Nested[NamedTuple], +) -> object: + """Recursively find the first NamedTuple or dataclass instance in a nested sequence or ndarray.""" + if _is_struct_instance(seq): + return seq + if isinstance(seq, np.ndarray): + if seq.size == 0: + msg = "Cannot infer type from empty ndarray." + raise ValueError(msg) + return _first_leaf(seq.flat[0]) + if isinstance(seq, (list, tuple)): + if not seq: + msg = "Cannot infer type from empty nested sequence." + raise ValueError(msg) + return _first_leaf(seq[0]) + + msg = "Items must be NamedTuple or dataclass or nested sequences thereof." + raise TypeError(msg) + + +def _check_homogeneous(seq: SequenceType | np.ndarray | Nested[T], typ: type[T]) -> None: + """Recursively check that all items in the nested sequence/ndarray are of the given type.""" + if isinstance(seq, np.ndarray): + for it in seq.flat: + _check_homogeneous(it, typ) + return + if isinstance(seq, (list, tuple)) and not _is_struct_instance(seq): + seq_ = cast("SequenceType", seq) + for it in seq_: + _check_homogeneous(it, typ) + return + if _is_struct_instance(seq): + if type(seq) is not typ: + msg = f"All items must be of the same type ({typ.__name__}), got {type(seq).__name__}." + raise ValueError(msg) + return + msg = "Items must be NamedTuple, dataclass, or nested sequences thereof." + raise TypeError(msg) + + +def _python_to_np_dtype(py: type) -> np.dtype | str: + """Map basic Python types to NumPy dtypes.""" + if py is int: + return "i8" + if py is float: + return "f8" + if py is bool: + return "?" + if py is str: + return np.dtype(object) # keep Python str; avoids truncation + + if get_origin(py) is Union: + args = [a for a in get_args(py) if a is not type(None)] + if len(args) == 1: + return _python_to_np_dtype(args[0]) + return np.dtype(object) + + +def _is_namedtuple_instance(item: object) -> TypeGuard[NamedTuple]: + """Check if item is a NamedTuple instance.""" + return isinstance(item, tuple) and hasattr(item, "_fields") and not is_dataclass(item) + + +def _dataclass_items_to_tuples(items: object, dc_type: type) -> object: + """Recursively convert dataclass items to tuples for numpy.asarray conversion. + + numpy.asarray doesn't know how to convert arbitrary dataclass instances + to structured array rows, so we convert them to tuples first. + """ + if isinstance(items, np.ndarray): + # For ndarrays, convert each element + result = np.array([_dataclass_items_to_tuples(item, dc_type) for item in items.flat]) + return result.reshape(items.shape) + + if isinstance(items, (list, tuple)) and not is_dataclass(items): + # It's a sequence - recursively convert each element + return [_dataclass_items_to_tuples(item, dc_type) for item in items] + + # It's a single dataclass instance - convert to tuple of field values + if is_dataclass(items): + field_names = _get_field_names(type(items)) + return tuple(getattr(items, name) for name in field_names) + + # Shouldn't reach here, but just in case + return items + + +def _is_namedtuple_type(item: object) -> TypeGuard[type[NamedTuple]]: + """Check if item is a NamedTuple type.""" + return isinstance(item, type) and issubclass(item, tuple) and hasattr(item, "_fields") and not is_dataclass(item) + + +def _is_struct_instance(item: object) -> bool: + """Check if item is a NamedTuple or dataclass instance (not type).""" + if isinstance(item, type): + # Exclude types themselves + return False + return _is_namedtuple_instance(item) or is_dataclass(item) + + +def _is_struct_type(item: object) -> bool: + """Check if item is a NamedTuple or dataclass type.""" + return _is_namedtuple_type(item) or (isinstance(item, type) and is_dataclass(item)) + + +def _get_field_names(item_type: type) -> list[str]: + """Get field names from a NamedTuple or dataclass type.""" + if is_dataclass(item_type): + return [f.name for f in dc_fields(item_type)] + if _is_namedtuple_type(item_type): + return list(item_type._fields) # type: ignore[return-value] + msg = f"item_type must be a NamedTuple or dataclass type, got {item_type}" + raise TypeError(msg) + + +def _get_n_fields(item_type: type) -> int: + """Get number of fields from a NamedTuple or dataclass type.""" + return len(_get_field_names(item_type)) + + +def _get_struct_dtype(item: object) -> np.dtype: + """Generate a NumPy structured dtype from a NamedTuple or dataclass type.""" + if _is_struct_instance(item): + item = type(item) + if not _is_struct_type(item): + msg = "item must be a NamedTuple instance, NamedTuple type, dataclass instance, or dataclass type." + raise TypeError(msg) + + if is_dataclass(item): + return _get_dataclass_dtype(item) # type: ignore[arg-type] + return _get_namedtuple_dtype(item) # type: ignore[arg-type] + + +def _get_namedtuple_dtype(nt_type: type[NamedTuple]) -> np.dtype: + """Generate a NumPy structured dtype from a NamedTuple type.""" + hints = get_type_hints(nt_type, include_extras=False) + names_in_order = list(nt_type._fields) # type: ignore[attr-defined] + + def to_np_dtype(py: type) -> str | np.dtype: + if py is int: + return "i8" + if py is float: + return "f8" + if py is bool: + return "?" + # everything else (incl. str) → object to avoid truncation + return np.dtype(object) + + fields = [(name, to_np_dtype(hints.get(name, object))) for name in names_in_order] + return np.dtype(fields) + + +def _get_dataclass_dtype(dc_type: type) -> np.dtype: + """Generate a NumPy structured dtype from a dataclass type.""" + try: + hints = get_type_hints(dc_type, include_extras=False) + except (NameError, TypeError, AttributeError): + # If we can't resolve type hints (e.g., for types defined in test scopes), + # fall back to empty hints + hints = {} + + def to_np_dtype(py: type) -> str | np.dtype: + if py is int: + return "i8" + if py is float: + return "f8" + if py is bool: + return "?" + # everything else (incl. str) → object to avoid truncation + return np.dtype(object) + + fields = [(f.name, to_np_dtype(hints.get(f.name, object))) for f in dc_fields(dc_type)] + return np.dtype(fields) + + +def _is_dataclass_frozen(dc_type: type) -> bool: + """Check if a dataclass type is frozen. + + Args: + dc_type: A dataclass type. + + Returns: + True if the dataclass is frozen, False otherwise. + """ + if not is_dataclass(dc_type): + return False + + # Access the __dataclass_params__ which contains the frozen flag + if hasattr(dc_type, "__dataclass_params__"): + return dc_type.__dataclass_params__.frozen # type: ignore[attr-defined] + + return False + + +def _has_mutable_fields(dc_type: type) -> bool: + """Check if a dataclass has any mutable field types. + + A frozen dataclass with mutable fields like list, dict, or set can still have + its contents modified even though the dataclass instance itself is frozen. + + Args: + dc_type: A dataclass type. + + Returns: + True if the dataclass has any mutable field types, False otherwise. + """ + if not is_dataclass(dc_type): + return False + + try: + hints = get_type_hints(dc_type, include_extras=False) + except (NameError, TypeError, AttributeError): + # If we can't resolve type hints, skip the check + return False + + mutable_types = MUTABLE_TYPES + + for field_type in hints.values(): + # Get the origin type (e.g., list from list[int]) + origin = get_origin(field_type) + + # Check if directly mutable or is a subclass of mutable types + if _is_subclass_of_allowed(field_type, mutable_types) or ( + origin is not None and _is_subclass_of_allowed(origin, mutable_types) + ): + return True + + # Check if it's a Union with mutable types + if origin is Union: + args = get_args(field_type) + for arg in args: + arg_origin = get_origin(arg) + if _is_subclass_of_allowed(arg, mutable_types) or ( + arg_origin is not None and _is_subclass_of_allowed(arg_origin, mutable_types) + ): + return True + + return False + + +def _get_disallowed_field_names(dc_type: type) -> list[str]: + """Get field names that have disallowed types in a dataclass. + + Allowed types include: str, int, float, bool, bytes, tuple, frozenset, pathlib.Path, + numpy scalars/arrays, frozen dataclasses, and NamedTuples. + + Args: + dc_type: A dataclass type. + + Returns: + List of field names with disallowed types, empty if all types are allowed. + """ + if not is_dataclass(dc_type): + return [] + + try: + hints = get_type_hints(dc_type, include_extras=False) + except (NameError, TypeError, AttributeError): + # If we can't resolve type hints, skip the check + return [] + + disallowed_fields = [] + + # All allowed types (immutable and mutable) + allowed = ALLOWED_TYPES + + for field_name, field_type in hints.items(): + origin = get_origin(field_type) + + # Check if directly allowed or is a subclass of allowed types + if _is_subclass_of_allowed(field_type, allowed): + continue + + # Check if it's a generic type with allowed origin + if origin is not None and _is_subclass_of_allowed(origin, allowed): + continue + + # Check for Union types + if origin is Union: + args = get_args(field_type) + if all(_is_allowed_type(arg) for arg in args if arg is not type(None)): + continue + + # Check for frozen dataclass or NamedTuple + if _is_namedtuple_type(field_type) or (is_dataclass(field_type) and _is_dataclass_frozen(field_type)): + continue + + # If we get here, the type is not allowed + disallowed_fields.append(field_name) + + return disallowed_fields + + +def _is_subclass_of_allowed(check_type: type, allowed_types: tuple[type, ...]) -> bool: + """Check if check_type is a subclass of any type in allowed_types.""" + if check_type in allowed_types: + return True + + try: + if isinstance(check_type, type): + for allowed in allowed_types: + if allowed is not type and issubclass(check_type, allowed): + return True + except TypeError: + # issubclass() can raise TypeError for non-class types + pass + + return False + + +def _is_allowed_type(field_type: type) -> bool: + """Check if a field type is in the allowed list or is a subclass of an allowed type.""" + # Check against basic allowed types (includes numpy types via np.generic) + if _is_subclass_of_allowed(field_type, ALLOWED_TYPES): + return True + + if _is_namedtuple_type(field_type): + return True + + if is_dataclass(field_type) and _is_dataclass_frozen(field_type): + return True + + # Check for generic types + origin = get_origin(field_type) + return origin is not None and _is_subclass_of_allowed(origin, ALLOWED_TYPES) From 59bcfc0da8fc086476512ab6249ec25329fdd412 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 9 Feb 2026 14:23:12 +0100 Subject: [PATCH 26/32] Make Interval dataclass with post_init validation method --- eitprocessing/datahandling/intervaldata.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/eitprocessing/datahandling/intervaldata.py b/eitprocessing/datahandling/intervaldata.py index c4aebbbec..eb06d62aa 100644 --- a/eitprocessing/datahandling/intervaldata.py +++ b/eitprocessing/datahandling/intervaldata.py @@ -1,6 +1,6 @@ import copy from dataclasses import dataclass, field -from typing import Any, NamedTuple, TypeVar +from typing import Any, TypeVar import numpy as np from typing_extensions import Self @@ -12,12 +12,18 @@ T = TypeVar("T", bound="IntervalData") -class Interval(NamedTuple): +@dataclass(frozen=True) +class Interval: """A tuple containing the start time and end time of an interval.""" start_time: float end_time: float + def __post_init__(self): + if self.start_time >= self.end_time: + msg = f"Interval start time ({self.start_time:.2f}) must be less than end time ({self.end_time:.2f})." + raise ValueError(msg) + @property def duration(self) -> float: """Duration of the interval.""" From 19006aeddaca1bde43effadb91fa07daf21abc83 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 9 Feb 2026 14:28:34 +0100 Subject: [PATCH 27/32] Update IntervalData for StructuredArray --- eitprocessing/datahandling/intervaldata.py | 60 ++++++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/eitprocessing/datahandling/intervaldata.py b/eitprocessing/datahandling/intervaldata.py index eb06d62aa..73c0fd813 100644 --- a/eitprocessing/datahandling/intervaldata.py +++ b/eitprocessing/datahandling/intervaldata.py @@ -1,4 +1,7 @@ +import contextlib import copy +import itertools +from collections.abc import Iterator from dataclasses import dataclass, field from typing import Any, TypeVar @@ -7,7 +10,7 @@ from eitprocessing.datahandling import DataContainer from eitprocessing.datahandling.mixins.slicing import HasTimeIndexer, SelectByIndex -from eitprocessing.datahandling.namedtuple_array import NamedTupleArray, Nested +from eitprocessing.datahandling.structured_array import StructuredArray T = TypeVar("T", bound="IntervalData") @@ -63,8 +66,8 @@ class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): name: str = field(compare=False, repr=False) unit: str | None = field(metadata={"check_equivalence": True}, repr=False) category: str = field(metadata={"check_equivalence": True}, repr=False) - intervals: NamedTupleArray[Interval] = field(repr=False) - values: list[Any] | None = field(repr=False, default=None) + intervals: StructuredArray[Interval] = field(repr=False) + values: list[Any] | np.ndarray | None | StructuredArray = field(repr=False, default=None) description: str = field(compare=False, default="", repr=False) default_partial_inclusion: bool = field(repr=False, default=False) @@ -78,26 +81,59 @@ def __post_init__(self) -> None: msg = f"The number of time points ({lt}) does not match the number of values ({lv})." raise ValueError(msg) + if isinstance(self.values, list): + try: + object.__setattr__(self, "values", StructuredArray(self.values)) + except TypeError: + object.__setattr__(self, "values", np.array(self.values)) + + def __iter__(self) -> Iterator[tuple[Interval, Any | None]]: + if self.values is not None: + return iter(zip(self.intervals, self.values, strict=True)) + + return iter(zip(self.intervals, itertools.repeat(None), strict=False)) + @staticmethod def _parse_intervals( - intervals: list[Interval] | Nested[Interval] | np.ndarray | NamedTupleArray[Interval], - ) -> NamedTupleArray[Interval]: - """Parse intervals into a NamedTupleArray of Interval.""" - if isinstance(intervals, NamedTupleArray): - if intervals.dtype is not Interval: - msg = f"Expected intervals of type 'Interval', got '{intervals.dtype.__name__}'" + intervals: list[Interval] | np.ndarray | StructuredArray[Interval], + ) -> StructuredArray[Interval]: + """Parse intervals into a StructuredArray of Interval.""" + if isinstance(intervals, StructuredArray): + if intervals.item_type is not Interval: + msg = f"Expected intervals of type 'Interval', got '{intervals.dtype}'" raise TypeError(msg) + if intervals.items.ndim != 2 or intervals.items.shape[1] != 2: # noqa: PLR2004 + msg = f"Intervals should be a 1D array of Interval, got {intervals.items.ndim + 1}D array." return intervals if isinstance(intervals, np.ndarray): + if intervals.ndim != 2 or intervals.shape[1] != 2: # noqa: PLR2004 + msg = f"Intervals should be a 2D array (1D array of Interval data), got {intervals.ndim}D array." + raise ValueError(msg) try: - return NamedTupleArray.from_numpy_array(intervals, Interval) + return StructuredArray.from_array(intervals, item_type=Interval) except ValueError as e: msg = f"Could not parse intervals from numpy array with dtype '{intervals.dtype}'" raise TypeError(msg) from e + if isinstance(intervals, list | tuple): + with contextlib.suppress(TypeError): + return StructuredArray(intervals, item_type=Interval) + + with contextlib.suppress(TypeError): + return StructuredArray([Interval(*interval) for interval in intervals], item_type=Interval) + + with contextlib.suppress(TypeError): + return StructuredArray.from_array(intervals, item_type=Interval) + + msg = ( + "Could not parse intervals from given input. Intervals should be a list of Interval objects, a 2D " + "numpy array, or a StructuredArray of Interval." + ) + raise TypeError(msg) + try: - return NamedTupleArray.from_nested(intervals, Interval) + return StructuredArray.from_array(intervals, Interval) except Exception as e: msg = "Could not parse intervals from given input." raise TypeError(msg) from e @@ -129,6 +165,7 @@ def _sliced_copy( description=description, intervals=intervals, values=values, + default_partial_inclusion=self.default_partial_inclusion, ) def select_by_time( @@ -187,6 +224,7 @@ def select_by_time( category=self.category, intervals=list(filtered_intervals), values=values, + default_partial_inclusion=self.default_partial_inclusion, ) @staticmethod From 3a32df92599c401f1713ea0d2c72a831ed7b9d1a Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 9 Feb 2026 14:33:34 +0100 Subject: [PATCH 28/32] Freeze Breath --- eitprocessing/datahandling/breath.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/eitprocessing/datahandling/breath.py b/eitprocessing/datahandling/breath.py index ddb5e5702..a68356525 100644 --- a/eitprocessing/datahandling/breath.py +++ b/eitprocessing/datahandling/breath.py @@ -1,8 +1,13 @@ -from collections.abc import Iterator +from __future__ import annotations + from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterator -@dataclass +@dataclass(frozen=True) class Breath: """Represents a breath with a start, middle and end time.""" @@ -14,7 +19,7 @@ def __post_init__(self): if self.start_time >= self.middle_time or self.middle_time >= self.end_time: msg = ( "Start, middle and end should be consecutive, not " - "{self.start_time:.2f}, {self.middle_time:.2f} and {self.end_time:.2f}" + f"{self.start_time:.2f}, {self.middle_time:.2f} and {self.end_time:.2f}" ) raise ValueError(msg) From 368fa6983aa738abea4c0d07f31878b6a654a179 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 9 Feb 2026 14:34:26 +0100 Subject: [PATCH 29/32] Convert Timpel loading to StructuredArray --- eitprocessing/datahandling/loading/timpel.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/eitprocessing/datahandling/loading/timpel.py b/eitprocessing/datahandling/loading/timpel.py index 96198b889..286641a62 100644 --- a/eitprocessing/datahandling/loading/timpel.py +++ b/eitprocessing/datahandling/loading/timpel.py @@ -10,9 +10,10 @@ from eitprocessing.datahandling.continuousdata import ContinuousData from eitprocessing.datahandling.datacollection import DataCollection from eitprocessing.datahandling.eitdata import EITData, Vendor -from eitprocessing.datahandling.intervaldata import IntervalData +from eitprocessing.datahandling.intervaldata import Interval, IntervalData from eitprocessing.datahandling.loading import load_eit_data from eitprocessing.datahandling.sparsedata import SparseData +from eitprocessing.datahandling.structured_array import StructuredArray if TYPE_CHECKING: from pathlib import Path @@ -184,7 +185,7 @@ def load_from_single_path( gi = continuousdata_collection["global_impedance_(raw)"].values - time_ranges, breaths = _make_breaths(time, min_indices, max_indices, gi) + intervals, breaths = _make_breaths(time, min_indices, max_indices, gi) intervaldata_collection = DataCollection(IntervalData) intervaldata_collection.add( IntervalData( @@ -192,7 +193,7 @@ def load_from_single_path( name="Breaths (Timpel)", unit=None, category="breaths", - intervals=time_ranges, + intervals=intervals, values=breaths, default_partial_inclusion=False, ), @@ -222,12 +223,12 @@ def _make_breaths( min_indices: np.ndarray, max_indices: np.ndarray, gi: np.ndarray, -) -> tuple[list[tuple[float, float]], list[Breath]]: +) -> tuple[StructuredArray[Interval], StructuredArray[Breath]]: # TODO: replace section with BreathDetection._remove_doubles() and BreathDetection._remove_edge_cases() from # 41_breath_detection_psomhorst; this code was directly copied from b59ac54 if len(min_indices) < 2 or len(max_indices) < 1: # noqa: PLR2004 - return [], [] + return StructuredArray([], Interval), StructuredArray([], Breath) valley_indices = min_indices.copy() peak_indices = max_indices.copy() @@ -272,9 +273,8 @@ def _make_breaths( current_valley_index += 1 - breaths = [] - for start, end, middle in zip(valley_indices[:-1], valley_indices[1:], peak_indices, strict=True): - breaths.append(((time[start], time[end]), Breath(time[start], time[middle], time[end]))) + times = time[np.column_stack([valley_indices[:-1], peak_indices, valley_indices[1:]])] + breaths = StructuredArray.from_array(times, Breath) + intervals = StructuredArray.from_array(times[:, [0, 2]], Interval) - time_ranges, values = zip(*breaths, strict=True) - return list(time_ranges), list(values) + return intervals, breaths From 351c02c184bd04a7a2e25d2d95d163a1bbbf5f59 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 9 Feb 2026 14:34:45 +0100 Subject: [PATCH 30/32] Convert BreathDetection to StructuredArray --- eitprocessing/features/breath_detection.py | 56 +++++++--------------- 1 file changed, 18 insertions(+), 38 deletions(-) diff --git a/eitprocessing/features/breath_detection.py b/eitprocessing/features/breath_detection.py index 64cb8c87e..5c0804821 100644 --- a/eitprocessing/features/breath_detection.py +++ b/eitprocessing/features/breath_detection.py @@ -1,4 +1,3 @@ -import itertools import math from collections.abc import Callable from dataclasses import dataclass @@ -9,8 +8,9 @@ from eitprocessing.datahandling.breath import Breath from eitprocessing.datahandling.continuousdata import ContinuousData -from eitprocessing.datahandling.intervaldata import IntervalData +from eitprocessing.datahandling.intervaldata import Interval, IntervalData from eitprocessing.datahandling.sequence import Sequence +from eitprocessing.datahandling.structured_array import StructuredArray from eitprocessing.features.moving_average import MovingAverage @@ -116,12 +116,13 @@ def find_breaths( valley_indices, ) breaths = self._remove_breaths_around_invalid_data(breaths, time, sample_frequency, invalid_data_indices) + intervals = StructuredArray.from_array(breaths.to_array()[:, [0, 2]], Interval) breaths_container = IntervalData( label=result_label, name="Breaths as determined by BreathDetection", unit=None, category="breath", - intervals=[(breath.start_time, breath.end_time) for breath in breaths], + intervals=intervals, values=breaths, ) @@ -370,56 +371,35 @@ def _create_breaths_from_peak_valley_data( time: np.ndarray, peak_indices: np.ndarray, valley_indices: np.ndarray, - ) -> list[Breath]: - return [ - Breath(time[start], time[middle], time[end]) - for middle, (start, end) in zip( - peak_indices, - itertools.pairwise(valley_indices), - strict=True, - ) - ] + ) -> StructuredArray[Breath]: + times = time[np.column_stack([valley_indices[:-1], peak_indices, valley_indices[1:]])] + return StructuredArray.from_array(times, Breath) def _remove_breaths_around_invalid_data( self, - breaths: list[Breath], + breaths: StructuredArray[Breath], time: np.ndarray, sample_frequency: float, invalid_data_indices: np.ndarray, - ) -> list[Breath]: - """Remove breaths overlapping with invalid data. - - Breaths that start within a window length (given by invalid_data_removal_window_length) of invalid data are - removed. - - Args: - breaths: list of detected breath objects - time: time axis belonging to the data - sample_frequency: sample frequency of the data and time - invalid_data_indices: indices of invalid data points - """ - # TODO: write more general(ized) method of determining invalid data - - new_breaths = breaths[:] - + ) -> StructuredArray[Breath]: + """Remove breaths overlapping with invalid data.""" if not len(invalid_data_indices): - return new_breaths + return breaths[:] - invalid_data_values = np.zeros(time.shape) - invalid_data_values[invalid_data_indices] = 1 # gives the value 1 to each invalid datapoint + invalid_data_values = np.zeros_like(time) + invalid_data_values[invalid_data_indices] = 1 window_length = math.ceil(self.invalid_data_removal_window_length * sample_frequency) - for breath in new_breaths[:]: + indices_to_keep = [] + for i, breath in enumerate(breaths): breath_start_minus_window = max(0, np.argmax(time == breath.start_time) - window_length) breath_end_plus_window = min(len(invalid_data_values), np.argmax(time == breath.end_time) + window_length) - # if no invalid datapoints are within the window, np.max() will return 0 - # if any invalid datapoints are within the window, np.max() will return 1 - if np.max(invalid_data_values[breath_start_minus_window:breath_end_plus_window]): - new_breaths.remove(breath) + if not np.max(invalid_data_values[breath_start_minus_window:breath_end_plus_window]): + indices_to_keep.append(i) - return new_breaths + return breaths[indices_to_keep] @staticmethod def _fill_nan_with_nearest_neighbour(data: np.ndarray) -> np.ndarray: From 3e33a9a6eebf5ab49b132452eb640e7f41780ce6 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 9 Feb 2026 14:35:43 +0100 Subject: [PATCH 31/32] Update tests --- tests/test_breath_detection.py | 10 +- tests/test_intervaldata.py | 24 +- tests/test_namedtuple_array.py | 778 --------------------------------- tests/test_sparsedata.py | 4 - 4 files changed, 17 insertions(+), 799 deletions(-) delete mode 100644 tests/test_namedtuple_array.py diff --git a/tests/test_breath_detection.py b/tests/test_breath_detection.py index 42cc8ce5a..3f69b6a7a 100644 --- a/tests/test_breath_detection.py +++ b/tests/test_breath_detection.py @@ -360,16 +360,16 @@ def test_create_breaths_from_peak_valley_data(): breaths = bd._create_breaths_from_peak_valley_data(time, peak_indices, valley_indices) assert len(breaths) == 5 assert all(isinstance(breath, Breath) for breath in breaths) - assert np.array_equal(np.array([breath.start_time for breath in breaths]), time[valley_indices[:-1]]) - assert np.array_equal(np.array([breath.middle_time for breath in breaths]), time[peak_indices]) - assert np.array_equal(np.array([breath.end_time for breath in breaths]), time[valley_indices[1:]]) + assert np.array_equal(breaths["start_time"], time[valley_indices[:-1]]) + assert np.array_equal(breaths["middle_time"], time[peak_indices]) + assert np.array_equal(breaths["end_time"], time[valley_indices[1:]]) fewer_valley_indices = valley_indices[:-1] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="all the input array dimensions .* must match exactly"): bd._create_breaths_from_peak_valley_data(time, peak_indices, fewer_valley_indices) fewer_peak_indices = peak_indices[:-1] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="all the input array dimensions .* must match exactly"): bd._create_breaths_from_peak_valley_data(time, fewer_peak_indices, valley_indices) peaks_out_of_order = np.concatenate([peak_indices[3:], peak_indices[:3]]) diff --git a/tests/test_intervaldata.py b/tests/test_intervaldata.py index e5bffc387..19c92f239 100644 --- a/tests/test_intervaldata.py +++ b/tests/test_intervaldata.py @@ -4,6 +4,7 @@ import pytest from eitprocessing.datahandling.intervaldata import Interval, IntervalData +from eitprocessing.datahandling.structured_array import StructuredArray @pytest.fixture @@ -77,7 +78,7 @@ def intervaldata_valuesarray_partialfalse(): def test_post_init(intervaldata_novalues_partialtrue: IntervalData): - assert isinstance(intervaldata_novalues_partialtrue.intervals, list) + assert isinstance(intervaldata_novalues_partialtrue.intervals, StructuredArray) assert all(isinstance(interval, Interval) for interval in intervaldata_novalues_partialtrue.intervals) @@ -92,11 +93,6 @@ def test_has_values( intervaldata_valuesarray_partialfalse: IntervalData, ) -> None: assert not intervaldata_novalues_partialtrue.has_values - intervaldata_novalues_partialtrue.values = [] - assert intervaldata_novalues_partialtrue.has_values - intervaldata_novalues_partialtrue.values = None - assert not intervaldata_novalues_partialtrue.has_values - assert not intervaldata_novalues_partialfalse.has_values assert intervaldata_valueslist_partialfalse.has_values assert intervaldata_valuesarray_partialfalse.has_values @@ -162,21 +158,23 @@ def test_select_by_time( def test_select_by_time_values(intervaldata_valueslist_partialfalse: IntervalData): - assert isinstance(intervaldata_valueslist_partialfalse.values, list) + assert isinstance(intervaldata_valueslist_partialfalse.values, np.ndarray) sliced_copy = intervaldata_valueslist_partialfalse[:10] assert len(sliced_copy.intervals) == len(sliced_copy.values) - assert sliced_copy.values == intervaldata_valueslist_partialfalse.values[:10] + assert np.all(sliced_copy.values == intervaldata_valueslist_partialfalse.values[:10]) def test_concatenate(intervaldata_novalues_partialtrue: IntervalData): sliced_copy_1 = intervaldata_novalues_partialtrue[:10] sliced_copy_2 = intervaldata_novalues_partialtrue[10:20] + assert isinstance(sliced_copy_1.intervals, StructuredArray) assert len(sliced_copy_1) == 10 assert len(sliced_copy_2) == 10 concatenated = sliced_copy_1 + sliced_copy_2 + assert isinstance(concatenated.intervals, StructuredArray) assert len(concatenated) == 20 assert concatenated == sliced_copy_1.concatenate(sliced_copy_2) @@ -197,8 +195,8 @@ def test_concatenate_values_list(intervaldata_valueslist_partialfalse: IntervalD concatenated = sliced_copy_1 + sliced_copy_2 assert len(concatenated.intervals) == len(concatenated.values) - assert isinstance(intervaldata_valueslist_partialfalse.values, list) - assert concatenated.values == intervaldata_valueslist_partialfalse.values[:20] + assert isinstance(intervaldata_valueslist_partialfalse.values, np.ndarray) + assert np.all(concatenated.values == intervaldata_valueslist_partialfalse.values[:20]) def test_concatenate_values_numpy(intervaldata_valuesarray_partialfalse: IntervalData): @@ -215,5 +213,7 @@ def test_concatenate_values_type_mismatch( intervaldata_valueslist_partialfalse: IntervalData, intervaldata_valuesarray_partialfalse: IntervalData, ): - with pytest.raises(TypeError): - intervaldata_valueslist_partialfalse[:10] + intervaldata_valuesarray_partialfalse[10:] + intervaldata_valueslist_partialfalse[:10] + intervaldata_valuesarray_partialfalse[10:] + + +# TODO: add tests for IntervalData selection and concatenation with StructuredArray diff --git a/tests/test_namedtuple_array.py b/tests/test_namedtuple_array.py deleted file mode 100644 index 3475999c4..000000000 --- a/tests/test_namedtuple_array.py +++ /dev/null @@ -1,778 +0,0 @@ -from typing import NamedTuple - -import numpy as np -import pytest - -from eitprocessing.datahandling.namedtuple_array import NamedTupleArray - - -class Mixed(NamedTuple): - """NamedTuple with mixed field types and a computed property.""" - - a: int - b: float - c: bool - d: str - - @property - def d_length(self) -> int: - """Computed property returning the length of string d.""" - return len(self.d) - - -class Simple(NamedTuple): - """NamedTuple with simple numeric fields.""" - - x: int - y: float - - -class Breath(NamedTuple): - """NamedTuple representing a breath with start, mid, end times.""" - - start: float - mid: float - end: float - - @property - def duration(self) -> float: - """Computed property returning the duration of the breath.""" - return self.end - self.start - - -def test_1d_mixed_types_and_properties(): - items = [Mixed(1, 2.0, True, "foo"), Mixed(3, 4.5, False, "hello")] - nta = NamedTupleArray(items) - - assert nta.shape == (2,) - # scalar access - v0 = nta[0] - assert isinstance(v0, Mixed) - assert v0.a == 1 - assert v0.b == 2.0 - assert v0.c is True - assert v0.d == "foo" - - # field views have expected dtype and values - a = nta["a"] - assert a.dtype.kind in ("i", "u") - assert a.shape == (2,) - assert (a == np.array([1, 3])).all() - - b = nta["b"] - assert np.issubdtype(b.dtype, np.floating) - - c = nta["c"] - assert np.issubdtype(c.dtype, np.bool_) - - d = nta["d"] - # strings kept as object - assert d.dtype == object - - # computed property -> returns int dtype (annotated) - dl = nta["d_length"] - assert np.issubdtype(dl.dtype, np.integer) - assert list(dl) == [3, 5] - - -def test_2d_indexing_and_slicing(): - nested = [[Simple(i + j, float(i * j)) for j in range(3)] for i in range(2)] - nta2d = NamedTupleArray(nested) - assert nta2d.shape == (2, 3) - - # scalar multi-dimensional indexing returns NamedTuple - item = nta2d[0, 1] - assert isinstance(item, Simple) - assert item.x == 1 - assert item.y == 0.0 - - # row slice returns NamedTupleArray - row = nta2d[0] - assert isinstance(row, NamedTupleArray) - assert row.shape == (3,) - - # field access on 2D returns array with original shape - xs = nta2d["x"] - assert xs.shape == (2, 3) - assert xs[0, 1] == 1 - - -def test_3d_from_ndarray_and_indexing(): - # create shape (2,2,2,2) last axis 2 fields - arr = np.array( - [ - [[[1, 2.0], [3, 4.0]], [[5, 6.0], [7, 8.0]]], - [[[9, 10.0], [11, 12.0]], [[13, 14.0], [15, 16.0]]], - ], - dtype=float, - ) - nta = NamedTupleArray.from_array(arr, Simple) - assert nta.shape == (2, 2, 2) - - # random 3D scalar access - s = nta[1, 0, 1] - assert isinstance(s, Simple) - assert s.x == 11 - assert s.y == 12.0 - - -def test_field_views_readonly_and_shape_preserved(): - items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items, frozen=False) - assert nta.flags.writeable is True - vx = nta["x"] - assert vx.flags.writeable is False - assert vx.shape == (2,) - - -def test_calculated_property_float_dtype(): - breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.4, 2.2)] - nta = NamedTupleArray(breaths) - dur = nta["duration"] - assert np.issubdtype(dur.dtype, np.floating) - assert pytest.approx(list(dur), rel=1e-9) == [1.0, 1.2] - - -def test_forwarding_attributes_and_methods(): - items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items) - - # dtype via property - assert nta.dtype is not None - - # reshape is not available (array is private now) - assert not hasattr(nta, "reshape") - - # flags and ndim reflect the underlying array - assert nta.flags.writeable is not None - assert nta.ndim == 1 - - # field-view methods are available (e.g., sum) - vx = nta["x"] - assert hasattr(vx, "sum") - assert int(vx.sum()) == 4 - - -def test_frozen_namedtuple_array_numeric(): - items = [Simple(1, 2.0), Simple(3, 4.0)] - frozen_nta = NamedTupleArray(items, frozen=True) - nta = NamedTupleArray(items, frozen=False) - - # Field views are always read-only (for both frozen and unfrozen) - field_view = nta["x"] - with pytest.raises(ValueError, match="assignment destination is read-only"): - field_view[0] = 10 - - # Frozen array also prevents field modification - frozen_field_view = frozen_nta["x"] - with pytest.raises(ValueError, match="assignment destination is read-only"): - frozen_field_view[0] = 10 - - # Both frozen and unfrozen arrays block new attributes - with pytest.raises(AttributeError): - nta.new_attribute = 42 - with pytest.raises(AttributeError): - frozen_nta.new_attribute = 42 - - # Both frozen and unfrozen arrays block modification of _type - with pytest.raises(AttributeError): - nta._type = np.floating - with pytest.raises(AttributeError): - frozen_nta._type = np.floating - - # Frozen array prevents toggling writeable flag - with pytest.raises(ValueError, match="cannot set WRITEABLE flag to True of this array"): - frozen_nta.flags.writeable = True - - -def test_frozen_namedtuple_array_string(): - items = [Mixed(1, 2.0, True, "foo"), Mixed(3, 4.0, False, "bar")] - frozen_nta = NamedTupleArray(items, frozen=True) - - # Frozen array with object dtype prevents field modification via read-only view - frozen_field_view = frozen_nta["d"] - with pytest.raises(ValueError, match="assignment destination is read-only"): - frozen_field_view[0] = "baz" - - # Frozen array blocks new attributes - with pytest.raises(AttributeError): - frozen_nta.new_attribute = 42 - - # Frozen array blocks modifying _type - with pytest.raises(AttributeError): - frozen_nta._type = np.floating - - # For object dtypes, users cannot access the underlying array to toggle the writeable flag - # because __items is now private (inaccessible). This is the key benefit of the refactoring. - # The writeable flag is protected by preventing direct array access. - with pytest.raises(AttributeError, match="has no attribute '_items'"): - _ = frozen_nta._items - - -def test_items_property_unfrozen(): - """Test that .items property returns the underlying array for unfrozen arrays.""" - items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items, frozen=False) - - # Access via .items property works - underlying = nta.items - assert isinstance(underlying, np.ndarray) - assert underlying.shape == (2,) - assert underlying.dtype.names == ("x", "y") - - # For unfrozen arrays, modifications are allowed (though field views are still read-only) - assert nta.items.flags.writeable is True - - -def test_items_property_frozen(): - """Test that .items property is read-only for frozen arrays.""" - items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items, frozen=True) - - # Access via .items property works - underlying = nta.items - assert isinstance(underlying, np.ndarray) - assert underlying.shape == (2,) - - # For frozen arrays with numeric dtypes, writeable is False - assert nta.items.flags.writeable is False - - -def test_slicing_frozen_array(): - """Test that slicing a frozen array returns a frozen sub-array.""" - items = [Simple(i, float(i)) for i in range(5)] - nta = NamedTupleArray(items, frozen=True) - - # Slice returns NamedTupleArray with same frozenness - sliced = nta[1:3] - assert isinstance(sliced, NamedTupleArray) - assert sliced.shape == (2,) - assert sliced.items.flags.writeable is False - - -def test_computed_property_frozen(): - """Test that computed properties work correctly on frozen arrays.""" - breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.4, 2.2)] - nta = NamedTupleArray(breaths, frozen=True) - - # Computed property returns correct values - dur = nta["duration"] - assert pytest.approx(list(dur), rel=1e-9) == [1.0, 1.2] - # Computed properties create new arrays, so they're writable - assert dur.flags.writeable is True - - -def test_single_item_array(): - """Test NamedTupleArray with a single item.""" - items = [Simple(42, 3.14)] - nta = NamedTupleArray(items, frozen=True) - - assert nta.shape == (1,) - assert nta[0] == Simple(42, 3.14) - assert nta["x"][0] == 42 - - -def test_2d_frozen_array(): - """Test 2D frozen array slicing and access.""" - nested = [[Simple(i + j, float(i * j)) for j in range(2)] for i in range(2)] - nta = NamedTupleArray(nested, frozen=True) - - assert nta.shape == (2, 2) - assert nta.items.flags.writeable is False - - # Slicing should also be frozen - row = nta[0] - assert row.items.flags.writeable is False - - -def test_all_properties_accessible(): - """Test that all expected properties are accessible.""" - items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items) - - # All properties should be accessible - assert nta.shape == (2,) - assert nta.ndim == 1 - assert nta.dtype is not None - assert nta.flags is not None - assert nta.items is not None - - -def test_setattr_blocked_post_init(): - """Test that __setattr__ blocks all attribute setting after init.""" - items = [Simple(1, 2.0)] - nta = NamedTupleArray(items, frozen=False) - - # Cannot set any new attributes - with pytest.raises(AttributeError, match="immutable"): - nta.custom_attr = "value" - - # Cannot modify internal attributes - with pytest.raises(AttributeError, match="immutable"): - nta._type = int - - -# ============================================================================ -# Edge Cases and Error Conditions -# ============================================================================ - - -def test_empty_sequence(): - """Test that empty sequences raise ValueError.""" - with pytest.raises(ValueError, match="Cannot infer type from empty"): - NamedTupleArray([]) - - -def test_empty_ndarray(): - """Test that empty ndarrays raise ValueError.""" - empty_array = np.array([], dtype=float) - with pytest.raises(ValueError, match="Cannot infer type from empty"): - NamedTupleArray(empty_array) - - -def test_non_namedtuple_items(): - """Test that passing non-NamedTuple items raises TypeError.""" - items = [(1, 2.0), (3, 4.0)] # Regular tuples, not NamedTuples - with pytest.raises(TypeError, match="NamedTuple"): - NamedTupleArray(items) - - -def test_ndarray_last_axis_mismatch(): - """Test that from_ndarray with mismatched last axis raises error.""" - # Simple has 2 fields, but array has 3 columns - arr = np.array([[1, 2.0, 3.0], [4, 5.0, 6.0]]) - with pytest.raises(ValueError): - NamedTupleArray.from_array(arr, Simple) - - -def test_frozen_from_ndarray(): - """Test that from_ndarray with frozen=True works correctly.""" - arr = np.array([[1, 2.0], [3, 4.0]]) - nta = NamedTupleArray.from_array(arr, Simple, frozen=True) - - assert nta.shape == (2,) - assert nta.items.flags.writeable is False - - -def test_iteration(): - """Test that NamedTupleArray is iterable.""" - items = [Simple(1, 2.0), Simple(3, 4.0), Simple(5, 6.0)] - nta = NamedTupleArray(items) - - # Iteration should yield NamedTuple instances - for i, item in enumerate(nta): - assert isinstance(item, Simple) - assert item == items[i] - - -def test_len(): - """Test that len() works on NamedTupleArray.""" - items = [Simple(1, 2.0), Simple(3, 4.0), Simple(5, 6.0)] - nta = NamedTupleArray(items) - - assert len(nta) == 3 - - -def test_repr(): - """Test that repr() produces a meaningful string.""" - items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items) - - r = repr(nta) - assert "NamedTupleArray" in r - assert "Simple" in r - - -def test_computed_property_nonexistent_attribute(): - """Test accessing a computed property that doesn't exist.""" - items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items) - - # Accessing a non-existent attribute should raise KeyError - with pytest.raises(KeyError): - _ = nta["nonexistent_property"] - - -def test_heterogeneous_items(): - """Test that arrays with different NamedTuple types raise error.""" - items = [Simple(1, 2.0), Mixed(3, 4.0, True, "foo")] - with pytest.raises(ValueError): - NamedTupleArray(items) - - -def test_mixed_items_and_non_items(): - """Test that mixing NamedTuples with non-NamedTuples raises error.""" - items = [Simple(1, 2.0), (3, 4.0)] # Mix of NamedTuple and tuple - with pytest.raises(TypeError): - NamedTupleArray(items) - - -def test_passing_array_in_list(): - """Test that passing a numpy array wrapped in a list works.""" - arr = np.array([(1, 2.0), (3, 4.0)], dtype=[("x", "i8"), ("y", "f8")]) - # This should be treated as a single-item list containing an array - # and should fail since arrays aren't NamedTuple instances - with pytest.raises(TypeError): - NamedTupleArray([arr]) - - -def test_empty_nested_list(): - """Test that empty nested lists raise ValueError.""" - with pytest.raises(ValueError, match="Cannot infer type from empty"): - NamedTupleArray([[]]) - - -def test_casting_computed_property(): - """Test computed properties with type annotations are cast correctly.""" - breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.4, 2.2)] - nta = NamedTupleArray(breaths) - - # duration is annotated as float - dur = nta["duration"] - assert np.issubdtype(dur.dtype, np.floating) - - # d_length (from Mixed) is annotated as int - mixed_items = [Mixed(1, 2.0, True, "foo"), Mixed(3, 4.0, False, "hello")] - nta_mixed = NamedTupleArray(mixed_items) - d_len = nta_mixed["d_length"] - assert np.issubdtype(d_len.dtype, np.integer) - - -def test_array_protocol(): - """Test that __array__ interface works (if implemented).""" - items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items) - - # Should be convertible to numpy array - arr = np.asarray(nta) - assert isinstance(arr, np.ndarray) - assert arr.shape == (2,) - - -def test_frozen_unfrozen_mixed_access(): - """Test accessing frozen array via different methods.""" - items = [Simple(1, 2.0), Simple(3, 4.0)] - frozen_nta = NamedTupleArray(items, frozen=True) - - # Scalar access still works - assert frozen_nta[0] == Simple(1, 2.0) - - # Field access returns read-only view - field = frozen_nta["x"] - assert field.flags.writeable is False - - # items property returns frozen array - assert frozen_nta.items.flags.writeable is False - - -def test_nested_list_homogeneity(): - """Test that nested lists maintain homogeneity checks.""" - # Valid nested structure - nested = [[Simple(i, float(i)) for i in range(2)], [Simple(j, float(j)) for j in range(2, 4)]] - nta = NamedTupleArray(nested) - assert nta.shape == (2, 2) - - # Invalid: mixed types in nested structure - invalid_nested = [[Simple(1, 2.0), Mixed(3, 4.0, True, "foo")]] - with pytest.raises(ValueError): - NamedTupleArray(invalid_nested) - - -def test_slicing_preserves_type(): - """Test that slicing returns the correct type.""" - items = [Simple(i, float(i)) for i in range(5)] - nta = NamedTupleArray(items) - - # Integer indexing returns item - item = nta[2] - assert isinstance(item, Simple) - - # Slice returns NamedTupleArray - sliced = nta[1:3] - assert isinstance(sliced, NamedTupleArray) - assert sliced.shape == (2,) - - -def test_from_ndarray_0d_array(): - """Test that from_ndarray with 0D array raises error.""" - arr = np.array(2.0) - with pytest.raises(ValueError, match="at least 1 dimension"): - NamedTupleArray.from_array(arr, Simple) - - -def test_multidimensional_iteration(): - """Test iteration over multi-dimensional NamedTupleArray.""" - nested = [[Simple(i, float(i)) for i in range(2)], [Simple(j + 2, float(j + 2)) for j in range(2)]] - nta = NamedTupleArray(nested) - - # Iteration over 2D array yields 1D NamedTupleArrays - rows = list(nta) - assert len(rows) == 2 - assert all(isinstance(row, NamedTupleArray) for row in rows) - assert rows[0].shape == (2,) - - -def test_computed_property_with_heuristic_int(): - """Test computed property that infers int dtype via heuristic.""" - - # Create a NamedTuple with a property that returns int but has no type annotation - class AnnotationlessNT(NamedTuple): - value: int - - @property - def doubled(self) -> int: - """Unannotated property returning int.""" - return self.value * 2 - - items = [AnnotationlessNT(1), AnnotationlessNT(2)] - nta = NamedTupleArray(items) - - result = nta["doubled"] - assert np.issubdtype(result.dtype, np.integer) - assert list(result) == [2, 4] - - -def test_computed_property_with_heuristic_float(): - """Test computed property that infers float dtype via heuristic.""" - - class FloatPropertyNT(NamedTuple): - value: float - - @property - def halved(self) -> float: - """Unannotated property returning float.""" - return self.value / 2.0 - - items = [FloatPropertyNT(2.0), FloatPropertyNT(4.0)] - nta = NamedTupleArray(items) - - result = nta["halved"] - assert np.issubdtype(result.dtype, np.floating) - assert pytest.approx(list(result)) == [1.0, 2.0] - - -def test_computed_property_mixed_types_returns_object(): - """Test computed property with mixed types returns object dtype.""" - - class MixedPropertyNT(NamedTuple): - value: int - - @property - def mixed(self) -> int | str: - """Unannotated property that sometimes returns int, sometimes str.""" - return self.value if self.value < 2 else f"str_{self.value}" - - items = [MixedPropertyNT(1), MixedPropertyNT(3)] - nta = NamedTupleArray(items) - - result = nta["mixed"] - assert result.dtype == object - assert result[0] == 1 - assert result[1] == "str_3" - - -def test_getitem_returns_2d_array(): - """Test that indexing returns correct types for various index types.""" - items = [Simple(i, float(i)) for i in range(6)] - nta = NamedTupleArray(items) - - # Fancy indexing with list returns NamedTupleArray - indexed = nta[[0, 2, 4]] - assert isinstance(indexed, NamedTupleArray) - assert indexed.shape == (3,) - - -def test_zero_d_ndarray_from_indexing(): - """Test that scalar indexing returns NamedTuple correctly.""" - items = [Simple(1, 2.0), Simple(3, 4.0)] - nta = NamedTupleArray(items) - - # Scalar access should return NamedTuple - scalar = nta[0] - assert isinstance(scalar, Simple) - assert scalar.x == 1 - assert scalar.y == 2.0 - - -def test_from_ndarray_empty_fields(): - """Test that from_ndarray handles edge cases.""" - # Just verify the function is defined and doesn't break in normal usage - # The empty fields check is hard to trigger naturally - arr = np.array([[1, 2.0], [3, 4.0]]) - nta = NamedTupleArray.from_array(arr, Simple) - assert nta.shape == (2,) - - -def test_nested_empty_list_deeply(): - """Test that deeply nested empty lists raise ValueError.""" - with pytest.raises(ValueError): - NamedTupleArray([[[]]]) - - -def test_property_with_exception_in_heuristic(): - """Test that computed property handles exceptions in type inference gracefully.""" - - class NTWithUnusualProperty(NamedTuple): - value: int - - @property - def unusual(self) -> dict: - # Return something that can't be easily cast - return {"key": self.value} - - items = [NTWithUnusualProperty(1), NTWithUnusualProperty(2)] - nta = NamedTupleArray(items) - - # Should return object dtype (via exception handling) - result = nta["unusual"] - assert result.dtype == object - - -def test_from_ndarray_3d_array(): - """Test from_ndarray with 3D array.""" - arr = np.array([[[1, 2.0], [3, 4.0]], [[5, 6.0], [7, 8.0]]], dtype=float) - nta = NamedTupleArray.from_array(arr, Simple) - - assert nta.shape == (2, 2) - assert nta[0, 1] == Simple(3, 4.0) - - -def test_iteration_1d_direct_yield(): - """Test that 1D iteration yields NamedTuple items directly.""" - items = [Simple(1, 1.0), Simple(2, 2.0), Simple(3, 3.0)] - nta = NamedTupleArray(items) - - yielded = list(nta) - assert len(yielded) == 3 - assert all(isinstance(item, Simple) for item in yielded) - assert yielded == items - - -def test_from_ndarray_empty_namedtuple(): - """Test that from_ndarray with empty NamedTuple raises RuntimeError (lines 232-233).""" - - class Empty(NamedTuple): - pass - - arr = np.array([], dtype=float) - with pytest.raises(RuntimeError, match="no fields"): - NamedTupleArray.from_array(arr, Empty) - - -def test_equality_identical_arrays(): - """Test that two NamedTupleArray instances with identical data are equal.""" - items1 = [Simple(1, 2.0), Simple(3, 4.5)] - items2 = [Simple(1, 2.0), Simple(3, 4.5)] - - arr1 = NamedTupleArray(items1) - arr2 = NamedTupleArray(items2) - - assert arr1 == arr2 - assert arr2 == arr1 # Test equality is symmetric - - -def test_equality_different_values(): - """Test that NamedTupleArray instances with different values are not equal.""" - items1 = [Simple(1, 2.0), Simple(3, 4.5)] - items2 = [Simple(1, 2.0), Simple(3, 5.0)] - - arr1 = NamedTupleArray(items1) - arr2 = NamedTupleArray(items2) - - assert arr1 != arr2 - - -def test_equality_different_lengths(): - """Test that NamedTupleArray instances with different lengths are not equal.""" - items1 = [Simple(1, 2.0), Simple(3, 4.5)] - items2 = [Simple(1, 2.0), Simple(3, 4.5), Simple(5, 6.0)] - - arr1 = NamedTupleArray(items1) - arr2 = NamedTupleArray(items2) - - assert arr1 != arr2 - - -def test_equality_different_types(): - """Test that NamedTupleArray instances with different NamedTuple types are not equal.""" - simple_items = [Simple(1, 2.0), Simple(3, 4.5)] - breath_items = [Breath(1.0, 2.0, 3.0), Breath(3.0, 4.0, 5.0)] - - arr1 = NamedTupleArray(simple_items) - arr2 = NamedTupleArray(breath_items) - - assert arr1 != arr2 - - -def test_equality_with_nan_values(): - """Test that NamedTupleArray instances with NaN values can be compared correctly.""" - items1 = [Simple(1, np.nan), Simple(3, 4.5)] - items2 = [Simple(1, np.nan), Simple(3, 4.5)] - - arr1 = NamedTupleArray(items1) - arr2 = NamedTupleArray(items2) - - # NaN values should be considered equal in this comparison - assert arr1 == arr2 - - -def test_equality_with_nan_different_positions(): - """Test that NamedTupleArray with NaN in different positions are not equal.""" - items1 = [Simple(1, np.nan), Simple(3, 4.5)] - items2 = [Simple(1, 2.0), Simple(3, np.nan)] - - arr1 = NamedTupleArray(items1) - arr2 = NamedTupleArray(items2) - - assert arr1 != arr2 - - -def test_equality_2d_arrays(): - """Test equality comparison for 2D NamedTupleArray instances.""" - nested1 = [[Simple(i + j, float(i * j)) for j in range(3)] for i in range(2)] - nested2 = [[Simple(i + j, float(i * j)) for j in range(3)] for i in range(2)] - - arr1 = NamedTupleArray(nested1) - arr2 = NamedTupleArray(nested2) - - assert arr1 == arr2 - - -def test_equality_2d_arrays_different(): - """Test inequality for 2D NamedTupleArray instances with different values.""" - nested1 = [[Simple(i + j, float(i * j)) for j in range(3)] for i in range(2)] - nested2 = [[Simple(i + j + 1, float(i * j)) for j in range(3)] for i in range(2)] - - arr1 = NamedTupleArray(nested1) - arr2 = NamedTupleArray(nested2) - - assert arr1 != arr2 - - -def test_equality_not_equal_to_non_namedtuplearray(): - """Test that NamedTupleArray is not equal to other types.""" - items = [Simple(1, 2.0), Simple(3, 4.5)] - arr = NamedTupleArray(items) - - # Test inequality with list - assert arr != items - - # Test inequality with numpy array - assert arr != np.array([(1, 2.0), (3, 4.5)]) - - # Test inequality with None - assert arr is not None - - # Test inequality with string - assert arr != "not an array" - - -def test_equality_frozen_and_unfrozen(): - """Test that frozen and unfrozen arrays with same data are equal.""" - items1 = [Simple(1, 2.0), Simple(3, 4.5)] - items2 = [Simple(1, 2.0), Simple(3, 4.5)] - - arr_frozen = NamedTupleArray(items1, frozen=True) - arr_unfrozen = NamedTupleArray(items2, frozen=False) - - assert arr_frozen == arr_unfrozen diff --git a/tests/test_sparsedata.py b/tests/test_sparsedata.py index d706cf8e9..e9aa02a63 100644 --- a/tests/test_sparsedata.py +++ b/tests/test_sparsedata.py @@ -62,10 +62,6 @@ def test_has_values( sparsedata_valuesarray: SparseData, ) -> None: assert not sparsedata_novalues.has_values - sparsedata_novalues.values = [] - assert sparsedata_novalues.has_values - sparsedata_novalues.values = None - assert sparsedata_valueslist.has_values assert sparsedata_valuesarray.has_values From 010372a8bf7269bfedc3f75da9a9a364723a5ab8 Mon Sep 17 00:00:00 2001 From: Peter Somhorst Date: Mon, 9 Feb 2026 14:36:14 +0100 Subject: [PATCH 32/32] Update TIV for StructuredArray --- eitprocessing/parameters/tidal_impedance_variation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/eitprocessing/parameters/tidal_impedance_variation.py b/eitprocessing/parameters/tidal_impedance_variation.py index 29818260c..3fb75b475 100644 --- a/eitprocessing/parameters/tidal_impedance_variation.py +++ b/eitprocessing/parameters/tidal_impedance_variation.py @@ -3,7 +3,7 @@ import warnings from dataclasses import InitVar, dataclass, field from functools import singledispatchmethod -from typing import Final, Literal, NoReturn +from typing import Final, Literal import numpy as np @@ -13,6 +13,7 @@ from eitprocessing.datahandling.intervaldata import IntervalData from eitprocessing.datahandling.sequence import Sequence from eitprocessing.datahandling.sparsedata import SparseData +from eitprocessing.datahandling.structured_array import StructuredArray from eitprocessing.features.breath_detection import BreathDetection from eitprocessing.features.pixel_breath import PixelBreath from eitprocessing.parameters import ParameterCalculation @@ -73,7 +74,7 @@ def __post_init__(self, breath_detection_kwargs: dict | None) -> None: def compute_parameter( self, data: ContinuousData | EITData, - ) -> NoReturn: + ) -> SparseData: """Compute the tidal impedance variation per breath on either ContinuousData or EITData, depending on the input. Args: @@ -282,12 +283,12 @@ def _calculate_tiv_values( self, data: np.ndarray, time: np.ndarray, - breaths: list[Breath], + breaths: StructuredArray[Breath], tiv_method: str, tiv_timing: str, # noqa: ARG002 # remove when restructuring ) -> list: # Filter out None breaths - breaths = np.array(breaths) + valid_breath_indices = np.flatnonzero([breath is not None for breath in breaths]) valid_breaths = breaths[valid_breath_indices]