Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/spotforecast2/plots/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,19 @@ def plot_feature_importance_by_family(
print(type(fig).__name__)
```
"""
ranking = sorted(
zip(names, importances), key=lambda kv: kv[1], reverse=True
)[:top_n]
ranking = sorted(zip(names, importances), key=lambda kv: kv[1], reverse=True)[
:top_n
]
labels = [n for n, _ in ranking][::-1]
values = [v for _, v in ranking][::-1]
colors = [_FAMILY_COLOR.get(feature_family(n), "#888888") for n in labels]

fig, ax = plt.subplots(figsize=(8.5, 5.5))
ax.barh(labels, values, color=colors)
ax.set_xlabel("split count (feature importance)")
ax.set_title(
f"Top-{top_n} feature importances (coloured by family; lags in red)"
)
ax.set_title(f"Top-{top_n} feature importances (coloured by family; lags in red)")
ax.grid(True, axis="x", color="#E5E5E5", linewidth=0.5)
handles = [
plt.Rectangle((0, 0), 1, 1, color=c) for c in _FAMILY_COLOR.values()
]
handles = [plt.Rectangle((0, 0), 1, 1, color=c) for c in _FAMILY_COLOR.values()]
ax.legend(handles, _FAMILY_COLOR.keys(), fontsize=7, loc="lower right")
for spine in ("top", "right"):
ax.spines[spine].set_visible(False)
Expand Down
27 changes: 19 additions & 8 deletions tests/test_model_selection_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
suggest_bounds,
)


# ---------------------------------------------------------------------------
# Shared fixtures / helpers
# ---------------------------------------------------------------------------
Expand All @@ -38,6 +37,7 @@
# report_boundary_positions
# ---------------------------------------------------------------------------


class TestReportBoundaryPositions:
def test_interior_no_flag(self):
"""An optimum well inside all bounds returns an empty list."""
Expand Down Expand Up @@ -72,6 +72,7 @@ def test_log10_dim_near_lower(self):
# learning_rate = 0.006 in (0.005, 0.3) log10:
# pos = (log10(0.006) - log10(0.005)) / (log10(0.3) - log10(0.005))
import math

low, high, val = 0.005, 0.3, 0.006
pos = (math.log10(val) - math.log10(low)) / (math.log10(high) - math.log10(low))
assert pos < 0.10 # confirm this is a near-lower case
Expand Down Expand Up @@ -131,9 +132,7 @@ def test_logger_injection(self, caplog):
with caplog.at_level(logging.INFO, logger="test_custom_boundary_logger"):
report_boundary_positions(params, space, logger=custom_logger)
# The custom logger should have emitted at least one INFO message
assert any(
r.name == "test_custom_boundary_logger" for r in caplog.records
)
assert any(r.name == "test_custom_boundary_logger" for r in caplog.records)

def test_multiple_flags(self):
"""Multiple near-boundary dims all appear in the returned list."""
Expand All @@ -157,6 +156,7 @@ def test_log_invalid_val_skipped(self):
# boundary_report (DataFrame form)
# ---------------------------------------------------------------------------


class TestBoundaryReport:
def test_returns_dataframe(self):
best = {"estimator__reg_alpha": 9.89, "estimator__learning_rate": 0.069}
Expand All @@ -165,7 +165,15 @@ def test_returns_dataframe(self):
"estimator__learning_rate": (0.005, 0.3, "log10"),
}
df = boundary_report(best, space)
assert set(df.columns) == {"param", "low", "high", "value", "scale", "position", "flag"}
assert set(df.columns) == {
"param",
"low",
"high",
"value",
"scale",
"position",
"flag",
}

def test_flagged_near_upper(self):
best = {"estimator__reg_alpha": 9.89}
Expand All @@ -191,8 +199,8 @@ def test_categorical_skipped(self):

def test_sorted_descending_position(self):
best = {
"estimator__reg_alpha": 9.89, # near upper → high position
"estimator__num_leaves": 300, # interior
"estimator__reg_alpha": 9.89, # near upper → high position
"estimator__num_leaves": 300, # interior
}
space = {
"estimator__reg_alpha": (0.001, 10.0),
Expand All @@ -213,6 +221,7 @@ def test_prefix_stripped_in_param_column(self):
# suggest_bounds
# ---------------------------------------------------------------------------


class TestSuggestBounds:
def test_interior_unchanged(self):
best = {"estimator__num_leaves": 300}
Expand Down Expand Up @@ -277,7 +286,9 @@ def test_all_keys_preserved(self):
# The near-upper-boundary integer dim must have been widened upward
assert new_space["estimator__num_leaves"][1] > 1024
# Interior log10 dim is unchanged
assert new_space["estimator__learning_rate"] == space["estimator__learning_rate"]
assert (
new_space["estimator__learning_rate"] == space["estimator__learning_rate"]
)

def test_widen_factor_parameter(self):
"""Different widen_factor values produce different bounds."""
Expand Down
6 changes: 3 additions & 3 deletions tests/test_plot_with_outliers_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@ def test_no_targets_kwarg_emits_no_deprecation_warning(self):
for w in caught
if issubclass(w.category, (DeprecationWarning, FutureWarning))
]
assert deprecation_warnings == [], (
f"Unexpected deprecation warnings: {[str(w.message) for w in deprecation_warnings]}"
)
assert (
deprecation_warnings == []
), f"Unexpected deprecation warnings: {[str(w.message) for w in deprecation_warnings]}"

def test_explicit_targets_wins_over_config_targets(self):
"""When explicit targets= differs from config.targets, explicit wins."""
Expand Down
29 changes: 13 additions & 16 deletions tests/test_plots_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
plot_shap_summary,
)


# ---------------------------------------------------------------------------
# Teardown
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -174,8 +173,14 @@ def test_importance_plot_top_n_respected():

def test_importance_plot_family_colors_distinct():
"""Known feature names from different families get different bar colors."""
names = ["lag_1", "holiday_DE", "poly_hour_2", "window_mean_72",
"sin_hour", "wind_speed"]
names = [
"lag_1",
"holiday_DE",
"poly_hour_2",
"window_mean_72",
"sin_hour",
"wind_speed",
]
scores = [100, 90, 80, 70, 60, 50]
fig = plot_feature_importance_by_family(names, scores)
ax = fig.axes[0]
Expand Down Expand Up @@ -205,9 +210,7 @@ def test_shap_summary_returns_figure():
from lightgbm import LGBMRegressor

rng = np.random.default_rng(7)
X = pd.DataFrame(
rng.standard_normal((120, 4)), columns=["f0", "f1", "f2", "f3"]
)
X = pd.DataFrame(rng.standard_normal((120, 4)), columns=["f0", "f1", "f2", "f3"])
y = X["f0"] * 2 + rng.standard_normal(120) * 0.1
est = LGBMRegressor(n_estimators=10, verbose=-1, random_state=0)
est.fit(X, y)
Expand All @@ -228,9 +231,7 @@ def test_shap_summary_subsamples_max_samples():
from lightgbm import LGBMRegressor

rng = np.random.default_rng(9)
X = pd.DataFrame(
rng.standard_normal((500, 3)), columns=["a", "b", "c"]
)
X = pd.DataFrame(rng.standard_normal((500, 3)), columns=["a", "b", "c"])
y = X["a"] + rng.standard_normal(500) * 0.2
est = LGBMRegressor(n_estimators=10, verbose=-1, random_state=0)
est.fit(X, y)
Expand All @@ -254,8 +255,7 @@ def test_forecast_vs_reference_returns_figure_with_overlap():
forecast = _make_forecast(24)
reference = _make_forecast(24, seed=1) # same index -> full overlap
fig = plot_forecast_vs_reference(
forecast, reference,
forecast_label="team_4", reference_label="ENTSO-E"
forecast, reference, forecast_label="team_4", reference_label="ENTSO-E"
)
assert isinstance(fig, Figure)

Expand Down Expand Up @@ -297,9 +297,7 @@ def test_forecast_vs_reference_unit_scale_applied(caplog):
forecast = _make_forecast(24)
reference = _make_forecast(24, seed=5)
with caplog.at_level(logging.INFO, logger="spotforecast2.plots.diagnostics"):
fig = plot_forecast_vs_reference(
forecast, reference, unit_scale=1.0, unit="MW"
)
fig = plot_forecast_vs_reference(forecast, reference, unit_scale=1.0, unit="MW")
ax = fig.axes[0]
# y-axis upper limit should be in the MW range (>1000), not GW (<100)
ymin, ymax = ax.get_ylim()
Expand Down Expand Up @@ -339,6 +337,5 @@ def test_forecast_vs_reference_mad_logged_in_display_unit(caplog):
assert match, f"Could not parse GW value from log message: {msg!r}"
logged_value = float(match.group(1))
assert logged_value == pytest.approx(1.0, abs=0.05), (
f"Logged MAD should be 1.0 GW, got {logged_value}. "
f"Full message: {msg!r}"
f"Logged MAD should be 1.0 GW, got {logged_value}. " f"Full message: {msg!r}"
)
5 changes: 4 additions & 1 deletion tests/test_stats_select_pacf_lags.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

from spotforecast2.stats.autocorrelation import select_pacf_lags


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_daily_ar_series(n_days: int = 120, seed: int = 42) -> pd.Series:
"""AR(1) with moderate coefficient for general lag selection tests."""
rng = np.random.default_rng(seed)
Expand All @@ -40,6 +40,7 @@ def _make_ar24_series(n_days: int = 200, seed: int = 42) -> pd.Series:
# Happy-path tests
# ---------------------------------------------------------------------------


class TestSelectPacfLagsHappyPath:
def test_returns_list_of_ints(self):
series = _make_daily_ar_series()
Expand Down Expand Up @@ -89,6 +90,7 @@ def test_n_lags_limits_search(self):
# Degenerate / edge-case tests
# ---------------------------------------------------------------------------


class TestSelectPacfLagsDegenerate:
def test_constant_series_fallback_returned(self):
series = pd.Series([1.0] * 50)
Expand Down Expand Up @@ -139,6 +141,7 @@ def test_fallback_none_is_default(self):
# Parameter variation
# ---------------------------------------------------------------------------


class TestSelectPacfLagsParams:
def test_default_n_lags_200_top_k_8(self):
"""Default parameters work on a long series without error."""
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading