From b04c4ffa0858b94c3900535f0543d707d1d70edb Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Tue, 5 May 2026 04:09:07 -0700 Subject: [PATCH 1/2] Geotiff polish: validation, caching caps, parallelism thresholds, memory guards (#1488) --- xrspatial/geotiff/__init__.py | 95 ++++++++++++++++++++------ xrspatial/geotiff/_gpu_decode.py | 53 +++++++++++++++ xrspatial/geotiff/_reader.py | 112 ++++++++++++++++++++++++------- xrspatial/geotiff/_writer.py | 13 +++- 4 files changed, 225 insertions(+), 48 deletions(-) diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 3c02b675..5a79fad3 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -398,6 +398,14 @@ def _is_gpu_data(data) -> bool: 'lz4': (0, 16), } +# Names accepted by ``compression=`` in :func:`to_geotiff`. Kept in sync with +# ``_compression_tag`` in ``_writer.py``. Validated up-front so users see a +# friendly error rather than the deeper traceback from ``_compression_tag``. +_VALID_COMPRESSIONS = ( + 'none', 'deflate', 'lzw', 'jpeg', 'packbits', 'zstd', 'lz4', + 'jpeg2000', 'j2k', 'lerc', +) + def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, crs: int | str | None = None, @@ -452,12 +460,17 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, tiled : bool Use tiled layout (default True). tile_size : int - Tile size in pixels (default 256). + Tile size in pixels (default 256). Ignored when ``tiled=False``; + a warning is emitted if a non-default value is passed alongside + strip mode. predictor : bool or int - TIFF predictor. ``False``/``0``/``1`` -> none, ``True``/``2`` -> - horizontal differencing (good for integer data), ``3`` -> - floating-point predictor (float dtypes only; typically gives - better deflate/zstd ratios on float data than predictor 2). + TIFF predictor. Accepted values: + + * ``False``, ``0``, or ``1`` -> no predictor. + * ``True`` or ``2`` -> horizontal differencing (good for integer + data; ``True`` and ``2`` are exactly equivalent). + * ``3`` -> floating-point predictor (float dtypes only; typically + gives better deflate/zstd ratios on float data than predictor 2). cog : bool Write as Cloud Optimized GeoTIFF. overview_levels : list[int] or None @@ -468,6 +481,27 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, gpu : bool or None Force GPU compression. None (default) auto-detects CuPy data. """ + # Up-front validation: catch bad compression names before they reach + # any of the deeper write paths (streaming, GPU, VRT, COG) where the + # error surfaces from _compression_tag with a less obvious traceback. + if isinstance(compression, str): + if compression.lower() not in _VALID_COMPRESSIONS: + raise ValueError( + f"Unknown compression {compression!r}. " + f"Valid options: {list(_VALID_COMPRESSIONS)}.") + + # tile_size only applies to tiled output; warn if the caller passed a + # non-default size alongside strip mode (it would otherwise be silently + # ignored). + if not tiled and tile_size != 256: + import warnings + warnings.warn( + f"tile_size={tile_size} is ignored when tiled=False " + "(strip layout). Pass tiled=True to use tile_size, or drop " + "tile_size to silence this warning.", + stacklevel=2, + ) + # VRT tiled output if path.lower().endswith('.vrt'): if cog: @@ -900,7 +934,11 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512, """ import dask.array as da - # VRT files: delegate to read_vrt which handles chunks + # ``read_geotiff`` already routes ``.vrt`` to ``read_vrt`` before + # reaching here, so this branch is only hit when ``read_geotiff_dask`` + # is called directly with a VRT path. Keep it as a defensive fallback + # rather than letting the windowed-read path try to parse VRT XML as + # TIFF bytes. ``read_vrt`` is the single source of truth for VRT. if source.lower().endswith('.vrt'): return read_vrt(source, dtype=dtype, name=name, chunks=chunks) @@ -944,23 +982,24 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512, # Graph-size guard. Each chunk becomes a delayed task whose Python graph # entry retains ~1KB. At very large chunk counts the graph itself OOMs # the driver before any read executes (30TB at chunks=256 => ~500M tasks - # => ~500GB graph on host). Auto-scale chunks up to cap total task count. - _MAX_DASK_CHUNKS = 1_000_000 + # => ~500GB graph on host). Refuse anything past the cap and ask the + # caller to pick a chunk size, rather than silently rescaling -- the + # rescaled chunks may not align with the user's downstream pipeline. + _MAX_DASK_CHUNKS = 50_000 n_chunks = ((full_h + ch_h - 1) // ch_h) * ((full_w + ch_w - 1) // ch_w) if n_chunks > _MAX_DASK_CHUNKS: import math scale = math.sqrt(n_chunks / _MAX_DASK_CHUNKS) - new_ch_h = int(math.ceil(ch_h * scale)) - new_ch_w = int(math.ceil(ch_w * scale)) - import warnings - warnings.warn( - f"read_geotiff_dask: requested chunks=({ch_h}, {ch_w}) on a " - f"{full_h}x{full_w} image would produce {n_chunks} dask tasks, " - f"exceeding the {_MAX_DASK_CHUNKS}-task cap. Auto-scaling to " - f"chunks=({new_ch_h}, {new_ch_w}).", - stacklevel=2, + suggested_h = int(math.ceil(ch_h * scale)) + suggested_w = int(math.ceil(ch_w * scale)) + raise ValueError( + f"read_geotiff_dask: chunks=({ch_h}, {ch_w}) on a " + f"{full_h}x{full_w} image would produce {n_chunks:,} dask " + f"tasks, exceeding the {_MAX_DASK_CHUNKS:,}-task cap. Pass a " + f"larger chunks=... value explicitly (e.g. chunks=" + f"({suggested_h}, {suggested_w}) keeps the task count under " + "the cap)." ) - ch_h, ch_w = new_ch_h, new_ch_w # Build dask array from delayed windowed reads rows = list(range(0, full_h, ch_h)) @@ -1355,12 +1394,14 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp): # Full resolution parts = [_gpu_compress_to_part(arr, width, height, samples)] - # Overview generation + # Overview generation -- mirrors the CPU writer's 8-level cap. if cog: if overview_levels is None: + from ._writer import _MAX_OVERVIEW_LEVELS overview_levels = [] oh, ow = height, width - while oh > tile_size and ow > tile_size: + while (oh > tile_size and ow > tile_size and + len(overview_levels) < _MAX_OVERVIEW_LEVELS): oh //= 2 ow //= 2 if oh > 0 and ow > 0: @@ -1505,13 +1546,23 @@ def write_vrt(vrt_path: str, source_files: list[str], **kwargs) -> str: Output .vrt file path. source_files : list of str Paths to the source GeoTIFF files. - **kwargs - relative, crs_wkt, nodata -- see _vrt.write_vrt. + relative : bool, optional + Store source paths relative to the VRT file (default True). + crs_wkt : str or None, optional + CRS as a WKT string. If None, the CRS is taken from the first + source GeoTIFF. + nodata : float or None, optional + NoData value. If None, taken from the first source GeoTIFF. Returns ------- str Path to the written VRT file. + + Notes + ----- + Only the keyword arguments listed above are accepted. Passing any + other keyword raises ``TypeError`` from the underlying writer. """ from ._vrt import write_vrt as _write_vrt_internal return _write_vrt_internal(vrt_path, source_files, **kwargs) diff --git a/xrspatial/geotiff/_gpu_decode.py b/xrspatial/geotiff/_gpu_decode.py index 242852cf..67c36829 100644 --- a/xrspatial/geotiff/_gpu_decode.py +++ b/xrspatial/geotiff/_gpu_decode.py @@ -11,6 +11,51 @@ import numpy as np from numba import cuda +#: Fraction of free GPU memory we're willing to allocate in a single call. +#: Above this, raise MemoryError up-front so the caller gets an actionable +#: error rather than a CUDA OOM deep inside the kernel launch. +_GPU_FREE_MEMORY_FRACTION = 0.9 + + +def _check_gpu_memory(required_bytes: int, what: str = "tile buffer") -> None: + """Raise MemoryError if *required_bytes* would exhaust the GPU. + + Calls ``cupy.cuda.runtime.memGetInfo()`` and refuses any allocation + that would consume more than ``_GPU_FREE_MEMORY_FRACTION`` of the + currently free memory. This is a soft guard -- another process can + grab memory between the check and the allocation -- but it catches + the common 'this single tensor is way too big' case before CUDA + raises a less informative error. + + Parameters + ---------- + required_bytes : int + Bytes the caller is about to allocate (sum across all buffers in + the same logical step). + what : str + Short label included in the error message, e.g. ``"tile buffer"``. + """ + if required_bytes <= 0: + return + try: + import cupy + free, total = cupy.cuda.runtime.memGetInfo() + except Exception: + # If we can't query, fall through and let the real allocation + # surface the error. Don't add a second failure mode here. + return + + budget = int(free * _GPU_FREE_MEMORY_FRACTION) + if required_bytes > budget: + raise MemoryError( + f"GPU out of memory: {what} needs {required_bytes:,} bytes " + f"but only {free:,} bytes free on device (cap is " + f"{_GPU_FREE_MEMORY_FRACTION:.0%} of free = {budget:,} " + "bytes). Consider reading the file in chunks via " + "read_geotiff_dask(..., chunks=...) or freeing GPU memory " + "with cupy.get_default_memory_pool().free_all_blocks()." + ) + # LZW constants (same as _compression.py) LZW_CLEAR_CODE = 256 LZW_EOI_CODE = 257 @@ -1006,6 +1051,8 @@ class _NvjpegImage(ctypes.Structure): ('pitch', ctypes.c_size_t * 4), ] + _check_gpu_memory(n_tiles * tile_bytes, + what="nvJPEG output buffer") d_all = cupy.empty(n_tiles * tile_bytes, dtype=cupy.uint8) decode_fn = getattr(lib, 'nvjpegDecode') @@ -1353,6 +1400,8 @@ def _apply_predictor_and_assemble(d_decomp, d_decomp_offsets, n_tiles, tiles_across = math.ceil(image_width / tile_width) total_pixels = image_width * image_height + _check_gpu_memory(total_pixels * bytes_per_pixel, + what="full-image output buffer") d_output = cupy.empty(total_pixels * bytes_per_pixel, dtype=cupy.uint8) tpb = 256 @@ -1440,6 +1489,7 @@ def gpu_decode_tiles( # Allocate decompressed buffer on device decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + _check_gpu_memory(n_tiles * tile_bytes, what="tile decode buffer") d_decomp = cupy.zeros(n_tiles * tile_bytes, dtype=cupy.uint8) d_decomp_offsets = cupy.asarray(decomp_offsets) d_tile_sizes = cupy.full(n_tiles, tile_bytes, dtype=cupy.int64) @@ -1470,6 +1520,7 @@ def gpu_decode_tiles( d_comp_sizes = cupy.asarray(np.array(comp_sizes, dtype=np.int64)) decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + _check_gpu_memory(n_tiles * tile_bytes, what="tile decode buffer") d_decomp = cupy.zeros(n_tiles * tile_bytes, dtype=cupy.uint8) d_decomp_offsets = cupy.asarray(decomp_offsets) d_tile_sizes = cupy.full(n_tiles, tile_bytes, dtype=cupy.int64) @@ -1602,6 +1653,8 @@ def gpu_decode_tiles( # Assemble tiles into output image on GPU tiles_across = math.ceil(image_width / tile_width) total_pixels = image_width * image_height + _check_gpu_memory(total_pixels * bytes_per_pixel, + what="full-image output buffer") d_output = cupy.empty(total_pixels * bytes_per_pixel, dtype=cupy.uint8) tpb = 256 diff --git a/xrspatial/geotiff/_reader.py b/xrspatial/geotiff/_reader.py index 4604edf2..635ca585 100644 --- a/xrspatial/geotiff/_reader.py +++ b/xrspatial/geotiff/_reader.py @@ -3,8 +3,10 @@ import math import mmap +import os as _os_module import threading import urllib.request +from collections import OrderedDict import numpy as np @@ -45,28 +47,55 @@ def _check_dimensions(width, height, samples, max_pixels): # Data source abstraction # --------------------------------------------------------------------------- +#: Soft cap on the number of mmap entries the reader keeps open at once. +#: When the cache size exceeds this, the least-recently-used *idle* entry +#: (refcount 0) is closed. In-use entries are never evicted. Override via +#: the ``XRSPATIAL_GEOTIFF_MMAP_CACHE_SIZE`` environment variable. +_DEFAULT_MMAP_CACHE_SIZE = 32 + + +def _mmap_cache_size_from_env() -> int: + """Read the cache size cap from the environment, falling back to the default.""" + raw = _os_module.environ.get('XRSPATIAL_GEOTIFF_MMAP_CACHE_SIZE') + if raw is None: + return _DEFAULT_MMAP_CACHE_SIZE + try: + val = int(raw) + except (TypeError, ValueError): + return _DEFAULT_MMAP_CACHE_SIZE + return max(1, val) + + class _MmapCache: - """Thread-safe, reference-counted mmap cache. + """Thread-safe, reference-counted, bounded LRU mmap cache. Multiple threads reading the same file share a single read-only mmap. - The mmap is closed when the last reference is released. + The cache keeps idle (refcount 0) mmaps around so repeated opens of the + same file avoid the cost of re-mapping. When the number of entries + exceeds the cap (default 32, or ``XRSPATIAL_GEOTIFF_MMAP_CACHE_SIZE``), + the least-recently-used *idle* entry is evicted. Entries with active + references are never evicted. + mmap slicing on a read-only mapping is thread-safe (no seek involved). """ - def __init__(self): + def __init__(self, max_size: int | None = None): self._lock = threading.Lock() - # path -> (fh, mm, refcount) - self._entries: dict[str, tuple] = {} + # path -> [fh, mm, size, refcount] (list so we can mutate in place) + # OrderedDict gives LRU semantics via move_to_end on access. + self._entries: OrderedDict[str, list] = OrderedDict() + self._max_size = (max_size if max_size is not None + else _mmap_cache_size_from_env()) def acquire(self, path: str): """Get or create a read-only mmap for *path*. Returns (mm, size).""" - import os - real = os.path.realpath(path) + real = _os_module.path.realpath(path) with self._lock: - if real in self._entries: - fh, mm, size, rc = self._entries[real] - self._entries[real] = (fh, mm, size, rc + 1) - return mm, size + entry = self._entries.get(real) + if entry is not None: + entry[3] += 1 + self._entries.move_to_end(real) + return entry[1], entry[2] fh = open(real, 'rb') fh.seek(0, 2) @@ -76,26 +105,56 @@ def acquire(self, path: str): mm = mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ) else: mm = None - self._entries[real] = (fh, mm, size, 1) + self._entries[real] = [fh, mm, size, 1] + self._evict_locked() return mm, size def release(self, path: str): - """Decrement the reference count; close the mmap when it hits zero.""" - import os - real = os.path.realpath(path) + """Decrement the reference count. + + When the count hits zero the entry stays cached (keyed by realpath) + until LRU eviction or :meth:`clear` is called. + """ + real = _os_module.path.realpath(path) with self._lock: entry = self._entries.get(real) if entry is None: return - fh, mm, size, rc = entry - rc -= 1 - if rc <= 0: - del self._entries[real] + entry[3] -= 1 + if entry[3] <= 0: + # Idle but still cached; mark LRU position. + self._entries.move_to_end(real) + self._evict_locked() + + def _evict_locked(self): + """Drop oldest *idle* entries until the cache is at or below the cap.""" + if len(self._entries) <= self._max_size: + return + # Walk from the front (oldest); only close idle (refcount 0) entries. + # An in-use entry can still happen to be at the front if the same + # file was acquired long ago and held; skip it. + to_drop = [] + for key, entry in list(self._entries.items()): + if len(self._entries) - len(to_drop) <= self._max_size: + break + if entry[3] <= 0: + to_drop.append(key) + for key in to_drop: + entry = self._entries.pop(key) + _, mm, _, _ = entry + if mm is not None: + mm.close() + entry[0].close() + + def clear(self): + """Close and drop all idle entries (used by tests).""" + with self._lock: + for key in [k for k, v in self._entries.items() if v[3] <= 0]: + entry = self._entries.pop(key) + _, mm, _, _ = entry if mm is not None: mm.close() - fh.close() - else: - self._entries[real] = (fh, mm, size, rc) + entry[0].close() # Module-level cache shared across all reads @@ -550,9 +609,14 @@ def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader, continue tile_jobs.append((band_idx, tr, tc, tile_idx, tile_samples)) - # Decode tiles -- parallel for compressed, sequential for uncompressed + # Decode tiles in parallel when the work per tile is large enough to + # outweigh the thread-pool overhead. Uncompressed multi-tile reads also + # benefit because numpy frombuffer + slice copies aren't free at large + # tile sizes. Threshold (~64K decoded pixels per tile) was picked to + # avoid pool overhead on small 64x64 / 128x128 tile reads. n_tiles = len(tile_jobs) - use_parallel = (compression != 1 and n_tiles > 4) # 1 = COMPRESSION_NONE + tile_pixels = tw * th + use_parallel = (n_tiles > 1 and tile_pixels > 64 * 1024) def _decode_one(job): band_idx, tr, tc, tile_idx, tile_samples = job diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index 13311a94..081840e4 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -127,6 +127,12 @@ def _compression_tag(compression_name: str) -> int: OVERVIEW_METHODS = ('mean', 'nearest', 'min', 'max', 'median', 'mode', 'cubic') +#: Maximum number of overview levels generated by auto-overview mode in COG +#: writes. 8 halvings = 1/256 of the original resolution, which is enough +#: for any practical raster. Pass ``overview_levels=[...]`` explicitly to +#: override. +_MAX_OVERVIEW_LEVELS = 8 + def _block_reduce_2d(arr2d, method): """2x block-reduce a single 2D plane using *method*.""" @@ -1027,10 +1033,13 @@ def write(data: np.ndarray, path: str, *, # Overviews if cog: if overview_levels is None: - # Auto-generate: keep halving until < tile_size + # Auto-generate: keep halving until < tile_size, capped at 8 levels. + # 8 halvings = 1/256 resolution, which is more than enough for + # interactive zoom on any realistic raster. Past that, overview + # write cost dominates without benefiting consumers. overview_levels = [] oh, ow = h, w - while oh > tile_size and ow > tile_size: + while oh > tile_size and ow > tile_size and len(overview_levels) < _MAX_OVERVIEW_LEVELS: oh //= 2 ow //= 2 if oh > 0 and ow > 0: From 7285ca0ba212e7dc84c701627a357d16c92334e1 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Tue, 5 May 2026 04:12:21 -0700 Subject: [PATCH 2/2] Add tests for #1488 polish bundle --- xrspatial/geotiff/tests/test_edge_cases.py | 6 +- xrspatial/geotiff/tests/test_polish_1488.py | 339 ++++++++++++++++++++ 2 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 xrspatial/geotiff/tests/test_polish_1488.py diff --git a/xrspatial/geotiff/tests/test_edge_cases.py b/xrspatial/geotiff/tests/test_edge_cases.py index 7e6d2a94..edbdab56 100644 --- a/xrspatial/geotiff/tests/test_edge_cases.py +++ b/xrspatial/geotiff/tests/test_edge_cases.py @@ -49,7 +49,11 @@ def test_0d_scalar(self, tmp_path): def test_unsupported_compression(self, tmp_path): arr = np.zeros((4, 4), dtype=np.float32) - with pytest.raises(ValueError, match="Unsupported compression"): + # ``to_geotiff`` validates ``compression`` up-front (#1488). The + # earlier "Unsupported compression" message comes from the deeper + # ``_compression_tag`` and is now only seen when callers reach + # the writer directly. Both phrasings are acceptable. + with pytest.raises(ValueError, match="(Unknown|Unsupported) compression"): to_geotiff(arr, str(tmp_path / 'bad.tif'), compression='webp') def test_complex_dtype(self, tmp_path): diff --git a/xrspatial/geotiff/tests/test_polish_1488.py b/xrspatial/geotiff/tests/test_polish_1488.py new file mode 100644 index 00000000..e9c7ad76 --- /dev/null +++ b/xrspatial/geotiff/tests/test_polish_1488.py @@ -0,0 +1,339 @@ +"""Tests for the geotiff polish bundle (issue #1488). + +Covers ten low-severity audit items: + +* C-1 -- early ``compression`` validation in ``to_geotiff`` +* C-2 -- read dispatch leaves ``read_geotiff_dask`` with a defensive + ``.vrt`` fallback that delegates to ``read_vrt`` +* C-5 -- ``write_vrt`` docstring lists kwargs (rejects unknown ones) +* C-6 -- predictor doc covers True/2 equivalence and 3=fp +* C-7 -- ``tile_size`` warns when ``tiled=False`` and non-default +* P-3 -- mmap cache eviction (LRU + env var override) +* P-4 -- decode parallelism gate based on tile pixel count +* P-5 -- dask read raises with a hint instead of silently rescaling +* P-6 -- GPU memory pre-check raises before the cupy call +* P-9 -- COG auto-overview generation capped at 8 levels +""" +from __future__ import annotations + +import os +import warnings + +import numpy as np +import pytest + +from xrspatial.geotiff import to_geotiff, read_geotiff_dask, write_vrt +from xrspatial.geotiff._reader import _MmapCache, read_to_array +from xrspatial.geotiff._writer import _MAX_OVERVIEW_LEVELS, write + + +# --------------------------------------------------------------------------- +# C-1: early compression validation +# --------------------------------------------------------------------------- + +class TestC1CompressionValidation: + def test_unknown_compression_raises_at_top(self, tmp_path): + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'bogus_1488.tif') + with pytest.raises(ValueError, match="Unknown compression"): + to_geotiff(arr, path, compression='bzip2') + + def test_known_compression_passes(self, tmp_path): + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'ok_1488.tif') + # Should not raise. + to_geotiff(arr, path, compression='deflate') + assert os.path.exists(path) + + def test_validation_runs_before_vrt_dispatch(self, tmp_path): + # Bad compression on a .vrt path should still surface from the + # up-front check, not from deeper in _write_vrt_tiled. + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'bogus_1488.vrt') + with pytest.raises(ValueError, match="Unknown compression"): + to_geotiff(arr, path, compression='nope') + + +# --------------------------------------------------------------------------- +# C-2: read dispatch comment (behavior verification) +# --------------------------------------------------------------------------- + +class TestC2ReadDispatch: + def test_read_geotiff_dask_handles_vrt_directly(self, tmp_path): + # Build a 2-tile VRT and confirm read_geotiff_dask routes to the + # VRT reader without trying to parse XML as TIFF. + from xrspatial.geotiff import write_vrt as wv + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + a_path = str(tmp_path / 'a_1488.tif') + b_path = str(tmp_path / 'b_1488.tif') + # Two tiles side-by-side via geo-transform attrs would normally be + # generated upstream; for this test we just need both files to + # share a CRS and the writer's default transform. + write(arr, a_path, compression='none') + write(arr, b_path, compression='none') + vrt_path = str(tmp_path / 'mosaic_1488.vrt') + wv(vrt_path, [a_path, b_path]) + + result = read_geotiff_dask(vrt_path, chunks=8) + assert result.dims == ('y', 'x') + + +# --------------------------------------------------------------------------- +# C-5: write_vrt kwargs documented +# --------------------------------------------------------------------------- + +class TestC5WriteVrtKwargs: + def test_known_kwargs_accepted(self, tmp_path): + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + a_path = str(tmp_path / 'a_c5_1488.tif') + write(arr, a_path, compression='none') + vrt_path = str(tmp_path / 'mosaic_c5_1488.vrt') + # All three documented kwargs should be accepted. + write_vrt(vrt_path, [a_path], relative=False, crs_wkt=None, + nodata=-9999.0) + assert os.path.exists(vrt_path) + + def test_unknown_kwarg_raises_typeerror(self, tmp_path): + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + a_path = str(tmp_path / 'a_c5b_1488.tif') + write(arr, a_path, compression='none') + vrt_path = str(tmp_path / 'mosaic_c5b_1488.vrt') + with pytest.raises(TypeError): + write_vrt(vrt_path, [a_path], not_a_real_kwarg=True) + + def test_docstring_lists_kwargs(self): + # Defensive: the whole point of C-5 is the docstring -- guard + # against future regressions. + assert 'relative' in write_vrt.__doc__ + assert 'crs_wkt' in write_vrt.__doc__ + assert 'nodata' in write_vrt.__doc__ + + +# --------------------------------------------------------------------------- +# C-6: predictor docstring polish +# --------------------------------------------------------------------------- + +class TestC6PredictorDoc: + def test_to_geotiff_doc_covers_predictor_modes(self): + doc = to_geotiff.__doc__ + # True and 2 documented as equivalent + assert 'True' in doc and '2' in doc + # 3 documented as fp predictor for floats + assert '3' in doc and 'float' in doc.lower() + + +# --------------------------------------------------------------------------- +# C-7: tile_size warning in strip mode +# --------------------------------------------------------------------------- + +class TestC7TileSizeWarn: + def test_warning_when_tile_size_set_with_tiled_false(self, tmp_path): + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'strip_1488.tif') + with pytest.warns(UserWarning, match="tile_size.*ignored"): + to_geotiff(arr, path, tiled=False, tile_size=128, + compression='none') + + def test_no_warning_when_tile_size_default(self, tmp_path): + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'strip2_1488.tif') + with warnings.catch_warnings(): + warnings.simplefilter('error') + to_geotiff(arr, path, tiled=False, compression='none') + + def test_no_warning_when_tiled_true(self, tmp_path): + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'tiled_1488.tif') + with warnings.catch_warnings(): + warnings.simplefilter('error') + to_geotiff(arr, path, tiled=True, tile_size=128, + compression='none') + + +# --------------------------------------------------------------------------- +# P-3: mmap LRU cache cap +# --------------------------------------------------------------------------- + +class TestP3MmapLRU: + def test_cap_evicts_oldest_idle_entry(self, tmp_path): + cache = _MmapCache(max_size=2) + files = [] + for i in range(3): + p = tmp_path / f'p3_{i}_1488.bin' + p.write_bytes(b'x' * 32) + files.append(str(p)) + + # Acquire and release each: all become idle. + for f in files: + cache.acquire(f) + cache.release(f) + + # Cache should be at the cap (2), with the oldest evicted. + assert len(cache._entries) == 2 + oldest = os.path.realpath(files[0]) + assert oldest not in cache._entries + + def test_inuse_entries_not_evicted(self, tmp_path): + cache = _MmapCache(max_size=1) + a = tmp_path / 'p3_a_1488.bin' + a.write_bytes(b'a' * 32) + b = tmp_path / 'p3_b_1488.bin' + b.write_bytes(b'b' * 32) + + cache.acquire(str(a)) # rc=1, in use + cache.acquire(str(b)) # would exceed cap, but a is in use + cache.release(str(b)) # b idle now + + # a still in cache because rc > 0. + assert os.path.realpath(str(a)) in cache._entries + + def test_env_var_override(self, tmp_path, monkeypatch): + monkeypatch.setenv('XRSPATIAL_GEOTIFF_MMAP_CACHE_SIZE', '5') + # New cache picks up the env setting. + cache = _MmapCache() + assert cache._max_size == 5 + + +# --------------------------------------------------------------------------- +# P-4: decode parallelism threshold +# --------------------------------------------------------------------------- + +class TestP4ParallelThreshold: + def test_uncompressed_large_tiles_round_trip(self, tmp_path): + # The threshold is purely an internal heuristic; the test confirms + # the new gate doesn't regress correctness on uncompressed files + # with multiple large tiles (which previously skipped the pool). + rng = np.random.default_rng(0) + arr = rng.integers(0, 255, size=(512, 512), dtype=np.uint8) + path = str(tmp_path / 'p4_1488.tif') + write(arr, path, compression='none', tiled=True, tile_size=256) + out, _ = read_to_array(path) + np.testing.assert_array_equal(out, arr) + + def test_small_tile_path_still_correct(self, tmp_path): + # 64x64 tiles of 64x64 image: tile_pixels = 4096 < 64K, so the + # gate stays sequential. Round-trip must still match. + arr = np.arange(64 * 64, dtype=np.float32).reshape(64, 64) + path = str(tmp_path / 'p4b_1488.tif') + write(arr, path, compression='deflate', tiled=True, tile_size=64) + out, _ = read_to_array(path) + np.testing.assert_array_equal(out, arr) + + +# --------------------------------------------------------------------------- +# P-5: dask task cap +# --------------------------------------------------------------------------- + +class TestP5DaskTaskCap: + def test_excessive_chunks_raises_with_hint(self, tmp_path): + # 1024x1024 image with chunks=2 => ~262K chunks, well over 50K cap. + arr = np.zeros((1024, 1024), dtype=np.uint8) + path = str(tmp_path / 'p5_1488.tif') + write(arr, path, compression='none', tiled=True, tile_size=64) + with pytest.raises(ValueError, match=r"50,000-task cap"): + read_geotiff_dask(path, chunks=2) + + def test_normal_chunks_pass(self, tmp_path): + arr = np.zeros((512, 512), dtype=np.uint8) + path = str(tmp_path / 'p5b_1488.tif') + write(arr, path, compression='none', tiled=True, tile_size=128) + result = read_geotiff_dask(path, chunks=128) + assert result.shape == (512, 512) + + +# --------------------------------------------------------------------------- +# P-6: GPU memory pre-check +# --------------------------------------------------------------------------- + +class TestP6GpuMemoryCheck: + def test_helper_raises_when_request_exceeds_free(self, monkeypatch): + cupy = pytest.importorskip('cupy') + from xrspatial.geotiff import _gpu_decode + + # Stub memGetInfo to a tiny free budget so the check trips. + class _FakeRuntime: + @staticmethod + def memGetInfo(): + return (1024, 8 * 1024 * 1024 * 1024) # 1KB free + + monkeypatch.setattr(cupy.cuda, 'runtime', _FakeRuntime, + raising=False) + with pytest.raises(MemoryError, match="GPU out of memory"): + _gpu_decode._check_gpu_memory(10 * 1024 * 1024, + what="test buffer") + + def test_helper_noop_for_zero_or_negative(self): + from xrspatial.geotiff import _gpu_decode + # Should not even try to query CUDA. + _gpu_decode._check_gpu_memory(0, what="empty") + _gpu_decode._check_gpu_memory(-100, what="negative") + + def test_helper_silent_when_cupy_unavailable(self, monkeypatch): + # When cupy isn't importable, the helper falls through silently + # so the real allocation can produce its own error. + from xrspatial.geotiff import _gpu_decode + import builtins + + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == 'cupy': + raise ImportError("simulated") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, '__import__', fake_import) + _gpu_decode._check_gpu_memory(10**9, what="probe") # no raise + + +# --------------------------------------------------------------------------- +# P-9: COG auto-overview level cap +# --------------------------------------------------------------------------- + +class TestP9OverviewCap: + def test_max_levels_constant(self): + assert _MAX_OVERVIEW_LEVELS == 8 + + def test_auto_overview_capped(self, tmp_path): + # 4096x4096 with tile_size=64 would generate ceil(log2(4096/64)) + # = 6 levels (under the cap), so use a tinier tile_size to push + # the natural count past 8. + arr = np.zeros((4096, 4096), dtype=np.uint8) + path = str(tmp_path / 'p9_1488.tif') + # tile_size=4 -> log2(4096/4)=10 natural levels, but cap is 8. + write(arr, path, compression='none', tiled=True, tile_size=4, + cog=True) + + # Re-open and count IFDs (overviews + full-res). + from xrspatial.geotiff._header import parse_header, parse_all_ifds + from xrspatial.geotiff._reader import _FileSource + src = _FileSource(path) + try: + data = src.read_all() + header = parse_header(data) + ifds = parse_all_ifds(data, header) + finally: + src.close() + + # 1 full-res IFD + at most 8 overview IFDs. + assert len(ifds) <= 1 + _MAX_OVERVIEW_LEVELS + + def test_explicit_overview_levels_not_capped(self, tmp_path): + # When the caller passes overview_levels explicitly, the cap is + # not applied -- they get exactly what they asked for. + arr = np.zeros((1024, 1024), dtype=np.uint8) + path = str(tmp_path / 'p9b_1488.tif') + write(arr, path, compression='none', tiled=True, tile_size=64, + cog=True, overview_levels=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + from xrspatial.geotiff._header import parse_header, parse_all_ifds + from xrspatial.geotiff._reader import _FileSource + src = _FileSource(path) + try: + data = src.read_all() + header = parse_header(data) + ifds = parse_all_ifds(data, header) + finally: + src.close() + + # 10 explicit overviews + 1 full-res = 11 IFDs. + assert len(ifds) == 11