diff --git a/Stoner/Image/core.py b/Stoner/Image/core.py index 9ae69cb20..4160f6060 100755 --- a/Stoner/Image/core.py +++ b/Stoner/Image/core.py @@ -396,6 +396,78 @@ def __array_finalize__(self, obj): super().__array_finalize__(obj=obj) + # ============================================================================================================== + ############################ Descriptor Protocol ################################################ + # ============================================================================================================== + + def __set_name__(self, owner, name): + """Record the attribute name when the descriptor is assigned as a class attribute. + + Args: + owner (type): + The class that owns this descriptor. + name (str): + The attribute name used for this descriptor on the owner class. + """ + self._attr_name = name + self._private_name = f"_{name}" + + def __get__(self, obj, objtype=None): + """Return the ImageArray stored on the owner instance (descriptor getter). + + Args: + obj: The owner instance, or ``None`` when accessed on the class itself. + objtype: The owner class. + + Returns: + ImageArray: The stored image array, or this descriptor instance when accessed on the class. + """ + if obj is None: + return self + private = getattr(self, "_private_name", "_image") + try: + return object.__getattribute__(obj, private) + except AttributeError: + return ImageArray() + + def __set__(self, obj, value): + """Store *value* as an :class:`ImageArray` on the owner instance (descriptor setter). + + The setter preserves the existing filename and metadata, and handles the special + case where the image is part of a stack (in-place update when shapes match). + + Args: + obj: The owner instance. + value (array-like): New image data to store. + """ + private = getattr(self, "_private_name", "_image") + try: + old_image = object.__getattribute__(obj, private) + filename = old_image.filename + metadata = old_image.metadata + except AttributeError: + old_image = None + filename = "" + metadata = {} + if ( + old_image is not None + and hasattr(obj, "_fromstack") + and obj._fromstack + and old_image.shape == value.shape + and old_image.dtype == value.dtype + ): + old_image[:] = np.copy(value) + new_image = old_image + elif isinstance(value, np.ndarray): + new_image = np.copy(value).view(ImageArray) + object.__setattr__(obj, private, new_image) + else: + new_image = ImageArray(value) + object.__setattr__(obj, private, new_image) + obj.filename = filename + new_image.metadata.update(metadata) + new_image.metadata.update(getattr(value, "metadata", {})) + def _load(self, filename, *args, **kwargs): """Load an image from a file and return as a ImageArray.""" cls = type(self) @@ -894,31 +966,8 @@ def draw(self): """Access the DrawProxy object for accessing the skimage draw sub module.""" return DrawProxy(self.image, self) - @property - def image(self): - """Access the image data.""" - return self._image - - @image.setter - def image(self, v): - """Ensure stored image is always an ImageArray.""" - filename = self._image.filename - metadata = self._image.metadata - # ensure setting image goes into the same memory block if from stack - if ( - hasattr(self, "_fromstack") - and self._fromstack - and self._image.shape == v.shape - and self._image.dtype == v.dtype - ): - self._image[:] = np.copy(v) - elif isinstance(v, np.ndarray): - self._image = np.copy(v).view(ImageArray) - else: - self._image = ImageArray(v) - self.filename = filename - self._image.metadata.update(metadata) - self._image.metadata.update(getattr(v, "metadata", {})) + image = ImageArray() + """ImageArray descriptor that enforces the image attribute is always an :class:`ImageArray` instance.""" @property def mask(self): diff --git a/Stoner/Image/kerr.py b/Stoner/Image/kerr.py index 5c93513cf..ba05d665d 100755 --- a/Stoner/Image/kerr.py +++ b/Stoner/Image/kerr.py @@ -109,12 +109,12 @@ def __init__(self: Self, *args: Args, **kwargs: Kwargs) -> None: super().__init__(*args, **kwargs) self._image = self.image.view(KerrArray) - @ImageFile.image.getter + @property def image(self: Self) -> ImageArray: # pylint: disable=invalid-overridden-method """Access the image data.""" return self._image.view(KerrArray) - @ImageFile.image.setter + @image.setter def image(self: Self, v) -> None: # noqa: F811 # pylint: disable=redefined-outer-name, function-redefined """Ensure stored image is always an ImageArray.""" filename = self.filename diff --git a/tests/Stoner/Image/test_core.py b/tests/Stoner/Image/test_core.py index 24a8fd540..035ef9de1 100755 --- a/tests/Stoner/Image/test_core.py +++ b/tests/Stoner/Image/test_core.py @@ -511,5 +511,38 @@ def test_operators(): assert i.sum() == 50 * 255, "Negate operators failed" +def test_image_descriptor(): + """Test that ImageArray acts as a descriptor on ImageFile.image.""" + # Class-level: ImageFile.image should return the descriptor itself + assert isinstance(ImageFile.__dict__["image"], ImageArray), "ImageFile.image class attr is not an ImageArray descriptor" + + # Setting a plain numpy array should produce an ImageArray + imf = ImageFile() + arr = np.ones((5, 6)) + imf.image = arr + assert isinstance(imf.image, ImageArray), "image is not ImageArray after setting numpy array" + assert imf.image.shape == (5, 6), "shape mismatch after setting numpy array" + + # Setting an ImageArray should keep it as an ImageArray + ima = ImageArray(np.zeros((3, 4))) + imf2 = ImageFile() + imf2.image = ima + assert isinstance(imf2.image, ImageArray), "image is not ImageArray after setting ImageArray" + + # metadata should be preserved across assignment + imf3 = ImageFile() + imf3.image = ImageArray(np.ones((4, 4))) + imf3.image.metadata["test_key"] = "test_value" + imf3.image = np.zeros((4, 4)) + assert imf3.image.metadata.get("test_key") == "test_value", "metadata not preserved across image assignment" + + # filename should be preserved when image is replaced + imf4 = ImageFile() + imf4.image = ImageArray(np.ones((4, 4))) + imf4.filename = "test.png" + imf4.image = np.zeros((4, 4)) + assert imf4.filename == "test.png", "filename not preserved across image assignment" + + if __name__ == "__main__": # Run some tests manually to allow debugging pytest.main([__file__, "--pdb"])