Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 215 additions & 3 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ def _geo_to_coords(geo_info, height: int, width: int) -> dict:
centers are at origin + 0.5*pixel_size.
For PixelIsPoint: origin (tiepoint) is already the center of pixel (0,0),
so no half-pixel offset is needed.

Returned coords are pixel-center values in either raster type, matching
xarray convention. The raw GeoTransform (origin and pixel size) is
preserved separately on the DataArray as a rasterio-style 6-tuple in
``attrs['transform']``: ``(pixel_width, 0, origin_x, 0, pixel_height,
origin_y)``. ``to_geotiff`` prefers that attr over recomputing the
transform from the coord arrays, which avoids float drift on
fractional-precision rasters.
"""
t = geo_info.transform
if geo_info.raster_type == RASTER_PIXEL_IS_POINT:
Expand All @@ -57,6 +65,54 @@ def _geo_to_coords(geo_info, height: int, width: int) -> dict:
return {'y': y, 'x': x}


def _transform_tuple(geo_info) -> tuple | None:
"""Return the rasterio-style 6-tuple for a GeoInfo's transform.

Format: ``(pixel_width, 0.0, origin_x, 0.0, pixel_height, origin_y)``.

This matches ``rasterio.Affine.to_gdal()``-adjacent ordering used by
rioxarray's ``rio.transform()`` output. Storing the tuple on the
DataArray lets ``to_geotiff`` reproduce the source GeoTransform
byte-for-byte, side-stepping float drift in the y/x coord arrays.
"""
if geo_info is None:
return None
t = geo_info.transform
if t is None:
return None
return (
float(t.pixel_width), 0.0, float(t.origin_x),
0.0, float(t.pixel_height), float(t.origin_y),
)


def _transform_from_attr(attr_val) -> 'GeoTransform | None':
"""Build a GeoTransform from an ``attrs['transform']`` value.

Accepts a 6-tuple ``(a, b, c, d, e, f)`` (rasterio Affine ordering;
``b`` and ``d`` are ignored, only axis-aligned affines round-trip),
a 6-tuple GDAL ordering ``(c, a, b, f, d, e)`` is NOT accepted, or
a ``GeoTransform`` instance. Returns None for anything else.
"""
if attr_val is None:
return None
if isinstance(attr_val, GeoTransform):
return attr_val
try:
seq = tuple(attr_val)
except TypeError:
return None
if len(seq) != 6:
return None
try:
a, _b, c, _d, e, f = (float(x) for x in seq)
except (TypeError, ValueError):
return None
return GeoTransform(
origin_x=c, origin_y=f, pixel_width=a, pixel_height=e,
)


def _validate_dtype_cast(source_dtype, target_dtype):
"""Validate that casting source_dtype to target_dtype is allowed.

