Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions _freeze/docs/tasks/task_multi/execute-results/html.json

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
24 changes: 0 additions & 24 deletions docs/tasks/task_multi.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -206,29 +206,6 @@ instead of silent dropping) and to plotting:
`MultiTask.plot_with_outliers()` raises `NotImplementedError` because no
plotting library is permitted in this package.

## The n-to-1 task entry point

[`run_pipeline`](../reference/tasks.task_safe_n_to_1_with_covariates_and_dataframe.qmd) from
`task_safe_n_to_1_with_covariates_and_dataframe`
wraps exactly this pipeline with `task="lazy"` — one call from config and
DataFrame to combined forecast:

```{python}
from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import (
run_pipeline,
)

forecast = run_pipeline(config=cfg, dataframe=df, cache_home=cache)
forecast.head(3)
```

The matching console script accepts the same knobs as flags:

```bash
uv run spotforecast-safe-n2o1-cov-df --forecast_horizon 24 --lags 24 \
--include_holiday_features true
```

## Scaling up from the toy example

For a real run, switch the feature machinery on instead of off:
Expand Down Expand Up @@ -263,5 +240,4 @@ design, absent.

## Where to go next

- API reference: [`run_pipeline`](../reference/tasks.task_safe_n_to_1_with_covariates_and_dataframe.qmd) — the CLI-facing wrapper around this pipeline.
- API reference: [`MultiTask`](../reference/multitask.multi.MultiTask.qmd), [`BaseTask`](../reference/multitask.base.BaseTask.qmd), [`ConfigMulti`](../reference/configurator.config_multi.ConfigMulti.qmd), [`runner.run`](../reference/multitask.runner.run.qmd).
26 changes: 10 additions & 16 deletions tests/preprocessing/test_target_corruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,14 +650,11 @@ def test_fall_back_no_raise(self):
)
vals = [BASE_MW] * len(idx)
df = pd.DataFrame({"load": vals}, index=idx)
try:
mask = detect_target_corruption(
df, targets=["load"], range_mw=5_000, step_mw=8_000, window_days=7
)
except Exception as exc:
pytest.fail(
f"detect_target_corruption raised on fall-back DST index: {exc}"
)
# A raise here fails the test (with a full traceback); that is exactly
# the "must not raise" guarantee this case is asserting.
mask = detect_target_corruption(
df, targets=["load"], range_mw=5_000, step_mw=8_000, window_days=7
)
assert not mask.any(), "Clean DST week must produce no flags."

def test_fall_back_dropout_is_flagged(self):
Expand Down Expand Up @@ -694,14 +691,11 @@ def test_spring_forward_no_raise(self):
)
vals = [BASE_MW] * len(idx)
df = pd.DataFrame({"load": vals}, index=idx)
try:
mask = detect_target_corruption(
df, targets=["load"], range_mw=5_000, step_mw=8_000, window_days=7
)
except Exception as exc:
pytest.fail(
f"detect_target_corruption raised on spring-forward DST index: {exc}"
)
# A raise here fails the test (with a full traceback); that is exactly
# the "must not raise" guarantee this case is asserting.
mask = detect_target_corruption(
df, targets=["load"], range_mw=5_000, step_mw=8_000, window_days=7
)
assert not mask.any(), "Clean spring-forward DST week must produce no flags."


Expand Down
18 changes: 7 additions & 11 deletions tests/test_entsoe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

import spotforecast2_safe.data.entsoe_loader as entsoe_loader
from spotforecast2_safe.configurator import ConfigEntsoe
from spotforecast2_safe.data.entsoe_loader import (
entsoe_data_loader,
entsoe_test_data_loader,
)


def _write_interim_csv(path, start: str, periods: int, tz: str | None = "UTC"):
Expand All @@ -29,7 +25,7 @@ def test_absolute_path_loads_full_frame(self, tmp_path):

config = ConfigEntsoe()
config.data_filename = str(csv_path)
df = entsoe_data_loader(config)
df = entsoe_loader.entsoe_data_loader(config)

assert df.shape == (48, 1)
assert df.index.name == "Time (UTC)"
Expand All @@ -41,7 +37,7 @@ def test_relative_path_resolves_against_data_home(self, tmp_path, monkeypatch):

config = ConfigEntsoe()
config.data_filename = "energy_load.csv"
df = entsoe_data_loader(config)
df = entsoe_loader.entsoe_data_loader(config)

assert df.shape == (24, 1)

Expand All @@ -50,7 +46,7 @@ def test_missing_file_raises_with_cli_hint(self, tmp_path):
config.data_filename = str(tmp_path / "does_not_exist.csv")

with pytest.raises(FileNotFoundError, match="spotforecast2-entsoe"):
entsoe_data_loader(config)
entsoe_loader.entsoe_data_loader(config)


class TestEntsoeTestDataLoader:
Expand All @@ -66,7 +62,7 @@ def test_slices_predict_size_steps_after_end_train(self, tmp_path):
_write_interim_csv(csv_path, "2025-12-29 00:00", 120)
config = self._config(csv_path, "2025-12-31 00:00+00:00")

test_df = entsoe_test_data_loader(config)
test_df = entsoe_loader.entsoe_test_data_loader(config)

assert test_df.shape == (24, 1)
assert test_df.index[0] == pd.Timestamp("2025-12-31 01:00", tz="UTC")
Expand All @@ -77,7 +73,7 @@ def test_naive_end_train_is_localized_to_utc(self, tmp_path):
_write_interim_csv(csv_path, "2025-12-29 00:00", 120)
config = self._config(csv_path, "2025-12-31 00:00") # no tz marker

test_df = entsoe_test_data_loader(config)
test_df = entsoe_loader.entsoe_test_data_loader(config)

assert test_df.shape == (24, 1)
assert test_df.index[0] == pd.Timestamp("2025-12-31 01:00", tz="UTC")
Expand All @@ -87,7 +83,7 @@ def test_naive_csv_index_is_supported(self, tmp_path):
_write_interim_csv(csv_path, "2025-12-29 00:00", 120, tz=None)
config = self._config(csv_path, "2025-12-31 00:00+00:00")

test_df = entsoe_test_data_loader(config)
test_df = entsoe_loader.entsoe_test_data_loader(config)

assert test_df.shape == (24, 1)
assert test_df.index[0] == pd.Timestamp("2025-12-31 01:00")
Expand All @@ -98,6 +94,6 @@ def test_window_shorter_when_data_runs_out(self, tmp_path):
_write_interim_csv(csv_path, "2025-12-29 00:00", 60) # ends 12-31 11:00
config = self._config(csv_path, "2025-12-31 00:00+00:00")

test_df = entsoe_test_data_loader(config)
test_df = entsoe_loader.entsoe_test_data_loader(config)

assert len(test_df) == 11 # only the rows that exist after the cutoff
Loading