Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ jobs:
ruff check . --output-format=github
ruff format --check .

type-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Set up Python 3.13
uses: actions/setup-python@v6
with:
python-version: "3.13"
cache: "pip"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[type-check]
- name: Type-check with mypy
run: mypy

test:
needs: lint
runs-on: ${{ matrix.os }}
Expand Down
19 changes: 13 additions & 6 deletions dataretrieval/nadp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

"""

from __future__ import annotations

import io
import re
import warnings
Expand All @@ -45,7 +47,7 @@
)


def _warn_deprecated():
def _warn_deprecated() -> None:
warnings.warn(_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=3)


Expand Down Expand Up @@ -74,19 +76,19 @@ def _warn_deprecated():
class NADP_ZipFile(zipfile.ZipFile):
"""Extend zipfile.ZipFile for working on data from NADP"""

def tif_name(self):
def tif_name(self) -> str:
"""Get the name of the tif file in the zip file."""
filenames = self.namelist()
r = re.compile(".*tif$")
tif_list = list(filter(r.match, filenames))
return tif_list[0]

def tif(self):
def tif(self) -> bytes:
"""Read the tif file in the zip file."""
return self.read(self.tif_name())


def get_annual_MDN_map(measurement_type, year, path):
def get_annual_MDN_map(measurement_type: str, year: str, path: str) -> str:
"""Download a MDN map from NDAP.

This function looks for a zip file containing gridded information at:
Expand Down Expand Up @@ -135,7 +137,12 @@ def get_annual_MDN_map(measurement_type, year, path):
return str(path)


def get_annual_NTN_map(measurement_type, measurement=None, year=None, path="."):
def get_annual_NTN_map(
measurement_type: str,
measurement: str | None = None,
year: str | None = None,
path: str = ".",
) -> str:
"""Download a NTN map from NDAP.

This function looks for a zip file containing gridded information at:
Expand Down Expand Up @@ -193,7 +200,7 @@ def get_annual_NTN_map(measurement_type, measurement=None, year=None, path="."):
return str(path)


def get_zip(url, filename):
def get_zip(url: str, filename: str) -> NADP_ZipFile:
"""Gets a ZipFile at url and returns it

Parameters
Expand Down
54 changes: 35 additions & 19 deletions dataretrieval/nldi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from json import JSONDecodeError
from typing import Literal
from typing import Any, Literal, cast

from dataretrieval.utils import query

Expand All @@ -16,13 +16,17 @@
_VALID_NAVIGATION_MODES = ("UM", "DM", "UT", "DD")


def _query_nldi(url, query_params, error_message):
def _query_nldi(
url: str,
query_params: dict[str, str],
error_message: str,
) -> dict[str, Any] | list[Any]:
# A helper function to query the NLDI API
response = query(url, payload=query_params)
if response.status_code != 200:
raise ValueError(f"{error_message}. Error reason: {response.reason_phrase}")

response_data = {}
response_data: dict[str, Any] | list[Any] = {}
try:
response_data = response.json()
except JSONDecodeError:
Expand All @@ -32,7 +36,7 @@ def _query_nldi(url, query_params, error_message):
return response_data


def _features_to_gdf(feature_collection: dict) -> gpd.GeoDataFrame:
def _features_to_gdf(feature_collection: dict[str, Any]) -> gpd.GeoDataFrame:
"""Build a GeoDataFrame from an NLDI FeatureCollection, tolerating empties.

NLDI can legitimately return no features (e.g. a feature with nothing
Expand All @@ -56,7 +60,7 @@ def get_flowlines(
stop_comid: int | None = None,
trim_start: bool = False,
as_json: bool = False,
) -> gpd.GeoDataFrame | dict:
) -> gpd.GeoDataFrame | dict[str, Any]:
"""Gets the flowlines for the specified navigation either by comid or feature
source in WGS84 lat/long coordinates as GeoDataFrame containing a polyline geometry.