Expand Down Expand Up @@ -226,6 +282,29 @@ def open_geotiff(source: str, *, dtype=None, window=None,
-------
xr.DataArray
NumPy, Dask, CuPy, or Dask+CuPy backed depending on options.

Notes
-----
The CRS is stored as an int EPSG code in ``attrs['crs']`` whenever the
file's GeoKeys carry a recognized EPSG. Files whose CRS can only be
expressed as WKT keep the WKT in ``attrs['crs_wkt']`` and leave
``attrs['crs']`` unset. ``to_geotiff`` accepts either an int EPSG or a
WKT string in ``attrs['crs']`` for backward compatibility.

The file's GeoTransform is also surfaced as ``attrs['transform']``,
a rasterio-style 6-tuple
``(pixel_width, 0, origin_x, 0, pixel_height, origin_y)``. ``to_geotiff``
uses this attr verbatim when present, falling back to recomputing the
transform from the y/x coord arrays only when it is missing. The attr
is what makes write -> read -> write -> read round-trips bit-stable for
rasters with fractional pixel sizes or origins.

Integer rasters with a nodata sentinel are silently promoted to
``float64`` with NaN replacing the sentinel so downstream NaN-aware
code works uniformly. Pass ``dtype=...`` to keep the source dtype
(the cast will fail with ``ValueError`` for float-to-int because that
is lossy in a way users rarely intend; cast explicitly after read if
you need it).
"""
# VRT files
if source.lower().endswith('.vrt'):
Expand Down Expand Up @@ -282,6 +361,22 @@ def open_geotiff(source: str, *, dtype=None, window=None,
if geo_info.raster_type == RASTER_PIXEL_IS_POINT:
attrs['raster_type'] = 'point'

# Preserve the source GeoTransform verbatim. For a windowed read the
# origin shifts to the window's top-left pixel so the transform stays
# consistent with the returned y/x coords.
src_t = geo_info.transform
if src_t is not None:
if window is not None:
r0, c0, _r1, _c1 = window
origin_x_w = float(src_t.origin_x) + c0 * float(src_t.pixel_width)
origin_y_w = float(src_t.origin_y) + r0 * float(src_t.pixel_height)
attrs['transform'] = (
float(src_t.pixel_width), 0.0, origin_x_w,
0.0, float(src_t.pixel_height), origin_y_w,
)
else:
attrs['transform'] = _transform_tuple(geo_info)

# CRS description fields
if geo_info.crs_name is not None:
attrs['crs_name'] = geo_info.crs_name
Expand Down Expand Up @@ -317,6 +412,15 @@ def open_geotiff(source: str, *, dtype=None, window=None,
if geo_info.extra_tags is not None:
attrs['extra_tags'] = geo_info.extra_tags

# Friendly accessors for a few common pass-through tags. The raw
# entry stays in attrs['extra_tags'] so the writer can re-emit the
# exact bytes; users who tweak these convenience attrs can rely on
# to_geotiff to fold the new value into extra_tags before write.
if geo_info.image_description is not None:
attrs['image_description'] = geo_info.image_description
if geo_info.extra_samples is not None:
attrs['extra_samples'] = geo_info.extra_samples

# Resolution / DPI metadata
if geo_info.x_resolution is not None:
attrs['x_resolution'] = geo_info.x_resolution
Expand All @@ -327,7 +431,10 @@ def open_geotiff(source: str, *, dtype=None, window=None,
attrs['resolution_unit'] = _unit_names.get(
geo_info.resolution_unit, str(geo_info.resolution_unit))

# Attach palette colormap for indexed-color TIFFs
# Attach palette colormap for indexed-color TIFFs. The normalized
# RGBA triples drive matplotlib display; the raw uint16 ColorMap
# tag value lives in attrs['extra_tags'] for round-trip and is
# exposed here as attrs['colormap'] for convenience.
if geo_info.colormap is not None:
try:
from matplotlib.colors import ListedColormap
Expand All @@ -338,6 +445,13 @@ def open_geotiff(source: str, *, dtype=None, window=None,
# matplotlib not available -- store raw RGBA tuples only
attrs['colormap_rgba'] = geo_info.colormap

# Raw uint16 ColorMap tag value (3 * 2**bps entries, R-then-G-then-B)
if geo_info.extra_tags is not None:
for _tag_id, _tt, _tc, _tv in geo_info.extra_tags:
if _tag_id == 320: # TAG_COLORMAP
attrs['colormap'] = _tv
break

# Apply nodata mask: replace nodata sentinel values with NaN
nodata = geo_info.nodata
if nodata is not None:
Expand Down Expand Up @@ -399,6 +513,55 @@ def _is_gpu_data(data) -> bool:
}


# TIFF type ids needed when synthesizing extra_tags entries from attrs.
_TIFF_BYTE = 1
_TIFF_ASCII = 2
_TIFF_SHORT = 3


def _merge_friendly_extra_tags(extra_tags_list, attrs: dict) -> list | None:
"""Combine ``attrs['extra_tags']`` with friendly tag attrs.

Synthesizes ``(tag_id, type_id, count, value)`` entries from
``attrs['image_description']`` (270, ASCII),
``attrs['extra_samples']`` (338, SHORT) and ``attrs['colormap']``
(320, SHORT). An entry already present in ``extra_tags`` wins, so
a verbatim round-trip stays byte-identical.
"""
existing = list(extra_tags_list) if extra_tags_list else []
seen_ids = {t[0] for t in existing}

img_desc = attrs.get('image_description')
if img_desc is not None and 270 not in seen_ids:
s = str(img_desc)
existing.append((270, _TIFF_ASCII, len(s) + 1, s))
seen_ids.add(270)

extra_samples = attrs.get('extra_samples')
if extra_samples is not None and 338 not in seen_ids:
try:
vals = tuple(int(x) for x in extra_samples)
except (TypeError, ValueError):
vals = None
if vals:
value = vals if len(vals) > 1 else vals[0]
existing.append((338, _TIFF_SHORT, len(vals), value))
seen_ids.add(338)

colormap = attrs.get('colormap')
if colormap is not None and 320 not in seen_ids:
try:
cmap_vals = tuple(int(x) for x in colormap)
except (TypeError, ValueError):
cmap_vals = None
if cmap_vals:
value = cmap_vals if len(cmap_vals) > 1 else cmap_vals[0]
existing.append((320, _TIFF_SHORT, len(cmap_vals), value))
seen_ids.add(320)

return existing or None


def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
crs: int | str | None = None,
nodata=None,
Expand Down Expand Up @@ -524,7 +687,13 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
if isinstance(data, xr.DataArray):
raw = data.data

# Extract metadata from DataArray attrs (no materialisation needed)
# Extract metadata from DataArray attrs (no materialisation needed).
# Prefer attrs['transform'] (from open_geotiff) over the coord-derived
# transform: that path is bit-stable across round-trips, while
# _coords_to_transform can drift on fractional pixel sizes because
# x[1] - x[0] is computed in float64 from already-rounded coords.
if geo_transform is None:
geo_transform = _transform_from_attr(data.attrs.get('transform'))
if geo_transform is None:
geo_transform = _coords_to_transform(data)
if epsg is None and crs is None:
Expand Down Expand Up @@ -552,6 +721,12 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
from ._geotags import _build_gdal_metadata_xml
gdal_meta_xml = _build_gdal_metadata_xml(gdal_meta_dict)
extra_tags_list = data.attrs.get('extra_tags')
# Fold friendly attrs into extra_tags so a user-edited
# attrs['image_description'] / ['extra_samples'] / ['colormap']
# actually reaches the file. Existing entries with the same tag id
# win, which keeps verbatim round-trips byte-stable.
extra_tags_list = _merge_friendly_extra_tags(
extra_tags_list, data.attrs)
x_res = data.attrs.get('x_resolution')
y_res = data.attrs.get('y_resolution')
unit_str = data.attrs.get('resolution_unit')
Expand Down Expand Up @@ -771,7 +946,9 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
wkt_fallback = wkt
if nodata is None:
nodata = data.attrs.get('nodata')
geo_transform = _coords_to_transform(data)
geo_transform = _transform_from_attr(data.attrs.get('transform'))
if geo_transform is None:
geo_transform = _coords_to_transform(data)
else:
raw = data

Expand Down Expand Up @@ -935,6 +1112,9 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
attrs['raster_type'] = 'point'
if nodata is not None:
attrs['nodata'] = nodata
transform_tuple = _transform_tuple(geo_info)
if transform_tuple is not None:
attrs['transform'] = transform_tuple

if isinstance(chunks, int):
ch_h = ch_w = chunks
Expand Down Expand Up @@ -1114,6 +1294,9 @@ def read_geotiff_gpu(source: str, *,
attrs = {}
if geo_info.crs_epsg is not None:
attrs['crs'] = geo_info.crs_epsg
t_tuple = _transform_tuple(geo_info)
if t_tuple is not None:
attrs['transform'] = t_tuple
if dtype is not None:
target = np.dtype(dtype)
_validate_dtype_cast(np.dtype(str(arr_gpu.dtype)), target)
Expand Down Expand Up @@ -1201,6 +1384,9 @@ def read_geotiff_gpu(source: str, *,
attrs['crs'] = geo_info.crs_epsg
if geo_info.crs_wkt is not None:
attrs['crs_wkt'] = geo_info.crs_wkt
t_tuple = _transform_tuple(geo_info)
if t_tuple is not None:
attrs['transform'] = t_tuple

if arr_gpu.ndim == 3:
dims = ['y', 'x', 'band']
Expand Down Expand Up @@ -1415,6 +1601,14 @@ def read_vrt(source: str, *, dtype=None, window=None,
-------
xr.DataArray
NumPy, Dask, CuPy, or Dask+CuPy backed depending on options.

Notes
-----
Like ``open_geotiff``, the CRS lands as an int EPSG in
``attrs['crs']`` when the VRT's WKT resolves to a known EPSG code.
Otherwise ``attrs['crs']`` stays unset and ``attrs['crs_wkt']`` carries
the original WKT. The source GeoTransform is preserved as a
rasterio-style 6-tuple in ``attrs['transform']``.
"""
from ._vrt import read_vrt as _read_vrt_internal

Expand Down Expand Up @@ -1467,6 +1661,24 @@ def read_vrt(source: str, *, dtype=None, window=None,
if nodata is not None:
attrs['nodata'] = nodata

# Surface the source GeoTransform in the same rasterio ordering used
# by open_geotiff: (pixel_width, 0, origin_x, 0, pixel_height, origin_y).
# vrt.geo_transform is GDAL ordering, so reorder. For a windowed read
# the origin shifts by (col_offset * res_x, row_offset * res_y).
if gt is not None:
if window is not None:
r0w, c0w, _r1w, _c1w = window
r0w = max(0, r0w)
c0w = max(0, c0w)
else:
r0w = c0w = 0
origin_x_out = float(origin_x) + c0w * float(res_x)
origin_y_out = float(origin_y) + r0w * float(res_y)
attrs['transform'] = (
float(res_x), 0.0, origin_x_out,
0.0, float(res_y), origin_y_out,
)

# Transfer to GPU if requested
if gpu:
import cupy
Expand Down
Loading
Loading