From bbdbea638aed450223c2c86f5e670e43e462d99f Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Tue, 5 May 2026 04:13:44 -0700 Subject: [PATCH] Round-trip transform, crs, and tag metadata through to_geotiff/open_geotiff (#1484) Findings M-1 through M-4 from the geotiff metadata audit. - attrs['transform'] (rasterio 6-tuple) now lands on every read path so to_geotiff can rebuild the source GeoTransform without going through _coords_to_transform. That route can drift on fractional pixel sizes because x[1] - x[0] is recomputed in float64 from already-rounded coords. Coord-derived transform stays as the fallback when the attr is absent. - Document the int-EPSG convention for attrs['crs'] in open_geotiff and read_vrt; keep tolerating WKT strings on the write side. - Drop ColorMap (320) from _MANAGED_TAGS so the tag rides extra_tags and round-trips without a dedicated writer path. Surface the raw uint16 ColorMap value as attrs['colormap']. - Surface ImageDescription (270) as attrs['image_description'] and ExtraSamples (338) as attrs['extra_samples']. to_geotiff folds user-edited values back into extra_tags before write, with verbatim entries winning to keep round-trips byte-stable. - Document the int-with-nodata -> float64 promotion in open_geotiff's docstring along with the dtype= override. Tests added in test_metadata_round_trip_1484.py cover transform double round-trip on a fractional transform, ColorMap on a uint8 indexed raster, ImageDescription read and write, ExtraSamples surface, and the uint16 nodata promotion plus the dtype='uint16' ValueError. Mirrors the polish from #1462 (reproject) and #1472 (resample). Closes #1484 --- xrspatial/geotiff/__init__.py | 218 ++++++++- xrspatial/geotiff/_geotags.py | 43 +- .../tests/test_metadata_round_trip_1484.py | 418 ++++++++++++++++++ 3 files changed, 672 insertions(+), 7 deletions(-) create mode 100644 xrspatial/geotiff/tests/test_metadata_round_trip_1484.py diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 3c02b675..cb3b2a84 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -44,6 +44,14 @@ def _geo_to_coords(geo_info, height: int, width: int) -> dict: centers are at origin + 0.5*pixel_size. For PixelIsPoint: origin (tiepoint) is already the center of pixel (0,0), so no half-pixel offset is needed. + + Returned coords are pixel-center values in either raster type, matching + xarray convention. The raw GeoTransform (origin and pixel size) is + preserved separately on the DataArray as a rasterio-style 6-tuple in + ``attrs['transform']``: ``(pixel_width, 0, origin_x, 0, pixel_height, + origin_y)``. ``to_geotiff`` prefers that attr over recomputing the + transform from the coord arrays, which avoids float drift on + fractional-precision rasters. """ t = geo_info.transform if geo_info.raster_type == RASTER_PIXEL_IS_POINT: @@ -57,6 +65,54 @@ def _geo_to_coords(geo_info, height: int, width: int) -> dict: return {'y': y, 'x': x} +def _transform_tuple(geo_info) -> tuple | None: + """Return the rasterio-style 6-tuple for a GeoInfo's transform. + + Format: ``(pixel_width, 0.0, origin_x, 0.0, pixel_height, origin_y)``. + + This matches ``rasterio.Affine.to_gdal()``-adjacent ordering used by + rioxarray's ``rio.transform()`` output. Storing the tuple on the + DataArray lets ``to_geotiff`` reproduce the source GeoTransform + byte-for-byte, side-stepping float drift in the y/x coord arrays. + """ + if geo_info is None: + return None + t = geo_info.transform + if t is None: + return None + return ( + float(t.pixel_width), 0.0, float(t.origin_x), + 0.0, float(t.pixel_height), float(t.origin_y), + ) + + +def _transform_from_attr(attr_val) -> 'GeoTransform | None': + """Build a GeoTransform from an ``attrs['transform']`` value. + + Accepts a 6-tuple ``(a, b, c, d, e, f)`` (rasterio Affine ordering; + ``b`` and ``d`` are ignored, only axis-aligned affines round-trip), + a 6-tuple GDAL ordering ``(c, a, b, f, d, e)`` is NOT accepted, or + a ``GeoTransform`` instance. Returns None for anything else. + """ + if attr_val is None: + return None + if isinstance(attr_val, GeoTransform): + return attr_val + try: + seq = tuple(attr_val) + except TypeError: + return None + if len(seq) != 6: + return None + try: + a, _b, c, _d, e, f = (float(x) for x in seq) + except (TypeError, ValueError): + return None + return GeoTransform( + origin_x=c, origin_y=f, pixel_width=a, pixel_height=e, + ) + + def _validate_dtype_cast(source_dtype, target_dtype): """Validate that casting source_dtype to target_dtype is allowed. @@ -226,6 +282,29 @@ def open_geotiff(source: str, *, dtype=None, window=None, ------- xr.DataArray NumPy, Dask, CuPy, or Dask+CuPy backed depending on options. + + Notes + ----- + The CRS is stored as an int EPSG code in ``attrs['crs']`` whenever the + file's GeoKeys carry a recognized EPSG. Files whose CRS can only be + expressed as WKT keep the WKT in ``attrs['crs_wkt']`` and leave + ``attrs['crs']`` unset. ``to_geotiff`` accepts either an int EPSG or a + WKT string in ``attrs['crs']`` for backward compatibility. + + The file's GeoTransform is also surfaced as ``attrs['transform']``, + a rasterio-style 6-tuple + ``(pixel_width, 0, origin_x, 0, pixel_height, origin_y)``. ``to_geotiff`` + uses this attr verbatim when present, falling back to recomputing the + transform from the y/x coord arrays only when it is missing. The attr + is what makes write -> read -> write -> read round-trips bit-stable for + rasters with fractional pixel sizes or origins. + + Integer rasters with a nodata sentinel are silently promoted to + ``float64`` with NaN replacing the sentinel so downstream NaN-aware + code works uniformly. Pass ``dtype=...`` to keep the source dtype + (the cast will fail with ``ValueError`` for float-to-int because that + is lossy in a way users rarely intend; cast explicitly after read if + you need it). """ # VRT files if source.lower().endswith('.vrt'): @@ -282,6 +361,22 @@ def open_geotiff(source: str, *, dtype=None, window=None, if geo_info.raster_type == RASTER_PIXEL_IS_POINT: attrs['raster_type'] = 'point' + # Preserve the source GeoTransform verbatim. For a windowed read the + # origin shifts to the window's top-left pixel so the transform stays + # consistent with the returned y/x coords. + src_t = geo_info.transform + if src_t is not None: + if window is not None: + r0, c0, _r1, _c1 = window + origin_x_w = float(src_t.origin_x) + c0 * float(src_t.pixel_width) + origin_y_w = float(src_t.origin_y) + r0 * float(src_t.pixel_height) + attrs['transform'] = ( + float(src_t.pixel_width), 0.0, origin_x_w, + 0.0, float(src_t.pixel_height), origin_y_w, + ) + else: + attrs['transform'] = _transform_tuple(geo_info) + # CRS description fields if geo_info.crs_name is not None: attrs['crs_name'] = geo_info.crs_name @@ -317,6 +412,15 @@ def open_geotiff(source: str, *, dtype=None, window=None, if geo_info.extra_tags is not None: attrs['extra_tags'] = geo_info.extra_tags + # Friendly accessors for a few common pass-through tags. The raw + # entry stays in attrs['extra_tags'] so the writer can re-emit the + # exact bytes; users who tweak these convenience attrs can rely on + # to_geotiff to fold the new value into extra_tags before write. + if geo_info.image_description is not None: + attrs['image_description'] = geo_info.image_description + if geo_info.extra_samples is not None: + attrs['extra_samples'] = geo_info.extra_samples + # Resolution / DPI metadata if geo_info.x_resolution is not None: attrs['x_resolution'] = geo_info.x_resolution @@ -327,7 +431,10 @@ def open_geotiff(source: str, *, dtype=None, window=None, attrs['resolution_unit'] = _unit_names.get( geo_info.resolution_unit, str(geo_info.resolution_unit)) - # Attach palette colormap for indexed-color TIFFs + # Attach palette colormap for indexed-color TIFFs. The normalized + # RGBA triples drive matplotlib display; the raw uint16 ColorMap + # tag value lives in attrs['extra_tags'] for round-trip and is + # exposed here as attrs['colormap'] for convenience. if geo_info.colormap is not None: try: from matplotlib.colors import ListedColormap @@ -338,6 +445,13 @@ def open_geotiff(source: str, *, dtype=None, window=None, # matplotlib not available -- store raw RGBA tuples only attrs['colormap_rgba'] = geo_info.colormap + # Raw uint16 ColorMap tag value (3 * 2**bps entries, R-then-G-then-B) + if geo_info.extra_tags is not None: + for _tag_id, _tt, _tc, _tv in geo_info.extra_tags: + if _tag_id == 320: # TAG_COLORMAP + attrs['colormap'] = _tv + break + # Apply nodata mask: replace nodata sentinel values with NaN nodata = geo_info.nodata if nodata is not None: @@ -399,6 +513,55 @@ def _is_gpu_data(data) -> bool: } +# TIFF type ids needed when synthesizing extra_tags entries from attrs. +_TIFF_BYTE = 1 +_TIFF_ASCII = 2 +_TIFF_SHORT = 3 + + +def _merge_friendly_extra_tags(extra_tags_list, attrs: dict) -> list | None: + """Combine ``attrs['extra_tags']`` with friendly tag attrs. + + Synthesizes ``(tag_id, type_id, count, value)`` entries from + ``attrs['image_description']`` (270, ASCII), + ``attrs['extra_samples']`` (338, SHORT) and ``attrs['colormap']`` + (320, SHORT). An entry already present in ``extra_tags`` wins, so + a verbatim round-trip stays byte-identical. + """ + existing = list(extra_tags_list) if extra_tags_list else [] + seen_ids = {t[0] for t in existing} + + img_desc = attrs.get('image_description') + if img_desc is not None and 270 not in seen_ids: + s = str(img_desc) + existing.append((270, _TIFF_ASCII, len(s) + 1, s)) + seen_ids.add(270) + + extra_samples = attrs.get('extra_samples') + if extra_samples is not None and 338 not in seen_ids: + try: + vals = tuple(int(x) for x in extra_samples) + except (TypeError, ValueError): + vals = None + if vals: + value = vals if len(vals) > 1 else vals[0] + existing.append((338, _TIFF_SHORT, len(vals), value)) + seen_ids.add(338) + + colormap = attrs.get('colormap') + if colormap is not None and 320 not in seen_ids: + try: + cmap_vals = tuple(int(x) for x in colormap) + except (TypeError, ValueError): + cmap_vals = None + if cmap_vals: + value = cmap_vals if len(cmap_vals) > 1 else cmap_vals[0] + existing.append((320, _TIFF_SHORT, len(cmap_vals), value)) + seen_ids.add(320) + + return existing or None + + def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, crs: int | str | None = None, nodata=None, @@ -524,7 +687,13 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, if isinstance(data, xr.DataArray): raw = data.data - # Extract metadata from DataArray attrs (no materialisation needed) + # Extract metadata from DataArray attrs (no materialisation needed). + # Prefer attrs['transform'] (from open_geotiff) over the coord-derived + # transform: that path is bit-stable across round-trips, while + # _coords_to_transform can drift on fractional pixel sizes because + # x[1] - x[0] is computed in float64 from already-rounded coords. + if geo_transform is None: + geo_transform = _transform_from_attr(data.attrs.get('transform')) if geo_transform is None: geo_transform = _coords_to_transform(data) if epsg is None and crs is None: @@ -552,6 +721,12 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, from ._geotags import _build_gdal_metadata_xml gdal_meta_xml = _build_gdal_metadata_xml(gdal_meta_dict) extra_tags_list = data.attrs.get('extra_tags') + # Fold friendly attrs into extra_tags so a user-edited + # attrs['image_description'] / ['extra_samples'] / ['colormap'] + # actually reaches the file. Existing entries with the same tag id + # win, which keeps verbatim round-trips byte-stable. + extra_tags_list = _merge_friendly_extra_tags( + extra_tags_list, data.attrs) x_res = data.attrs.get('x_resolution') y_res = data.attrs.get('y_resolution') unit_str = data.attrs.get('resolution_unit') @@ -771,7 +946,9 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None, wkt_fallback = wkt if nodata is None: nodata = data.attrs.get('nodata') - geo_transform = _coords_to_transform(data) + geo_transform = _transform_from_attr(data.attrs.get('transform')) + if geo_transform is None: + geo_transform = _coords_to_transform(data) else: raw = data @@ -935,6 +1112,9 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512, attrs['raster_type'] = 'point' if nodata is not None: attrs['nodata'] = nodata + transform_tuple = _transform_tuple(geo_info) + if transform_tuple is not None: + attrs['transform'] = transform_tuple if isinstance(chunks, int): ch_h = ch_w = chunks @@ -1114,6 +1294,9 @@ def read_geotiff_gpu(source: str, *, attrs = {} if geo_info.crs_epsg is not None: attrs['crs'] = geo_info.crs_epsg + t_tuple = _transform_tuple(geo_info) + if t_tuple is not None: + attrs['transform'] = t_tuple if dtype is not None: target = np.dtype(dtype) _validate_dtype_cast(np.dtype(str(arr_gpu.dtype)), target) @@ -1201,6 +1384,9 @@ def read_geotiff_gpu(source: str, *, attrs['crs'] = geo_info.crs_epsg if geo_info.crs_wkt is not None: attrs['crs_wkt'] = geo_info.crs_wkt + t_tuple = _transform_tuple(geo_info) + if t_tuple is not None: + attrs['transform'] = t_tuple if arr_gpu.ndim == 3: dims = ['y', 'x', 'band'] @@ -1415,6 +1601,14 @@ def read_vrt(source: str, *, dtype=None, window=None, ------- xr.DataArray NumPy, Dask, CuPy, or Dask+CuPy backed depending on options. + + Notes + ----- + Like ``open_geotiff``, the CRS lands as an int EPSG in + ``attrs['crs']`` when the VRT's WKT resolves to a known EPSG code. + Otherwise ``attrs['crs']`` stays unset and ``attrs['crs_wkt']`` carries + the original WKT. The source GeoTransform is preserved as a + rasterio-style 6-tuple in ``attrs['transform']``. """ from ._vrt import read_vrt as _read_vrt_internal @@ -1467,6 +1661,24 @@ def read_vrt(source: str, *, dtype=None, window=None, if nodata is not None: attrs['nodata'] = nodata + # Surface the source GeoTransform in the same rasterio ordering used + # by open_geotiff: (pixel_width, 0, origin_x, 0, pixel_height, origin_y). + # vrt.geo_transform is GDAL ordering, so reorder. For a windowed read + # the origin shifts by (col_offset * res_x, row_offset * res_y). + if gt is not None: + if window is not None: + r0w, c0w, _r1w, _c1w = window + r0w = max(0, r0w) + c0w = max(0, c0w) + else: + r0w = c0w = 0 + origin_x_out = float(origin_x) + c0w * float(res_x) + origin_y_out = float(origin_y) + r0w * float(res_y) + attrs['transform'] = ( + float(res_x), 0.0, origin_x_out, + 0.0, float(res_y), origin_y_out, + ) + # Transfer to GPU if requested if gpu: import cupy diff --git a/xrspatial/geotiff/_geotags.py b/xrspatial/geotiff/_geotags.py index e7f95c82..6173e828 100644 --- a/xrspatial/geotiff/_geotags.py +++ b/xrspatial/geotiff/_geotags.py @@ -15,13 +15,22 @@ TAG_PREDICTOR, TAG_COLORMAP, TAG_TILE_WIDTH, TAG_TILE_LENGTH, TAG_TILE_OFFSETS, TAG_TILE_BYTE_COUNTS, + TAG_EXTRA_SAMPLES, TAG_SAMPLE_FORMAT, TAG_GDAL_METADATA, TAG_GDAL_NODATA, TAG_MODEL_PIXEL_SCALE, TAG_MODEL_TIEPOINT, TAG_MODEL_TRANSFORMATION, TAG_GEO_KEY_DIRECTORY, TAG_GEO_DOUBLE_PARAMS, TAG_GEO_ASCII_PARAMS, ) -# Tags that the writer manages -- everything else can be passed through +# ImageDescription tag (270). Captured for round-trip but not managed +# by the writer -- it flows through extra_tags pass-through. +TAG_IMAGE_DESCRIPTION = 270 + +# Tags the writer manages directly. Tags not in this set are collected +# into GeoInfo.extra_tags on read and re-emitted on write via the +# extra_tags pass-through. ColorMap (320), ExtraSamples (338, only emitted +# automatically when samples > 1), and ImageDescription (270) intentionally +# stay OUT of this set so they round-trip without dedicated writer plumbing. _MANAGED_TAGS = frozenset({ TAG_IMAGE_WIDTH, TAG_IMAGE_LENGTH, TAG_BITS_PER_SAMPLE, TAG_COMPRESSION, TAG_PHOTOMETRIC, @@ -29,7 +38,7 @@ TAG_ROWS_PER_STRIP, TAG_STRIP_BYTE_COUNTS, TAG_X_RESOLUTION, TAG_Y_RESOLUTION, TAG_PLANAR_CONFIG, TAG_RESOLUTION_UNIT, - TAG_PREDICTOR, TAG_COLORMAP, + TAG_PREDICTOR, TAG_TILE_WIDTH, TAG_TILE_LENGTH, TAG_TILE_OFFSETS, TAG_TILE_BYTE_COUNTS, TAG_SAMPLE_FORMAT, TAG_GDAL_METADATA, TAG_GDAL_NODATA, @@ -139,6 +148,11 @@ class GeoInfo: # Extra TIFF tags not managed by the writer (pass-through on round-trip) # List of (tag_id, type_id, count, raw_value) tuples. extra_tags: list | None = None + # ImageDescription tag (270) decoded as a Python str, when present. + image_description: str | None = None + # ExtraSamples tag (338) as a tuple of int alpha/extra-sample codes, + # when present. + extra_samples: tuple | None = None # Raw geokeys dict for anything else geokeys: dict[int, int | float | str] = field(default_factory=dict) @@ -478,9 +492,28 @@ def extract_geo_info(ifd: IFD, data: bytes | memoryview, # Collect extra (non-managed) tags for pass-through extra_tags = [] + image_description = None + extra_samples = None for tag_id, entry in ifd.entries.items(): - if tag_id not in _MANAGED_TAGS: - extra_tags.append((tag_id, entry.type_id, entry.count, entry.value)) + if tag_id in _MANAGED_TAGS: + continue + extra_tags.append((tag_id, entry.type_id, entry.count, entry.value)) + # Surface a few well-known extras as friendly attrs while still + # carrying the raw entry in extra_tags so to_geotiff can rewrite + # it byte-for-byte. + if tag_id == TAG_IMAGE_DESCRIPTION: + v = entry.value + if isinstance(v, bytes): + v = v.rstrip(b'\x00').decode('ascii', errors='replace') + elif isinstance(v, str): + v = v.rstrip('\x00') + image_description = v + elif tag_id == TAG_EXTRA_SAMPLES: + v = entry.value + if isinstance(v, tuple): + extra_samples = tuple(int(x) for x in v) + elif isinstance(v, int): + extra_samples = (int(v),) if not extra_tags: extra_tags = None @@ -518,6 +551,8 @@ def extract_geo_info(ifd: IFD, data: bytes | memoryview, gdal_metadata=gdal_metadata, gdal_metadata_xml=gdal_metadata_xml, extra_tags=extra_tags, + image_description=image_description, + extra_samples=extra_samples, geokeys=geokeys, ) diff --git a/xrspatial/geotiff/tests/test_metadata_round_trip_1484.py b/xrspatial/geotiff/tests/test_metadata_round_trip_1484.py new file mode 100644 index 00000000..708b1f74 --- /dev/null +++ b/xrspatial/geotiff/tests/test_metadata_round_trip_1484.py @@ -0,0 +1,418 @@ +"""Round-trip tests for transform / crs / tag metadata (issue #1484). + +These cover findings M-1 through M-4 from the geotiff metadata audit: + +* M-1 / M-2: ``attrs['crs']`` stays as the same int EPSG and + ``attrs['transform']`` survives write -> read -> write -> read with + the same numeric values up to float precision. +* M-3: ColorMap, ExtraSamples, and ImageDescription survive a single + write -> read cycle. ColorMap exits the writer through the + ``extra_tags`` pass-through (the tag is no longer in + ``_MANAGED_TAGS``); ImageDescription gets a friendly ``attrs`` entry. +* M-4: integer rasters with a nodata sentinel get promoted to float64 + with NaN, and a user-requested ``dtype='uint16'`` cast on the read + side raises ValueError (existing float-to-int guard). +""" +from __future__ import annotations + +import struct + +import numpy as np +import pytest + +from xrspatial.geotiff import open_geotiff, to_geotiff +from xrspatial.geotiff._writer import write + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_palette_uint8_tiff(path, pixels, palette_rgb16): + """Write an 8-bit, 256-entry palette TIFF directly (no writer support + for ColorMap on the write side). + + palette_rgb16 must have 256 (R, G, B) tuples of uint16 values. + """ + bo = '<' + width = pixels.shape[1] + height = pixels.shape[0] + n_colors = 256 + assert len(palette_rgb16) == n_colors + + flat = pixels.ravel().astype(np.uint8) + pixel_bytes = flat.tobytes() + + r_vals = [c[0] for c in palette_rgb16] + g_vals = [c[1] for c in palette_rgb16] + b_vals = [c[2] for c in palette_rgb16] + cmap_values = r_vals + g_vals + b_vals + + tag_list = [] + + def add_short(tag, val): + tag_list.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + + def add_long(tag, val): + tag_list.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + + def add_shorts(tag, vals): + tag_list.append( + (tag, 3, len(vals), + struct.pack(f'{bo}{len(vals)}H', *vals))) + + add_short(256, width) + add_short(257, height) + add_short(258, 8) # bits per sample + add_short(259, 1) # no compression + add_short(262, 3) # photometric = palette + add_short(277, 1) # samples per pixel = 1 + add_short(278, height) # rows per strip + add_long(273, 0) # strip offsets placeholder + add_long(279, len(pixel_bytes)) + add_shorts(320, cmap_values) # ColorMap + add_short(339, 1) # sample format = uint + + tag_list.sort(key=lambda t: t[0]) + num_entries = len(tag_list) + ifd_start = 8 + ifd_size = 2 + 12 * num_entries + 4 + overflow_start = ifd_start + ifd_size + + overflow_buf = bytearray() + tag_offsets = {} + for tag, _typ, _count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + pixel_data_start = overflow_start + len(overflow_buf) + + patched = [] + for tag, typ, count, raw in tag_list: + if tag == 273: + patched.append((tag, typ, count, + struct.pack(f'{bo}I', pixel_data_start))) + else: + patched.append((tag, typ, count, raw)) + tag_list = patched + + overflow_buf = bytearray() + tag_offsets = {} + for tag, _typ, _count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + out = bytearray() + out.extend(b'II') + out.extend(struct.pack(f'{bo}H', 42)) + out.extend(struct.pack(f'{bo}I', ifd_start)) + out.extend(struct.pack(f'{bo}H', num_entries)) + for tag, typ, count, raw in tag_list: + out.extend(struct.pack(f'{bo}HHI', tag, typ, count)) + if len(raw) <= 4: + out.extend(raw.ljust(4, b'\x00')) + else: + ptr = overflow_start + tag_offsets[tag] + out.extend(struct.pack(f'{bo}I', ptr)) + out.extend(struct.pack(f'{bo}I', 0)) + out.extend(overflow_buf) + out.extend(pixel_bytes) + + with open(path, 'wb') as f: + f.write(bytes(out)) + + +def _write_simple_tiff_with_image_description(path, pixels, description): + """Write an uncompressed, single-strip TIFF that carries an + ImageDescription tag (270) so we can test the read side.""" + bo = '<' + height, width = pixels.shape + pixel_bytes = pixels.astype(np.float32).tobytes() + desc_bytes = description.encode('ascii') + b'\x00' + if len(desc_bytes) % 2: + desc_bytes += b'\x00' + + tag_list = [] + + def add_short(tag, val): + tag_list.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + + def add_long(tag, val): + tag_list.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + + add_short(256, width) + add_short(257, height) + add_short(258, 32) + add_short(259, 1) + add_short(262, 1) + tag_list.append((270, 2, len(description) + 1, desc_bytes)) + add_short(277, 1) + add_short(278, height) + add_long(273, 0) + add_long(279, len(pixel_bytes)) + add_short(339, 3) # sample format = float + + tag_list.sort(key=lambda t: t[0]) + num_entries = len(tag_list) + ifd_start = 8 + ifd_size = 2 + 12 * num_entries + 4 + overflow_start = ifd_start + ifd_size + + overflow_buf = bytearray() + tag_offsets = {} + for tag, _t, _c, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + pixel_data_start = overflow_start + len(overflow_buf) + patched = [] + for tag, typ, count, raw in tag_list: + if tag == 273: + patched.append((tag, typ, count, + struct.pack(f'{bo}I', pixel_data_start))) + else: + patched.append((tag, typ, count, raw)) + tag_list = patched + + overflow_buf = bytearray() + tag_offsets = {} + for tag, _t, _c, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + out = bytearray() + out.extend(b'II') + out.extend(struct.pack(f'{bo}H', 42)) + out.extend(struct.pack(f'{bo}I', ifd_start)) + out.extend(struct.pack(f'{bo}H', num_entries)) + for tag, typ, count, raw in tag_list: + out.extend(struct.pack(f'{bo}HHI', tag, typ, count)) + if len(raw) <= 4: + out.extend(raw.ljust(4, b'\x00')) + else: + ptr = overflow_start + tag_offsets[tag] + out.extend(struct.pack(f'{bo}I', ptr)) + out.extend(struct.pack(f'{bo}I', 0)) + out.extend(overflow_buf) + out.extend(pixel_bytes) + + with open(path, 'wb') as f: + f.write(bytes(out)) + + +# --------------------------------------------------------------------------- +# M-1 / M-2: transform & crs round-trip stability +# --------------------------------------------------------------------------- + +class TestTransformCrsRoundTrip: + + def test_transform_attr_present_on_read(self, tmp_path): + arr = np.arange(20, dtype=np.float32).reshape(4, 5) + from xrspatial.geotiff._geotags import GeoTransform + gt = GeoTransform( + origin_x=500000.0, origin_y=4000000.0, + pixel_width=30.0, pixel_height=-30.0, + ) + path = str(tmp_path / 'transform_present_1484.tif') + write(arr, path, geo_transform=gt, crs_epsg=32610, + compression='none', tiled=False) + da = open_geotiff(path) + assert 'transform' in da.attrs + a, b, c, d, e, f = da.attrs['transform'] + assert b == 0.0 and d == 0.0 + assert a == pytest.approx(30.0) + assert e == pytest.approx(-30.0) + assert c == pytest.approx(500000.0) + assert f == pytest.approx(4000000.0) + assert da.attrs['crs'] == 32610 + + def test_double_round_trip_fractional_transform(self, tmp_path): + """Fractional pixel size + non-grid origin: writing twice must not + drift the transform. This is the case ``_coords_to_transform`` can + miss because ``x[1] - x[0]`` is recomputed from already-rounded + coords.""" + from xrspatial.geotiff._geotags import GeoTransform + arr = np.linspace(0, 1, 8 * 12, dtype=np.float64).reshape(8, 12) + gt = GeoTransform( + origin_x=-122.123456789, + origin_y=37.987654321, + pixel_width=1.0 / 3600.0 + 1e-12, # ~ 1 arc-second + tiny offset + pixel_height=-(1.0 / 3600.0 + 1e-12), + ) + path1 = str(tmp_path / 'rt1_1484.tif') + write(arr, path1, geo_transform=gt, crs_epsg=4326, + compression='none', tiled=False) + da1 = open_geotiff(path1) + assert da1.attrs['crs'] == 4326 + + path2 = str(tmp_path / 'rt2_1484.tif') + to_geotiff(da1, path2, compression='none') + da2 = open_geotiff(path2) + + path3 = str(tmp_path / 'rt3_1484.tif') + to_geotiff(da2, path3, compression='none') + da3 = open_geotiff(path3) + + # CRS stays an int EPSG unchanged across cycles + assert da3.attrs['crs'] == 4326 + # Transform tuple equal up to float precision + t1 = da1.attrs['transform'] + t3 = da3.attrs['transform'] + for v1, v3 in zip(t1, t3): + assert v3 == pytest.approx(v1, abs=1e-15, rel=1e-12) + + def test_crs_string_input_still_tolerated(self, tmp_path): + """Backward compat: passing a WKT string in attrs['crs'] still works + on the write side. open_geotiff turns it back into an int EPSG.""" + import xarray as xr + from xrspatial.geotiff._geotags import _epsg_to_wkt + wkt = _epsg_to_wkt(4326) + if wkt is None: + pytest.skip("pyproj not available") + arr = np.zeros((3, 3), dtype=np.float32) + da = xr.DataArray( + arr, + dims=['y', 'x'], + coords={ + 'y': np.array([0.5, -0.5, -1.5]), + 'x': np.array([0.5, 1.5, 2.5]), + }, + attrs={'crs': wkt}, + ) + path = str(tmp_path / 'wkt_string_crs_1484.tif') + to_geotiff(da, path, compression='none') + result = open_geotiff(path) + assert result.attrs['crs'] == 4326 + + +# --------------------------------------------------------------------------- +# M-3: tag pass-through (ColorMap, ImageDescription, ExtraSamples) +# --------------------------------------------------------------------------- + +class TestTagPassThrough: + + def test_colormap_round_trip(self, tmp_path): + palette = [(i * 257, (255 - i) * 257, (i * 2) % 65536) + for i in range(256)] + pixels = np.array([[0, 1, 2, 254, 255], + [10, 20, 30, 40, 50]], dtype=np.uint8) + in_path = str(tmp_path / 'colormap_in_1484.tif') + _make_palette_uint8_tiff(in_path, pixels, palette) + + da = open_geotiff(in_path) + assert da.dtype == np.uint8 + assert 'colormap' in da.attrs + # Raw uint16 ColorMap: 3 * 256 = 768 entries + assert len(da.attrs['colormap']) == 768 + + # Round-trip through to_geotiff: ColorMap rides extra_tags + out_path = str(tmp_path / 'colormap_out_1484.tif') + to_geotiff(da, out_path, compression='none') + da2 = open_geotiff(out_path) + + np.testing.assert_array_equal(da2.values, pixels) + assert 'colormap' in da2.attrs + assert tuple(da2.attrs['colormap']) == tuple(da.attrs['colormap']) + + def test_image_description_round_trip(self, tmp_path): + pixels = np.arange(12, dtype=np.float32).reshape(3, 4) + desc = "elevation tile from issue 1484" + in_path = str(tmp_path / 'desc_in_1484.tif') + _write_simple_tiff_with_image_description(in_path, pixels, desc) + + da = open_geotiff(in_path) + assert da.attrs.get('image_description') == desc + # Also reachable through extra_tags by tag id 270 + et_ids = {t[0] for t in da.attrs['extra_tags']} + assert 270 in et_ids + + out_path = str(tmp_path / 'desc_out_1484.tif') + to_geotiff(da, out_path, compression='none') + da2 = open_geotiff(out_path) + assert da2.attrs.get('image_description') == desc + + def test_image_description_added_via_attrs(self, tmp_path): + """Setting attrs['image_description'] on a fresh DataArray flows + through to the output file even when extra_tags is empty.""" + import xarray as xr + arr = np.zeros((4, 4), dtype=np.float32) + da = xr.DataArray( + arr, dims=['y', 'x'], + coords={'y': np.arange(4), 'x': np.arange(4)}, + attrs={'image_description': 'synthetic test 1484'}, + ) + path = str(tmp_path / 'desc_synth_1484.tif') + to_geotiff(da, path, compression='none') + + result = open_geotiff(path) + assert result.attrs.get('image_description') == 'synthetic test 1484' + + def test_extra_samples_attr_surfaces_on_read(self, tmp_path): + """A 4-band write produces ExtraSamples internally; reading it back + surfaces the codes as attrs['extra_samples'].""" + rgba = np.zeros((4, 5, 4), dtype=np.uint8) + rgba[..., 3] = 255 + path = str(tmp_path / 'rgba_es_1484.tif') + write(rgba, path, compression='none', tiled=False) + da = open_geotiff(path) + assert da.attrs.get('extra_samples') is not None + # Code 2 = unassociated alpha, per the writer + assert da.attrs['extra_samples'][0] in (1, 2) + + +# --------------------------------------------------------------------------- +# M-4: integer-with-nodata dtype promotion +# --------------------------------------------------------------------------- + +class TestIntegerNodataPromotion: + + def test_uint16_with_nodata_promotes_to_float64(self, tmp_path): + arr = np.array([[1, 2, 3], [65535, 5, 6]], dtype=np.uint16) + path = str(tmp_path / 'u16_nodata_1484.tif') + write(arr, path, nodata=65535, compression='none', tiled=False) + + da = open_geotiff(path) + assert da.dtype == np.float64 + assert np.isnan(da.values[1, 0]) + np.testing.assert_array_equal( + da.values[~np.isnan(da.values)], + np.array([1.0, 2.0, 3.0, 5.0, 6.0]), + ) + + def test_uint16_with_nodata_dtype_uint16_raises(self, tmp_path): + """Promotion happens before the user-requested dtype check, so + passing dtype='uint16' on an integer-with-nodata raster hits the + float-to-int guard and raises ValueError.""" + arr = np.array([[1, 2, 3], [65535, 5, 6]], dtype=np.uint16) + path = str(tmp_path / 'u16_nodata_cast_1484.tif') + write(arr, path, nodata=65535, compression='none', tiled=False) + with pytest.raises(ValueError, match='float.*int'): + open_geotiff(path, dtype='uint16') + + def test_uint16_no_nodata_keeps_dtype(self, tmp_path): + """Without a nodata sentinel, no promotion; original dtype stays.""" + arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) + path = str(tmp_path / 'u16_no_nodata_1484.tif') + write(arr, path, compression='none', tiled=False) + da = open_geotiff(path) + assert da.dtype == np.uint16