From cd26afc9ebf124ffed79c992bf86dd5592672fe0 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Sun, 31 May 2026 21:17:54 -0400 Subject: [PATCH 1/3] chore(typing): set up mypy and fix the type errors it surfaces The package ships a ``py.typed`` marker (advertising itself as typed to downstream users) but nothing type-checked it. Add mypy and get a clean run. Setup: - [tool.mypy] in pyproject.toml: a lenient first-pass config (ignore_missing_imports, target python_version 3.9), scoped to the dataretrieval package. - mypy<2 added to the [test] extra (<2 so it can still target 3.9). - a type-check job in the CI workflow, parallel to the ruff lint job. Fixes (mypy went from 78 errors to 0 on the tracked package): - HTTPX_DEFAULTS annotated dict[str, Any] so **-splatting it into httpx.get / httpx.AsyncClient type-checks -- cleared ~55 errors across 7 call sites at once. - utils.py gains `from __future__ import annotations`: mypy (targeting 3.9) caught that the new `str | None` annotations there would be a runtime error on 3.9, because this module -- unlike the rest of the package -- lacked the future import. - BaseMetadata.comment annotated `str | None` (was inferred `None`, which rejected every subclass that assigns a comment string). - _format_api_dates: accept Sequence[str | None] (covariant) so a list[str] caller type-checks, and build the formatted list with an early return so the final join sees list[str]. - _as_str_list: delegate to _normalize_str_iterable then wrap, so the declared return type list[str] | None holds. - _next_req_url: declare next_host / cur_host as `str | None`. - ratings._search: build the query dict in a non-Optional local before aliasing it to the loop's `params` (which toggles to None per page). - nldi: drop the bool->str / Literal->str variable reuse; guard the basin branch so feature_source / feature_id are non-None before get_basin. - chunking: narrow the optional filter before _is_chunkable; fix a stale `# type: ignore` error code. The fixes are annotations and type-narrowing guards. The only runtime-visible change is that nldi.search() now raises a clear ValueError up front when a basin search is missing feature_source/feature_id, where the same condition previously raised deeper inside get_basin. 259 tests pass across the affected suites. Co-Authored-By: Claude Opus 4.8 (1M context) --- .github/workflows/python-package.yml | 16 ++++++++++++++++ dataretrieval/nldi.py | 17 ++++++++++++----- dataretrieval/utils.py | 11 +++++++++-- dataretrieval/waterdata/chunking.py | 4 ++-- dataretrieval/waterdata/ratings.py | 11 +++++++---- dataretrieval/waterdata/utils.py | 24 ++++++++++++++---------- pyproject.toml | 11 +++++++++++ 7 files changed, 71 insertions(+), 23 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 0ab3d142..175dcd06 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -26,6 +26,22 @@ jobs: ruff check . --output-format=github ruff format --check . + type-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Set up Python 3.13 + uses: actions/setup-python@v6 + with: + python-version: "3.13" + cache: "pip" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] + - name: Type-check with mypy + run: mypy + test: needs: lint runs-on: ${{ matrix.os }} diff --git a/dataretrieval/nldi.py b/dataretrieval/nldi.py index 8d61fcc2..0a6cfb74 100644 --- a/dataretrieval/nldi.py +++ b/dataretrieval/nldi.py @@ -1,7 +1,7 @@ from __future__ import annotations from json import JSONDecodeError -from typing import Literal +from typing import Literal, cast from dataretrieval.utils import query @@ -162,9 +162,12 @@ def get_basin( raise ValueError("feature_id is required") url = f"{NLDI_API_BASE_URL}/{feature_source}/{feature_id}/basin" - simplified = str(simplified).lower() - split_catchment = str(split_catchment).lower() - query_params = {"simplified": simplified, "splitCatchment": split_catchment} + simplified_str = str(simplified).lower() + split_catchment_str = str(split_catchment).lower() + query_params = { + "simplified": simplified_str, + "splitCatchment": split_catchment_str, + } err_msg = ( f"Error getting basin for feature source '{feature_source}' and " f"feature_id '{feature_id}'" @@ -408,7 +411,7 @@ def search( if (lat is None) != (long is None): raise ValueError("Both lat and long are required") - find = find.lower() + find = cast(Literal["basin", "flowlines", "features"], find.lower()) if find not in ("basin", "flowlines", "features"): raise ValueError( f"Invalid value for find: {find} - allowed values are:" @@ -428,6 +431,10 @@ def search( return get_features(lat=lat, long=long, as_json=True) if find == "basin": + if feature_source is None or feature_id is None: + raise ValueError( + "feature_source and feature_id are required to find a basin" + ) return get_basin( feature_source=feature_source, feature_id=feature_id, as_json=True ) diff --git a/dataretrieval/utils.py b/dataretrieval/utils.py index 7bb03a69..ab5efc32 100644 --- a/dataretrieval/utils.py +++ b/dataretrieval/utils.py @@ -2,8 +2,11 @@ Useful utilities for data munging. """ +from __future__ import annotations + import warnings from collections.abc import Iterable +from typing import Any import httpx import pandas as pd @@ -11,7 +14,10 @@ import dataretrieval from dataretrieval.codes import tz -HTTPX_DEFAULTS = { +# Typed as ``dict[str, Any]`` (not the inferred ``dict[str, object]``) so that +# splatting it as ``**HTTPX_DEFAULTS`` into ``httpx.get`` / ``httpx.AsyncClient`` +# type-checks: the values are a heterogeneous bag of httpx keyword arguments. +HTTPX_DEFAULTS: dict[str, Any] = { "follow_redirects": True, "timeout": httpx.Timeout(60.0, connect=10.0), } @@ -190,6 +196,7 @@ def _attach_datetime_columns(df: pd.DataFrame) -> pd.DataFrame: # Concat in one shot — per-column assignment on a wide CSV-derived # frame triggers pandas' fragmentation PerformanceWarning. df = pd.concat([df, pd.DataFrame(new_columns, index=df.index)], axis=1) + sort_key: str | None if "Activity_StartDateTime" in df.columns: sort_key = "Activity_StartDateTime" elif "ActivityStartDateTime" in df.columns: @@ -234,7 +241,7 @@ def __init__(self, response) -> None: self.url = str(response.url) self.query_time = response.elapsed self.header = response.headers - self.comment = None + self.comment: str | None = None # # not sure what statistic_info is # self.statistic_info = None diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index ab079070..ad7b99c1 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -681,7 +681,7 @@ def _set_response_url(response: httpx.Response, url: str | httpx.URL) -> None: same ``.request``. """ try: - response.url = url # type: ignore[misc] + response.url = url # type: ignore[misc, assignment] except AttributeError: target = httpx.URL(str(url)) try: @@ -800,7 +800,7 @@ def _extract_axes(args: dict[str, Any]) -> list[_Axis]: axes.append(_Axis(arg_key=key, atoms=tuple(value), joiner=_LIST_SEP)) filter_expr = args.get("filter") - if _is_chunkable(filter_expr, args.get("filter_lang")): + if filter_expr is not None and _is_chunkable(filter_expr, args.get("filter_lang")): _check_numeric_filter_pitfall(filter_expr) clauses = _split_top_level_or(filter_expr) if len(clauses) >= 2: diff --git a/dataretrieval/waterdata/ratings.py b/dataretrieval/waterdata/ratings.py index ed242612..c0f870c1 100644 --- a/dataretrieval/waterdata/ratings.py +++ b/dataretrieval/waterdata/ratings.py @@ -246,15 +246,18 @@ def _search( STAC ``next`` link is followed until exhausted so a result set larger than one page isn't silently truncated. """ - params: dict[str, Any] | None = {"limit": min(limit, 10000)} + query_params: dict[str, Any] = {"limit": min(limit, 10000)} if filter_str is not None: - params["filter"] = filter_str + query_params["filter"] = filter_str if time_str is not None: - params["datetime"] = time_str + query_params["datetime"] = time_str if bbox is not None: - params["bbox"] = ",".join(map(str, bbox)) + query_params["bbox"] = ",".join(map(str, bbox)) url: str | None = f"{STAC_URL}/search" + # ``params`` is sent only on the first request; each STAC ``next`` link + # already carries the query, so it is reset to None inside the loop. + params: dict[str, Any] | None = query_params features: list[dict[str, Any]] = [] while url is not None: response = httpx.get( diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index ad1b3afd..26a45773 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -14,6 +14,7 @@ Iterable, Iterator, Mapping, + Sequence, ) from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar @@ -251,7 +252,7 @@ def _format_one(dt, *, date: bool) -> str | None: def _format_api_dates( - datetime_input: str | list[str | None] | None, date: bool = False + datetime_input: str | Sequence[str | None] | None, date: bool = False ) -> str | None: """ Formats date or datetime input(s) for use with an API. @@ -330,11 +331,13 @@ def _format_api_dates( if _DURATION_RE.match(single) or "/" in single: return single - # Half-bounded ranges: NA endpoints render as ".."; any unparseable non-NA # element invalidates the range. - formatted = [_format_one(dt, date=date) for dt in datetime_input] - if any(f is None for f in formatted): - return None + formatted: list[str] = [] + for dt in datetime_input: + one = _format_one(dt, date=date) + if one is None: + return None + formatted.append(one) return "/".join(formatted) @@ -823,6 +826,8 @@ def _next_req_url( # body might supply. Guarded against mock-shaped ``resp.url`` # attributes (tests sometimes set strings or ``MagicMock``) # by falling open when host extraction isn't reliable. + next_host: str | None + cur_host: str | None try: next_host = httpx.URL(href).host resp_url = ( @@ -1915,11 +1920,10 @@ def _as_str_list( ``",".join(...)`` doesn't iterate it character-by-character — and materializes any other iterable via :func:`_normalize_str_iterable`. """ - return ( - [value] - if isinstance(value, str) - else _normalize_str_iterable(value, param_name) - ) + normalized = _normalize_str_iterable(value, param_name) + if isinstance(normalized, str): + return [normalized] + return normalized def _check_monitoring_location_id( diff --git a/pyproject.toml b/pyproject.toml index 62ac7478..9af5f6a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ test = [ "coverage", "pytest-httpx", "ruff", + "mypy<2", # <2 so it can still target Python 3.9 (the project's floor) ] doc = [ "docutils<0.22", @@ -102,3 +103,13 @@ skip-magic-trailing-comma = false line-ending = "auto" docstring-code-format = true docstring-code-line-length = 72 + +[tool.mypy] +# Incremental-adoption config. ``check_untyped_defs`` type-checks the bodies of +# not-yet-annotated functions; what still keeps this short of ``--strict`` is +# that annotations aren't *required* (``disallow_untyped_defs`` is off) and +# untyped third-party libs are treated as ``Any`` (``ignore_missing_imports``). +python_version = "3.9" # the project's minimum supported version +files = ["dataretrieval"] +ignore_missing_imports = true +check_untyped_defs = true From d9bec5e2d6cb6b84d8038053062beff7aa7a9e55 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Sun, 31 May 2026 21:44:41 -0400 Subject: [PATCH 2/3] chore(typing): annotate the package and enable mypy --strict Builds on the lenient baseline: annotate the 71 previously-unannotated functions, resolve the remaining strict findings, and flip the mypy config to ``strict = true`` (keeping only ``ignore_missing_imports`` for untyped third-party libraries). - nwis: type the ``@_deprecated`` decorator with a signature-preserving ``TypeVar`` (3.9-safe; ParamSpec would need 3.10), which un-erases the decorated public getters; annotate the getters and helpers; defunct stubs that only ``raise`` are typed ``-> NoReturn``. - wqp / nldi / nadp / streamstats / samples / utils / waterdata{utils,api, nearest,ratings,chunking}: precise parameter and return annotations (``tuple[pd.DataFrame, ]`` for the getters), parameterized bare ``dict`` generics, and a few justified ``cast``s where a callee's union return is statically wider than the single concrete type a caller is guaranteed (NLDI FeatureCollection endpoints, format-keyed StreamStats). - __init__: ``get_ratings`` (nwis vs waterdata) and ``what_sites`` (nwis vs wqp) collide across the star-imports; the modern definitions win by import order. mypy can't model last-binding-wins, so the two re-binding mismatches are silenced with an explanatory comment. (The collision itself is a pre-existing public-namespace question, left as-is here.) Annotations only -- no runtime behavior change. ``mypy --strict`` is clean and the full test suite (469 passed, 2 skipped) is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) --- dataretrieval/nadp.py | 19 +++++--- dataretrieval/nldi.py | 39 ++++++++++------ dataretrieval/nwis.py | 71 +++++++++++++++++------------ dataretrieval/samples.py | 8 +++- dataretrieval/streamstats.py | 45 ++++++++++-------- dataretrieval/utils.py | 23 ++++++---- dataretrieval/waterdata/api.py | 8 ++-- dataretrieval/waterdata/chunking.py | 9 +++- dataretrieval/waterdata/nearest.py | 16 +++---- dataretrieval/waterdata/utils.py | 29 ++++++++---- dataretrieval/wqp.py | 71 +++++++++++++++++------------ pyproject.toml | 11 +++-- 12 files changed, 214 insertions(+), 135 deletions(-) diff --git a/dataretrieval/nadp.py b/dataretrieval/nadp.py index 3d1ee442..d6b26381 100644 --- a/dataretrieval/nadp.py +++ b/dataretrieval/nadp.py @@ -29,6 +29,8 @@ """ +from __future__ import annotations + import io import re import warnings @@ -45,7 +47,7 @@ ) -def _warn_deprecated(): +def _warn_deprecated() -> None: warnings.warn(_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=3) @@ -74,19 +76,19 @@ def _warn_deprecated(): class NADP_ZipFile(zipfile.ZipFile): """Extend zipfile.ZipFile for working on data from NADP""" - def tif_name(self): + def tif_name(self) -> str: """Get the name of the tif file in the zip file.""" filenames = self.namelist() r = re.compile(".*tif$") tif_list = list(filter(r.match, filenames)) return tif_list[0] - def tif(self): + def tif(self) -> bytes: """Read the tif file in the zip file.""" return self.read(self.tif_name()) -def get_annual_MDN_map(measurement_type, year, path): +def get_annual_MDN_map(measurement_type: str, year: str, path: str) -> str: """Download a MDN map from NDAP. This function looks for a zip file containing gridded information at: @@ -135,7 +137,12 @@ def get_annual_MDN_map(measurement_type, year, path): return str(path) -def get_annual_NTN_map(measurement_type, measurement=None, year=None, path="."): +def get_annual_NTN_map( + measurement_type: str, + measurement: str | None = None, + year: str | None = None, + path: str = ".", +) -> str: """Download a NTN map from NDAP. This function looks for a zip file containing gridded information at: @@ -193,7 +200,7 @@ def get_annual_NTN_map(measurement_type, measurement=None, year=None, path="."): return str(path) -def get_zip(url, filename): +def get_zip(url: str, filename: str) -> NADP_ZipFile: """Gets a ZipFile at url and returns it Parameters diff --git a/dataretrieval/nldi.py b/dataretrieval/nldi.py index 0a6cfb74..a03aa1e6 100644 --- a/dataretrieval/nldi.py +++ b/dataretrieval/nldi.py @@ -1,7 +1,7 @@ from __future__ import annotations from json import JSONDecodeError -from typing import Literal, cast +from typing import Any, Literal, cast from dataretrieval.utils import query @@ -16,13 +16,17 @@ _VALID_NAVIGATION_MODES = ("UM", "DM", "UT", "DD") -def _query_nldi(url, query_params, error_message): +def _query_nldi( + url: str, + query_params: dict[str, str], + error_message: str, +) -> dict[str, Any] | list[Any]: # A helper function to query the NLDI API response = query(url, payload=query_params) if response.status_code != 200: raise ValueError(f"{error_message}. Error reason: {response.reason_phrase}") - response_data = {} + response_data: dict[str, Any] | list[Any] = {} try: response_data = response.json() except JSONDecodeError: @@ -32,7 +36,7 @@ def _query_nldi(url, query_params, error_message): return response_data -def _features_to_gdf(feature_collection: dict) -> gpd.GeoDataFrame: +def _features_to_gdf(feature_collection: dict[str, Any]) -> gpd.GeoDataFrame: """Build a GeoDataFrame from an NLDI FeatureCollection, tolerating empties. NLDI can legitimately return no features (e.g. a feature with nothing @@ -56,7 +60,7 @@ def get_flowlines( stop_comid: int | None = None, trim_start: bool = False, as_json: bool = False, -) -> gpd.GeoDataFrame | dict: +) -> gpd.GeoDataFrame | dict[str, Any]: """Gets the flowlines for the specified navigation either by comid or feature source in WGS84 lat/long coordinates as GeoDataFrame containing a polyline geometry. @@ -116,7 +120,7 @@ def get_flowlines( else: err_msg = f"Error getting flowlines for comid '{comid}'" - feature_collection = _query_nldi(url, query_params, err_msg) + feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg)) if as_json: return feature_collection gdf = _features_to_gdf(feature_collection) @@ -129,7 +133,7 @@ def get_basin( simplified: bool = True, split_catchment: bool = False, as_json: bool = False, -) -> gpd.GeoDataFrame | dict: +) -> gpd.GeoDataFrame | dict[str, Any]: """Gets the aggregated basin for the specified feature in WGS84 lat/lon as GeoDataFrame or as JSON conatining a polygon geometry. @@ -172,7 +176,7 @@ def get_basin( f"Error getting basin for feature source '{feature_source}' and " f"feature_id '{feature_id}'" ) - feature_collection = _query_nldi(url, query_params, err_msg) + feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg)) if as_json: return feature_collection gdf = _features_to_gdf(feature_collection) @@ -190,7 +194,7 @@ def get_features( long: float | None = None, stop_comid: int | None = None, as_json: bool = False, -) -> gpd.GeoDataFrame | dict: +) -> gpd.GeoDataFrame | dict[str, Any]: """Gets all features found along the specified navigation either by comid or feature source as points in WGS84 lat/long coordinates - a GeoDataFrame containing a point geometry. @@ -288,7 +292,7 @@ def get_features( query_params = {} err_msg = _features_err_msg(feature_source, feature_id, comid, data_source) - feature_collection = _query_nldi(url, query_params, err_msg) + feature_collection = cast("dict[str, Any]", _query_nldi(url, query_params, err_msg)) if as_json: return feature_collection gdf = _features_to_gdf(feature_collection) @@ -324,7 +328,7 @@ def get_features_by_data_source(data_source: str) -> gpd.GeoDataFrame: _validate_data_source(data_source) url = f"{NLDI_API_BASE_URL}/{data_source}" err_msg = f"Error getting features for data source '{data_source}'" - feature_collection = _query_nldi(url, {}, err_msg) + feature_collection = cast("dict[str, Any]", _query_nldi(url, {}, err_msg)) gdf = _features_to_gdf(feature_collection) return gdf @@ -339,7 +343,7 @@ def search( lat: float | None = None, long: float | None = None, distance: int = 50, -) -> dict: +) -> dict[str, Any]: """Searches for the specified feature in NLDI and returns the results as a dictionary. @@ -465,7 +469,7 @@ def search( ) -def _validate_data_source(data_source: str): +def _validate_data_source(data_source: str) -> None: # A helper function to validate user specified data source/feature source global _AVAILABLE_DATA_SOURCES @@ -494,7 +498,12 @@ def _validate_data_source(data_source: str): raise ValueError(err_msg) -def _features_err_msg(feature_source, feature_id, comid, data_source) -> str: +def _features_err_msg( + feature_source: str | None, + feature_id: str | None, + comid: int | None, + data_source: str | None, +) -> str: if feature_source is not None: return ( f"Error getting features for feature source '{feature_source}'" @@ -519,7 +528,7 @@ def _validate_navigation_mode(navigation_mode: str | None) -> str: def _validate_feature_source_comid( feature_source: str | None, feature_id: str | None, comid: int | None -): +) -> None: if feature_source is not None and feature_id is None: raise ValueError("feature_id is required if feature_source is provided") if feature_id is not None and feature_source is None: diff --git a/dataretrieval/nwis.py b/dataretrieval/nwis.py index 1372caa7..fafd0a08 100644 --- a/dataretrieval/nwis.py +++ b/dataretrieval/nwis.py @@ -9,7 +9,9 @@ import functools import threading import warnings +from collections.abc import Callable from json import JSONDecodeError +from typing import Any, NoReturn, TypeVar, cast import httpx import pandas as pd @@ -24,6 +26,8 @@ except ImportError: gpd = None +F = TypeVar("F", bound=Callable[..., Any]) + WATERDATA_BASE_URL = "https://nwis.waterdata.usgs.gov/" WATERDATA_URL = WATERDATA_BASE_URL + "nwis/" WATERSERVICE_URL = "https://waterservices.usgs.gov/nwis/" @@ -75,7 +79,7 @@ def _warn_deprecated(func_name: str) -> None: ) -def _deprecated(func): +def _deprecated(func: F) -> F: """Mark an nwis function as deprecated. Wrappers like ``get_record`` -> ``get_iv`` -> ``query_waterservices`` would @@ -89,7 +93,7 @@ def _deprecated(func): ) @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: if getattr(_deprecation_state, "active", False): return func(*args, **kwargs) _deprecation_state.active = True @@ -99,7 +103,7 @@ def wrapper(*args, **kwargs): finally: _deprecation_state.active = False - return wrapper + return cast(F, wrapper) def _parse_json_or_raise(response: httpx.Response) -> pd.DataFrame: @@ -123,7 +127,7 @@ def _parse_json_or_raise(response: httpx.Response) -> pd.DataFrame: def format_response( - df: pd.DataFrame, service: str | None = None, **kwargs + df: pd.DataFrame, service: str | None = None, **kwargs: Any ) -> pd.DataFrame: """Setup index for response from query. @@ -197,14 +201,14 @@ def preformat_peaks_response(df: pd.DataFrame) -> pd.DataFrame: return df -def get_qwdata(**kwargs): +def get_qwdata(**kwargs: Any) -> NoReturn: """Defunct: use ``waterdata.get_samples()``.""" raise NameError( "`nwis.get_qwdata` has been replaced with `waterdata.get_samples()`." ) -def get_discharge_measurements(**kwargs): +def get_discharge_measurements(**kwargs: Any) -> NoReturn: """Defunct: use ``waterdata.get_field_measurements()``.""" raise NameError( "`nwis.get_discharge_measurements` has been replaced " @@ -219,8 +223,8 @@ def get_discharge_peaks( end: str | None = None, multi_index: bool = True, ssl_check: bool = True, - **kwargs, -) -> tuple[pd.DataFrame, BaseMetadata]: + **kwargs: Any, +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Get discharge peaks from the waterdata service. @@ -285,7 +289,7 @@ def get_discharge_peaks( ) -def get_gwlevels(**kwargs): +def get_gwlevels(**kwargs: Any) -> NoReturn: """Defunct: use ``waterdata.get_continuous()``, ``waterdata.get_daily()``, or ``waterdata.get_field_measurements()``.""" raise NameError( @@ -298,8 +302,8 @@ def get_gwlevels(**kwargs): @_deprecated def get_stats( - sites: list[str] | str | None = None, ssl_check: bool = True, **kwargs -) -> tuple[pd.DataFrame, BaseMetadata]: + sites: list[str] | str | None = None, ssl_check: bool = True, **kwargs: Any +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Queries water services statistics information. @@ -359,7 +363,9 @@ def get_stats( @_deprecated -def query_waterdata(service: str, ssl_check: bool = True, **kwargs) -> httpx.Response: +def query_waterdata( + service: str, ssl_check: bool = True, **kwargs: Any +) -> httpx.Response: """ Queries waterdata. @@ -404,7 +410,7 @@ def query_waterdata(service: str, ssl_check: bool = True, **kwargs) -> httpx.Res @_deprecated def query_waterservices( - service: str, ssl_check: bool = True, **kwargs + service: str, ssl_check: bool = True, **kwargs: Any ) -> httpx.Response: """ Queries waterservices.usgs.gov @@ -473,8 +479,8 @@ def get_dv( end: str | None = None, multi_index: bool = True, ssl_check: bool = True, - **kwargs, -) -> tuple[pd.DataFrame, BaseMetadata]: + **kwargs: Any, +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Get daily values data from NWIS and return it as a ``pandas.DataFrame``. @@ -539,7 +545,9 @@ def get_dv( @_deprecated -def get_info(ssl_check: bool = True, **kwargs) -> tuple[pd.DataFrame, BaseMetadata]: +def get_info( + ssl_check: bool = True, **kwargs: Any +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Get site description information from NWIS. @@ -661,8 +669,8 @@ def get_iv( end: str | None = None, multi_index: bool = True, ssl_check: bool = True, - **kwargs, -) -> tuple[pd.DataFrame, BaseMetadata]: + **kwargs: Any, +) -> tuple[pd.DataFrame, NWIS_Metadata]: """Get instantaneous values data from NWIS and return it as a DataFrame. .. note:: @@ -725,7 +733,7 @@ def get_iv( return format_response(df, **kwargs), NWIS_Metadata(response, **kwargs) -def get_pmcodes(**kwargs): +def get_pmcodes(**kwargs: Any) -> NoReturn: """Defunct: use ``get_reference_table(collection='parameter-codes')``.""" raise NameError( "`nwis.get_pmcodes` has been replaced " @@ -733,7 +741,7 @@ def get_pmcodes(**kwargs): ) -def get_water_use(**kwargs): +def get_water_use(**kwargs: Any) -> NoReturn: """Defunct: no current replacement.""" raise NameError("`nwis.get_water_use` is defunct.") @@ -743,8 +751,8 @@ def get_ratings( site: str | None = None, file_type: str = "base", ssl_check: bool = True, - **kwargs, -) -> tuple[pd.DataFrame, BaseMetadata]: + **kwargs: Any, +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Rating table for an active USGS streamgage retrieval. @@ -797,7 +805,9 @@ def get_ratings( @_deprecated -def what_sites(ssl_check: bool = True, **kwargs) -> tuple[pd.DataFrame, BaseMetadata]: +def what_sites( + ssl_check: bool = True, **kwargs: Any +) -> tuple[pd.DataFrame, NWIS_Metadata]: """ Search NWIS for sites within a region with specific data. @@ -847,7 +857,7 @@ def get_record( state: str | None = None, service: str = "iv", ssl_check: bool = True, - **kwargs, + **kwargs: Any, ) -> pd.DataFrame: """ Get data from NWIS and return it as a ``pandas.DataFrame``. @@ -985,7 +995,10 @@ def get_record( return df elif service == "ratings": - df, _ = get_ratings(site=sites, ssl_check=ssl_check, **kwargs) + # the ratings service is single-site; get_ratings takes a scalar site + df, _ = get_ratings( + site=cast("str | None", sites), ssl_check=ssl_check, **kwargs + ) return df elif service == "stat": @@ -996,7 +1009,7 @@ def get_record( raise TypeError(f"{service} service not yet implemented") -def _read_json(json): +def _read_json(json: dict[str, Any]) -> pd.DataFrame: """ Reads a NWIS Water Services formatted JSON into a ``pandas.DataFrame``. @@ -1092,7 +1105,7 @@ def _read_json(json): return merged_df -def _read_rdb(rdb): +def _read_rdb(rdb: str) -> pd.DataFrame: """Parse an NWIS RDB response and apply NWIS-specific post-processing. Thin wrapper around :func:`dataretrieval.rdb.read_rdb` that adds the @@ -1102,7 +1115,7 @@ def _read_rdb(rdb): return format_response(read_rdb(rdb, dtypes=_NWIS_RDB_DTYPES)) -def _check_sites_value_types(sites): +def _check_sites_value_types(sites: list[str] | str | None) -> None: if sites and not isinstance(sites, list) and not isinstance(sites, str): raise TypeError("sites must be a string or a list of strings") @@ -1128,7 +1141,7 @@ class NWIS_Metadata(BaseMetadata): """ - def __init__(self, response, **parameters) -> None: + def __init__(self, response: httpx.Response, **parameters: Any) -> None: """Generates a standard set of metadata informed by the response with specific metadata for NWIS data. diff --git a/dataretrieval/samples.py b/dataretrieval/samples.py index 2259969c..025fa76e 100644 --- a/dataretrieval/samples.py +++ b/dataretrieval/samples.py @@ -7,9 +7,15 @@ from __future__ import annotations import warnings +from typing import TYPE_CHECKING, Any +if TYPE_CHECKING: + import pandas as pd -def get_usgs_samples(**kwargs): + from dataretrieval.utils import BaseMetadata + + +def get_usgs_samples(**kwargs: Any) -> tuple[pd.DataFrame, BaseMetadata]: """Deprecated: use ``waterdata.get_samples()`` instead. All keyword arguments are forwarded directly to diff --git a/dataretrieval/streamstats.py b/dataretrieval/streamstats.py index 6737d54c..039f292b 100644 --- a/dataretrieval/streamstats.py +++ b/dataretrieval/streamstats.py @@ -5,14 +5,17 @@ """ +from __future__ import annotations + import json +from typing import Any, cast import httpx from dataretrieval.utils import HTTPX_DEFAULTS -def download_workspace(workspaceID, format=""): +def download_workspace(workspaceID: str, format: str = "") -> httpx.Response: """Function to download streamstats workspace. Parameters @@ -46,7 +49,7 @@ def download_workspace(workspaceID, format=""): # return -def get_sample_watershed(): +def get_sample_watershed() -> Watershed: """Sample function to get a watershed object for a location in NY. Makes the function call :obj:`dataretrieval.streamstats.get_watershed` @@ -60,20 +63,23 @@ def get_sample_watershed(): from the streamstats JSON object. """ - return get_watershed("NY", -74.524, 43.939, format="object") + return cast( + "Watershed", + get_watershed("NY", -74.524, 43.939, format="object"), + ) def get_watershed( - rcode, - xlocation, - ylocation, - crs=4326, - includeparameters=True, - includeflowtypes=False, - includefeatures=True, - simplify=True, - format="geojson", -): + rcode: str, + xlocation: float, + ylocation: float, + crs: int | str = 4326, + includeparameters: bool = True, + includeflowtypes: bool = False, + includefeatures: bool = True, + simplify: bool = True, + format: str = "geojson", +) -> httpx.Response | Watershed: """Get watershed object based on location **Streamstats documentation:** @@ -115,7 +121,7 @@ def get_watershed( from the streamstats JSON object. """ - payload = { + payload: dict[str, str | int | float | bool] = { "rcode": rcode, "xlocation": xlocation, "ylocation": ylocation, @@ -170,14 +176,17 @@ class Watershed: :obj:`dataretrieval.streamstats.download_workspace`. """ - def __init__(self, rcode, xlocation, ylocation): + def __init__(self, rcode: str, xlocation: float, ylocation: float) -> None: """Delineate the watershed at ``(xlocation, ylocation)`` and parse the response onto this instance.""" - response = get_watershed(rcode, xlocation, ylocation, format="geojson") + response = cast( + httpx.Response, + get_watershed(rcode, xlocation, ylocation, format="geojson"), + ) self._populate(json.loads(response.text)) @classmethod - def from_streamstats_json(cls, streamstats_json) -> "Watershed": + def from_streamstats_json(cls, streamstats_json: dict[str, Any]) -> Watershed: """Create a :class:`Watershed` from an already-parsed StreamStats JSON payload, without issuing a new request. @@ -190,7 +199,7 @@ class state. self._populate(streamstats_json) return self - def _populate(self, streamstats_json) -> None: + def _populate(self, streamstats_json: dict[str, Any]) -> None: """Extract watershed fields from a StreamStats JSON payload onto this instance.""" self.watershed_point = streamstats_json["featurecollection"][0]["feature"] diff --git a/dataretrieval/utils.py b/dataretrieval/utils.py index ab5efc32..f9766ee6 100644 --- a/dataretrieval/utils.py +++ b/dataretrieval/utils.py @@ -23,7 +23,7 @@ } -def to_str(listlike, delimiter=","): +def to_str(listlike: object, delimiter: str = ",") -> str | None: """Translates list-like objects into strings. Parameters @@ -60,7 +60,9 @@ def to_str(listlike, delimiter=","): return None -def format_datetime(df, date_field, time_field, tz_field): +def format_datetime( + df: pd.DataFrame, date_field: str, time_field: str, tz_field: str +) -> pd.DataFrame: """Creates a datetime field from separate date, time, and time zone fields. @@ -222,7 +224,7 @@ class BaseMetadata: """ - def __init__(self, response) -> None: + def __init__(self, response: httpx.Response) -> None: """Generates a standard set of metadata informed by the response. Parameters @@ -251,13 +253,13 @@ def __init__(self, response) -> None: # These properties are to be set by `nwis` or `wqp`-specific metadata classes. @property - def site_info(self): + def site_info(self) -> Any: raise NotImplementedError( "site_info must be implemented by utils.BaseMetadata children" ) @property - def variable_info(self): + def variable_info(self) -> Any: raise NotImplementedError( "variable_info must be implemented by utils.BaseMetadata children" ) @@ -285,7 +287,12 @@ def _url_too_long_error(detail: str) -> ValueError: ) -def query(url, payload, delimiter=",", ssl_check=True): +def query( + url: str, + payload: dict[str, Any], + delimiter: str = ",", + ssl_check: bool = True, +) -> httpx.Response: """Send a query. Wrapper for httpx.get that handles errors, converts listed @@ -354,10 +361,10 @@ def query(url, payload, delimiter=",", ssl_check=True): class NoSitesError(Exception): """Custom error class used when selection criteria returns no sites/data.""" - def __init__(self, url): + def __init__(self, url: httpx.URL) -> None: self.url = url - def __str__(self): + def __str__(self) -> str: return ( "No sites/data found using the selection criteria specified in " f"url: {self.url}" diff --git a/dataretrieval/waterdata/api.py b/dataretrieval/waterdata/api.py index 1b609324..3144bf80 100644 --- a/dataretrieval/waterdata/api.py +++ b/dataretrieval/waterdata/api.py @@ -10,7 +10,7 @@ import logging from collections.abc import Iterable from io import StringIO -from typing import get_args +from typing import Any, get_args from urllib.parse import quote import httpx @@ -2018,7 +2018,7 @@ def get_peaks( def get_reference_table( collection: str, limit: int | None = None, - query: dict | None = None, + query: dict[str, Any] | None = None, max_rows: int | None = None, ) -> tuple[pd.DataFrame, BaseMetadata]: """Get metadata reference tables for the USGS Water Data API. @@ -2140,7 +2140,7 @@ def get_codes(code_service: CODE_SERVICES) -> tuple[pd.DataFrame, BaseMetadata]: def _get_samples_csv( - url: str, params: dict, ssl_check: bool + url: str, params: dict[str, Any], ssl_check: bool ) -> tuple[pd.DataFrame, httpx.Response]: """Issue a Samples CSV request and parse the body into a DataFrame. @@ -2852,7 +2852,7 @@ def get_channel( def get_cql( service: WATERDATA_SERVICES, - cql: str | dict, + cql: str | dict[str, Any], *, properties: str | Iterable[str] | None = None, bbox: list[float] | None = None, diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index ad7b99c1..1f49ed97 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -59,7 +59,7 @@ from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta -from typing import Any, ClassVar +from typing import Any, ClassVar, cast from urllib.parse import quote_plus import httpx @@ -1560,7 +1560,12 @@ def resume(self) -> tuple[pd.DataFrame, Any]: """ concurrency = _read_concurrency_env() with start_blocking_portal() as portal: - return portal.call(functools.partial(self._run, concurrency)) + # ``portal.call`` returns ``Any`` because ``functools.partial`` + # erases ``_run``'s return type; restore the declared tuple. + return cast( + "tuple[pd.DataFrame, Any]", + portal.call(functools.partial(self._run, concurrency)), + ) async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: """ diff --git a/dataretrieval/waterdata/nearest.py b/dataretrieval/waterdata/nearest.py index 12aad61c..39a80332 100644 --- a/dataretrieval/waterdata/nearest.py +++ b/dataretrieval/waterdata/nearest.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Literal, get_args +from typing import Any, Literal, get_args import pandas as pd @@ -18,13 +18,13 @@ def get_nearest_continuous( - targets, + targets: Iterable[Any], monitoring_location_id: str | Iterable[str] | None = None, parameter_code: str | Iterable[str] | None = None, *, window: str | pd.Timedelta = "PT7M30S", on_tie: OnTie = "first", - **kwargs, + **kwargs: Any, ) -> tuple[pd.DataFrame, BaseMetadata]: """For each target timestamp, return the nearest continuous observation. @@ -138,13 +138,13 @@ def get_nearest_continuous( ... ) """ _check_nearest_kwargs(kwargs, on_tie) - targets = pd.DatetimeIndex(pd.to_datetime(targets, utc=True)) + target_index = pd.DatetimeIndex(pd.to_datetime(targets, utc=True)) window_td = pd.Timedelta(window) - if len(targets) == 0: + if len(target_index) == 0: raise ValueError("targets must contain at least one timestamp") - filter_expr = _build_window_or_filter(targets, window_td) + filter_expr = _build_window_or_filter(target_index, window_td) df, md = get_continuous( monitoring_location_id=monitoring_location_id, parameter_code=parameter_code, @@ -165,7 +165,7 @@ def get_nearest_continuous( selected = [ row for _, site_df in site_groups - for target in targets + for target in target_index if (row := _pick_nearest_row(site_df, target, window_td, on_tie)) is not None ] if not selected: @@ -173,7 +173,7 @@ def get_nearest_continuous( return pd.DataFrame(selected).reset_index(drop=True), md -def _check_nearest_kwargs(kwargs: dict, on_tie: OnTie) -> None: +def _check_nearest_kwargs(kwargs: dict[str, Any], on_tie: OnTie) -> None: """Reject kwargs the helper owns; validate ``on_tie``.""" for forbidden in ("time", "filter", "filter_lang"): if forbidden in kwargs: diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 26a45773..5581d086 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -19,7 +19,7 @@ from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from datetime import datetime, timedelta -from typing import Any, TypeVar, get_args +from typing import Any, TypeVar, cast, get_args from zoneinfo import ZoneInfo import httpx @@ -98,7 +98,7 @@ } -def _switch_arg_id(ls: dict[str, Any], id_name: str, service: str): +def _switch_arg_id(ls: dict[str, Any], id_name: str, service: str) -> dict[str, Any]: """ Switch argument id from its package-specific identifier to the standardized "id" key that the API recognizes. @@ -143,7 +143,9 @@ def _switch_arg_id(ls: dict[str, Any], id_name: str, service: str): return ls -def _switch_properties_id(properties: list[str] | None, id_name: str, service: str): +def _switch_properties_id( + properties: list[str] | None, id_name: str, service: str +) -> list[str]: """ Switch properties id from its package-specific identifier to the standardized "id" key that the API recognizes. @@ -234,7 +236,7 @@ def _parse_datetime(value: str) -> datetime | None: return None -def _format_one(dt, *, date: bool) -> str | None: +def _format_one(dt: str | None, *, date: bool) -> str | None: """Format a single datetime element for inclusion in the API time arg.""" if pd.isna(dt) or dt == "" or dt is None: return ".." @@ -374,7 +376,7 @@ def _cql2_param(args: dict[str, Any]) -> str: return json.dumps(query, separators=(",", ":")) -def _default_headers(): +def _default_headers() -> dict[str, str]: """ Generate default HTTP headers for API requests. @@ -397,7 +399,9 @@ def _default_headers(): return headers -def _check_ogc_requests(endpoint: str = "daily", req_type: str = "queryables"): +def _check_ogc_requests( + endpoint: str = "daily", req_type: str = "queryables" +) -> dict[str, Any]: """ Sends an HTTP GET request to the specified OGC endpoint and request type, returning the JSON response. @@ -429,10 +433,12 @@ def _check_ogc_requests(endpoint: str = "daily", req_type: str = "queryables"): url = f"{OGC_API_URL}/collections/{endpoint}/{req_type}" resp = httpx.get(url, headers=_default_headers(), **HTTPX_DEFAULTS) _raise_for_non_200(resp) - return resp.json() + # ``Response.json`` is typed ``Any``; the OGC queryables/schema endpoints + # return a JSON object, and callers index it as a dict. + return cast("dict[str, Any]", resp.json()) -def _error_body(resp: httpx.Response): +def _error_body(resp: httpx.Response) -> str: """ Build an informative error message from an HTTP response. @@ -629,7 +635,7 @@ def _construct_api_requests( bbox: list[float] | None = None, limit: int | None = None, skip_geometry: bool = False, - **kwargs, + **kwargs: Any, ) -> httpx.Request: """ Constructs an HTTP request object for the specified water data API service. @@ -843,7 +849,10 @@ def _next_req_url( f"Refusing to follow cross-host next-page URL: " f"{next_host} != {cur_host}" ) - return href + # ``href`` comes from the JSON ``links`` array (typed ``Any``); the + # ``not href`` guard above already excluded empty/None, and it is a + # URL string (passed to ``httpx.URL`` above). + return cast("str", href) return None diff --git a/dataretrieval/wqp.py b/dataretrieval/wqp.py index ff01c46a..1ca0098a 100644 --- a/dataretrieval/wqp.py +++ b/dataretrieval/wqp.py @@ -13,13 +13,14 @@ import warnings from io import StringIO -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pandas as pd from .utils import BaseMetadata, _attach_datetime_columns, query if TYPE_CHECKING: + import httpx from pandas import DataFrame @@ -67,9 +68,9 @@ def _read_wqp_csv(text: str) -> DataFrame: def get_results( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Query the WQP for results. @@ -185,9 +186,9 @@ def get_results( def what_sites( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for sites within a region with specific data. @@ -240,9 +241,9 @@ def what_sites( def what_organizations( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for organizations within a region with specific data. @@ -290,7 +291,11 @@ def what_organizations( return df, WQP_Metadata(response, **kwargs) -def what_projects(ssl_check=True, legacy=True, **kwargs): +def what_projects( + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, +) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for projects within a region with specific data. Any WQP API parameter can be passed as a keyword argument to this function. @@ -338,9 +343,9 @@ def what_projects(ssl_check=True, legacy=True, **kwargs): def what_activities( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for activities within a region with specific data. @@ -402,9 +407,9 @@ def what_activities( def what_detection_limits( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for result detection limits within a region with specific data. @@ -460,9 +465,9 @@ def what_detection_limits( def what_habitat_metrics( - ssl_check=True, - legacy=True, - **kwargs, + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, ) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for habitat metrics within a region with specific data. @@ -510,7 +515,11 @@ def what_habitat_metrics( return df, WQP_Metadata(response, **kwargs) -def what_project_weights(ssl_check=True, legacy=True, **kwargs): +def what_project_weights( + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, +) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for project weights within a region with specific data. Any WQP API parameter can be passed as a keyword argument to this function. @@ -562,7 +571,11 @@ def what_project_weights(ssl_check=True, legacy=True, **kwargs): return df, WQP_Metadata(response, **kwargs) -def what_activity_metrics(ssl_check=True, legacy=True, **kwargs): +def what_activity_metrics( + ssl_check: bool = True, + legacy: bool = True, + **kwargs: Any, +) -> tuple[DataFrame, WQP_Metadata]: """Search WQP for activity metrics within a region with specific data. Any WQP API parameter can be passed as a keyword argument to this function. @@ -614,7 +627,7 @@ def what_activity_metrics(ssl_check=True, legacy=True, **kwargs): return df, WQP_Metadata(response, **kwargs) -def wqp_url(service): +def wqp_url(service: str) -> str: """Construct the WQP URL for a given service.""" base_url = "https://www.waterqualitydata.us/data/" @@ -628,7 +641,7 @@ def wqp_url(service): return f"{base_url}{service}/Search?" -def wqx3_url(service): +def wqx3_url(service: str) -> str: """Construct the WQP URL for a given WQX 3.0 service.""" base_url = "https://www.waterqualitydata.us/wqx3/" @@ -659,7 +672,7 @@ class WQP_Metadata(BaseMetadata): Site information (via ``what_sites``) if the query included a ``siteid``. """ - def __init__(self, response, **parameters) -> None: + def __init__(self, response: httpx.Response, **parameters: Any) -> None: """Generates a standard set of metadata informed by the response with specific metadata for WQP data. @@ -703,7 +716,7 @@ def site_info(self) -> tuple[DataFrame, WQP_Metadata] | None: return what_sites(siteid=siteid) -def _check_kwargs(kwargs): +def _check_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: """Private function to check kwargs for unsupported parameters.""" mimetype = kwargs.get("mimeType") if mimetype == "geojson": @@ -716,7 +729,7 @@ def _check_kwargs(kwargs): return kwargs -def _warn_wqx3_use(): +def _warn_wqx3_use() -> None: message = ( "Support for the WQX3.0 profiles is experimental. " "Queries may be slow or fail intermittently." @@ -724,7 +737,7 @@ def _warn_wqx3_use(): warnings.warn(message, UserWarning, stacklevel=2) -def _warn_legacy_use(): +def _warn_legacy_use() -> None: message = ( "This function call will return the legacy WQX format, " "which means USGS data have not been updated since March 2024. " @@ -735,7 +748,7 @@ def _warn_legacy_use(): warnings.warn(message, DeprecationWarning, stacklevel=2) -def _warn_wqx3_unavailable(): +def _warn_wqx3_unavailable() -> None: # stacklevel=3: warn -> _warn_wqx3_unavailable -> _legacy_only_url -> what_* warnings.warn( "WQX3.0 profile not available, returning legacy profile.", diff --git a/pyproject.toml b/pyproject.toml index 9af5f6a1..5d018a46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,11 +105,12 @@ docstring-code-format = true docstring-code-line-length = 72 [tool.mypy] -# Incremental-adoption config. ``check_untyped_defs`` type-checks the bodies of -# not-yet-annotated functions; what still keeps this short of ``--strict`` is -# that annotations aren't *required* (``disallow_untyped_defs`` is off) and -# untyped third-party libs are treated as ``Any`` (``ignore_missing_imports``). +# The package is fully annotated and passes ``mypy --strict``. The one +# remaining relaxation is ``ignore_missing_imports``: untyped third-party +# libraries (pandas, geopandas, anyio) are treated as ``Any`` instead of +# requiring stub packages. Dropping that — via pandas-stubs/types-requests and +# per-module overrides — can follow. python_version = "3.9" # the project's minimum supported version files = ["dataretrieval"] +strict = true ignore_missing_imports = true -check_untyped_defs = true From 8096a54c2915a7d485fd187f76e2e660fe1bf437 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Mon, 1 Jun 2026 08:31:43 -0400 Subject: [PATCH 3/3] ci: install only mypy in the type-check job, not the full test stack MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The type-check job ran ``pip install .[test]``, pulling pytest/coverage/ruff/ pytest-httpx just to run mypy. Add a minimal ``type-check`` extra (just ``mypy<2``, pinned once; ``test`` self-references it) and install ``.[type-check]`` instead — the job's install drops from 26 to 16 packages. mypy --strict is unchanged (0 errors). Co-Authored-By: Claude Opus 4.8 (1M context) --- .github/workflows/python-package.yml | 2 +- pyproject.toml | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 175dcd06..d5b7a2d2 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -38,7 +38,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install .[test] + pip install .[type-check] - name: Type-check with mypy run: mypy diff --git a/pyproject.toml b/pyproject.toml index 5d018a46..57f60161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,11 @@ packages = ["dataretrieval", "dataretrieval.codes"] dataretrieval = ["py.typed"] [project.optional-dependencies] +# Minimal set the CI ``type-check`` job installs — just mypy + the package, +# not the whole test stack. +type-check = [ + "mypy<2", # <2 so it can still target Python 3.9 (the project's floor) +] test = [ "pytest > 5.0.0", "pytest-cov[all]", @@ -39,7 +44,7 @@ test = [ "coverage", "pytest-httpx", "ruff", - "mypy<2", # <2 so it can still target Python 3.9 (the project's floor) + "dataretrieval[type-check]", # mypy, pinned once in the type-check extra ] doc = [ "docutils<0.22",