diff --git a/eitprocessing/datahandling/__init__.py b/eitprocessing/datahandling/__init__.py index 190182068..823e57371 100644 --- a/eitprocessing/datahandling/__init__.py +++ b/eitprocessing/datahandling/__init__.py @@ -1,3 +1,4 @@ +import dataclasses from copy import deepcopy from dataclasses import dataclass @@ -6,7 +7,7 @@ from eitprocessing.datahandling.mixins.equality import Equivalence -@dataclass(eq=False) +@dataclass(eq=False, frozen=True) class DataContainer(Equivalence): """Base class for data container classes.""" @@ -16,3 +17,14 @@ def __bool__(self): def deepcopy(self) -> Self: """Return a deep copy of the object.""" return deepcopy(self) + + def update(self: Self, **kwargs: object) -> Self: + """Return a copy of the object with specified fields replaced. + + Args: + **kwargs: Fields to replace. + + Returns: + A new instance of the object with the specified fields replaced. + """ + return dataclasses.replace(self, **kwargs) diff --git a/eitprocessing/datahandling/breath.py b/eitprocessing/datahandling/breath.py index ddb5e5702..a68356525 100644 --- a/eitprocessing/datahandling/breath.py +++ b/eitprocessing/datahandling/breath.py @@ -1,8 +1,13 @@ -from collections.abc import Iterator +from __future__ import annotations + from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterator -@dataclass +@dataclass(frozen=True) class Breath: """Represents a breath with a start, middle and end time.""" @@ -14,7 +19,7 @@ def __post_init__(self): if self.start_time >= self.middle_time or self.middle_time >= self.end_time: msg = ( "Start, middle and end should be consecutive, not " - "{self.start_time:.2f}, {self.middle_time:.2f} and {self.end_time:.2f}" + f"{self.start_time:.2f}, {self.middle_time:.2f} and {self.end_time:.2f}" ) raise ValueError(msg) diff --git a/eitprocessing/datahandling/continuousdata.py b/eitprocessing/datahandling/continuousdata.py index 8ca919465..e119c8522 100644 --- a/eitprocessing/datahandling/continuousdata.py +++ b/eitprocessing/datahandling/continuousdata.py @@ -1,23 +1,24 @@ from __future__ import annotations import warnings -from dataclasses import dataclass, field +from dataclasses import KW_ONLY, dataclass, field from typing import TYPE_CHECKING, TypeVar import numpy as np from eitprocessing.datahandling import DataContainer from eitprocessing.datahandling.mixins.slicing import SelectByTime +from eitprocessing.utils.frozen_array import freeze_array if TYPE_CHECKING: from collections.abc import Callable - from typing_extensions import Any, Self + from typing_extensions import Self T = TypeVar("T", bound="ContinuousData") -@dataclass(eq=False) +@dataclass(eq=False, frozen=True) class ContinuousData(DataContainer, SelectByTime): """Container for data with a continuous time axis. @@ -32,27 +33,20 @@ class ContinuousData(DataContainer, SelectByTime): unit: Unit of the data, if applicable. category: Category the data falls into, e.g. 'airway pressure'. description: Human readable extended description of the data. - parameters: Parameters used to derive this data. - derived_from: Traceback of intermediates from which the current data was derived. values: Data points. """ - label: str = field(compare=False) - name: str = field(compare=False, repr=False) - unit: str = field(metadata={"check_equivalence": True}, repr=False) - category: str = field(metadata={"check_equivalence": True}, repr=False) - description: str = field(default="", compare=False, repr=False) - parameters: dict[str, Any] = field(default_factory=dict, repr=False, metadata={"check_equivalence": True}) - derived_from: Any | list[Any] = field(default_factory=list, repr=False, compare=False) - time: np.ndarray = field(kw_only=True, repr=False) - values: np.ndarray = field(kw_only=True, repr=False) + time: np.ndarray = field(repr=False) + values: np.ndarray = field(repr=False) + _: KW_ONLY + label: str | None = field(compare=False, default=None) + name: str | None = field(compare=False, repr=False, default=None) + description: str | None = field(compare=False, repr=False, default=None) + unit: str | None = field(metadata={"check_equivalence": True}, repr=False, default=None) + category: str | None = field(metadata={"check_equivalence": True}, repr=False, default=None) sample_frequency: float | None = field(kw_only=True, repr=False, metadata={"check_equivalence": True}, default=None) def __post_init__(self) -> None: - if self.loaded: - self.lock() - self.lock("time") - if self.sample_frequency is None: msg = ( "`sample_frequency` is set to `None`. This will not be supported in future versions. " @@ -64,48 +58,8 @@ def __post_init__(self) -> None: msg = f"The number of time points ({lt}) does not match the number of values ({lv})." raise ValueError(msg) - def __setattr__(self, attr: str, value: Any): # noqa: ANN401 - try: - old_value = getattr(self, attr) - except AttributeError: - pass - else: - if isinstance(old_value, np.ndarray) and old_value.flags["WRITEABLE"] is False: - msg = f"Attribute '{attr}' is locked and can't be overwritten." - raise AttributeError(msg) - super().__setattr__(attr, value) - - def copy( - self, - label: str, - *, - name: str | None = None, - unit: str | None = None, - description: str | None = None, - parameters: dict | None = None, - ) -> Self: - """Create a copy. - - Whenever data is altered, it should probably be copied first. The alterations should then be made in the copy. - """ - obj = self.__class__( - label=label, - name=name or label, - unit=unit or self.unit, - description=description or f"Derived from {self.name}", - parameters=self.parameters | (parameters or {}), - derived_from=[*self.derived_from, self], - category=self.category, - # copying data can become inefficient with large datasets if the - # data is not directly edited afer copying but overridden instead; - # consider creating a view and locking it, requiring the user to - # make a copy if they want to edit the data directly - time=np.copy(self.time), - values=np.copy(self.values), - sample_frequency=self.sample_frequency, - ) - obj.unlock() - return obj + object.__setattr__(self, "time", freeze_array(self.time)) + object.__setattr__(self, "values", freeze_array(self.values)) def __add__(self: Self, other: Self) -> Self: return self.concatenate(other) @@ -128,7 +82,6 @@ def concatenate(self: Self, other: Self, newlabel: str | None = None) -> Self: category=self.category, time=np.concatenate((self.time, other.time)), values=np.concatenate((self.values, other.values)), - derived_from=[*self.derived_from, *other.derived_from, self, other], sample_frequency=self.sample_frequency, ) @@ -173,61 +126,6 @@ def convert_data(x, add=None, subtract=None, multiply=None, divide=None): copy.values = function(copy.values, **func_args) return copy - def lock(self, *attr: str) -> None: - """Lock attributes, essentially rendering them read-only. - - Locked attributes cannot be overwritten. Attributes can be unlocked using `unlock()`. - - Args: - *attr: any number of attributes can be passed here, all of which will be locked. Defaults to "values". - - Examples: - >>> # lock the `values` attribute of `data` - >>> data.lock() - >>> data.values = [1, 2, 3] # will result in an AttributeError - >>> data.values[0] = 1 # will result in a RuntimeError - """ - if not attr: - # default values are not allowed when using *attr, so set a default here if none is supplied - attr = ("values",) - for attr_ in attr: - getattr(self, attr_).flags["WRITEABLE"] = False - - def unlock(self, *attr: str) -> None: - """Unlock attributes, rendering them editable. - - Locked attributes cannot be overwritten, but can be unlocked with this function to make them editable. - - Args: - *attr: any number of attributes can be passed here, all of which will be unlocked. Defaults to "values". - - Examples: - >>> # lock the `values` attribute of `data` - >>> data.lock() - >>> data.values = [1, 2, 3] # will result in an AttributeError - >>> data.values[0] = 1 # will result in a RuntimeError - >>> data.unlock() - >>> data.values = [1, 2, 3] - >>> print(data.values) - [1,2,3] - >>> data.values[0] = 1 # will result in a RuntimeError - >>> print(data.values) - 1 - """ - if not attr: - # default values are not allowed when using *attr, so set a default here if none is supplied - attr = ("values",) - for attr_ in attr: - getattr(self, attr_).flags["WRITEABLE"] = True - - @property - def locked(self) -> bool: - """Return whether the values attribute is locked. - - See lock(). - """ - return not self.values.flags["WRITEABLE"] - @property def loaded(self) -> bool: """Return whether the data was loaded from disk, or derived from elsewhere.""" @@ -243,19 +141,7 @@ def _sliced_copy( newlabel: str, # noqa: ARG002 ) -> Self: # TODO: check correct implementation - cls = self.__class__ - time = np.copy(self.time[start_index:end_index]) - values = np.copy(self.values[start_index:end_index]) - description = f"Slice ({start_index}-{end_index}) of <{self.description}>" + time = self.time[start_index:end_index] + values = self.values[start_index:end_index] - return cls( - label=self.label, # TODO: newlabel gives errors - name=self.name, - unit=self.unit, - category=self.category, - description=description, - derived_from=[*self.derived_from, self], - time=time, - values=values, - sample_frequency=self.sample_frequency, - ) + return self.update(time=time, values=values) diff --git a/eitprocessing/datahandling/datacollection.py b/eitprocessing/datahandling/datacollection.py index 25bd8874a..a9952dbec 100644 --- a/eitprocessing/datahandling/datacollection.py +++ b/eitprocessing/datahandling/datacollection.py @@ -93,18 +93,6 @@ def _check_item( ) raise KeyError(msg) - def get_loaded_data(self) -> dict[str, V]: - """Return all data that was directly loaded from disk.""" - return {k: v for k, v in self.items() if v.loaded} - - def get_data_derived_from(self, obj: V) -> dict[str, V]: - """Return all data that was derived from a specific source.""" - return {k: v for k, v in self.items() if any(obj is item for item in v.derived_from)} - - def get_derived_data(self) -> dict[str, V]: - """Return all data that was derived from any source.""" - return {k: v for k, v in self.items() if v.derived_from} - def concatenate(self: Self, other: Self) -> Self: """Concatenate this collection with an equivalent collection. diff --git a/eitprocessing/datahandling/eitdata.py b/eitprocessing/datahandling/eitdata.py index ea50d9bf0..ae3ae34d1 100644 --- a/eitprocessing/datahandling/eitdata.py +++ b/eitprocessing/datahandling/eitdata.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from dataclasses import InitVar, dataclass, field +from collections.abc import Sequence as SequenceType +from dataclasses import KW_ONLY, InitVar, dataclass, field from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any, TypeVar @@ -11,6 +12,7 @@ from eitprocessing.datahandling import DataContainer from eitprocessing.datahandling.continuousdata import ContinuousData from eitprocessing.datahandling.mixins.slicing import SelectByTime +from eitprocessing.utils.frozen_array import freeze_array if TYPE_CHECKING: from typing_extensions import Self @@ -19,7 +21,7 @@ T = TypeVar("T", bound="EITData") -@dataclass(eq=False) +@dataclass(eq=False, frozen=True) class EITData(DataContainer, SelectByTime): """Container for EIT impedance data. @@ -30,57 +32,118 @@ class is meant to hold data from (part of) a singular continuous measurement. disk. Args: - path: The path of list of paths of the source from which data was derived. - nframes: Number of frames. time: The time of each frame (since start measurement). + values: Impedance values for each pixel at each frame. sample_frequency: The (average) frequency at which the frames are collected, in Hz. vendor: The vendor of the device the data was collected with. + path: The path of list of paths of the source from which data was derived. label: Computer readable label identifying this dataset. name: Human readable name for the data. - pixel_impedance: Impedance values for each pixel at each frame. + description: Human readable description of the data. """ # TODO: fix docstring - path: str | Path | list[Path | str] = field(compare=False, repr=False) - nframes: int = field(repr=False) time: np.ndarray = field(repr=False) + values: np.ndarray = field(repr=False) + _: KW_ONLY sample_frequency: float = field(metadata={"check_equivalence": True}, repr=False) vendor: Vendor = field(metadata={"check_equivalence": True}, repr=False) + path: tuple[Path] | None = field(compare=False, repr=False, default=None) label: str | None = field(default=None, compare=False, metadata={"check_equivalence": True}) - description: str = field(default="", compare=False, repr=False) + description: str | None = field(default=None, compare=False, repr=False) name: str | None = field(default=None, compare=False, repr=False) - pixel_impedance: np.ndarray = field(repr=False, kw_only=True) suppress_simulated_warning: InitVar[bool] = False - def __post_init__(self, suppress_simulated_warning: bool) -> None: - if not self.label: - self.label = f"{self.__class__.__name__}_{id(self)}" - - self.path = self.ensure_path_list(self.path) - if len(self.path) == 1: - self.path = self.path[0] + def __init__( + self, + time: np.ndarray, + values: np.ndarray | None = None, + *, + sample_frequency: float, + vendor: Vendor | str, + path: str | Path | SequenceType[Path | str] | None = None, + label: str | None = None, + description: str | None = None, + name: str | None = None, + suppress_simulated_warning: bool = False, + **kwargs, + ): + values = self._parse_kwargs(values, kwargs) + + if not isinstance(values, np.ndarray): + msg = f"'values' must be a numpy ndarray, not {type(values)}." + raise TypeError(msg) - self.name = self.name or self.label - old_sample_frequency = self.sample_frequency - self.sample_frequency = float(self.sample_frequency) - if self.sample_frequency != old_sample_frequency: + label = label or f"{self.__class__.__name__}_{id(self)}" + object.__setattr__(self, "label", label) + object.__setattr__(self, "name", name) + object.__setattr__(self, "description", description) + + if path is None: + object.__setattr__(self, "path", None) + else: + path_list = self.ensure_path_tuple(path) + if len(path_list) == 1: + object.__setattr__(self, "path", path_list[0]) + else: + object.__setattr__(self, "path", path_list) + + object.__setattr__(self, "sample_frequency", float(sample_frequency)) + if self.sample_frequency != sample_frequency: msg = ( "Sample frequency could not be correctly converted from " - f"{old_sample_frequency} ({type(old_sample_frequency)}) to " + f"{sample_frequency} ({type(sample_frequency)}) to " f"{self.sample_frequency:.1f} (float)." ) raise TypeError(msg) - if (lv := len(self.pixel_impedance)) != (lt := len(self.time)): + if (lv := len(values)) != (lt := len(time)): msg = f"The number of time points ({lt}) does not match the number of pixel impedance values ({lv})." raise ValueError(msg) - if not suppress_simulated_warning and self.vendor == Vendor.SIMULATED: + object.__setattr__(self, "values", freeze_array(values)) + object.__setattr__(self, "time", freeze_array(time)) + + vendor = Vendor(vendor) + if not suppress_simulated_warning and vendor == Vendor.SIMULATED: warnings.warn( "The simulated vendor is used for testing purposes. " "It is not a real vendor and should not be used in production code.", UserWarning, stacklevel=2, ) + object.__setattr__(self, "vendor", vendor) + + def _parse_kwargs(self, values: np.ndarray | None, kwargs: dict[str, Any]) -> np.ndarray | None: + if "pixel_impedance" in kwargs: + if values is not None: + msg = "Cannot provide both 'pixel_impedance' and 'values'." + raise ValueError(msg) + warnings.warn("`pixel_impedance` has been replaced by `values`.", DeprecationWarning, stacklevel=2) + values = kwargs.pop("pixel_impedance") + + if "nframes" in kwargs: + warnings.warn( + "`nframes` is no longer a constructor argument. Use `len(eitdata)` instead.", + DeprecationWarning, + stacklevel=2, + ) + _ = kwargs.pop("nframes") + + if kwargs: + msg = f"Unexpected keyword arguments: {', '.join(kwargs.keys())}." + raise TypeError(msg) + return values + + @property + def pixel_impedance(self) -> np.ndarray: + """Alias to `values`.""" + return self.values + + @property + def nframes(self) -> int: + """Number of frames in the data.""" + warnings.warn("`nframes` is deprecated. Use `len(eitdata)` instead.", DeprecationWarning, stacklevel=2) + return self.values.shape[0] @property def framerate(self) -> float: @@ -93,17 +156,17 @@ def framerate(self) -> float: return self.sample_frequency @staticmethod - def ensure_path_list( - path: str | Path | list[str | Path], - ) -> list[Path]: + def ensure_path_tuple( + path: str | Path | SequenceType[str | Path], + ) -> tuple[Path, ...]: """Return the path or paths as a list. The path of any EITData object can be a single str/Path or a list of str/Path objects. This method returns a list of Path objects given either a str/Path or list of str/Paths. """ - if isinstance(path, list): - return [Path(p) for p in path] - return [Path(path)] + if isinstance(path, SequenceType): + return tuple(Path(p) for p in path) + return (Path(path),) def __add__(self: Self, other: Self) -> Self: return self.concatenate(other) @@ -115,18 +178,20 @@ def concatenate(self: Self, other: Self, newlabel: str | None = None) -> Self: msg = f"Concatenation failed. Second dataset ({other.name}) may not start before first ({self.name}) ends." raise ValueError(msg) - self_path = self.ensure_path_list(self.path) - other_path = self.ensure_path_list(other.path) + self_path = () if self.path is None else self.ensure_path_tuple(self.path) + other_path = () if other.path is None else self.ensure_path_tuple(other.path) + concat_path = (*self_path, *other_path) + if not concat_path: + concat_path = None newlabel = newlabel or f"Merge of <{self.label}> and <{other.label}>" return self.__class__( vendor=self.vendor, - path=[*self_path, *other_path], + path=concat_path, label=self.label, # TODO: using newlabel leads to errors sample_frequency=self.sample_frequency, - nframes=self.nframes + other.nframes, time=np.concatenate((self.time, other.time)), - pixel_impedance=np.concatenate((self.pixel_impedance, other.pixel_impedance), axis=0), + values=np.concatenate((self.values, other.values), axis=0), ) def _sliced_copy( @@ -135,24 +200,13 @@ def _sliced_copy( end_index: int, newlabel: str, # noqa: ARG002 ) -> Self: - cls = self.__class__ - time = np.copy(self.time[start_index:end_index]) - nframes = len(time) - - pixel_impedance = np.copy(self.pixel_impedance[start_index:end_index, :, :]) - - return cls( - path=self.path, - nframes=nframes, - vendor=self.vendor, - time=time, - sample_frequency=self.sample_frequency, - label=self.label, # newlabel gives errors - pixel_impedance=pixel_impedance, + return self.update( + time=self.time[start_index:end_index], + values=self.values[start_index:end_index, :, :], ) def __len__(self): - return self.pixel_impedance.shape[0] + return self.values.shape[0] def get_summed_impedance(self, *, return_label: str | None = None, **return_kwargs) -> ContinuousData: """Return a ContinuousData-object with the same time axis and summed pixel values over time. @@ -162,7 +216,7 @@ def get_summed_impedance(self, *, return_label: str | None = None, **return_kwar the current object. **return_kwargs: Keyword arguments for the creation of the returned object. """ - summed_impedance = np.nansum(self.pixel_impedance, axis=(1, 2)) + summed_impedance = np.nansum(self.values, axis=(1, 2)) if return_label is None: return_label = f"summed {self.label}" @@ -174,11 +228,40 @@ def get_summed_impedance(self, *, return_label: str | None = None, **return_kwar "sample_frequency": self.sample_frequency, } | return_kwargs - return ContinuousData(label=return_label, time=np.copy(self.time), values=summed_impedance, **return_kwargs_) + return ContinuousData(label=return_label, time=self.time, values=summed_impedance, **return_kwargs_) def calculate_global_impedance(self) -> np.ndarray: """Return the global impedance, i.e. the sum of all included pixels at each frame.""" - return np.nansum(self.pixel_impedance, axis=(1, 2)) + return np.nansum(self.values, axis=(1, 2)) + + def update(self, **kwargs) -> Self: + """Return a copy of the object with specified fields replaced. + + Args: + **kwargs: Fields to replace. + + Returns: + A new instance of the object with the specified fields replaced. + """ + if "pixel_impedance" in kwargs: + if "values" in kwargs: + msg = "Cannot provide both 'pixel_impedance' and 'values'." + raise ValueError(msg) + warnings.warn("`pixel_impedance` has been replaced by `values`.", DeprecationWarning, stacklevel=2) + kwargs["values"] = kwargs.pop("pixel_impedance") + + if "framerate" in kwargs: + if "sample_frequency" in kwargs: + msg = "Cannot provide both 'framerate' and 'sample_frequency'." + raise ValueError(msg) + warnings.warn( + "`framerate` has been deprecated. Use `sample_frequency` instead.", + DeprecationWarning, + stacklevel=2, + ) + kwargs["sample_frequency"] = kwargs.pop("framerate") + + return super().update(**kwargs) class Vendor(Enum): diff --git a/eitprocessing/datahandling/event.py b/eitprocessing/datahandling/event.py index d1a1c3aa4..19d7715ec 100644 --- a/eitprocessing/datahandling/event.py +++ b/eitprocessing/datahandling/event.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -@dataclass +@dataclass(frozen=True) class Event: """Single time point event registered during an EIT measurement.""" diff --git a/eitprocessing/datahandling/intervaldata.py b/eitprocessing/datahandling/intervaldata.py index 2ae95826a..73c0fd813 100644 --- a/eitprocessing/datahandling/intervaldata.py +++ b/eitprocessing/datahandling/intervaldata.py @@ -1,24 +1,39 @@ +import contextlib import copy +import itertools +from collections.abc import Iterator from dataclasses import dataclass, field -from typing import Any, NamedTuple, TypeVar +from typing import Any, TypeVar import numpy as np from typing_extensions import Self from eitprocessing.datahandling import DataContainer from eitprocessing.datahandling.mixins.slicing import HasTimeIndexer, SelectByIndex +from eitprocessing.datahandling.structured_array import StructuredArray T = TypeVar("T", bound="IntervalData") -class Interval(NamedTuple): +@dataclass(frozen=True) +class Interval: """A tuple containing the start time and end time of an interval.""" start_time: float end_time: float + def __post_init__(self): + if self.start_time >= self.end_time: + msg = f"Interval start time ({self.start_time:.2f}) must be less than end time ({self.end_time:.2f})." + raise ValueError(msg) + + @property + def duration(self) -> float: + """Duration of the interval.""" + return self.end_time - self.start_time -@dataclass(eq=False) + +@dataclass(eq=False, frozen=True) class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): """Container for interval data existing over a period of time. @@ -43,8 +58,6 @@ class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): category: Category the data falls into, e.g. 'breath'. intervals: A list of intervals (tuples containing a start time and end time). values: An optional list of values associated with each interval. - parameters: Parameters used to derive the data. - derived_from: Traceback of intermediates from which the current data was derived. description: Extended human readible description of the data. default_partial_inclusion: Whether to include a trimmed version of an interval when selecting data """ @@ -53,20 +66,78 @@ class IntervalData(DataContainer, SelectByIndex, HasTimeIndexer): name: str = field(compare=False, repr=False) unit: str | None = field(metadata={"check_equivalence": True}, repr=False) category: str = field(metadata={"check_equivalence": True}, repr=False) - intervals: list[Interval | tuple[float, float]] = field(repr=False) - values: list[Any] | None = field(repr=False, default=None) - parameters: dict[str, Any] = field(default_factory=dict, metadata={"check_equivalence": True}, repr=False) - derived_from: list[Any] = field(default_factory=list, compare=False, repr=False) + intervals: StructuredArray[Interval] = field(repr=False) + values: list[Any] | np.ndarray | None | StructuredArray = field(repr=False, default=None) description: str = field(compare=False, default="", repr=False) default_partial_inclusion: bool = field(repr=False, default=False) def __post_init__(self) -> None: - self.intervals = [Interval._make(interval) for interval in self.intervals] + object.__setattr__(self, "intervals", self._parse_intervals(self.intervals)) + + if self.values is not None and len(self.values) == 0: + object.__setattr__(self, "values", None) - if self.has_values and (lv := len(self.values)) != (lt := len(self.intervals)): + if self.values is not None and (lv := len(self.values)) != (lt := len(self.intervals)): msg = f"The number of time points ({lt}) does not match the number of values ({lv})." raise ValueError(msg) + if isinstance(self.values, list): + try: + object.__setattr__(self, "values", StructuredArray(self.values)) + except TypeError: + object.__setattr__(self, "values", np.array(self.values)) + + def __iter__(self) -> Iterator[tuple[Interval, Any | None]]: + if self.values is not None: + return iter(zip(self.intervals, self.values, strict=True)) + + return iter(zip(self.intervals, itertools.repeat(None), strict=False)) + + @staticmethod + def _parse_intervals( + intervals: list[Interval] | np.ndarray | StructuredArray[Interval], + ) -> StructuredArray[Interval]: + """Parse intervals into a StructuredArray of Interval.""" + if isinstance(intervals, StructuredArray): + if intervals.item_type is not Interval: + msg = f"Expected intervals of type 'Interval', got '{intervals.dtype}'" + raise TypeError(msg) + if intervals.items.ndim != 2 or intervals.items.shape[1] != 2: # noqa: PLR2004 + msg = f"Intervals should be a 1D array of Interval, got {intervals.items.ndim + 1}D array." + return intervals + + if isinstance(intervals, np.ndarray): + if intervals.ndim != 2 or intervals.shape[1] != 2: # noqa: PLR2004 + msg = f"Intervals should be a 2D array (1D array of Interval data), got {intervals.ndim}D array." + raise ValueError(msg) + try: + return StructuredArray.from_array(intervals, item_type=Interval) + except ValueError as e: + msg = f"Could not parse intervals from numpy array with dtype '{intervals.dtype}'" + raise TypeError(msg) from e + + if isinstance(intervals, list | tuple): + with contextlib.suppress(TypeError): + return StructuredArray(intervals, item_type=Interval) + + with contextlib.suppress(TypeError): + return StructuredArray([Interval(*interval) for interval in intervals], item_type=Interval) + + with contextlib.suppress(TypeError): + return StructuredArray.from_array(intervals, item_type=Interval) + + msg = ( + "Could not parse intervals from given input. Intervals should be a list of Interval objects, a 2D " + "numpy array, or a StructuredArray of Interval." + ) + raise TypeError(msg) + + try: + return StructuredArray.from_array(intervals, Interval) + except Exception as e: + msg = "Could not parse intervals from given input." + raise TypeError(msg) from e + def __len__(self) -> int: return len(self.intervals) @@ -92,9 +163,9 @@ def _sliced_copy( unit=self.unit, category=self.category, description=description, - derived_from=[*self.derived_from, self], intervals=intervals, values=values, + default_partial_inclusion=self.default_partial_inclusion, ) def select_by_time( @@ -128,12 +199,6 @@ def select_by_time( """ newlabel = newlabel or self.label - if start_time is None and end_time is None: - copy_ = copy.deepcopy(self) - copy_.derived_from.append(self) - copy_.label = newlabel - return copy_ - partial_inclusion = partial_inclusion or self.default_partial_inclusion selection_start = start_time or self.intervals[0].start_time @@ -157,9 +222,9 @@ def select_by_time( name=self.name, unit=self.unit, category=self.category, - derived_from=[*self.derived_from, self], intervals=list(filtered_intervals), values=values, + default_partial_inclusion=self.default_partial_inclusion, ) @staticmethod @@ -228,7 +293,6 @@ def concatenate(self: Self, other: Self, newlabel: str | None = None) -> Self: unit=self.unit, category=self.category, description=self.description, - derived_from=[*self.derived_from, *other.derived_from, self, other], intervals=self.intervals + other.intervals, values=new_values, ) diff --git a/eitprocessing/datahandling/loading/__init__.py b/eitprocessing/datahandling/loading/__init__.py index b69180f54..e8acd5071 100644 --- a/eitprocessing/datahandling/loading/__init__.py +++ b/eitprocessing/datahandling/loading/__init__.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence as SequenceType from functools import reduce from pathlib import Path @@ -7,7 +8,7 @@ def load_eit_data( - path: str | Path | list[str | Path], + path: str | Path | SequenceType[str | Path], vendor: Vendor | str, label: str | None = None, name: str | None = None, @@ -52,7 +53,7 @@ def load_eit_data( ... vendor="sentec", ... label="initial_measurement" ... ) - >>> pixel_impedance = sequence.eit_data["raw"].pixel_impedance + >>> pixel_impedance = sequence.eit_data["raw"].values ``` """ from eitprocessing.datahandling.loading import ( # noqa: PLC0415 @@ -70,7 +71,7 @@ def load_eit_data( first_frame = _check_first_frame(first_frame) - paths = EITData.ensure_path_list(path) + paths = EITData.ensure_path_tuple(path) eit_datasets: list[DataCollection] = [] continuous_datasets: list[DataCollection] = [] diff --git a/eitprocessing/datahandling/loading/draeger.py b/eitprocessing/datahandling/loading/draeger.py index 5fb048fd8..d3383c9ca 100644 --- a/eitprocessing/datahandling/loading/draeger.py +++ b/eitprocessing/datahandling/loading/draeger.py @@ -120,10 +120,9 @@ def load_from_single_path( # noqa: PLR0915 vendor=Vendor.DRAEGER, path=path, sample_frequency=sample_frequency, - nframes=n_frames, time=time, label="raw", - pixel_impedance=pixel_impedance, + values=pixel_impedance, ) eitdata_collection = DataCollection(EITData, raw=eit_data) @@ -141,7 +140,6 @@ def load_from_single_path( # noqa: PLR0915 name="Global impedance (raw)", unit="a.u.", category="impedance", - derived_from=[eit_data], time=eit_data.time, values=eit_data.calculate_global_impedance(), sample_frequency=sample_frequency, @@ -153,7 +151,6 @@ def load_from_single_path( # noqa: PLR0915 name="Minimum values detected by Draeger device.", unit=None, category="minvalue", - derived_from=[eit_data], time=np.array([t for t, d in phases if d == -1]), ), ) @@ -163,7 +160,6 @@ def load_from_single_path( # noqa: PLR0915 name="Maximum values detected by Draeger device.", unit=None, category="maxvalue", - derived_from=[eit_data], time=np.array([t for t, d in phases if d == 1]), ), ) @@ -179,7 +175,6 @@ def load_from_single_path( # noqa: PLR0915 name="Events loaded from Draeger data", unit=None, category="event", - derived_from=[eit_data], time=time, values=events, ), @@ -260,7 +255,6 @@ def _convert_medibus_data( category=field_info.signal_name, sample_frequency=sample_frequency, ) - continuous_data.lock() continuousdata_collection.add(continuous_data) else: diff --git a/eitprocessing/datahandling/loading/sentec.py b/eitprocessing/datahandling/loading/sentec.py index 5a252eda2..aee08c606 100644 --- a/eitprocessing/datahandling/loading/sentec.py +++ b/eitprocessing/datahandling/loading/sentec.py @@ -86,10 +86,9 @@ def load_from_single_path( vendor=Vendor.SENTEC, path=path, sample_frequency=sample_frequency, - nframes=n_frames, time=time_array, label="raw", - pixel_impedance=image, + values=image, ), ) diff --git a/eitprocessing/datahandling/loading/timpel.py b/eitprocessing/datahandling/loading/timpel.py index 2a93107f7..286641a62 100644 --- a/eitprocessing/datahandling/loading/timpel.py +++ b/eitprocessing/datahandling/loading/timpel.py @@ -10,9 +10,10 @@ from eitprocessing.datahandling.continuousdata import ContinuousData from eitprocessing.datahandling.datacollection import DataCollection from eitprocessing.datahandling.eitdata import EITData, Vendor -from eitprocessing.datahandling.intervaldata import IntervalData +from eitprocessing.datahandling.intervaldata import Interval, IntervalData from eitprocessing.datahandling.loading import load_eit_data from eitprocessing.datahandling.sparsedata import SparseData +from eitprocessing.datahandling.structured_array import StructuredArray if TYPE_CHECKING: from pathlib import Path @@ -96,10 +97,9 @@ def load_from_single_path( vendor=Vendor.TIMPEL, label="raw", path=path, - nframes=nframes, time=time, sample_frequency=sample_frequency, - pixel_impedance=pixel_impedance, + values=pixel_impedance, ) eitdata_collection = DataCollection(EITData, raw=eit_data) @@ -109,13 +109,13 @@ def load_from_single_path( continuousdata_collection = DataCollection(ContinuousData) continuousdata_collection.add( ContinuousData( - "global_impedance_(raw)", - "Global impedance", - "a.u.", - "global_impedance", - "Global impedance calculated from raw EIT data", time=time, values=eit_data.calculate_global_impedance(), + label="global_impedance_(raw)", + name="Global impedance", + unit="a.u.", + category="global_impedance", + description="Global impedance calculated from raw EIT data", sample_frequency=sample_frequency, ), ) @@ -168,7 +168,6 @@ def load_from_single_path( name="Minimum values detected by Timpel device.", unit=None, category="minvalue", - derived_from=[eit_data], time=time[min_indices], ), ) @@ -180,14 +179,13 @@ def load_from_single_path( name="Maximum values detected by Timpel device.", unit=None, category="maxvalue", - derived_from=[eit_data], time=time[max_indices], ), ) gi = continuousdata_collection["global_impedance_(raw)"].values - time_ranges, breaths = _make_breaths(time, min_indices, max_indices, gi) + intervals, breaths = _make_breaths(time, min_indices, max_indices, gi) intervaldata_collection = DataCollection(IntervalData) intervaldata_collection.add( IntervalData( @@ -195,7 +193,7 @@ def load_from_single_path( name="Breaths (Timpel)", unit=None, category="breaths", - intervals=time_ranges, + intervals=intervals, values=breaths, default_partial_inclusion=False, ), @@ -208,7 +206,6 @@ def load_from_single_path( name="QRS complexes detected by Timpel device", unit=None, category="qrs_complex", - derived_from=[eit_data], time=time[qrs_indices], ), ) @@ -226,12 +223,12 @@ def _make_breaths( min_indices: np.ndarray, max_indices: np.ndarray, gi: np.ndarray, -) -> tuple[list[tuple[float, float]], list[Breath]]: +) -> tuple[StructuredArray[Interval], StructuredArray[Breath]]: # TODO: replace section with BreathDetection._remove_doubles() and BreathDetection._remove_edge_cases() from # 41_breath_detection_psomhorst; this code was directly copied from b59ac54 if len(min_indices) < 2 or len(max_indices) < 1: # noqa: PLR2004 - return [], [] + return StructuredArray([], Interval), StructuredArray([], Breath) valley_indices = min_indices.copy() peak_indices = max_indices.copy() @@ -276,9 +273,8 @@ def _make_breaths( current_valley_index += 1 - breaths = [] - for start, end, middle in zip(valley_indices[:-1], valley_indices[1:], peak_indices, strict=True): - breaths.append(((time[start], time[end]), Breath(time[start], time[middle], time[end]))) + times = time[np.column_stack([valley_indices[:-1], peak_indices, valley_indices[1:]])] + breaths = StructuredArray.from_array(times, Breath) + intervals = StructuredArray.from_array(times[:, [0, 2]], Interval) - time_ranges, values = zip(*breaths, strict=True) - return list(time_ranges), list(values) + return intervals, breaths diff --git a/eitprocessing/datahandling/sparsedata.py b/eitprocessing/datahandling/sparsedata.py index 2f0bd8365..f23177fb9 100644 --- a/eitprocessing/datahandling/sparsedata.py +++ b/eitprocessing/datahandling/sparsedata.py @@ -11,7 +11,7 @@ T = TypeVar("T", bound="SparseData") -@dataclass(eq=False) +@dataclass(eq=False, frozen=True) class SparseData(DataContainer, SelectByTime): """Container for data related to individual time points. @@ -30,8 +30,6 @@ class SparseData(DataContainer, SelectByTime): unit: Unit of the data, if applicable. category: Category the data falls into, e.g. 'detected r peak'. description: Human readable extended description of the data. - parameters: Parameters used to derive the data. - derived_from: Traceback of intermediates from which the current data was derived. values: List or array of values. These van be numeric data, text or Python objects. """ @@ -41,8 +39,6 @@ class SparseData(DataContainer, SelectByTime): category: str = field(metadata={"check_equivalence": True}, repr=False) time: np.ndarray = field(repr=False) description: str = field(compare=False, default="", repr=False) - parameters: dict[str, Any] = field(default_factory=dict, metadata={"check_equivalence": True}, repr=False) - derived_from: list[Any] = field(default_factory=list, compare=False, repr=False) values: Any | None = None def __repr__(self) -> str: @@ -78,7 +74,6 @@ def _sliced_copy( unit=self.unit, category=self.category, description=description, - derived_from=[*self.derived_from, self], time=time, values=values, ) @@ -122,7 +117,6 @@ def concatenate(self, other: Self, newlabel: str | None = None) -> Self: # noqa unit=self.unit, category=self.category, description=self.description, - derived_from=[*self.derived_from, *other.derived_from, self, other], time=np.concatenate((self.time, other.time)), values=new_values, ) diff --git a/eitprocessing/datahandling/structured_array.py b/eitprocessing/datahandling/structured_array.py new file mode 100644 index 000000000..defb47a92 --- /dev/null +++ b/eitprocessing/datahandling/structured_array.py @@ -0,0 +1,993 @@ +# %% +"""Array-like interface over NamedTuple and dataclass collections enabling NumPy slicing. + +Motivation +---------- +Some data is best represented with multiple related data points (e.g., the start, middle and end time of a breath). They +can be represented as NamedTuple or dataclass instances; lightweight containers that group related fields together. +Lists (or tuples) of such instances are more difficult to handle efficiently compared to NumPy arrays, especially in +multi-dimensional cases. When such instances are collected inside NumPy arrays, however, they lose their context, +removing the field names and data types. + +This module provides StructuredArray, a container that wraps homogeneous collections of NamedTuple or dataclass +instances into a NumPy structured array, preserving field names and types while enabling NumPy-style slicing and +field-wise access. It allows access to fields and even computed properties by name. Dataclass instances benefit from +automatic __post_init__ validation on access. + +Key features +------------ +- Homogeneous type checking: ensures all items share the same NamedTuple or dataclass type. +- Safe field views: returns read-only views for direct field access. +- Property evaluation: computes per-item properties, resolving postponed + annotations to pick appropriate NumPy dtypes. +- Shape preservation: supports nested sequences, maintaining their shape in the + structured array. +- Validation support: dataclass __post_init__ is called on item access. +- Interop: from_array helper to map array columns to fields. + +Example with NamedTuple: +```python +from typing import NamedTuple + +class Coordinate(NamedTuple): + x: float + y: float + z: float + + @property + def r(self) -> float: + \"\"\"The radial distance from the origin.\"\"\" + return (self.x**2 + self.y**2 + self.z**2) ** 0.5 + +coords = [Coordinate(1.0, 2.0, 2.0), Coordinate(3.0, 4.0, 0.0), Coordinate(0.0, 0.0, 5.0)] +arr = StructuredArray(coords) + +arr[0] # Access a single Coordinate +# Coordinate(x=1.0, y=2.0, z=2.0) +arr["x"] # Access x field across all Coordinates +# array([1., 3., 0.]) +arr["r"] # Access computed property across all Coordinates +# array([3., 5., 5.]) +``` + +Example with dataclass: +```python +from dataclasses import dataclass + +@dataclass +class Point: + x: float + y: float + + def __post_init__(self): + if self.x < 0 or self.y < 0: + raise ValueError("Coordinates must be non-negative") + +points = [Point(1.0, 2.0), Point(3.0, 4.0)] +arr = StructuredArray(points) # Validation runs on item construction/access +arr[0] # Point(x=1.0, y=2.0) +``` +""" + +from __future__ import annotations + +import contextlib +import warnings +from dataclasses import fields as dc_fields +from dataclasses import is_dataclass +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Generic, + NamedTuple, + TypeAlias, + TypeGuard, + TypeVar, + Union, + cast, + get_args, + get_origin, + get_type_hints, + overload, +) + +import numpy as np + +from eitprocessing.utils.frozen_array import freeze_array + +if TYPE_CHECKING: + from collections.abc import Generator + from collections.abc import Sequence as SequenceType + + from numpy._core.multiarray import flagsobj + + +T = TypeVar("T", bound=tuple | object) # NamedTuple or dataclass instance +NonStringSeq: TypeAlias = tuple[T, ...] | list[T] +Nested = T | NonStringSeq + +# Mutable types that should trigger warnings when used in frozen dataclasses +MUTABLE_TYPES: tuple[type, ...] = (list, dict, set, bytearray, np.ndarray) + +# Immutable types that are allowed in dataclass fields +# Note: np.generic is included to allow all numpy scalar types (np.int64, np.float32, etc.) +ALLOWED_IMMUTABLE_TYPES: tuple[type, ...] = (str, int, float, bool, bytes, tuple, frozenset, Path, np.generic) + +# All allowed types (used for disallowed type checking) +ALLOWED_TYPES: tuple[type, ...] = ALLOWED_IMMUTABLE_TYPES + MUTABLE_TYPES + + +class StructuredArray(Generic[T]): + """An array-like container for homogeneous NamedTuple or dataclass instances. + + Overview + -------- + StructuredArray wraps a sequence (or nested sequence) of NamedTuple or dataclass items + into a NumPy structured ndarray, enabling: + - NumPy-style indexing and slicing (preserving shape). + - Field access by name that returns read-only NumPy views. + - Computation of per-item properties or attributes, returning a NumPy array + with dtype inferred from the property's type annotation when available. + - Immutability control: pass frozen=True to prevent all array modifications. + - Validation support: dataclass __post_init__ runs on item access. + + Construction + ------------ + - From a sequence: StructuredArray([item1, item2, ...]) + Validates that all items are of the same NamedTuple or dataclass type. + Type is inferred from the first item. For empty sequences, provide item_type explicitly. + - From a sequence with explicit type: StructuredArray([...], item_type=MyType) + Allows type specification upfront, enabling empty sequences and ensuring type safety. + - From an ndarray: StructuredArray.from_array(arr, ItemType) + 'arr' must have last axis equal to the number of fields in ItemType. + Columns along the last axis are mapped to item fields. + - Use StructuredArray(..., frozen=True) or StructuredArray.from_array(..., frozen=True) to make the array + immutable (prevents all modifications). + + Access + ------ + - Field by name: arr["x"] -> returns a read-only view of the field. + - Property/attribute: arr["duration"] -> computes the property for each + item, producing an ndarray. If the property has a type hint, the result + dtype is chosen accordingly (e.g., int -> int64, float -> float64). + Otherwise, a heuristic casts ints to int64, else tries float, else object. + - Direct array access via items property: arr.items (always returns the + underlying NumPy array; if frozen=True, modifications are prevented). + + Notes: + ----- + - Homogeneity: All elements must be the same NamedTuple or dataclass type. + - String fields are kept as object dtype to avoid truncation. + - Properties are evaluated per element; heavy properties may be costly. + - Field views are always read-only to prevent accidental mutation. + - The .items property returns the underlying array; modifications are only + prevented if frozen=True was passed during construction. + - For dataclasses: __post_init__ is called when accessing items via iteration + or indexing to ensure validation happens. + + Example with NamedTuple: + ------- + >>> from typing import NamedTuple + >>> class Breath(NamedTuple): + ... start_time: float + ... middle_time: float + ... end_time: float + ... @property + ... def duration(self) -> float: + ... return self.end_time - self.start_time + ... + >>> breaths = [Breath(0.0, 0.5, 1.0), Breath(1.0, 1.6, 2.1)] + >>> arr = StructuredArray(breaths) + >>> arr["duration"] + array([1. , 1.1]) + + Example with dataclass: + ------- + >>> from dataclasses import dataclass + >>> @dataclass + ... class Event: + ... time: float + ... value: float + ... + ... def __post_init__(self): + ... if self.value < 0: + ... raise ValueError("value must be positive") + ... + >>> events = [Event(0.0, 1.0), Event(1.0, 2.0)] + >>> arr = StructuredArray(events) + >>> arr[0] # Runs __post_init__ validation + Event(time=0.0, value=1.0) + + Example with explicit type for empty list: + ------- + >>> # Create an empty StructuredArray with explicit type + >>> arr = StructuredArray([], item_type=Event) + >>> len(arr) + 0 + """ + + item_type: type[T] + _items: np.ndarray + _is_dataclass: bool + frozen: bool + + def __init__( + self, + items: NonStringSeq[T] | np.ndarray | Nested[T], + item_type: type[T] | None = None, + frozen: bool = True, + ): + """Initialize a StructuredArray from a sequence or nested sequence of items. + + Args: + items: A sequence (or nested sequence) of NamedTuple or dataclass instances, + or a numpy ndarray containing them. + item_type: + Optional explicit type of the items. If provided, all items are validated + against this type. If not provided, the type is inferred from the first + leaf item. Required when items is empty. + frozen: If True (default), makes the underlying array immutable. + + Raises: + ValueError: If items is empty and item_type is None, or if homogeneity check fails. + TypeError: If items contain unsupported types. + """ + if item_type is not None: + self.item_type = item_type + else: + if isinstance(items, np.ndarray) and items.size == 0: + msg = "Cannot infer type from empty array. Provide item_type explicitly." + raise ValueError(msg) + if not isinstance(items, np.ndarray) and not items: + msg = "Cannot infer type from empty sequence. Provide item_type explicitly." + raise ValueError(msg) + + # Infer type from first leaf element + leaf = _first_leaf(items) + self.item_type = type(leaf) # type: ignore[assignment] + + # Validate homogeneity + _check_homogeneous(items, self.item_type) + + # Determine if item_type is a dataclass + self._is_dataclass = is_dataclass(self.item_type) + + # Build structured dtype and array with same shape + dt = _get_struct_dtype(self.item_type) + + # For dataclasses, numpy.asarray doesn't know how to convert them directly, + # so we convert to tuples first (via named tuples or tuples of field values) + if self._is_dataclass: + items = cast("NonStringSeq[T] | np.ndarray | Nested[T]", _dataclass_items_to_tuples(items, self.item_type)) + + self._items = np.asarray(items, dtype=dt) + + # Check for disallowed field types in dataclass + if self._is_dataclass: + disallowed_fields = _get_disallowed_field_names(self.item_type) + if disallowed_fields: + warnings.warn( + f"Dataclass '{self.item_type.__name__}' has fields with disallowed types: " + f"{', '.join(disallowed_fields)}. " + f"Allowed types are: str, int, float, bool, bytes, tuple, frozenset, pathlib.Path, " + f"numpy types, frozen dataclasses, and NamedTuples.", + UserWarning, + stacklevel=2, + ) + + self.frozen = frozen + if frozen: + # Check if dataclass is not frozen while array is + if self._is_dataclass and not _is_dataclass_frozen(self.item_type): + warnings.warn( + f"StructuredArray is frozen, but the underlying dataclass " + f"'{self.item_type.__name__}' is not. Items accessed from the array " + f"can still be modified. Consider freezing the dataclass with " + f"@dataclass(frozen=True).", + UserWarning, + stacklevel=2, + ) + # Check if dataclass has mutable fields (independent check - can warn even if dataclass is frozen) + if self._is_dataclass and _has_mutable_fields(self.item_type): + warnings.warn( + f"StructuredArray is frozen, but the dataclass '{self.item_type.__name__}' has mutable fields " + f"(list, dict, set, etc.). The contents of these fields can still be modified. " + f"Consider using immutable types (tuple) instead.", + UserWarning, + stacklevel=2, + ) + self._freeze() + + object.__setattr__(self, "_initialized", True) + + def _freeze(self) -> None: + """Make the underlying array immutable.""" + dt = _get_struct_dtype(self.item_type) + freeze_method = "flag" if dt.hasobject else "memoryview" + self._items = freeze_array(self._items, method=freeze_method) + + def __setattr__(self, name: str, value: object) -> None: + """Allow setting type and _items only during initialization; block modification after.""" + # Check if initialization is complete (use object.__getattribute__ to bypass our __getattr__) + try: + initialized = object.__getattribute__(self, "_initialized") + except AttributeError: + initialized = False + + # Allow setting type, _items, and _is_dataclass only during initialization + if not initialized and name in ("item_type", "_items", "_is_dataclass", "frozen"): + super().__setattr__(name, value) + elif initialized and name in ("item_type", "_items", "_is_dataclass", "frozen"): + msg = f"{type(self).__name__!r} object is immutable; cannot modify {name!r} after initialization." + raise AttributeError(msg) + else: + msg = f"{type(self).__name__!r} object is immutable; cannot set attribute {name!r}." + raise AttributeError(msg) + + @classmethod + def from_array(cls, arr: np.ndarray | Nested, item_type: type[T], frozen: bool = True) -> StructuredArray[T]: # noqa: C901 + """Build a StructuredArray from an unstructured numpy array or nested list. + + The list must be convertible to a numpy array. The last axis of the array is mapped to the fields. + The length of the last axis must equal to the number of fields in the given type. + + Example: + This examples represents a sequence of 10 breaths for each of 32x32 pixels. Each breath contains 3 fields: + start_time, middle_time, end_time. + + ```python + breath_data = load_breath_data() # shape (10, 32, 32, 3) + breaths = StructuredArray.from_array(breath_data, Breath) + ``` + + This is equivalent to a list of 10 nested lists, each containing 32 lists (rows) of 32 (columns) items. + """ + if not isinstance(arr, np.ndarray): + arr = np.array(arr) + + if arr.ndim < 1: + msg = "array must have at least 1 dimension." + raise ValueError(msg) + + n_fields = _get_n_fields(item_type) + if (lal := arr.shape[-1]) != n_fields: + msg = f"Last axis must have size {n_fields} for {item_type.__name__}, not {lal}." + raise ValueError(msg) + + dt = _get_struct_dtype(item_type) + out = np.empty(arr.shape[:-1], dtype=dt) + + if not dt.fields: + msg = "Generated dtype has no fields; cannot proceed." + raise RuntimeError(msg) + + fields = cast("dict[str, tuple[np.dtype, int]]", dt.fields) + field_names = _get_field_names(item_type) + for i, name in enumerate(field_names): + # Cast each column to the target field dtype to avoid unintended promotion + target_dt = fields[name][0] + out[name] = arr[..., i].astype(target_dt, copy=False) + + inst = cls.__new__(cls) + inst.item_type = item_type + inst._is_dataclass = is_dataclass(item_type) # noqa: SLF001 + inst._items = out # noqa: SLF001 + + if inst._is_dataclass and "__post_init__" in dir(inst.item_type): # noqa: SLF001 + # Run __post_init__ for all items to ensure validation + for record in inst._items.flat: # noqa: SLF001 + inst._reconstruct_item(record) # noqa: SLF001 + + # Check for disallowed field types in dataclass + if inst._is_dataclass: # noqa: SLF001 + disallowed_fields = _get_disallowed_field_names(item_type) + if disallowed_fields: + warnings.warn( + f"Dataclass '{item_type.__name__}' has fields with disallowed types: " + f"{', '.join(disallowed_fields)}. " + "Allowed types are: str, int, float, bool, bytes, tuple, frozenset, pathlib.Path, " + "numpy types, frozen dataclasses, and NamedTuples.", + UserWarning, + stacklevel=2, + ) + + inst.frozen = frozen + if frozen: + # Check if dataclass is not frozen while array is + if inst._is_dataclass and not _is_dataclass_frozen(item_type): # noqa: SLF001 + warnings.warn( + "StructuredArray is frozen, but the underlying dataclass " + f"'{item_type.__name__}' is not. Consider freezing the dataclass with " + "@dataclass(frozen=True).", + UserWarning, + stacklevel=2, + ) + # Check if dataclass has mutable fields (independent check - can warn even if dataclass is frozen) + if inst._is_dataclass and _has_mutable_fields(item_type): # noqa: SLF001 + warnings.warn( + f"StructuredArray is frozen, but the dataclass '{item_type.__name__}' has mutable fields " + f"(list, dict, set, etc.). The contents of these fields can still be modified. " + f"Consider using immutable types (tuple) instead.", + UserWarning, + stacklevel=2, + ) + inst._freeze() # noqa: SLF001 + + object.__setattr__(inst, "_initialized", True) + + return inst + + @property + def shape(self) -> tuple[int, ...]: + """The shape of the StructuredArray.""" + return self._items.shape + + @property + def ndim(self) -> int: + """The number of dimensions of the StructuredArray.""" + return self._items.ndim + + @property + def dtype(self) -> np.dtype: + """The dtype of the underlying structured array.""" + return self._items.dtype + + @property + def items(self) -> np.ndarray: + """The underlying NumPy structured array. + + Returns the private array. If this instance was created with frozen=True, + modifications via this reference are prevented. Otherwise, modifications + are allowed. + """ + return self._items + + @property + def flags(self) -> flagsobj: + """The flags of the underlying structured array. + + If this instance was created with frozen=True, the WRITEABLE flag cannot + be changed. Otherwise, the flags are fully mutable. + """ + return self._items.flags + + def to_array(self) -> np.ndarray: + """Convert to an unstructured numpy array. + + Returns a 2D array where each column corresponds to a field of the item type, + in field order. This allows convenient slicing by column indices like + `arr[:, [0, 2]]`. + + Returns: + A 2D unstructured numpy array of shape (n_items, n_fields). + + Example: + >>> from typing import NamedTuple + >>> class Point(NamedTuple): + ... x: float + ... y: float + ... z: float + >>> arr = StructuredArray([Point(1.0, 2.0, 3.0), Point(4.0, 5.0, 6.0)]) + >>> arr_2d = arr.to_array() + >>> arr_2d.shape + (2, 3) + >>> arr_2d[:, [0, 2]] # Get x and z columns + array([[1., 3.], + [4., 6.]]) + """ + # Stack each field as a column to create unstructured array + if not self._items.dtype.names: + # No fields, return empty array + return np.empty((self.shape[0], 0)) + + return np.column_stack([self._items[name] for name in self._items.dtype.names]) + + def __array__(self, dtype: np.dtype | None = None) -> np.ndarray: + return self._items.astype(dtype) if dtype is not None else self._items + + def __iter__(self) -> Generator[T | StructuredArray[T], None, None]: + if self.ndim == 1: + for item in self._items: + yield self._reconstruct_item(item) + else: + # yield structured subarrays along axis 0 + for i in range(self._items.shape[0]): + out = StructuredArray.__new__(StructuredArray) + out.item_type = self.item_type + out._is_dataclass = self._is_dataclass # noqa: SLF001 + out._items = self._items[i] # noqa: SLF001 + yield out + + def __len__(self) -> int: + return self._items.shape[0] if self._items.ndim > 0 else 0 + + def __repr__(self) -> str: + return f"StructuredArray[{self.item_type.__name__}]{repr(self._items).removeprefix('array')}" + + def __eq__(self, other: object) -> bool: + """Compare two StructuredArray instances for equality. + + Two StructuredArray instances are equal if: + - They are both StructuredArray instances + - They have the same item type + - Their underlying arrays are equal (including NaN equality for floats) + """ + if not isinstance(other, StructuredArray): + return False + + if self.item_type is not other.item_type: + return False + + # Compare shapes + if self._items.shape != other._items.shape: + return False + + # Compare dtypes + if self._items.dtype != other._items.dtype: + return False + + # For structured arrays, compare field by field to handle NaN values properly + for name in self._items.dtype.names or []: + self_field = self._items[name] + other_field = other._items[name] + + # Use array_equal with equal_nan for each field + if not np.array_equal(self_field, other_field, equal_nan=True): + return False + + return True + + __hash__ = None # type: ignore[assignment] + + def __add__(self, other: StructuredArray[T]) -> StructuredArray[T]: + if not isinstance(other, StructuredArray): + msg = f"Can only concatenate StructuredArray (not '{type(other).__name__}') to StructuredArray." + raise TypeError(msg) + + if self.item_type is not other.item_type: + msg = "Cannot concatenate StructuredArray with different item types." + raise TypeError(msg) + + new_items = np.concatenate((self._items, other._items), axis=0) + frozen = self.frozen or other.frozen + + # Create a new StructuredArray directly with the concatenated structured array + inst = self.__class__.__new__(self.__class__) + inst.item_type = self.item_type + inst._is_dataclass = self._is_dataclass + inst._items = new_items + inst.frozen = frozen + + if frozen: + inst._freeze() + + object.__setattr__(inst, "_initialized", True) + return inst + + @overload + def __getitem__(self, index: str) -> np.ndarray: ... + + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> StructuredArray[T]: ... + + @overload + def __getitem__(self, index: NonStringSeq) -> StructuredArray[T]: ... + + def __getitem__(self, index: str | int | slice | NonStringSeq) -> np.ndarray | StructuredArray[T] | T: + # Field-name access: return field view + if isinstance(index, str): + names = self._items.dtype.names or () + if index in names: + view = self._items[index] + # Ensure field view is read-only + with contextlib.suppress(Exception): + view.flags.writeable = False + return view + # Computed property or attribute on the item type → compute over all items + return self._compute_property(index) + + # NumPy-style indexing + result = self._items[index] + + # Structured scalar (np.void) → return reconstructed item + if isinstance(result, np.void): + return self._reconstruct_item(result) + + # Zero-d structured ndarray (shape == ()) → convert to item + if isinstance(result, np.ndarray) and result.dtype.fields is not None and result.ndim == 0: + scalar = result.item() # np.void + return self._reconstruct_item(scalar) + + # Structured ndarray → wrap + if isinstance(result, np.ndarray) and result.dtype.fields is not None: + out: StructuredArray[T] = type(self).__new__(type(self)) + out.item_type = self.item_type + out._is_dataclass = self._is_dataclass + out._items = result + out.frozen = self.frozen + object.__setattr__(out, "_initialized", True) + return out + + # Non-structured ndarray (e.g. field slice) → return as-is + return result + + def _reconstruct_item(self, record: np.void) -> T: + """Reconstruct an item (NamedTuple or dataclass) from a numpy void record. + + For dataclasses, this calls __post_init__ to ensure validation runs. + """ + values = record.tolist() + + return self.item_type(*values) + + def _compute_property(self, attr: str) -> np.ndarray: + """Compute a property or attribute across all items, preserving the array shape.""" + # Verify attribute exists on a sample item + sample_record = self._items.flat[0] + sample = self._reconstruct_item(sample_record) + if not hasattr(sample, attr): + msg = f"Field or property '{attr}' not found in {self.item_type.__name__}." + raise KeyError(msg) + + # Collect values (single pass using flat indexing) + out_obj = np.empty(self.shape, dtype=object) + for i, rec in enumerate(self._items.reshape(-1)): + item = self._reconstruct_item(rec) + out_obj.reshape(-1)[i] = getattr(item, attr) + + # Determine target dtype from property annotation if available (handles postponed annotations) + target_dtype: np.dtype | None = None + attr_member = getattr(self.item_type, attr, None) + if isinstance(attr_member, property) and attr_member.fget is not None: + with contextlib.suppress(Exception): + hints = get_type_hints(attr_member.fget) + ret_ann = hints.get("return") + if ret_ann is not None: + target = _python_to_np_dtype(ret_ann) + target_dtype = np.dtype(target) + + # Cast accordingly + if target_dtype is not None and target_dtype != np.dtype(object): + with contextlib.suppress(Exception): + return out_obj.astype(target_dtype) + + # Heuristics: ints → int64; floats → float64; numpy scalar families respected + with contextlib.suppress(Exception): + if all(isinstance(v, (int, np.integer)) for v in out_obj.flat): + return out_obj.astype(np.int64) + with contextlib.suppress(Exception): + if all(isinstance(v, (float, np.floating)) for v in out_obj.flat): + return out_obj.astype(np.float64) + + return out_obj + + +def _first_leaf( + seq: NamedTuple | np.ndarray | list[NamedTuple] | tuple[NamedTuple, ...] | Nested[NamedTuple], +) -> object: + """Recursively find the first NamedTuple or dataclass instance in a nested sequence or ndarray.""" + if _is_struct_instance(seq): + return seq + if isinstance(seq, np.ndarray): + if seq.size == 0: + msg = "Cannot infer type from empty ndarray." + raise ValueError(msg) + return _first_leaf(seq.flat[0]) + if isinstance(seq, (list, tuple)): + if not seq: + msg = "Cannot infer type from empty nested sequence." + raise ValueError(msg) + return _first_leaf(seq[0]) + + msg = "Items must be NamedTuple or dataclass or nested sequences thereof." + raise TypeError(msg) + + +def _check_homogeneous(seq: SequenceType | np.ndarray | Nested[T], typ: type[T]) -> None: + """Recursively check that all items in the nested sequence/ndarray are of the given type.""" + if isinstance(seq, np.ndarray): + for it in seq.flat: + _check_homogeneous(it, typ) + return + if isinstance(seq, (list, tuple)) and not _is_struct_instance(seq): + seq_ = cast("SequenceType", seq) + for it in seq_: + _check_homogeneous(it, typ) + return + if _is_struct_instance(seq): + if type(seq) is not typ: + msg = f"All items must be of the same type ({typ.__name__}), got {type(seq).__name__}." + raise ValueError(msg) + return + msg = "Items must be NamedTuple, dataclass, or nested sequences thereof." + raise TypeError(msg) + + +def _python_to_np_dtype(py: type) -> np.dtype | str: + """Map basic Python types to NumPy dtypes.""" + if py is int: + return "i8" + if py is float: + return "f8" + if py is bool: + return "?" + if py is str: + return np.dtype(object) # keep Python str; avoids truncation + + if get_origin(py) is Union: + args = [a for a in get_args(py) if a is not type(None)] + if len(args) == 1: + return _python_to_np_dtype(args[0]) + return np.dtype(object) + + +def _is_namedtuple_instance(item: object) -> TypeGuard[NamedTuple]: + """Check if item is a NamedTuple instance.""" + return isinstance(item, tuple) and hasattr(item, "_fields") and not is_dataclass(item) + + +def _dataclass_items_to_tuples(items: object, dc_type: type) -> object: + """Recursively convert dataclass items to tuples for numpy.asarray conversion. + + numpy.asarray doesn't know how to convert arbitrary dataclass instances + to structured array rows, so we convert them to tuples first. + """ + if isinstance(items, np.ndarray): + # For ndarrays, convert each element + result = np.array([_dataclass_items_to_tuples(item, dc_type) for item in items.flat]) + return result.reshape(items.shape) + + if isinstance(items, (list, tuple)) and not is_dataclass(items): + # It's a sequence - recursively convert each element + return [_dataclass_items_to_tuples(item, dc_type) for item in items] + + # It's a single dataclass instance - convert to tuple of field values + if is_dataclass(items): + field_names = _get_field_names(type(items)) + return tuple(getattr(items, name) for name in field_names) + + # Shouldn't reach here, but just in case + return items + + +def _is_namedtuple_type(item: object) -> TypeGuard[type[NamedTuple]]: + """Check if item is a NamedTuple type.""" + return isinstance(item, type) and issubclass(item, tuple) and hasattr(item, "_fields") and not is_dataclass(item) + + +def _is_struct_instance(item: object) -> bool: + """Check if item is a NamedTuple or dataclass instance (not type).""" + if isinstance(item, type): + # Exclude types themselves + return False + return _is_namedtuple_instance(item) or is_dataclass(item) + + +def _is_struct_type(item: object) -> bool: + """Check if item is a NamedTuple or dataclass type.""" + return _is_namedtuple_type(item) or (isinstance(item, type) and is_dataclass(item)) + + +def _get_field_names(item_type: type) -> list[str]: + """Get field names from a NamedTuple or dataclass type.""" + if is_dataclass(item_type): + return [f.name for f in dc_fields(item_type)] + if _is_namedtuple_type(item_type): + return list(item_type._fields) # type: ignore[return-value] + msg = f"item_type must be a NamedTuple or dataclass type, got {item_type}" + raise TypeError(msg) + + +def _get_n_fields(item_type: type) -> int: + """Get number of fields from a NamedTuple or dataclass type.""" + return len(_get_field_names(item_type)) + + +def _get_struct_dtype(item: object) -> np.dtype: + """Generate a NumPy structured dtype from a NamedTuple or dataclass type.""" + if _is_struct_instance(item): + item = type(item) + if not _is_struct_type(item): + msg = "item must be a NamedTuple instance, NamedTuple type, dataclass instance, or dataclass type." + raise TypeError(msg) + + if is_dataclass(item): + return _get_dataclass_dtype(item) # type: ignore[arg-type] + return _get_namedtuple_dtype(item) # type: ignore[arg-type] + + +def _get_namedtuple_dtype(nt_type: type[NamedTuple]) -> np.dtype: + """Generate a NumPy structured dtype from a NamedTuple type.""" + hints = get_type_hints(nt_type, include_extras=False) + names_in_order = list(nt_type._fields) # type: ignore[attr-defined] + + def to_np_dtype(py: type) -> str | np.dtype: + if py is int: + return "i8" + if py is float: + return "f8" + if py is bool: + return "?" + # everything else (incl. str) → object to avoid truncation + return np.dtype(object) + + fields = [(name, to_np_dtype(hints.get(name, object))) for name in names_in_order] + return np.dtype(fields) + + +def _get_dataclass_dtype(dc_type: type) -> np.dtype: + """Generate a NumPy structured dtype from a dataclass type.""" + try: + hints = get_type_hints(dc_type, include_extras=False) + except (NameError, TypeError, AttributeError): + # If we can't resolve type hints (e.g., for types defined in test scopes), + # fall back to empty hints + hints = {} + + def to_np_dtype(py: type) -> str | np.dtype: + if py is int: + return "i8" + if py is float: + return "f8" + if py is bool: + return "?" + # everything else (incl. str) → object to avoid truncation + return np.dtype(object) + + fields = [(f.name, to_np_dtype(hints.get(f.name, object))) for f in dc_fields(dc_type)] + return np.dtype(fields) + + +def _is_dataclass_frozen(dc_type: type) -> bool: + """Check if a dataclass type is frozen. + + Args: + dc_type: A dataclass type. + + Returns: + True if the dataclass is frozen, False otherwise. + """ + if not is_dataclass(dc_type): + return False + + # Access the __dataclass_params__ which contains the frozen flag + if hasattr(dc_type, "__dataclass_params__"): + return dc_type.__dataclass_params__.frozen # type: ignore[attr-defined] + + return False + + +def _has_mutable_fields(dc_type: type) -> bool: + """Check if a dataclass has any mutable field types. + + A frozen dataclass with mutable fields like list, dict, or set can still have + its contents modified even though the dataclass instance itself is frozen. + + Args: + dc_type: A dataclass type. + + Returns: + True if the dataclass has any mutable field types, False otherwise. + """ + if not is_dataclass(dc_type): + return False + + try: + hints = get_type_hints(dc_type, include_extras=False) + except (NameError, TypeError, AttributeError): + # If we can't resolve type hints, skip the check + return False + + mutable_types = MUTABLE_TYPES + + for field_type in hints.values(): + # Get the origin type (e.g., list from list[int]) + origin = get_origin(field_type) + + # Check if directly mutable or is a subclass of mutable types + if _is_subclass_of_allowed(field_type, mutable_types) or ( + origin is not None and _is_subclass_of_allowed(origin, mutable_types) + ): + return True + + # Check if it's a Union with mutable types + if origin is Union: + args = get_args(field_type) + for arg in args: + arg_origin = get_origin(arg) + if _is_subclass_of_allowed(arg, mutable_types) or ( + arg_origin is not None and _is_subclass_of_allowed(arg_origin, mutable_types) + ): + return True + + return False + + +def _get_disallowed_field_names(dc_type: type) -> list[str]: + """Get field names that have disallowed types in a dataclass. + + Allowed types include: str, int, float, bool, bytes, tuple, frozenset, pathlib.Path, + numpy scalars/arrays, frozen dataclasses, and NamedTuples. + + Args: + dc_type: A dataclass type. + + Returns: + List of field names with disallowed types, empty if all types are allowed. + """ + if not is_dataclass(dc_type): + return [] + + try: + hints = get_type_hints(dc_type, include_extras=False) + except (NameError, TypeError, AttributeError): + # If we can't resolve type hints, skip the check + return [] + + disallowed_fields = [] + + # All allowed types (immutable and mutable) + allowed = ALLOWED_TYPES + + for field_name, field_type in hints.items(): + origin = get_origin(field_type) + + # Check if directly allowed or is a subclass of allowed types + if _is_subclass_of_allowed(field_type, allowed): + continue + + # Check if it's a generic type with allowed origin + if origin is not None and _is_subclass_of_allowed(origin, allowed): + continue + + # Check for Union types + if origin is Union: + args = get_args(field_type) + if all(_is_allowed_type(arg) for arg in args if arg is not type(None)): + continue + + # Check for frozen dataclass or NamedTuple + if _is_namedtuple_type(field_type) or (is_dataclass(field_type) and _is_dataclass_frozen(field_type)): + continue + + # If we get here, the type is not allowed + disallowed_fields.append(field_name) + + return disallowed_fields + + +def _is_subclass_of_allowed(check_type: type, allowed_types: tuple[type, ...]) -> bool: + """Check if check_type is a subclass of any type in allowed_types.""" + if check_type in allowed_types: + return True + + try: + if isinstance(check_type, type): + for allowed in allowed_types: + if allowed is not type and issubclass(check_type, allowed): + return True + except TypeError: + # issubclass() can raise TypeError for non-class types + pass + + return False + + +def _is_allowed_type(field_type: type) -> bool: + """Check if a field type is in the allowed list or is a subclass of an allowed type.""" + # Check against basic allowed types (includes numpy types via np.generic) + if _is_subclass_of_allowed(field_type, ALLOWED_TYPES): + return True + + if _is_namedtuple_type(field_type): + return True + + if is_dataclass(field_type) and _is_dataclass_frozen(field_type): + return True + + # Check for generic types + origin = get_origin(field_type) + return origin is not None and _is_subclass_of_allowed(origin, ALLOWED_TYPES) diff --git a/eitprocessing/features/breath_detection.py b/eitprocessing/features/breath_detection.py index 2cedf7fc9..5c0804821 100644 --- a/eitprocessing/features/breath_detection.py +++ b/eitprocessing/features/breath_detection.py @@ -1,4 +1,3 @@ -import itertools import math from collections.abc import Callable from dataclasses import dataclass @@ -9,8 +8,9 @@ from eitprocessing.datahandling.breath import Breath from eitprocessing.datahandling.continuousdata import ContinuousData -from eitprocessing.datahandling.intervaldata import IntervalData +from eitprocessing.datahandling.intervaldata import Interval, IntervalData from eitprocessing.datahandling.sequence import Sequence +from eitprocessing.datahandling.structured_array import StructuredArray from eitprocessing.features.moving_average import MovingAverage @@ -116,15 +116,14 @@ def find_breaths( valley_indices, ) breaths = self._remove_breaths_around_invalid_data(breaths, time, sample_frequency, invalid_data_indices) + intervals = StructuredArray.from_array(breaths.to_array()[:, [0, 2]], Interval) breaths_container = IntervalData( label=result_label, name="Breaths as determined by BreathDetection", unit=None, category="breath", - intervals=[(breath.start_time, breath.end_time) for breath in breaths], + intervals=intervals, values=breaths, - parameters={type(self): dict(vars(self))}, - derived_from=[continuous_data], ) if store: @@ -372,56 +371,35 @@ def _create_breaths_from_peak_valley_data( time: np.ndarray, peak_indices: np.ndarray, valley_indices: np.ndarray, - ) -> list[Breath]: - return [ - Breath(time[start], time[middle], time[end]) - for middle, (start, end) in zip( - peak_indices, - itertools.pairwise(valley_indices), - strict=True, - ) - ] + ) -> StructuredArray[Breath]: + times = time[np.column_stack([valley_indices[:-1], peak_indices, valley_indices[1:]])] + return StructuredArray.from_array(times, Breath) def _remove_breaths_around_invalid_data( self, - breaths: list[Breath], + breaths: StructuredArray[Breath], time: np.ndarray, sample_frequency: float, invalid_data_indices: np.ndarray, - ) -> list[Breath]: - """Remove breaths overlapping with invalid data. - - Breaths that start within a window length (given by invalid_data_removal_window_length) of invalid data are - removed. - - Args: - breaths: list of detected breath objects - time: time axis belonging to the data - sample_frequency: sample frequency of the data and time - invalid_data_indices: indices of invalid data points - """ - # TODO: write more general(ized) method of determining invalid data - - new_breaths = breaths[:] - + ) -> StructuredArray[Breath]: + """Remove breaths overlapping with invalid data.""" if not len(invalid_data_indices): - return new_breaths + return breaths[:] - invalid_data_values = np.zeros(time.shape) - invalid_data_values[invalid_data_indices] = 1 # gives the value 1 to each invalid datapoint + invalid_data_values = np.zeros_like(time) + invalid_data_values[invalid_data_indices] = 1 window_length = math.ceil(self.invalid_data_removal_window_length * sample_frequency) - for breath in new_breaths[:]: + indices_to_keep = [] + for i, breath in enumerate(breaths): breath_start_minus_window = max(0, np.argmax(time == breath.start_time) - window_length) breath_end_plus_window = min(len(invalid_data_values), np.argmax(time == breath.end_time) + window_length) - # if no invalid datapoints are within the window, np.max() will return 0 - # if any invalid datapoints are within the window, np.max() will return 1 - if np.max(invalid_data_values[breath_start_minus_window:breath_end_plus_window]): - new_breaths.remove(breath) + if not np.max(invalid_data_values[breath_start_minus_window:breath_end_plus_window]): + indices_to_keep.append(i) - return new_breaths + return breaths[indices_to_keep] @staticmethod def _fill_nan_with_nearest_neighbour(data: np.ndarray) -> np.ndarray: diff --git a/eitprocessing/features/pixel_breath.py b/eitprocessing/features/pixel_breath.py index 3f156afcf..f0bcd915e 100644 --- a/eitprocessing/features/pixel_breath.py +++ b/eitprocessing/features/pixel_breath.py @@ -141,7 +141,7 @@ def find_pixel_breaths( # noqa: C901, PLR0912, PLR0915 [breath.middle_time for breath in continuous_breaths.values], ) - _, n_rows, n_cols = eit_data.pixel_impedance.shape + _, n_rows, n_cols = eit_data.values.shape from eitprocessing.parameters.tidal_impedance_variation import TIV # noqa: PLC0415 @@ -169,7 +169,7 @@ def find_pixel_breaths( # noqa: C901, PLR0912, PLR0915 mean_tiv_pixel[~all_nan_mask] = np.nanmean(pixel_tiv_with_continuous_data_timing[:, ~all_nan_mask], axis=0) time = eit_data.time - pixel_impedance = eit_data.pixel_impedance + pixel_impedance = eit_data.values pixel_breaths = np.full((len(continuous_breaths), n_rows, n_cols), None, dtype=object) @@ -291,7 +291,6 @@ def find_pixel_breaths( # noqa: C901, PLR0912, PLR0915 values=list( pixel_breaths, ), ## TODO: change back to pixel_breaths array when IntervalData works with 3D array - derived_from=[eit_data], ) if store: sequence.interval_data.add(pixel_breaths_container) diff --git a/eitprocessing/features/rate_detection.py b/eitprocessing/features/rate_detection.py index 742467f47..7becb61dc 100644 --- a/eitprocessing/features/rate_detection.py +++ b/eitprocessing/features/rate_detection.py @@ -152,7 +152,7 @@ def apply( set to segment length - 1. """ - pixel_impedance = eit_data.pixel_impedance.copy().astype(np.float32) + pixel_impedance = eit_data.values.copy().astype(np.float32) summed_impedance = np.nansum(pixel_impedance, axis=(1, 2)) len_segment = int(self.welch_window * eit_data.sample_frequency) diff --git a/eitprocessing/filters/mdn.py b/eitprocessing/filters/mdn.py index b3b2c2b3c..53fdf0319 100644 --- a/eitprocessing/filters/mdn.py +++ b/eitprocessing/filters/mdn.py @@ -1,7 +1,7 @@ -import copy import math import warnings from dataclasses import dataclass +from dataclasses import replace as dataclass_replace from typing import TypeVar, cast, overload import numpy as np @@ -155,17 +155,17 @@ def apply( # pyright: ignore[reportInconsistentOverload] return new_data # TODO: Replace with input_data.update(...) when implemented - return_object = copy.deepcopy(input_data) - for attr, value in kwargs.items(): - setattr(return_object, attr, value) - if isinstance(return_object, ContinuousData): - return_object.values = new_data - elif isinstance(return_object, EITData): - return_object.pixel_impedance = new_data + kwargs["values"] = new_data - capture("filtered_data", return_object) - return return_object + if isinstance(input_data, ContinuousData): + filtered_data: T = dataclass_replace(input_data, **kwargs) + elif isinstance(input_data, EITData): + filtered_data: T = input_data.update(**kwargs) + + capture("filtered_data", filtered_data) + + return filtered_data def _validate_arguments( self, @@ -182,14 +182,10 @@ def _validate_arguments( msg = "Axis should not be provided when using ContinuousData or EITData." raise ValueError(msg) - if isinstance(input_data, ContinuousData): + if isinstance(input_data, (ContinuousData, EITData)): data = input_data.values sample_frequency_ = cast("float", input_data.sample_frequency) axis_ = 0 - elif isinstance(input_data, EITData): - data = input_data.pixel_impedance - sample_frequency_ = cast("float", input_data.sample_frequency) - axis_ = 0 elif isinstance(input_data, np.ndarray): data = input_data axis_ = DEFAULT_AXIS if axis is MISSING else axis diff --git a/eitprocessing/parameters/__init__.py b/eitprocessing/parameters/__init__.py index 2ab04a6d0..41a0e14b9 100644 --- a/eitprocessing/parameters/__init__.py +++ b/eitprocessing/parameters/__init__.py @@ -2,7 +2,7 @@ import numpy as np -from eitprocessing.datahandling.continuousdata import DataContainer +from eitprocessing.datahandling import DataContainer class ParameterCalculation(ABC): diff --git a/eitprocessing/parameters/eeli.py b/eitprocessing/parameters/eeli.py index 75fbff071..a2a224929 100644 --- a/eitprocessing/parameters/eeli.py +++ b/eitprocessing/parameters/eeli.py @@ -109,8 +109,6 @@ def compute_parameter( category="impedance", time=time, description="End-expiratory lung impedance (EELI) determined on continuous data", - parameters=self.breath_detection_kwargs, - derived_from=[continuous_data], values=values, ) if store: diff --git a/eitprocessing/parameters/tidal_impedance_variation.py b/eitprocessing/parameters/tidal_impedance_variation.py index 6b47086b6..3fb75b475 100644 --- a/eitprocessing/parameters/tidal_impedance_variation.py +++ b/eitprocessing/parameters/tidal_impedance_variation.py @@ -3,7 +3,7 @@ import warnings from dataclasses import InitVar, dataclass, field from functools import singledispatchmethod -from typing import Final, Literal, NoReturn +from typing import Final, Literal import numpy as np @@ -13,6 +13,7 @@ from eitprocessing.datahandling.intervaldata import IntervalData from eitprocessing.datahandling.sequence import Sequence from eitprocessing.datahandling.sparsedata import SparseData +from eitprocessing.datahandling.structured_array import StructuredArray from eitprocessing.features.breath_detection import BreathDetection from eitprocessing.features.pixel_breath import PixelBreath from eitprocessing.parameters import ParameterCalculation @@ -73,7 +74,7 @@ def __post_init__(self, breath_detection_kwargs: dict | None) -> None: def compute_parameter( self, data: ContinuousData | EITData, - ) -> NoReturn: + ) -> SparseData: """Compute the tidal impedance variation per breath on either ContinuousData or EITData, depending on the input. Args: @@ -142,7 +143,6 @@ def compute_continuous_parameter( category="impedance difference", time=[breath.middle_time for breath in breaths.values if breath is not None], description="Tidal impedance variation determined on continuous data", - derived_from=[continuous_data], values=tiv_values, ) if store: @@ -203,7 +203,7 @@ def compute_pixel_parameter( msg = f"Invalid {tiv_timing}. The tiv_timing must be either 'continuous' or 'pixel'." raise ValueError(msg) - data = eit_data.pixel_impedance + data = eit_data.values _, n_rows, n_cols = data.shape if tiv_timing == "pixel": @@ -253,7 +253,6 @@ def compute_pixel_parameter( category="impedance difference", time=list(all_pixels_breath_timings), description="Tidal impedance variation determined on pixel impedance", - derived_from=[eit_data], values=list(all_pixels_tiv_values.astype(float)), ) @@ -284,12 +283,12 @@ def _calculate_tiv_values( self, data: np.ndarray, time: np.ndarray, - breaths: list[Breath], + breaths: StructuredArray[Breath], tiv_method: str, tiv_timing: str, # noqa: ARG002 # remove when restructuring ) -> list: # Filter out None breaths - breaths = np.array(breaths) + valid_breath_indices = np.flatnonzero([breath is not None for breath in breaths]) valid_breaths = breaths[valid_breath_indices] diff --git a/eitprocessing/plotting/filter.py b/eitprocessing/plotting/filter.py index fc050344c..5b7e0efdd 100644 --- a/eitprocessing/plotting/filter.py +++ b/eitprocessing/plotting/filter.py @@ -202,19 +202,12 @@ def _get_data( unfiltered_signal = unfiltered_data filtered_signal = filtered_data sample_frequency_ = sample_frequency - case ContinuousData(), ContinuousData(): + case (ContinuousData(), ContinuousData()) | (EITData(), EITData()): unfiltered_signal = unfiltered_data.values filtered_signal = filtered_data.values sample_frequency_ = unfiltered_data.sample_frequency if sample_frequency is not MISSING: - msg = "Sample frequency should not be provided when using ContinuousData." - raise ValueError(msg) - case EITData(), EITData(): - unfiltered_signal = unfiltered_data.pixel_impedance - filtered_signal = filtered_data.pixel_impedance - sample_frequency_ = unfiltered_data.sample_frequency - if sample_frequency is not MISSING: - msg = "Sample frequency should not be provided when using EITData." + msg = "Sample frequency should not be provided when using ContinuousData or EITData." raise ValueError(msg) case _: msg = "Unfiltered and filtered data must be either numpy arrays, ContinuousData, or EITData." diff --git a/eitprocessing/roi/__init__.py b/eitprocessing/roi/__init__.py index 21b10adf8..979c4e7a8 100644 --- a/eitprocessing/roi/__init__.py +++ b/eitprocessing/roi/__init__.py @@ -13,7 +13,7 @@ import re import sys import warnings -from dataclasses import InitVar, dataclass, field, replace +from dataclasses import InitVar, dataclass, field from dataclasses import replace as dataclass_replace from typing import TYPE_CHECKING, TypeVar, overload @@ -189,7 +189,7 @@ def __replace__(self, /, **changes) -> Self: elif isinstance(changes["plot_config"], dict): changes["plot_config"] = self._plot_config.update(**changes["plot_config"]) label = changes.pop("label", None) - return replace(self, label=label, **changes) + return dataclass_replace(self, label=label, **changes) update = __replace__ # TODO: add tests for update @@ -206,9 +206,8 @@ def apply(self, data: PixelMap, **kwargs) -> PixelMap: ... def apply(self, data, **kwargs): """Apply pixel mask to data, returning a copy of the object with pixel values masked. - Data can be a numpy array, an EITData object or PixelMap object. In case of an EITData object, the mask will be - applied to the `pixel_impedance` attribute. In case of a PixelMap, the mask will be applied to the `values` - attribute. + Data can be a numpy array, an EITData object or PixelMap object. In case of an EITData or PixelMap object, the + mask will be applied to the `values` attribute. The input data can have any dimension. The mask is applied to the last two dimensions. The size of the last two dimensions must match the size of the dimensions of the mask, and will generally (but do not have to) have the @@ -241,9 +240,7 @@ def transform_and_mask(data: np.ndarray) -> np.ndarray: match data: case np.ndarray(): return transform_and_mask(data) - case EITData(): - return dataclass_replace(data, pixel_impedance=transform_and_mask(data.pixel_impedance), **kwargs) - case PixelMap(): + case EITData() | PixelMap(): return data.update(values=transform_and_mask(data.values), **kwargs) case _: msg = f"Data should be an array, or EITData or PixelMap object, not {type(data)}." diff --git a/eitprocessing/utils/frozen_array.py b/eitprocessing/utils/frozen_array.py new file mode 100644 index 000000000..9c4d1b2b0 --- /dev/null +++ b/eitprocessing/utils/frozen_array.py @@ -0,0 +1,34 @@ +import warnings +from typing import Literal + +import numpy as np + +DEFAULT_FREEZE_METHOD = "memoryview" + + +def freeze_array(a: np.ndarray, *, method: Literal["flag", "memoryview"] = DEFAULT_FREEZE_METHOD) -> np.ndarray: + """Return a read-only array that cannot be made writeable again.""" + # Memory buffers cannot represent object/structured fields safely. + if method == "memoryview" and a.dtype.hasobject: + warnings.warn( + "Cannot use 'memoryview' method for object or structured dtypes; falling back to 'flag' method.", + RuntimeWarning, + stacklevel=2, + ) + method = "flag" + + match method: + case "flag": + # Make a copy if needed and mark it readonly (can be flipped back by a user). + if a.flags["WRITEABLE"]: + a = a.copy() + a.flags["WRITEABLE"] = False + return a + case "memoryview": + # For numeric/plain dtypes, create a readonly buffer-backed view that can't be flipped. + a_c = np.ascontiguousarray(a) + ro_buf = memoryview(a_c).toreadonly() + return np.frombuffer(ro_buf, dtype=a_c.dtype).reshape(a_c.shape) + case _: + msg = f"Invalid method: {method!r}" + raise ValueError(msg) diff --git a/pyproject.toml b/pyproject.toml index 31563f828..51f7e86c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -190,3 +190,6 @@ replace = 'version: "{new_version}"' filename = "eitprocessing/__init__.py" search = '__version__ = "{current_version}"' replace = '__version__ = "{new_version}"' + + +[tool.pyrefly] diff --git a/tests/eitdata/test_loading_sentec.py b/tests/eitdata/test_loading_sentec.py index 73f8bb244..dc437d48f 100644 --- a/tests/eitdata/test_loading_sentec.py +++ b/tests/eitdata/test_loading_sentec.py @@ -29,8 +29,8 @@ def test_load_sentec_single_file( assert isinstance(eit_data, EITData) assert np.isclose(eit_data.sample_frequency, 50.2, rtol=2e-2), "Sample frequency should be approximately 50.2 Hz" assert len(eit_data.time) > 0, "Time axis should not be empty" - assert len(eit_data.time) == len(eit_data.pixel_impedance), "Length of time axis should match number of frames" - assert len(eit_data) == len(eit_data.pixel_impedance), "Length of EITData should match number of frames" + assert len(eit_data.time) == len(eit_data.values), "Length of time axis should match number of frames" + assert len(eit_data) == len(eit_data.values), "Length of EITData should match number of frames" assert len(sequence.continuous_data) == 0, "Sentec data should not have continuous data channels" assert len(sequence.sparse_data) == 0, "Sentec data should not have sparse data channels" diff --git a/tests/test_amplitude_lungspace.py b/tests/test_amplitude_lungspace.py index 6032ae428..1e8ea3cd6 100644 --- a/tests/test_amplitude_lungspace.py +++ b/tests/test_amplitude_lungspace.py @@ -26,15 +26,13 @@ def factory(amplitudes: npt.ArrayLike, duration: float = 20.0) -> EITData: sine = np.sin(t * 2 * np.pi / 5) # 5 second period signal = amplitudes[None, :, :] * (sine[:, None, None] + 0.5) return EITData( - pixel_impedance=signal, + values=signal, path="", - nframes=len(t), time=t, sample_frequency=sample_frequency, vendor=Vendor.SIMULATED, label="simulated", description="Simulated EIT data for testing purposes", - name="", suppress_simulated_warning=True, ) diff --git a/tests/test_breath_detection.py b/tests/test_breath_detection.py index 248722147..3f69b6a7a 100644 --- a/tests/test_breath_detection.py +++ b/tests/test_breath_detection.py @@ -360,16 +360,16 @@ def test_create_breaths_from_peak_valley_data(): breaths = bd._create_breaths_from_peak_valley_data(time, peak_indices, valley_indices) assert len(breaths) == 5 assert all(isinstance(breath, Breath) for breath in breaths) - assert np.array_equal(np.array([breath.start_time for breath in breaths]), time[valley_indices[:-1]]) - assert np.array_equal(np.array([breath.middle_time for breath in breaths]), time[peak_indices]) - assert np.array_equal(np.array([breath.end_time for breath in breaths]), time[valley_indices[1:]]) + assert np.array_equal(breaths["start_time"], time[valley_indices[:-1]]) + assert np.array_equal(breaths["middle_time"], time[peak_indices]) + assert np.array_equal(breaths["end_time"], time[valley_indices[1:]]) fewer_valley_indices = valley_indices[:-1] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="all the input array dimensions .* must match exactly"): bd._create_breaths_from_peak_valley_data(time, peak_indices, fewer_valley_indices) fewer_peak_indices = peak_indices[:-1] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="all the input array dimensions .* must match exactly"): bd._create_breaths_from_peak_valley_data(time, fewer_peak_indices, valley_indices) peaks_out_of_order = np.concatenate([peak_indices[3:], peak_indices[:3]]) @@ -481,13 +481,9 @@ def test_find_breaths(): label = "waveform_data" cd = ContinuousData( - label, - "Generated waveform data", - "", - "mock", - "", time=time, values=y, + label=label, sample_frequency=sample_frequency, ) seq = Sequence("sequence_label") @@ -514,13 +510,9 @@ def test_find_breaths(): y_copy = np.copy(y) y_copy[438] = -100 # single timepoint around the peak of the 4th breath cd = ContinuousData( - label, - "Generated waveform data", - "", - "mock", - "", time=time, values=y_copy, + label=label, sample_frequency=sample_frequency, ) seq.continuous_data.add(cd, overwrite=True) diff --git a/tests/test_continuous_data.py b/tests/test_continuous_data.py index 59dbb56b8..404ed53fc 100644 --- a/tests/test_continuous_data.py +++ b/tests/test_continuous_data.py @@ -1,3 +1,4 @@ +import dataclasses import warnings import numpy as np @@ -6,6 +7,47 @@ from eitprocessing.datahandling.continuousdata import ContinuousData +@pytest.fixture +def continuous_data_object(): + n = 100 + sample_frequency = 10 + time = np.arange(n) / sample_frequency + values = np.arange(n) + + return ContinuousData( + time=time, + values=values, + sample_frequency=sample_frequency, + ) + + +def test_continuous_data_frozen(continuous_data_object: ContinuousData): + with pytest.raises(dataclasses.FrozenInstanceError, match="cannot assign to field"): + continuous_data_object.sample_frequency = 20 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + continuous_data_object.time[0] = -1 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + continuous_data_object.values[0] = -1 + + +def test_continuous_data_copy_array(continuous_data_object: ContinuousData): + time = continuous_data_object.time + + with pytest.raises(ValueError, match="assignment destination is read-only"): + time[0] = -1 + + with pytest.raises(ValueError, match="cannot set WRITEABLE flag"): + time.flags["WRITEABLE"] = True + + new_time = time.copy() + new_time[0] = -1 + + new_continuous_data_object = continuous_data_object.update(time=new_time) + assert new_continuous_data_object.time[0] == -1 + + def test_sample_frequency_deprecation_warning(): n = 100 sample_frequency = 10 @@ -14,20 +56,12 @@ def test_sample_frequency_deprecation_warning(): with pytest.warns(DeprecationWarning, match="`sample_frequency` is set to `None`"): ContinuousData( - "label", - "name", - "unit", - "category", time=time, values=values, ) with pytest.warns(DeprecationWarning, match="`sample_frequency` is set to `None`"): ContinuousData( - "label", - "name", - "unit", - "category", time=time, values=values, sample_frequency=None, @@ -35,10 +69,6 @@ def test_sample_frequency_deprecation_warning(): with warnings.catch_warnings(record=True) as w: ContinuousData( - "label", - "name", - "unit", - "category", time=time, values=values, sample_frequency=sample_frequency, diff --git a/tests/test_datacollection.py b/tests/test_datacollection.py index e6589d413..e105f06c3 100644 --- a/tests/test_datacollection.py +++ b/tests/test_datacollection.py @@ -16,7 +16,6 @@ def create_data_object() -> Callable[[str, list | None], ContinuousData]: def internal( label: str, - derived_from: list | None = None, time: np.ndarray | None = None, values: np.ndarray | None = None, ) -> ContinuousData: @@ -28,7 +27,6 @@ def internal( time=time if isinstance(time, np.ndarray) else np.array([]), sample_frequency=0, values=values if isinstance(values, np.ndarray) else np.array([]), - derived_from=derived_from if isinstance(derived_from, list) else [], ) return internal @@ -121,9 +119,9 @@ def test_set_item(create_data_object: Callable[[str], ContinuousData]): def test_loaded_derived_data(create_data_object: Callable): data_object_loaded_1 = create_data_object("label 1") data_object_loaded_2 = create_data_object("label 2") - data_object_derived_1_a = create_data_object("label 1 der a", derived_from=[data_object_loaded_1]) - data_object_derived_1_b = create_data_object("label 1 der b", derived_from=[data_object_loaded_1]) - data_object_derived_2_a = create_data_object("label 2 der a", derived_from=[data_object_loaded_2]) + data_object_derived_1_a = create_data_object("label 1 der a") + data_object_derived_1_b = create_data_object("label 1 der b") + data_object_derived_2_a = create_data_object("label 2 der a") dc = DataCollection(ContinuousData) dc.add( @@ -134,23 +132,6 @@ def test_loaded_derived_data(create_data_object: Callable): data_object_derived_2_a, ) - assert dc.get_loaded_data() == { - "label 1": data_object_loaded_1, - "label 2": data_object_loaded_2, - } - assert dc.get_derived_data() == { - "label 1 der a": data_object_derived_1_a, - "label 1 der b": data_object_derived_1_b, - "label 2 der a": data_object_derived_2_a, - } - assert dc.get_data_derived_from(data_object_loaded_1) == { - "label 1 der a": data_object_derived_1_a, - "label 1 der b": data_object_derived_1_b, - } - assert dc.get_data_derived_from(data_object_loaded_2) == { - "label 2 der a": data_object_derived_2_a, - } - def test_concatenate(create_data_object: Callable): time = np.arange(100) / 20 diff --git a/tests/test_eeli.py b/tests/test_eeli.py index c837685fa..0cf18fe3e 100644 --- a/tests/test_eeli.py +++ b/tests/test_eeli.py @@ -18,10 +18,10 @@ def create_continuous_data_object(): def internal(sample_frequency: float, duration: float, frequency: float) -> ContinuousData: time, values = _make_cosine_wave(sample_frequency, duration, frequency) return ContinuousData( - "label", - "name", - "unit", - "impedance", + label="label", + name="name", + unit="unit", + category="impedance", time=time, values=values, sample_frequency=sample_frequency, @@ -119,12 +119,9 @@ def test_with_data(draeger_20hz_healthy_volunteer_pressure_pod: Sequence, pytest def test_non_impedance_data(draeger_20hz_healthy_volunteer_pressure_pod: Sequence) -> None: cd = draeger_20hz_healthy_volunteer_pressure_pod.continuous_data["global_impedance_(raw)"] - original_category = cd.category _ = EELI().compute_parameter(cd) - cd.category = "foo" + cd = cd.update(category="foo") with pytest.raises(ValueError, match="This method will only work on 'impedance' data, not 'foo'."): _ = EELI().compute_parameter(cd) - - cd.category = original_category diff --git a/tests/test_frozen.py b/tests/test_frozen.py new file mode 100644 index 000000000..4ba452a70 --- /dev/null +++ b/tests/test_frozen.py @@ -0,0 +1,75 @@ +import contextlib + +import numpy as np +import pytest + +from eitprocessing.datahandling.eitdata import EITData + + +@pytest.fixture +def frozen_eitdata_object() -> EITData: + return EITData( + label="test_label", + time=np.arange(10) / 10.0, + values=np.random.default_rng().random((10, 3, 3)), + sample_frequency=10.0, + vendor="simulated", + ) + + +def test_frozen_time_axis(frozen_eitdata_object: EITData): + with pytest.raises(AttributeError, match="cannot assign to field 'vendor'"): + frozen_eitdata_object.vendor = "new_vendor" + + with pytest.raises(ValueError, match="output array is read-only"): + frozen_eitdata_object.time += 1.0 + + with pytest.raises(AttributeError, match="cannot assign to field 'time'"): + frozen_eitdata_object.time = frozen_eitdata_object.time + 1.0 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + frozen_eitdata_object.time[0] = 1.0 + + +def test_frozen_values(frozen_eitdata_object: EITData): + with pytest.raises(ValueError, match="output array is read-only"): + frozen_eitdata_object.values += 1.0 + + with pytest.raises(AttributeError, match="cannot assign to field 'values'"): + frozen_eitdata_object.values = frozen_eitdata_object.values + 1.0 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + frozen_eitdata_object.values[0, 0, 0] = 1.0 + + +def test_unfreeze_array_on_copy(frozen_eitdata_object: EITData): + values_copy = frozen_eitdata_object.values.copy() + assert values_copy.flags["WRITEABLE"] + values_copy += 1.0 + new_frozen_eitdata_object = frozen_eitdata_object.update(values=values_copy) + assert not new_frozen_eitdata_object.values.flags["WRITEABLE"] + assert np.array_equal(values_copy, new_frozen_eitdata_object.values) + + +def test_frozen_slice(frozen_eitdata_object: EITData): + values_view = frozen_eitdata_object.values[:5, :, :] + assert not values_view.flags["WRITEABLE"] + with pytest.raises(ValueError, match="assignment destination is read-only"): + values_view[0, 0, 0] = 1.0 + + values_view = values_view.copy() + assert values_view.flags["WRITEABLE"] + values_view += 1.0 + + +def test_cannot_unfreeze(frozen_eitdata_object: EITData): + base = frozen_eitdata_object.values.base + with contextlib.suppress(AttributeError): + while True: + base = base.base + + if not isinstance(base, memoryview): + pytest.skip("Array is not based on a memoryview; cannot test unfreeze.") + + with pytest.raises(ValueError, match="cannot set WRITEABLE flag to True of this array"): + frozen_eitdata_object.values.flags["WRITEABLE"] = True diff --git a/tests/test_intervaldata.py b/tests/test_intervaldata.py index e5bffc387..19c92f239 100644 --- a/tests/test_intervaldata.py +++ b/tests/test_intervaldata.py @@ -4,6 +4,7 @@ import pytest from eitprocessing.datahandling.intervaldata import Interval, IntervalData +from eitprocessing.datahandling.structured_array import StructuredArray @pytest.fixture @@ -77,7 +78,7 @@ def intervaldata_valuesarray_partialfalse(): def test_post_init(intervaldata_novalues_partialtrue: IntervalData): - assert isinstance(intervaldata_novalues_partialtrue.intervals, list) + assert isinstance(intervaldata_novalues_partialtrue.intervals, StructuredArray) assert all(isinstance(interval, Interval) for interval in intervaldata_novalues_partialtrue.intervals) @@ -92,11 +93,6 @@ def test_has_values( intervaldata_valuesarray_partialfalse: IntervalData, ) -> None: assert not intervaldata_novalues_partialtrue.has_values - intervaldata_novalues_partialtrue.values = [] - assert intervaldata_novalues_partialtrue.has_values - intervaldata_novalues_partialtrue.values = None - assert not intervaldata_novalues_partialtrue.has_values - assert not intervaldata_novalues_partialfalse.has_values assert intervaldata_valueslist_partialfalse.has_values assert intervaldata_valuesarray_partialfalse.has_values @@ -162,21 +158,23 @@ def test_select_by_time( def test_select_by_time_values(intervaldata_valueslist_partialfalse: IntervalData): - assert isinstance(intervaldata_valueslist_partialfalse.values, list) + assert isinstance(intervaldata_valueslist_partialfalse.values, np.ndarray) sliced_copy = intervaldata_valueslist_partialfalse[:10] assert len(sliced_copy.intervals) == len(sliced_copy.values) - assert sliced_copy.values == intervaldata_valueslist_partialfalse.values[:10] + assert np.all(sliced_copy.values == intervaldata_valueslist_partialfalse.values[:10]) def test_concatenate(intervaldata_novalues_partialtrue: IntervalData): sliced_copy_1 = intervaldata_novalues_partialtrue[:10] sliced_copy_2 = intervaldata_novalues_partialtrue[10:20] + assert isinstance(sliced_copy_1.intervals, StructuredArray) assert len(sliced_copy_1) == 10 assert len(sliced_copy_2) == 10 concatenated = sliced_copy_1 + sliced_copy_2 + assert isinstance(concatenated.intervals, StructuredArray) assert len(concatenated) == 20 assert concatenated == sliced_copy_1.concatenate(sliced_copy_2) @@ -197,8 +195,8 @@ def test_concatenate_values_list(intervaldata_valueslist_partialfalse: IntervalD concatenated = sliced_copy_1 + sliced_copy_2 assert len(concatenated.intervals) == len(concatenated.values) - assert isinstance(intervaldata_valueslist_partialfalse.values, list) - assert concatenated.values == intervaldata_valueslist_partialfalse.values[:20] + assert isinstance(intervaldata_valueslist_partialfalse.values, np.ndarray) + assert np.all(concatenated.values == intervaldata_valueslist_partialfalse.values[:20]) def test_concatenate_values_numpy(intervaldata_valuesarray_partialfalse: IntervalData): @@ -215,5 +213,7 @@ def test_concatenate_values_type_mismatch( intervaldata_valueslist_partialfalse: IntervalData, intervaldata_valuesarray_partialfalse: IntervalData, ): - with pytest.raises(TypeError): - intervaldata_valueslist_partialfalse[:10] + intervaldata_valuesarray_partialfalse[10:] + intervaldata_valueslist_partialfalse[:10] + intervaldata_valuesarray_partialfalse[10:] + + +# TODO: add tests for IntervalData selection and concatenation with StructuredArray diff --git a/tests/test_mdn_filter.py b/tests/test_mdn_filter.py index b5c8051b6..a140a1c06 100644 --- a/tests/test_mdn_filter.py +++ b/tests/test_mdn_filter.py @@ -133,9 +133,9 @@ def test_with_eit_data(draeger_20hz_healthy_volunteer: Sequence): ) filtered_data = mdn_filter.apply(eit_data) - filtered_signal = mdn_filter.apply(eit_data.pixel_impedance, sample_frequency=eit_data.sample_frequency, axis=0) + filtered_signal = mdn_filter.apply(eit_data.values, sample_frequency=eit_data.sample_frequency, axis=0) - assert np.allclose(filtered_data.pixel_impedance, filtered_signal) + assert np.allclose(filtered_data.values, filtered_signal) @pytest.mark.parametrize( diff --git a/tests/test_parameter_tiv.py b/tests/test_parameter_tiv.py index 3410e4eda..bccdb530d 100644 --- a/tests/test_parameter_tiv.py +++ b/tests/test_parameter_tiv.py @@ -73,8 +73,6 @@ def mock_continuous_data(): unit="au", category="relative impedance", description="Global impedance created for testing TIV parameter", - parameters={}, - derived_from="mock_eit_data", time=np.linspace(0, 18, (18 * 1000), endpoint=False), values=mock_global_impedance(), sample_frequency=1000, @@ -86,13 +84,12 @@ def mock_eit_data(): """Fixture to provide an instance of EITData.""" return EITData( path="", - nframes=2000, time=np.linspace(0, 18, (18 * 1000), endpoint=False), sample_frequency=1000, vendor=Vendor.DRAEGER, label="mock_eit_data", name="mock_eit_data", - pixel_impedance=mock_pixel_impedance(), + values=mock_pixel_impedance(), ) @@ -421,8 +418,6 @@ def test_tiv_with_no_breaths_continuous(mock_continuous_data: ContinuousData): category="breath", intervals=[], values=[], - parameters={}, - derived_from=[], ), ), patch.object(tiv, "_calculate_tiv_values", return_value=[]), @@ -449,8 +444,6 @@ def test_tiv_with_no_breaths_pixel( category="breath", intervals=[], values=np.empty((0, 2, 2), dtype=object), - parameters={}, - derived_from=[], ), ), patch.object(tiv, "_calculate_tiv_values", return_value=[]), diff --git a/tests/test_pixel_breath.py b/tests/test_pixel_breath.py index 43ee5791b..4d6f3f06e 100644 --- a/tests/test_pixel_breath.py +++ b/tests/test_pixel_breath.py @@ -62,8 +62,6 @@ def mock_continuous_data(): unit="au", category="relative impedance", description="Global impedance created for testing pixel breath feature", - parameters={}, - derived_from="mock_eit_data", time=np.linspace(0, 2 * np.pi, 400), values=mock_global_impedance(), sample_frequency=399 / 2 * np.pi, @@ -75,13 +73,12 @@ def mock_eit_data(): """Fixture to provide an instance of EITData.""" return EITData( path="", - nframes=400, time=np.linspace(0, 2 * np.pi, 400), sample_frequency=399 / 2 * np.pi, vendor=Vendor.DRAEGER, label="mock_eit_data", name="mock_eit_data", - pixel_impedance=mock_pixel_impedance(), + values=mock_pixel_impedance(), ) @@ -113,13 +110,12 @@ def mock_zero_eit_data(): """Fixture to provide an instance of EITData with one element set to zero.""" return EITData( path="", - nframes=400, time=np.linspace(0, 2 * np.pi, 400), sample_frequency=399 / 2 * np.pi, vendor=Vendor.DRAEGER, label="mock_eit_data", name="mock_eit_data", - pixel_impedance=mock_pixel_impedance_one_zero(), + values=mock_pixel_impedance_one_zero(), ) @@ -172,8 +168,6 @@ def _mock(*_args, **_kwargs) -> SparseData: category="impedance difference", time=time, description="Mock tidal impedance variation", - parameters={}, - derived_from=[], values=values, ) @@ -318,7 +312,7 @@ def test_with_custom_mean_pixel_tiv( for row, col in itertools.product(range(2), range(2)): time_point = test_result[1, row, col].middle_time index = np.where(mock_eit_data.time == time_point)[0] - value_at_time = mock_eit_data.pixel_impedance[index[0], row, col] + value_at_time = mock_eit_data.values[index[0], row, col] if mean == -1: assert np.isclose(value_at_time, -1, atol=0.01) elif mean == 1: @@ -393,16 +387,18 @@ def test_phase_modes(draeger_50hz_healthy_volunteer_pressure_pod: Sequence, pyte eit_data = sequence.eit_data["raw"] # reduce the pixel set to some 'well-behaved' pixels with positive TIV - eit_data.pixel_impedance = eit_data.pixel_impedance[:, 14:20, 14:20] + eit_data = eit_data.update(values=eit_data.values[:, 14:20, 14:20]) # flip a single pixel, so the differences between algorithms becomes predictable flip_row, flip_col = 5, 5 - eit_data.pixel_impedance[:, flip_row, flip_col] = -eit_data.pixel_impedance[:, flip_row, flip_col] + new_values = eit_data.values.copy() + new_values[:, flip_row, flip_col] = -new_values[:, flip_row, flip_col] + eit_data = eit_data.update(values=new_values) cd = sequence.continuous_data["global_impedance_(raw)"] # replace the 'global' data with the sum of the middly pixels - cd.values = np.sum(eit_data.pixel_impedance, axis=(1, 2)) + cd = cd.update(values=np.sum(eit_data.values, axis=(1, 2))) pb_negative_amplitude = PixelBreath(phase_correction_mode="negative amplitude").find_pixel_breaths(eit_data, cd) pb_phase_shift = PixelBreath(phase_correction_mode="phase shift").find_pixel_breaths(eit_data, cd) diff --git a/tests/test_pixelmask.py b/tests/test_pixelmask.py index eb2d5e5a2..d5bdbfc34 100644 --- a/tests/test_pixelmask.py +++ b/tests/test_pixelmask.py @@ -158,19 +158,19 @@ def test_pixelmask_apply_eitdata(draeger_20hz_healthy_volunteer: Sequence): mask = PixelMask(np.full((32, 32), np.nan), suppress_all_nan_warning=True) masked_eit_data = mask.apply(eit_data) - assert masked_eit_data.pixel_impedance.shape == eit_data.pixel_impedance.shape - assert np.all(np.isnan(masked_eit_data.pixel_impedance)) + assert masked_eit_data.values.shape == eit_data.values.shape + assert np.all(np.isnan(masked_eit_data.values)) mask_values = np.full((32, 32), np.nan) mask_values[10:23, 10:23] = 1.0 # let center pixels pass mask = PixelMask(mask_values) masked_eit_data = mask.apply(eit_data) - assert np.array_equal(masked_eit_data.pixel_impedance[:, 10:23, 10:23], eit_data.pixel_impedance[:, 10:23, 10:23]) - assert np.all(np.isnan(masked_eit_data.pixel_impedance[:, :10, :])) - assert np.all(np.isnan(masked_eit_data.pixel_impedance[:, 23:, :])) - assert np.all(np.isnan(masked_eit_data.pixel_impedance[:, :, :10])) - assert np.all(np.isnan(masked_eit_data.pixel_impedance[:, :, 23:])) + assert np.array_equal(masked_eit_data.values[:, 10:23, 10:23], eit_data.values[:, 10:23, 10:23]) + assert np.all(np.isnan(masked_eit_data.values[:, :10, :])) + assert np.all(np.isnan(masked_eit_data.values[:, 23:, :])) + assert np.all(np.isnan(masked_eit_data.values[:, :, :10])) + assert np.all(np.isnan(masked_eit_data.values[:, :, 23:])) def test_pixelmask_apply_pixelmap(): diff --git a/tests/test_rate_detection.py b/tests/test_rate_detection.py index 4c943ace4..cfd5e98e8 100644 --- a/tests/test_rate_detection.py +++ b/tests/test_rate_detection.py @@ -107,12 +107,11 @@ def factory( # noqa: PLR0913 return EITData( path=".", - nframes=nframes, time=time, sample_frequency=sample_frequency, vendor="draeger", label="test_signal", - pixel_impedance=pixel_impedance, + values=pixel_impedance, ) return factory diff --git a/tests/test_sparsedata.py b/tests/test_sparsedata.py index d706cf8e9..e9aa02a63 100644 --- a/tests/test_sparsedata.py +++ b/tests/test_sparsedata.py @@ -62,10 +62,6 @@ def test_has_values( sparsedata_valuesarray: SparseData, ) -> None: assert not sparsedata_novalues.has_values - sparsedata_novalues.values = [] - assert sparsedata_novalues.has_values - sparsedata_novalues.values = None - assert sparsedata_valueslist.has_values assert sparsedata_valuesarray.has_values diff --git a/tests/test_tiv_lungspace.py b/tests/test_tiv_lungspace.py index 3770351e1..6064c1f49 100644 --- a/tests/test_tiv_lungspace.py +++ b/tests/test_tiv_lungspace.py @@ -26,15 +26,12 @@ def factory(amplitudes: npt.ArrayLike, duration: float = 20.0) -> EITData: sine = np.sin(t * 2 * np.pi / 5) # 5 second period signal = amplitudes[None, :, :] * (sine[:, None, None] + 0.5) return EITData( - pixel_impedance=signal, - path="", - nframes=len(t), + values=signal, time=t, sample_frequency=sample_frequency, vendor=Vendor.SIMULATED, label="simulated", description="Simulated EIT data for testing purposes", - name="", suppress_simulated_warning=True, ) diff --git a/tests/test_watershed.py b/tests/test_watershed.py index 92390725f..6c456a265 100644 --- a/tests/test_watershed.py +++ b/tests/test_watershed.py @@ -48,15 +48,12 @@ def factory( pixel_impedance = sinusoid_shape * amplitude + end_expiratory_value return EITData( - pixel_impedance=pixel_impedance, - path="", - nframes=len(time), + values=pixel_impedance, time=time, sample_frequency=sample_frequency, vendor=Vendor.SIMULATED, label="simulated", description="Simulated EIT data for testing purposes", - name="", suppress_simulated_warning=True, ) @@ -164,15 +161,12 @@ def test_watershed_captures(draeger_50hz_healthy_volunteer_pressure_pod: Sequenc def test_watershed_no_amplitude(): eit_data = EITData( - pixel_impedance=np.ones((100, 32, 32)), - path="", - nframes=100, + values=np.ones((100, 32, 32)), time=np.arange(100) / 20, sample_frequency=20, vendor=Vendor.SIMULATED, label="simulated", description="Simulated EIT data with no amplitude", - name="", suppress_simulated_warning=True, ) with pytest.raises(ValueError, match="No breaths found in TIV or amplitude data"):