Expand Down Expand Up @@ -116,7 +120,7 @@ def get_flowlines(
else:
err_msg = f"Error getting flowlines for comid '{comid}'"

feature_collection = _query_nldi(url, query_params, err_msg)
feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg))
if as_json:
return feature_collection
gdf = _features_to_gdf(feature_collection)
Expand All @@ -129,7 +133,7 @@ def get_basin(
simplified: bool = True,
split_catchment: bool = False,
as_json: bool = False,
) -> gpd.GeoDataFrame | dict:
) -> gpd.GeoDataFrame | dict[str, Any]:
"""Gets the aggregated basin for the specified feature in WGS84 lat/lon
as GeoDataFrame or as JSON conatining a polygon geometry.

Expand Down Expand Up @@ -162,14 +166,17 @@ def get_basin(
raise ValueError("feature_id is required")

url = f"{NLDI_API_BASE_URL}/{feature_source}/{feature_id}/basin"
simplified = str(simplified).lower()
split_catchment = str(split_catchment).lower()
query_params = {"simplified": simplified, "splitCatchment": split_catchment}
simplified_str = str(simplified).lower()
split_catchment_str = str(split_catchment).lower()
query_params = {
"simplified": simplified_str,
"splitCatchment": split_catchment_str,
}
err_msg = (
f"Error getting basin for feature source '{feature_source}' and "
f"feature_id '{feature_id}'"
)
feature_collection = _query_nldi(url, query_params, err_msg)
feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg))
if as_json:
return feature_collection
gdf = _features_to_gdf(feature_collection)
Expand All @@ -187,7 +194,7 @@ def get_features(
long: float | None = None,
stop_comid: int | None = None,
as_json: bool = False,
) -> gpd.GeoDataFrame | dict:
) -> gpd.GeoDataFrame | dict[str, Any]:
"""Gets all features found along the specified navigation either by
comid or feature source as points in WGS84 lat/long coordinates - a GeoDataFrame
containing a point geometry.
Expand Down Expand Up @@ -285,7 +292,7 @@ def get_features(
query_params = {}
err_msg = _features_err_msg(feature_source, feature_id, comid, data_source)

feature_collection = _query_nldi(url, query_params, err_msg)
feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg))
if as_json:
return feature_collection
gdf = _features_to_gdf(feature_collection)
Expand Down Expand Up @@ -321,7 +328,7 @@ def get_features_by_data_source(data_source: str) -> gpd.GeoDataFrame:
_validate_data_source(data_source)
url = f"{NLDI_API_BASE_URL}/{data_source}"
err_msg = f"Error getting features for data source '{data_source}'"
feature_collection = _query_nldi(url, {}, err_msg)
feature_collection = cast("dict[str, Any]", _query_nldi(url, {}, err_msg))
gdf = _features_to_gdf(feature_collection)
return gdf

Expand All @@ -336,7 +343,7 @@ def search(
lat: float | None = None,
long: float | None = None,
distance: int = 50,
) -> dict:
) -> dict[str, Any]:
"""Searches for the specified feature in NLDI and returns the results
as a dictionary.

Expand Down Expand Up @@ -408,7 +415,7 @@ def search(
if (lat is None) != (long is None):
raise ValueError("Both lat and long are required")

find = find.lower()
find = cast(Literal["basin", "flowlines", "features"], find.lower())
if find not in ("basin", "flowlines", "features"):
raise ValueError(
f"Invalid value for find: {find} - allowed values are:"
Expand All @@ -428,6 +435,10 @@ def search(
return get_features(lat=lat, long=long, as_json=True)

if find == "basin":
if feature_source is None or feature_id is None:
raise ValueError(
"feature_source and feature_id are required to find a basin"
)
return get_basin(
feature_source=feature_source, feature_id=feature_id, as_json=True
)
Expand Down Expand Up @@ -458,7 +469,7 @@ def search(
)


def _validate_data_source(data_source: str):
def _validate_data_source(data_source: str) -> None:
# A helper function to validate user specified data source/feature source

global _AVAILABLE_DATA_SOURCES
Expand Down Expand Up @@ -487,7 +498,12 @@ def _validate_data_source(data_source: str):
raise ValueError(err_msg)


def _features_err_msg(feature_source, feature_id, comid, data_source) -> str:
def _features_err_msg(
feature_source: str | None,
feature_id: str | None,
comid: int | None,
data_source: str | None,
) -> str:
if feature_source is not None:
return (
f"Error getting features for feature source '{feature_source}'"
Expand All @@ -512,7 +528,7 @@ def _validate_navigation_mode(navigation_mode: str | None) -> str:

def _validate_feature_source_comid(
feature_source: str | None, feature_id: str | None, comid: int | None
):
) -> None:
if feature_source is not None and feature_id is None:
raise ValueError("feature_id is required if feature_source is provided")
if feature_id is not None and feature_source is None:
Expand Down
Loading
Loading