diff --git a/src/spotforecast2/model_selection/boundary.py b/src/spotforecast2/model_selection/boundary.py index bd40373e..58e3f1e0 100644 --- a/src/spotforecast2/model_selection/boundary.py +++ b/src/spotforecast2/model_selection/boundary.py @@ -63,6 +63,26 @@ _logger = logging.getLogger(__name__) +def _position_flag(pos: float, warn_frac: float) -> str: + """Classify a normalized boundary position as near-upper, near-lower, or interior. + + Args: + pos: Position of the tuned value inside its dimension, normalized to + ``[0, 1]`` in the dimension's own scale. + warn_frac: Fraction of the range defining the "near-boundary" zone at + each end. + + Returns: + ``"> upper"``, ``"< lower"``, or ``""`` (interior). + """ + upper_zone = 1.0 - warn_frac + if pos > upper_zone: + return "> upper" + if pos < warn_frac: + return "< lower" + return "" + + def report_boundary_positions( params: Mapping[str, float | int], search_space: Mapping[str, Any], @@ -158,11 +178,7 @@ def report_boundary_positions( ) else: pos = (float(val) - low) / (high - low) - flag = ( - "> upper" - if pos > 1 - warn_frac - else ("< lower" if pos < warn_frac else "") - ) + flag = _position_flag(pos, warn_frac) log.info( " bound %-18s = %-11.5g in [%g, %g]%s pos=%.2f%s", key, @@ -266,11 +282,7 @@ def boundary_report( ) else: pos = (float(val) - low) / (high - low) - flag = ( - "> upper" - if pos > 1 - warn_frac - else ("< lower" if pos < warn_frac else "") - ) + flag = _position_flag(pos, warn_frac) rows.append( { "param": name.replace("estimator__", ""), @@ -298,7 +310,9 @@ def boundary_report( "Use report_boundary_positions() instead if you have unprefixed keys.", numeric_dims, ) - return pd.DataFrame(columns=["param", "low", "high", "value", "scale", "position", "flag"]) + return pd.DataFrame( + columns=["param", "low", "high", "value", "scale", "position", "flag"] + ) return df.sort_values("position", ascending=False).reset_index(drop=True) @@ -389,11 +403,7 @@ def suggest_bounds( """ report = boundary_report(best_params, search_space, warn_frac=warn_frac) - flagged = { - row["param"]: row["flag"] - for _, row in report.iterrows() - if row["flag"] - } + flagged = {row["param"]: row["flag"] for _, row in report.iterrows() if row["flag"]} out: dict[str, Any] = {} for name, spec in search_space.items(): if not (isinstance(spec, tuple) and len(spec) in (2, 3)): diff --git a/tests/test_exog_providers_reexport.py b/tests/test_exog_providers_reexport.py index 30a74322..38481c96 100644 --- a/tests/test_exog_providers_reexport.py +++ b/tests/test_exog_providers_reexport.py @@ -30,6 +30,8 @@ "include_entsoe_renewable_forecast", "include_entsoe_net_load", "include_entsoe_day_ahead_price", + "include_football_match_window", + "include_energy_saving_window", } @@ -46,7 +48,9 @@ def test_registry_reexported_from_preprocessing(): build_providers_from_config, ) - assert set(EXOG_PROVIDER_REGISTRY) == EXPECTED_FLAGS + # Superset, not equality: additive provider releases in sf2-safe must not + # break sf2's suite; removals of expected flags still fail. + assert EXPECTED_FLAGS <= set(EXOG_PROVIDER_REGISTRY) assert issubclass(CovidInfectionRateProvider, ExogFeatureProvider) assert callable(build_providers) and callable(build_providers_from_config) # silence unused-import linters for the re-export assertions diff --git a/tests/test_multitask.py b/tests/test_multitask.py index 1fa26a79..235ff13b 100644 --- a/tests/test_multitask.py +++ b/tests/test_multitask.py @@ -814,8 +814,6 @@ class TestCacheHomeResolution: """BaseTask must resolve cache_home=None to get_cache_home().""" def test_cache_home_none_resolves_to_default(self, tmp_path, monkeypatch): - import logging - from spotforecast2_safe.data.fetch_data import get_cache_home # Redirect the package default so the test never touches the diff --git a/uv.lock b/uv.lock index 6f4d8a47..13d02a26 100644 --- a/uv.lock +++ b/uv.lock @@ -3491,7 +3491,7 @@ wheels = [ [[package]] name = "spotforecast2" -version = "8.0.0" +version = "8.1.0" source = { editable = "." } dependencies = [ { name = "astral" }, @@ -3590,7 +3590,7 @@ dev = [ [[package]] name = "spotforecast2-safe" -version = "22.1.0" +version = "22.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "astral" }, @@ -3607,9 +3607,9 @@ dependencies = [ { name = "statsmodels" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f1/7b/9f7dcc944bde99ae1fbf4c329b95598bed4bb2c1c5ca04b879574bd42027/spotforecast2_safe-22.1.0.tar.gz", hash = "sha256:32666cc9bf320c18d146a24a1e803abfc1de821cf63da2f226e009808f52187c", size = 20654907, upload-time = "2026-06-12T07:21:36.071Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5d/fd/685a4d9797d467ec646c3ffc75b8d6327ff770a7540b47c5a0300f23aac5/spotforecast2_safe-22.2.0.tar.gz", hash = "sha256:fd458aea0a6421cc8229cdbc2314b9a0863508771c286ac2537c8d3221eb1362", size = 20660329, upload-time = "2026-06-12T11:53:21.696Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/a3/a12c5d4fa72350cf6741f53d1d3123385a50117b0592fd7bc7f59038018c/spotforecast2_safe-22.1.0-py3-none-any.whl", hash = "sha256:e8f66ea353c56b64752ae5d44f01e0143ed725c1724108b54e87a25b9e9456df", size = 20723105, upload-time = "2026-06-12T07:21:31.953Z" }, + { url = "https://files.pythonhosted.org/packages/30/1b/24262db44be056e62680e4ef28e8423bd977abeac68b42c15d045f6941e3/spotforecast2_safe-22.2.0-py3-none-any.whl", hash = "sha256:1b79b7d132024103da23ccfb4e8057583e3d04906574e633c78f97e163301da9", size = 20729694, upload-time = "2026-06-12T11:53:19.078Z" }, ] [[package]]