diff --git a/.gitignore b/.gitignore index 83c6ce6..c0e496a 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,4 @@ cython_debug/ .DS_Store # pixi environments -.pixi +.pixi \ No newline at end of file diff --git a/README.md b/README.md index a270156..4948e83 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ _An Xarray extension for Google Earth Engine._ +Xee bridges the gap between Google Earth Engine's massive data catalog and the scientific Python ecosystem. It provides a custom Xarray backend that allows you to open any `ee.ImageCollection` as if it were a local `xarray.Dataset`. Data is loaded lazily and in parallel, enabling you to work with petabyte-scale archives of satellite and climate data using the power and flexibility of Xarray and its integrations with libraries like Dask. + [![image](https://img.shields.io/pypi/v/xee.svg)](https://pypi.python.org/pypi/xee) [![image](https://static.pepy.tech/badge/xee)](https://pepy.tech/project/xee) [![Conda @@ -32,85 +34,206 @@ Then, authenticate Earth Engine: earthengine authenticate --quiet ``` -Now, in your Python environment, make the following imports: +Now, in your Python environment, make the following imports and initialize the Earth Engine client with your project ID. Using the high-volume API endpoint is recommended. ```python import ee -import xarray +import xarray as xr +from xee import helpers +import shapely + +ee.Initialize( + project='PROJECT-ID', # Replace with your project ID + opt_url='https://earthengine-highvolume.googleapis.com' +) ``` -Next, specify your EE-registered cloud project ID and initialize the EE client -with the high volume API: +### Specifying the Output Grid + +To open a dataset, you must specify the desired output pixel grid. The `xee.helpers` module simplifies this process by providing several convenient workflows, summarized below. + +| Goal | Method | When to Use | +| :--- | :--- | :--- | +| **Match Source Grid** | Use `helpers.extract_grid_params()` to get the parameters from an EE object. | When you want the data in its original, default projection and scale. | +| **Fit Area to a Shape** | Use `helpers.fit_geometry()` with the `geometry` and `grid_shape` arguments. | When you need a consistent output array size (e.g., for ML models) and the exact pixel size is less important. | +| **Fit Area to a Scale** | Use `helpers.fit_geometry()` with the `geometry` and `grid_scale` arguments. | When the specific resolution (e.g., 30 meters, 0.01 degrees) is critical for your analysis. | +| **Manual Override** | Pass `crs`, `crs_transform`, and `shape_2d` directly to `xr.open_dataset`. | For advanced cases where you already have an exact grid definition. | + +> **Important Note on Units:** All grid parameter values must be in the units of the specified Coordinate Reference System (`crs`). +> * For a geographic CRS like `'EPSG:4326'`, the units are in **degrees**. +> * For a projected CRS like `'EPSG:32610'` (UTM), the units are in **meters**. +> This applies to the translation values in `crs_transform` and the pixel sizes in `grid_scale`. + +### Usage Examples + +Here are common workflows for opening datasets with `xee`, corresponding to the methods in the table above. + +#### Match Source Grid + +This is the simplest case, using `helpers.extract_grid_params` to match the dataset's default grid. ```python -ee.Initialize( - project='my-project-id' - opt_url='https://earthengine-highvolume.googleapis.com') +ic = ee.ImageCollection('ECMWF/ERA5_LAND/MONTHLY_AGGR') +grid_params = helpers.extract_grid_params(ic) +ds = xr.open_dataset(ic, engine='ee', **grid_params) ``` -Open any Earth Engine ImageCollection by specifying the Xarray engine as `'ee'`: +#### Fit Area to a Shape + +Define a grid over an area of interest by specifying the number of pixels. `helpers.fit_geometry` will calculate the correct `crs_transform`. ```python -ds = xarray.open_dataset('ee://ECMWF/ERA5_LAND/HOURLY', engine='ee') +aoi = shapely.geometry.box(113.33, -43.63, 153.56, -10.66) # Australia +grid_params = helpers.fit_geometry( + geometry=aoi, + grid_crs='EPSG:4326', + grid_shape=(256, 256) +) + +ds = xr.open_dataset('ee://ECMWF/ERA5_LAND/MONTHLY_AGGR', engine='ee', **grid_params) ``` -Open all bands in a specific projection (not the Xee default): +#### Fit Area to a Scale (Resolution) + +> **A Note on `grid_scale` and Y-Scale Orientation** +> When using `fit_geometry` with `grid_scale`, you are defining both the pixel size and the grid's orientation via the sign of the y-scale. +> * A **negative `y_scale`** (e.g., `(10000, -10000)`) is the standard for "north-up" satellite and aerial imagery, creating a grid with a **top-left** origin. +> * A **positive `y_scale`** (e.g., `(10000, 10000)`) is used by some datasets and creates a grid with a **bottom-left** origin. +> You may need to inspect your source dataset's projection information to determine the correct sign to use. If you use `grid_shape`, a standard negative y-scale is assumed. + +The following example defines a grid over an area by specifying the pixel size in meters. `fit_geometry` will reproject the geometry and calculate the correct `shape_2d`. ```python -ds = xarray.open_dataset('ee://ECMWF/ERA5_LAND/HOURLY', engine='ee', - crs='EPSG:4326', scale=0.25) +aoi = shapely.geometry.box(113.33, -43.63, 153.56, -10.66) # Australia +grid_params = helpers.fit_geometry( + geometry=aoi, + geometry_crs='EPSG:4326', # CRS of the input geometry + grid_crs='EPSG:32662', # Target CRS in meters (Plate Carrée) + grid_scale=(10000, -10000) # Define a 10km pixel size +) + +ds = xr.open_dataset('ee://ECMWF/ERA5_LAND/MONTHLY_AGGR', engine='ee', **grid_params) ``` -Open an ImageCollection (maybe, with EE-side filtering or processing): +#### Open a Custom Region at Source Resolution + +This workflow is ideal for analyzing a specific area while maintaining the dataset's original resolution. ```python -ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate( - '1992-10-05', '1993-03-31') -ds = xarray.open_dataset(ic, engine='ee', crs='EPSG:4326', scale=0.25) +# 1. Get the original grid parameters from the target ImageCollection +ic = ee.ImageCollection('ECMWF/ERA5_LAND/MONTHLY_AGGR') +source_params = helpers.extract_grid_params(ic) + +# 2. Extract the source CRS and scale +source_crs = source_params['crs'] +source_transform = source_params['crs_transform'] +source_scale = (source_transform[0], source_transform[4]) # (x_scale, y_scale) + +# 3. Use the source parameters to fit the grid to a specific geometry +aoi = shapely.geometry.box(113.33, -43.63, 153.56, -10.66) # Australia +final_grid_params = helpers.fit_geometry( + geometry=aoi, + geometry_crs='EPSG:4326', + grid_crs=source_crs, + grid_scale=source_scale +) + +# 4. Open the dataset with the final, combined parameters +ds = xr.open_dataset(ic, engine='ee', **final_grid_params) ``` -Open an ImageCollection with a specific EE projection or geometry: +#### Manual Override + +For use cases where you know the exact grid parameters, you can provide them directly. ```python -ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate( - '1992-10-05', '1993-03-31') -leg1 = ee.Geometry.Rectangle(113.33, -43.63, 153.56, -10.66) -ds = xarray.open_dataset( - ic, +# Manually define a 512x512 pixel grid with 1-degree pixels in EPSG:4326 +manual_crs = 'EPSG:4326' +manual_transform = (0.1, 0, -180.05, 0, -0.1, 90.05) # Values are in degrees +manual_shape = (512, 512) + +ds = xr.open_dataset( + 'ee://ECMWF/ERA5_LAND/MONTHLY_AGGR', engine='ee', - projection=ic.first().select(0).projection(), - geometry=leg1 + crs=manual_crs, + crs_transform=manual_transform, + shape_2d=manual_shape, ) ``` -Open multiple ImageCollections into one `xarray.Dataset`, all with the same -projection: +#### Open a Pre-Processed ImageCollection + +A key feature of Xee is its ability to open a computed `ee.ImageCollection`. This allows you to leverage Earth Engine's powerful server-side processing for tasks like filtering, band selection, and calculations before loading the data into Xarray. ```python -ds = xarray.open_mfdataset( - ['ee://ECMWF/ERA5_LAND/HOURLY', 'ee://NASA/GDDP-CMIP6'], - engine='ee', crs='EPSG:4326', scale=0.25) +# Define an AOI as a shapely object for the helper function +sf_aoi_shapely = shapely.geometry.Point(-122.4, 37.7).buffer(0.2) +# Create an ee.Geometry from the shapely object for server-side filtering +coords = list(sf_aoi_shapely.exterior.coords) +sf_aoi_ee = ee.Geometry.Polygon(coords) + +# Define a function to calculate NDVI and add it as a band +def add_ndvi(image): + # Landsat 9 SR bands: NIR = B5, Red = B4 + ndvi = image.normalizedDifference(['SR_B5', 'SR_B4']).rename('NDVI') + return image.addBands(ndvi) + +# Build the pre-processed collection +processed_collection = (ee.ImageCollection('LANDSAT/LC09/C02/T1_L2') + .filterDate('2024-06-01', '2024-09-01') + .filterBounds(sf_aoi_ee) + .map(add_ndvi) + .select(['NDVI'])) + +# Define the output grid using a helper +grid_params = helpers.fit_geometry( + geometry=sf_aoi_shapely, + grid_crs='EPSG:32610', # Target CRS in meters (UTM Zone 10N) + grid_scale=(30, -30) # Use Landsat's 30m resolution +) + +# Open the fully processed collection +ds = xr.open_dataset(processed_collection, engine='ee', **grid_params) ``` -Open a single Image by passing it to an ImageCollection: +#### Open a single Image + +The `helpers` work the same way for a single `ee.Image`. ```python -i = ee.ImageCollection(ee.Image('LANDSAT/LC08/C02/T1_TOA/LC08_044034_20140318')) -ds = xarray.open_dataset(i, engine='ee') +img = ee.Image('ECMWF/ERA5_LAND/MONTHLY_AGGR/202501') +grid_params = helpers.extract_grid_params(img) +ds = xr.open_dataset(img, engine='ee', **grid_params) +``` + +#### Visualize a Single Time Slice + +Once you have your `xarray.Dataset`, you can visualize a single time slice of a variable to verify the results. This requires the `matplotlib` library, which is an optional dependency. + +If you don't have it installed, you can add it with pip: + +```shell +pip install matplotlib ``` -Open any Earth Engine ImageCollection to match an existing transform: +Xarray's plotting functions expect dimensions in `(y, x)` order for 2D plots. Since the data is in `(x, y)` order, we use `.transpose()` to swap the axes for correct visualization. ```python -raster = rioxarray.open_rasterio(...) # assume crs + transform is set -ds = xr.open_dataset( - 'ee://ECMWF/ERA5_LAND/HOURLY', - engine='ee', - geometry=tuple(raster.rio.bounds()), # must be in EPSG:4326 - projection=ee.Projection( - crs=str(raster.rio.crs), transform=raster.rio.transform()[:6] - ), + +# First, open a dataset using one of the methods above +aoi = shapely.geometry.box(113.33, -43.63, 153.56, -10.66) # Australia +grid_params = helpers.fit_geometry( + geometry=aoi, + grid_crs='EPSG:4326', + grid_shape=(256, 256) ) +ds = xr.open_dataset('ECMWF/ERA5_LAND/MONTHLY_AGGR', engine='ee', **grid_params) + +# Select the 2m air temperature for the first time step +temp_slice = ds['temperature_2m'].isel(time=0) + +# Transpose from (x, y) to (y, x) for correct plotting orientation and plot +temp_slice.transpose('y', 'x').plot() ``` See [examples](https://github.com/google/Xee/tree/main/examples) or diff --git a/pyproject.toml b/pyproject.toml index 48cd94a..6c956fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "xee" dynamic = ["version"] description = "A Google Earth Engine extension for Xarray." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.8,<3.13" license = {text = "Apache-2.0"} authors = [ {name = "Google LLC", email = "noreply@google.com"}, @@ -28,6 +28,7 @@ dependencies = [ "earthengine-api>=0.1.374", "pyproj", "affine", + "shapely", ] [project.entry-points."xarray.backends"] @@ -65,5 +66,8 @@ preview = true pyink-indentation = 2 pyink-use-majority-quotes = true +[tool.setuptools] +packages = ["xee"] + [tool.setuptools_scm] fallback_version = "9999" diff --git a/xee/ext.py b/xee/ext.py index c6890ad..1187b1a 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -34,8 +34,6 @@ import affine import numpy as np import pandas as pd -import pyproj -from pyproj.crs import CRS import xarray from xarray import backends from xarray.backends import common @@ -61,6 +59,12 @@ # data as a single chunk. Chunks = Union[int, Dict[Any, Any], Literal['auto'], None] +# Types for type hints +CrsType = str +TransformType = Union[ + Tuple[float, float, float, float, float, float], affine.Affine +] +ShapeType = Tuple[int, int] _BUILTIN_DTYPES = { 'int': np.int32, @@ -78,6 +82,14 @@ # value was chosen by trial and error. _TO_LIST_WARNING_LIMIT = 10000 +EE_AFFINE_TRANSFORM_FIELDS = [ + 'scaleX', + 'shearX', + 'translateX', + 'shearY', + 'scaleY', + 'translateY' +] # Used in ext_test.py. def _check_request_limit(chunks: Dict[str, int], dtype_size: int, limit: int): @@ -122,13 +134,6 @@ class EarthEngineStore(common.AbstractDataStore): 'm': 10_000, } - DIMENSION_NAMES: Dict[str, Tuple[str, str]] = { - 'degree': ('lon', 'lat'), - 'metre': ('X', 'Y'), - 'meter': ('X', 'Y'), - 'm': ('X', 'Y'), - } - DEFAULT_MASK_VALUE = np.iinfo(np.int32).max ATTRS_VALID_TYPES = ( @@ -146,13 +151,12 @@ class EarthEngineStore(common.AbstractDataStore): def open( cls, image_collection: ee.ImageCollection, + crs: CrsType, + crs_transform: TransformType, + shape_2d: ShapeType, mode: Literal['r'] = 'r', chunk_store: Chunks = None, n_images: int = -1, - crs: Optional[str] = None, - scale: Optional[float] = None, - projection: Optional[ee.Projection] = None, - geometry: ee.Geometry | Tuple[float, float, float, float] | None = None, primary_dim_name: Optional[str] = None, primary_dim_property: Optional[str] = None, mask_value: Optional[float] = None, @@ -170,12 +174,11 @@ def open( return cls( image_collection, + crs=crs, + crs_transform=crs_transform, + shape_2d=shape_2d, chunks=chunk_store, n_images=n_images, - crs=crs, - scale=scale, - projection=projection, - geometry=geometry, primary_dim_name=primary_dim_name, primary_dim_property=primary_dim_property, mask_value=mask_value, @@ -190,12 +193,11 @@ def open( def __init__( self, image_collection: ee.ImageCollection, + crs: CrsType, + crs_transform: TransformType, + shape_2d: ShapeType, chunks: Chunks = None, n_images: int = -1, - crs: Optional[str] = None, - scale: Union[float, int, None] = None, - projection: Optional[ee.Projection] = None, - geometry: ee.Geometry | Tuple[float, float, float, float] | None = None, primary_dim_name: Optional[str] = None, primary_dim_property: Optional[str] = None, mask_value: Optional[float] = None, @@ -206,6 +208,21 @@ def __init__( getitem_kwargs: Optional[Dict[str, int]] = None, fast_time_slicing: bool = False, ): + # Ensure crs_transform is a tuple and create the affine.Affine object. + if isinstance(crs_transform, affine.Affine): + self.affine_transform = crs_transform + crs_transform = ( + crs_transform.a, + crs_transform.b, + crs_transform.c, + crs_transform.d, + crs_transform.e, + crs_transform.f, + ) + elif not isinstance(crs_transform, tuple): + raise TypeError('crs_transform must be an affine.Affine object or a tuple.') + else: + self.affine_transform = affine.Affine(*crs_transform) self.ee_init_kwargs = ee_init_kwargs self.ee_init_if_necessary = ee_init_if_necessary self.fast_time_slicing = fast_time_slicing @@ -221,8 +238,10 @@ def __init__( if n_images != -1: self.image_collection = image_collection.limit(n_images) - self.projection = projection - self.geometry = geometry + self.crs = crs + self.crs_transform = crs_transform + self.shape_2d = shape_2d + self.primary_dim_name = primary_dim_name or 'time' self.primary_dim_property = primary_dim_property or 'system:time_start' @@ -231,36 +250,10 @@ def __init__( # Metadata should apply to all imgs. self._img_info: types.ImageInfo = self.get_info['first'] - proj = self.get_info.get('projection', {}) - - self.crs_arg = crs or proj.get('crs', proj.get('wkt', 'EPSG:4326')) - self.crs = CRS(self.crs_arg) - - is_crs_geographic = self.crs.is_geographic - # Gets the unit i.e. meter, degree etc. - self.scale_units = 'degree' if is_crs_geographic else 'meter' - # Get the dimensions name based on the CRS (scale units). - self.dimension_names = self.DIMENSION_NAMES.get( - self.scale_units, ('X', 'Y') - ) - x_dim_name, y_dim_name = self.dimension_names - self._props.update( - coordinates=f'{self.primary_dim_name} {x_dim_name} {y_dim_name}', - crs=self.crs_arg, - ) + self.dimension_names = ('x', 'y') self._props = self._make_attrs_valid(self._props) - # Scale in the projection's units. Typically, either meters or degrees. - # If we use the default CRS i.e. EPSG:3857, the units is in meters. - default_scale = self.SCALE_UNITS.get(self.scale_units, 1) - if scale is None: - scale = default_scale - default_transform = affine.Affine.scale(scale, scale) - - transform = affine.Affine(*proj.get('transform', default_transform)[:6]) - self.scale_x, self.scale_y = transform.a, transform.e - self.scale = np.sqrt(np.abs(transform.determinant)) - - self.bounds = self._determine_bounds(geometry=geometry) + self.scale_x, self.scale_y = crs_transform[0], crs_transform[4] + self.scale = np.sqrt(np.abs(self.affine_transform.determinant)) max_dtype = self._max_itemsize() @@ -288,20 +281,6 @@ def get_info(self) -> Dict[str, Any]: ('first', self.image_collection.first()), ] - if isinstance(self.projection, ee.Projection): - rpcs.append(('projection', self.projection)) - - if isinstance(self.geometry, ee.Geometry): - rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection))) - else: - rpcs.append( - ( - 'bounds', - self.image_collection.first() - .geometry() - .bounds(1, proj=self.projection), - ) - ) # TODO(#29, #30): This RPC call takes the longest time to compute. This # requires a full scan of the images in the collection, which happens on the @@ -416,11 +395,6 @@ def _assign_preferred_chunks(self) -> Chunks: chunks[y_dim_name] = self.chunks['height'] return chunks - def transform(self, xs: float, ys: float) -> Tuple[float, float]: - transformer = pyproj.Transformer.from_crs( - self.crs.geodetic_crs, self.crs, always_xy=True - ) - return transformer.transform(xs, ys) def project(self, bbox: types.BBox) -> types.Grid: """Translate a bounding box (pixel space) to a grid (projection space). @@ -436,41 +410,23 @@ def project(self, bbox: types.BBox) -> types.Grid: appropriate region of data to return according to the Array's configured projection and scale. """ - x_min, y_min, x_max, y_max = self.bounds x_start, y_start, x_end, y_end = bbox - width = x_end - x_start - height = y_end - y_start - - # Find the actual coordinates of the first or last point of the bounding box - # (bbox) based on the positive and negative scale in the actual Earth Engine - # (EE) image. Since EE bounding boxes can be flipped (negative scale), we - # cannot determine the origin (transform translation) simply by identifying - # the min and max extents. Instead, we calculate the translation by - # considering the direction of scaling (positive or negative) along both - # the x and y axes. - translate_x = self.scale_x * x_start + ( - x_min if self.scale_x > 0 else x_max - ) - translate_y = self.scale_y * y_start + ( - y_min if self.scale_y > 0 else y_max - ) + + # Translate the crs_transform to the origin of the bounding box + transform_grid_cell = affine.Affine.translation( + xoff=x_start * self.affine_transform.a, + yoff=y_start * self.affine_transform.e + ) * self.affine_transform return { # The size of the bounding box. The affine transform and project will be # applied, so we can think in terms of pixels. 'dimensions': { - 'width': width, - 'height': height, + 'width': x_end - x_start, + 'height': y_end - y_start, }, - 'affineTransform': { - 'translateX': translate_x, - 'translateY': translate_y, - # Define the scale for each pixel (e.g. the number of meters between - # each value). - 'scaleX': self.scale_x, - 'scaleY': self.scale_y, - }, - 'crsCode': self.crs_arg, + 'affineTransform': dict(zip(EE_AFFINE_TRANSFORM_FIELDS, transform_grid_cell)), + 'crsCode': self.crs, } def image_to_array( @@ -576,10 +532,8 @@ def open_store_variable(self, name: str) -> xarray.Variable: encoding = { 'source': attrs['id'], 'scale_factor': arr.scale, - 'scale_units': self.scale_units, 'dtype': data.dtype, 'preferred_chunks': self.preferred_chunks, - 'bounds': arr.bounds, } return xarray.Variable(dimensions, data, attrs, encoding) @@ -606,74 +560,6 @@ def _get_primary_coordinates(self) -> List[Any]: ] return primary_coords - def _get_tile_from_ee( - self, tile_and_band: Tuple[Tuple[int, int, int], str] - ) -> Tuple[int, np.ndarray[Any, np.dtype]]: - """Get a numpy array from EE for a specific bounding box (a 'tile').""" - (tile_index, tile_coords_start, tile_coords_end), band_id = tile_and_band - bbox = self.project( - (tile_coords_start, 0, tile_coords_end, 1) - if band_id == 'x' - else (0, tile_coords_start, 1, tile_coords_end) - ) - target_image = ee.Image.pixelCoordinates(ee.Projection(self.crs_arg)) - return tile_index, self.image_to_array( - target_image, grid=bbox, dtype=np.float64, bandIds=[band_id] - ) - - def _process_coordinate_data( - self, - tile_count: int, - tile_size: int, - end_point: int, - coordinate_type: str, - ) -> np.ndarray: - """Process coordinate data using multithreading for longitude or latitude.""" - data = [ - (i, tile_size * i, min(tile_size * (i + 1), end_point)) - for i in range(tile_count) - ] - tiles = [None] * tile_count - with concurrent.futures.ThreadPoolExecutor(**self.executor_kwargs) as pool: - for i, arr in pool.map( - self._get_tile_from_ee, - list(zip(data, itertools.cycle([coordinate_type]))), - ): - tiles[i] = arr.flatten() - return np.concatenate(tiles) - - def _determine_bounds( - self, - geometry: ee.Geometry | Tuple[float, float, float, float] | None = None, - ) -> Tuple[float, float, float, float]: - if geometry is None: - try: - x_min_0, y_min_0, x_max_0, y_max_0 = self.crs.area_of_use.bounds - except AttributeError: - # `area_of_use` is probably `None`. Parse the geometry from the first - # image instead (calculated in self.get_info()) - x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds( - self.get_info['bounds'] - ) - elif isinstance(geometry, ee.Geometry): - x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds( - self.get_info['bounds'] - ) - elif isinstance(geometry, Sequence): - if len(geometry) != 4: - raise ValueError( - 'geometry must be a tuple or list of length 4, or a ee.Geometry, ' - f'but got {geometry!r}' - ) - x_min_0, y_min_0, x_max_0, y_max_0 = geometry - else: - raise ValueError( - 'geometry must be a tuple or list of length 4, a ee.Geometry, or' - f' None but got {type(geometry)}' - ) - x_min, y_min = self.transform(x_min_0, y_min_0) - x_max, y_max = self.transform(x_max_0, y_max_0) - return x_min, y_min, x_max, y_max def get_variables(self) -> utils.Frozen[str, xarray.Variable]: vars_ = [(name, self.open_store_variable(name)) for name in self._bands()] @@ -690,26 +576,10 @@ def get_variables(self) -> utils.Frozen[str, xarray.Variable]: f'ImageCollection due to: {e}.' ) - if isinstance(self.chunks, dict): - # when the value of self.chunks = 'auto' or user-defined. - width_chunk = self.chunks['width'] - height_chunk = self.chunks['height'] - else: - # when the value of self.chunks = -1. - width_chunk = v0.shape[1] - height_chunk = v0.shape[2] - - lon_total_tile = math.ceil(v0.shape[1] / width_chunk) - lon = self._process_coordinate_data( - lon_total_tile, width_chunk, v0.shape[1], 'x' - ) - lat_total_tile = math.ceil(v0.shape[2] / height_chunk) - lat = self._process_coordinate_data( - lat_total_tile, height_chunk, v0.shape[2], 'y' - ) - - width_coord = np.squeeze(lon) - height_coord = np.squeeze(lat) + x_scale, _, x_translate, _, y_scale, y_translate = self.crs_transform + width, height = self.shape_2d + width_coord = np.array([x_translate + x_scale / 2 + ix * x_scale for ix in range(width)]) + height_coord = np.array([y_translate + y_scale / 2 + iy * y_scale for iy in range(height)]) # Make sure there's at least a single point in each dimension. if width_coord.ndim == 0: @@ -782,19 +652,13 @@ def __init__(self, variable_name: str, ee_store: EarthEngineStore): self.store = ee_store self.scale = ee_store.scale - self.bounds = ee_store.bounds # It looks like different bands have different dimensions & transforms! # Can we get this into consistent dimensions? self._info = ee_store._band_attrs(variable_name) self.dtype = np.dtype(np.float32) - x_min, y_min, x_max, y_max = self.bounds - # Make sure the size is at least 1x1. - x_size = max(1, int(np.round((x_max - x_min) / np.abs(self.store.scale_x)))) - y_size = max(1, int(np.round((y_max - y_min) / np.abs(self.store.scale_y)))) - - self.shape = (ee_store.n_images, x_size, y_size) + self.shape = (ee_store.n_images, ) + ee_store.shape_2d self._apparent_chunks = {k: 1 for k in self.store.PREFERRED_CHUNKS.keys()} if isinstance(self.store.chunks, dict): self._apparent_chunks = self.store.chunks.copy() @@ -1014,6 +878,9 @@ def guess_can_open( def open_dataset( self, filename_or_obj: Union[str, os.PathLike[Any], ee.ImageCollection], + crs: CrsType, + crs_transform: TransformType, + shape_2d: ShapeType, drop_variables: Optional[Tuple[str, ...]] = None, io_chunks: Optional[Any] = None, n_images: int = -1, @@ -1023,10 +890,6 @@ def open_dataset( use_cftime: Optional[bool] = None, concat_characters: bool = True, decode_coords: bool = True, - crs: Optional[str] = None, - scale: Union[float, int, None] = None, - projection: Optional[ee.Projection] = None, - geometry: ee.Geometry | Tuple[float, float, float, float] | None = None, primary_dim_name: Optional[str] = None, primary_dim_property: Optional[str] = None, ee_mask_value: Optional[float] = None, @@ -1042,6 +905,12 @@ def open_dataset( Args: filename_or_obj: An asset ID for an ImageCollection, or an ee.ImageCollection object. + crs: The coordinate reference system (a CRS code or WKT + string). This defines the frame of reference to coalesce all variables + upon opening. + crs_transform: Transform matrix describing the grid origin and scale + relative to the CRS. + shape_2d: Dimensions of the pixel grid in the form (width, height). drop_variables (optional): Variables or bands to drop before opening. io_chunks (optional): Specifies the chunking strategy for loading data from EE. By default, this automatically calculates optional chunks based @@ -1076,22 +945,6 @@ def open_dataset( or individual variables as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. - crs (optional): The coordinate reference system (a CRS code or WKT - string). This defines the frame of reference to coalesce all variables - upon opening. By default, data is opened with `EPSG:4326'. - scale (optional): The scale in the `crs` or `projection`'s units of - measure -- either meters or degrees. This defines the scale that all - data is represented in upon opening. By default, the scale is 1° when - the CRS is in degrees or 10,000 when in meters. - projection (optional): Specify an `ee.Projection` object to define the - `scale` and `crs` (or other coordinate reference system) with which to - coalesce all variables upon opening. By default, the scale and reference - system is set by the the `crs` and `scale` arguments. - geometry (optional): Specify an `ee.Geometry` to define the regional - bounds when opening the data or a bbox specifying [x_min, y_min, x_max, - y_max] in EPSG:4326. When not set, the bounds are defined by - the CRS's 'area_of_use` boundaries. If those aren't present, the bounds - are derived from the geometry of the first image of the collection. primary_dim_name (optional): Override the name of the primary dimension of the output Dataset. By default, the name is 'time'. primary_dim_property (optional): Override the `ee.Image` property for @@ -1135,12 +988,11 @@ def open_dataset( store = EarthEngineStore.open( collection, + crs=crs, + crs_transform=crs_transform, + shape_2d=shape_2d, chunk_store=io_chunks, n_images=n_images, - crs=crs, - scale=scale, - projection=projection, - geometry=geometry, primary_dim_name=primary_dim_name, primary_dim_property=primary_dim_property, mask_value=ee_mask_value, diff --git a/xee/ext_integration_test.py b/xee/ext_integration_test.py index ad1414a..d9eca92 100644 --- a/xee/ext_integration_test.py +++ b/xee/ext_integration_test.py @@ -21,9 +21,11 @@ from absl.testing import absltest from google.auth import identity_pool import numpy as np +import shapely import xarray as xr from xarray.core import indexing import xee +from xee import helpers import ee @@ -40,6 +42,12 @@ 'https://www.googleapis.com/auth/earthengine', ] +# Define grid parameters for tests +_TEST_GRID_PARAMS = { + 'crs': 'EPSG:4326', + 'crs_transform': (1.0, 0, -180.0, 0, -1.0, 90.0), + 'shape_2d': (360, 180) +} def _read_identity_pool_creds() -> identity_pool.Credentials: credentials_path = os.environ[_CREDENTIALS_PATH_KEY] @@ -53,6 +61,7 @@ def init_ee_for_tests(): init_params = { 'opt_url': ee.data.HIGH_VOLUME_API_BASE_URL, } + if _CREDENTIALS_PATH_KEY in os.environ: credentials = _read_identity_pool_creds() init_params['credentials'] = credentials @@ -71,16 +80,19 @@ def setUp(self): ), n_images=64, getitem_kwargs={'max_retries': 10, 'initial_delay': 1500}, + **_TEST_GRID_PARAMS, ) self.store_with_neg_mask_value = xee.EarthEngineStore( ee.ImageCollection('LANDSAT/LC08/C02/T1').filterDate( '2017-01-01', '2017-01-03' ), + **_TEST_GRID_PARAMS, n_images=64, mask_value=-9999, ) self.lnglat_store = xee.EarthEngineStore( ee.ImageCollection.fromImages([ee.Image.pixelLonLat()]), + **_TEST_GRID_PARAMS, chunks={'index': 256, 'width': 512, 'height': 512}, n_images=64, ) @@ -88,13 +100,15 @@ def setUp(self): ee.ImageCollection('GRIDMET/DROUGHT').filterDate( '2020-03-30', '2020-04-01' ), + **_TEST_GRID_PARAMS, n_images=64, getitem_kwargs={'max_retries': 9}, ) self.all_img_store = xee.EarthEngineStore( ee.ImageCollection('LANDSAT/LC08/C02/T1').filterDate( '2017-01-01', '2017-01-03' - ) + ), + **_TEST_GRID_PARAMS, ) def test_creates_lat_long_array(self): @@ -270,32 +284,6 @@ def __getitem__(self, params): self.assertEqual(getter.count, 3) - def test_geometry_bounds_with_and_without_projection(self): - image = ( - ee.ImageCollection('LANDSAT/LC08/C02/T1') - .filterDate('2017-01-01', '2017-01-03') - .first() - ) - point = ee.Geometry.Point(-40.2414893624401, 105.48790177216375) - distance = 311.5 - scale = 5000 - projection = ee.Projection('EPSG:4326', [1, 0, 0, 0, -1, 0]).atScale(scale) - image = image.reproject(projection) - - geometry = point.buffer(distance, proj=projection).bounds(proj=projection) - - data_store = xee.EarthEngineStore( - ee.ImageCollection(image), - projection=image.projection(), - geometry=geometry, - ) - data_store_bounds = data_store.get_info['bounds'] - - self.assertNotEqual(geometry.bounds().getInfo(), data_store_bounds) - self.assertEqual( - geometry.bounds(1, proj=projection).getInfo(), data_store_bounds - ) - def test_getitem_kwargs(self): arr = xee.EarthEngineBackendArray('B4', self.store) self.assertEqual(arr.store.getitem_kwargs['initial_delay'], 1500) @@ -336,67 +324,60 @@ def test_guess_can_open__image_collection(self): self.assertFalse(self.entry.guess_can_open('WRI/GPPD/power_plants')) def test_open_dataset__sanity_check(self): + """Test opening a simple image collection in geographic coordinates.""" + n_images, width, height = 3, 4, 5 ds = self.entry.open_dataset( - pathlib.Path('LANDSAT') / 'LC08' / 'C02' / 'T1', - drop_variables=tuple(f'B{i}' for i in range(3, 12)), - n_images=3, - projection=ee.Projection('EPSG:4326', [25, 0, 0, 0, -25, 0]), + pathlib.Path('ECMWF') / 'ERA5' / 'MONTHLY', + n_images=n_images, + crs='EPSG:4326', + crs_transform=(12.0, 0, -180.0, 0, -25.0, 90.0), + shape_2d=(width, height), ) - self.assertEqual(dict(ds.sizes), {'time': 3, 'lon': 14, 'lat': 7}) + self.assertEqual(dict(ds.sizes), {'time': 3, 'x': width, 'y': height}) self.assertNotEmpty(dict(ds.coords)) self.assertEqual( - list(ds.data_vars.keys()), - [f'B{i}' for i in range(1, 3)] - + ['QA_PIXEL', 'QA_RADSAT', 'SAA', 'SZA', 'VAA', 'VZA'], - ) + list(ds.data_vars.keys()), + [ + 'mean_2m_air_temperature', + 'minimum_2m_air_temperature', + 'maximum_2m_air_temperature', + 'dewpoint_2m_temperature', + 'total_precipitation', + 'surface_pressure', + 'mean_sea_level_pressure', + 'u_component_of_wind_10m', + 'v_component_of_wind_10m' + ] + ) + # Loop through the data variables. for v in ds.values(): self.assertIsNotNone(v.data) self.assertFalse(v.isnull().all(), 'All values are null!') - self.assertEqual(v.shape, (3, 14, 7)) + self.assertEqual(v.shape, (n_images, width, height)) - def test_open_dataset__sanity_check_with_negative_scale(self): - ds = self.entry.open_dataset( - pathlib.Path('LANDSAT') / 'LC08' / 'C02' / 'T1', - drop_variables=tuple(f'B{i}' for i in range(3, 12)), - scale=-25.0, # in degrees - n_images=3, - ) - self.assertEqual(dict(ds.sizes), {'time': 3, 'lon': 14, 'lat': 7}) - self.assertNotEmpty(dict(ds.coords)) - self.assertEqual( - list(ds.data_vars.keys()), - [f'B{i}' for i in range(1, 3)] - + ['QA_PIXEL', 'QA_RADSAT', 'SAA', 'SZA', 'VAA', 'VZA'], - ) - for v in ds.values(): - self.assertIsNotNone(v.data) - self.assertTrue(v.isnull().all(), 'All values must be null!') - self.assertEqual(v.shape, (3, 14, 7)) def test_open_dataset__n_images(self): ds = self.entry.open_dataset( pathlib.Path('LANDSAT') / 'LC08' / 'C02' / 'T1', drop_variables=tuple(f'B{i}' for i in range(3, 12)), n_images=1, - scale=25.0, # in degrees + **_TEST_GRID_PARAMS ) - self.assertLen(ds.time, 1) def test_open_dataset_image_to_imagecollection(self): """Ensure that opening an ee.Image is the same as opening a single image ee.ImageCollection.""" img = ee.Image('CGIAR/SRTM90_V4') ic = ee.ImageCollection(img) - ds1 = xr.open_dataset(img, engine='ee') - ds2 = xr.open_dataset(ic, engine='ee') + ds1 = xr.open_dataset(img, engine='ee', **_TEST_GRID_PARAMS) + ds2 = xr.open_dataset(ic, engine='ee', **_TEST_GRID_PARAMS) self.assertTrue(ds1.identical(ds2)) def test_can_chunk__opened_dataset(self): ds = xr.open_dataset( 'NASA/GPM_L3/IMERG_V07', - crs='EPSG:4326', - scale=0.25, engine=xee.EarthEngineBackendEntrypoint, + **_TEST_GRID_PARAMS ).isel(time=slice(0, 1)) try: @@ -404,40 +385,48 @@ def test_can_chunk__opened_dataset(self): except ValueError: self.fail('Chunking failed.') - def test_honors_geometry(self): - ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate( - '1992-10-05', '1993-03-31' - ) - leg1 = ee.Geometry.Rectangle(113.33, -43.63, 153.56, -10.66) + + def test_honors_geometry_simple_utm(self): + """Test that a non-geographic projection can be used.""" + ic = ee.ImageCollection([ + ee.Image('LANDSAT/LC09/C02/T1_L2/LC09_043034_20211116').select(0) + .addBands(ee.Image.pixelLonLat()), + ]) + min_x, max_x = 10, 12 + min_y, max_y = -4, 0 + width = max_x - min_x + height = max_y - min_y ds = xr.open_dataset( ic, engine=xee.EarthEngineBackendEntrypoint, - geometry=leg1, - ) - standard_ds = xr.open_dataset( - ic, - engine=xee.EarthEngineBackendEntrypoint, + crs='EPSG:32610', + crs_transform=(30, 0, 448485+103000, 0, -30, 4263915-84000), # Origin over SF + shape_2d=(width, height), ) - self.assertEqual(ds.sizes, {'time': 4248, 'lon': 40, 'lat': 35}) - self.assertNotEqual(ds.sizes, standard_ds.sizes) - - def test_honors_projection(self): - ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate( - '1992-10-05', '1993-03-31' + self.assertEqual(ds.sizes, {'time': 1, 'x': width, 'y': height}) + np.testing.assert_allclose( + ds['latitude'].values, + np.array([[ + [37.764977, 37.764706, 37.764435, 37.764164], + [37.764973, 37.7647 , 37.76443 , 37.764164] + ]]) ) - ds = xr.open_dataset( - ic, - engine=xee.EarthEngineBackendEntrypoint, - projection=ic.first().select(0).projection(), + np.testing.assert_allclose( + ds['longitude'].values, + np.array([[ + [-122.41528, -122.41529, -122.41529, -122.41529], + [-122.41495, -122.41495, -122.41495, -122.41495] + ]]) ) - standard_ds = xr.open_dataset( - ic, - engine=xee.EarthEngineBackendEntrypoint, + np.testing.assert_allclose( + ds['SR_B1'].values, + np.array([[ + [14332., 13622., 12058., 11264.], + [12254., 10379., 10701., 11150.] + ]]) ) - self.assertEqual(ds.sizes, {'time': 4248, 'lon': 3600, 'lat': 1800}) - self.assertNotEqual(ds.sizes, standard_ds.sizes) @absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded') def test_expected_precise_transform(self): @@ -467,32 +456,33 @@ def test_expected_precise_transform(self): xee_dataset = xr.open_dataset( ee.ImageCollection(ic), engine='ee', - geometry=tuple(raster.rio.bounds()), - projection=ee.Projection( - crs=str(raster.rio.crs), transform=raster.rio.transform()[:6] - ), - ).rename({'lon': 'x', 'lat': 'y'}) + crs=str(raster.rio.crs), + crs_transform=raster.rio.transform()[:6], + shape_2d=data.shape + ) self.assertNotEqual(abs(x_res), abs(y_res)) - np.testing.assert_equal( + np.testing.assert_allclose( np.array(xee_dataset.rio.transform()), np.array(raster.rio.transform()), ) def test_parses_ee_url(self): - ds = self.entry.open_dataset( - 'ee://LANDSAT/LC08/C02/T1', - drop_variables=tuple(f'B{i}' for i in range(3, 12)), - scale=25.0, # in degrees - n_images=3, + """Test the ee: URL parsing.""" + n_images, width, height = 3, 10, 20 + test_params = { + 'n_images': n_images, + 'crs': 'EPSG:4326', + 'crs_transform': (12.0, 0, -180.0, 0, -25.0, 90.0), + 'shape_2d': (width, height) + } + ds1 = self.entry.open_dataset('ee://LANDSAT/LC08/C02/T1', **test_params) + ds2 = self.entry.open_dataset('ee:LANDSAT/LC08/C02/T1', **test_params) + self.assertEqual(dict(ds1.sizes), {'time': n_images, 'x': width, 'y': height}) + self.assertEqual(dict(ds2.sizes), {'time': n_images, 'x': width, 'y': height}) + np.testing.assert_allclose( + ds1['B1'].compute().values, + ds2['B1'].compute().values ) - self.assertEqual(dict(ds.sizes), {'time': 3, 'lon': 14, 'lat': 7}) - ds = self.entry.open_dataset( - 'ee:LANDSAT/LC08/C02/T1', - drop_variables=tuple(f'B{i}' for i in range(3, 12)), - scale=25.0, # in degrees - n_images=3, - ) - self.assertEqual(dict(ds.sizes), {'time': 3, 'lon': 14, 'lat': 7}) def test_data_sanity_check(self): # This simple test uncovered a bug with the default definition of `scale`. @@ -502,6 +492,7 @@ def test_data_sanity_check(self): 'ECMWF/ERA5_LAND/HOURLY', engine=xee.EarthEngineBackendEntrypoint, n_images=1, + **_TEST_GRID_PARAMS ) temperature_2m = era5.isel(time=0).temperature_2m self.assertNotEqual(temperature_2m.min(), 0.0) @@ -511,8 +502,8 @@ def test_validate_band_attrs(self): ds = self.entry.open_dataset( 'ee:LANDSAT/LC08/C02/T1', drop_variables=tuple(f'B{i}' for i in range(3, 12)), - scale=25.0, # in degrees n_images=3, + **_TEST_GRID_PARAMS ) valid_types = (str, int, float, complex, np.ndarray, np.number, list, tuple) @@ -543,8 +534,9 @@ def test_fast_time_slicing(self): params = dict( filename_or_obj=fake_collection, engine=xee.EarthEngineBackendEntrypoint, - geometry=ee.Geometry.BBox(-83.86, 41.13, -76.83, 46.15), - projection=first.projection().atScale(100000), + crs='EPSG:4326', + crs_transform=(1, 0, -100, 0, 1, 50), + shape_2d=(3, 4), ) # With slow slicing, the returned data should include the modified image. @@ -554,7 +546,7 @@ def test_fast_time_slicing(self): # With fast slicing, the returned data should include the original image. fast_slicing = xr.open_dataset(**params, fast_time_slicing=True) - fast_slicing_data = getattr(fast_slicing[dict(time=0)], band).as_numpy() + fast_slicing_data = getattr(fast_slicing[dict(time=0)], band).as_numpy() self.assertTrue(np.all(fast_slicing_data > 0)) @absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded') @@ -564,26 +556,31 @@ def test_write_projected_dataset_to_raster(self): with tempfile.TemporaryDirectory() as temp_dir: temp_file = os.path.join(temp_dir, 'test.tif') - crs = 'epsg:32610' + crs = 'EPSG:32610' proj = ee.Projection(crs) - point = ee.Geometry.Point([-122.44, 37.78]) - geom = point.buffer(1024).bounds() + + point = shapely.geometry.Point(-122.44, 37.78) + ee_point = ee.Geometry.Point(list(point.coords)[0]) + # Create a collection of 10 low-cloud images intersecting a point. col = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') - col = col.filterBounds(point) + col = col.filterBounds(ee_point) col = col.filter(ee.Filter.lte('CLOUDY_PIXEL_PERCENTAGE', 5)) col = col.limit(10) + grid_dict = helpers.fit_geometry( + geometry=point.buffer(0.1), + grid_crs=crs, + grid_scale=(100, -100) + ) + ds = xr.open_dataset( col, engine=xee.EarthEngineBackendEntrypoint, - crs=crs, - geometry=geom, - projection=ee.Projection('EPSG:4326', [10, 0, 0, 0, -10, 0]), + **grid_dict ) - ds = ds.isel(time=0).transpose('Y', 'X') - ds.rio.set_spatial_dims(x_dim='X', y_dim='Y', inplace=True) + ds = ds.isel(time=0).transpose('y', 'x') ds.rio.write_crs(crs, inplace=True) ds.rio.reproject(crs, inplace=True) ds.rio.to_raster(temp_file) @@ -591,44 +588,97 @@ def test_write_projected_dataset_to_raster(self): with rasterio.open(temp_file) as raster: # see https://gis.stackexchange.com/a/407755 for evenOdd explanation bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False) - intersects = bbox.intersects(point, 1, proj=proj) + intersects = bbox.intersects(ee_point, 1, proj=proj) self.assertTrue(intersects.getInfo()) - @absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded') - def test_write_dataset_to_raster(self): - # ensure that a dataset written to a raster intersects with the point used - # to create the initial image collection - with tempfile.TemporaryDirectory() as temp_dir: - temp_file = os.path.join(temp_dir, 'test.tif') - crs = 'EPSG:4326' - proj = ee.Projection(crs) - point = ee.Geometry.Point([-122.44, 37.78]) - geom = point.buffer(1024).bounds() +class GridHelpersTest(absltest.TestCase): + """Test grid helper functions that require GEE access.""" - col = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') - col = col.filterBounds(point) - col = col.filter(ee.Filter.lte('CLOUDY_PIXEL_PERCENTAGE', 5)) - col = col.limit(10) + def setUp(self): + super().setUp() + init_ee_for_tests() + self.entry = xee.EarthEngineBackendEntrypoint() + + def test_extract_grid_params_from_image(self): + img = ee.Image('LANDSAT/LT05/C02/T1_TOA/LT05_031034_20110619') + grid_params = helpers.extract_grid_params(img) + self.assertEqual(grid_params['shape_2d'], (7881, 6981)) + self.assertEqual(grid_params['crs'], 'EPSG:32613') + np.allclose(grid_params['crs_transform'], [30, 0, 643185, 0, -30, 4255815]) - ds = xr.open_dataset( - col, - engine=xee.EarthEngineBackendEntrypoint, - geometry=geom, - projection=ee.Projection('EPSG:4326', [0.0025, 0, 0, 0, -0.0025, 0]), - ) + def test_extract_grid_params_from_image_collection(self): + dem = ee.ImageCollection('COPERNICUS/DEM/GLO30'); + grid_params = helpers.extract_grid_params(dem) + self.assertEqual(grid_params['shape_2d'], (3601, 3601)) + self.assertEqual(grid_params['crs'], 'EPSG:4326') + np.allclose(grid_params['crs_transform'], [0.000278, 0, 5.999861, 0, -0.000278, 1.000139]) - ds = ds.isel(time=0).transpose('lat', 'lon') - ds.rio.set_spatial_dims(x_dim='lon', y_dim='lat', inplace=True) - ds.rio.write_crs(crs, inplace=True) - ds.rio.reproject(crs, inplace=True) - ds.rio.to_raster(temp_file) + def test_extract_grid_params_from_invalid_object(self): + with self.assertRaises(TypeError): + helpers.extract_grid_params('a string object') - with rasterio.open(temp_file) as raster: - # see https://gis.stackexchange.com/a/407755 for evenOdd explanation - bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False) - intersects = bbox.intersects(point, 1, proj=proj) - self.assertTrue(intersects.getInfo()) + +class ReadmeCodeTest(absltest.TestCase): + """Tests a copy of code contained in the Xee README.""" + + def setUp(self): + super().setUp() + init_ee_for_tests() + self.entry = xee.EarthEngineBackendEntrypoint() + + def test_extract_projection_from_image(self): + + ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate('1992-10-05', '1993-03-31') + grid_params = helpers.extract_grid_params(ic) + + # Open any Earth Engine ImageCollection by specifying the Xarray engine as 'ee': + ds = xr.open_dataset( + 'ee://ECMWF/ERA5_LAND/HOURLY', + engine='ee', + **grid_params + ) + + # Open all bands in a specific projection: + ds = xr.open_dataset( + 'ee://ECMWF/ERA5_LAND/HOURLY', + engine='ee', + crs='EPSG:32610', + crs_transform=(30, 0, 448485 + 103000, 0, -30, 4263915 - 84000), # In San Francisco, California + shape_2d=(64, 64), + ) + + # Open an ImageCollection (maybe, with EE-side filtering or processing): + ds = xr.open_dataset( + ic, + engine='ee', + crs='EPSG:32610', + crs_transform=(30, 0, 448485 + 103000, 0, -30, 4263915 - 84000), # In San Francisco, California + shape_2d=(64, 64), + ) + + # Open an ImageCollection with a specific EE projection or geometry: + + grid_params = helpers.fit_geometry( + geometry=shapely.geometry.box(113.33, -43.63, 153.56, -10.66), + grid_crs='EPSG:4326', + grid_shape=(256, 256) + ) + + ds = xr.open_dataset( + ic, + engine='ee', + **grid_params + ) + + # Open a single Image: + img = ee.Image('LANDSAT/LC08/C02/T1_TOA/LC08_044034_20140318') + grid_params = helpers.extract_grid_params(img) + ds = xr.open_dataset( + img, + engine='ee', + **grid_params + ) if __name__ == '__main__': diff --git a/xee/ext_test.py b/xee/ext_test.py index a873d2f..32bda17 100644 --- a/xee/ext_test.py +++ b/xee/ext_test.py @@ -3,8 +3,12 @@ from absl.testing import absltest from absl.testing import parameterized import numpy as np +import affine +import shapely +from unittest import mock import xee from xee import ext +from xee import helpers class EEStoreStandardDatatypesTest(parameterized.TestCase): @@ -96,6 +100,108 @@ def test_exceeding_byte_limit__raises_error(self): with self.assertRaises(ValueError): ext._check_request_limit(chunks, dtype_size, xee.REQUEST_BYTE_LIMIT) + @mock.patch( + 'xee.ext.EarthEngineStore.get_info', + new_callable=mock.PropertyMock, + ) + def test_init_with_affine_transform(self, mock_get_info): + """Test that an affine.Affine object can be passed for crs_transform.""" + mock_get_info.return_value = { + 'size': 1, + 'props': {}, + 'first': { + 'bands': [{ + 'id': 'b1', + 'data_type': {'type': 'PixelType', 'precision': 'float'} + }] + }, + } + transform_tuple = (1.0, 0.0, -180.0, 0.0, -1.0, 90.0) + transform_affine = affine.Affine(*transform_tuple) + + store = xee.EarthEngineStore( + image_collection=mock.MagicMock(), + crs='EPSG:4326', + crs_transform=transform_affine, + shape_2d=(360, 180), + ) + + self.assertIsInstance(store.crs_transform, tuple) + self.assertEqual(store.crs_transform, transform_tuple) + self.assertEqual(store.scale_x, 1.0) + self.assertEqual(store.scale_y, -1.0) + self.assertEqual(store.scale, 1.0) + + @mock.patch( + 'xee.ext.EarthEngineStore.get_info', + new_callable=mock.PropertyMock, + ) + def test_project(self, mock_get_info): + """Test that the project method correctly calculates the grid.""" + mock_get_info.return_value = { + 'size': 1, + 'props': {}, + 'first': { + 'bands': [{ + 'id': 'b1', + 'data_type': {'type': 'PixelType', 'precision': 'float'} + }] + }, + } + transform_tuple = (0.25, 0.0, -180.0, 0.0, -0.5, 90.0) + store = xee.EarthEngineStore( + image_collection=mock.MagicMock(), + crs='EPSG:4326', + crs_transform=transform_tuple, + shape_2d=(1440, 720), + ) + + bbox = (10, 20, 30, 40) # x_start, y_start, x_end, y_end + grid = store.project(bbox) + + self.assertEqual(grid['dimensions']['width'], 20) + self.assertEqual(grid['dimensions']['height'], 20) + self.assertEqual(grid['crsCode'], 'EPSG:4326') + # Check that the translation is correct: c + (x_start * a), f + (y_start * e) + self.assertAlmostEqual(grid['affineTransform']['translateX'], -180.0 + (10 * 0.25)) + self.assertAlmostEqual(grid['affineTransform']['translateY'], 90.0 + (20 * -0.5)) + + @mock.patch( + 'xee.ext.EarthEngineStore.get_info', + new_callable=mock.PropertyMock, + ) + def test_init_with_tuple_transform(self, mock_get_info): + """Test that a tuple object can be passed for crs_transform.""" + # (Setup the mock_get_info.return_value just like in the other test) + mock_get_info.return_value = { + 'size': 1, 'props': {}, + 'first': {'bands': [{'id': 'b1', 'data_type': {'type': 'PixelType', 'precision': 'float'}}]} + } + transform_tuple = (1.0, 0.0, -180.0, 0.0, -1.0, 90.0) + + # Pass the tuple directly + store = xee.EarthEngineStore( + image_collection=mock.MagicMock(), + crs='EPSG:4326', + crs_transform=transform_tuple, + shape_2d=(360, 180), + ) + + # Assert that the tuple was stored correctly + self.assertEqual(store.crs_transform, transform_tuple) + + def test_init_with_invalid_transform_type(self): + """Test that a TypeError is raised for invalid crs_transform types.""" + with self.assertRaises(TypeError): + # Pass a list, which is an invalid type + invalid_transform = [1.0, 0.0, -180.0, 0.0, -1.0, 90.0] + xee.EarthEngineStore( + image_collection=mock.MagicMock(), + crs='EPSG:4326', + crs_transform=invalid_transform, + shape_2d=(360, 180), + ) + class ParseEEInitKwargsTest(absltest.TestCase): @@ -141,5 +247,146 @@ def test_parse_ee_init_kwargs__credentials_and_credentials_function(self): ) +class GridHelpersTest(absltest.TestCase): + """Test grid helper functions that do not require GEE access.""" + + def test_set_scale(self): + """Test that the scale values of the CRS transform can be updated.""" + crs_transform = [1, 0, 100, 0, 5, 200] + scaling = (123, 456) + crs_transform_new = helpers.set_scale(crs_transform, scaling) + np.testing.assert_allclose( + crs_transform_new, + [123, 0, 100, 0, 456, 200] + ) + + + def test_fit_geometry_specify_scale(self): + """Test generating grid parameters to match a geometry, specifying the scale.""" + grid_dict = helpers.fit_geometry( + geometry=shapely.Polygon([(10.1, 10.1), + (10.1, 10.9), + (11.9, 10.1)]), + grid_crs='EPSG:4326', + grid_scale=(0.5, -0.5), + ) + self.assertEqual( + grid_dict['crs_transform'], + (0.5, 0.0, 10.0, 0.0, -0.5, 11.0), + ) + self.assertEqual( + grid_dict['shape_2d'], + (4, 2) + ) + + + def test_fit_geometry_specify_scale_scalar_fails(self): + """Test that a scalar grid_scale raises a TypeError.""" + with self.assertRaises(TypeError): + helpers.fit_geometry( + geometry=shapely.Polygon( + [(10.1, 10.1), (10.1, 10.9), (11.9, 10.1)] + ), + grid_crs='EPSG:4326', + grid_scale=0.5, # A scalar should fail + ) + + def test_fit_geometry_specify_scale_positive_y(self): + """Test fit_geometry with an explicit positive y-scale.""" + grid_dict = helpers.fit_geometry( + geometry=shapely.Polygon( + [(10.1, 10.1), (10.1, 10.9), (11.9, 10.1)] + ), + grid_crs='EPSG:4326', + grid_scale=(0.5, 0.5), # Note the positive y-scale + ) + # The transform should now reflect the positive y-scale. + self.assertEqual( + grid_dict['crs_transform'], (0.5, 0.0, 10.0, 0.0, 0.5, 11.0) + ) + self.assertEqual( + grid_dict['shape_2d'], (4, 2) + ) + + + def test_fit_geometry_specify_scale_utm(self): + """Test generating grid parameters to match a UTM geometry, specifying the scale.""" + grid_dict = helpers.fit_geometry( + geometry=shapely.geometry.box(551000, 4179000, 552000, 4180000), # over San Francisco + geometry_crs='EPSG:32610', + grid_crs='EPSG:4326', + grid_scale=(0.01, -0.01), + ) + self.assertEqual( + grid_dict['crs_transform'], + (0.01, 0.0, -122.43, 0.0, -0.01, 37.77) + ) + self.assertEqual( + grid_dict['shape_2d'], + (3, 2) + ) + + + def test_fit_geometry_specify_shape(self): + """Test generating grid parameters to match a geometry, specifying the shape.""" + grid_dict = helpers.fit_geometry( + geometry=shapely.Polygon([(10.0, 2.0), + (10.0, 3.0), + (12.0, 2.0)]), + grid_crs='EPSG:4326', + grid_shape=(4, 2) + ) + np.testing.assert_allclose( + grid_dict['crs_transform'], + (0.5, 0, 10, 0, -0.5, 3), + rtol=1e-4, + ) + + def test_fit_geometry_value_error(self): + """Test that a ValueError is raised for invalid scale/shape combinations.""" + geom = shapely.geometry.box(0, 0, 1, 1) # Use a valid polygon + # Test when both grid_scale and grid_shape are provided + with self.assertRaisesRegex( + ValueError, "Exactly one of 'grid_scale' or 'grid_shape' must be" + ): + helpers.fit_geometry( + geometry=geom, grid_scale=(0.1, -0.1), grid_shape=(10, 10) + ) + + # Test when neither grid_scale nor grid_shape are provided + with self.assertRaisesRegex( + ValueError, "Exactly one of 'grid_scale' or 'grid_shape' must be" + ): + helpers.fit_geometry(geometry=geom) + + def test_fit_geometry_with_buffer(self): + """Test that the buffer parameter correctly expands the grid.""" + grid_dict = helpers.fit_geometry( + geometry=shapely.Point(10.5, 10.5), + buffer=0.5, # Creates a 1x1 degree box around the point + grid_crs='EPSG:4326', + grid_shape=(10, 10), + ) + # The origin should be at (10.0, 11.0) for a 1x1 box centered at 10.5, 10.5 + self.assertAlmostEqual(grid_dict['crs_transform'][2], 10.0) + self.assertAlmostEqual(grid_dict['crs_transform'][5], 11.0) + self.assertEqual(grid_dict['shape_2d'], (10, 10)) + + def test_fit_geometry_with_rounding(self): + """Test that grid_scale_digits correctly rounds the scale.""" + grid_dict = helpers.fit_geometry( + geometry=shapely.Polygon( + [(0, 0), (0, 1.001), (1.001, 1.001), (1.001, 0)] + ), + grid_crs='EPSG:4326', + grid_shape=(10, 10), + grid_scale_digits=2, # Round scale to 2 decimal places + ) + # x_scale = 1.001 / 10 = 0.1001, rounded to 0.1 + # y_scale = -1.001 / 10 = -0.1001, rounded to -0.1 + self.assertAlmostEqual(grid_dict['crs_transform'][0], 0.1) + self.assertAlmostEqual(grid_dict['crs_transform'][4], -0.1) + + if __name__ == '__main__': absltest.main() diff --git a/xee/helpers.py b/xee/helpers.py new file mode 100644 index 0000000..1249756 --- /dev/null +++ b/xee/helpers.py @@ -0,0 +1,130 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper functions for grid parameters.""" +import math + +import affine +import ee +from pyproj import Transformer +import shapely +from shapely.ops import transform +from typing import TypedDict, Tuple, Union + + +TransformType = Tuple[float, float, float, float, float, float] +ShapeType = Tuple[int, int] +ScalingType = Tuple[float, float] + + +class PixelGridParams(TypedDict): + crs: str + crs_transform: TransformType + shape_d2: ShapeType + + +def set_scale( + crs_transform: TransformType, + scaling: ScalingType, + ) -> list: + """Update the CRS transform's scale parameters.""" + if isinstance(scaling, tuple) and len(scaling) == 2: + x_scale, y_scale = scaling + crs_transform[0] = x_scale + crs_transform[4] = y_scale + else: + raise TypeError(f'Expected a tuple of length 2 for scaling, got {scaling}') + affine_transform = affine.Affine(*crs_transform) + return list(affine_transform)[:6] + + +def fit_geometry( + geometry: shapely.geometry, + *, + geometry_crs: str = 'EPSG:4326', + buffer: float = 0, + grid_crs: str = 'EPSG:4326', + grid_scale: ScalingType = None, + grid_scale_digits: int = None, + grid_shape: ShapeType = None, +) -> PixelGridParams: + """Return grid parameters that fit the geometry.""" + + if (grid_scale is None) == (grid_shape is None): + raise ValueError("Exactly one of 'grid_scale' or 'grid_shape' must be specified.") + + transformer = Transformer.from_crs( + crs_from=geometry_crs, crs_to=grid_crs, always_xy=True + ) + reprojected_geometry = transform(transformer.transform, geometry) + if buffer and buffer > 0: + buffered_geom = shapely.buffer(reprojected_geometry, buffer) + else: + buffered_geom = reprojected_geometry + x_min, y_min, x_max, y_max = buffered_geom.bounds + + if grid_scale: + if isinstance(grid_scale, tuple) and len(grid_scale) == 2: + x_scale, y_scale = grid_scale + else: + raise TypeError(f'Expected a tuple of length 2 for grid_scale, got {grid_scale}') + + # REVERTED to the more direct and robust shape calculation. + x_shape = int(math.ceil(x_max / x_scale) - math.floor(x_min / x_scale)) + y_shape = int(math.ceil(y_max / abs(y_scale)) - math.floor(y_min / abs(y_scale))) + else: # grid_shape is not None + x_shape, y_shape = grid_shape + x_scale = (x_max - x_min) / x_shape + y_scale = -(y_max - y_min) / y_shape + + if grid_scale_digits: + x_scale = round(x_scale, grid_scale_digits) + y_scale = round(y_scale, grid_scale_digits) + + grid_x_min = math.floor(x_min / x_scale) * x_scale + grid_y_max = math.ceil(y_max / abs(y_scale)) * abs(y_scale) + + affine_transform = ( + affine.Affine.translation(grid_x_min, grid_y_max) + * affine.Affine.scale(x_scale, y_scale) + ) + + crs_transform = affine_transform[:6] + + return dict( + crs=grid_crs, + crs_transform=crs_transform, + shape_2d=(x_shape, y_shape) + ) + + +def extract_grid_params( + ee_obj: Union[ee.Image, ee.ImageCollection] + ) -> PixelGridParams: + # Extract the pixel grid parameters from an ee.Image or ee.ImageCollection object + + if isinstance(ee_obj, ee.Image): + img_obj = ee_obj + elif isinstance(ee_obj, ee.ImageCollection): + img_obj = ee_obj.first() + else: + raise TypeError(f'Expected ee.Image or ee.ImageCollection, got {type(ee_obj)}') + + first_band_info = img_obj.select(0).getInfo()['bands'][0] + + return dict( + crs=first_band_info['crs'], + crs_transform=tuple(first_band_info['crs_transform']), + shape_2d=tuple(first_band_info['dimensions']) + )