diff --git a/src/spotforecast2/plots/diagnostics.py b/src/spotforecast2/plots/diagnostics.py index f5d68b34..8214252b 100644 --- a/src/spotforecast2/plots/diagnostics.py +++ b/src/spotforecast2/plots/diagnostics.py @@ -185,9 +185,9 @@ def plot_feature_importance_by_family( print(type(fig).__name__) ``` """ - ranking = sorted( - zip(names, importances), key=lambda kv: kv[1], reverse=True - )[:top_n] + ranking = sorted(zip(names, importances), key=lambda kv: kv[1], reverse=True)[ + :top_n + ] labels = [n for n, _ in ranking][::-1] values = [v for _, v in ranking][::-1] colors = [_FAMILY_COLOR.get(feature_family(n), "#888888") for n in labels] @@ -195,13 +195,9 @@ def plot_feature_importance_by_family( fig, ax = plt.subplots(figsize=(8.5, 5.5)) ax.barh(labels, values, color=colors) ax.set_xlabel("split count (feature importance)") - ax.set_title( - f"Top-{top_n} feature importances (coloured by family; lags in red)" - ) + ax.set_title(f"Top-{top_n} feature importances (coloured by family; lags in red)") ax.grid(True, axis="x", color="#E5E5E5", linewidth=0.5) - handles = [ - plt.Rectangle((0, 0), 1, 1, color=c) for c in _FAMILY_COLOR.values() - ] + handles = [plt.Rectangle((0, 0), 1, 1, color=c) for c in _FAMILY_COLOR.values()] ax.legend(handles, _FAMILY_COLOR.keys(), fontsize=7, loc="lower right") for spine in ("top", "right"): ax.spines[spine].set_visible(False) diff --git a/tests/test_model_selection_boundary.py b/tests/test_model_selection_boundary.py index 1386bdc3..fb27619b 100644 --- a/tests/test_model_selection_boundary.py +++ b/tests/test_model_selection_boundary.py @@ -13,7 +13,6 @@ suggest_bounds, ) - # --------------------------------------------------------------------------- # Shared fixtures / helpers # --------------------------------------------------------------------------- @@ -38,6 +37,7 @@ # report_boundary_positions # --------------------------------------------------------------------------- + class TestReportBoundaryPositions: def test_interior_no_flag(self): """An optimum well inside all bounds returns an empty list.""" @@ -72,6 +72,7 @@ def test_log10_dim_near_lower(self): # learning_rate = 0.006 in (0.005, 0.3) log10: # pos = (log10(0.006) - log10(0.005)) / (log10(0.3) - log10(0.005)) import math + low, high, val = 0.005, 0.3, 0.006 pos = (math.log10(val) - math.log10(low)) / (math.log10(high) - math.log10(low)) assert pos < 0.10 # confirm this is a near-lower case @@ -131,9 +132,7 @@ def test_logger_injection(self, caplog): with caplog.at_level(logging.INFO, logger="test_custom_boundary_logger"): report_boundary_positions(params, space, logger=custom_logger) # The custom logger should have emitted at least one INFO message - assert any( - r.name == "test_custom_boundary_logger" for r in caplog.records - ) + assert any(r.name == "test_custom_boundary_logger" for r in caplog.records) def test_multiple_flags(self): """Multiple near-boundary dims all appear in the returned list.""" @@ -157,6 +156,7 @@ def test_log_invalid_val_skipped(self): # boundary_report (DataFrame form) # --------------------------------------------------------------------------- + class TestBoundaryReport: def test_returns_dataframe(self): best = {"estimator__reg_alpha": 9.89, "estimator__learning_rate": 0.069} @@ -165,7 +165,15 @@ def test_returns_dataframe(self): "estimator__learning_rate": (0.005, 0.3, "log10"), } df = boundary_report(best, space) - assert set(df.columns) == {"param", "low", "high", "value", "scale", "position", "flag"} + assert set(df.columns) == { + "param", + "low", + "high", + "value", + "scale", + "position", + "flag", + } def test_flagged_near_upper(self): best = {"estimator__reg_alpha": 9.89} @@ -191,8 +199,8 @@ def test_categorical_skipped(self): def test_sorted_descending_position(self): best = { - "estimator__reg_alpha": 9.89, # near upper → high position - "estimator__num_leaves": 300, # interior + "estimator__reg_alpha": 9.89, # near upper → high position + "estimator__num_leaves": 300, # interior } space = { "estimator__reg_alpha": (0.001, 10.0), @@ -213,6 +221,7 @@ def test_prefix_stripped_in_param_column(self): # suggest_bounds # --------------------------------------------------------------------------- + class TestSuggestBounds: def test_interior_unchanged(self): best = {"estimator__num_leaves": 300} @@ -277,7 +286,9 @@ def test_all_keys_preserved(self): # The near-upper-boundary integer dim must have been widened upward assert new_space["estimator__num_leaves"][1] > 1024 # Interior log10 dim is unchanged - assert new_space["estimator__learning_rate"] == space["estimator__learning_rate"] + assert ( + new_space["estimator__learning_rate"] == space["estimator__learning_rate"] + ) def test_widen_factor_parameter(self): """Different widen_factor values produce different bounds.""" diff --git a/tests/test_plot_with_outliers_bounds.py b/tests/test_plot_with_outliers_bounds.py index 5dd9e984..c889ab66 100644 --- a/tests/test_plot_with_outliers_bounds.py +++ b/tests/test_plot_with_outliers_bounds.py @@ -323,9 +323,9 @@ def test_no_targets_kwarg_emits_no_deprecation_warning(self): for w in caught if issubclass(w.category, (DeprecationWarning, FutureWarning)) ] - assert deprecation_warnings == [], ( - f"Unexpected deprecation warnings: {[str(w.message) for w in deprecation_warnings]}" - ) + assert ( + deprecation_warnings == [] + ), f"Unexpected deprecation warnings: {[str(w.message) for w in deprecation_warnings]}" def test_explicit_targets_wins_over_config_targets(self): """When explicit targets= differs from config.targets, explicit wins.""" diff --git a/tests/test_plots_diagnostics.py b/tests/test_plots_diagnostics.py index d8b4fdb5..647c1dba 100644 --- a/tests/test_plots_diagnostics.py +++ b/tests/test_plots_diagnostics.py @@ -26,7 +26,6 @@ plot_shap_summary, ) - # --------------------------------------------------------------------------- # Teardown # --------------------------------------------------------------------------- @@ -174,8 +173,14 @@ def test_importance_plot_top_n_respected(): def test_importance_plot_family_colors_distinct(): """Known feature names from different families get different bar colors.""" - names = ["lag_1", "holiday_DE", "poly_hour_2", "window_mean_72", - "sin_hour", "wind_speed"] + names = [ + "lag_1", + "holiday_DE", + "poly_hour_2", + "window_mean_72", + "sin_hour", + "wind_speed", + ] scores = [100, 90, 80, 70, 60, 50] fig = plot_feature_importance_by_family(names, scores) ax = fig.axes[0] @@ -205,9 +210,7 @@ def test_shap_summary_returns_figure(): from lightgbm import LGBMRegressor rng = np.random.default_rng(7) - X = pd.DataFrame( - rng.standard_normal((120, 4)), columns=["f0", "f1", "f2", "f3"] - ) + X = pd.DataFrame(rng.standard_normal((120, 4)), columns=["f0", "f1", "f2", "f3"]) y = X["f0"] * 2 + rng.standard_normal(120) * 0.1 est = LGBMRegressor(n_estimators=10, verbose=-1, random_state=0) est.fit(X, y) @@ -228,9 +231,7 @@ def test_shap_summary_subsamples_max_samples(): from lightgbm import LGBMRegressor rng = np.random.default_rng(9) - X = pd.DataFrame( - rng.standard_normal((500, 3)), columns=["a", "b", "c"] - ) + X = pd.DataFrame(rng.standard_normal((500, 3)), columns=["a", "b", "c"]) y = X["a"] + rng.standard_normal(500) * 0.2 est = LGBMRegressor(n_estimators=10, verbose=-1, random_state=0) est.fit(X, y) @@ -254,8 +255,7 @@ def test_forecast_vs_reference_returns_figure_with_overlap(): forecast = _make_forecast(24) reference = _make_forecast(24, seed=1) # same index -> full overlap fig = plot_forecast_vs_reference( - forecast, reference, - forecast_label="team_4", reference_label="ENTSO-E" + forecast, reference, forecast_label="team_4", reference_label="ENTSO-E" ) assert isinstance(fig, Figure) @@ -297,9 +297,7 @@ def test_forecast_vs_reference_unit_scale_applied(caplog): forecast = _make_forecast(24) reference = _make_forecast(24, seed=5) with caplog.at_level(logging.INFO, logger="spotforecast2.plots.diagnostics"): - fig = plot_forecast_vs_reference( - forecast, reference, unit_scale=1.0, unit="MW" - ) + fig = plot_forecast_vs_reference(forecast, reference, unit_scale=1.0, unit="MW") ax = fig.axes[0] # y-axis upper limit should be in the MW range (>1000), not GW (<100) ymin, ymax = ax.get_ylim() @@ -339,6 +337,5 @@ def test_forecast_vs_reference_mad_logged_in_display_unit(caplog): assert match, f"Could not parse GW value from log message: {msg!r}" logged_value = float(match.group(1)) assert logged_value == pytest.approx(1.0, abs=0.05), ( - f"Logged MAD should be 1.0 GW, got {logged_value}. " - f"Full message: {msg!r}" + f"Logged MAD should be 1.0 GW, got {logged_value}. " f"Full message: {msg!r}" ) diff --git a/tests/test_stats_select_pacf_lags.py b/tests/test_stats_select_pacf_lags.py index 7cf687d3..ba2ed379 100644 --- a/tests/test_stats_select_pacf_lags.py +++ b/tests/test_stats_select_pacf_lags.py @@ -11,11 +11,11 @@ from spotforecast2.stats.autocorrelation import select_pacf_lags - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _make_daily_ar_series(n_days: int = 120, seed: int = 42) -> pd.Series: """AR(1) with moderate coefficient for general lag selection tests.""" rng = np.random.default_rng(seed) @@ -40,6 +40,7 @@ def _make_ar24_series(n_days: int = 200, seed: int = 42) -> pd.Series: # Happy-path tests # --------------------------------------------------------------------------- + class TestSelectPacfLagsHappyPath: def test_returns_list_of_ints(self): series = _make_daily_ar_series() @@ -89,6 +90,7 @@ def test_n_lags_limits_search(self): # Degenerate / edge-case tests # --------------------------------------------------------------------------- + class TestSelectPacfLagsDegenerate: def test_constant_series_fallback_returned(self): series = pd.Series([1.0] * 50) @@ -139,6 +141,7 @@ def test_fallback_none_is_default(self): # Parameter variation # --------------------------------------------------------------------------- + class TestSelectPacfLagsParams: def test_default_n_lags_200_top_k_8(self): """Default parameters work on a long series without error.""" diff --git a/uv.lock b/uv.lock index 13d02a26..c91a1374 100644 --- a/uv.lock +++ b/uv.lock @@ -3491,7 +3491,7 @@ wheels = [ [[package]] name = "spotforecast2" -version = "8.1.0" +version = "8.1.1" source = { editable = "." } dependencies = [ { name = "astral" },