From 335a6c7e44b040302793b4c41e19dcd9c2ff34f1 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 13 May 2026 12:44:26 -0500 Subject: [PATCH 1/7] chore: spec out rust-warp integration --- dev-docs/specs/rust-warp-integration.md | 617 ++++++++++++++++++++++++ 1 file changed, 617 insertions(+) create mode 100644 dev-docs/specs/rust-warp-integration.md diff --git a/dev-docs/specs/rust-warp-integration.md b/dev-docs/specs/rust-warp-integration.md new file mode 100644 index 0000000..bfaae3e --- /dev/null +++ b/dev-docs/specs/rust-warp-integration.md @@ -0,0 +1,617 @@ +# Spec: Replacing lazycogs' numpy/pyproj Warp Engine with rust-warp + +## Context + +lazycogs currently reprojects source COG windows with a small in-repo engine in `src/lazycogs/_reproject.py`. That engine: + +- uses `pyproj.Transformer` to map destination pixel centers into the source CRS +- converts projected coordinates to source pixel coordinates with the inverse affine +- floors to integer indices +- samples with numpy fancy indexing + +This design is simple and fast enough for nearest-neighbor, but it has hard limits: + +- only nearest-neighbor resampling exists +- the warp math and sampling code are maintained locally +- performance is constrained by Python/numpy memory traffic +- the implementation is tightly specialized to today's warp path + +[rust-warp](https://github.com/jakenotjay/rust-warp) is a GDAL-free Rust/Python warp engine that already provides: + +- inverse-mapping reprojection +- multiple resampling kernels +- GIL-free compute +- pure-Rust projection math for several common CRSes, with `proj4rs` fallback +- Python bindings that accept affine transforms and CRS strings + +The question is not "can we call rust-warp somewhere". We obviously can. The question is whether it cleanly replaces lazycogs' current per-window reprojection path without breaking the parts lazycogs is already good at: per-chunk DuckDB search, async COG window reads, overview selection, and mosaic ordering. + +This spec defines the architecture for that replacement. + +## Goals + +- Replace the internal warp implementation in `src/lazycogs/_reproject.py` with a rust-warp-backed adapter. +- Preserve lazycogs' existing compute-time pipeline: DuckDB search, async-geotiff window reads, overview selection, and mosaic methods. +- Enable multiple reprojection resampling methods through lazycogs' public API. +- Preserve the current fast path for same-grid reads where reprojection is unnecessary. +- Keep the migration incremental so behavior can be parity-tested against the current implementation before full cutover. + +## Non-goals + +- Replacing `rustac`, DuckDB search, or `async-geotiff`. +- Replacing the search-time `pyproj` usage that converts request bboxes to EPSG:4326. That is a separate concern. +- Adopting rust-warp's dask graph builder or xarray accessor. lazycogs already owns those layers. +- Re-architecting mosaic methods. +- Promising bit-for-bit parity with current `pyproj` output across every CRS. + +## Constraints and Assumptions + +- lazycogs remains a Python-first package with Rust dependencies only through installed wheels. +- The replacement must work with per-item window reads produced by `async-geotiff`; rust-warp does not become the source reader. +- lazycogs must continue to support chunk-local reprojection of small source windows, not only full-scene arrays. +- The current same-CRS same-affine fast path should remain in Python because it avoids any warp call at all. +- rust-warp's useful integration surface is its low-level array reprojection API, not its xarray or dask APIs. +- rust-warp currently documents 2D low-level warps. lazycogs must therefore continue to iterate over band planes in Python unless upstream adds a true 3D band-aware kernel. +- rust-warp supports a limited dtype set. lazycogs must validate or cast unsupported dtypes explicitly. + +## Why the high-level rust-warp APIs are the wrong fit + +lazycogs already has the hard parts that matter for its product shape: + +- chunk-local STAC search +- overview selection before read +- selective source window reads +- per-item mosaicking in caller-defined order +- xarray backend integration + +rust-warp's dask planner and xarray accessor solve a different problem: lazy reprojection of arrays you already have. lazycogs does not already have the full source array. It discovers and reads tiny windows on demand. Replacing lazycogs' chunk orchestration with rust-warp's planner would be a step backward. + +Therefore the correct integration point is the low-level warp kernel: + +- keep lazycogs' orchestration +- replace only the local reprojection core + +## Architecture Overview + +```text +open(..., resampling=...) + -> MultiBandStacBackendArray + -> _async_getitem(...) + -> _read_chunk_all_dates(...) + -> read_chunk_async(...) + -> _read_item_band(...) + -> GeoTIFF.open / overview select / window read + -> rust-warp adapter + -> mosaic method feed +``` + +### What stays in lazycogs + +- `src/lazycogs/_backend.py` +- `src/lazycogs/_chunk_reader.py` item/window/overview logic +- mosaic methods in `src/lazycogs/_mosaic_methods.py` +- grid construction in `src/lazycogs/_grid.py` +- all chunk/time concurrency behavior + +### What changes + +- `src/lazycogs/_reproject.py` stops implementing warp math itself +- lazycogs adds a thin adapter around `rust_warp.reproject_array` +- lazycogs adds public resampling selection +- caching shifts from cached integer `WarpMap`s to a simpler geometry/argument cache only if benchmarking proves it worthwhile + +## Proposed Integration Design + +### 1. Introduce a backend-neutral reprojection interface + +`src/lazycogs/_reproject.py` should stop exposing "warp map" as the primary abstraction. That is an implementation detail of the current engine, and it leaks too much. + +Instead define a narrow operation-level interface: + +```python +@dataclass(frozen=True) +class ReprojectRequest: + data: np.ndarray # shape (bands, src_h, src_w) + src_transform: Affine + src_crs: CRS + dst_transform: Affine + dst_crs: CRS + dst_width: int + dst_height: int + nodata: float | None + resampling: str + + +def reproject_tile(request: ReprojectRequest) -> np.ndarray: + ... +``` + +Initial implementation paths: + +- `same-grid fast path`: return input unchanged +- `python-legacy engine`: existing implementation, temporarily retained for rollout +- `rust-warp engine`: new default path once validated + +This lets us swap implementations without rewriting `_chunk_reader.py` again. + +### 2. Implement a rust-warp adapter, not a direct scatter of API calls + +Add a new private module, e.g. `src/lazycogs/_rust_warp.py`, responsible for: + +- converting `Affine` to the 6-tuple rust-warp expects +- converting `pyproj.CRS` objects to CRS strings accepted by rust-warp +- handling per-band iteration for `(bands, h, w)` arrays +- validating dtypes and nodata compatibility +- mapping lazycogs resampling names to rust-warp resampling names + +Proposed function: + +```python +def reproject_array_rust_warp( + data: np.ndarray, + src_transform: Affine, + src_crs: CRS, + dst_transform: Affine, + dst_crs: CRS, + dst_width: int, + dst_height: int, + nodata: float | None = None, + resampling: str = "nearest", +) -> np.ndarray: + ... +``` + +Implementation sketch: + +1. Fast-return if source and destination grid already match. +2. Normalize CRS strings. +3. For each band plane: + - call `rust_warp.reproject_array(...)` +4. Stack band outputs back to `(bands, dst_h, dst_w)`. + +This is the right first integration even if it feels slightly boring. It replaces the risky part and leaves working infrastructure alone. + +### 3. Stop centering the design on warp-map caching + +The current implementation caches `WarpMap` objects because computing destination-to-source coordinate maps in Python is expensive and repeated. Once the warp loop moves into Rust, that specific cache may become unnecessary or even counterproductive. + +Decision: + +- remove `WarpMap` from the long-term design +- keep the `warp_cache` plumbing during migration only if needed for A/B comparison +- benchmark before reintroducing any geometry-plan cache + +Rationale: + +- caching tied to old internals makes the new design harder to reason about +- rust-warp already has its own optimization strategy +- a bad cache can waste memory across concurrent chunk reads + +If post-migration benchmarks show repeated geometry setup is still expensive, reintroduce a new cache based on a backend-neutral key such as `(src_transform, src_crs, dst_transform, dst_crs, dst_shape, resampling)`. + +### 4. Add public resampling selection now that the backend can support it + +`lazycogs.open()` should gain: + +```python +resampling: Literal["nearest", "bilinear", "cubic", "lanczos", "average"] = "nearest" +``` + +This parameter must flow through: + +- `_core.open()` +- `MultiBandStacBackendArray` +- `_async_getitem()` +- `read_chunk_async()` / `read_chunk()` +- `_read_item_band()` +- `reproject_tile()` + +Notes: + +- `average` should be documented as downsampling-oriented. +- `nearest` remains the default for backward compatibility. +- If a later decision is made to expose only a smaller supported subset, the interface can still start with the rust-warp names and reject unsupported ones centrally. + +### 5. Keep search-time bbox reprojection on pyproj for now + +Do not try to delete `pyproj` entirely in this change. + +Current uses of `pyproj` fall into two buckets: + +1. search/grid plumbing + - bbox to EPSG:4326 for STAC queries + - output grid and metadata helpers +2. per-pixel warp math + - current `_reproject.py` + +This spec replaces only bucket 2. + +Reasons: + +- rust-warp is solving raster reprojection, not all CRS concerns in lazycogs +- bbox transforms are not a performance bottleneck worth destabilizing +- widening scope makes it harder to debug accuracy differences + +After migration, `pyproj` may still remain a dependency. That is acceptable. + +## API or Interface Design + +### Public API + +```python +def open( + href: str, + *, + ..., + resampling: str = "nearest", + ..., +) -> xr.DataArray: + ... +``` + +Validation rules: + +- accepted values initially: `nearest`, `bilinear`, `cubic`, `lanczos`, `average` +- unknown values raise `ValueError` +- docs must state that quality and performance differ by method + +### Internal API + +```python +@dataclass(frozen=True) +class ReprojectRequest: + data: np.ndarray + src_transform: Affine + src_crs: CRS + dst_transform: Affine + dst_crs: CRS + dst_width: int + dst_height: int + nodata: float | None + resampling: str + + +def reproject_tile(request: ReprojectRequest) -> np.ndarray: + """Reproject one `(bands, y, x)` source tile onto the destination chunk grid.""" +``` + +### Optional migration hook + +During rollout only: + +```python +reproject_engine: Literal["legacy", "rust-warp"] = "rust-warp" +``` + +This should be private or test-only, not a documented permanent user-facing API. + +## Data Model + +No external data model changes. + +Internal changes: + +- `WarpMap` becomes deprecated and then removable. +- `MultiBandStacBackendArray` gains a `resampling: str` field. +- `_ChunkContext` gains `resampling: str`. + +## Detailed Behavior + +### Same-grid fast path + +If all of the following are true: + +- `src_crs.equals(dst_crs)` +- `raster.transform == dst_transform` +- source width/height match destination width/height + +then lazycogs must return the source data unchanged without calling rust-warp. + +This preserves the existing zero-overhead case. + +### Band handling + +rust-warp's documented low-level API is 2D. lazycogs reads `(bands, h, w)` windows. Therefore the adapter must: + +- iterate over band axis in Python +- preserve input band order +- preserve output shape `(bands, dst_h, dst_w)` + +Future optimization: + +- if upstream rust-warp adds multi-band low-level kernels, lazycogs can switch internally without changing public API + +### CRS normalization + +lazycogs currently carries `pyproj.CRS` objects. rust-warp wants EPSG or PROJ strings. + +Adapter rules: + +1. Prefer `CRS.to_epsg()` when available. +2. If EPSG exists, pass `f"EPSG:{epsg}"`. +3. Otherwise pass `CRS.to_proj4()`. +4. If neither produces a usable value, raise a clear error. + +Do not pass WKT unless benchmarking and compatibility work proves it is necessary. rust-warp's own docs say WKT handling depends on `pyproj` assistance and is not its strongest path. + +### Dtype handling + +rust-warp README says low-level support includes: + +- `float32` +- `float64` +- `int8` +- `uint8` +- `uint16` +- `int16` + +Before calling rust-warp, lazycogs must: + +- allow these dtypes directly +- explicitly reject or cast unsupported dtypes + +Recommended initial policy: + +- direct pass-through for supported dtypes +- raise `TypeError` for unsupported dtypes during migration + +Reason: silent casting is how you ship a bug and only discover it six weeks later in someone's science pipeline. + +A later follow-up can add an explicit casting policy if needed. + +### Nodata behavior + +lazycogs currently fills out-of-bounds pixels with `nodata` or zero. + +Required behavior with rust-warp: + +- preserve current caller-facing semantics +- pass explicit `nodata` whenever known +- keep lazycogs' existing `effective_nodata` logic per band asset + +Open detail: + +- verify how rust-warp treats integer arrays when `nodata=None` +- verify whether NaN propagation for float arrays matches lazycogs expectations + +This must be locked down in tests before cutover. + +## Integration Points + +### `src/lazycogs/_core.py` + +- add `resampling` parameter to `open()` +- store on backend object +- validate once at open time + +### `src/lazycogs/_backend.py` + +- propagate `resampling` through chunk reads +- no change to indexing model + +### `src/lazycogs/_chunk_reader.py` + +Current `_apply_bands_with_warp_cache()` should be replaced or renamed to reflect the new responsibility. Suggested shape: + +```python +def _reproject_band_rasters( + band_rasters: list[tuple[str, RasterArray, CRS, float | None]], + dst_transform: Affine, + dst_crs: CRS, + dst_width: int, + dst_height: int, + resampling: str, +) -> dict[str, tuple[np.ndarray, float | None]]: + ... +``` + +Responsibilities: + +- preserve same-grid fast path +- call `reproject_tile()` for actual warps +- stop exposing backend-specific cache mechanics to callers + +### `src/lazycogs/_reproject.py` + +Migration plan: + +- stage 1: keep legacy implementation under renamed helpers +- stage 2: add backend-neutral `reproject_tile()` dispatcher +- stage 3: make rust-warp the default implementation +- stage 4: delete legacy warp-map code if benchmarks and parity are acceptable + +### Dependencies + +Add `rust-warp` as a dependency if licensing, wheel availability, and platform coverage are acceptable. + +Before adding it, verify: + +- Python versions supported by lazycogs +- Linux/macOS wheel availability for CI and target users +- whether `maturin`-built wheels include everything needed for downstream installs + +This is a release engineering issue, not just a code issue. + +## Migration Path + +### Phase 1: Adapter and hidden A/B mode + +- Add dependency. +- Add `reproject_tile()` abstraction. +- Keep legacy engine in place. +- Add rust-warp adapter behind an internal switch. +- Add parity tests that run both engines on the same cases. + +Exit criteria: + +- supported dtypes work +- same-grid fast path preserved +- nearest parity acceptable on representative CRS pairs + +### Phase 2: Public resampling API + +- Add `resampling=` to `open()`. +- Route `nearest` through rust-warp too in test environments. +- Add tests for `bilinear`, `cubic`, `lanczos`, `average` where appropriate. +- Benchmark representative workloads. + +Exit criteria: + +- API stable +- docs updated +- no obvious regressions in common nearest-neighbor path + +### Phase 3: Default cutover + +- Make rust-warp the default backend. +- Retain legacy engine only as short-lived fallback if needed. + +Exit criteria: + +- CI green +- benchmark deltas understood +- accuracy deltas documented + +### Phase 4: Cleanup + +- Remove `WarpMap`, `compute_warp_map`, `apply_warp_map`, and related cache plumbing if no longer needed. +- Simplify docs and architecture notes. + +## Testing Strategy + +### Unit tests + +Add tests for: + +- affine tuple conversion +- CRS normalization to EPSG/PROJ strings +- supported dtype pass-through +- unsupported dtype rejection +- same-grid fast path bypassing rust-warp +- output shape and dtype preservation + +### Parity tests vs legacy engine + +For `nearest` only, compare legacy and rust-warp on: + +- identical-grid reads +- partial overlap +- out-of-bounds nodata fill +- common CRS pairs in docs/tests +- multi-band windows + +Comparison rule: + +- exact equality for identity and same-CRS simple cases +- exact or near-exact equality for common reprojection cases, depending on measured behavior + +### Parity tests vs rasterio/GDAL reference + +Keep and extend the existing raster parity suite to compare: + +- `nearest` +- `bilinear` +- `cubic` +- maybe `average` for downsampling cases + +This matters because switching away from `pyproj` changes the projection engine, not just the resampler. + +### Integration tests + +Exercise full lazycogs chunk reads with: + +- multiple overlapping items +- same-CRS overview reads +- cross-CRS reads +- preserved mosaic ordering with `FirstMethod` + +### Benchmark tests + +Benchmark at least these scenarios: + +1. same-grid fast path +2. nearest reprojection, same dtype as current tests +3. bilinear reprojection on continuous data +4. many small chunk reads across dates +5. overview-backed read vs full-resolution read + +The benchmark goal is not "rust is always faster". The goal is to confirm the whole lazycogs pipeline gets better or at least does not regress in the dominant workloads. + +## Decision Log + +| Decision | Options Considered | Rationale | +|----------|--------------------|-----------| +| Integrate at low-level warp API only | Use rust-warp xarray/dask APIs; replace only `_reproject.py` | lazycogs already owns orchestration and selective reads; replacing more would duplicate working logic | +| Keep `pyproj` for search-time CRS work | Remove `pyproj` entirely; replace only per-pixel warp math | search-time transforms are not the problem and do not justify widening scope | +| Add backend-neutral `reproject_tile()` | Keep `compute_warp_map` API and swap internals | current API is overfit to the old implementation | +| Iterate over band planes in Python | Wait for upstream multi-band support; try to coerce 3D input now | documented low-level API is 2D; a thin Python loop is the safest integration | +| Validate dtypes explicitly | Implicit casts; best-effort support | silent casts are risky for scientific output | +| Retain same-grid fast path | Always call rust-warp | avoiding unnecessary warp calls preserves a proven optimization | + +## Risks + +### 1. CRS accuracy differences + +rust-warp uses native Rust projections for common CRSes and `proj4rs` fallback elsewhere. That means output may differ slightly from current `pyproj`-backed behavior. + +Mitigation: + +- document expected differences +- benchmark and parity-test common CRS pairs used by lazycogs +- prefer EPSG normalization so common CRSes route to rust-warp's native implementations + +### 2. Packaging risk + +A compiled extension dependency can break installation in environments where pure Python currently works. + +Mitigation: + +- verify wheel availability before merge +- test install in CI on supported platforms +- do not remove legacy path until packaging is proven + +### 3. Small-window overhead + +lazycogs often reprojects many small windows, not one giant array. If rust-warp has high fixed call overhead, theoretical wins may not show up in real workloads. + +Mitigation: + +- benchmark realistic lazycogs chunk shapes, not only synthetic full-image cases + +### 4. Nodata/resampling semantics drift + +Different kernels and nodata rules can subtly change mosaic outcomes. + +Mitigation: + +- lock behavior down with integration tests +- document that `average` is for downsampling, not categorical data + +## Open Questions + +- Does rust-warp's per-call overhead remain favorable for the small window sizes lazycogs reads most often? +- Which lazycogs-supported datasets use dtypes outside rust-warp's current low-level support matrix? +- Are there important lazycogs CRS workflows that would hit `proj4rs` fallback instead of rust-warp native implementations? +- Do we want a temporary environment-variable fallback to the legacy engine during rollout? +- Should `average` ship in the initial public API, or only after dedicated downsampling tests? + +## Recommended Implementation Order + +1. Add adapter module and backend-neutral `reproject_tile()`. +2. Thread `resampling` through internal call sites. +3. Add hidden rust-warp path and parity tests. +4. Benchmark realistic workloads. +5. Expose public resampling API. +6. Cut over default backend. +7. Delete legacy warp-map code. + +## References + +- `src/lazycogs/_reproject.py` +- `src/lazycogs/_chunk_reader.py` +- `tests/test_reproject.py` +- `README.md` +- `ARCHITECTURE.md` +- https://github.com/jakenotjay/rust-warp +- https://github.com/jakenotjay/rust-warp/blob/main/docs/architecture.md +- https://github.com/jakenotjay/rust-warp/blob/main/docs/proj4rs-differences.md From bb7202929d20834b5cce83914908aa7b15d430b3 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 13 May 2026 13:32:54 -0500 Subject: [PATCH 2/7] refactor: add scaffolding for rust-warp integration --- ARCHITECTURE.md | 15 ++-- README.md | 4 +- pyproject.toml | 4 + src/lazycogs/_chunk_reader.py | 64 +++++++--------- src/lazycogs/_reproject.py | 137 ++++++++++++++++++++++++++++++---- tests/test_chunk_reader.py | 40 +++++++--- tests/test_core.py | 7 ++ tests/test_reproject.py | 54 ++++++++++++++ uv.lock | 10 +++ 9 files changed, 263 insertions(+), 72 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index e8e6de6..4ef091f 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -1,6 +1,6 @@ # Architecture: lazycogs -lazycogs turns a geoparquet STAC item index into a lazy `(band, time, y, x)` xarray DataArray backed by Cloud-Optimized GeoTIFFs. It requires no GDAL. All raster I/O is done through `async-geotiff` (Rust-backed), spatial queries go through DuckDB via `rustac`, and reprojection is pure `pyproj` + numpy. +lazycogs turns a geoparquet STAC item index into a lazy `(band, time, y, x)` xarray DataArray backed by Cloud-Optimized GeoTIFFs. It requires no GDAL. All raster I/O is done through `async-geotiff` (Rust-backed), spatial queries go through DuckDB via `rustac`, and reprojection now flows through a backend-neutral dispatcher in `_reproject.py`. Today that dispatcher still defaults to the legacy `pyproj` + numpy nearest-neighbor engine while the rust-warp adapter is being integrated. ## Why parquet, not a STAC API URL @@ -30,7 +30,7 @@ src/lazycogs/ _executor.py Per-chunk reprojection thread pool configuration. Exposes set_reproject_workers() and get_max_workers(); the actual pool is created per event loop in _backend.py. _explain.py Dry-run read estimator. Registers the da.lazycogs.explain() xarray accessor. _grid.py Compute output affine transform and dimensions from bbox + resolution. - _reproject.py Nearest-neighbor reprojection using pyproj Transformer + numpy fancy indexing. + _reproject.py Backend-neutral reprojection dispatcher; legacy nearest-neighbor backend still uses pyproj Transformer + numpy fancy indexing. _storage_ext.py STAC Storage Extension metadata parsing (version detection, kwargs extraction for v1 and v2). _store.py Resolve cloud HREFs into obstore Store instances (or route through a user-supplied store) with a thread-local cache; store_for() factory for constructing stores from parquet STAC files. _temporal.py Temporal grouping strategies (day, week, month, year, fixed-day-count). @@ -125,9 +125,11 @@ If the chunk bbox falls entirely outside the source image after clamping, `_nati `await reader.read(window=window)` fetches the windowed pixel data from the selected overview level (or full-res). The result is a `(bands, window_h, window_w)` array in the source CRS/grid. -### 4. Nearest-neighbor reprojection +### 4. Reprojection dispatch and current legacy backend -`reproject_array()` in `_reproject.py` warps the source tile onto the destination chunk grid without GDAL: +`_chunk_reader.py` now builds a `ReprojectRequest` and calls `reproject_tile()` in `_reproject.py` rather than reaching directly into warp-map helpers. That gives lazycogs a clean seam for swapping reprojection engines while keeping chunk orchestration unchanged. + +The dispatcher currently short-circuits exact same-grid reads and otherwise routes to the legacy nearest-neighbor backend, which still warps the source tile onto the destination chunk grid without GDAL: 1. Build a meshgrid of destination pixel-centre coordinates. 2. Transform all coordinates from `dst_crs` to `src_crs` in one vectorised `Transformer.transform()` call. @@ -135,7 +137,7 @@ If the chunk bbox falls entirely outside the source image after clamping, `_nati 4. `np.floor` rounds to the nearest-neighbor sample; numpy fancy indexing populates the output array. 5. Out-of-bounds pixels get the nodata fill value. -Nearest-neighbor is the only supported resampling method. +Nearest-neighbor is still the only active resampling method in production code. The `rust-warp` dependency is present for development and adapter work, but it is not the default engine yet. ## Concurrency model @@ -264,7 +266,8 @@ When the store root does not align with the URL structure of the asset HREFs — | `arro3-core` | Zero-copy Arrow table output from DuckDB queries (installed via `rustac[arrow]`) | | `async-geotiff` | Async COG header reads and windowed tile reads (Rust, no GDAL) | | `obstore` | Cloud object store abstraction layer for async-geotiff | -| `pyproj` | CRS transforms: bbox reprojection, warp map generation | +| `rust-warp` | Experimental reprojection backend dependency, currently sourced from GitHub during integration work | +| `pyproj` | CRS transforms: bbox reprojection, target-resolution estimation, legacy warp backend | | `xarray` | DataArray / Dataset assembly, `BackendArray` / `LazilyIndexedArray` protocol | | `rasterix` | CRS-aware `RasterIndex` for lazy spatial coordinates | | `xproj` | CRS accessor and alignment for xarray Flexible Indexes | diff --git a/README.md b/README.md index 39f59f7..5711b03 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ One constraint worth naming: lazycogs only reads Cloud Optimized GeoTIFFs. If yo | STAC search + spatial indexing | `rustac` (DuckDB + geoparquet) | | COG I/O | `async-geotiff` (Rust, no GDAL) | | Cloud storage | `obstore` | -| Reprojection | `pyproj` + numpy | +| Reprojection | backend-neutral seam in `lazycogs`; legacy engine is `pyproj` + numpy | | Lazy dataset construction | xarray `BackendEntrypoint` + `LazilyIndexedArray` | ## Installation @@ -33,6 +33,8 @@ One constraint worth naming: lazycogs only reads Cloud Optimized GeoTIFFs. If yo pip install lazycogs ``` +Current development work also pins `rust-warp` from GitHub via uv for local testing of the upcoming reprojection backend swap. That source dependency is still experimental and is not a release-ready packaging story yet. + ## Coordinate convention `lazycogs.open()` returns a DataArray whose `y` coordinates follow the standard diff --git a/pyproject.toml b/pyproject.toml index 6375fa5..079cadc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "pandas>=3.0.2", "pyproj>=3.7.2", "rasterix>=0.2.0", + "rust-warp", "rustac[arrow]>=0.9.10", "xarray>=2026.2.0", "xproj>=0.2.0", @@ -115,3 +116,6 @@ python_files = ["test_*.py", "bench_*.py"] markers = [ "benchmark: performance benchmarks (run with --benchmark-enable)", ] + +[tool.uv.sources] +rust-warp = { git = "https://github.com/jakenotjay/rust-warp.git" } diff --git a/src/lazycogs/_chunk_reader.py b/src/lazycogs/_chunk_reader.py index b3b7439..dd71302 100644 --- a/src/lazycogs/_chunk_reader.py +++ b/src/lazycogs/_chunk_reader.py @@ -14,12 +14,7 @@ from lazycogs._executor import _run_coroutine from lazycogs._mosaic_methods import FirstMethod, MosaicMethodBase -from lazycogs._reproject import ( - WarpMap, - _get_transformer, - apply_warp_map, - compute_warp_map, -) +from lazycogs._reproject import ReprojectRequest, _get_transformer, reproject_tile from lazycogs._store import resolve as _resolve_store if TYPE_CHECKING: @@ -48,7 +43,7 @@ class _ChunkContext: nodata: float | None store: ObjectStore | None path_fn: Callable[[str], str] | None - warp_cache: dict[tuple[tuple[float, ...], CRS], WarpMap] | None + warp_cache: dict[object, object] | None def _log_batch_failure( @@ -275,45 +270,32 @@ def _apply_bands_with_warp_cache( dst_crs: CRS, dst_width: int, dst_height: int, - warp_cache: dict[tuple[tuple[float, ...], CRS], WarpMap] | None = None, + warp_cache: dict[object, object] | None = None, ) -> dict[str, tuple[np.ndarray, float | None]]: - """Apply warp maps to multiple band rasters, reusing maps for identical geometries. + """Reproject multiple band rasters through the backend-neutral interface. - Checks ``warp_cache`` (keyed on ``(tuple(raster.transform), src_crs.to_wkt())``) - before computing a new warp map. When ``warp_cache`` is shared across calls - (e.g. across time steps in a single chunk read), warp maps for recurring tile - geometries are computed only once. Bands with different geometries each get - their own correct warp map. - - This function is designed to run inside a thread executor — it is CPU-bound - and must not be called from the async event loop directly. When ``warp_cache`` - is shared across concurrent executor calls, two threads may both compute the - same warp map before either stores it; this is safe because ``compute_warp_map`` - is deterministic and the duplicate result is simply overwritten. + The optional ``warp_cache`` is currently forwarded to the legacy backend so + repeated source geometries can still reuse precomputed mappings during the + migration away from warp-map-specific call sites. Args: band_rasters: List of ``(band_name, raster, src_crs, effective_nodata)`` - tuples. ``raster`` must have ``.transform`` (Affine) and ``.data`` + tuples. ``raster`` must have ``.transform`` (Affine) and ``.data`` (ndarray of shape ``(bands, h, w)``) attributes. dst_transform: Affine transform of the destination grid. dst_crs: CRS of the destination grid. dst_width: Width of the destination grid in pixels. dst_height: Height of the destination grid in pixels. - warp_cache: Optional external cache shared across calls. When ``None`` - a fresh local dict is used (original per-item behaviour). + warp_cache: Optional migration-time cache shared across calls. Returns: ``dict`` mapping band name to ``(reprojected_array, effective_nodata)``. """ - cache: dict[tuple[tuple[float, ...], CRS], WarpMap] = ( - warp_cache if warp_cache is not None else {} - ) + cache = warp_cache if warp_cache is not None else {} results: dict[str, tuple[np.ndarray, float | None]] = {} for band, raster, src_crs, effective_nodata in band_rasters: - # Fast path: skip reprojection when the read window already matches the - # destination chunk exactly (same CRS, same affine, same pixel dimensions). if ( src_crs.equals(dst_crs) and raster.transform == dst_transform @@ -322,18 +304,22 @@ def _apply_bands_with_warp_cache( ): results[band] = (raster.data, effective_nodata) continue - cache_key = (tuple(raster.transform), src_crs) - if cache_key not in cache: - cache[cache_key] = compute_warp_map( - src_transform=raster.transform, - src_crs=src_crs, - dst_transform=dst_transform, - dst_crs=dst_crs, - dst_width=dst_width, - dst_height=dst_height, - ) + results[band] = ( - apply_warp_map(raster.data, cache[cache_key], effective_nodata), + reproject_tile( + ReprojectRequest( + data=raster.data, + src_transform=raster.transform, + src_crs=src_crs, + dst_transform=dst_transform, + dst_crs=dst_crs, + dst_width=dst_width, + dst_height=dst_height, + nodata=effective_nodata, + resampling="nearest", + ), + warp_cache=cache, + ), effective_nodata, ) diff --git a/src/lazycogs/_reproject.py b/src/lazycogs/_reproject.py index 481d331..0044794 100644 --- a/src/lazycogs/_reproject.py +++ b/src/lazycogs/_reproject.py @@ -1,10 +1,10 @@ -"""Reproject raster arrays using pyproj and numpy nearest-neighbor sampling.""" +"""Reproject raster arrays using a backend-neutral request interface.""" from __future__ import annotations import functools from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np from pyproj import CRS, Transformer @@ -34,6 +34,35 @@ def _get_transformer(src_crs: CRS, dst_crs: CRS) -> Transformer: return Transformer.from_crs(src_crs, dst_crs, always_xy=True) +@dataclass(frozen=True) +class ReprojectRequest: + """All inputs required to reproject one source tile. + + Attributes: + data: Source data with shape ``(bands, src_h, src_w)``. + src_transform: Affine transform of the source array. + src_crs: CRS of the source array. + dst_transform: Affine transform of the destination grid. + dst_crs: CRS of the destination grid. + dst_width: Destination width in pixels. + dst_height: Destination height in pixels. + nodata: Fill value for pixels that fall outside the source extent. + resampling: Requested resampling method. Only ``"nearest"`` is + supported by the legacy backend. + + """ + + data: np.ndarray + src_transform: Affine + src_crs: CRS + dst_transform: Affine + dst_crs: CRS + dst_width: int + dst_height: int + nodata: float | None = None + resampling: str = "nearest" + + @dataclass class WarpMap: """Precomputed pixel-coordinate mapping from a destination grid to a source grid. @@ -145,6 +174,84 @@ def apply_warp_map( return out +def _same_grid(request: ReprojectRequest) -> bool: + """Return ``True`` when reprojection is an exact no-op.""" + return ( + request.src_crs.equals(request.dst_crs) + and request.src_transform == request.dst_transform + and request.data.shape[1] == request.dst_height + and request.data.shape[2] == request.dst_width + ) + + +def _legacy_cache_key(request: ReprojectRequest) -> tuple[tuple[float, ...], CRS]: + """Return the cache key for the legacy nearest-neighbor backend.""" + return (tuple(request.src_transform), request.src_crs) + + +def _reproject_tile_legacy( + request: ReprojectRequest, + warp_map: WarpMap | None = None, +) -> np.ndarray: + """Reproject a tile with the legacy pyproj/numpy nearest backend.""" + if request.resampling != "nearest": + raise ValueError( + "The legacy reprojection backend only supports resampling='nearest'.", + ) + + resolved_warp_map = warp_map or compute_warp_map( + request.src_transform, + request.src_crs, + request.dst_transform, + request.dst_crs, + request.dst_width, + request.dst_height, + ) + return apply_warp_map(request.data, resolved_warp_map, request.nodata) + + +def reproject_tile( + request: ReprojectRequest, + *, + backend: Literal["legacy"] = "legacy", + warp_cache: dict[tuple[tuple[float, ...], CRS], WarpMap] | None = None, +) -> np.ndarray: + """Reproject one ``(bands, y, x)`` source tile onto a destination grid. + + Args: + request: Reprojection inputs for one tile. + backend: Internal backend selector. Only ``"legacy"`` is currently + implemented. + warp_cache: Optional cache for the legacy backend's precomputed warp + maps, keyed by source transform and CRS. + + Returns: + Reprojected array with shape ``(bands, dst_height, dst_width)``. + + """ + if _same_grid(request): + return request.data + if backend != "legacy": + raise ValueError(f"Unsupported reprojection backend: {backend}") + + warp_map: WarpMap | None = None + if warp_cache is not None: + cache_key = _legacy_cache_key(request) + warp_map = warp_cache.get(cache_key) + if warp_map is None: + warp_map = compute_warp_map( + request.src_transform, + request.src_crs, + request.dst_transform, + request.dst_crs, + request.dst_width, + request.dst_height, + ) + warp_cache[cache_key] = warp_map + + return _reproject_tile_legacy(request, warp_map) + + def reproject_array( data: np.ndarray, src_transform: Affine, @@ -157,10 +264,9 @@ def reproject_array( ) -> np.ndarray: """Reproject a raster array using nearest-neighbor sampling. - Convenience wrapper around :func:`compute_warp_map` and - :func:`apply_warp_map`. Use those functions directly when the same source - CRS and window transform are shared across multiple bands, so the warp map - can be computed once and reused. + Convenience wrapper around :class:`ReprojectRequest` and + :func:`reproject_tile`. Use :func:`reproject_tile` directly for the new + backend-neutral path. Args: data: Source data with shape ``(bands, src_h, src_w)``. @@ -178,12 +284,15 @@ def reproject_array( the same dtype as ``data``. """ - warp_map = compute_warp_map( - src_transform, - src_crs, - dst_transform, - dst_crs, - dst_width, - dst_height, + return reproject_tile( + ReprojectRequest( + data=data, + src_transform=src_transform, + src_crs=src_crs, + dst_transform=dst_transform, + dst_crs=dst_crs, + dst_width=dst_width, + dst_height=dst_height, + nodata=nodata, + ), ) - return apply_warp_map(data, warp_map, nodata) diff --git a/tests/test_chunk_reader.py b/tests/test_chunk_reader.py index a8634b0..88b549b 100644 --- a/tests/test_chunk_reader.py +++ b/tests/test_chunk_reader.py @@ -273,6 +273,25 @@ def _make_raster(transform: Affine, value: float, h: int = 4, w: int = 4) -> Mag return raster +def test_apply_bands_with_warp_cache_same_grid_bypasses_reproject_tile(): + """Same-grid reads return directly without invoking the reprojection backend.""" + crs = CRS.from_epsg(4326) + transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 4.0) + raster = _make_raster(transform, 1.0) + + with patch("lazycogs._chunk_reader.reproject_tile") as reproject_tile_mock: + results = _apply_bands_with_warp_cache( + [("B01", raster, crs, None)], + transform, + crs, + dst_width=4, + dst_height=4, + ) + + reproject_tile_mock.assert_not_called() + np.testing.assert_array_equal(results["B01"][0], raster.data) + + def test_apply_bands_with_warp_cache_shared_geometry(): """Bands with the same transform/CRS share a single warp map computation.""" crs = CRS.from_epsg(4326) @@ -285,16 +304,15 @@ def test_apply_bands_with_warp_cache_shared_geometry(): raster_b = _make_raster(transform, 2.0) warp_map_calls = [] + from lazycogs._reproject import compute_warp_map as real_compute_warp_map def _spy_compute_warp_map(*args, **kwargs): - from lazycogs._reproject import compute_warp_map as _real - - result = _real(*args, **kwargs) + result = real_compute_warp_map(*args, **kwargs) warp_map_calls.append(True) return result with patch( - "lazycogs._chunk_reader.compute_warp_map", + "lazycogs._reproject.compute_warp_map", side_effect=_spy_compute_warp_map, ): results = _apply_bands_with_warp_cache( @@ -325,16 +343,15 @@ def test_apply_bands_with_warp_cache_different_geometry(): raster_b = _make_raster(transform_b, 2.0, h=2, w=2) warp_map_calls = [] + from lazycogs._reproject import compute_warp_map as real_compute_warp_map def _spy_compute_warp_map(*args, **kwargs): - from lazycogs._reproject import compute_warp_map as _real - - result = _real(*args, **kwargs) + result = real_compute_warp_map(*args, **kwargs) warp_map_calls.append(True) return result with patch( - "lazycogs._chunk_reader.compute_warp_map", + "lazycogs._reproject.compute_warp_map", side_effect=_spy_compute_warp_map, ): results = _apply_bands_with_warp_cache( @@ -361,16 +378,15 @@ def test_apply_bands_with_warp_cache_shared_across_calls(): shared_cache: dict = {} warp_map_calls = [] + from lazycogs._reproject import compute_warp_map as real_compute_warp_map def _spy_compute_warp_map(*args, **kwargs): - from lazycogs._reproject import compute_warp_map as _real - - result = _real(*args, **kwargs) + result = real_compute_warp_map(*args, **kwargs) warp_map_calls.append(True) return result with patch( - "lazycogs._chunk_reader.compute_warp_map", + "lazycogs._reproject.compute_warp_map", side_effect=_spy_compute_warp_map, ): _apply_bands_with_warp_cache( diff --git a/tests/test_core.py b/tests/test_core.py index c3a0a05..b371f81 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -84,6 +84,13 @@ def test_open_accepts_parquet_extension_passes_validation(tmp_path): assert "must be a .parquet" not in str(exc_info.value) +def test_rust_warp_dependency_available(): + """The configured rust-warp dependency is importable in the test environment.""" + import rust_warp + + assert callable(rust_warp.reproject_array) + + # --------------------------------------------------------------------------- # _build_time_steps # --------------------------------------------------------------------------- diff --git a/tests/test_reproject.py b/tests/test_reproject.py index cca8221..559784d 100644 --- a/tests/test_reproject.py +++ b/tests/test_reproject.py @@ -6,10 +6,12 @@ from pyproj import CRS from lazycogs._reproject import ( + ReprojectRequest, WarpMap, apply_warp_map, compute_warp_map, reproject_array, + reproject_tile, ) @@ -44,6 +46,58 @@ def test_output_shape(wgs84): assert out.shape == (2, 2, 1) +def test_reproject_tile_same_grid_returns_original_array(wgs84): + """The backend-neutral dispatcher short-circuits exact same-grid reads.""" + transform = _make_transform(0.0, 3.0, 1.0) + data = np.arange(9, dtype=np.float32).reshape(1, 3, 3) + + out = reproject_tile( + ReprojectRequest( + data=data, + src_transform=transform, + src_crs=wgs84, + dst_transform=transform, + dst_crs=wgs84, + dst_width=3, + dst_height=3, + ), + ) + + assert out is data + + +def test_reproject_tile_matches_legacy_wrapper(wgs84): + """The backend-neutral path preserves current nearest-neighbor behavior.""" + src_transform = _make_transform(0.0, 2.0, 1.0) + dst_transform = _make_transform(0.0, 4.0, 2.0) + data = np.ones((2, 2, 2), dtype=np.float32) + + request = ReprojectRequest( + data=data, + src_transform=src_transform, + src_crs=wgs84, + dst_transform=dst_transform, + dst_crs=wgs84, + dst_width=1, + dst_height=2, + nodata=-9999.0, + ) + + np.testing.assert_array_equal( + reproject_tile(request), + reproject_array( + data, + src_transform, + wgs84, + dst_transform, + wgs84, + 1, + 2, + nodata=-9999.0, + ), + ) + + def test_out_of_bounds_pixels_get_nodata(wgs84): """Destination pixels outside the source extent are filled with nodata.""" src_transform = _make_transform(5.0, 5.0, 1.0) # covers x=5..8, y=2..5 diff --git a/uv.lock b/uv.lock index 1d790fb..426f8d6 100644 --- a/uv.lock +++ b/uv.lock @@ -1296,6 +1296,7 @@ dependencies = [ { name = "pandas" }, { name = "pyproj" }, { name = "rasterix" }, + { name = "rust-warp" }, { name = "rustac", extra = ["arrow"] }, { name = "xarray" }, { name = "xproj" }, @@ -1335,6 +1336,7 @@ requires-dist = [ { name = "pandas", specifier = ">=3.0.2" }, { name = "pyproj", specifier = ">=3.7.2" }, { name = "rasterix", specifier = ">=0.2.0" }, + { name = "rust-warp", git = "https://github.com/jakenotjay/rust-warp.git" }, { name = "rustac", extras = ["arrow"], specifier = ">=0.9.10" }, { name = "xarray", specifier = ">=2026.2.0" }, { name = "xproj", specifier = ">=0.2.0" }, @@ -2723,6 +2725,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c0/98/6beb4b351e472e5f4c4613f7c35a5290b8be2497e183825310c4c3a3984b/ruff-0.15.12-py3-none-win_arm64.whl", hash = "sha256:a538f7a82d061cee7be55542aca1d86d1393d55d81d4fcc314370f4340930d4f", size = 11120821, upload-time = "2026-04-24T18:16:57.979Z" }, ] +[[package]] +name = "rust-warp" +version = "0.1.0" +source = { git = "https://github.com/jakenotjay/rust-warp.git#c384ebe90063a1c94d6deb0ecef2c54cc561309e" } +dependencies = [ + { name = "numpy" }, +] + [[package]] name = "rustac" version = "0.9.10" From 3541c8c1902bcef02e667926033508b46637921b Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 13 May 2026 13:58:50 -0500 Subject: [PATCH 3/7] refactor: Add the rust-warp adapter and backend selection --- src/lazycogs/_reproject.py | 31 +++++++-- src/lazycogs/_rust_warp.py | 135 +++++++++++++++++++++++++++++++++++++ tests/integration_test.py | 88 ++++++++++++++++-------- tests/test_chunk_reader.py | 27 ++++++++ tests/test_reproject.py | 107 +++++++++++++++++++++++++++++ 5 files changed, 353 insertions(+), 35 deletions(-) create mode 100644 src/lazycogs/_rust_warp.py diff --git a/src/lazycogs/_reproject.py b/src/lazycogs/_reproject.py index 0044794..f55e21f 100644 --- a/src/lazycogs/_reproject.py +++ b/src/lazycogs/_reproject.py @@ -9,6 +9,8 @@ import numpy as np from pyproj import CRS, Transformer +from lazycogs._rust_warp import reproject_array_rust_warp + if TYPE_CHECKING: from affine import Affine @@ -210,20 +212,23 @@ def _reproject_tile_legacy( return apply_warp_map(request.data, resolved_warp_map, request.nodata) +_DEFAULT_REPROJECT_BACKEND: Literal["legacy", "rust-warp"] = "legacy" + + def reproject_tile( request: ReprojectRequest, *, - backend: Literal["legacy"] = "legacy", + backend: Literal["legacy", "rust-warp"] | None = None, warp_cache: dict[tuple[tuple[float, ...], CRS], WarpMap] | None = None, ) -> np.ndarray: """Reproject one ``(bands, y, x)`` source tile onto a destination grid. Args: request: Reprojection inputs for one tile. - backend: Internal backend selector. Only ``"legacy"`` is currently - implemented. + backend: Internal backend selector. When ``None``, the module's + current default backend is used. warp_cache: Optional cache for the legacy backend's precomputed warp - maps, keyed by source transform and CRS. + maps, keyed by source transform and CRS. Ignored by rust-warp. Returns: Reprojected array with shape ``(bands, dst_height, dst_width)``. @@ -231,8 +236,22 @@ def reproject_tile( """ if _same_grid(request): return request.data - if backend != "legacy": - raise ValueError(f"Unsupported reprojection backend: {backend}") + + resolved_backend = _DEFAULT_REPROJECT_BACKEND if backend is None else backend + if resolved_backend == "rust-warp": + return reproject_array_rust_warp( + data=request.data, + src_transform=request.src_transform, + src_crs=request.src_crs, + dst_transform=request.dst_transform, + dst_crs=request.dst_crs, + dst_width=request.dst_width, + dst_height=request.dst_height, + nodata=request.nodata, + resampling=request.resampling, + ) + if resolved_backend != "legacy": + raise ValueError(f"Unsupported reprojection backend: {resolved_backend}") warp_map: WarpMap | None = None if warp_cache is not None: diff --git a/src/lazycogs/_rust_warp.py b/src/lazycogs/_rust_warp.py new file mode 100644 index 0000000..9a9e33c --- /dev/null +++ b/src/lazycogs/_rust_warp.py @@ -0,0 +1,135 @@ +"""Private adapter for rust-warp's low-level reprojection API.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import rust_warp + +if TYPE_CHECKING: + from affine import Affine + from pyproj import CRS + +_SUPPORTED_DTYPES = frozenset( + { + np.dtype(np.float32), + np.dtype(np.float64), + np.dtype(np.int8), + np.dtype(np.uint8), + np.dtype(np.uint16), + np.dtype(np.int16), + }, +) +_EXPECTED_ARRAY_NDIM = 3 + + +def _affine_to_rust_warp( + transform: Affine, +) -> tuple[float, float, float, float, float, float]: + """Return affine coefficients in rust-warp's 6-value rasterio order.""" + return ( + transform.a, + transform.b, + transform.c, + transform.d, + transform.e, + transform.f, + ) + + +def _normalize_crs(crs: CRS) -> str: + """Return a rust-warp-compatible CRS string for ``crs``. + + Prefers ``EPSG:`` when ``pyproj`` can resolve one, then falls back to + the CRS's PROJ string representation. + """ + epsg = crs.to_epsg() + if epsg is not None: + return f"EPSG:{epsg}" + + proj4 = crs.to_proj4().strip() + if proj4: + return proj4 + + raise ValueError(f"Could not normalize CRS {crs!r} to an EPSG or PROJ string.") + + +def _validate_supported_dtype(data: np.ndarray) -> np.dtype: + """Return ``data.dtype`` when rust-warp supports it, else raise ``TypeError``.""" + dtype = data.dtype + if dtype not in _SUPPORTED_DTYPES: + supported = ", ".join(sorted(dt.name for dt in _SUPPORTED_DTYPES)) + raise TypeError( + "rust-warp does not support dtype " + f"{dtype.name!r}. Supported dtypes: {supported}.", + ) + return dtype + + +def _normalize_nodata(nodata: float | None, dtype: np.dtype) -> float | int | None: + """Cast ``nodata`` to the source dtype so fill semantics match numpy's.""" + if nodata is None: + return None + return np.array([nodata]).astype(dtype, casting="unsafe")[0].item() + + +def reproject_array_rust_warp( + data: np.ndarray, + src_transform: Affine, + src_crs: CRS, + dst_transform: Affine, + dst_crs: CRS, + dst_width: int, + dst_height: int, + nodata: float | None = None, + resampling: str = "nearest", +) -> np.ndarray: + """Reproject a ``(bands, y, x)`` array via rust-warp's 2D kernel. + + Args: + data: Source array with shape ``(bands, src_h, src_w)``. + src_transform: Affine transform of the source array. + src_crs: CRS of the source array. + dst_transform: Affine transform of the destination grid. + dst_crs: CRS of the destination grid. + dst_width: Destination width in pixels. + dst_height: Destination height in pixels. + nodata: Fill value for pixels outside the source extent. + resampling: rust-warp resampling method name. + + Returns: + Reprojected array with shape ``(bands, dst_height, dst_width)``. + + Raises: + ValueError: If ``data`` is not 3D or the CRS cannot be normalized. + TypeError: If the input dtype is unsupported by rust-warp. + + """ + if data.ndim != _EXPECTED_ARRAY_NDIM: + raise ValueError( + "rust-warp adapter expects data with shape (bands, src_height, src_width).", + ) + + dtype = _validate_supported_dtype(data) + src_crs_str = _normalize_crs(src_crs) + dst_crs_str = _normalize_crs(dst_crs) + src_transform_tuple = _affine_to_rust_warp(src_transform) + dst_transform_tuple = _affine_to_rust_warp(dst_transform) + dst_shape = (dst_height, dst_width) + normalized_nodata = _normalize_nodata(nodata, dtype) + + reprojected_bands = [ + rust_warp.reproject_array( + src=band, + src_crs=src_crs_str, + src_transform=src_transform_tuple, + dst_crs=dst_crs_str, + dst_transform=dst_transform_tuple, + dst_shape=dst_shape, + resampling=resampling, + nodata=normalized_nodata, + ) + for band in data + ] + return np.stack(reprojected_bands, axis=0) diff --git a/tests/integration_test.py b/tests/integration_test.py index 3d312e6..8289c94 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -1,3 +1,4 @@ +import argparse import asyncio import contextlib import hashlib @@ -10,9 +11,11 @@ from pyproj import Transformer import lazycogs +from lazycogs import _reproject logging.basicConfig(level="WARN") logging.getLogger("lazycogs").setLevel("DEBUG") +logger = logging.getLogger(__name__) def _parquet_path( @@ -68,16 +71,40 @@ def measure(label: str): yield elapsed = time.perf_counter() - t0 rss_after = _rss_mb() - print( - f"[{label}] " - f"time={elapsed:.2f}s " - f"rss_before={rss_before:.0f}MB " - f"rss_after={rss_after:.0f}MB " - f"delta={rss_after - rss_before:+.0f}MB", + logger.warning( + "[%s] time=%.2fs rss_before=%.0fMB rss_after=%.0fMB delta=%+.0fMB", + label, + elapsed, + rss_before, + rss_after, + rss_after - rss_before, ) -async def run(): +@contextlib.contextmanager +def _reproject_backend(backend: str): + """Temporarily force the internal reprojection backend for this script.""" + previous_backend = _reproject._DEFAULT_REPROJECT_BACKEND + _reproject._DEFAULT_REPROJECT_BACKEND = backend + try: + yield + finally: + _reproject._DEFAULT_REPROJECT_BACKEND = previous_backend + + +def _parse_args() -> argparse.Namespace: + """Parse command-line options for the integration script.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--reproject-backend", + choices=["legacy", "rust-warp"], + default="legacy", + help="Internal reprojection backend to exercise during the run.", + ) + return parser.parse_args() + + +async def run(reproject_backend: str): dst_crs = "epsg:5070" dst_bbox = (-700_000, 2_220_000, 600_000, 2_930_000) @@ -96,7 +123,7 @@ async def run(): bbox=bbox_4326, limit=limit, ) - print(f"cache: {items_parquet}") + logger.warning("cache: %s", items_parquet) if not items_parquet.exists(): await rustac.search_to( @@ -109,29 +136,32 @@ async def run(): ) # --- daily time steps --- - store = lazycogs.store_for(str(items_parquet), skip_signature=True) - da = lazycogs.open( - str(items_parquet), - crs=dst_crs, - bbox=dst_bbox, - resolution=100, - time_period="P1D", - bands=["red", "green", "blue"], - dtype="int16", - store=store, - ) - print(f"\ndaily array: {da}") + with _reproject_backend(reproject_backend): + logger.warning("using reprojection backend: %s", reproject_backend) + store = lazycogs.store_for(str(items_parquet), skip_signature=True) + da = lazycogs.open( + str(items_parquet), + crs=dst_crs, + bbox=dst_bbox, + resolution=100, + time_period="P1D", + bands=["red", "green", "blue"], + dtype="int16", + store=store, + ) + logger.warning("daily array: %s", da) - with measure("daily point (chunked)"): - _ = da.chunk(time=1).sel(x=299965, y=2653947, method="nearest").compute() + with measure("daily point"): + _ = da.sel(x=299965, y=2653947, method="nearest").compute() - subset = da.sel( - x=slice(100_000, 400_000), - y=slice(2_800_000, 2_600_000), - ) - with measure("daily spatial subset isel(time=1)"): - _ = subset.isel(time=1).load() + subset = da.sel( + x=slice(100_000, 400_000), + y=slice(2_800_000, 2_600_000), + ) + with measure("daily spatial subset isel(time=1)"): + _ = subset.isel(time=1).load() if __name__ == "__main__": - asyncio.run(run()) + args = _parse_args() + asyncio.run(run(args.reproject_backend)) diff --git a/tests/test_chunk_reader.py b/tests/test_chunk_reader.py index 88b549b..895a40e 100644 --- a/tests/test_chunk_reader.py +++ b/tests/test_chunk_reader.py @@ -411,6 +411,33 @@ def _spy_compute_warp_map(*args, **kwargs): assert len(shared_cache) == 1 +def test_apply_bands_with_warp_cache_uses_rust_backend_when_selected(): + """The chunk-reader seam can dispatch to rust-warp without caller changes.""" + crs = CRS.from_epsg(4326) + src_transform = Affine(1.0, 0.0, 0.5, 0.0, -1.0, 4.0) + dst_transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 4.0) + raster = _make_raster(src_transform, 1.0) + expected = np.full((1, 4, 4), 7.0, dtype=np.float32) + + with ( + patch("lazycogs._reproject._DEFAULT_REPROJECT_BACKEND", "rust-warp"), + patch( + "lazycogs._reproject.reproject_array_rust_warp", + return_value=expected, + ) as rust_backend_mock, + ): + results = _apply_bands_with_warp_cache( + [("B01", raster, crs, None)], + dst_transform, + crs, + dst_width=4, + dst_height=4, + ) + + rust_backend_mock.assert_called_once() + np.testing.assert_array_equal(results["B01"][0], expected) + + # --------------------------------------------------------------------------- # read_chunk_async (multi-band) # --------------------------------------------------------------------------- diff --git a/tests/test_reproject.py b/tests/test_reproject.py index 559784d..7c34823 100644 --- a/tests/test_reproject.py +++ b/tests/test_reproject.py @@ -13,6 +13,7 @@ reproject_array, reproject_tile, ) +from lazycogs._rust_warp import _affine_to_rust_warp, _normalize_crs @pytest.fixture @@ -98,6 +99,112 @@ def test_reproject_tile_matches_legacy_wrapper(wgs84): ) +def test_affine_to_rust_warp_uses_six_value_rasterio_order(): + """Affine conversion emits the 6-tuple rust-warp expects.""" + transform = Affine(2.0, 0.5, 10.0, -0.25, -3.0, 20.0) + + assert _affine_to_rust_warp(transform) == (2.0, 0.5, 10.0, -0.25, -3.0, 20.0) + + +def test_normalize_crs_prefers_epsg_strings(wgs84): + """EPSG-backed CRSes normalize to ``EPSG:`` strings.""" + assert _normalize_crs(wgs84) == "EPSG:4326" + + +def test_normalize_crs_falls_back_to_proj_string(): + """Non-EPSG CRSes fall back to a usable PROJ string.""" + custom_crs = CRS.from_proj4( + "+proj=aea +lat_1=29.5 +lat_2=45.5 +lat_0=23 +lon_0=-96 " + "+x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs", + ) + + normalized = _normalize_crs(custom_crs) + + assert normalized.startswith("+proj=aea") + assert "+lat_1=29.5" in normalized + + +@pytest.mark.parametrize( + "dtype", + [np.float32, np.float64, np.int8, np.uint8, np.uint16, np.int16], +) +def test_reproject_tile_rust_warp_supports_expected_dtypes(wgs84, dtype): + """Supported dtypes pass through the rust-warp backend unchanged.""" + src_transform = _make_transform(0.0, 2.0, 1.0) + dst_transform = _make_transform(0.0, 2.0, 0.5) + data = np.arange(4, dtype=dtype).reshape(1, 2, 2) + + out = reproject_tile( + ReprojectRequest( + data=data, + src_transform=src_transform, + src_crs=wgs84, + dst_transform=dst_transform, + dst_crs=wgs84, + dst_width=4, + dst_height=4, + nodata=-1.0, + ), + backend="rust-warp", + ) + + assert out.shape == (1, 4, 4) + assert out.dtype == data.dtype + + +def test_reproject_tile_rust_warp_rejects_unsupported_dtype(wgs84): + """Unsupported dtypes fail deterministically before async chunk execution.""" + src_transform = _make_transform(0.0, 2.0, 1.0) + dst_transform = _make_transform(0.0, 2.0, 0.5) + data = np.arange(4, dtype=np.int32).reshape(1, 2, 2) + + with pytest.raises(TypeError, match="does not support dtype"): + reproject_tile( + ReprojectRequest( + data=data, + src_transform=src_transform, + src_crs=wgs84, + dst_transform=dst_transform, + dst_crs=wgs84, + dst_width=4, + dst_height=4, + ), + backend="rust-warp", + ) + + +def test_reproject_tile_rust_warp_preserves_band_order(wgs84): + """Band-plane iteration keeps shape and band ordering intact.""" + src_transform = _make_transform(0.0, 2.0, 1.0) + dst_transform = _make_transform(0.0, 2.0, 0.5) + data = np.stack( + [ + np.full((2, 2), 10.0, dtype=np.float32), + np.full((2, 2), 20.0, dtype=np.float32), + np.full((2, 2), 30.0, dtype=np.float32), + ], + ) + + out = reproject_tile( + ReprojectRequest( + data=data, + src_transform=src_transform, + src_crs=wgs84, + dst_transform=dst_transform, + dst_crs=wgs84, + dst_width=4, + dst_height=4, + nodata=-9999.0, + ), + backend="rust-warp", + ) + + assert out.shape == (3, 4, 4) + np.testing.assert_array_equal(out[0], 10.0) + np.testing.assert_array_equal(out[1], 20.0) + np.testing.assert_array_equal(out[2], 30.0) + + def test_out_of_bounds_pixels_get_nodata(wgs84): """Destination pixels outside the source extent are filled with nodata.""" src_transform = _make_transform(5.0, 5.0, 1.0) # covers x=5..8, y=2..5 From d06538ada203488dd464298fecd0f4c94d8ffbf5 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 13 May 2026 14:20:23 -0500 Subject: [PATCH 4/7] feat: add resampling arg to open, thread through API --- ARCHITECTURE.md | 12 +++--- README.md | 8 +++- src/lazycogs/__init__.py | 7 +++- src/lazycogs/_backend.py | 7 ++++ src/lazycogs/_chunk_reader.py | 12 +++++- src/lazycogs/_core.py | 38 ++++++++++++++++++- src/lazycogs/_reproject.py | 16 ++++++-- src/lazycogs/_rust_warp.py | 12 ++++-- tests/test_backend.py | 23 ++++++++++++ tests/test_chunk_reader.py | 45 +++++++++++++--------- tests/test_core.py | 70 ++++++++++++++++++++++++++++++++++- tests/test_reproject.py | 35 ++++++++++++++++++ 12 files changed, 249 insertions(+), 36 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 4ef091f..7d6cc55 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -1,6 +1,6 @@ # Architecture: lazycogs -lazycogs turns a geoparquet STAC item index into a lazy `(band, time, y, x)` xarray DataArray backed by Cloud-Optimized GeoTIFFs. It requires no GDAL. All raster I/O is done through `async-geotiff` (Rust-backed), spatial queries go through DuckDB via `rustac`, and reprojection now flows through a backend-neutral dispatcher in `_reproject.py`. Today that dispatcher still defaults to the legacy `pyproj` + numpy nearest-neighbor engine while the rust-warp adapter is being integrated. +lazycogs turns a geoparquet STAC item index into a lazy `(band, time, y, x)` xarray DataArray backed by Cloud-Optimized GeoTIFFs. It requires no GDAL. All raster I/O is done through `async-geotiff` (Rust-backed), spatial queries go through DuckDB via `rustac`, and reprojection now flows through a backend-neutral dispatcher in `_reproject.py`. Public reprojection methods currently route through the rust-warp adapter, while the legacy pyproj + numpy path remains available only as internal migration scaffolding. ## Why parquet, not a STAC API URL @@ -42,7 +42,7 @@ src/lazycogs/ `open()` in `_core.py`: 1. Resolves `duckdb_client`: if not provided, creates a plain `DuckdbClient()`. Validates that `href` ends in `.parquet`/`.geoparquet` when no client is supplied (a directory path is accepted when a custom client is passed). -2. Parses `time_period` into a `_TemporalGrouper` (see `_temporal.py`). +2. Validates `time_period` and `resampling` up front in `_core.py`, so unsupported public options fail before any storage or DuckDB I/O. 3. Converts `bbox` from the target CRS to EPSG:4326 using `pyproj.Transformer`. 4. Calls `_discover_bands()`: queries the parquet source via `duckdb_client.search(..., max_items=1)` to find asset keys. Assets with role `"data"` or media type `"image/tiff"` are returned first. 5. Calls `_smoketest_store()`: fetches one sample item from the parquet, resolves the object store for a representative data asset HREF, and calls `head()` to confirm access. Raises `RuntimeError` immediately with a clear message if the store cannot reach the asset, so misconfiguration is surfaced at `open()` time rather than deferred to the first chunk read. @@ -125,11 +125,13 @@ If the chunk bbox falls entirely outside the source image after clamping, `_nati `await reader.read(window=window)` fetches the windowed pixel data from the selected overview level (or full-res). The result is a `(bands, window_h, window_w)` array in the source CRS/grid. -### 4. Reprojection dispatch and current legacy backend +### 4. Reprojection dispatch and current backend state `_chunk_reader.py` now builds a `ReprojectRequest` and calls `reproject_tile()` in `_reproject.py` rather than reaching directly into warp-map helpers. That gives lazycogs a clean seam for swapping reprojection engines while keeping chunk orchestration unchanged. -The dispatcher currently short-circuits exact same-grid reads and otherwise routes to the legacy nearest-neighbor backend, which still warps the source tile onto the destination chunk grid without GDAL: +The dispatcher always short-circuits exact same-grid reads. After that, public reprojection requests currently route to rust-warp for `nearest`, `bilinear`, and `cubic` alike. + +The legacy nearest-neighbor path is still present as internal migration scaffolding and still warps the source tile onto the destination chunk grid without GDAL: 1. Build a meshgrid of destination pixel-centre coordinates. 2. Transform all coordinates from `dst_crs` to `src_crs` in one vectorised `Transformer.transform()` call. @@ -137,7 +139,7 @@ The dispatcher currently short-circuits exact same-grid reads and otherwise rout 4. `np.floor` rounds to the nearest-neighbor sample; numpy fancy indexing populates the output array. 5. Out-of-bounds pixels get the nodata fill value. -Nearest-neighbor is still the only active resampling method in production code. The `rust-warp` dependency is present for development and adapter work, but it is not the default engine yet. +Public resampling is validated at `open()` time. The currently supported values are `nearest`, `bilinear`, and `cubic`. Same-grid reads bypass reprojection regardless of the selected resampling mode. ## Concurrency model diff --git a/README.md b/README.md index 5711b03..e17a050 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ One constraint worth naming: lazycogs only reads Cloud Optimized GeoTIFFs. If yo | STAC search + spatial indexing | `rustac` (DuckDB + geoparquet) | | COG I/O | `async-geotiff` (Rust, no GDAL) | | Cloud storage | `obstore` | -| Reprojection | backend-neutral seam in `lazycogs`; legacy engine is `pyproj` + numpy | +| Reprojection | backend-neutral seam in `lazycogs`; all public resampling methods currently route through `rust-warp` | | Lazy dataset construction | xarray `BackendEntrypoint` + `LazilyIndexedArray` | ## Installation @@ -33,7 +33,7 @@ One constraint worth naming: lazycogs only reads Cloud Optimized GeoTIFFs. If yo pip install lazycogs ``` -Current development work also pins `rust-warp` from GitHub via uv for local testing of the upcoming reprojection backend swap. That source dependency is still experimental and is not a release-ready packaging story yet. +Current development work also pins `rust-warp` from GitHub via uv while the reprojection migration is in progress. ## Coordinate convention @@ -75,7 +75,11 @@ da = lazycogs.open( bbox=dst_bbox, crs=dst_crs, resolution=10.0, + resampling="nearest", # also supports "bilinear" and "cubic" ) + +# Or use the enum if you prefer: +# lazycogs.open(..., resampling=lazycogs.ResamplingMethod.CUBIC) ``` ### Async loading diff --git a/src/lazycogs/__init__.py b/src/lazycogs/__init__.py index 828a60c..41d4777 100644 --- a/src/lazycogs/__init__.py +++ b/src/lazycogs/__init__.py @@ -1,7 +1,8 @@ """lazycogs: lazy xarray DataArrays from STAC COG collections.""" from lazycogs._chunk_reader import read_chunk, read_chunk_async -from lazycogs._core import open # noqa: A004 +from lazycogs._core import DEFAULT_RESAMPLING, SUPPORTED_RESAMPLING +from lazycogs._core import open as open # noqa: A004 from lazycogs._executor import set_reproject_workers from lazycogs._explain import ( # noqa: F401 — registers da.lazycogs accessor ChunkRead, @@ -20,9 +21,12 @@ MosaicMethodBase, StdevMethod, ) +from lazycogs._reproject import ResamplingMethod from lazycogs._store import store_for __all__ = [ + "DEFAULT_RESAMPLING", + "SUPPORTED_RESAMPLING", "ChunkRead", "CogRead", "CountMethod", @@ -33,6 +37,7 @@ "MeanMethod", "MedianMethod", "MosaicMethodBase", + "ResamplingMethod", "StdevMethod", "align_bbox", "open", diff --git a/src/lazycogs/_backend.py b/src/lazycogs/_backend.py index e9ec514..02388b6 100644 --- a/src/lazycogs/_backend.py +++ b/src/lazycogs/_backend.py @@ -20,6 +20,7 @@ _DUCKDB_EXECUTOR, _run_coroutine, ) +from lazycogs._reproject import ResamplingMethod logger = logging.getLogger(__name__) @@ -61,6 +62,7 @@ class _ChunkReadPlan: chunk_height: Chunk height in pixels. nodata: No-data fill value, or ``None``. mosaic_method_cls: Mosaic method class, or ``None`` for the default. + resampling: Reprojection resampling method for this chunk. store: Pre-configured obstore ``ObjectStore`` instance, or ``None``. max_concurrent_reads: Maximum concurrent COG reads per chunk. warp_cache: Shared warp map cache across time steps. @@ -83,6 +85,7 @@ class _ChunkReadPlan: chunk_height: int nodata: float | None mosaic_method_cls: type[MosaicMethodBase] | None + resampling: ResamplingMethod store: Any | None max_concurrent_reads: int warp_cache: dict @@ -259,6 +262,7 @@ async def _run_one_date( chunk_height=plan.chunk_height, nodata=plan.nodata, mosaic_method_cls=plan.mosaic_method_cls, + resampling=plan.resampling, store=plan.store, max_concurrent_reads=plan.max_concurrent_reads, warp_cache=plan.warp_cache, @@ -335,6 +339,7 @@ class MultiBandStacBackendArray(BackendArray): mosaic_method_cls: Mosaic method class instantiated per chunk, or ``None`` to use the default :class:`~lazycogs._mosaic_methods.FirstMethod`. + resampling: Reprojection resampling method used for chunk reads. store: Pre-configured obstore ``ObjectStore`` instance shared across all chunk reads. When ``None``, each asset HREF is resolved to a store via the thread-local cache in @@ -369,6 +374,7 @@ class MultiBandStacBackendArray(BackendArray): dtype: np.dtype nodata: float | None mosaic_method_cls: type[MosaicMethodBase] | None = field(default=None) + resampling: ResamplingMethod = field(default=ResamplingMethod.NEAREST) store: Any | None = field(default=None) max_concurrent_reads: int = field(default=32) path_from_href: Callable[[str], str] | None = field(default=None) @@ -559,6 +565,7 @@ async def _async_getitem(self, key: tuple[Any, ...]) -> np.ndarray: chunk_height=win.chunk_height, nodata=self.nodata, mosaic_method_cls=self.mosaic_method_cls, + resampling=self.resampling, store=self.store, max_concurrent_reads=self.max_concurrent_reads, warp_cache={}, diff --git a/src/lazycogs/_chunk_reader.py b/src/lazycogs/_chunk_reader.py index dd71302..0d4a46c 100644 --- a/src/lazycogs/_chunk_reader.py +++ b/src/lazycogs/_chunk_reader.py @@ -41,6 +41,7 @@ class _ChunkContext: chunk_width: int chunk_height: int nodata: float | None + resampling: str store: ObjectStore | None path_fn: Callable[[str], str] | None warp_cache: dict[object, object] | None @@ -270,6 +271,7 @@ def _apply_bands_with_warp_cache( dst_crs: CRS, dst_width: int, dst_height: int, + resampling: str = "nearest", warp_cache: dict[object, object] | None = None, ) -> dict[str, tuple[np.ndarray, float | None]]: """Reproject multiple band rasters through the backend-neutral interface. @@ -286,6 +288,7 @@ def _apply_bands_with_warp_cache( dst_crs: CRS of the destination grid. dst_width: Width of the destination grid in pixels. dst_height: Height of the destination grid in pixels. + resampling: Reprojection resampling method. warp_cache: Optional migration-time cache shared across calls. Returns: @@ -316,7 +319,7 @@ def _apply_bands_with_warp_cache( dst_width=dst_width, dst_height=dst_height, nodata=effective_nodata, - resampling="nearest", + resampling=resampling, ), warp_cache=cache, ), @@ -435,6 +438,7 @@ async def _read_band( ctx.dst_crs, ctx.chunk_width, ctx.chunk_height, + ctx.resampling, ctx.warp_cache, ), ) @@ -506,6 +510,7 @@ async def read_chunk_async( chunk_height: int, nodata: float | None = None, mosaic_method_cls: type[MosaicMethodBase] | None = None, + resampling: str = "nearest", store: ObjectStore | None = None, max_concurrent_reads: int = 32, warp_cache: dict | None = None, @@ -530,6 +535,7 @@ async def read_chunk_async( nodata: No-data fill value. mosaic_method_cls: Mosaic method class instantiated once per band. Defaults to :class:`~lazycogs._mosaic_methods.FirstMethod`. + resampling: Reprojection resampling method. store: Optional pre-configured obstore ``ObjectStore`` instance. max_concurrent_reads: Maximum number of COG reads to run concurrently. warp_cache: Optional cache shared across calls for reusing warp maps @@ -553,6 +559,7 @@ async def read_chunk_async( chunk_width=chunk_width, chunk_height=chunk_height, nodata=nodata, + resampling=resampling, store=store, path_fn=path_fn, warp_cache=warp_cache, @@ -612,6 +619,7 @@ def read_chunk( chunk_height: int, nodata: float | None = None, mosaic_method_cls: type[MosaicMethodBase] | None = None, + resampling: str = "nearest", store: ObjectStore | None = None, max_concurrent_reads: int = 32, warp_cache: dict | None = None, @@ -631,6 +639,7 @@ def read_chunk( nodata: No-data fill value. mosaic_method_cls: Mosaic method class instantiated once per band. Defaults to :class:`~lazycogs._mosaic_methods.FirstMethod`. + resampling: Reprojection resampling method. store: Optional pre-configured obstore ``ObjectStore`` instance. max_concurrent_reads: Maximum number of COG reads to run concurrently. warp_cache: Optional cache shared across calls for reusing warp maps @@ -654,6 +663,7 @@ def read_chunk( chunk_height=chunk_height, nodata=nodata, mosaic_method_cls=mosaic_method_cls, + resampling=resampling, store=store, max_concurrent_reads=max_concurrent_reads, warp_cache=warp_cache, diff --git a/src/lazycogs/_core.py b/src/lazycogs/_core.py index 20e28a0..c368169 100644 --- a/src/lazycogs/_core.py +++ b/src/lazycogs/_core.py @@ -17,6 +17,7 @@ from lazycogs._cql2 import _extract_filter_fields, _sortby_fields from lazycogs._grid import compute_output_grid from lazycogs._mosaic_methods import FirstMethod, MosaicMethodBase +from lazycogs._reproject import ResamplingMethod from lazycogs._store import resolve from lazycogs._temporal import _TemporalGrouper, grouper_from_period @@ -28,6 +29,9 @@ logger = logging.getLogger(__name__) +DEFAULT_RESAMPLING = ResamplingMethod.NEAREST +SUPPORTED_RESAMPLING = tuple(ResamplingMethod) + class _CompactDateArray(np.ndarray): """Numpy datetime64 array subclass with a compact display for xarray HTML repr.""" @@ -251,6 +255,28 @@ def _build_time_steps( return filter_strings, time_coords +def _validate_resampling(resampling: str | ResamplingMethod) -> ResamplingMethod: + """Return ``resampling`` as a supported enum, else raise ``ValueError``. + + Args: + resampling: User-provided resampling method name. + + Returns: + Normalized resampling enum value. + + Raises: + ValueError: If *resampling* is not currently supported. + + """ + try: + return ResamplingMethod(resampling) + except ValueError as exc: + supported = ", ".join(SUPPORTED_RESAMPLING) + raise ValueError( + f"Unsupported resampling {resampling!r}. Supported values: {supported}.", + ) from exc + + def _build_dataarray( *, parquet_path: str, @@ -269,6 +295,7 @@ def _build_dataarray( out_dtype: np.dtype, method_cls: type[MosaicMethodBase], chunks: dict[str, int] | None, + resampling: ResamplingMethod, store: ObjectStore | None = None, max_concurrent_reads: int = 32, path_from_href: Callable[[str], str] | None = None, @@ -299,6 +326,7 @@ def _build_dataarray( out_dtype: Output array dtype. method_cls: Mosaic method class. chunks: Passed to ``DataArray.chunk()`` if not ``None``. + resampling: Reprojection resampling method for chunk reads. store: Pre-configured obstore ``ObjectStore`` instance. When provided, it is used directly for all asset reads instead of resolving a store from each HREF. @@ -333,6 +361,7 @@ def _build_dataarray( dtype=out_dtype, nodata=nodata, mosaic_method_cls=method_cls, + resampling=resampling, store=store, max_concurrent_reads=max_concurrent_reads, path_from_href=path_from_href, @@ -439,6 +468,7 @@ def open( # noqa: A001 nodata: float | None = None, dtype: str | np.dtype | None = None, mosaic_method: type[MosaicMethodBase] | None = None, + resampling: str | ResamplingMethod = DEFAULT_RESAMPLING, time_period: str = "P1D", store: ObjectStore | None = None, max_concurrent_reads: int = 32, @@ -476,6 +506,10 @@ def open( # noqa: A001 dtype: Output array dtype. Defaults to ``float32``. mosaic_method: Mosaic method class (not instance) to use. Defaults to :class:`~lazycogs._mosaic_methods.FirstMethod`. + resampling: Reprojection resampling method. Supported values are + ``"nearest"`` (default), ``"bilinear"``, and ``"cubic"``. + Validation happens at open time so unsupported values fail before + any chunk reads begin. time_period: ISO 8601 duration string controlling how items are grouped into time steps. Supported forms: ``PnD`` (days), ``P1W`` (ISO calendar week), ``P1M`` (calendar month), ``P1Y`` @@ -558,8 +592,9 @@ def strip_bucket(href: str) -> str: "To query a hive-partitioned directory, pass a duckdb_client.", ) - # Validate time_period early before any I/O so bad values fail fast. + # Validate user-facing options early before any I/O so bad values fail fast. grouper = grouper_from_period(time_period) + resolved_resampling = _validate_resampling(resampling) dst_crs = CRS.from_user_input(crs) @@ -654,6 +689,7 @@ def strip_bucket(href: str) -> str: out_dtype=out_dtype, method_cls=method_cls, chunks=chunks, + resampling=resolved_resampling, store=store, max_concurrent_reads=max_concurrent_reads, path_from_href=path_from_href, diff --git a/src/lazycogs/_reproject.py b/src/lazycogs/_reproject.py index f55e21f..77fe87b 100644 --- a/src/lazycogs/_reproject.py +++ b/src/lazycogs/_reproject.py @@ -4,6 +4,7 @@ import functools from dataclasses import dataclass +from enum import StrEnum from typing import TYPE_CHECKING, Literal import numpy as np @@ -15,6 +16,14 @@ from affine import Affine +class ResamplingMethod(StrEnum): + """Supported public reprojection resampling methods.""" + + NEAREST = "nearest" + BILINEAR = "bilinear" + CUBIC = "cubic" + + @functools.lru_cache(maxsize=256) def _get_transformer(src_crs: CRS, dst_crs: CRS) -> Transformer: """Return a cached ``Transformer`` for a CRS pair. @@ -49,8 +58,7 @@ class ReprojectRequest: dst_width: Destination width in pixels. dst_height: Destination height in pixels. nodata: Fill value for pixels that fall outside the source extent. - resampling: Requested resampling method. Only ``"nearest"`` is - supported by the legacy backend. + resampling: Requested resampling method. """ @@ -62,7 +70,7 @@ class ReprojectRequest: dst_width: int dst_height: int nodata: float | None = None - resampling: str = "nearest" + resampling: ResamplingMethod = ResamplingMethod.NEAREST @dataclass @@ -212,7 +220,7 @@ def _reproject_tile_legacy( return apply_warp_map(request.data, resolved_warp_map, request.nodata) -_DEFAULT_REPROJECT_BACKEND: Literal["legacy", "rust-warp"] = "legacy" +_DEFAULT_REPROJECT_BACKEND: Literal["legacy", "rust-warp"] = "rust-warp" def reproject_tile( diff --git a/src/lazycogs/_rust_warp.py b/src/lazycogs/_rust_warp.py index 9a9e33c..581676e 100644 --- a/src/lazycogs/_rust_warp.py +++ b/src/lazycogs/_rust_warp.py @@ -67,10 +67,16 @@ def _validate_supported_dtype(data: np.ndarray) -> np.dtype: return dtype -def _normalize_nodata(nodata: float | None, dtype: np.dtype) -> float | int | None: - """Cast ``nodata`` to the source dtype so fill semantics match numpy's.""" +def _normalize_nodata(nodata: float | None, dtype: np.dtype) -> float | int: + """Cast ``nodata`` to the source dtype, defaulting to lazycogs' zero fill. + + rust-warp uses ``NaN`` as the implicit fill for floating-point arrays when + ``nodata`` is omitted. lazycogs has historically treated ``nodata=None`` as + a request for zero fill regardless of dtype, so normalize that explicitly + before dispatch. + """ if nodata is None: - return None + return np.array([0]).astype(dtype, casting="unsafe")[0].item() return np.array([nodata]).astype(dtype, casting="unsafe")[0].item() diff --git a/tests/test_backend.py b/tests/test_backend.py index 0e1ce66..98e29fd 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -238,6 +238,29 @@ async def _fake_read_chunk_async(*args, **kwargs): assert result.shape == (2, 1, 4) +def test_multiband_raw_getitem_forwards_resampling(wgs84): + """The selected resampling mode reaches read_chunk_async unchanged.""" + bands = ["B01", "B02"] + multi = _make_multiband_array(wgs84, bands) + multi.resampling = "cubic" + fake_items = [ + {"id": "item-1", "assets": {b: {"href": f"s3://b/{b}.tif"} for b in bands}}, + ] + fake_chunk = {b: np.zeros((1, 1, 4), dtype=np.float32) for b in bands} + + with ( + patch("rustac.DuckdbClient.search", return_value=fake_items), + patch( + "lazycogs._backend.read_chunk_async", + new_callable=AsyncMock, + return_value=fake_chunk, + ) as read_chunk_async_mock, + ): + multi._sync_getitem((slice(0, 2), 0, slice(0, 1), slice(0, 4))) + + assert read_chunk_async_mock.await_args.kwargs["resampling"] == "cubic" + + def test_multiband_raw_getitem_squeeze_band(wgs84): """Integer band index squeezes the band dimension.""" multi = _make_multiband_array(wgs84, ["B01", "B02"]) diff --git a/tests/test_chunk_reader.py b/tests/test_chunk_reader.py index 895a40e..2767e3e 100644 --- a/tests/test_chunk_reader.py +++ b/tests/test_chunk_reader.py @@ -286,6 +286,7 @@ def test_apply_bands_with_warp_cache_same_grid_bypasses_reproject_tile(): crs, dst_width=4, dst_height=4, + resampling="cubic", ) reproject_tile_mock.assert_not_called() @@ -311,9 +312,12 @@ def _spy_compute_warp_map(*args, **kwargs): warp_map_calls.append(True) return result - with patch( - "lazycogs._reproject.compute_warp_map", - side_effect=_spy_compute_warp_map, + with ( + patch("lazycogs._reproject._DEFAULT_REPROJECT_BACKEND", "legacy"), + patch( + "lazycogs._reproject.compute_warp_map", + side_effect=_spy_compute_warp_map, + ), ): results = _apply_bands_with_warp_cache( [("B01", raster_a, crs, None), ("B02", raster_b, crs, None)], @@ -350,9 +354,12 @@ def _spy_compute_warp_map(*args, **kwargs): warp_map_calls.append(True) return result - with patch( - "lazycogs._reproject.compute_warp_map", - side_effect=_spy_compute_warp_map, + with ( + patch("lazycogs._reproject._DEFAULT_REPROJECT_BACKEND", "legacy"), + patch( + "lazycogs._reproject.compute_warp_map", + side_effect=_spy_compute_warp_map, + ), ): results = _apply_bands_with_warp_cache( [("B01", raster_a, crs, None), ("B02", raster_b, crs, None)], @@ -385,9 +392,12 @@ def _spy_compute_warp_map(*args, **kwargs): warp_map_calls.append(True) return result - with patch( - "lazycogs._reproject.compute_warp_map", - side_effect=_spy_compute_warp_map, + with ( + patch("lazycogs._reproject._DEFAULT_REPROJECT_BACKEND", "legacy"), + patch( + "lazycogs._reproject.compute_warp_map", + side_effect=_spy_compute_warp_map, + ), ): _apply_bands_with_warp_cache( [("B01", raster, crs, None)], @@ -411,30 +421,29 @@ def _spy_compute_warp_map(*args, **kwargs): assert len(shared_cache) == 1 -def test_apply_bands_with_warp_cache_uses_rust_backend_when_selected(): - """The chunk-reader seam can dispatch to rust-warp without caller changes.""" +def test_apply_bands_with_warp_cache_uses_rust_backend_for_nearest_by_default(): + """The chunk-reader seam now defaults all reprojection methods to rust-warp.""" crs = CRS.from_epsg(4326) src_transform = Affine(1.0, 0.0, 0.5, 0.0, -1.0, 4.0) dst_transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 4.0) raster = _make_raster(src_transform, 1.0) expected = np.full((1, 4, 4), 7.0, dtype=np.float32) - with ( - patch("lazycogs._reproject._DEFAULT_REPROJECT_BACKEND", "rust-warp"), - patch( - "lazycogs._reproject.reproject_array_rust_warp", - return_value=expected, - ) as rust_backend_mock, - ): + with patch( + "lazycogs._reproject.reproject_array_rust_warp", + return_value=expected, + ) as rust_backend_mock: results = _apply_bands_with_warp_cache( [("B01", raster, crs, None)], dst_transform, crs, dst_width=4, dst_height=4, + resampling="nearest", ) rust_backend_mock.assert_called_once() + assert rust_backend_mock.call_args.kwargs["resampling"] == "nearest" np.testing.assert_array_equal(results["B01"][0], expected) diff --git a/tests/test_core.py b/tests/test_core.py index b371f81..298e4a7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -12,7 +12,13 @@ import lazycogs from lazycogs._backend import MultiBandStacBackendArray -from lazycogs._core import _build_time_steps, _smoketest_store +from lazycogs._core import ( + DEFAULT_RESAMPLING, + SUPPORTED_RESAMPLING, + _build_time_steps, + _smoketest_store, +) +from lazycogs._reproject import ResamplingMethod from lazycogs._temporal import _DayGrouper, _FixedDayGrouper, _MonthGrouper @@ -265,6 +271,67 @@ def test_open_invalid_time_period_raises(): ) +@pytest.mark.parametrize("resampling", SUPPORTED_RESAMPLING) +def test_open_accepts_supported_resampling_values(tmp_path, resampling): + """open() accepts each currently supported resampling value.""" + parquet = tmp_path / "items.parquet" + parquet.write_bytes(b"") + + with ( + patch("lazycogs._core._discover_bands", return_value=["B04"]), + patch("lazycogs._core._smoketest_store"), + patch( + "lazycogs._core._build_time_steps", + return_value=(["2023-01-15"], [np.datetime64("2023-01-15", "D")]), + ), + ): + da = lazycogs.open( + str(parquet), + bbox=(0.0, 0.0, 10.0, 10.0), + crs="EPSG:4326", + resolution=1.0, + resampling=resampling, + ) + + assert da.attrs["_stac_backend"].resampling is resampling + + +def test_open_accepts_resampling_enum(tmp_path): + """open() accepts ``ResamplingMethod`` enum values directly.""" + parquet = tmp_path / "items.parquet" + parquet.write_bytes(b"") + + with ( + patch("lazycogs._core._discover_bands", return_value=["B04"]), + patch("lazycogs._core._smoketest_store"), + patch( + "lazycogs._core._build_time_steps", + return_value=(["2023-01-15"], [np.datetime64("2023-01-15", "D")]), + ), + ): + da = lazycogs.open( + str(parquet), + bbox=(0.0, 0.0, 10.0, 10.0), + crs="EPSG:4326", + resolution=1.0, + resampling=ResamplingMethod.CUBIC, + ) + + assert da.attrs["_stac_backend"].resampling is ResamplingMethod.CUBIC + + +def test_open_rejects_unknown_resampling(): + """open() validates resampling once at API entry.""" + with pytest.raises(ValueError, match="Unsupported resampling"): + lazycogs.open( + "items.parquet", + bbox=(-93.5, 44.5, -93.0, 45.0), + crs="EPSG:4326", + resolution=0.0001, + resampling="lanczos", + ) + + def test_open_works_inside_running_event_loop(tmp_path): """open() does not raise RuntimeError when called inside a running event loop.""" @@ -357,6 +424,7 @@ def test_open_sets_expected_dataarray_attributes(tmp_path): # Internal bookkeeping attributes assert isinstance(da.attrs["_stac_backend"], MultiBandStacBackendArray) + assert da.attrs["_stac_backend"].resampling == DEFAULT_RESAMPLING assert da.attrs["_stac_time_coords"].dtype == np.dtype("datetime64[D]") diff --git a/tests/test_reproject.py b/tests/test_reproject.py index 7c34823..518b4d2 100644 --- a/tests/test_reproject.py +++ b/tests/test_reproject.py @@ -1,5 +1,7 @@ """Tests for _reproject: reproject_array, compute_warp_map, apply_warp_map.""" +from unittest.mock import patch + import numpy as np import pytest from affine import Affine @@ -7,6 +9,7 @@ from lazycogs._reproject import ( ReprojectRequest, + ResamplingMethod, WarpMap, apply_warp_map, compute_warp_map, @@ -99,6 +102,38 @@ def test_reproject_tile_matches_legacy_wrapper(wgs84): ) +def test_reproject_tile_defaults_to_rust_warp_backend(wgs84): + """The default backend selection now routes nearest through rust-warp.""" + src_transform = _make_transform(0.0, 2.0, 1.0) + dst_transform = _make_transform(0.0, 2.0, 0.5) + data = np.arange(4, dtype=np.float32).reshape(1, 2, 2) + calls: list[ResamplingMethod] = [] + + def _fake_rust_warp(**kwargs): + calls.append(kwargs["resampling"]) + return np.zeros((1, 4, 4), dtype=np.float32) + + with patch( + "lazycogs._reproject.reproject_array_rust_warp", + side_effect=_fake_rust_warp, + ): + out = reproject_tile( + ReprojectRequest( + data=data, + src_transform=src_transform, + src_crs=wgs84, + dst_transform=dst_transform, + dst_crs=wgs84, + dst_width=4, + dst_height=4, + resampling=ResamplingMethod.NEAREST, + ), + ) + + assert calls == [ResamplingMethod.NEAREST] + assert out.shape == (1, 4, 4) + + def test_affine_to_rust_warp_uses_six_value_rasterio_order(): """Affine conversion emits the 6-tuple rust-warp expects.""" transform = Affine(2.0, 0.5, 10.0, -0.25, -3.0, 20.0) From be6d4839c085331c60843cc87084b6f0c5d7c4f3 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 13 May 2026 14:57:27 -0500 Subject: [PATCH 5/7] chore: add some reprojection benchmarks, compare rust-warp to rasterio --- CONTRIBUTING.md | 2 +- scripts/format_benchmark_comparison.py | 102 ++++++- scripts/prepare_benchmark_data.py | 144 ++++++--- tests/benchmarks/bench_pipeline.py | 145 +++++++++- tests/benchmarks/conftest.py | 58 ++++ tests/conftest.py | 142 ++++++--- tests/test_format_benchmark_comparison.py | 59 ++++ tests/test_prepare_benchmark_data.py | 83 ++++++ tests/test_rasterio_parity.py | 338 ++++++++++++++++------ 9 files changed, 878 insertions(+), 195 deletions(-) create mode 100644 tests/test_format_benchmark_comparison.py create mode 100644 tests/test_prepare_benchmark_data.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8b60bb8..28adadf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -34,7 +34,7 @@ Benchmarks live in `tests/benchmarks/` and are excluded from the default test ru uv run python scripts/prepare_benchmark_data.py ``` -This queries the Element84 Earth Search STAC API, downloads a small set of COG assets to `.benchmark_data/`, and writes local parquet index files. Pass `--overwrite` to re-download if needed. +This queries the Element84 Earth Search STAC API, downloads a small set of COG assets to `.benchmark_data/`, and writes local parquet index files. The parquet files are always refreshed to point at the current checkout's local `.benchmark_data/cogs/` paths, so rerunning the script fixes stale `file://` HREFs after moving the repo. Pass `--overwrite` to force re-downloading the raw query and COG files. Once the data is in place: diff --git a/scripts/format_benchmark_comparison.py b/scripts/format_benchmark_comparison.py index f4272d5..c20cab4 100755 --- a/scripts/format_benchmark_comparison.py +++ b/scripts/format_benchmark_comparison.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Format a pytest-benchmark comparison as a GitHub-flavored markdown table. +"""Format a pytest-benchmark comparison as GitHub-flavored markdown. Usage: uv run python scripts/format_benchmark_comparison.py \ @@ -7,6 +7,8 @@ --pr '.benchmarks/**/*_pr-*.json' """ +from __future__ import annotations + import argparse import json import logging @@ -16,6 +18,8 @@ logger = logging.getLogger(__name__) REGRESSION_THRESHOLD_PCT = 10 +_SMALL_WINDOW_LABEL = "Small-window reprojection microbenchmarks" +_END_TO_END_LABEL = "End-to-end benchmarks" def find_file(pattern: str) -> Path: @@ -38,31 +42,97 @@ def _ms(seconds: float) -> str: return f"{seconds * 1000:.1f}" -def generate_report(baseline: dict[str, dict], pr: dict[str, dict]) -> str: - """Generate a markdown benchmark comparison table.""" - rows = [] - for name in sorted(baseline): - if name not in pr: - continue - base_mean = baseline[name]["mean"] - pr_mean = pr[name]["mean"] +def _classify_benchmark(name: str) -> str: + """Return the report section label for a benchmark name.""" + if "small_window" in name: + return _SMALL_WINDOW_LABEL + return _END_TO_END_LABEL + + +def _comparison_row(name: str, baseline_stats: dict, pr_stats: dict) -> str: + """Return one markdown table row for a benchmark present in both runs.""" + base_mean = baseline_stats["mean"] + pr_mean = pr_stats["mean"] + + if base_mean == 0: + pct_display = "n/a" + flag = "" + else: pct = (pr_mean - base_mean) / base_mean * 100 sign = "+" if pct >= 0 else "" + pct_display = f"{sign}{pct:.1f}%" flag = " :warning:" if pct > REGRESSION_THRESHOLD_PCT else "" - base_ms, pr_ms = _ms(base_mean), _ms(pr_mean) - row = f"| `{name}` | {base_ms} | {pr_ms} | {sign}{pct:.1f}%{flag} |" - rows.append(row) - table = "\n".join( + base_ms, pr_ms = _ms(base_mean), _ms(pr_mean) + return f"| `{name}` | {base_ms} | {pr_ms} | {pct_display}{flag} |" + + +def _render_comparison_section( + heading: str, + names: list[str], + baseline: dict[str, dict], + pr: dict[str, dict], +) -> str: + """Render one benchmark comparison table section.""" + rows = [_comparison_row(name, baseline[name], pr[name]) for name in names] + return "\n".join( [ + f"### {heading}", + "", "| Test | Baseline (ms) | PR (ms) | Change |", "|------|:-------------:|:-------:|-------:|", *rows, ], ) - return ( - f"\n## Benchmark Comparison\n\n{table}\n" - ) + + +def _render_name_list_section(heading: str, names: list[str]) -> str: + """Render a markdown bullet list of benchmark names.""" + lines = [f"## {heading}", "", *[f"- `{name}`" for name in names]] + return "\n".join(lines) + + +def generate_report(baseline: dict[str, dict], pr: dict[str, dict]) -> str: + """Generate a markdown benchmark comparison report.""" + shared_names = sorted(set(baseline) & set(pr)) + new_names = sorted(set(pr) - set(baseline)) + missing_names = sorted(set(baseline) - set(pr)) + + shared_sections: list[str] = [] + for heading in (_END_TO_END_LABEL, _SMALL_WINDOW_LABEL): + names = [name for name in shared_names if _classify_benchmark(name) == heading] + if names: + shared_sections.append( + _render_comparison_section(heading, names, baseline, pr), + ) + + body_parts = [ + "", + "## Benchmark Comparison", + "", + ] + if shared_sections: + body_parts.extend(shared_sections) + else: + body_parts.append("No benchmarks were present in both runs.") + + if new_names: + body_parts.extend( + ["", _render_name_list_section("New benchmarks in PR", new_names)], + ) + if missing_names: + body_parts.extend( + [ + "", + _render_name_list_section( + "Benchmarks missing from PR", + missing_names, + ), + ], + ) + + body_parts.append("") + return "\n".join(body_parts) def main() -> None: diff --git a/scripts/prepare_benchmark_data.py b/scripts/prepare_benchmark_data.py index 7d3c77d..af13405 100755 --- a/scripts/prepare_benchmark_data.py +++ b/scripts/prepare_benchmark_data.py @@ -3,7 +3,7 @@ Queries the Element84 Earth Search STAC API for Sentinel-2 items over western Colorado, downloads the selected band assets to a local directory, then writes -a new parquet file with hrefs pointing to the local files. Also synthesises an +a new parquet file with hrefs pointing to the local files. Also synthesises an expanded parquet with 12 monthly time steps (cloned from the real items, same asset hrefs) for concurrency benchmarks. @@ -12,6 +12,11 @@ benchmark_items.parquet parquet index with file:// hrefs expanded_benchmark_items.parquet 12 synthetic time steps, same COG files +The parquet files are always rewritten to point at this checkout's local +``.benchmark_data/cogs`` paths. That keeps benchmark HREFs valid after moving +or renaming the repository directory. ``--overwrite`` only forces re-downloads +of the raw STAC query and local COG files. + Usage: uv run python scripts/prepare_benchmark_data.py uv run python scripts/prepare_benchmark_data.py --overwrite @@ -27,7 +32,7 @@ from urllib.parse import urlparse import rustac -from obstore.store import from_url +from obstore.store import LocalStore, from_url logger = logging.getLogger(__name__) @@ -39,6 +44,7 @@ # Red (10m) + Narrow NIR (20m) — sufficient for NDVI benchmarks BANDS = ["red", "nir08"] LIMIT = 10 +DOWNLOAD_CONCURRENCY = 8 # 12 monthly dates used for the expanded concurrency benchmark dataset. # One per month from 2024-01 through 2024-12, anchored to the 15th so dates @@ -48,25 +54,85 @@ DATA_DIR = Path(__file__).parents[1] / ".benchmark_data" -def _download(href: str, dest: Path) -> None: - """Download a cloud object to a local file using obstore.""" - dest.parent.mkdir(parents=True, exist_ok=True) +def _local_asset_path(cog_dir: Path, item_id: str, band: str) -> Path: + """Return the local benchmark asset path for one item band.""" + return cog_dir / item_id / f"{band}.tif" + + +def _remote_store_key(href: str) -> tuple[str, str]: + """Return the obstore root URL and object key for an asset HREF.""" parsed = urlparse(href) - root_url = f"{parsed.scheme}://{parsed.netloc}" - path = parsed.path.lstrip("/") - kwargs = {"skip_signature": True} - store = from_url(root_url, **kwargs) - logger.info("Downloading %s", href) - result = store.get(path) - dest.write_bytes(result.bytes()) - logger.info("Wrote %s (%.1f MB)", dest, dest.stat().st_size / 1_048_576) + return f"{parsed.scheme}://{parsed.netloc}", parsed.path.lstrip("/") + + +async def _download( + href: str, + dest: Path, + *, + local_store: LocalStore, + remote_stores: dict[str, object], + semaphore: asyncio.Semaphore, +) -> None: + """Download a cloud object to a local file using obstore's async APIs.""" + root_url, remote_key = _remote_store_key(href) + if root_url not in remote_stores: + remote_stores[root_url] = from_url(root_url, skip_signature=True) + local_key = str(dest.relative_to(Path(local_store.prefix))) + + async with semaphore: + logger.info("Downloading %s", href) + result = await remote_stores[root_url].get_async(remote_key) + await local_store.put_async(local_key, result) + size = await asyncio.to_thread(lambda: dest.stat().st_size) + logger.info("Wrote %s (%.1f MB)", dest, size / 1_048_576) + + +async def _localize_item_assets( + item: dict, + cog_dir: Path, + *, + overwrite: bool, + local_store: LocalStore, + remote_stores: dict[str, object], + semaphore: asyncio.Semaphore, +) -> dict: + """Return ``item`` with benchmark assets rewritten to local ``file://`` HREFs.""" + item_id = item["id"] + local_assets = {} + download_tasks = [] + + for band in BANDS: + if band not in item.get("assets", {}): + logger.warning("Item %s has no asset %r; skipping.", item_id, band) + continue + href = item["assets"][band]["href"] + local_path = _local_asset_path(cog_dir, item_id, band) + if overwrite or not local_path.exists(): + download_tasks.append( + _download( + href, + local_path, + local_store=local_store, + remote_stores=remote_stores, + semaphore=semaphore, + ), + ) + local_assets[band] = { + **item["assets"][band], + "href": local_path.as_uri(), + } + + if download_tasks: + await asyncio.gather(*download_tasks) + + return {**item, "assets": local_assets} def _expand_items(source_items: list[dict], dates: list[str]) -> list[dict]: """Clone source_items across synthetic dates by round-robin assignment. - Each clone keeps the original geometry, bbox, and asset hrefs. Only the - ``id`` and ``properties.datetime`` are changed. The result has one item + Each clone keeps the original geometry, bbox, and asset hrefs. Only the + ``id`` and ``properties.datetime`` are changed. The result has one item per date, suitable for building a multi-time-step benchmark parquet without downloading additional data. @@ -111,37 +177,35 @@ async def main(*, overwrite: bool = False) -> None: items: list[dict] = rustac.search_sync(str(raw_parquet), use_duckdb=True) logger.info("Found %d items", len(items)) - local_items = [] - for item in items: - item_id = item["id"] - local_assets = {} - for band in BANDS: - if band not in item.get("assets", {}): - logger.warning("Item %s has no asset %r; skipping.", item_id, band) - continue - href = item["assets"][band]["href"] - local_path = cog_dir / item_id / f"{band}.tif" - if overwrite or not local_path.exists(): - _download(href, local_path) - local_assets[band] = { - **item["assets"][band], - "href": local_path.as_uri(), - } - local_items.append({**item, "assets": local_assets}) + local_store = LocalStore(prefix=str(cog_dir), mkdir=True) + remote_stores: dict[str, object] = {} + semaphore = asyncio.Semaphore(DOWNLOAD_CONCURRENCY) + local_items = await asyncio.gather( + *[ + _localize_item_assets( + item, + cog_dir, + overwrite=overwrite, + local_store=local_store, + remote_stores=remote_stores, + semaphore=semaphore, + ) + for item in items + ], + ) out_parquet = DATA_DIR / "benchmark_items.parquet" rustac.write_sync(str(out_parquet), local_items) logger.info("Wrote benchmark parquet: %s", out_parquet) expanded_parquet = DATA_DIR / "expanded_benchmark_items.parquet" - if overwrite or not expanded_parquet.exists(): - expanded = _expand_items(local_items, SYNTHETIC_DATES) - rustac.write_sync(str(expanded_parquet), expanded) - logger.info( - "Wrote expanded benchmark parquet (%d time steps): %s", - len(SYNTHETIC_DATES), - expanded_parquet, - ) + expanded = _expand_items(local_items, SYNTHETIC_DATES) + rustac.write_sync(str(expanded_parquet), expanded) + logger.info( + "Wrote expanded benchmark parquet (%d time steps): %s", + len(SYNTHETIC_DATES), + expanded_parquet, + ) logger.info( "Run benchmarks with: uv run pytest tests/benchmarks/ --benchmark-enable", diff --git a/tests/benchmarks/bench_pipeline.py b/tests/benchmarks/bench_pipeline.py index 83259f3..485c611 100644 --- a/tests/benchmarks/bench_pipeline.py +++ b/tests/benchmarks/bench_pipeline.py @@ -1,16 +1,24 @@ -"""End-to-end benchmarks using the public lazycogs.open() API. +"""End-to-end and micro benchmarks for lazycogs reprojection. -These benchmarks require local benchmark data. See scripts/prepare_benchmark_data.py. +These benchmarks require local benchmark data only for the public ``open()`` +benchmarks. The direct reprojection micro-benchmarks use synthetic arrays and +run without any prepared dataset. Run with: uv run pytest tests/benchmarks/ --benchmark-enable uv run pytest tests/benchmarks/ --benchmark-enable --benchmark-save= """ +from __future__ import annotations + +import numpy as np import pytest +from affine import Affine +from pyproj import CRS import lazycogs from lazycogs import FirstMethod, MedianMethod, MosaicMethodBase, set_reproject_workers +from lazycogs._reproject import ReprojectRequest, ResamplingMethod, reproject_tile from .conftest import ( BENCHMARK_BBOX, @@ -23,8 +31,116 @@ ) +def _benchmark_request( + *, + same_grid: bool = False, + dst_resolution: float = 20.0, + bands: int = 3, + size: int = 64, +) -> ReprojectRequest: + """Build a representative small-window reprojection request.""" + src_crs = CRS.from_epsg(32632) + src_transform = Affine(10.0, 0.0, 500_000.0, 0.0, -10.0, 5_600_000.0) + data = np.arange(bands * size * size, dtype=np.float32).reshape(bands, size, size) + + if same_grid: + dst_transform = src_transform + dst_width = size + dst_height = size + else: + dst_transform = Affine( + dst_resolution, + 0.0, + 500_320.0, + 0.0, + -dst_resolution, + 5_599_680.0, + ) + dst_width = size + dst_height = size + + return ReprojectRequest( + data=data, + src_transform=src_transform, + src_crs=src_crs, + dst_transform=dst_transform, + dst_crs=src_crs, + dst_width=dst_width, + dst_height=dst_height, + nodata=-9999.0, + ) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("backend", ["legacy", "rust-warp"]) +def test_small_window_nearest_backend_comparison(benchmark, backend: str) -> None: + """Compare legacy vs rust-warp on the same representative small window.""" + request = _benchmark_request(dst_resolution=20.0) + + benchmark(reproject_tile, request, backend=backend) + + @pytest.mark.benchmark -def test_open_overhead(benchmark, benchmark_parquet: str) -> None: +@pytest.mark.parametrize( + ("label", "reproject_request", "resampling"), + [ + pytest.param( + "same_grid_noop", + _benchmark_request(same_grid=True), + ResamplingMethod.NEAREST, + id="same_grid_noop", + ), + pytest.param( + "nearest", + _benchmark_request(dst_resolution=20.0), + ResamplingMethod.NEAREST, + id="nearest", + ), + pytest.param( + "bilinear", + _benchmark_request(dst_resolution=15.0), + ResamplingMethod.BILINEAR, + id="bilinear", + ), + pytest.param( + "cubic", + _benchmark_request(dst_resolution=15.0), + ResamplingMethod.CUBIC, + id="cubic", + ), + ], +) +def test_small_window_reprojection_modes( + benchmark, + label: str, + reproject_request: ReprojectRequest, + resampling: ResamplingMethod, +) -> None: + """Show the cost gap between no-op, nearest, and interpolating modes.""" + benchmark.extra_info["mode"] = label + benchmark( + reproject_tile, + ReprojectRequest( + data=reproject_request.data, + src_transform=reproject_request.src_transform, + src_crs=reproject_request.src_crs, + dst_transform=reproject_request.dst_transform, + dst_crs=reproject_request.dst_crs, + dst_width=reproject_request.dst_width, + dst_height=reproject_request.dst_height, + nodata=reproject_request.nodata, + resampling=resampling, + ), + backend="rust-warp", + ) + + +@pytest.mark.benchmark +def test_open_overhead( + benchmark, + benchmark_parquet: str, + benchmark_open_kwargs: dict[str, object], +) -> None: """Phase 0: time the open() call without triggering any COG reads. Measures parquet queries, band discovery, time-step building, and grid @@ -36,11 +152,16 @@ def test_open_overhead(benchmark, benchmark_parquet: str) -> None: bbox=BENCHMARK_BBOX, crs=BENCHMARK_CRS, resolution=60.0, + **benchmark_open_kwargs, ) @pytest.mark.benchmark -def test_full_compute(benchmark, benchmark_parquet: str) -> None: +def test_full_compute( + benchmark, + benchmark_parquet: str, + benchmark_open_kwargs: dict[str, object], +) -> None: """Full pipeline: open + .compute() including local COG I/O.""" def run() -> object: @@ -49,6 +170,7 @@ def run() -> object: bbox=BENCHMARK_BBOX, crs=BENCHMARK_CRS, resolution=60.0, + **benchmark_open_kwargs, ) return da.compute() @@ -60,6 +182,7 @@ def run() -> object: def test_mosaic_method( benchmark, benchmark_parquet: str, + benchmark_open_kwargs: dict[str, object], method: type[MosaicMethodBase], ) -> None: """Compare mosaic strategy cost end-to-end.""" @@ -72,6 +195,7 @@ def run() -> object: resolution=60.0, time_period="P1M", mosaic_method=method, + **benchmark_open_kwargs, ) return da.compute() @@ -83,6 +207,7 @@ def run() -> object: def test_reproject_workers( benchmark, expanded_benchmark_parquet: str, + benchmark_open_kwargs: dict[str, object], n_workers: int, ) -> None: """Measure throughput as reprojection thread count varies. @@ -102,6 +227,7 @@ def run() -> object: resolution=60.0, time_period="P1M", chunks={"time": 1}, + **benchmark_open_kwargs, ) return da.compute() @@ -113,7 +239,11 @@ def run() -> object: @pytest.mark.benchmark -def test_native_crs_resolution(benchmark, benchmark_parquet: str) -> None: +def test_native_crs_resolution( + benchmark, + benchmark_parquet: str, + benchmark_open_kwargs: dict[str, object], +) -> None: """Full pipeline using the assets' native CRS and resolution. Requests data in EPSG:32612 at 10 m — exactly the source COG projection and @@ -128,6 +258,7 @@ def run() -> object: bbox=BENCHMARK_NATIVE_BBOX, crs=BENCHMARK_NATIVE_CRS, resolution=BENCHMARK_NATIVE_RESOLUTION, + **benchmark_open_kwargs, ) return da.compute() @@ -143,6 +274,7 @@ def run() -> object: def test_time_step_parallelism( benchmark, expanded_benchmark_parquet: str, + benchmark_open_kwargs: dict[str, object], chunks: dict | None, ) -> None: """Compare native time-step thread pool vs Dask across 24 time steps. @@ -161,6 +293,7 @@ def run() -> object: resolution=60.0, time_period="P1M", chunks=chunks, + **benchmark_open_kwargs, ) return da.compute() @@ -176,6 +309,7 @@ def run() -> object: def test_band_access_pattern( benchmark, expanded_benchmark_parquet: str, + benchmark_open_kwargs: dict[str, object], bands: list[str], ) -> None: """Compare single-band vs multi-band compute cost. @@ -195,6 +329,7 @@ def run() -> object: time_period="P1M", bands=bands, chunks={"time": 1}, + **benchmark_open_kwargs, ) return da.compute() diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py index 6fb7103..b6b3d2d 100644 --- a/tests/benchmarks/conftest.py +++ b/tests/benchmarks/conftest.py @@ -4,10 +4,13 @@ """ from pathlib import Path +from urllib.parse import urlparse import pytest +from rustac import DuckdbClient _DATA_DIR = Path(__file__).parents[2] / ".benchmark_data" +_COG_DIR = _DATA_DIR / "cogs" _PARQUET = _DATA_DIR / "benchmark_items.parquet" _EXPANDED_PARQUET = _DATA_DIR / "expanded_benchmark_items.parquet" @@ -33,6 +36,53 @@ BENCHMARK_NATIVE_RESOLUTION = 10.0 # red band native pixel size +def _benchmark_path_from_href(href: str) -> str: + """Return the local file path for a benchmark asset HREF. + + Benchmark parquet files may outlive a local checkout move, leaving absolute + ``file://`` HREFs that still point at an older workspace root. When that + happens, remap the asset path back into this checkout's ``.benchmark_data`` + directory using the stable ``cogs//.tif`` suffix. + """ + parsed = urlparse(href) + if parsed.scheme != "file": + return parsed.path.lstrip("/") + + path = Path(parsed.path) + if path.exists(): + return str(path) + + try: + cogs_index = path.parts.index("cogs") + except ValueError: + return str(path) + + remapped = _DATA_DIR.joinpath(*path.parts[cogs_index:]) + return str(remapped) + + +def _assert_benchmark_assets_available(parquet_path: Path) -> None: + """Skip benchmark tests when the local benchmark assets are unavailable.""" + client = DuckdbClient() + items = client.search(str(parquet_path), max_items=1, include=["assets"]) + if not items: + pytest.skip(f"Benchmark parquet {parquet_path} contains no STAC items.") + + assets = items[0].get("assets", {}) + for band in BENCHMARK_MULTIBAND: + href = assets.get(band, {}).get("href") + if not href: + continue + if Path(_benchmark_path_from_href(href)).exists(): + return + + pytest.skip( + "Benchmark parquet references local COG paths that are unavailable in " + "this checkout. Re-run `uv run python scripts/prepare_benchmark_data.py` " + "to refresh the benchmark dataset.", + ) + + @pytest.fixture(scope="session") def benchmark_parquet() -> str: """Path to the local benchmark parquet file. @@ -44,6 +94,7 @@ def benchmark_parquet() -> str: "Benchmark data not found. " "Run `uv run python scripts/prepare_benchmark_data.py` first.", ) + _assert_benchmark_assets_available(_PARQUET) return str(_PARQUET) @@ -58,4 +109,11 @@ def expanded_benchmark_parquet() -> str: "Expanded benchmark data not found. " "Run `uv run python scripts/prepare_benchmark_data.py` first.", ) + _assert_benchmark_assets_available(_EXPANDED_PARQUET) return str(_EXPANDED_PARQUET) + + +@pytest.fixture(scope="session") +def benchmark_open_kwargs() -> dict[str, object]: + """Common ``lazycogs.open()`` kwargs for local benchmark datasets.""" + return {"path_from_href": _benchmark_path_from_href} diff --git a/tests/conftest.py b/tests/conftest.py index 0b1061a..47e50c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,66 +14,112 @@ from pyproj import CRS +def _write_synthetic_cog( + cog_path: Path, + *, + data: np.ndarray, + transform: Affine, + crs: CRS, + nodata: float, + overview_resampling: rasterio.enums.Resampling = rasterio.enums.Resampling.nearest, +) -> Path: + """Write ``data`` to a tiled GeoTIFF with built overviews. + + The file is written with the two-step recipe required by ``async_geotiff``: + first create a temporary GeoTIFF and build overviews, then copy it to a + tiled output while preserving the overview IFDs. + """ + with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp: + tmp_path = Path(tmp.name) + + try: + with rasterio.open( + tmp_path, + "w", + driver="GTiff", + height=data.shape[0], + width=data.shape[1], + count=1, + dtype=data.dtype, + crs=crs.to_wkt(), + transform=transform, + nodata=nodata, + ) as dst: + dst.write(data[np.newaxis]) + + with rasterio.open(tmp_path, "r+") as dst: + dst.build_overviews([2, 4, 8, 16], overview_resampling) + dst.update_tags( + ns="rio_overview", + resampling=overview_resampling.name, + ) + + rasterio.shutil.copy( + str(tmp_path), + str(cog_path), + driver="GTiff", + copy_src_overviews=True, + tiled=True, + blockxsize=64, + blockysize=64, + ) + finally: + tmp_path.unlink(missing_ok=True) + + return cog_path + + @pytest.fixture(scope="session") def synthetic_cog(tmp_path_factory) -> Path: - """Write a small synthetic COG with four overview levels to a temp file. - - Properties: - - Native resolution: 10 m, 320 x 320 pixels - - CRS: UTM zone 32N (EPSG:32632) - - Origin: 500 000 E, 5 600 000 N - - Overview shrink factors: [2, 4, 8, 16] → resolutions 20, 40, 80, 160 m - - Pixel values: unique uint16 per pixel (col + row * width), so every - sampling position returns a deterministic, distinct value that lets - tests distinguish which source pixel was sampled. - - Nodata: 0 (pixels shifted by 1 to avoid accidental nodata) - - The file is written using the standard two-step COG recipe so that both - the full-resolution IFD and all overview IFDs are tiled (required by - async_geotiff). - """ + """Write a synthetic nearest-neighbor parity COG to a temp file.""" cog_path = tmp_path_factory.mktemp("cog") / "synthetic.tif" native_res = 10.0 size = 2048 minx, maxy = 500_000.0, 5_600_000.0 transform = Affine(native_res, 0.0, minx, 0.0, -native_res, maxy) - crs_wkt = CRS.from_epsg(32632).to_wkt() + crs = CRS.from_epsg(32632) rows, cols = np.meshgrid(np.arange(size), np.arange(size), indexing="ij") data = ((cols + rows * size) % 65535 + 1).astype(np.uint16) - # Step 1: write to a temporary stripped GeoTIFF and build overviews. - with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp: - tmp_path = Path(tmp.name) - - with rasterio.open( - tmp_path, - "w", - driver="GTiff", - height=size, - width=size, - count=1, - dtype="uint16", - crs=crs_wkt, + return _write_synthetic_cog( + cog_path, + data=data, transform=transform, + crs=crs, nodata=0, - ) as dst: - dst.write(data[np.newaxis]) - - with rasterio.open(tmp_path, "r+") as dst: - dst.build_overviews([2, 4, 8, 16], rasterio.enums.Resampling.nearest) - dst.update_tags(ns="rio_overview", resampling="nearest") - - # Step 2: copy to a tiled COG so async_geotiff can read all IFDs. - rasterio.shutil.copy( - str(tmp_path), - str(cog_path), - driver="GTiff", - copy_src_overviews=True, - tiled=True, - blockxsize=64, - blockysize=64, + overview_resampling=rasterio.enums.Resampling.nearest, ) - tmp_path.unlink() - return cog_path + +@pytest.fixture(scope="session") +def continuous_synthetic_cog(tmp_path_factory) -> Path: + """Write a smooth float32 COG for interpolation parity tests.""" + cog_path = tmp_path_factory.mktemp("cog") / "continuous.tif" + native_res = 10.0 + size = 1024 + minx, maxy = 500_000.0, 5_600_000.0 + transform = Affine(native_res, 0.0, minx, 0.0, -native_res, maxy) + crs = CRS.from_epsg(32632) + + rows, cols = np.meshgrid( + np.arange(size, dtype=np.float32), + np.arange(size, dtype=np.float32), + indexing="ij", + ) + data = ( + cols * np.float32(0.5) + + rows * np.float32(1.25) + + np.sin(cols / np.float32(32.0)) * np.float32(5.0) + + np.cos(rows / np.float32(40.0)) * np.float32(7.0) + + np.float32(1000.0) + ).astype(np.float32) + + return _write_synthetic_cog( + cog_path, + data=data, + transform=transform, + crs=crs, + nodata=np.float32(-9999.0), + overview_resampling=rasterio.enums.Resampling.average, + ) diff --git a/tests/test_format_benchmark_comparison.py b/tests/test_format_benchmark_comparison.py new file mode 100644 index 0000000..ccea2b6 --- /dev/null +++ b/tests/test_format_benchmark_comparison.py @@ -0,0 +1,59 @@ +"""Tests for the benchmark comparison formatter script.""" + +from __future__ import annotations + +import importlib.util +from pathlib import Path + +_SCRIPT_PATH = Path(__file__).parents[1] / "scripts" / "format_benchmark_comparison.py" +_SPEC = importlib.util.spec_from_file_location( + "format_benchmark_comparison", + _SCRIPT_PATH, +) +assert _SPEC is not None +assert _SPEC.loader is not None +format_benchmark_comparison = importlib.util.module_from_spec(_SPEC) +_SPEC.loader.exec_module(format_benchmark_comparison) + + +def test_generate_report_includes_sections_for_shared_new_and_missing() -> None: + """The report should surface changed, added, and removed benchmarks.""" + baseline = { + "test_open_overhead": {"mean": 0.010}, + "test_small_window_reprojection_modes[nearest]": {"mean": 0.002}, + "test_removed_benchmark": {"mean": 0.005}, + } + pr = { + "test_open_overhead": {"mean": 0.012}, + "test_small_window_reprojection_modes[nearest]": {"mean": 0.001}, + "test_small_window_reprojection_modes[cubic]": {"mean": 0.003}, + "test_new_end_to_end_benchmark": {"mean": 0.020}, + } + + report = format_benchmark_comparison.generate_report(baseline, pr) + + assert "## Benchmark Comparison" in report + assert "### End-to-end benchmarks" in report + assert "### Small-window reprojection microbenchmarks" in report + assert "| `test_open_overhead` | 10.0 | 12.0 | +20.0% :warning: |" in report + assert ( + "| `test_small_window_reprojection_modes[nearest]` | 2.0 | 1.0 | -50.0% |" + in report + ) + assert "## New benchmarks in PR" in report + assert "`test_small_window_reprojection_modes[cubic]`" in report + assert "`test_new_end_to_end_benchmark`" in report + assert "## Benchmarks missing from PR" in report + assert "`test_removed_benchmark`" in report + + +def test_generate_report_handles_empty_shared_benchmarks() -> None: + """The report should still render when only added or removed tests exist.""" + report = format_benchmark_comparison.generate_report( + baseline={"test_removed": {"mean": 0.005}}, + pr={"test_added": {"mean": 0.007}}, + ) + + assert "No benchmarks were present in both runs." in report + assert "`test_added`" in report + assert "`test_removed`" in report diff --git a/tests/test_prepare_benchmark_data.py b/tests/test_prepare_benchmark_data.py new file mode 100644 index 0000000..ee10dfd --- /dev/null +++ b/tests/test_prepare_benchmark_data.py @@ -0,0 +1,83 @@ +"""Tests for the benchmark data preparation script.""" + +from __future__ import annotations + +import asyncio +import importlib.util +from pathlib import Path + +_SCRIPT_PATH = Path(__file__).parents[1] / "scripts" / "prepare_benchmark_data.py" +_SPEC = importlib.util.spec_from_file_location("prepare_benchmark_data", _SCRIPT_PATH) +assert _SPEC is not None +assert _SPEC.loader is not None +prepare_benchmark_data = importlib.util.module_from_spec(_SPEC) +_SPEC.loader.exec_module(prepare_benchmark_data) + + +def test_main_rewrites_expanded_parquet_without_overwrite( + tmp_path, + monkeypatch, +) -> None: + """Running the script refreshes expanded parquet HREFs even when it exists.""" + monkeypatch.setattr(prepare_benchmark_data, "DATA_DIR", tmp_path) + + raw_parquet = tmp_path / "raw_items.parquet" + raw_parquet.write_text("placeholder") + expanded_parquet = tmp_path / "expanded_benchmark_items.parquet" + expanded_parquet.write_text("stale") + + source_items = [ + { + "id": "item-001", + "assets": { + "red": {"href": "https://example.com/red.tif"}, + "nir08": {"href": "https://example.com/nir08.tif"}, + }, + "properties": {"datetime": "2025-07-04T00:00:00Z"}, + }, + ] + writes: dict[str, list[dict]] = {} + + async def _fake_search_to(*args, **kwargs) -> None: + raise AssertionError( + "search_to should not run when raw_items.parquet already exists", + ) + + def _fake_search_sync(path: str, **kwargs) -> list[dict]: + assert path == str(raw_parquet) + assert kwargs == {"use_duckdb": True} + return source_items + + async def _fake_download(href: str, dest: Path, **kwargs) -> None: + await asyncio.to_thread(dest.parent.mkdir, parents=True, exist_ok=True) + await asyncio.to_thread(dest.write_text, f"downloaded from {href}") + + def _fake_write_sync(path: str, items: list[dict]) -> None: + writes[path] = items + Path(path).write_text("rewritten") + + monkeypatch.setattr(prepare_benchmark_data.rustac, "search_to", _fake_search_to) + monkeypatch.setattr( + prepare_benchmark_data.rustac, + "search_sync", + _fake_search_sync, + ) + monkeypatch.setattr(prepare_benchmark_data, "_download", _fake_download) + monkeypatch.setattr(prepare_benchmark_data.rustac, "write_sync", _fake_write_sync) + + asyncio.run(prepare_benchmark_data.main(overwrite=False)) + + benchmark_parquet = tmp_path / "benchmark_items.parquet" + assert str(benchmark_parquet) in writes + assert str(expanded_parquet) in writes + + expanded_items = writes[str(expanded_parquet)] + assert len(expanded_items) == len(prepare_benchmark_data.SYNTHETIC_DATES) + assert expanded_items[0]["id"] == "synthetic-0000" + assert expanded_items[0]["properties"]["datetime"] == "2024-01-15T12:00:00Z" + + expected_red_href = (tmp_path / "cogs" / "item-001" / "red.tif").as_uri() + assert ( + writes[str(benchmark_parquet)][0]["assets"]["red"]["href"] == expected_red_href + ) + assert expanded_items[0]["assets"]["red"]["href"] == expected_red_href diff --git a/tests/test_rasterio_parity.py b/tests/test_rasterio_parity.py index 9f3a191..8adaede 100644 --- a/tests/test_rasterio_parity.py +++ b/tests/test_rasterio_parity.py @@ -1,20 +1,4 @@ -"""Parity tests: lazycogs read pipeline vs rasterio nearest-neighbor. - -Verifies that lazycogs produces bit-identical output to rasterio for nearest- -neighbor resampling, at resolutions that bracket each overview boundary. -Overview levels for the synthetic COG are 20, 40, 80, and 160 m, so -resolutions just below, at, and just above each boundary are all covered. - -Two CRS scenarios are tested: -- Same CRS: destination grid is in the same UTM 32N CRS as the source file. - This exercises pure pixel-mapping without any coordinate transform. -- Cross CRS: destination grid is in UTM zone 33N (EPSG:32633), an adjacent - zone that shares metric units and a similar pixel scale (~10 m) but requires - a genuine coordinate transform. Using a same-unit projection avoids the - large scale-ratio issues that arise with degree-based CRS (e.g. WGS84), - which cause floating-point boundary sensitivity unrelated to overview - selection. -""" +"""Rasterio/GDAL reference tests for lazycogs reprojection behavior.""" from __future__ import annotations @@ -31,12 +15,15 @@ from pyproj import CRS from lazycogs._chunk_reader import _native_window, _select_overview -from lazycogs._reproject import _get_transformer, apply_warp_map, compute_warp_map +from lazycogs._reproject import ( + ReprojectRequest, + ResamplingMethod, + _get_transformer, + reproject_tile, +) from lazycogs._store import resolve -# Resolutions chosen to bracket each overview boundary (20, 40, 80, 160 m). -# Includes native resolution, between-level values, and one well above all overviews. -_RESOLUTIONS = [ +_NEAREST_RESOLUTIONS = [ 10, 14, 19, @@ -56,10 +43,22 @@ 161, 300, ] - +_INTERPOLATING_RESOLUTIONS = [14, 21, 41, 81, 161] _CHUNK_SIZE = 64 _CENTER_UTM_X = 510_240.0 _CENTER_UTM_Y = 5_589_760.0 +_INTERPOLATING_METHODS = [ + pytest.param( + ResamplingMethod.BILINEAR, + rasterio.enums.Resampling.bilinear, + id="bilinear", + ), + pytest.param( + ResamplingMethod.CUBIC, + rasterio.enums.Resampling.cubic, + id="cubic", + ), +] def _chunk_affine(resolution: float, center_x: float, center_y: float) -> Affine: @@ -71,8 +70,11 @@ async def _read_lazycogs( href: str, chunk_affine: Affine, dst_crs: CRS, + *, + resampling: ResamplingMethod = ResamplingMethod.NEAREST, + backend: str | None = None, ) -> np.ndarray: - """Run the lazycogs read pipeline and return the output array.""" + """Run the lazycogs tile read + reprojection path for one chunk.""" store, path = resolve(href) geotiff = await GeoTIFF.open(path, store=store) src_crs = geotiff.crs @@ -109,15 +111,20 @@ async def _read_lazycogs( assert window is not None, "test chunk must overlap COG extent" raster = await reader.read(window=window) - warp_map = compute_warp_map( - raster.transform, - src_crs, - chunk_affine, - dst_crs, - _CHUNK_SIZE, - _CHUNK_SIZE, + return reproject_tile( + ReprojectRequest( + data=raster.data, + src_transform=raster.transform, + src_crs=src_crs, + dst_transform=chunk_affine, + dst_crs=dst_crs, + dst_width=_CHUNK_SIZE, + dst_height=_CHUNK_SIZE, + nodata=geotiff.nodata, + resampling=resampling, + ), + backend=backend, ) - return apply_warp_map(raster.data, warp_map, geotiff.nodata) def _odc_overview_level( @@ -143,10 +150,14 @@ def _read_rasterio( chunk_affine: Affine, dst_crs: CRS, native_res: float, + *, + resampling: rasterio.enums.Resampling, ) -> np.ndarray: - """Run rasterio nearest-neighbor reproject at the odc-stac-selected overview.""" + """Run rasterio/GDAL reproject at the odc-stac-selected overview.""" with rasterio.open(path) as src: src_crs_obj = CRS.from_user_input(src.crs.to_wkt()) + src_dtype = np.dtype(src.dtypes[0]) + src_nodata = src.nodata same_crs = dst_crs.equals(src_crs_obj) target_res_native = abs(chunk_affine.a) @@ -161,26 +172,19 @@ def _read_rasterio( ovr_level = _odc_overview_level(path, target_res_native, native_res) with rasterio.open(path, overview_level=ovr_level) as src: - out = np.zeros((1, _CHUNK_SIZE, _CHUNK_SIZE), dtype=np.float32) + out = np.zeros((1, _CHUNK_SIZE, _CHUNK_SIZE), dtype=src_dtype) rasterio.warp.reproject( source=rasterio.band(src, 1), destination=out, src_transform=src.transform, src_crs=src.crs, - dst_transform=Affine( - chunk_affine.a, - chunk_affine.b, - chunk_affine.c, - chunk_affine.d, - chunk_affine.e, - chunk_affine.f, - ), + dst_transform=chunk_affine, dst_crs=dst_crs.to_wkt(), - resampling=rasterio.enums.Resampling.nearest, - src_nodata=0, - dst_nodata=0, + resampling=resampling, + src_nodata=src_nodata, + dst_nodata=src_nodata, ) - return out.astype(np.uint16) + return out def _href(path: Path) -> str: @@ -191,79 +195,243 @@ def _assert_parity( lazycogs_out: np.ndarray, rasterio_out: np.ndarray, label: str, + *, max_differing_pixels: int = 0, max_abs_diff: int = 0, ) -> None: - """Assert that the two outputs are pixel-identical within the given tolerances. - - ``max_differing_pixels`` and ``max_abs_diff`` may both be nonzero only for - the cross-CRS test, where a handful of destination pixel centres can land - within floating-point precision of a source pixel boundary and lazycogs - (pyproj) and GDAL round to opposite sides. These boundary pixels never - indicate an overview selection error; they differ by at most one source - overview pixel's value. The tolerances here are deliberately tight so that - any systematic regression (wrong overview level, large pixel-mapping error) - still trips the assertion. - """ + """Assert parity with integer-valued outputs under explicit tolerances.""" diff = lazycogs_out.astype(np.int32) - rasterio_out.astype(np.int32) n_diff = int(np.count_nonzero(diff)) actual_max = int(np.abs(diff).max()) if n_diff else 0 msg = ( f"{label}: {n_diff}/{lazycogs_out.size} pixels differ " - f"(allowed ≤{max_differing_pixels}); " - f"max abs diff = {actual_max} (allowed ≤{max_abs_diff})" + f"(allowed ≤{max_differing_pixels}); max abs diff = {actual_max} " + f"(allowed ≤{max_abs_diff})" ) assert n_diff <= max_differing_pixels, msg assert actual_max <= max_abs_diff, msg -@pytest.mark.parametrize("resolution", _RESOLUTIONS) +def _assert_interpolating_reference( + lazycogs_out: np.ndarray, + rasterio_out: np.ndarray, + label: str, + *, + nodata: float, + atol: float, + max_nodata_mismatch: int = 0, + rtol: float = 1e-6, +) -> None: + """Assert float-valued interpolation output stays close to rasterio.""" + lazy_valid = ~np.isclose(lazycogs_out, nodata) + rasterio_valid = ~np.isclose(rasterio_out, nodata) + nodata_mismatch = int(np.count_nonzero(lazy_valid != rasterio_valid)) + assert nodata_mismatch <= max_nodata_mismatch, ( + f"{label} had {nodata_mismatch} nodata-mask mismatches " + f"(allowed ≤{max_nodata_mismatch})" + ) + + shared_valid = lazy_valid & rasterio_valid + assert np.any(shared_valid), f"{label} produced no overlapping valid pixels" + + try: + np.testing.assert_allclose( + lazycogs_out[shared_valid], + rasterio_out[shared_valid], + atol=atol, + rtol=rtol, + ) + except AssertionError as exc: + raise AssertionError( + f"{label} exceeded tolerance atol={atol}, rtol={rtol}", + ) from exc + + +@pytest.mark.parametrize("resolution", _NEAREST_RESOLUTIONS) def test_parity_same_crs(synthetic_cog: Path, resolution: int) -> None: - """lazycogs matches rasterio for same-CRS reads at all overview levels.""" + """rust-warp nearest matches rasterio for same-CRS reads at overview boundaries.""" dst_crs = CRS.from_epsg(32632) affine = _chunk_affine(resolution, _CENTER_UTM_X, _CENTER_UTM_Y) lc_out = asyncio.run(_read_lazycogs(_href(synthetic_cog), affine, dst_crs)) - rio_out = _read_rasterio(synthetic_cog, affine, dst_crs, native_res=10.0) + rio_out = _read_rasterio( + synthetic_cog, + affine, + dst_crs, + native_res=10.0, + resampling=rasterio.enums.Resampling.nearest, + ) + + _assert_parity(lc_out, rio_out, f"same_crs res={resolution}") + + +@pytest.mark.parametrize("resolution", _NEAREST_RESOLUTIONS) +def test_parity_cross_crs(synthetic_cog: Path, resolution: int) -> None: + """rust-warp nearest stays within tight rasterio tolerances cross-CRS.""" + src_crs = CRS.from_epsg(32632) + dst_crs = CRS.from_epsg(3035) + t = _get_transformer(src_crs, dst_crs) + cx_laea, cy_laea = t.transform(_CENTER_UTM_X, _CENTER_UTM_Y) + affine = _chunk_affine(float(resolution), cx_laea, cy_laea) + + lc_out = asyncio.run(_read_lazycogs(_href(synthetic_cog), affine, dst_crs)) + rio_out = _read_rasterio( + synthetic_cog, + affine, + dst_crs, + native_res=10.0, + resampling=rasterio.enums.Resampling.nearest, + ) _assert_parity( lc_out, rio_out, - f"same_crs res={resolution}", - max_differing_pixels=0, - max_abs_diff=0, + f"cross_crs res={resolution}m", + max_differing_pixels=3, + max_abs_diff=2048 * 16 + 1, ) -@pytest.mark.parametrize("resolution", _RESOLUTIONS) -def test_parity_cross_crs(synthetic_cog: Path, resolution: int) -> None: - """lazycogs matches rasterio/nearest for cross-CRS reads at all overview boundaries. +@pytest.mark.parametrize("resolution", [20, 60, 160]) +def test_nearest_legacy_matches_rust_warp_same_crs( + synthetic_cog: Path, + resolution: int, +) -> None: + """Migration-window A/B checks keep nearest same-CRS behavior aligned.""" + dst_crs = CRS.from_epsg(32632) + affine = _chunk_affine(float(resolution), _CENTER_UTM_X, _CENTER_UTM_Y) + + legacy_out = asyncio.run( + _read_lazycogs( + _href(synthetic_cog), + affine, + dst_crs, + backend="legacy", + ), + ) + rust_out = asyncio.run( + _read_lazycogs( + _href(synthetic_cog), + affine, + dst_crs, + backend="rust-warp", + ), + ) + + np.testing.assert_array_equal(rust_out, legacy_out) - Destination CRS is UTM zone 33N (EPSG:32633). The source COG is in zone - 32N, so a real coordinate transform is required, but both CRS share metric - units and a similar pixel scale. This avoids the large scale-ratio issues - of degree-based projections while still exercising the cross-CRS code path. - """ + +@pytest.mark.parametrize("resolution", [20, 60, 160]) +def test_nearest_legacy_matches_rust_warp_cross_crs( + synthetic_cog: Path, + resolution: int, +) -> None: + """Migration-window A/B checks keep nearest cross-CRS behavior aligned.""" src_crs = CRS.from_epsg(32632) - dst_crs = CRS.from_epsg( - 3035, - ) # ETRS89 / LAEA Europe — same-unit, low distortion near COG + dst_crs = CRS.from_epsg(3035) t = _get_transformer(src_crs, dst_crs) cx_laea, cy_laea = t.transform(_CENTER_UTM_X, _CENTER_UTM_Y) affine = _chunk_affine(float(resolution), cx_laea, cy_laea) - lc_out = asyncio.run(_read_lazycogs(_href(synthetic_cog), affine, dst_crs)) - rio_out = _read_rasterio(synthetic_cog, affine, dst_crs, native_res=10.0) + legacy_out = asyncio.run( + _read_lazycogs( + _href(synthetic_cog), + affine, + dst_crs, + backend="legacy", + ), + ) + rust_out = asyncio.run( + _read_lazycogs( + _href(synthetic_cog), + affine, + dst_crs, + backend="rust-warp", + ), + ) - # Allow ≤ 3 pixels to differ by at most 1 overview-row's worth of value. - # These are floating-point boundary pixels where the destination centre lands - # within a ULP of a source pixel edge and pyproj/GDAL round to opposite sides. - # A systematic regression (wrong overview level, large pixel mapping error) - # would produce far more differing pixels or a much larger diff. _assert_parity( + rust_out, + legacy_out, + f"legacy_vs_rust cross_crs res={resolution}m", + max_differing_pixels=3, + max_abs_diff=2048 * 16 + 1, + ) + + +@pytest.mark.parametrize(("resampling", "rasterio_resampling"), _INTERPOLATING_METHODS) +@pytest.mark.parametrize("resolution", _INTERPOLATING_RESOLUTIONS) +def test_interpolating_parity_same_crs( + continuous_synthetic_cog: Path, + resampling: ResamplingMethod, + rasterio_resampling: rasterio.enums.Resampling, + resolution: int, +) -> None: + """Interpolating kernels stay close to rasterio on smooth same-CRS data.""" + dst_crs = CRS.from_epsg(32632) + affine = _chunk_affine(float(resolution), _CENTER_UTM_X, _CENTER_UTM_Y) + + lc_out = asyncio.run( + _read_lazycogs( + _href(continuous_synthetic_cog), + affine, + dst_crs, + resampling=resampling, + ), + ) + rio_out = _read_rasterio( + continuous_synthetic_cog, + affine, + dst_crs, + native_res=10.0, + resampling=rasterio_resampling, + ) + + atol = 1e-3 if resampling is ResamplingMethod.BILINEAR else 2e-1 + _assert_interpolating_reference( lc_out, rio_out, - f"cross_crs res={resolution}m", - max_differing_pixels=3, - max_abs_diff=2048 * 16 + 1, # 1 row in the coarsest (16x) overview + f"same_crs {resampling} res={resolution}", + nodata=-9999.0, + atol=atol, + ) + + +@pytest.mark.parametrize(("resampling", "rasterio_resampling"), _INTERPOLATING_METHODS) +def test_interpolating_parity_cross_crs( + continuous_synthetic_cog: Path, + resampling: ResamplingMethod, + rasterio_resampling: rasterio.enums.Resampling, +) -> None: + """Interpolating kernels stay close to rasterio on smooth cross-CRS data.""" + src_crs = CRS.from_epsg(32632) + dst_crs = CRS.from_epsg(3035) + t = _get_transformer(src_crs, dst_crs) + cx_laea, cy_laea = t.transform(_CENTER_UTM_X, _CENTER_UTM_Y) + affine = _chunk_affine(60.0, cx_laea, cy_laea) + + lc_out = asyncio.run( + _read_lazycogs( + _href(continuous_synthetic_cog), + affine, + dst_crs, + resampling=resampling, + ), + ) + rio_out = _read_rasterio( + continuous_synthetic_cog, + affine, + dst_crs, + native_res=10.0, + resampling=rasterio_resampling, + ) + + atol = 1e-2 if resampling is ResamplingMethod.BILINEAR else 2e-1 + _assert_interpolating_reference( + lc_out, + rio_out, + f"cross_crs {resampling}", + nodata=-9999.0, + atol=atol, + max_nodata_mismatch=24, ) From 11c8dd47e58973783e02051f96bdce61bcedfc3a Mon Sep 17 00:00:00 2001 From: hrodmn Date: Thu, 14 May 2026 04:36:53 -0500 Subject: [PATCH 6/7] refactor: consolidate all reprojection code into _warp.py --- ARCHITECTURE.md | 30 +-- README.md | 9 +- src/lazycogs/__init__.py | 2 +- src/lazycogs/_backend.py | 26 +- src/lazycogs/_chunk_reader.py | 61 ++--- src/lazycogs/_core.py | 38 +-- src/lazycogs/_executor.py | 7 +- src/lazycogs/_explain.py | 11 +- src/lazycogs/_reproject.py | 325 ----------------------- src/lazycogs/{_rust_warp.py => _warp.py} | 85 +++++- tests/benchmarks/bench_pipeline.py | 14 +- tests/integration_test.py | 65 ++--- tests/test_backend.py | 7 +- tests/test_chunk_reader.py | 169 ++---------- tests/test_core.py | 8 +- tests/test_explain.py | 70 ++++- tests/test_rasterio_parity.py | 73 +---- tests/test_reproject.py | 250 ++++++----------- 18 files changed, 360 insertions(+), 890 deletions(-) delete mode 100644 src/lazycogs/_reproject.py rename src/lazycogs/{_rust_warp.py => _warp.py} (64%) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 7d6cc55..c50f17f 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -1,6 +1,6 @@ # Architecture: lazycogs -lazycogs turns a geoparquet STAC item index into a lazy `(band, time, y, x)` xarray DataArray backed by Cloud-Optimized GeoTIFFs. It requires no GDAL. All raster I/O is done through `async-geotiff` (Rust-backed), spatial queries go through DuckDB via `rustac`, and reprojection now flows through a backend-neutral dispatcher in `_reproject.py`. Public reprojection methods currently route through the rust-warp adapter, while the legacy pyproj + numpy path remains available only as internal migration scaffolding. +lazycogs turns a geoparquet STAC item index into a lazy `(band, time, y, x)` xarray DataArray backed by Cloud-Optimized GeoTIFFs. It requires no GDAL. All raster I/O is done through `async-geotiff` (Rust-backed), spatial queries go through DuckDB via `rustac`, and reprojection flows through a small dispatcher in `_warp.py` that routes all public resampling modes to rust-warp. ## Why parquet, not a STAC API URL @@ -30,7 +30,7 @@ src/lazycogs/ _executor.py Per-chunk reprojection thread pool configuration. Exposes set_reproject_workers() and get_max_workers(); the actual pool is created per event loop in _backend.py. _explain.py Dry-run read estimator. Registers the da.lazycogs.explain() xarray accessor. _grid.py Compute output affine transform and dimensions from bbox + resolution. - _reproject.py Backend-neutral reprojection dispatcher; legacy nearest-neighbor backend still uses pyproj Transformer + numpy fancy indexing. + _warp.py Reprojection request types and rust-warp dispatch. _storage_ext.py STAC Storage Extension metadata parsing (version detection, kwargs extraction for v1 and v2). _store.py Resolve cloud HREFs into obstore Store instances (or route through a user-supplied store) with a thread-local cache; store_for() factory for constructing stores from parquet STAC files. _temporal.py Temporal grouping strategies (day, week, month, year, fixed-day-count). @@ -127,19 +127,9 @@ If the chunk bbox falls entirely outside the source image after clamping, `_nati ### 4. Reprojection dispatch and current backend state -`_chunk_reader.py` now builds a `ReprojectRequest` and calls `reproject_tile()` in `_reproject.py` rather than reaching directly into warp-map helpers. That gives lazycogs a clean seam for swapping reprojection engines while keeping chunk orchestration unchanged. +`_chunk_reader.py` builds a `ReprojectRequest` and calls `reproject_tile()` in `_warp.py`. The dispatcher always short-circuits exact same-grid reads. After that, all public reprojection requests route through rust-warp for `nearest`, `bilinear`, and `cubic`. -The dispatcher always short-circuits exact same-grid reads. After that, public reprojection requests currently route to rust-warp for `nearest`, `bilinear`, and `cubic` alike. - -The legacy nearest-neighbor path is still present as internal migration scaffolding and still warps the source tile onto the destination chunk grid without GDAL: - -1. Build a meshgrid of destination pixel-centre coordinates. -2. Transform all coordinates from `dst_crs` to `src_crs` in one vectorised `Transformer.transform()` call. -3. Apply the inverse source affine transform to get fractional source pixel indices. -4. `np.floor` rounds to the nearest-neighbor sample; numpy fancy indexing populates the output array. -5. Out-of-bounds pixels get the nodata fill value. - -Public resampling is validated at `open()` time. The currently supported values are `nearest`, `bilinear`, and `cubic`. Same-grid reads bypass reprojection regardless of the selected resampling mode. +Public resampling is validated at `open()` time and must be passed as a `ResamplingMethod` enum. Same-grid reads bypass reprojection regardless of the selected resampling mode. ## Concurrency model @@ -147,15 +137,15 @@ There are four nested layers of concurrency in a chunk read. **Dask (chunk level).** When a dask-backed DataArray is computed, dask dispatches each chunk task to a worker thread. Each worker thread calls `_sync_getitem()` in `_backend.py`, which calls `_run_coroutine(_async_getitem(...))` to drive all time steps from the persistent per-thread background loop. Worker threads are independent — they share no state except the thread-local store cache in `_store.py` and the thread-local DuckDB clients in `_backend.py`. -**asyncio (time + item level).** A single event loop call (via `_run_coroutine`) handles the entire chunk. Inside `_async_getitem`, `asyncio.gather` fans out one `_run_one_date` coroutine per time step, so all time steps are in flight concurrently within the same event loop. DuckDB queries run on `_DUCKDB_EXECUTOR` (a dedicated two-thread pool) via `run_in_executor`, yielding the event loop during each query. DuckDB's internal mutex serialises actual DB access, so queries are safe but not parallel on a single `DuckdbClient`. Once a query returns, its mosaic coroutine proceeds immediately and COG reads for all time steps overlap in the event loop's I/O layer. Because all time steps share a single event loop and therefore a single bounded reprojection executor, the reprojection thread count stays at `get_max_workers()` regardless of how many time steps are in the chunk (no thread explosion). The `warp_cache` is shared across coroutines: `compute_warp_map` is deterministic, so concurrent writes are safe. +**asyncio (time + item level).** A single event loop call (via `_run_coroutine`) handles the entire chunk. Inside `_async_getitem`, `asyncio.gather` fans out one `_run_one_date` coroutine per time step, so all time steps are in flight concurrently within the same event loop. DuckDB queries run on `_DUCKDB_EXECUTOR` (a dedicated two-thread pool) via `run_in_executor`, yielding the event loop during each query. DuckDB's internal mutex serialises actual DB access, so queries are safe but not parallel on a single `DuckdbClient`. Once a query returns, its mosaic coroutine proceeds immediately and COG reads for all time steps overlap in the event loop's I/O layer. Because all time steps share a single event loop and therefore a single bounded reprojection executor, the reprojection thread count stays at `get_max_workers()` regardless of how many time steps are in the chunk (no thread explosion). **asyncio (item level).** Inside a time step's event loop, `read_chunk_async` launches one `_read_item_band()` task per overlapping item up front, with an `asyncio.Semaphore(max_concurrent_reads)` (configurable via `open()`, default 32) capping how many reads run concurrently. Tasks complete in I/O arrival order, but results are buffered by their original list index and drained into the mosaic in source-list order. This preserves the sort contract for `FirstMethod` — items are fed strictly in the order returned by DuckDB (i.e. the caller's `sortby` order) regardless of which COGs arrive first over the network, while all concurrent I/O remains in flight. COG header reads and tile fetches from `async-geotiff` are all awaitable, so the event loop multiplexes them without blocking. Early exit is preserved: once the mosaic method signals completion, remaining tasks are cancelled in a `finally` block, and items still waiting on the semaphore never start. -**Thread pool (CPU work per item).** `_apply_bands_with_warp_cache` is synchronous CPU-bound work that processes all bands for one item together. `_read_item_band` dispatches it via `loop.run_in_executor(None, ...)` — one executor call per item — so the event loop stays free to process other items' tile reads while reprojections run on threads. Because the call is coarse-grained (all bands per item) and GIL-releasing (`pyproj` and numpy both release during heavy inner loops), offloading to the thread pool gives real CPU parallelism without excessive submission overhead. +**Thread pool (CPU work per item).** `_reproject_bands` is synchronous CPU-bound work that processes all bands for one item together. `_read_item_band` dispatches it via `loop.run_in_executor(None, ...)` — one executor call per item — so the event loop stays free to process other items' tile reads while reprojections run on threads. Because the call is coarse-grained and rust-warp does the heavy inner loops off the event loop, offloading to the thread pool gives real CPU parallelism without excessive submission overhead. **Why threads, not a process pool.** `pyproj.Transformer.transform()` and numpy's fancy-indexing both release the GIL during their heavy inner loops. Threads therefore give real CPU parallelism here — not just interleaving — without the overhead of process spawning and array pickling that a `ProcessPoolExecutor` would require. -**Why reprojection is memory-bandwidth-bound, not compute-bound.** `compute_warp_map` builds two meshgrids the size of the output chunk, transforms all coordinates in one vectorised call, and produces large index arrays. `apply_warp_map` samples the source array with random-access fancy indexing (`out[:, valid] = data[:, row_idx[valid], col_idx[valid]]`), which produces near-constant cache misses. Both phases are dominated by memory latency and bandwidth rather than arithmetic. In practice this means CPU utilisation is low (threads stall waiting for memory), and adding more than 4 concurrent reprojection threads provides no throughput benefit — they saturate the memory bus instead. +**Why reprojection is still memory-bandwidth-sensitive.** Even with rust-warp handling the reprojection kernel, chunk reprojection moves large raster windows through memory and writes full destination tiles. In practice this means throughput still stops improving after a small number of concurrent reprojection threads, so the default executor cap remains conservative. **Bounded per-loop executor.** Rather than using Python's default `min(32, cpu_count + 4)` thread count, `_get_or_create_background_loop()` installs a bounded `ThreadPoolExecutor` (default `min(os.cpu_count(), 4)`) as the default executor on each background loop it creates, before any coroutines run. This caps thread count per loop while preserving per-loop isolation: each dask task has its own independent pool and does not queue behind other tasks. The executor is shut down when the background loop thread exits. Call `lazycogs.set_reproject_workers(n)` to change the per-loop bound (see `_executor.py`). @@ -225,7 +215,7 @@ These are copied from `rio-tiler` (MIT licence, zero GDAL imports) to avoid pull Combining `time_period` with a mosaic method is the idiomatic way to produce temporal composites. Setting `time_period="P1W"` groups every STAC item within the same ISO calendar week into a single time step. When a chunk is read, -`async_mosaic_chunk` feeds all items for that week to the mosaic method in +`read_chunk_async()` feeds all items for that week to the mosaic method in order. With `FirstMethod` (the default), reading stops as soon as every output pixel has a valid (non-nodata) value — the remaining items in the week are never fetched. @@ -268,8 +258,8 @@ When the store root does not align with the URL structure of the asset HREFs — | `arro3-core` | Zero-copy Arrow table output from DuckDB queries (installed via `rustac[arrow]`) | | `async-geotiff` | Async COG header reads and windowed tile reads (Rust, no GDAL) | | `obstore` | Cloud object store abstraction layer for async-geotiff | -| `rust-warp` | Experimental reprojection backend dependency, currently sourced from GitHub during integration work | -| `pyproj` | CRS transforms: bbox reprojection, target-resolution estimation, legacy warp backend | +| `rust-warp` | Reprojection backend, currently sourced from GitHub | +| `pyproj` | CRS transforms: bbox reprojection and target-resolution estimation | | `xarray` | DataArray / Dataset assembly, `BackendArray` / `LazilyIndexedArray` protocol | | `rasterix` | CRS-aware `RasterIndex` for lazy spatial coordinates | | `xproj` | CRS accessor and alignment for xarray Flexible Indexes | diff --git a/README.md b/README.md index e17a050..596efea 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ One constraint worth naming: lazycogs only reads Cloud Optimized GeoTIFFs. If yo | STAC search + spatial indexing | `rustac` (DuckDB + geoparquet) | | COG I/O | `async-geotiff` (Rust, no GDAL) | | Cloud storage | `obstore` | -| Reprojection | backend-neutral seam in `lazycogs`; all public resampling methods currently route through `rust-warp` | +| Reprojection | `rust-warp` via `lazycogs.ResamplingMethod` | | Lazy dataset construction | xarray `BackendEntrypoint` + `LazilyIndexedArray` | ## Installation @@ -33,7 +33,7 @@ One constraint worth naming: lazycogs only reads Cloud Optimized GeoTIFFs. If yo pip install lazycogs ``` -Current development work also pins `rust-warp` from GitHub via uv while the reprojection migration is in progress. +`rust-warp` is currently pinned from GitHub via uv pending a stable release flow for this dependency. ## Coordinate convention @@ -75,11 +75,10 @@ da = lazycogs.open( bbox=dst_bbox, crs=dst_crs, resolution=10.0, - resampling="nearest", # also supports "bilinear" and "cubic" + resampling=lazycogs.ResamplingMethod.NEAREST, ) -# Or use the enum if you prefer: -# lazycogs.open(..., resampling=lazycogs.ResamplingMethod.CUBIC) +# Other supported modes: ResamplingMethod.BILINEAR and ResamplingMethod.CUBIC ``` ### Async loading diff --git a/src/lazycogs/__init__.py b/src/lazycogs/__init__.py index 41d4777..87098ae 100644 --- a/src/lazycogs/__init__.py +++ b/src/lazycogs/__init__.py @@ -21,8 +21,8 @@ MosaicMethodBase, StdevMethod, ) -from lazycogs._reproject import ResamplingMethod from lazycogs._store import store_for +from lazycogs._warp import ResamplingMethod __all__ = [ "DEFAULT_RESAMPLING", diff --git a/src/lazycogs/_backend.py b/src/lazycogs/_backend.py index 02388b6..48e4bcb 100644 --- a/src/lazycogs/_backend.py +++ b/src/lazycogs/_backend.py @@ -20,7 +20,7 @@ _DUCKDB_EXECUTOR, _run_coroutine, ) -from lazycogs._reproject import ResamplingMethod +from lazycogs._warp import ResamplingMethod logger = logging.getLogger(__name__) @@ -28,6 +28,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from obstore.store import ObjectStore from rustac import DuckdbClient from lazycogs._mosaic_methods import MosaicMethodBase @@ -41,11 +42,6 @@ class _ChunkReadPlan: ``_read_chunk_all_dates`` and ``_run_one_date``. Frozen to make the read-only intent explicit. - Note: ``warp_cache`` is a mutable dict despite the frozen dataclass. This - is intentional — concurrent writes from ``asyncio.gather`` coroutines are - safe because ``compute_warp_map`` is deterministic (a duplicate write - simply overwrites an identical value). - Attributes: duckdb_client: ``DuckdbClient`` instance used for STAC queries. parquet_path: Path to the geoparquet file or hive-partitioned directory. @@ -65,7 +61,6 @@ class _ChunkReadPlan: resampling: Reprojection resampling method for this chunk. store: Pre-configured obstore ``ObjectStore`` instance, or ``None``. max_concurrent_reads: Maximum concurrent COG reads per chunk. - warp_cache: Shared warp map cache across time steps. path_fn: Optional callable extracting an object path from an asset HREF. """ @@ -86,9 +81,8 @@ class _ChunkReadPlan: nodata: float | None mosaic_method_cls: type[MosaicMethodBase] | None resampling: ResamplingMethod - store: Any | None + store: ObjectStore | None max_concurrent_reads: int - warp_cache: dict path_fn: Callable[[str], str] | None @@ -265,7 +259,6 @@ async def _run_one_date( resampling=plan.resampling, store=plan.store, max_concurrent_reads=plan.max_concurrent_reads, - warp_cache=plan.warp_cache, path_fn=plan.path_fn, ) logger.debug( @@ -310,9 +303,9 @@ class MultiBandStacBackendArray(BackendArray): One instance is created at ``open()`` time. No pixel I/O happens until ``__getitem__`` is called inside a dask task. Reads all selected bands together per time step via - :func:`~lazycogs._chunk_reader.async_mosaic_chunk`, issuing a - single DuckDB query per time step and sharing reprojection warp maps across - bands that have identical source geometry. + :func:`~lazycogs._chunk_reader.read_chunk_async`, issuing a + single DuckDB query per time step and reprojecting all selected bands for + an item onto the same destination chunk. Attributes: parquet_path: Path to the geoparquet file or hive-partitioned directory @@ -375,7 +368,7 @@ class MultiBandStacBackendArray(BackendArray): nodata: float | None mosaic_method_cls: type[MosaicMethodBase] | None = field(default=None) resampling: ResamplingMethod = field(default=ResamplingMethod.NEAREST) - store: Any | None = field(default=None) + store: ObjectStore | None = field(default=None) max_concurrent_reads: int = field(default=32) path_from_href: Callable[[str], str] | None = field(default=None) shape: tuple[int, ...] = field(init=False) @@ -527,8 +520,8 @@ async def _async_getitem(self, key: tuple[Any, ...]) -> np.ndarray: Single source of truth for chunk reads. Reads all selected bands together per time step via :func:`~lazycogs._chunk_reader.read_chunk_async`, issuing a single - DuckDB query per time step and sharing reprojection warp maps across - bands that have identical source geometry. + DuckDB query per time step and reprojecting all selected bands for an + item onto the same destination chunk. Args: key: A tuple of ``int | slice`` objects for the @@ -568,7 +561,6 @@ async def _async_getitem(self, key: tuple[Any, ...]) -> np.ndarray: resampling=self.resampling, store=self.store, max_concurrent_reads=self.max_concurrent_reads, - warp_cache={}, path_fn=self.path_from_href, ) diff --git a/src/lazycogs/_chunk_reader.py b/src/lazycogs/_chunk_reader.py index 0d4a46c..541fc02 100644 --- a/src/lazycogs/_chunk_reader.py +++ b/src/lazycogs/_chunk_reader.py @@ -14,8 +14,13 @@ from lazycogs._executor import _run_coroutine from lazycogs._mosaic_methods import FirstMethod, MosaicMethodBase -from lazycogs._reproject import ReprojectRequest, _get_transformer, reproject_tile from lazycogs._store import resolve as _resolve_store +from lazycogs._warp import ( + ReprojectRequest, + ResamplingMethod, + _get_transformer, + reproject_tile, +) if TYPE_CHECKING: from collections.abc import Callable @@ -31,7 +36,7 @@ class _ChunkContext: """Immutable per-chunk parameters shared across all item reads. - Built once per chunk in async_mosaic_chunk and passed through to all + Built once per chunk in ``read_chunk_async`` and passed through to all internal helpers. Frozen to prevent accidental mutation across concurrent coroutines. """ @@ -41,10 +46,9 @@ class _ChunkContext: chunk_width: int chunk_height: int nodata: float | None - resampling: str + resampling: ResamplingMethod store: ObjectStore | None path_fn: Callable[[str], str] | None - warp_cache: dict[object, object] | None def _log_batch_failure( @@ -265,20 +269,15 @@ async def _open_and_window( return geotiff, reader, window, path -def _apply_bands_with_warp_cache( +def _reproject_bands( band_rasters: list[tuple[str, RasterArray, CRS, float | None]], dst_transform: Affine, dst_crs: CRS, dst_width: int, dst_height: int, - resampling: str = "nearest", - warp_cache: dict[object, object] | None = None, + resampling: ResamplingMethod = ResamplingMethod.NEAREST, ) -> dict[str, tuple[np.ndarray, float | None]]: - """Reproject multiple band rasters through the backend-neutral interface. - - The optional ``warp_cache`` is currently forwarded to the legacy backend so - repeated source geometries can still reuse precomputed mappings during the - migration away from warp-map-specific call sites. + """Reproject multiple band rasters onto one destination grid. Args: band_rasters: List of ``(band_name, raster, src_crs, effective_nodata)`` @@ -289,25 +288,14 @@ def _apply_bands_with_warp_cache( dst_width: Width of the destination grid in pixels. dst_height: Height of the destination grid in pixels. resampling: Reprojection resampling method. - warp_cache: Optional migration-time cache shared across calls. Returns: ``dict`` mapping band name to ``(reprojected_array, effective_nodata)``. """ - cache = warp_cache if warp_cache is not None else {} results: dict[str, tuple[np.ndarray, float | None]] = {} for band, raster, src_crs, effective_nodata in band_rasters: - if ( - src_crs.equals(dst_crs) - and raster.transform == dst_transform - and raster.data.shape[1] == dst_height - and raster.data.shape[2] == dst_width - ): - results[band] = (raster.data, effective_nodata) - continue - results[band] = ( reproject_tile( ReprojectRequest( @@ -321,7 +309,6 @@ def _apply_bands_with_warp_cache( nodata=effective_nodata, resampling=resampling, ), - warp_cache=cache, ), effective_nodata, ) @@ -334,13 +321,12 @@ async def _read_item_band( bands: list[str], ctx: _ChunkContext, ) -> dict[str, tuple[np.ndarray, float | None]] | None: - """Read and reproject multiple bands from one STAC item, sharing warp maps. + """Read and reproject multiple bands from one STAC item. Opens all band COGs concurrently, computes per-band windows independently (so bands with different native resolutions or extents are handled correctly), reads all windows concurrently, then dispatches a single thread-executor call - that applies warp maps with caching: bands sharing the same source CRS and - window transform reuse the same warp map. + that reprojects all bands onto the destination grid. Args: item: STAC item dict containing an ``assets`` key. @@ -428,18 +414,16 @@ async def _read_band( for band, raster in read_results ] - # Compute warp maps and apply, sharing maps across bands with identical geometry. loop = asyncio.get_running_loop() return await loop.run_in_executor( None, - lambda: _apply_bands_with_warp_cache( + lambda: _reproject_bands( band_rasters, ctx.chunk_affine, ctx.dst_crs, ctx.chunk_width, ctx.chunk_height, ctx.resampling, - ctx.warp_cache, ), ) @@ -510,17 +494,15 @@ async def read_chunk_async( chunk_height: int, nodata: float | None = None, mosaic_method_cls: type[MosaicMethodBase] | None = None, - resampling: str = "nearest", + resampling: ResamplingMethod = ResamplingMethod.NEAREST, store: ObjectStore | None = None, max_concurrent_reads: int = 32, - warp_cache: dict | None = None, path_fn: Callable[[str], str] | None = None, ) -> dict[str, np.ndarray]: """Read, reproject, and mosaic multiple bands from a list of STAC items. - Processes all requested bands together per item so that bands sharing the - same source geometry compute the reprojection warp map only once (via - :func:`_apply_bands_with_warp_cache`). + Processes all requested bands together per item so they can be reprojected + in one thread-executor call. Items are processed in batches of ``max_concurrent_reads``. When all per-band mosaic methods signal completion, remaining batches are skipped. @@ -538,8 +520,6 @@ async def read_chunk_async( resampling: Reprojection resampling method. store: Optional pre-configured obstore ``ObjectStore`` instance. max_concurrent_reads: Maximum number of COG reads to run concurrently. - warp_cache: Optional cache shared across calls for reusing warp maps - from earlier time steps. path_fn: Optional callable that takes an asset HREF and returns the object path to use with *store*. Forwarded to :func:`_read_item_band`. @@ -562,7 +542,6 @@ async def read_chunk_async( resampling=resampling, store=store, path_fn=path_fn, - warp_cache=warp_cache, ) semaphore = asyncio.Semaphore(max_concurrent_reads) @@ -619,10 +598,9 @@ def read_chunk( chunk_height: int, nodata: float | None = None, mosaic_method_cls: type[MosaicMethodBase] | None = None, - resampling: str = "nearest", + resampling: ResamplingMethod = ResamplingMethod.NEAREST, store: ObjectStore | None = None, max_concurrent_reads: int = 32, - warp_cache: dict | None = None, path_fn: Callable[[str], str] | None = None, ) -> dict[str, np.ndarray]: """Run :func:`read_chunk_async` on the persistent per-thread background loop. @@ -642,8 +620,6 @@ def read_chunk( resampling: Reprojection resampling method. store: Optional pre-configured obstore ``ObjectStore`` instance. max_concurrent_reads: Maximum number of COG reads to run concurrently. - warp_cache: Optional cache shared across calls for reusing warp maps - from earlier time steps. path_fn: Optional callable that takes an asset HREF and returns the object path to use with *store*. @@ -666,7 +642,6 @@ def read_chunk( resampling=resampling, store=store, max_concurrent_reads=max_concurrent_reads, - warp_cache=warp_cache, path_fn=path_fn, ), ) diff --git a/src/lazycogs/_core.py b/src/lazycogs/_core.py index c368169..d5ea17e 100644 --- a/src/lazycogs/_core.py +++ b/src/lazycogs/_core.py @@ -17,9 +17,9 @@ from lazycogs._cql2 import _extract_filter_fields, _sortby_fields from lazycogs._grid import compute_output_grid from lazycogs._mosaic_methods import FirstMethod, MosaicMethodBase -from lazycogs._reproject import ResamplingMethod from lazycogs._store import resolve from lazycogs._temporal import _TemporalGrouper, grouper_from_period +from lazycogs._warp import ResamplingMethod if TYPE_CHECKING: from collections.abc import Callable @@ -255,26 +255,28 @@ def _build_time_steps( return filter_strings, time_coords -def _validate_resampling(resampling: str | ResamplingMethod) -> ResamplingMethod: - """Return ``resampling`` as a supported enum, else raise ``ValueError``. +def _validate_resampling(resampling: ResamplingMethod) -> ResamplingMethod: + """Return ``resampling`` when it is a supported enum value. Args: - resampling: User-provided resampling method name. + resampling: User-provided resampling enum value. Returns: Normalized resampling enum value. Raises: - ValueError: If *resampling* is not currently supported. + TypeError: If *resampling* is not a ``ResamplingMethod`` value. """ - try: - return ResamplingMethod(resampling) - except ValueError as exc: - supported = ", ".join(SUPPORTED_RESAMPLING) - raise ValueError( - f"Unsupported resampling {resampling!r}. Supported values: {supported}.", - ) from exc + if isinstance(resampling, ResamplingMethod): + return resampling + + supported = ", ".join(SUPPORTED_RESAMPLING) + raise TypeError( + "resampling must be a ResamplingMethod value. " + f"Got {type(resampling).__name__!r}: {resampling!r}. " + f"Supported values: {supported}.", + ) def _build_dataarray( @@ -302,8 +304,7 @@ def _build_dataarray( ) -> DataArray: """Assemble the lazy DataArray from pre-computed parameters. - This is the shared implementation used by both :func:`open` and - the STAC search completes. + Internal helper used by :func:`open` after the STAC search completes. Args: parquet_path: Path to a geoparquet file or hive-partitioned directory. @@ -468,7 +469,7 @@ def open( # noqa: A001 nodata: float | None = None, dtype: str | np.dtype | None = None, mosaic_method: type[MosaicMethodBase] | None = None, - resampling: str | ResamplingMethod = DEFAULT_RESAMPLING, + resampling: ResamplingMethod = DEFAULT_RESAMPLING, time_period: str = "P1D", store: ObjectStore | None = None, max_concurrent_reads: int = 32, @@ -507,9 +508,10 @@ def open( # noqa: A001 mosaic_method: Mosaic method class (not instance) to use. Defaults to :class:`~lazycogs._mosaic_methods.FirstMethod`. resampling: Reprojection resampling method. Supported values are - ``"nearest"`` (default), ``"bilinear"``, and ``"cubic"``. - Validation happens at open time so unsupported values fail before - any chunk reads begin. + :attr:`~lazycogs.ResamplingMethod.NEAREST` (default), + :attr:`~lazycogs.ResamplingMethod.BILINEAR`, and + :attr:`~lazycogs.ResamplingMethod.CUBIC`. Validation happens at + open time so unsupported values fail before any chunk reads begin. time_period: ISO 8601 duration string controlling how items are grouped into time steps. Supported forms: ``PnD`` (days), ``P1W`` (ISO calendar week), ``P1M`` (calendar month), ``P1Y`` diff --git a/src/lazycogs/_executor.py b/src/lazycogs/_executor.py index 07e3c4b..9040549 100644 --- a/src/lazycogs/_executor.py +++ b/src/lazycogs/_executor.py @@ -9,8 +9,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from asyncio.futures import Future - from types import CoroutineType + from collections.abc import Coroutine config: dict[str, int | None] = {"max_workers": None} @@ -28,7 +27,7 @@ def _default_workers() -> int: """Return the default worker count: CPUs up to a cap of 4. - Reprojection (pyproj + numpy) is memory-bandwidth-bound, not compute-bound. + Reprojection is memory-bandwidth-bound, not compute-bound. Benchmarks show diminishing returns beyond 4 concurrent threads because they saturate the memory bus rather than adding CPU throughput. Keep the default conservative. @@ -109,7 +108,7 @@ def _get_or_create_background_loop() -> asyncio.AbstractEventLoop: return loop -def _run_coroutine(coro: CoroutineType) -> Future: +def _run_coroutine[T](coro: Coroutine[object, object, T]) -> T: """Run an async coroutine from sync code. Submits the coroutine to a persistent per-thread background event loop, diff --git a/src/lazycogs/_explain.py b/src/lazycogs/_explain.py index 217a300..6be2866 100644 --- a/src/lazycogs/_explain.py +++ b/src/lazycogs/_explain.py @@ -16,9 +16,10 @@ from lazycogs._chunk_reader import _ChunkContext, _open_and_window from lazycogs._executor import _DUCKDB_EXECUTOR, _run_coroutine +from lazycogs._warp import ResamplingMethod if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator from obstore.store import ObjectStore @@ -448,6 +449,7 @@ async def _inspect_item_async( chunk_width: int, chunk_height: int, store: ObjectStore | None = None, + path_fn: Callable[[str], str] | None = None, ) -> CogRead | None: """Open a COG header and compute the overview level and read window. @@ -461,6 +463,8 @@ async def _inspect_item_async( chunk_width: Chunk width in pixels. chunk_height: Chunk height in pixels. store: Optional pre-configured obstore ``ObjectStore``. + path_fn: Optional callable that maps an asset HREF to the object path + used with ``store``. Returns: A :class:`CogRead` with all header fields populated, or ``None`` if @@ -473,9 +477,9 @@ async def _inspect_item_async( chunk_width=chunk_width, chunk_height=chunk_height, nodata=None, + resampling=ResamplingMethod.NEAREST, store=store, - path_fn=None, - warp_cache=None, + path_fn=path_fn, ) opened = await _open_and_window(item, band, ctx) if opened is None: @@ -634,6 +638,7 @@ async def _explain_one_tile( actual_w, actual_h, backend.store, + backend.path_from_href, ) for item in items ], diff --git a/src/lazycogs/_reproject.py b/src/lazycogs/_reproject.py deleted file mode 100644 index 77fe87b..0000000 --- a/src/lazycogs/_reproject.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Reproject raster arrays using a backend-neutral request interface.""" - -from __future__ import annotations - -import functools -from dataclasses import dataclass -from enum import StrEnum -from typing import TYPE_CHECKING, Literal - -import numpy as np -from pyproj import CRS, Transformer - -from lazycogs._rust_warp import reproject_array_rust_warp - -if TYPE_CHECKING: - from affine import Affine - - -class ResamplingMethod(StrEnum): - """Supported public reprojection resampling methods.""" - - NEAREST = "nearest" - BILINEAR = "bilinear" - CUBIC = "cubic" - - -@functools.lru_cache(maxsize=256) -def _get_transformer(src_crs: CRS, dst_crs: CRS) -> Transformer: - """Return a cached ``Transformer`` for a CRS pair. - - ``Transformer.from_crs`` involves PROJ database lookups and pipeline - initialisation. The same (src_crs, dst_crs) pair recurs for every item - in a collection, so caching avoids recreating the same object hundreds of - times per chunk read. ``pyproj.CRS`` is hashable via its WKT - representation, and ``Transformer`` is thread-safe from PROJ 6+. - - Args: - src_crs: Source CRS. - dst_crs: Destination CRS. - - Returns: - A ``Transformer`` that maps ``src_crs`` → ``dst_crs``. - - """ - return Transformer.from_crs(src_crs, dst_crs, always_xy=True) - - -@dataclass(frozen=True) -class ReprojectRequest: - """All inputs required to reproject one source tile. - - Attributes: - data: Source data with shape ``(bands, src_h, src_w)``. - src_transform: Affine transform of the source array. - src_crs: CRS of the source array. - dst_transform: Affine transform of the destination grid. - dst_crs: CRS of the destination grid. - dst_width: Destination width in pixels. - dst_height: Destination height in pixels. - nodata: Fill value for pixels that fall outside the source extent. - resampling: Requested resampling method. - - """ - - data: np.ndarray - src_transform: Affine - src_crs: CRS - dst_transform: Affine - dst_crs: CRS - dst_width: int - dst_height: int - nodata: float | None = None - resampling: ResamplingMethod = ResamplingMethod.NEAREST - - -@dataclass -class WarpMap: - """Precomputed pixel-coordinate mapping from a destination grid to a source grid. - - Stores the source column and row index for every destination pixel centre, - computed by a single vectorised ``Transformer.transform`` call. The ``valid`` - mask is not stored here; ``apply_warp_map`` derives it from the actual source - array shape so the same ``WarpMap`` can be reused across bands that share the - same source CRS and window transform but may have slightly different window - dimensions due to rounding. - - Attributes: - src_col_idx: Source column indices, shape ``(dst_height, dst_width)``, - dtype ``intp``. May contain out-of-bounds values for pixels that - map outside the source extent. - src_row_idx: Source row indices, shape ``(dst_height, dst_width)``, - dtype ``intp``. - - """ - - src_col_idx: np.ndarray - src_row_idx: np.ndarray - - -def compute_warp_map( - src_transform: Affine, - src_crs: CRS, - dst_transform: Affine, - dst_crs: CRS, - dst_width: int, - dst_height: int, -) -> WarpMap: - """Build a pixel-coordinate mapping from destination grid to source grid. - - Transforms every destination pixel centre into the source CRS with a single - vectorised ``Transformer.transform`` call, then converts to fractional source - pixel coordinates. The result can be reused across multiple bands that share - the same source CRS and window transform via :func:`apply_warp_map`. - - Args: - src_transform: Affine transform of the source array (window transform). - src_crs: CRS of the source array. - dst_transform: Affine transform of the destination grid. - dst_crs: CRS of the destination grid. - dst_width: Width of the destination grid in pixels. - dst_height: Height of the destination grid in pixels. - - Returns: - :class:`WarpMap` with ``src_col_idx`` and ``src_row_idx`` arrays of - shape ``(dst_height, dst_width)``. - - """ - col_idx = np.arange(dst_width) - row_idx = np.arange(dst_height) - col_grid, row_grid = np.meshgrid(col_idx, row_idx) - - dst_xs = dst_transform.c + (col_grid + 0.5) * dst_transform.a - dst_ys = dst_transform.f + (row_grid + 0.5) * dst_transform.e - - transformer = _get_transformer(dst_crs, src_crs) - src_xs, src_ys = transformer.transform(dst_xs.ravel(), dst_ys.ravel()) - - # ~src_transform maps (x, y) → (col_frac, row_frac). - inv = ~src_transform - frac_cols = (inv.a * src_xs + inv.b * src_ys + inv.c).reshape(dst_height, dst_width) - frac_rows = (inv.d * src_xs + inv.e * src_ys + inv.f).reshape(dst_height, dst_width) - - return WarpMap( - src_col_idx=np.floor(frac_cols).astype(np.intp), - src_row_idx=np.floor(frac_rows).astype(np.intp), - ) - - -def apply_warp_map( - data: np.ndarray, - warp_map: WarpMap, - nodata: float | None = None, -) -> np.ndarray: - """Sample a source array using a precomputed :class:`WarpMap`. - - The valid mask is derived from ``data.shape`` at call time so the same - ``warp_map`` can be safely applied to bands with slightly different window - dimensions. - - Args: - data: Source data with shape ``(bands, src_h, src_w)``. - warp_map: Pixel-coordinate mapping from destination to source. - nodata: Fill value for destination pixels that fall outside the source - extent, or ``None`` to use zero. - - Returns: - Array with shape ``(bands, dst_height, dst_width)`` and the same dtype - as ``data``. - - """ - bands, src_h, src_w = data.shape - dst_height, dst_width = warp_map.src_col_idx.shape - fill = nodata if nodata is not None else 0 - - valid = ( - (warp_map.src_col_idx >= 0) - & (warp_map.src_col_idx < src_w) - & (warp_map.src_row_idx >= 0) - & (warp_map.src_row_idx < src_h) - ) - - out = np.full((bands, dst_height, dst_width), fill, dtype=data.dtype) - out[:, valid] = data[:, warp_map.src_row_idx[valid], warp_map.src_col_idx[valid]] - return out - - -def _same_grid(request: ReprojectRequest) -> bool: - """Return ``True`` when reprojection is an exact no-op.""" - return ( - request.src_crs.equals(request.dst_crs) - and request.src_transform == request.dst_transform - and request.data.shape[1] == request.dst_height - and request.data.shape[2] == request.dst_width - ) - - -def _legacy_cache_key(request: ReprojectRequest) -> tuple[tuple[float, ...], CRS]: - """Return the cache key for the legacy nearest-neighbor backend.""" - return (tuple(request.src_transform), request.src_crs) - - -def _reproject_tile_legacy( - request: ReprojectRequest, - warp_map: WarpMap | None = None, -) -> np.ndarray: - """Reproject a tile with the legacy pyproj/numpy nearest backend.""" - if request.resampling != "nearest": - raise ValueError( - "The legacy reprojection backend only supports resampling='nearest'.", - ) - - resolved_warp_map = warp_map or compute_warp_map( - request.src_transform, - request.src_crs, - request.dst_transform, - request.dst_crs, - request.dst_width, - request.dst_height, - ) - return apply_warp_map(request.data, resolved_warp_map, request.nodata) - - -_DEFAULT_REPROJECT_BACKEND: Literal["legacy", "rust-warp"] = "rust-warp" - - -def reproject_tile( - request: ReprojectRequest, - *, - backend: Literal["legacy", "rust-warp"] | None = None, - warp_cache: dict[tuple[tuple[float, ...], CRS], WarpMap] | None = None, -) -> np.ndarray: - """Reproject one ``(bands, y, x)`` source tile onto a destination grid. - - Args: - request: Reprojection inputs for one tile. - backend: Internal backend selector. When ``None``, the module's - current default backend is used. - warp_cache: Optional cache for the legacy backend's precomputed warp - maps, keyed by source transform and CRS. Ignored by rust-warp. - - Returns: - Reprojected array with shape ``(bands, dst_height, dst_width)``. - - """ - if _same_grid(request): - return request.data - - resolved_backend = _DEFAULT_REPROJECT_BACKEND if backend is None else backend - if resolved_backend == "rust-warp": - return reproject_array_rust_warp( - data=request.data, - src_transform=request.src_transform, - src_crs=request.src_crs, - dst_transform=request.dst_transform, - dst_crs=request.dst_crs, - dst_width=request.dst_width, - dst_height=request.dst_height, - nodata=request.nodata, - resampling=request.resampling, - ) - if resolved_backend != "legacy": - raise ValueError(f"Unsupported reprojection backend: {resolved_backend}") - - warp_map: WarpMap | None = None - if warp_cache is not None: - cache_key = _legacy_cache_key(request) - warp_map = warp_cache.get(cache_key) - if warp_map is None: - warp_map = compute_warp_map( - request.src_transform, - request.src_crs, - request.dst_transform, - request.dst_crs, - request.dst_width, - request.dst_height, - ) - warp_cache[cache_key] = warp_map - - return _reproject_tile_legacy(request, warp_map) - - -def reproject_array( - data: np.ndarray, - src_transform: Affine, - src_crs: CRS, - dst_transform: Affine, - dst_crs: CRS, - dst_width: int, - dst_height: int, - nodata: float | None = None, -) -> np.ndarray: - """Reproject a raster array using nearest-neighbor sampling. - - Convenience wrapper around :class:`ReprojectRequest` and - :func:`reproject_tile`. Use :func:`reproject_tile` directly for the new - backend-neutral path. - - Args: - data: Source data with shape ``(bands, src_h, src_w)``. - src_transform: Affine transform of the source array. - src_crs: CRS of the source array. - dst_transform: Affine transform of the destination grid. - dst_crs: CRS of the destination grid. - dst_width: Width of the output array in pixels. - dst_height: Height of the output array in pixels. - nodata: Value to use for destination pixels that fall outside the - source extent, or ``None`` to use zero. - - Returns: - Reprojected array with shape ``(bands, dst_height, dst_width)`` and - the same dtype as ``data``. - - """ - return reproject_tile( - ReprojectRequest( - data=data, - src_transform=src_transform, - src_crs=src_crs, - dst_transform=dst_transform, - dst_crs=dst_crs, - dst_width=dst_width, - dst_height=dst_height, - nodata=nodata, - ), - ) diff --git a/src/lazycogs/_rust_warp.py b/src/lazycogs/_warp.py similarity index 64% rename from src/lazycogs/_rust_warp.py rename to src/lazycogs/_warp.py index 581676e..7026cbe 100644 --- a/src/lazycogs/_rust_warp.py +++ b/src/lazycogs/_warp.py @@ -1,15 +1,62 @@ -"""Private adapter for rust-warp's low-level reprojection API.""" +"""Raster reprojection helpers backed by rust-warp.""" from __future__ import annotations +import functools +from dataclasses import dataclass +from enum import StrEnum from typing import TYPE_CHECKING import numpy as np import rust_warp +from pyproj import CRS, Transformer if TYPE_CHECKING: from affine import Affine - from pyproj import CRS + + +class ResamplingMethod(StrEnum): + """Supported public reprojection resampling methods.""" + + NEAREST = "nearest" + BILINEAR = "bilinear" + CUBIC = "cubic" + + +@functools.lru_cache(maxsize=256) +def _get_transformer(src_crs: CRS, dst_crs: CRS) -> Transformer: + """Return a cached ``Transformer`` for a CRS pair. + + ``Transformer.from_crs`` involves PROJ database lookups and pipeline + initialisation. The same (src_crs, dst_crs) pair recurs for every item in a + collection, so caching avoids recreating the same object hundreds of times + per chunk read. + + Args: + src_crs: Source CRS. + dst_crs: Destination CRS. + + Returns: + A ``Transformer`` that maps ``src_crs`` → ``dst_crs``. + + """ + return Transformer.from_crs(src_crs, dst_crs, always_xy=True) + + +@dataclass(frozen=True) +class ReprojectRequest: + """All inputs required to reproject one source tile.""" + + data: np.ndarray + src_transform: Affine + src_crs: CRS + dst_transform: Affine + dst_crs: CRS + dst_width: int + dst_height: int + nodata: float | None = None + resampling: ResamplingMethod = ResamplingMethod.NEAREST + _SUPPORTED_DTYPES = frozenset( { @@ -24,6 +71,16 @@ _EXPECTED_ARRAY_NDIM = 3 +def _same_grid(request: ReprojectRequest) -> bool: + """Return ``True`` when reprojection is an exact no-op.""" + return ( + request.src_crs.equals(request.dst_crs) + and request.src_transform == request.dst_transform + and request.data.shape[1] == request.dst_height + and request.data.shape[2] == request.dst_width + ) + + def _affine_to_rust_warp( transform: Affine, ) -> tuple[float, float, float, float, float, float]: @@ -80,7 +137,7 @@ def _normalize_nodata(nodata: float | None, dtype: np.dtype) -> float | int: return np.array([nodata]).astype(dtype, casting="unsafe")[0].item() -def reproject_array_rust_warp( +def reproject_array( data: np.ndarray, src_transform: Affine, src_crs: CRS, @@ -88,8 +145,8 @@ def reproject_array_rust_warp( dst_crs: CRS, dst_width: int, dst_height: int, + resampling: ResamplingMethod, nodata: float | None = None, - resampling: str = "nearest", ) -> np.ndarray: """Reproject a ``(bands, y, x)`` array via rust-warp's 2D kernel. @@ -101,8 +158,8 @@ def reproject_array_rust_warp( dst_crs: CRS of the destination grid. dst_width: Destination width in pixels. dst_height: Destination height in pixels. + resampling: rust-warp resampling method. nodata: Fill value for pixels outside the source extent. - resampling: rust-warp resampling method name. Returns: Reprojected array with shape ``(bands, dst_height, dst_width)``. @@ -139,3 +196,21 @@ def reproject_array_rust_warp( for band in data ] return np.stack(reprojected_bands, axis=0) + + +def reproject_tile(request: ReprojectRequest) -> np.ndarray: + """Reproject one ``(bands, y, x)`` source tile onto a destination grid.""" + if _same_grid(request): + return request.data + + return reproject_array( + data=request.data, + src_transform=request.src_transform, + src_crs=request.src_crs, + dst_transform=request.dst_transform, + dst_crs=request.dst_crs, + dst_width=request.dst_width, + dst_height=request.dst_height, + nodata=request.nodata, + resampling=request.resampling, + ) diff --git a/tests/benchmarks/bench_pipeline.py b/tests/benchmarks/bench_pipeline.py index 485c611..f2d2fda 100644 --- a/tests/benchmarks/bench_pipeline.py +++ b/tests/benchmarks/bench_pipeline.py @@ -18,7 +18,7 @@ import lazycogs from lazycogs import FirstMethod, MedianMethod, MosaicMethodBase, set_reproject_workers -from lazycogs._reproject import ReprojectRequest, ResamplingMethod, reproject_tile +from lazycogs._warp import ReprojectRequest, ResamplingMethod, reproject_tile from .conftest import ( BENCHMARK_BBOX, @@ -72,12 +72,11 @@ def _benchmark_request( @pytest.mark.benchmark -@pytest.mark.parametrize("backend", ["legacy", "rust-warp"]) -def test_small_window_nearest_backend_comparison(benchmark, backend: str) -> None: - """Compare legacy vs rust-warp on the same representative small window.""" +def test_small_window_nearest_reprojection(benchmark) -> None: + """Benchmark the representative nearest-neighbor rust-warp path.""" request = _benchmark_request(dst_resolution=20.0) - benchmark(reproject_tile, request, backend=backend) + benchmark(reproject_tile, request) @pytest.mark.benchmark @@ -131,7 +130,6 @@ def test_small_window_reprojection_modes( nodata=reproject_request.nodata, resampling=resampling, ), - backend="rust-warp", ) @@ -315,8 +313,8 @@ def test_band_access_pattern( """Compare single-band vs multi-band compute cost. Uses the expanded 12-time-step dataset with ``chunks={"time": 1}`` so each - time step is a concurrent dask task. Multi-band reads share a single - ``rustac.search_sync`` query and reuse reprojection warp maps across bands; + time step is a concurrent dask task. Multi-band reads share a single + ``rustac.search_sync`` query and reproject all requested bands in one pass; this benchmark quantifies that gain under concurrent load. """ diff --git a/tests/integration_test.py b/tests/integration_test.py index 8289c94..59af4cc 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -11,7 +11,6 @@ from pyproj import Transformer import lazycogs -from lazycogs import _reproject logging.basicConfig(level="WARN") logging.getLogger("lazycogs").setLevel("DEBUG") @@ -81,30 +80,13 @@ def measure(label: str): ) -@contextlib.contextmanager -def _reproject_backend(backend: str): - """Temporarily force the internal reprojection backend for this script.""" - previous_backend = _reproject._DEFAULT_REPROJECT_BACKEND - _reproject._DEFAULT_REPROJECT_BACKEND = backend - try: - yield - finally: - _reproject._DEFAULT_REPROJECT_BACKEND = previous_backend - - def _parse_args() -> argparse.Namespace: """Parse command-line options for the integration script.""" parser = argparse.ArgumentParser() - parser.add_argument( - "--reproject-backend", - choices=["legacy", "rust-warp"], - default="legacy", - help="Internal reprojection backend to exercise during the run.", - ) return parser.parse_args() -async def run(reproject_backend: str): +async def run(): dst_crs = "epsg:5070" dst_bbox = (-700_000, 2_220_000, 600_000, 2_930_000) @@ -135,33 +117,30 @@ async def run(reproject_backend: str): limit=limit, ) - # --- daily time steps --- - with _reproject_backend(reproject_backend): - logger.warning("using reprojection backend: %s", reproject_backend) - store = lazycogs.store_for(str(items_parquet), skip_signature=True) - da = lazycogs.open( - str(items_parquet), - crs=dst_crs, - bbox=dst_bbox, - resolution=100, - time_period="P1D", - bands=["red", "green", "blue"], - dtype="int16", - store=store, - ) - logger.warning("daily array: %s", da) + store = lazycogs.store_for(str(items_parquet), skip_signature=True) + da = lazycogs.open( + str(items_parquet), + crs=dst_crs, + bbox=dst_bbox, + resolution=100, + time_period="P1D", + bands=["red", "green", "blue"], + dtype="int16", + store=store, + ) + logger.warning("daily array: %s", da) - with measure("daily point"): - _ = da.sel(x=299965, y=2653947, method="nearest").compute() + with measure("daily point"): + _ = da.sel(x=299965, y=2653947, method="nearest").compute() - subset = da.sel( - x=slice(100_000, 400_000), - y=slice(2_800_000, 2_600_000), - ) - with measure("daily spatial subset isel(time=1)"): - _ = subset.isel(time=1).load() + subset = da.sel( + x=slice(100_000, 400_000), + y=slice(2_800_000, 2_600_000), + ) + with measure("daily spatial subset isel(time=1)"): + _ = subset.isel(time=1).load() if __name__ == "__main__": args = _parse_args() - asyncio.run(run(args.reproject_backend)) + asyncio.run(run()) diff --git a/tests/test_backend.py b/tests/test_backend.py index 98e29fd..67b872e 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -12,6 +12,7 @@ from lazycogs._backend import MultiBandStacBackendArray from lazycogs._mosaic_methods import FirstMethod +from lazycogs._warp import ResamplingMethod @pytest.fixture @@ -242,7 +243,7 @@ def test_multiband_raw_getitem_forwards_resampling(wgs84): """The selected resampling mode reaches read_chunk_async unchanged.""" bands = ["B01", "B02"] multi = _make_multiband_array(wgs84, bands) - multi.resampling = "cubic" + multi.resampling = ResamplingMethod.CUBIC fake_items = [ {"id": "item-1", "assets": {b: {"href": f"s3://b/{b}.tif"} for b in bands}}, ] @@ -258,7 +259,9 @@ def test_multiband_raw_getitem_forwards_resampling(wgs84): ): multi._sync_getitem((slice(0, 2), 0, slice(0, 1), slice(0, 4))) - assert read_chunk_async_mock.await_args.kwargs["resampling"] == "cubic" + assert ( + read_chunk_async_mock.await_args.kwargs["resampling"] is ResamplingMethod.CUBIC + ) def test_multiband_raw_getitem_squeeze_band(wgs84): diff --git a/tests/test_chunk_reader.py b/tests/test_chunk_reader.py index 2767e3e..5ea3e73 100644 --- a/tests/test_chunk_reader.py +++ b/tests/test_chunk_reader.py @@ -10,13 +10,14 @@ from pyproj import CRS from lazycogs._chunk_reader import ( - _apply_bands_with_warp_cache, _drain_in_order, _native_window, + _reproject_bands, _select_overview, read_chunk_async, ) from lazycogs._mosaic_methods import FirstMethod +from lazycogs._warp import ResamplingMethod # --------------------------------------------------------------------------- # Helpers @@ -262,7 +263,7 @@ def is_done() -> bool: # --------------------------------------------------------------------------- -# _apply_bands_with_warp_cache +# _reproject_bands # --------------------------------------------------------------------------- @@ -273,156 +274,26 @@ def _make_raster(transform: Affine, value: float, h: int = 4, w: int = 4) -> Mag return raster -def test_apply_bands_with_warp_cache_same_grid_bypasses_reproject_tile(): - """Same-grid reads return directly without invoking the reprojection backend.""" +def test_reproject_bands_same_grid_returns_original_array(): + """Same-grid reads return the original array through ``reproject_tile``.""" crs = CRS.from_epsg(4326) transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 4.0) raster = _make_raster(transform, 1.0) - with patch("lazycogs._chunk_reader.reproject_tile") as reproject_tile_mock: - results = _apply_bands_with_warp_cache( - [("B01", raster, crs, None)], - transform, - crs, - dst_width=4, - dst_height=4, - resampling="cubic", - ) - - reproject_tile_mock.assert_not_called() - np.testing.assert_array_equal(results["B01"][0], raster.data) - - -def test_apply_bands_with_warp_cache_shared_geometry(): - """Bands with the same transform/CRS share a single warp map computation.""" - crs = CRS.from_epsg(4326) - # Offset the source transform slightly so it differs from dst_transform and - # the warp path (rather than the identity fast-path) is exercised. - transform = Affine(1.0, 0.0, 0.5, 0.0, -1.0, 4.0) - dst_transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 4.0) - - raster_a = _make_raster(transform, 1.0) - raster_b = _make_raster(transform, 2.0) - - warp_map_calls = [] - from lazycogs._reproject import compute_warp_map as real_compute_warp_map - - def _spy_compute_warp_map(*args, **kwargs): - result = real_compute_warp_map(*args, **kwargs) - warp_map_calls.append(True) - return result - - with ( - patch("lazycogs._reproject._DEFAULT_REPROJECT_BACKEND", "legacy"), - patch( - "lazycogs._reproject.compute_warp_map", - side_effect=_spy_compute_warp_map, - ), - ): - results = _apply_bands_with_warp_cache( - [("B01", raster_a, crs, None), ("B02", raster_b, crs, None)], - dst_transform, - crs, - dst_width=4, - dst_height=4, - ) - - # Same transform → warp map computed exactly once. - assert len(warp_map_calls) == 1 - assert set(results) == {"B01", "B02"} - np.testing.assert_array_equal(results["B01"][0], 1.0) - np.testing.assert_array_equal(results["B02"][0], 2.0) - - -def test_apply_bands_with_warp_cache_different_geometry(): - """Bands with different transforms each compute their own warp map.""" - crs = CRS.from_epsg(4326) - # Both source transforms differ from dst_transform so the warp path is - # exercised for each band (fast-path is not triggered). - transform_a = Affine(1.0, 0.0, 0.5, 0.0, -1.0, 4.0) - transform_b = Affine(2.0, 0.0, 0.0, 0.0, -2.0, 8.0) - dst_transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 4.0) - - raster_a = _make_raster(transform_a, 1.0) - raster_b = _make_raster(transform_b, 2.0, h=2, w=2) - - warp_map_calls = [] - from lazycogs._reproject import compute_warp_map as real_compute_warp_map - - def _spy_compute_warp_map(*args, **kwargs): - result = real_compute_warp_map(*args, **kwargs) - warp_map_calls.append(True) - return result - - with ( - patch("lazycogs._reproject._DEFAULT_REPROJECT_BACKEND", "legacy"), - patch( - "lazycogs._reproject.compute_warp_map", - side_effect=_spy_compute_warp_map, - ), - ): - results = _apply_bands_with_warp_cache( - [("B01", raster_a, crs, None), ("B02", raster_b, crs, None)], - dst_transform, - crs, - dst_width=4, - dst_height=4, - ) - - # Different transforms → two separate warp map computations. - assert len(warp_map_calls) == 2 - assert set(results) == {"B01", "B02"} - - -def test_apply_bands_with_warp_cache_shared_across_calls(): - """A shared cache reuses warp maps across separate calls (e.g. time steps).""" - crs = CRS.from_epsg(4326) - # Source transform offset from dst so the warp path is exercised. - transform = Affine(1.0, 0.0, 0.5, 0.0, -1.0, 4.0) - dst_transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 4.0) - - raster = _make_raster(transform, 1.0) - shared_cache: dict = {} - - warp_map_calls = [] - from lazycogs._reproject import compute_warp_map as real_compute_warp_map - - def _spy_compute_warp_map(*args, **kwargs): - result = real_compute_warp_map(*args, **kwargs) - warp_map_calls.append(True) - return result - - with ( - patch("lazycogs._reproject._DEFAULT_REPROJECT_BACKEND", "legacy"), - patch( - "lazycogs._reproject.compute_warp_map", - side_effect=_spy_compute_warp_map, - ), - ): - _apply_bands_with_warp_cache( - [("B01", raster, crs, None)], - dst_transform, - crs, - dst_width=4, - dst_height=4, - warp_cache=shared_cache, - ) - _apply_bands_with_warp_cache( - [("B01", raster, crs, None)], - dst_transform, - crs, - dst_width=4, - dst_height=4, - warp_cache=shared_cache, - ) + results = _reproject_bands( + [("B01", raster, crs, None)], + transform, + crs, + dst_width=4, + dst_height=4, + resampling=ResamplingMethod.CUBIC, + ) - # Warp map computed only once despite two separate calls. - assert len(warp_map_calls) == 1 - assert len(shared_cache) == 1 + assert results["B01"][0] is raster.data -def test_apply_bands_with_warp_cache_uses_rust_backend_for_nearest_by_default(): - """The chunk-reader seam now defaults all reprojection methods to rust-warp.""" +def test_reproject_bands_forwards_resampling_to_rust_backend(): + """Non-trivial reprojection forwards the enum unchanged to rust-warp.""" crs = CRS.from_epsg(4326) src_transform = Affine(1.0, 0.0, 0.5, 0.0, -1.0, 4.0) dst_transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 4.0) @@ -430,20 +301,20 @@ def test_apply_bands_with_warp_cache_uses_rust_backend_for_nearest_by_default(): expected = np.full((1, 4, 4), 7.0, dtype=np.float32) with patch( - "lazycogs._reproject.reproject_array_rust_warp", + "lazycogs._warp.reproject_array", return_value=expected, ) as rust_backend_mock: - results = _apply_bands_with_warp_cache( + results = _reproject_bands( [("B01", raster, crs, None)], dst_transform, crs, dst_width=4, dst_height=4, - resampling="nearest", + resampling=ResamplingMethod.NEAREST, ) rust_backend_mock.assert_called_once() - assert rust_backend_mock.call_args.kwargs["resampling"] == "nearest" + assert rust_backend_mock.call_args.kwargs["resampling"] is ResamplingMethod.NEAREST np.testing.assert_array_equal(results["B01"][0], expected) diff --git a/tests/test_core.py b/tests/test_core.py index 298e4a7..a361e51 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -18,8 +18,8 @@ _build_time_steps, _smoketest_store, ) -from lazycogs._reproject import ResamplingMethod from lazycogs._temporal import _DayGrouper, _FixedDayGrouper, _MonthGrouper +from lazycogs._warp import ResamplingMethod def _items_to_arrow(items: list[dict]) -> rustac.DuckdbClient: @@ -320,9 +320,9 @@ def test_open_accepts_resampling_enum(tmp_path): assert da.attrs["_stac_backend"].resampling is ResamplingMethod.CUBIC -def test_open_rejects_unknown_resampling(): - """open() validates resampling once at API entry.""" - with pytest.raises(ValueError, match="Unsupported resampling"): +def test_open_rejects_string_resampling(): + """open() requires the public resampling enum at API entry.""" + with pytest.raises(TypeError, match="resampling must be a ResamplingMethod"): lazycogs.open( "items.parquet", bbox=(-93.5, 44.5, -93.0, 45.0), diff --git a/tests/test_explain.py b/tests/test_explain.py index 5dd61f7..9f3d868 100644 --- a/tests/test_explain.py +++ b/tests/test_explain.py @@ -2,7 +2,8 @@ from __future__ import annotations -from unittest.mock import patch +from collections.abc import Callable +from unittest.mock import AsyncMock, patch import numpy as np import pytest @@ -22,6 +23,7 @@ _roi_pixel_offsets, ) from lazycogs._mosaic_methods import FirstMethod +from lazycogs._warp import ResamplingMethod # --------------------------------------------------------------------------- # Fixtures @@ -46,6 +48,8 @@ def _make_backend( dst_width: int = 10, dst_height: int = 10, affine: Affine | None = None, + resampling: ResamplingMethod = ResamplingMethod.NEAREST, + path_from_href: Callable[[str], str] | None = None, ) -> MultiBandStacBackendArray: """Return a minimal MultiBandStacBackendArray for unit testing.""" if dates is None: @@ -70,6 +74,8 @@ def _make_backend( dtype=np.dtype("float32"), nodata=-9999.0, mosaic_method_cls=FirstMethod, + resampling=resampling, + path_from_href=path_from_href, ) @@ -81,6 +87,8 @@ def _make_da_with_backends( width: int = 10, height: int = 10, affine: Affine | None = None, + resampling: ResamplingMethod = ResamplingMethod.NEAREST, + path_from_href: Callable[[str], str] | None = None, ) -> xr.DataArray: """Return a minimal DataArray with stac_cog explain attrs attached.""" if affine is None: @@ -94,6 +102,8 @@ def _make_da_with_backends( dst_width=width, dst_height=height, affine=affine, + resampling=resampling, + path_from_href=path_from_href, ) time_coord = np.array(time_coords, dtype="datetime64[D]") @@ -622,3 +632,61 @@ def test_accessor_explain_query_count_not_multiplied_by_bands(wgs84): assert plan.total_chunk_reads == 6 # but only 2 DuckDB queries (T=2 x S=1), not 6 (B*T*S) assert mock_search.call_count == 2 + + +def test_accessor_explain_fetch_headers_uses_backend_read_context(wgs84): + """fetch_headers=True reuses backend path resolution settings.""" + + def path_fn(href: str) -> str: + return href.removeprefix("s3://bucket/") + + dates = ["2023-01-01/2023-01-01"] + time_coords = [np.datetime64("2023-01-01", "D")] + da = _make_da_with_backends( + wgs84, + dates=dates, + time_coords=time_coords, + bands=["red"], + width=4, + height=4, + resampling=ResamplingMethod.BILINEAR, + path_from_href=path_fn, + ) + + class _FakeGeoTIFF: + def __init__(self) -> None: + self.overviews: list[object] = [] + self.transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 4.0) + + class _FakeWindow: + col_off = 1 + row_off = 2 + width = 3 + height = 4 + + async def _fake_open_and_window(item: dict, band: str, ctx: object): + assert item["id"] == "item-0" + assert band == "red" + assert ctx.resampling == ResamplingMethod.NEAREST + assert ctx.path_fn is path_fn + geotiff = _FakeGeoTIFF() + return geotiff, geotiff, _FakeWindow(), "item-0.tif" + + with ( + patch("rustac.DuckdbClient.search") as mock_search, + patch( + "lazycogs._explain._open_and_window", + new=AsyncMock(side_effect=_fake_open_and_window), + ), + ): + mock_search.return_value = _fake_items("red", 1) + plan = da.lazycogs.explain(fetch_headers=True) + + assert mock_search.call_count == 1 + assert plan.total_chunk_reads == 1 + assert plan.total_cog_reads == 1 + read = plan.chunk_reads[0].cog_reads[0] + assert read.window_col_off == 1 + assert read.window_row_off == 2 + assert read.window_width == 3 + assert read.window_height == 4 diff --git a/tests/test_rasterio_parity.py b/tests/test_rasterio_parity.py index 8adaede..48735fc 100644 --- a/tests/test_rasterio_parity.py +++ b/tests/test_rasterio_parity.py @@ -15,13 +15,13 @@ from pyproj import CRS from lazycogs._chunk_reader import _native_window, _select_overview -from lazycogs._reproject import ( +from lazycogs._store import resolve +from lazycogs._warp import ( ReprojectRequest, ResamplingMethod, _get_transformer, reproject_tile, ) -from lazycogs._store import resolve _NEAREST_RESOLUTIONS = [ 10, @@ -72,7 +72,6 @@ async def _read_lazycogs( dst_crs: CRS, *, resampling: ResamplingMethod = ResamplingMethod.NEAREST, - backend: str | None = None, ) -> np.ndarray: """Run the lazycogs tile read + reprojection path for one chunk.""" store, path = resolve(href) @@ -123,7 +122,6 @@ async def _read_lazycogs( nodata=geotiff.nodata, resampling=resampling, ), - backend=backend, ) @@ -292,73 +290,6 @@ def test_parity_cross_crs(synthetic_cog: Path, resolution: int) -> None: ) -@pytest.mark.parametrize("resolution", [20, 60, 160]) -def test_nearest_legacy_matches_rust_warp_same_crs( - synthetic_cog: Path, - resolution: int, -) -> None: - """Migration-window A/B checks keep nearest same-CRS behavior aligned.""" - dst_crs = CRS.from_epsg(32632) - affine = _chunk_affine(float(resolution), _CENTER_UTM_X, _CENTER_UTM_Y) - - legacy_out = asyncio.run( - _read_lazycogs( - _href(synthetic_cog), - affine, - dst_crs, - backend="legacy", - ), - ) - rust_out = asyncio.run( - _read_lazycogs( - _href(synthetic_cog), - affine, - dst_crs, - backend="rust-warp", - ), - ) - - np.testing.assert_array_equal(rust_out, legacy_out) - - -@pytest.mark.parametrize("resolution", [20, 60, 160]) -def test_nearest_legacy_matches_rust_warp_cross_crs( - synthetic_cog: Path, - resolution: int, -) -> None: - """Migration-window A/B checks keep nearest cross-CRS behavior aligned.""" - src_crs = CRS.from_epsg(32632) - dst_crs = CRS.from_epsg(3035) - t = _get_transformer(src_crs, dst_crs) - cx_laea, cy_laea = t.transform(_CENTER_UTM_X, _CENTER_UTM_Y) - affine = _chunk_affine(float(resolution), cx_laea, cy_laea) - - legacy_out = asyncio.run( - _read_lazycogs( - _href(synthetic_cog), - affine, - dst_crs, - backend="legacy", - ), - ) - rust_out = asyncio.run( - _read_lazycogs( - _href(synthetic_cog), - affine, - dst_crs, - backend="rust-warp", - ), - ) - - _assert_parity( - rust_out, - legacy_out, - f"legacy_vs_rust cross_crs res={resolution}m", - max_differing_pixels=3, - max_abs_diff=2048 * 16 + 1, - ) - - @pytest.mark.parametrize(("resampling", "rasterio_resampling"), _INTERPOLATING_METHODS) @pytest.mark.parametrize("resolution", _INTERPOLATING_RESOLUTIONS) def test_interpolating_parity_same_crs( diff --git a/tests/test_reproject.py b/tests/test_reproject.py index 518b4d2..c975e96 100644 --- a/tests/test_reproject.py +++ b/tests/test_reproject.py @@ -1,4 +1,4 @@ -"""Tests for _reproject: reproject_array, compute_warp_map, apply_warp_map.""" +"""Tests for rust-warp-backed reprojection.""" from unittest.mock import patch @@ -7,16 +7,13 @@ from affine import Affine from pyproj import CRS -from lazycogs._reproject import ( +from lazycogs._warp import ( ReprojectRequest, ResamplingMethod, - WarpMap, - apply_warp_map, - compute_warp_map, - reproject_array, + _affine_to_rust_warp, + _normalize_crs, reproject_tile, ) -from lazycogs._rust_warp import _affine_to_rust_warp, _normalize_crs @pytest.fixture @@ -33,11 +30,37 @@ def _make_transform(minx: float, maxy: float, res: float) -> Affine: return Affine(res, 0.0, minx, 0.0, -res, maxy) +def _request( + data: np.ndarray, + src_transform: Affine, + src_crs: CRS, + dst_transform: Affine, + dst_crs: CRS, + dst_width: int, + dst_height: int, + *, + nodata: float | None = None, + resampling: ResamplingMethod = ResamplingMethod.NEAREST, +) -> ReprojectRequest: + """Build a ``ReprojectRequest`` for test cases.""" + return ReprojectRequest( + data=data, + src_transform=src_transform, + src_crs=src_crs, + dst_transform=dst_transform, + dst_crs=dst_crs, + dst_width=dst_width, + dst_height=dst_height, + nodata=nodata, + resampling=resampling, + ) + + def test_identity_same_crs_same_transform(wgs84): """Reprojecting to the identical grid returns the same values.""" transform = _make_transform(0.0, 3.0, 1.0) data = np.arange(9, dtype=np.float32).reshape(1, 3, 3) - out = reproject_array(data, transform, wgs84, transform, wgs84, 3, 3) + out = reproject_tile(_request(data, transform, wgs84, transform, wgs84, 3, 3)) np.testing.assert_array_equal(out, data) @@ -46,12 +69,14 @@ def test_output_shape(wgs84): src_transform = _make_transform(0.0, 2.0, 1.0) dst_transform = _make_transform(0.0, 4.0, 2.0) data = np.ones((2, 2, 2), dtype=np.float32) - out = reproject_array(data, src_transform, wgs84, dst_transform, wgs84, 1, 2) + out = reproject_tile( + _request(data, src_transform, wgs84, dst_transform, wgs84, 1, 2), + ) assert out.shape == (2, 2, 1) def test_reproject_tile_same_grid_returns_original_array(wgs84): - """The backend-neutral dispatcher short-circuits exact same-grid reads.""" + """Exact same-grid reads short-circuit without calling rust-warp.""" transform = _make_transform(0.0, 3.0, 1.0) data = np.arange(9, dtype=np.float32).reshape(1, 3, 3) @@ -70,40 +95,8 @@ def test_reproject_tile_same_grid_returns_original_array(wgs84): assert out is data -def test_reproject_tile_matches_legacy_wrapper(wgs84): - """The backend-neutral path preserves current nearest-neighbor behavior.""" - src_transform = _make_transform(0.0, 2.0, 1.0) - dst_transform = _make_transform(0.0, 4.0, 2.0) - data = np.ones((2, 2, 2), dtype=np.float32) - - request = ReprojectRequest( - data=data, - src_transform=src_transform, - src_crs=wgs84, - dst_transform=dst_transform, - dst_crs=wgs84, - dst_width=1, - dst_height=2, - nodata=-9999.0, - ) - - np.testing.assert_array_equal( - reproject_tile(request), - reproject_array( - data, - src_transform, - wgs84, - dst_transform, - wgs84, - 1, - 2, - nodata=-9999.0, - ), - ) - - -def test_reproject_tile_defaults_to_rust_warp_backend(wgs84): - """The default backend selection now routes nearest through rust-warp.""" +def test_reproject_tile_delegates_to_rust_warp(wgs84): + """Non-trivial reprojection calls the rust-warp adapter with the enum.""" src_transform = _make_transform(0.0, 2.0, 1.0) dst_transform = _make_transform(0.0, 2.0, 0.5) data = np.arange(4, dtype=np.float32).reshape(1, 2, 2) @@ -114,7 +107,7 @@ def _fake_rust_warp(**kwargs): return np.zeros((1, 4, 4), dtype=np.float32) with patch( - "lazycogs._reproject.reproject_array_rust_warp", + "lazycogs._warp.reproject_array", side_effect=_fake_rust_warp, ): out = reproject_tile( @@ -180,7 +173,6 @@ def test_reproject_tile_rust_warp_supports_expected_dtypes(wgs84, dtype): dst_height=4, nodata=-1.0, ), - backend="rust-warp", ) assert out.shape == (1, 4, 4) @@ -204,7 +196,6 @@ def test_reproject_tile_rust_warp_rejects_unsupported_dtype(wgs84): dst_width=4, dst_height=4, ), - backend="rust-warp", ) @@ -231,7 +222,6 @@ def test_reproject_tile_rust_warp_preserves_band_order(wgs84): dst_height=4, nodata=-9999.0, ), - backend="rust-warp", ) assert out.shape == (3, 4, 4) @@ -242,19 +232,20 @@ def test_reproject_tile_rust_warp_preserves_band_order(wgs84): def test_out_of_bounds_pixels_get_nodata(wgs84): """Destination pixels outside the source extent are filled with nodata.""" - src_transform = _make_transform(5.0, 5.0, 1.0) # covers x=5..8, y=2..5 + src_transform = _make_transform(5.0, 5.0, 1.0) data = np.ones((1, 3, 3), dtype=np.float32) - # Destination covers x=0..3, entirely outside source dst_transform = _make_transform(0.0, 3.0, 1.0) - out = reproject_array( - data, - src_transform, - wgs84, - dst_transform, - wgs84, - 3, - 3, - nodata=-9999.0, + out = reproject_tile( + _request( + data, + src_transform, + wgs84, + dst_transform, + wgs84, + 3, + 3, + nodata=-9999.0, + ), ) np.testing.assert_array_equal(out, -9999.0) @@ -264,7 +255,9 @@ def test_out_of_bounds_default_fill_is_zero(wgs84): src_transform = _make_transform(100.0, 100.0, 1.0) data = np.ones((1, 2, 2), dtype=np.float32) dst_transform = _make_transform(0.0, 2.0, 1.0) - out = reproject_array(data, src_transform, wgs84, dst_transform, wgs84, 2, 2) + out = reproject_tile( + _request(data, src_transform, wgs84, dst_transform, wgs84, 2, 2), + ) np.testing.assert_array_equal(out, 0.0) @@ -273,47 +266,38 @@ def test_dtype_preserved(wgs84): transform = _make_transform(0.0, 2.0, 1.0) for dtype in (np.uint8, np.int16, np.float64): data = np.zeros((1, 2, 2), dtype=dtype) - out = reproject_array(data, transform, wgs84, transform, wgs84, 2, 2) + out = reproject_tile(_request(data, transform, wgs84, transform, wgs84, 2, 2)) assert out.dtype == dtype def test_multiband_preserved(wgs84): """All bands are reprojected independently.""" transform = _make_transform(0.0, 2.0, 1.0) - data = np.stack( - [np.ones((2, 2), dtype=np.float32) * b for b in range(4)], - ) # shape (4, 2, 2) - out = reproject_array(data, transform, wgs84, transform, wgs84, 2, 2) + data = np.stack([np.ones((2, 2), dtype=np.float32) * b for b in range(4)]) + out = reproject_tile(_request(data, transform, wgs84, transform, wgs84, 2, 2)) assert out.shape == (4, 2, 2) for b in range(4): np.testing.assert_array_equal(out[b], b) def test_cross_crs_reproject(wgs84, utm32n): - """Reprojecting between WGS84 and UTM preserves values at matched pixels. - - We project a uniform field so that the exact pixel mapping doesn't matter — - every source pixel has the same value, so any valid sample should match. - """ - # UTM 32N chunk near central Europe: ~10 km at 1000 m resolution + """Reprojecting between WGS84 and UTM preserves values at matched pixels.""" utm_transform = _make_transform(500_000.0, 5_550_000.0, 1000.0) data = np.full((1, 10, 10), 42.0, dtype=np.float32) - - # Destination grid in WGS84, centred over the UTM source extent - # (which maps to roughly lon 9.0-9.14, lat 50.01-50.10) wgs84_transform = _make_transform(9.0, 50.1, 0.01) - out = reproject_array( - data, - utm_transform, - utm32n, - wgs84_transform, - wgs84, - 5, - 5, - nodata=0.0, + out = reproject_tile( + _request( + data, + utm_transform, + utm32n, + wgs84_transform, + wgs84, + 5, + 5, + nodata=0.0, + ), ) - # Any pixel that mapped back to a valid source location should be 42. valid_pixels = out[out != 0.0] assert len(valid_pixels) > 0 np.testing.assert_array_equal(valid_pixels, 42.0) @@ -321,96 +305,20 @@ def test_cross_crs_reproject(wgs84, utm32n): def test_partial_overlap_nodata(wgs84): """Pixels that fall outside the source extent use nodata; overlapping ones copy.""" - # 4x1 source strip along x=0..4 src_transform = _make_transform(0.0, 1.0, 1.0) data = np.full((1, 1, 4), 7.0, dtype=np.float32) - - # Destination covers x=2..6 — right half overlaps, left half does not dst_transform = _make_transform(2.0, 1.0, 1.0) - out = reproject_array( - data, - src_transform, - wgs84, - dst_transform, - wgs84, - 4, - 1, - nodata=-1.0, + out = reproject_tile( + _request( + data, + src_transform, + wgs84, + dst_transform, + wgs84, + 4, + 1, + nodata=-1.0, + ), ) - # x=2 and x=3 overlap source (values 7); x=4 and x=5 are outside np.testing.assert_array_equal(out[0, 0, :2], 7.0) np.testing.assert_array_equal(out[0, 0, 2:], -1.0) - - -# --------------------------------------------------------------------------- -# compute_warp_map / apply_warp_map -# --------------------------------------------------------------------------- - - -def test_compute_warp_map_returns_correct_shape(wgs84): - """WarpMap arrays have shape (dst_height, dst_width).""" - transform = _make_transform(0.0, 4.0, 1.0) - wm = compute_warp_map(transform, wgs84, transform, wgs84, dst_width=4, dst_height=3) - assert isinstance(wm, WarpMap) - assert wm.src_col_idx.shape == (3, 4) - assert wm.src_row_idx.shape == (3, 4) - - -def test_apply_warp_map_matches_reproject_array(wgs84): - """apply_warp_map with a precomputed map matches reproject_array.""" - src_transform = _make_transform(0.0, 3.0, 1.0) - dst_transform = _make_transform(0.0, 3.0, 1.0) - data = np.arange(9, dtype=np.float32).reshape(1, 3, 3) - - wm = compute_warp_map(src_transform, wgs84, dst_transform, wgs84, 3, 3) - out_warp = apply_warp_map(data, wm, nodata=0.0) - out_reproject = reproject_array( - data, - src_transform, - wgs84, - dst_transform, - wgs84, - 3, - 3, - nodata=0.0, - ) - np.testing.assert_array_equal(out_warp, out_reproject) - - -def test_apply_warp_map_reused_across_bands(wgs84): - """A single WarpMap applied to two bands matches reproject_array per band.""" - transform = _make_transform(0.0, 2.0, 1.0) - band_a = np.full((1, 2, 2), 1.0, dtype=np.float32) - band_b = np.full((1, 2, 2), 2.0, dtype=np.float32) - - wm = compute_warp_map(transform, wgs84, transform, wgs84, 2, 2) - - out_a = apply_warp_map(band_a, wm) - out_b = apply_warp_map(band_b, wm) - - np.testing.assert_array_equal( - out_a, - reproject_array(band_a, transform, wgs84, transform, wgs84, 2, 2), - ) - np.testing.assert_array_equal( - out_b, - reproject_array(band_b, transform, wgs84, transform, wgs84, 2, 2), - ) - - -def test_apply_warp_map_different_src_dimensions(wgs84): - """apply_warp_map derives valid mask from actual data shape, not stored metadata.""" - # Compute warp map for a 4x4 source extent. - src_transform = _make_transform(0.0, 4.0, 1.0) - dst_transform = _make_transform(0.0, 4.0, 1.0) - wm = compute_warp_map(src_transform, wgs84, dst_transform, wgs84, 4, 4) - - # Apply to a 3x3 source array — pixels that map to row/col >= 3 should use nodata. - data_small = np.ones((1, 3, 3), dtype=np.float32) - out = apply_warp_map(data_small, wm, nodata=-1.0) - - # Top-left 3x3 destination pixels map into the valid 3x3 source. - np.testing.assert_array_equal(out[0, :3, :3], 1.0) - # Bottom row and right column of destination map outside the 3x3 source. - np.testing.assert_array_equal(out[0, 3, :], -1.0) - np.testing.assert_array_equal(out[0, :, 3], -1.0) From 1095e8e53f4512354ca73ddda3a1e445da63c113 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Thu, 14 May 2026 05:26:18 -0500 Subject: [PATCH 7/7] test: add another integration test, include profiling in benchmarks --- scripts/format_benchmark_comparison.py | 102 +++++- tests/benchmarks/_profiling.py | 145 ++++++++ tests/benchmarks/bench_pipeline.py | 72 ++-- tests/integration_monthly_test.py | 391 ++++++++++++++++++++++ tests/test_format_benchmark_comparison.py | 94 +++++- 5 files changed, 757 insertions(+), 47 deletions(-) create mode 100644 tests/benchmarks/_profiling.py create mode 100644 tests/integration_monthly_test.py diff --git a/scripts/format_benchmark_comparison.py b/scripts/format_benchmark_comparison.py index c20cab4..57ee654 100755 --- a/scripts/format_benchmark_comparison.py +++ b/scripts/format_benchmark_comparison.py @@ -20,6 +20,11 @@ REGRESSION_THRESHOLD_PCT = 10 _SMALL_WINDOW_LABEL = "Small-window reprojection microbenchmarks" _END_TO_END_LABEL = "End-to-end benchmarks" +_RESOURCE_FIELDS = ( + ("profile_peak_rss_mb", "Peak RSS (MB)"), + ("profile_cpu_total_s", "CPU total (s)"), + ("profile_cpu_per_wall", "CPU/wall"), +) def find_file(pattern: str) -> Path: @@ -31,10 +36,16 @@ def find_file(pattern: str) -> Path: def load_benchmarks(path: Path) -> dict[str, dict]: - """Load benchmark stats keyed by test name from a pytest-benchmark JSON file.""" + """Load benchmark records keyed by test name from a pytest-benchmark JSON file.""" with path.open() as f: data = json.load(f) - return {b["name"]: b["stats"] for b in data["benchmarks"]} + return { + benchmark["name"]: { + "stats": benchmark["stats"], + "extra_info": benchmark.get("extra_info", {}), + } + for benchmark in data["benchmarks"] + } def _ms(seconds: float) -> str: @@ -49,24 +60,58 @@ def _classify_benchmark(name: str) -> str: return _END_TO_END_LABEL -def _comparison_row(name: str, baseline_stats: dict, pr_stats: dict) -> str: +def _change_display(baseline_value: float, pr_value: float) -> str: + """Return a signed percent-change display string.""" + if baseline_value == 0: + return "n/a" + pct = (pr_value - baseline_value) / baseline_value * 100 + sign = "+" if pct >= 0 else "" + return f"{sign}{pct:.1f}%" + + +def _comparison_row(name: str, baseline_record: dict, pr_record: dict) -> str: """Return one markdown table row for a benchmark present in both runs.""" - base_mean = baseline_stats["mean"] - pr_mean = pr_stats["mean"] + base_mean = baseline_record["stats"]["mean"] + pr_mean = pr_record["stats"]["mean"] if base_mean == 0: pct_display = "n/a" flag = "" else: pct = (pr_mean - base_mean) / base_mean * 100 - sign = "+" if pct >= 0 else "" - pct_display = f"{sign}{pct:.1f}%" + pct_display = _change_display(base_mean, pr_mean) flag = " :warning:" if pct > REGRESSION_THRESHOLD_PCT else "" base_ms, pr_ms = _ms(base_mean), _ms(pr_mean) return f"| `{name}` | {base_ms} | {pr_ms} | {pct_display}{flag} |" +def _resource_row( + name: str, + field: str, + baseline_record: dict, + pr_record: dict, +) -> str | None: + """Return one markdown row for a shared resource metric, if present.""" + baseline_extra = baseline_record.get("extra_info", {}) + pr_extra = pr_record.get("extra_info", {}) + if field not in baseline_extra or field not in pr_extra: + return None + + baseline_value = baseline_extra[field] + pr_value = pr_extra[field] + if not isinstance(baseline_value, int | float) or not isinstance( + pr_value, + int | float, + ): + return None + + return ( + f"| `{name}` | {baseline_value:.1f} | {pr_value:.1f} | " + f"{_change_display(float(baseline_value), float(pr_value))} |" + ) + + def _render_comparison_section( heading: str, names: list[str], @@ -86,6 +131,40 @@ def _render_comparison_section( ) +def _render_resource_section( + heading: str, + names: list[str], + baseline: dict[str, dict], + pr: dict[str, dict], +) -> str | None: + """Render resource-profile tables for one benchmark section.""" + field_sections: list[str] = [] + for field, label in _RESOURCE_FIELDS: + rows = [ + row + for name in names + if (row := _resource_row(name, field, baseline[name], pr[name])) is not None + ] + if not rows: + continue + field_sections.append( + "\n".join( + [ + f"#### Resource profile: {label}", + "", + f"| Test | Baseline {label} | PR {label} | Change |", + "|------|------------------:|-----------:|-------:|", + *rows, + ], + ), + ) + + if not field_sections: + return None + + return "\n\n".join([f"### {heading} resource profile", *field_sections]) + + def _render_name_list_section(heading: str, names: list[str]) -> str: """Render a markdown bullet list of benchmark names.""" lines = [f"## {heading}", "", *[f"- `{name}`" for name in names]] @@ -99,12 +178,16 @@ def generate_report(baseline: dict[str, dict], pr: dict[str, dict]) -> str: missing_names = sorted(set(baseline) - set(pr)) shared_sections: list[str] = [] + resource_sections: list[str] = [] for heading in (_END_TO_END_LABEL, _SMALL_WINDOW_LABEL): names = [name for name in shared_names if _classify_benchmark(name) == heading] if names: shared_sections.append( _render_comparison_section(heading, names, baseline, pr), ) + resource_section = _render_resource_section(heading, names, baseline, pr) + if resource_section is not None: + resource_sections.append(resource_section) body_parts = [ "", @@ -116,6 +199,11 @@ def generate_report(baseline: dict[str, dict], pr: dict[str, dict]) -> str: else: body_parts.append("No benchmarks were present in both runs.") + if resource_sections: + body_parts.extend( + ["", "## Resource Profile Comparison", "", *resource_sections], + ) + if new_names: body_parts.extend( ["", _render_name_list_section("New benchmarks in PR", new_names)], diff --git a/tests/benchmarks/_profiling.py b/tests/benchmarks/_profiling.py new file mode 100644 index 0000000..278060d --- /dev/null +++ b/tests/benchmarks/_profiling.py @@ -0,0 +1,145 @@ +"""Resource profiling helpers for benchmark tests.""" + +from __future__ import annotations + +import os +import resource +import threading +import time +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Self + + +@dataclass +class ResourceProfile: + """Resource profile captured during one benchmark run.""" + + wall_time_s: float + cpu_user_s: float + cpu_system_s: float + cpu_total_s: float + peak_rss_mb: float + rss_before_mb: float + rss_after_mb: float + max_cpu_pct_of_wall: float + sample_count: int + + +class ResourceSampler: + """Sample process RSS and CPU usage while a block runs.""" + + def __init__(self, interval_s: float = 0.01) -> None: + self.interval_s = interval_s + self._stop = threading.Event() + self._thread = threading.Thread(target=self._run, daemon=True) + self._peak_rss_mb = 0.0 + self._max_cpu_pct = 0.0 + self._sample_count = 0 + self._rss_before_mb = 0.0 + self._rss_after_mb = 0.0 + self._wall_start = 0.0 + self._wall_end = 0.0 + self._cpu_start: os.times_result | None = None + self._cpu_end: os.times_result | None = None + + def __enter__(self) -> Self: + self._rss_before_mb = _rss_mb() + self._wall_start = time.perf_counter() + self._cpu_start = os.times() + self._peak_rss_mb = self._rss_before_mb + self._thread.start() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self._stop.set() + self._thread.join() + self._cpu_end = os.times() + self._wall_end = time.perf_counter() + self._rss_after_mb = _rss_mb() + self._peak_rss_mb = max(self._peak_rss_mb, self._rss_after_mb, _peak_rss_mb()) + + def _run(self) -> None: + last_wall = time.perf_counter() + last_cpu = os.times() + while not self._stop.wait(self.interval_s): + now_wall = time.perf_counter() + now_cpu = os.times() + self._sample_count += 1 + self._peak_rss_mb = max(self._peak_rss_mb, _rss_mb(), _peak_rss_mb()) + wall_delta = now_wall - last_wall + cpu_delta = (now_cpu.user - last_cpu.user) + ( + now_cpu.system - last_cpu.system + ) + if wall_delta > 0: + self._max_cpu_pct = max( + self._max_cpu_pct, + 100.0 * cpu_delta / wall_delta, + ) + last_wall = now_wall + last_cpu = now_cpu + + def profile(self) -> ResourceProfile: + """Return the aggregated resource profile.""" + if self._cpu_start is None or self._cpu_end is None: + raise RuntimeError( + "ResourceSampler.profile() called before sampling finished.", + ) + cpu_user_s = self._cpu_end.user - self._cpu_start.user + cpu_system_s = self._cpu_end.system - self._cpu_start.system + wall_time_s = self._wall_end - self._wall_start + return ResourceProfile( + wall_time_s=wall_time_s, + cpu_user_s=cpu_user_s, + cpu_system_s=cpu_system_s, + cpu_total_s=cpu_user_s + cpu_system_s, + peak_rss_mb=self._peak_rss_mb, + rss_before_mb=self._rss_before_mb, + rss_after_mb=self._rss_after_mb, + max_cpu_pct_of_wall=self._max_cpu_pct, + sample_count=self._sample_count, + ) + + +def _rss_mb() -> float: + """Return current RSS of this process in MB on Linux.""" + with Path("/proc/self/status").open() as file_handle: + for line in file_handle: + if line.startswith("VmRSS:"): + return int(line.split()[1]) / 1024 + return float("nan") + + +def _peak_rss_mb() -> float: + """Return peak RSS of this process in MB on Linux.""" + usage = resource.getrusage(resource.RUSAGE_SELF) + return usage.ru_maxrss / 1024 + + +def add_resource_profile(benchmark, run_once: Callable[[], object]) -> None: + """Attach a one-shot resource profile to ``benchmark.extra_info``.""" + with ResourceSampler() as resource_sampler: + run_once() + profile = resource_sampler.profile() + benchmark.extra_info.update( + { + "profile_wall_s": round(profile.wall_time_s, 4), + "profile_cpu_user_s": round(profile.cpu_user_s, 4), + "profile_cpu_system_s": round(profile.cpu_system_s, 4), + "profile_cpu_total_s": round(profile.cpu_total_s, 4), + "profile_cpu_per_wall": round( + ( + profile.cpu_total_s / profile.wall_time_s + if profile.wall_time_s + else 0.0 + ), + 4, + ), + "profile_peak_rss_mb": round(profile.peak_rss_mb, 1), + "profile_rss_before_mb": round(profile.rss_before_mb, 1), + "profile_rss_after_mb": round(profile.rss_after_mb, 1), + "profile_peak_cpu_pct": round(profile.max_cpu_pct_of_wall, 1), + "profile_sample_count": profile.sample_count, + }, + ) diff --git a/tests/benchmarks/bench_pipeline.py b/tests/benchmarks/bench_pipeline.py index f2d2fda..39842bd 100644 --- a/tests/benchmarks/bench_pipeline.py +++ b/tests/benchmarks/bench_pipeline.py @@ -20,6 +20,7 @@ from lazycogs import FirstMethod, MedianMethod, MosaicMethodBase, set_reproject_workers from lazycogs._warp import ReprojectRequest, ResamplingMethod, reproject_tile +from ._profiling import add_resource_profile from .conftest import ( BENCHMARK_BBOX, BENCHMARK_CRS, @@ -71,12 +72,17 @@ def _benchmark_request( ) +def _profile_then_benchmark(benchmark, run): + """Attach one-shot resource info, then run the pytest benchmark.""" + add_resource_profile(benchmark, run) + benchmark(run) + + @pytest.mark.benchmark def test_small_window_nearest_reprojection(benchmark) -> None: """Benchmark the representative nearest-neighbor rust-warp path.""" request = _benchmark_request(dst_resolution=20.0) - - benchmark(reproject_tile, request) + _profile_then_benchmark(benchmark, lambda: reproject_tile(request)) @pytest.mark.benchmark @@ -117,20 +123,23 @@ def test_small_window_reprojection_modes( ) -> None: """Show the cost gap between no-op, nearest, and interpolating modes.""" benchmark.extra_info["mode"] = label - benchmark( - reproject_tile, - ReprojectRequest( - data=reproject_request.data, - src_transform=reproject_request.src_transform, - src_crs=reproject_request.src_crs, - dst_transform=reproject_request.dst_transform, - dst_crs=reproject_request.dst_crs, - dst_width=reproject_request.dst_width, - dst_height=reproject_request.dst_height, - nodata=reproject_request.nodata, - resampling=resampling, - ), - ) + + def run() -> np.ndarray: + return reproject_tile( + ReprojectRequest( + data=reproject_request.data, + src_transform=reproject_request.src_transform, + src_crs=reproject_request.src_crs, + dst_transform=reproject_request.dst_transform, + dst_crs=reproject_request.dst_crs, + dst_width=reproject_request.dst_width, + dst_height=reproject_request.dst_height, + nodata=reproject_request.nodata, + resampling=resampling, + ), + ) + + _profile_then_benchmark(benchmark, run) @pytest.mark.benchmark @@ -144,14 +153,17 @@ def test_open_overhead( Measures parquet queries, band discovery, time-step building, and grid computation. """ - benchmark( - lazycogs.open, - benchmark_parquet, - bbox=BENCHMARK_BBOX, - crs=BENCHMARK_CRS, - resolution=60.0, - **benchmark_open_kwargs, - ) + + def run() -> object: + return lazycogs.open( + benchmark_parquet, + bbox=BENCHMARK_BBOX, + crs=BENCHMARK_CRS, + resolution=60.0, + **benchmark_open_kwargs, + ) + + _profile_then_benchmark(benchmark, run) @pytest.mark.benchmark @@ -172,7 +184,7 @@ def run() -> object: ) return da.compute() - benchmark(run) + _profile_then_benchmark(benchmark, run) @pytest.mark.benchmark @@ -197,7 +209,7 @@ def run() -> object: ) return da.compute() - benchmark(run) + _profile_then_benchmark(benchmark, run) @pytest.mark.benchmark @@ -230,7 +242,7 @@ def run() -> object: return da.compute() try: - benchmark(run) + _profile_then_benchmark(benchmark, run) finally: # Reset to default so other benchmarks are not affected. set_reproject_workers(min(__import__("os").cpu_count() or 4, 4)) @@ -260,7 +272,7 @@ def run() -> object: ) return da.compute() - benchmark(run) + _profile_then_benchmark(benchmark, run) @pytest.mark.benchmark @@ -295,7 +307,7 @@ def run() -> object: ) return da.compute() - benchmark(run) + _profile_then_benchmark(benchmark, run) @pytest.mark.benchmark @@ -331,4 +343,4 @@ def run() -> object: ) return da.compute() - benchmark(run) + _profile_then_benchmark(benchmark, run) diff --git a/tests/integration_monthly_test.py b/tests/integration_monthly_test.py new file mode 100644 index 0000000..2703a0c --- /dev/null +++ b/tests/integration_monthly_test.py @@ -0,0 +1,391 @@ +import argparse +import asyncio +import contextlib +import hashlib +import json +import logging +import os +import resource +import threading +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Self + +import rustac +from async_geotiff import GeoTIFF, Overview +from pyproj import Transformer + +import lazycogs +from lazycogs import _backend, _chunk_reader + +logging.basicConfig(level="WARN") +logging.getLogger("lazycogs").setLevel("DEBUG") +logger = logging.getLogger(__name__) + + +@dataclass +class PhaseStats: + """Aggregate timing and count information for one phase.""" + + calls: int = 0 + total_s: float = 0.0 + + def add(self, elapsed_s: float) -> None: + """Record one timed call.""" + self.calls += 1 + self.total_s += elapsed_s + + +@dataclass +class LoadBreakdown: + """Timing breakdown for the monthly load path.""" + + duckdb_search: PhaseStats + geotiff_open: PhaseStats + geotiff_read: PhaseStats + overview_read: PhaseStats + reproject: PhaseStats + + def log_summary(self, total_load_s: float) -> None: + """Log a compact summary of where load time was spent.""" + parts = [ + ("duckdb_search", self.duckdb_search), + ("geotiff_open", self.geotiff_open), + ("geotiff_read", self.geotiff_read), + ("overview_read", self.overview_read), + ("reproject", self.reproject), + ] + + logger.warning("[monthly load] wall_time=%.2fs", total_load_s) + logger.warning( + "[monthly load] phase timings below are cumulative across " + "concurrent calls, so they can exceed wall time.", + ) + for name, stats in parts: + avg_ms = 1000.0 * stats.total_s / stats.calls if stats.calls else 0.0 + logger.warning( + "[monthly load] %-14s cumulative=%.2fs calls=%d avg=%.2fms", + name, + stats.total_s, + stats.calls, + avg_ms, + ) + + +@dataclass +class ResourceProfile: + """Resource profile captured during the monthly load.""" + + wall_time_s: float + cpu_user_s: float + cpu_system_s: float + cpu_total_s: float + peak_rss_mb: float + rss_before_mb: float + rss_after_mb: float + max_cpu_pct_of_wall: float + sample_count: int + + def log_summary(self) -> None: + """Log CPU and memory usage observed during the load.""" + logger.warning( + "[monthly load resources] wall=%.2fs cpu_user=%.2fs cpu_system=%.2fs " + "cpu_total=%.2fs cpu/wall=%.2fx peak_rss=%.0fMB rss_before=%.0fMB " + "rss_after=%.0fMB peak_cpu_pct_of_wall=%.0f%% samples=%d", + self.wall_time_s, + self.cpu_user_s, + self.cpu_system_s, + self.cpu_total_s, + self.cpu_total_s / self.wall_time_s if self.wall_time_s else 0.0, + self.peak_rss_mb, + self.rss_before_mb, + self.rss_after_mb, + self.max_cpu_pct_of_wall, + self.sample_count, + ) + + +class ResourceSampler: + """Sample process RSS and CPU usage while a block runs.""" + + def __init__(self, interval_s: float = 0.2) -> None: + self.interval_s = interval_s + self._stop = threading.Event() + self._thread = threading.Thread(target=self._run, daemon=True) + self._peak_rss_mb = 0.0 + self._max_cpu_pct = 0.0 + self._sample_count = 0 + self._rss_before_mb = 0.0 + self._rss_after_mb = 0.0 + self._wall_start = 0.0 + self._wall_end = 0.0 + self._cpu_start: os.times_result | None = None + self._cpu_end: os.times_result | None = None + + def __enter__(self) -> Self: + self._rss_before_mb = _rss_mb() + self._wall_start = time.perf_counter() + self._cpu_start = os.times() + self._peak_rss_mb = self._rss_before_mb + self._thread.start() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self._stop.set() + self._thread.join() + self._cpu_end = os.times() + self._wall_end = time.perf_counter() + self._rss_after_mb = _rss_mb() + self._peak_rss_mb = max(self._peak_rss_mb, self._rss_after_mb, _peak_rss_mb()) + + def _run(self) -> None: + last_wall = time.perf_counter() + last_cpu = os.times() + while not self._stop.wait(self.interval_s): + now_wall = time.perf_counter() + now_cpu = os.times() + self._sample_count += 1 + self._peak_rss_mb = max(self._peak_rss_mb, _rss_mb(), _peak_rss_mb()) + wall_delta = now_wall - last_wall + cpu_delta = (now_cpu.user - last_cpu.user) + ( + now_cpu.system - last_cpu.system + ) + if wall_delta > 0: + self._max_cpu_pct = max( + self._max_cpu_pct, + 100.0 * cpu_delta / wall_delta, + ) + last_wall = now_wall + last_cpu = now_cpu + + def profile(self) -> ResourceProfile: + """Return the aggregated resource profile.""" + if self._cpu_start is None or self._cpu_end is None: + raise RuntimeError( + "ResourceSampler.profile() called before sampling finished.", + ) + cpu_user_s = self._cpu_end.user - self._cpu_start.user + cpu_system_s = self._cpu_end.system - self._cpu_start.system + wall_time_s = self._wall_end - self._wall_start + return ResourceProfile( + wall_time_s=wall_time_s, + cpu_user_s=cpu_user_s, + cpu_system_s=cpu_system_s, + cpu_total_s=cpu_user_s + cpu_system_s, + peak_rss_mb=self._peak_rss_mb, + rss_before_mb=self._rss_before_mb, + rss_after_mb=self._rss_after_mb, + max_cpu_pct_of_wall=self._max_cpu_pct, + sample_count=self._sample_count, + ) + + +def _parquet_path( + href: str, + collections: list[str], + datetime: str, + bbox: list[float], + limit: int, +) -> Path: + """Return a cache path for a STAC search derived from its parameters.""" + params = { + "href": href, + "collections": sorted(collections), + "datetime": datetime, + "bbox": [round(v, 6) for v in bbox], + "limit": limit, + } + digest = hashlib.sha256(json.dumps(params, sort_keys=True).encode()).hexdigest()[ + :12 + ] + return Path(f"/tmp/stac_{digest}.parquet") + + +def _rss_mb() -> float: + """Return current RSS of this process in MB on Linux.""" + with Path("/proc/self/status").open() as file_handle: + for line in file_handle: + if line.startswith("VmRSS:"): + return int(line.split()[1]) / 1024 + return float("nan") + + +def _peak_rss_mb() -> float: + """Return peak RSS of this process in MB on Linux.""" + usage = resource.getrusage(resource.RUSAGE_SELF) + return usage.ru_maxrss / 1024 + + +@contextlib.contextmanager +def measure(label: str): + """Log wall time and RSS change for a block.""" + rss_before = _rss_mb() + t0 = time.perf_counter() + yield + elapsed = time.perf_counter() - t0 + rss_after = _rss_mb() + logger.warning( + "[%s] time=%.2fs rss_before=%.0fMB rss_after=%.0fMB delta=%+.0fMB", + label, + elapsed, + rss_before, + rss_after, + rss_after - rss_before, + ) + + +@contextlib.contextmanager +def instrument_monthly_load() -> LoadBreakdown: + """Patch internal helpers to measure monthly load sub-phases.""" + search_stats = PhaseStats() + geotiff_open_stats = PhaseStats() + geotiff_read_stats = PhaseStats() + overview_read_stats = PhaseStats() + reproject_stats = PhaseStats() + + original_search = _backend._search_items_sync + reproject_helper_name = ( + "_reproject_bands" + if hasattr(_chunk_reader, "_reproject_bands") + else "_apply_bands_with_warp_cache" + ) + original_reproject = getattr(_chunk_reader, reproject_helper_name) + original_geotiff_open = GeoTIFF.open + original_geotiff_read = GeoTIFF.read + original_overview_read = Overview.read + + def timed_search(*args, **kwargs): + t0 = time.perf_counter() + result = original_search(*args, **kwargs) + search_stats.add(time.perf_counter() - t0) + return result + + def timed_reproject(*args, **kwargs): + t0 = time.perf_counter() + result = original_reproject(*args, **kwargs) + reproject_stats.add(time.perf_counter() - t0) + return result + + async def timed_geotiff_open(cls, *args, **kwargs): + t0 = time.perf_counter() + result = await original_geotiff_open(*args, **kwargs) + geotiff_open_stats.add(time.perf_counter() - t0) + return result + + async def timed_geotiff_read(self, *args, **kwargs): + t0 = time.perf_counter() + result = await original_geotiff_read(self, *args, **kwargs) + geotiff_read_stats.add(time.perf_counter() - t0) + return result + + async def timed_overview_read(self, *args, **kwargs): + t0 = time.perf_counter() + result = await original_overview_read(self, *args, **kwargs) + overview_read_stats.add(time.perf_counter() - t0) + return result + + _backend._search_items_sync = timed_search + setattr(_chunk_reader, reproject_helper_name, timed_reproject) + GeoTIFF.open = classmethod(timed_geotiff_open) + GeoTIFF.read = timed_geotiff_read + Overview.read = timed_overview_read + + try: + yield LoadBreakdown( + duckdb_search=search_stats, + geotiff_open=geotiff_open_stats, + geotiff_read=geotiff_read_stats, + overview_read=overview_read_stats, + reproject=reproject_stats, + ) + finally: + _backend._search_items_sync = original_search + setattr(_chunk_reader, reproject_helper_name, original_reproject) + GeoTIFF.open = original_geotiff_open + GeoTIFF.read = original_geotiff_read + Overview.read = original_overview_read + + +def _parse_args() -> argparse.Namespace: + """Parse command-line options for the integration script.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--skip-explain", + action="store_true", + help="Skip the explain() dry run before loading data.", + ) + return parser.parse_args() + + +async def run(*, skip_explain: bool = False) -> None: + """Run the monthly southwest scenario outside Jupyter.""" + dst_crs = "epsg:3310" + dst_bbox = (-444_000, -609_000, 681_000, 500_000) + + stac_href = "https://earth-search.aws.element84.com/v1" + collections = ["sentinel-2-c1-l2a"] + datetime = "2025-03-01/2025-06-30" + limit = 100 + + transformer = Transformer.from_crs(dst_crs, "epsg:4326", always_xy=True) + bbox_4326 = list(transformer.transform_bounds(*dst_bbox)) + + items_parquet = _parquet_path( + href=stac_href, + collections=collections, + datetime=datetime, + bbox=bbox_4326, + limit=limit, + ) + logger.warning("cache: %s", items_parquet) + + if not items_parquet.exists(): + with measure("search_to parquet cache"): + await rustac.search_to( + str(items_parquet), + href=stac_href, + collections=collections, + datetime=datetime, + bbox=bbox_4326, + limit=limit, + ) + + store = lazycogs.store_for(str(items_parquet), skip_signature=True) + + with measure("monthly open"): + ca_monthly = lazycogs.open( + str(items_parquet), + crs=dst_crs, + bbox=dst_bbox, + resolution=300, + time_period="P1M", + bands=["red", "green", "blue"], + dtype="int16", + filter="eo:cloud_cover < 50", + sortby="eo:cloud_cover", + store=store, + ) + logger.warning("monthly array: %s", ca_monthly) + + ca_may = ca_monthly.chunk(x=1024, y=1024).sel(time="2025-05-01") + logger.warning("monthly may chunked array: %s", ca_may) + + if not skip_explain: + with measure("monthly explain"): + plan = ca_may.lazycogs.explain(fetch_headers=True) + logger.warning("\n%s", plan.summary()) + + with instrument_monthly_load() as breakdown, ResourceSampler() as resource_sampler: + t0 = time.perf_counter() + loaded = await ca_may.load_async() + total_load_s = time.perf_counter() - t0 + + logger.warning("loaded shape: %s", loaded.shape) + breakdown.log_summary(total_load_s) + resource_sampler.profile().log_summary() + + +if __name__ == "__main__": + args = _parse_args() + asyncio.run(run(skip_explain=args.skip_explain)) diff --git a/tests/test_format_benchmark_comparison.py b/tests/test_format_benchmark_comparison.py index ccea2b6..57db179 100644 --- a/tests/test_format_benchmark_comparison.py +++ b/tests/test_format_benchmark_comparison.py @@ -17,17 +17,54 @@ def test_generate_report_includes_sections_for_shared_new_and_missing() -> None: - """The report should surface changed, added, and removed benchmarks.""" + """The report should surface changed, added, removed, and resource-profile data.""" baseline = { - "test_open_overhead": {"mean": 0.010}, - "test_small_window_reprojection_modes[nearest]": {"mean": 0.002}, - "test_removed_benchmark": {"mean": 0.005}, + "test_open_overhead": { + "stats": {"mean": 0.010}, + "extra_info": { + "profile_peak_rss_mb": 100.0, + "profile_cpu_total_s": 1.5, + "profile_cpu_per_wall": 1.2, + }, + }, + "test_small_window_reprojection_modes[nearest]": { + "stats": {"mean": 0.002}, + "extra_info": { + "profile_peak_rss_mb": 50.0, + "profile_cpu_total_s": 0.2, + "profile_cpu_per_wall": 0.8, + }, + }, + "test_removed_benchmark": { + "stats": {"mean": 0.005}, + "extra_info": {}, + }, } pr = { - "test_open_overhead": {"mean": 0.012}, - "test_small_window_reprojection_modes[nearest]": {"mean": 0.001}, - "test_small_window_reprojection_modes[cubic]": {"mean": 0.003}, - "test_new_end_to_end_benchmark": {"mean": 0.020}, + "test_open_overhead": { + "stats": {"mean": 0.012}, + "extra_info": { + "profile_peak_rss_mb": 120.0, + "profile_cpu_total_s": 1.1, + "profile_cpu_per_wall": 0.9, + }, + }, + "test_small_window_reprojection_modes[nearest]": { + "stats": {"mean": 0.001}, + "extra_info": { + "profile_peak_rss_mb": 40.0, + "profile_cpu_total_s": 0.1, + "profile_cpu_per_wall": 0.4, + }, + }, + "test_small_window_reprojection_modes[cubic]": { + "stats": {"mean": 0.003}, + "extra_info": {}, + }, + "test_new_end_to_end_benchmark": { + "stats": {"mean": 0.020}, + "extra_info": {}, + }, } report = format_benchmark_comparison.generate_report(baseline, pr) @@ -40,6 +77,16 @@ def test_generate_report_includes_sections_for_shared_new_and_missing() -> None: "| `test_small_window_reprojection_modes[nearest]` | 2.0 | 1.0 | -50.0% |" in report ) + assert "## Resource Profile Comparison" in report + assert "#### Resource profile: Peak RSS (MB)" in report + assert "| `test_open_overhead` | 100.0 | 120.0 | +20.0% |" in report + assert "#### Resource profile: CPU total (s)" in report + assert "| `test_open_overhead` | 1.5 | 1.1 | -26.7% |" in report + assert "#### Resource profile: CPU/wall" in report + assert ( + "| `test_small_window_reprojection_modes[nearest]` | 0.8 | 0.4 | -50.0% |" + in report + ) assert "## New benchmarks in PR" in report assert "`test_small_window_reprojection_modes[cubic]`" in report assert "`test_new_end_to_end_benchmark`" in report @@ -50,10 +97,37 @@ def test_generate_report_includes_sections_for_shared_new_and_missing() -> None: def test_generate_report_handles_empty_shared_benchmarks() -> None: """The report should still render when only added or removed tests exist.""" report = format_benchmark_comparison.generate_report( - baseline={"test_removed": {"mean": 0.005}}, - pr={"test_added": {"mean": 0.007}}, + baseline={"test_removed": {"stats": {"mean": 0.005}, "extra_info": {}}}, + pr={"test_added": {"stats": {"mean": 0.007}, "extra_info": {}}}, ) assert "No benchmarks were present in both runs." in report assert "`test_added`" in report assert "`test_removed`" in report + + +def test_load_benchmarks_keeps_stats_and_extra_info(tmp_path: Path) -> None: + """The loader should preserve profiling metadata from pytest-benchmark JSON.""" + benchmark_json = tmp_path / "bench.json" + benchmark_json.write_text( + """ +{ + "benchmarks": [ + { + "name": "test_full_compute", + "stats": {"mean": 0.123}, + "extra_info": {"profile_peak_rss_mb": 256.0} + } + ] +} +""".strip(), + ) + + loaded = format_benchmark_comparison.load_benchmarks(benchmark_json) + + assert loaded == { + "test_full_compute": { + "stats": {"mean": 0.123}, + "extra_info": {"profile_peak_rss_mb": 256.0}, + }, + }