From 5893fc5764744f31fcebe0773941a7e8289beea1 Mon Sep 17 00:00:00 2001 From: bartzbeielstein <32470350+bartzbeielstein@users.noreply.github.com> Date: Mon, 8 Jun 2026 20:41:04 +0200 Subject: [PATCH 1/2] feat(tasks)!: remove the spotforecast-safe-n2o1-cov-df console task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The n-to-1-with-covariates task was a thin CLI wrapper whose real logic lives in spotforecast2_safe.multitask (BaseTask/MultiTask/runner.run) — retained as public API and exercised by the downstream spotforecast2 package. Removing the wrapper drops no library capability and clears the lone isinstance(config, ConfigMulti) guard that blocked ConfigEntsoe from inheriting ConfigMulti. Removed: the module, the `spotforecast-safe-n2o1-cov-df` console script (pyproject [project.scripts]), the `n2o1_cov_df_main` export from tasks/__init__, its dedicated tests (test_cli_n2o1, test_task_safe_n_to_1_with_covariates, and the TestTaskSafeN2O1CovDf class), and its docs/quartodoc entries. No downstream consumer (spotforecast2, bart26k-lecture) imports the module or calls the script. BREAKING CHANGE: the `spotforecast-safe-n2o1-cov-df` console entry point and the `spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe` module (incl. `run_pipeline`, `main`, and the `n2o1_cov_df_main` export) are removed. Use `spotforecast2_safe.multitask.MultiTask` / `multitask.runner.run` directly. Co-Authored-By: Claude Opus 4.8 (1M context) --- _quarto.yml | 3 - docs/reference/index.qmd | 3 +- ...e_n_to_1_with_covariates_and_dataframe.qmd | 157 ------- docs/safe/spotforecast2-safe.qmd | 6 +- docs/tasks/tasks.qmd | 62 --- .../n2n_predict_with_covariates_explained.qmd | 10 +- pyproject.toml | 1 - src/spotforecast2_safe/tasks/__init__.py | 3 +- ...fe_n_to_1_with_covariates_and_dataframe.py | 413 ------------------ tests/test_cli_n2o1.py | 134 ------ .../test_task_safe_n_to_1_with_covariates.py | 327 -------------- tests/test_tasks.py | 55 --- 12 files changed, 10 insertions(+), 1164 deletions(-) delete mode 100644 docs/reference/tasks.task_safe_n_to_1_with_covariates_and_dataframe.qmd delete mode 100644 src/spotforecast2_safe/tasks/task_safe_n_to_1_with_covariates_and_dataframe.py delete mode 100644 tests/test_cli_n2o1.py delete mode 100644 tests/test_task_safe_n_to_1_with_covariates.py diff --git a/_quarto.yml b/_quarto.yml index d13a47e2f..d40b8892f 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -421,8 +421,6 @@ website: contents: - text: "task_safe_demo" file: docs/reference/tasks.task_safe_demo.qmd - - text: "task_safe_n_to_1_with_covariates_and_dataframe" - file: docs/reference/tasks.task_safe_n_to_1_with_covariates_and_dataframe.qmd - section: "Processing Guides" contents: @@ -791,4 +789,3 @@ quartodoc: Executable tasks for demonstration and production pipelines. contents: - tasks.task_safe_demo - - tasks.task_safe_n_to_1_with_covariates_and_dataframe diff --git a/docs/reference/index.qmd b/docs/reference/index.qmd index abc2424d5..7aa40d9a4 100644 --- a/docs/reference/index.qmd +++ b/docs/reference/index.qmd @@ -292,5 +292,4 @@ Executable tasks for demonstration and production pipelines. | | | | --- | --- | -| [tasks.task_safe_demo](tasks.task_safe_demo.qmd#spotforecast2_safe.tasks.task_safe_demo) | Task demo: compare baseline, covariate, and custom LightGBM forecasts against ground truth. | -| [tasks.task_safe_n_to_1_with_covariates_and_dataframe](tasks.task_safe_n_to_1_with_covariates_and_dataframe.qmd#spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe) | Thin ConfigMulti-driven entry point for N-to-1 forecasting. | \ No newline at end of file +| [tasks.task_safe_demo](tasks.task_safe_demo.qmd#spotforecast2_safe.tasks.task_safe_demo) | Task demo: compare baseline, covariate, and custom LightGBM forecasts against ground truth. | \ No newline at end of file diff --git a/docs/reference/tasks.task_safe_n_to_1_with_covariates_and_dataframe.qmd b/docs/reference/tasks.task_safe_n_to_1_with_covariates_and_dataframe.qmd deleted file mode 100644 index c287bd2ff..000000000 --- a/docs/reference/tasks.task_safe_n_to_1_with_covariates_and_dataframe.qmd +++ /dev/null @@ -1,157 +0,0 @@ -# tasks.task_safe_n_to_1_with_covariates_and_dataframe { #spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe } - -`tasks.task_safe_n_to_1_with_covariates_and_dataframe` - -Thin ConfigMulti-driven entry point for N-to-1 forecasting. - -This module is a single-call wrapper around the ``multitask`` pipeline. -It delegates all heavy lifting to -``spotforecast2_safe.multitask.runner.run`` with ``task="lazy"`` and -returns the forecast DataFrame directly. - -``run_pipeline`` requires an explicit ``ConfigMulti`` instance. Outlier -``bounds`` and aggregation ``agg_weights`` are domain-specific calibrations -and must be supplied by the caller on ``ConfigMulti``; there are no -dataset-specific presets. Input data must always be passed explicitly via -the ``dataframe`` argument. The CLI flag ``--weights`` maps to -``ConfigMulti.agg_weights``; the flag ``--train_ratio`` derives -``train_size`` from the extent of the bundled ``demo10.csv`` (Python API -callers supply ``train_size`` explicitly on ``ConfigMulti``). - -CLI entry point: ``spotforecast-safe-n2o1-cov-df`` - -## Functions - -| Name | Description | -| --- | --- | -| [main](#spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe.main) | CLI entry point for the N-to-1 forecasting pipeline. | -| [run_pipeline](#spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe.run_pipeline) | Execute the N-to-1 forecasting pipeline and return the forecast DataFrame. | - -### main { #spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe.main } - -```python -tasks.task_safe_n_to_1_with_covariates_and_dataframe.main(argv=None) -``` - -CLI entry point for the N-to-1 forecasting pipeline. - -Parses command-line arguments, builds a ``ConfigMulti`` via -``_build_config_from_cli``, and delegates to ``run_pipeline``. -Prints the forecast head to stdout. - -When ``argv`` is ``None``, ``sys.argv[1:]`` is used. Pass an explicit -list of strings to invoke programmatically with a specific argv (useful -for testing without touching ``sys.argv``). - -#### Parameters {.doc-section .doc-section-parameters} - -| Name | Type | Description | Default | -|--------|------------------------------------------------------------------------|--------------------------------------------------------|-----------| -| argv | [Optional](`typing.Optional`)\[[List](`typing.List`)\[[str](`str`)\]\] | Argument list. ``None`` means read from ``sys.argv``. | `None` | - -#### Examples {.doc-section .doc-section-examples} - -```{python} -import sys -from io import StringIO -from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import main - -# Capture help text to verify the CLI is wired correctly without -# triggering a full training run. -buf = StringIO() -old_stdout, sys.stdout = sys.stdout, buf -try: - main(["--help"]) -except SystemExit as exc: - assert exc.code == 0, f"--help exited with non-zero code: {exc.code}" -finally: - sys.stdout = old_stdout - -help_text = buf.getvalue() -assert "forecast_horizon" in help_text -assert "lags" in help_text -print(help_text[:120]) -``` - -### run_pipeline { #spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe.run_pipeline } - -```python -tasks.task_safe_n_to_1_with_covariates_and_dataframe.run_pipeline( - config=None, - *, - dataframe=None, - data_test=None, - project_name='demo10', - cache_home=None, - show_progress=False, -) -``` - -Execute the N-to-1 forecasting pipeline and return the forecast DataFrame. - -Execution is delegated to ``spotforecast2_safe.multitask.runner.run`` -with ``task="lazy"``. A ``ConfigMulti`` instance must be supplied -explicitly; there is no implicit fallback. Outlier ``bounds`` and -aggregation ``agg_weights`` are domain-specific calibrations and must -be provided by the caller — no preset values are substituted. Input -data must likewise be supplied via ``dataframe``; auto-loading is not -performed. - -#### Parameters {.doc-section .doc-section-parameters} - -| Name | Type | Description | Default | -|---------------|------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|------------| -| config | [Optional](`typing.Optional`)\[[ConfigMulti](`spotforecast2_safe.configurator.config_multi.ConfigMulti`)\] | A ``ConfigMulti`` instance. Must not be ``None``; passing ``None`` raises ``ValueError``. | `None` | -| dataframe | [Optional](`typing.Optional`)\[[pd](`pandas`).[DataFrame](`pandas.DataFrame`)\] | Input time-series DataFrame. Must contain a datetime column matching ``config.index_name`` and at least one numeric target column. | `None` | -| data_test | [Optional](`typing.Optional`)\[[pd](`pandas`).[DataFrame](`pandas.DataFrame`)\] | Ground-truth DataFrame covering the prediction horizon. Optional; passed through to the runner for metric computation. | `None` | -| project_name | [str](`str`) | Cache subdirectory and model-file identifier. Defaults to ``"demo10"``. | `'demo10'` | -| cache_home | [Optional](`typing.Optional`)\[[Path](`pathlib.Path`)\] | Cache directory override. When ``None``, the package default (``~/.spotforecast2_cache/``) is used. | `None` | -| show_progress | [bool](`bool`) | Whether to emit progress messages during pipeline execution. | `False` | - -#### Returns {.doc-section .doc-section-returns} - -| Name | Type | Description | -|--------|------------------------------------------------|----------------------------------------------------------------| -| | [pd](`pandas`).[DataFrame](`pandas.DataFrame`) | DataFrame with a ``"forecast"`` column indexed by the forecast | -| | [pd](`pandas`).[DataFrame](`pandas.DataFrame`) | horizon timestamps. | - -#### Raises {.doc-section .doc-section-raises} - -| Name | Type | Description | -|--------|----------------------------|---------------------------------------------------------------------------------------------------------------------------------------| -| | [ValueError](`ValueError`) | If ``config`` is ``None``, or if the supplied ``config`` or ``dataframe`` is invalid (propagated from ``runner.run`` / ``BaseTask``). | -| | [TypeError](`TypeError`) | If ``config`` is not ``None`` and not a ``ConfigMulti`` instance. | - -#### Examples {.doc-section .doc-section-examples} - -```{python} -import tempfile -import numpy as np -import pandas as pd -from spotforecast2_safe.configurator.config_multi import ConfigMulti -from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import run_pipeline - -rng = np.random.default_rng(0) -n = 500 -idx = pd.date_range("2020-01-01", periods=n, freq="h", tz="UTC") -df = pd.DataFrame( - rng.uniform(0, 100, size=(n, 3)), - index=idx, - columns=["A", "B", "C"], -) -df.index.name = "DateTime" - -with tempfile.TemporaryDirectory() as tmp: - cfg = ConfigMulti( - predict_size=4, - train_size=pd.Timedelta(days=14), - use_exogenous_features=False, - use_outlier_detection=False, - imputation_method="linear", - agg_weights=[1.0, 1.0, -1.0], - ) - result = run_pipeline(config=cfg, dataframe=df, project_name="doctest", cache_home=tmp) - -print(type(result)) -print(len(result)) -``` \ No newline at end of file diff --git a/docs/safe/spotforecast2-safe.qmd b/docs/safe/spotforecast2-safe.qmd index df44cec4d..02798054b 100644 --- a/docs/safe/spotforecast2-safe.qmd +++ b/docs/safe/spotforecast2-safe.qmd @@ -11,7 +11,7 @@ In safety-critical environments, reducing the "dead code" and unnecessary depend --- ## Positive List (Retained Components) -The following files are essential for the execution of the primary workflows: `task_safe_demo.py` and `task_safe_n_to_1_with_covariates_and_dataframe.py`. +The following files are essential for the execution of the primary workflow `task_safe_demo.py`. (The `task_safe_n_to_1_with_covariates_and_dataframe.py` console task was removed in 20.0.0; the `multitask` pipeline it wrapped — `BaseTask`/`MultiTask`/`runner.run` — is retained as public API and is exercised by the downstream `spotforecast2` package.) ### Orchestration & Pipelines - `src/spotforecast2_safe/processing/n2n_predict_with_covariates.py` @@ -74,13 +74,13 @@ To maintain a green build and avoid import errors, the following non-essential t The resulting `spotforecast2_safe` project is a hardened version of the original, with $0$ unreachable code paths for the specified tasks and $100\%$ test coverage on the remaining logic. ## Essential Classes and Functions (Positive List) -The following classes and functions (including internal helpers) are strictly required for the execution of `task_safe_demo.py` and `task_safe_n_to_1_with_covariates_and_dataframe.py`: +The following classes and functions (including internal helpers) are strictly required for the execution of `task_safe_demo.py` and the public `multitask` pipeline: ### Orchestration & Processing - `agg_predict` (Function) - `n2n_predict` (Function) - `n2n_predict_with_covariates` (Function) — used by `task_safe_demo` -- `run_pipeline` (Function, `task_safe_n_to_1_with_covariates_and_dataframe`) — delegates to `multitask.runner.run` +- `run` (Function, `multitask.runner`) — public single-call pipeline entry point ### Multitask Pipeline (n-to-1 task) - `BaseTask` (Class) — shared prepare/detect/impute/exog steps diff --git a/docs/tasks/tasks.qmd b/docs/tasks/tasks.qmd index 18503ed3f..2afeb9b60 100644 --- a/docs/tasks/tasks.qmd +++ b/docs/tasks/tasks.qmd @@ -7,7 +7,6 @@ | Command | Description | |---------|-------------| | `spotforecast-safe-demo` | Demo task comparing baseline, covariate, and custom LightGBM forecasts | -| `spotforecast-safe-n2o1-cov-df` | N-to-1 forecasting with exogenous covariates and DataFrame input | --- @@ -49,63 +48,6 @@ uv run spotforecast-safe-demo --logging true --- -## N-to-1 with Covariates and DataFrame - -The `spotforecast-safe-n2o1-cov-df` command runs the `ConfigMulti`-driven -multitask pipeline. It delegates to `spotforecast2_safe.multitask.runner.run` -with `task="lazy"`. All pipeline parameters are centralised in a `ConfigMulti` -object; aggregation weights come from `ConfigMulti.agg_weights` (replacing the -former hard-coded `DEFAULT_WEIGHTS` constant). - -See the `run_pipeline` and `_build_config_from_cli` API reference for a full -description of the pipeline stages and `ConfigMulti` field mappings. - -### Features - -- `ConfigMulti` as single source of truth — horizon, training window, outlier - policy, feature flags, and aggregation weights in one object. -- Multitask pipeline stages: prepare_data, detect_outliers, impute, - build_exogenous_features, run(task="lazy"). -- Model persistence: fitted models are saved under `cache_home/models/` and - can be reloaded with `task="predict"`. - -### Usage - -```bash -# Run with custom settings (a ConfigMulti and explicit data path are required) -uv run spotforecast-safe-n2o1-cov-df - -# Custom forecast horizon -uv run spotforecast-safe-n2o1-cov-df --forecast_horizon 48 - -# Enable holiday features and verbose output -uv run spotforecast-safe-n2o1-cov-df --include_holiday_features true --verbose true - -# Custom aggregation weights -uv run spotforecast-safe-n2o1-cov-df --weights 1.0 1.0 -1.0 - -# Specify cache directory for models and logs -uv run spotforecast-safe-n2o1-cov-df --log_dir ~/my_cache -``` - -### Parameters - -| Parameter | Default | Description | -|-----------|---------|-------------| -| `--forecast_horizon` | 24 | Number of steps ahead to forecast (`predict_size`) | -| `--lags` | 24 | Lag depth N; expands to `lags_consider=range(1, N+1)` | -| `--train_ratio` | 0.8 | Fraction of data used for training | -| `--contamination` | 0.01 | Outlier contamination parameter | -| `--window_size` | 72 | Rolling window size for imputation (hours) | -| `--include_holiday_features` | false | Enable holiday indicator features | -| `--include_holiday_adjacency_features` | false | Enable Brückentag features | -| `--poly_features_degree` | 1 | Polynomial interaction degree (1 = off) | -| `--weights` | None | Space-separated aggregation weights (`agg_weights`) | -| `--verbose` | false | Enable detailed output | -| `--log_dir` | None | Cache directory for models and logs (`cache_home`) | - ---- - ## Configuration All tasks use sensible defaults but can be customized via: @@ -117,7 +59,6 @@ All tasks use sensible defaults but can be customized via: ```bash # View available options for any command uv run spotforecast-safe-demo --help -uv run spotforecast-safe-n2o1-cov-df --help ``` --- @@ -142,9 +83,6 @@ Safety-critical tasks support comprehensive logging: ```bash # Enable logging to default directory uv run spotforecast-safe-demo --logging true - -# Specify custom log directory -uv run spotforecast-safe-n2o1-cov-df --log_dir /var/log/spotforecast ``` Log files include: diff --git a/docs/tutorials/n2n_predict_with_covariates_explained.qmd b/docs/tutorials/n2n_predict_with_covariates_explained.qmd index cbb494c75..255ea7ed1 100644 --- a/docs/tutorials/n2n_predict_with_covariates_explained.qmd +++ b/docs/tutorials/n2n_predict_with_covariates_explained.qmd @@ -714,11 +714,11 @@ convention is used with the eleven-element vector aggregation in which the first, second, fifth, seventh, eighth, ninth, and eleventh columns are added and the remaining columns are subtracted. -The n2n pipeline remains available and is used by `task_safe_demo`. The n-to-1 -task (`tasks/task_safe_n_to_1_with_covariates_and_dataframe.py`) now runs on -`spotforecast2_safe.multitask`; see the -[API reference](../reference/tasks.task_safe_n_to_1_with_covariates_and_dataframe.qmd) -for its `ConfigMulti`-driven workflow and `agg_weights` configuration. +The n2n pipeline remains available and is used by `task_safe_demo`. The +`ConfigMulti`-driven, `agg_weights`-configurable n-to-1 workflow it describes now +runs directly on `spotforecast2_safe.multitask` (`MultiTask` / +`multitask.runner.run`); the former `spotforecast-safe-n2o1-cov-df` console task +was removed in 20.0.0. ```{python} from spotforecast2_safe.processing.agg_predict import agg_predict diff --git a/pyproject.toml b/pyproject.toml index f7d4ebbed..921871195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,6 @@ dev = [ [project.scripts] spotforecast-safe-demo = "spotforecast2_safe.tasks.task_safe_demo:main" -spotforecast-safe-n2o1-cov-df = "spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe:main" [tool.uv] # Force uv lock to resolve dependencies for every platform we ship on, diff --git a/src/spotforecast2_safe/tasks/__init__.py b/src/spotforecast2_safe/tasks/__init__.py index 62d71de12..77c723198 100644 --- a/src/spotforecast2_safe/tasks/__init__.py +++ b/src/spotforecast2_safe/tasks/__init__.py @@ -2,6 +2,5 @@ # SPDX-License-Identifier: AGPL-3.0-or-later from .task_safe_demo import main as demo_main -from .task_safe_n_to_1_with_covariates_and_dataframe import main as n2o1_cov_df_main -__all__ = ["demo_main", "n2o1_cov_df_main"] +__all__ = ["demo_main"] diff --git a/src/spotforecast2_safe/tasks/task_safe_n_to_1_with_covariates_and_dataframe.py b/src/spotforecast2_safe/tasks/task_safe_n_to_1_with_covariates_and_dataframe.py deleted file mode 100644 index febb40b02..000000000 --- a/src/spotforecast2_safe/tasks/task_safe_n_to_1_with_covariates_and_dataframe.py +++ /dev/null @@ -1,413 +0,0 @@ -# SPDX-FileCopyrightText: 2026 bartzbeielstein -# SPDX-License-Identifier: AGPL-3.0-or-later - -"""Thin ConfigMulti-driven entry point for N-to-1 forecasting. - -This module is a single-call wrapper around the ``multitask`` pipeline. -It delegates all heavy lifting to -``spotforecast2_safe.multitask.runner.run`` with ``task="lazy"`` and -returns the forecast DataFrame directly. - -``run_pipeline`` requires an explicit ``ConfigMulti`` instance. Outlier -``bounds`` and aggregation ``agg_weights`` are domain-specific calibrations -and must be supplied by the caller on ``ConfigMulti``; there are no -dataset-specific presets. Input data must always be passed explicitly via -the ``dataframe`` argument. The CLI flag ``--weights`` maps to -``ConfigMulti.agg_weights``; the flag ``--train_ratio`` derives -``train_size`` from the extent of the bundled ``demo10.csv`` (Python API -callers supply ``train_size`` explicitly on ``ConfigMulti``). - -CLI entry point: ``spotforecast-safe-n2o1-cov-df`` -""" - -import argparse -import sys -from pathlib import Path -from typing import List, Optional - -import pandas as pd - -from spotforecast2_safe.configurator.config_multi import ConfigMulti -from spotforecast2_safe.data.fetch_data import get_package_data_home -from spotforecast2_safe.multitask.runner import run -from spotforecast2_safe.utils.parse import parse_bool - -# --------------------------------------------------------------------------- -# Argument parser (shared by CLI entry-point and _build_config_from_cli) -# --------------------------------------------------------------------------- - - -def _make_arg_parser() -> argparse.ArgumentParser: - """Return the argument parser for the CLI entry point.""" - parser = argparse.ArgumentParser( - description=( - "Run the safety-critical N-to-1 forecasting pipeline " - "(ConfigMulti-driven, multitask runner)." - ), - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Forecast parameters - parser.add_argument( - "--forecast_horizon", - type=int, - default=24, - help="Number of steps ahead to forecast.", - ) - parser.add_argument( - "--lags", - type=int, - default=24, - help="Lag depth N; expands to lags_consider=range(1, N+1).", - ) - parser.add_argument( - "--train_ratio", - type=float, - default=0.8, - help="Fraction of data used for training [0, 1].", - ) - parser.add_argument( - "--contamination", - type=float, - default=0.01, - help="Outlier contamination parameter [0, 0.5).", - ) - parser.add_argument( - "--window_size", - type=int, - default=72, - help="Rolling window size for weighted imputation (hours).", - ) - - # Location parameters - parser.add_argument( - "--latitude", - type=float, - default=51.5136, - help="Location latitude for solar/weather features.", - ) - parser.add_argument( - "--longitude", - type=float, - default=7.4653, - help="Location longitude for solar/weather features.", - ) - parser.add_argument( - "--timezone", - type=str, - default="UTC", - help="IANA timezone string.", - ) - parser.add_argument( - "--country_code", - type=str, - default="DE", - help="ISO 3166-1 alpha-2 country code for holidays.", - ) - parser.add_argument( - "--state", - type=str, - default="NW", - help="ISO 3166-2 subdivision code for regional holidays.", - ) - - # Feature engineering flags - parser.add_argument( - "--include_weather_windows", - type=parse_bool, - default=False, - help="Enable rolling weather-window features.", - ) - parser.add_argument( - "--include_holiday_features", - type=parse_bool, - default=False, - help="Enable public-holiday indicator features.", - ) - parser.add_argument( - "--include_holiday_adjacency_features", - type=parse_bool, - default=False, - help="Enable Brückentag and before/after-holiday features.", - ) - parser.add_argument( - "--poly_features_degree", - type=int, - default=1, - help="Polynomial-interaction degree (1 = off).", - ) - parser.add_argument( - "--max_poly_features", - type=int, - default=10, - help="Cap on kept polynomial columns (top-K by mutual information).", - ) - - # Execution controls - parser.add_argument( - "--verbose", - type=parse_bool, - default=False, - help="Enable verbose pipeline output.", - ) - parser.add_argument( - "--weights", - type=float, - nargs="+", - default=None, - help="Space-separated aggregation weights (one per target column).", - ) - parser.add_argument( - "--log_dir", - type=str, - default=None, - help="Cache directory for models and logs (maps to cache_home).", - ) - - return parser - - -def _build_config_from_cli(args: argparse.Namespace) -> ConfigMulti: - """Translate parsed CLI arguments into a ``ConfigMulti`` instance. - - CLI flag to ``ConfigMulti`` field mapping: - - - ``--forecast_horizon N`` -> ``predict_size=N`` - - ``--lags N`` -> ``lags_consider=list(range(1, N+1))`` - - ``--train_ratio R`` -> ``train_size`` derived from demo10 data extent - - ``--contamination`` -> ``contamination`` - - ``--window_size`` -> ``window_size`` - - ``--latitude`` -> ``latitude`` - - ``--longitude`` -> ``longitude`` - - ``--timezone`` -> ``timezone`` - - ``--country_code`` -> ``country_code`` - - ``--state`` -> ``state`` - - ``--include_weather_windows`` -> ``include_weather_windows`` - - ``--include_holiday_features`` -> ``include_holiday_features`` - - ``--include_holiday_adjacency_features`` -> ``include_holiday_adjacency_features`` - - ``--poly_features_degree`` -> ``poly_features_degree`` - - ``--max_poly_features`` -> ``max_poly_features`` - - ``--verbose`` -> ``verbose`` - - ``--weights w1 w2 ...`` -> ``agg_weights=[w1, w2, ...]`` - - ``--log_dir PATH`` -> forwarded as ``cache_home`` to ``main()`` - - ``--lags N`` expands to ``lags_consider=list(range(1, N+1))`` because - ``default_lgbm_forecaster_factory`` reads ``config.lags_consider[-1]`` - to set the lag depth of ``ForecasterRecursive``, preserving the old - ``ForecasterRecursive(lags=N)`` behaviour. - - ``--train_ratio R`` is translated into ``train_size`` as a - ``pd.Timedelta`` derived from the demo10 dataset extent. - - Args: - args: Parsed ``argparse.Namespace`` from ``_make_arg_parser()``. - - Returns: - A new ``ConfigMulti`` instance with the CLI arguments applied. - """ - # --lags N -> lags_consider=list(range(1, N+1)) - lags_consider = list(range(1, args.lags + 1)) - - # --train_ratio R -> train_size derived from demo10 extent - data_home = get_package_data_home() - try: - _df = pd.read_csv(data_home / "demo10.csv", index_col=0, parse_dates=True) - except OSError as exc: - # Fail-safe: never silently substitute a train_size unrelated to - # the data the pipeline will actually load. - raise FileNotFoundError( - "Cannot derive train_size from --train_ratio: the bundled " - f"demo10.csv could not be read from {data_home}: {exc}" - ) from exc - first_ts = pd.to_datetime(_df.index.min(), utc=True) - last_ts = pd.to_datetime(_df.index.max(), utc=True) - total_span = last_ts - first_ts - train_size = pd.Timedelta( - seconds=int(total_span.total_seconds() * args.train_ratio) - ) - - return ConfigMulti( - predict_size=args.forecast_horizon, - lags_consider=lags_consider, - train_size=train_size, - contamination=args.contamination, - window_size=args.window_size, - latitude=args.latitude, - longitude=args.longitude, - timezone=args.timezone, - country_code=args.country_code, - state=args.state, - include_weather_windows=args.include_weather_windows, - include_holiday_features=args.include_holiday_features, - include_holiday_adjacency_features=args.include_holiday_adjacency_features, - poly_features_degree=args.poly_features_degree, - max_poly_features=args.max_poly_features, - verbose=args.verbose, - agg_weights=args.weights, - ) - - -def run_pipeline( - config: Optional[ConfigMulti] = None, - *, - dataframe: Optional[pd.DataFrame] = None, - data_test: Optional[pd.DataFrame] = None, - project_name: str = "demo10", - cache_home: Optional[Path] = None, - show_progress: bool = False, -) -> pd.DataFrame: - """Execute the N-to-1 forecasting pipeline and return the forecast DataFrame. - - Execution is delegated to ``spotforecast2_safe.multitask.runner.run`` - with ``task="lazy"``. A ``ConfigMulti`` instance must be supplied - explicitly; there is no implicit fallback. Outlier ``bounds`` and - aggregation ``agg_weights`` are domain-specific calibrations and must - be provided by the caller — no preset values are substituted. Input - data must likewise be supplied via ``dataframe``; auto-loading is not - performed. - - Args: - config: A ``ConfigMulti`` instance. Must not be ``None``; passing - ``None`` raises ``ValueError``. - dataframe: Input time-series DataFrame. Must contain a datetime - column matching ``config.index_name`` and at least one numeric - target column. - data_test: Ground-truth DataFrame covering the prediction horizon. - Optional; passed through to the runner for metric computation. - project_name: Cache subdirectory and model-file identifier. - Defaults to ``"demo10"``. - cache_home: Cache directory override. When ``None``, the package - default (``~/.spotforecast2_cache/``) is used. - show_progress: Whether to emit progress messages during pipeline - execution. - - Returns: - DataFrame with a ``"forecast"`` column indexed by the forecast - horizon timestamps. - - Raises: - ValueError: If ``config`` is ``None``, or if the supplied - ``config`` or ``dataframe`` is invalid (propagated from - ``runner.run`` / ``BaseTask``). - TypeError: If ``config`` is not ``None`` and not a ``ConfigMulti`` - instance. - - Examples: - ```{python} - import tempfile - import numpy as np - import pandas as pd - from spotforecast2_safe.configurator.config_multi import ConfigMulti - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import run_pipeline - - rng = np.random.default_rng(0) - n = 500 - idx = pd.date_range("2020-01-01", periods=n, freq="h", tz="UTC") - df = pd.DataFrame( - rng.uniform(0, 100, size=(n, 3)), - index=idx, - columns=["A", "B", "C"], - ) - df.index.name = "DateTime" - - with tempfile.TemporaryDirectory() as tmp: - cfg = ConfigMulti( - predict_size=4, - train_size=pd.Timedelta(days=14), - use_exogenous_features=False, - use_outlier_detection=False, - imputation_method="linear", - agg_weights=[1.0, 1.0, -1.0], - ) - result = run_pipeline(config=cfg, dataframe=df, project_name="doctest", cache_home=tmp) - - print(type(result)) - print(len(result)) - ``` - """ - if config is not None and not isinstance(config, ConfigMulti): - raise TypeError( - f"config must be a ConfigMulti instance or None; " - f"got {type(config).__name__!r}." - ) - - if config is None: - raise ValueError( - "config is required: build a ConfigMulti and pass it explicitly, e.g.\n" - " ConfigMulti(\n" - " predict_size=24,\n" - " agg_weights=[...], # one weight per target column; None = equal 1/n\n" - " bounds=[...], # one (lower, upper) per target; None = no clipping\n" - " )\n" - "Outlier `bounds` and aggregation `agg_weights` are domain-specific " - "calibrations and are never defaulted to demo-dataset values." - ) - - return run( - config, - task="lazy", - dataframe=dataframe, - data_test=data_test, - project_name=project_name, - cache_home=cache_home, - show_progress=show_progress, - ) - - -def main(argv: Optional[List[str]] = None) -> None: - """CLI entry point for the N-to-1 forecasting pipeline. - - Parses command-line arguments, builds a ``ConfigMulti`` via - ``_build_config_from_cli``, and delegates to ``run_pipeline``. - Prints the forecast head to stdout. - - When ``argv`` is ``None``, ``sys.argv[1:]`` is used. Pass an explicit - list of strings to invoke programmatically with a specific argv (useful - for testing without touching ``sys.argv``). - - Args: - argv: Argument list. ``None`` means read from ``sys.argv``. - - Examples: - ```{python} - import sys - from io import StringIO - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import main - - # Capture help text to verify the CLI is wired correctly without - # triggering a full training run. - buf = StringIO() - old_stdout, sys.stdout = sys.stdout, buf - try: - main(["--help"]) - except SystemExit as exc: - assert exc.code == 0, f"--help exited with non-zero code: {exc.code}" - finally: - sys.stdout = old_stdout - - help_text = buf.getvalue() - assert "forecast_horizon" in help_text - assert "lags" in help_text - print(help_text[:120]) - ``` - """ - parser = _make_arg_parser() - args = parser.parse_args(argv) - - cache_home_path = Path(args.log_dir) if args.log_dir else None - cfg = _build_config_from_cli(args) - - try: - forecast = run_pipeline(config=cfg, cache_home=cache_home_path) - print("\nForecast head:") - print(forecast.head()) - except KeyboardInterrupt: - print("\nShutdown requested by user.") - sys.exit(0) - except Exception as exc: - print(f"\nCritical failure: {exc}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/tests/test_cli_n2o1.py b/tests/test_cli_n2o1.py deleted file mode 100644 index 43d167d5e..000000000 --- a/tests/test_cli_n2o1.py +++ /dev/null @@ -1,134 +0,0 @@ -# SPDX-FileCopyrightText: 2026 bartzbeielstein -# SPDX-License-Identifier: AGPL-3.0-or-later - -"""CLI smoke tests for the spotforecast-safe-n2o1-cov-df entry point. - -These tests verify that: -- ``--help`` exits with code 0 (argparse contract). -- ``main(["--help"])`` exits with code 0 (programmatic argv path). -- An invalid flag value causes argparse to exit with code 2. -- ``main([])`` with a patched ``run_pipeline`` completes without error. - -The tests do NOT run the full pipeline (that requires model training and -cached data); they exercise only the CLI argument-parsing layer. The single -real end-to-end test lives in -``tests/test_task_safe_n_to_1_with_covariates.py::TestRunPipelineEndToEnd``. -""" - -import subprocess -import sys -from pathlib import Path -from unittest.mock import patch - -import pandas as pd -import pytest - -MOD = "spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe" - - -class TestCliHelp: - """``--help`` must produce a usage message and exit 0.""" - - def test_help_via_subprocess(self): - """Invoke the CLI via subprocess to exercise the full console-script path.""" - result = subprocess.run( - [ - sys.executable, - "-m", - "spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe", - "--help", - ], - capture_output=True, - text=True, - ) - assert result.returncode == 0 - assert "forecast_horizon" in result.stdout - - def test_help_via_main_argv(self): - """``main(["--help"])`` exits 0 without touching the pipeline.""" - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - with pytest.raises(SystemExit) as exc_info: - main(["--help"]) - assert exc_info.value.code == 0 - - def test_invalid_flag_exits_2(self): - """An unrecognised flag must cause argparse to exit with code 2.""" - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - with pytest.raises(SystemExit) as exc_info: - main(["--no_such_flag"]) - assert exc_info.value.code == 2 - - -class TestCliArgvParsing: - """CLI flag parsing routes to the correct ``ConfigMulti`` fields.""" - - @patch(f"{MOD}.run_pipeline") - def test_empty_argv_runs_without_error(self, mock_run): - """``main([])`` uses all defaults and calls ``run_pipeline`` once.""" - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - mock_run.return_value = pd.DataFrame({"forecast": [1.0, 2.0]}) - main([]) - mock_run.assert_called_once() - - @patch(f"{MOD}.run_pipeline") - def test_forecast_horizon_flag(self, mock_run): - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - mock_run.return_value = pd.DataFrame({"forecast": [1.0]}) - main(["--forecast_horizon", "6"]) - _, kwargs = mock_run.call_args - assert kwargs["config"].predict_size == 6 - - @patch(f"{MOD}.run_pipeline") - def test_lags_flag_expands_range(self, mock_run): - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - mock_run.return_value = pd.DataFrame({"forecast": [1.0]}) - main(["--lags", "5"]) - _, kwargs = mock_run.call_args - assert kwargs["config"].lags_consider == list(range(1, 6)) - - @patch(f"{MOD}.run_pipeline") - def test_weights_flag(self, mock_run): - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - mock_run.return_value = pd.DataFrame({"forecast": [1.0]}) - main(["--weights", "1.0", "-1.0", "0.5"]) - _, kwargs = mock_run.call_args - assert kwargs["config"].agg_weights == [1.0, -1.0, 0.5] - - @patch(f"{MOD}.run_pipeline") - def test_log_dir_becomes_cache_home(self, mock_run, tmp_path: Path): - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - mock_run.return_value = pd.DataFrame({"forecast": [1.0]}) - main(["--log_dir", str(tmp_path)]) - _, kwargs = mock_run.call_args - assert kwargs["cache_home"] == tmp_path - - @patch(f"{MOD}.run_pipeline", side_effect=RuntimeError("boom")) - def test_pipeline_error_exits_1(self, _mock_run): - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - with pytest.raises(SystemExit) as exc_info: - main([]) - assert exc_info.value.code == 1 diff --git a/tests/test_task_safe_n_to_1_with_covariates.py b/tests/test_task_safe_n_to_1_with_covariates.py deleted file mode 100644 index 763d01830..000000000 --- a/tests/test_task_safe_n_to_1_with_covariates.py +++ /dev/null @@ -1,327 +0,0 @@ -# SPDX-FileCopyrightText: 2026 bartzbeielstein -# SPDX-License-Identifier: AGPL-3.0-or-later - -"""Tests for the N-to-1-with-covariates task module (ConfigMulti-driven). - -These tests exercise the thin entry-point in -``spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe`` -by verifying: - -- ``run_pipeline`` delegates to ``runner.run`` with the right arguments. -- ``run_pipeline`` with ``config=None`` raises ``ValueError`` with a message - mentioning "explicitly" (no demo10 preset is substituted). -- ``_build_config_from_cli`` maps CLI flags to ``ConfigMulti`` fields correctly. -- A real end-to-end run on synthetic data returns a non-empty forecast DataFrame. -- Fail-safe: invalid input raises ``TypeError``/``ValueError``. -""" - -import argparse -from pathlib import Path -from unittest.mock import patch - -import numpy as np -import pandas as pd -import pytest - -from spotforecast2_safe.configurator.config_multi import ConfigMulti -from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - _build_config_from_cli, - main, - run_pipeline, -) - -MOD = "spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe" - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_synthetic_df(n: int = 500, n_cols: int = 3) -> pd.DataFrame: - """Return a small hourly DataFrame with a DatetimeIndex suitable for the pipeline. - - The pipeline's ``prepare_data`` step expects a DataFrame with a - ``DatetimeIndex`` (or a ``DatetimeIndex`` named ``index_name``). Passing - a pre-reset DataFrame (with "DateTime" as a regular column) causes a - ``ValueError`` when ``reset_index`` is called internally. - """ - rng = np.random.default_rng(0) - idx = pd.date_range("2020-01-01", periods=n, freq="h", tz="UTC") - data = rng.uniform(0, 100, size=(n, n_cols)) - df = pd.DataFrame(data, index=idx, columns=[f"C{i}" for i in range(n_cols)]) - df.index.name = "DateTime" - return df - - -def _minimal_config(n_cols: int = 3, predict_size: int = 4) -> ConfigMulti: - """Return a fast, offline-safe ``ConfigMulti`` for synthetic-data tests.""" - return ConfigMulti( - predict_size=predict_size, - train_size=pd.Timedelta(days=14), - use_exogenous_features=False, - use_outlier_detection=False, - imputation_method="linear", - agg_weights=[1.0] * n_cols, - ) - - -# --------------------------------------------------------------------------- -# run_pipeline: routing tests (runner.run patched) -# --------------------------------------------------------------------------- - - -class TestRunPipelineRouting: - """``run_pipeline`` must delegate correctly to ``runner.run``.""" - - @patch(f"{MOD}.run") - def test_run_called_with_task_lazy(self, mock_run): - mock_run.return_value = pd.DataFrame({"forecast": [1.0, 2.0, 3.0, 4.0]}) - cfg = _minimal_config() - df = _make_synthetic_df() - - run_pipeline(config=cfg, dataframe=df, project_name="proj", cache_home="/tmp") - - mock_run.assert_called_once() - _, kwargs = mock_run.call_args - assert kwargs["task"] == "lazy" - assert kwargs["project_name"] == "proj" - assert kwargs["cache_home"] == "/tmp" - assert kwargs["dataframe"] is df - - def test_none_config_raises_value_error(self): - """When config=None, run_pipeline raises ValueError mentioning 'explicitly'.""" - with pytest.raises(ValueError, match="explicitly"): - run_pipeline(config=None) - - @patch(f"{MOD}.run") - def test_returns_dataframe_from_runner(self, mock_run): - expected = pd.DataFrame( - {"forecast": [10.0, 20.0, 30.0, 40.0]}, - index=pd.date_range("2020-01-01", periods=4, freq="h"), - ) - mock_run.return_value = expected - cfg = _minimal_config() - df = _make_synthetic_df() - - result = run_pipeline(config=cfg, dataframe=df) - - assert result is expected - - -# --------------------------------------------------------------------------- -# run_pipeline: fail-safe behaviour -# --------------------------------------------------------------------------- - - -class TestRunPipelineFailSafe: - """Invalid input must raise explicitly; nothing is silently swallowed.""" - - def test_wrong_config_type_raises_type_error(self): - with pytest.raises(TypeError, match="ConfigMulti"): - run_pipeline(config="not_a_config") - - def test_wrong_config_type_int_raises_type_error(self): - with pytest.raises(TypeError, match="ConfigMulti"): - run_pipeline(config=42) - - @patch(f"{MOD}.run", side_effect=ValueError("bad data")) - def test_pipeline_failure_propagates(self, _mock_run): - cfg = _minimal_config() - df = _make_synthetic_df() - with pytest.raises(ValueError, match="bad data"): - run_pipeline(config=cfg, dataframe=df) - - -# --------------------------------------------------------------------------- -# run_pipeline: real end-to-end on synthetic data -# --------------------------------------------------------------------------- - - -class TestRunPipelineEndToEnd: - """At least one test exercises the real pipeline on synthetic data.""" - - def test_returns_non_empty_forecast_of_correct_length(self, tmp_path: Path): - """Full real pipeline run (no mocks) on a small synthetic DataFrame.""" - n = 500 - predict_size = 4 - n_cols = 3 - df = _make_synthetic_df(n=n, n_cols=n_cols) - cfg = _minimal_config(n_cols=n_cols, predict_size=predict_size) - - result = run_pipeline( - config=cfg, - dataframe=df, - project_name="e2e_test", - cache_home=str(tmp_path), - ) - - assert isinstance(result, pd.DataFrame) - assert not result.empty - assert len(result) == predict_size - assert "forecast" in result.columns - - -# --------------------------------------------------------------------------- -# _build_config_from_cli: flag mapping -# --------------------------------------------------------------------------- - - -class TestBuildConfigFromCli: - """``_build_config_from_cli`` must map every CLI flag to the right ConfigMulti field.""" - - def _args(self, **kwargs) -> argparse.Namespace: - defaults = dict( - forecast_horizon=24, - lags=24, - train_ratio=0.8, - contamination=0.01, - window_size=72, - latitude=51.5136, - longitude=7.4653, - timezone="UTC", - country_code="DE", - state="NW", - include_weather_windows=False, - include_holiday_features=False, - include_holiday_adjacency_features=False, - poly_features_degree=1, - max_poly_features=10, - verbose=False, - weights=None, - log_dir=None, - ) - defaults.update(kwargs) - return argparse.Namespace(**defaults) - - def test_predict_size_mapped(self): - cfg = _build_config_from_cli(self._args(forecast_horizon=48)) - assert cfg.predict_size == 48 - - def test_lags_expanded_to_range(self): - cfg = _build_config_from_cli(self._args(lags=5)) - assert cfg.lags_consider == list(range(1, 6)) - assert cfg.lags_consider[-1] == 5 - - def test_lags_last_matches_n(self): - """lags_consider[-1] == N preserves old ForecasterRecursive(lags=N) behaviour.""" - for n in (12, 24, 48): - cfg = _build_config_from_cli(self._args(lags=n)) - assert cfg.lags_consider[-1] == n - - def test_train_ratio_yields_timedelta(self): - cfg = _build_config_from_cli(self._args(train_ratio=0.5)) - assert isinstance(cfg.train_size, pd.Timedelta) - assert cfg.train_size > pd.Timedelta(0) - - def test_train_ratio_proportional(self): - """Larger ratio -> larger train_size.""" - cfg_large = _build_config_from_cli(self._args(train_ratio=0.9)) - cfg_small = _build_config_from_cli(self._args(train_ratio=0.5)) - assert cfg_large.train_size > cfg_small.train_size - - def test_contamination_mapped(self): - cfg = _build_config_from_cli(self._args(contamination=0.05)) - assert cfg.contamination == 0.05 - - def test_window_size_mapped(self): - cfg = _build_config_from_cli(self._args(window_size=48)) - assert cfg.window_size == 48 - - def test_latitude_mapped(self): - cfg = _build_config_from_cli(self._args(latitude=48.1)) - assert cfg.latitude == 48.1 - - def test_longitude_mapped(self): - cfg = _build_config_from_cli(self._args(longitude=11.6)) - assert cfg.longitude == 11.6 - - def test_timezone_mapped(self): - cfg = _build_config_from_cli(self._args(timezone="Europe/Berlin")) - assert cfg.timezone == "Europe/Berlin" - - def test_country_code_mapped(self): - cfg = _build_config_from_cli(self._args(country_code="FR")) - assert cfg.country_code == "FR" - - def test_state_mapped(self): - cfg = _build_config_from_cli(self._args(state="BY")) - assert cfg.state == "BY" - - def test_include_weather_windows_mapped(self): - cfg = _build_config_from_cli(self._args(include_weather_windows=True)) - assert cfg.include_weather_windows is True - - def test_include_holiday_features_mapped(self): - cfg = _build_config_from_cli(self._args(include_holiday_features=True)) - assert cfg.include_holiday_features is True - - def test_include_holiday_adjacency_features_mapped(self): - cfg = _build_config_from_cli( - self._args(include_holiday_adjacency_features=True) - ) - assert cfg.include_holiday_adjacency_features is True - - def test_poly_features_degree_mapped(self): - cfg = _build_config_from_cli(self._args(poly_features_degree=2)) - assert cfg.poly_features_degree == 2 - - def test_max_poly_features_mapped(self): - cfg = _build_config_from_cli(self._args(max_poly_features=5)) - assert cfg.max_poly_features == 5 - - def test_verbose_mapped(self): - cfg = _build_config_from_cli(self._args(verbose=True)) - assert cfg.verbose is True - - def test_weights_mapped_to_agg_weights(self): - weights = [1.0, -1.0, 0.5] - cfg = _build_config_from_cli(self._args(weights=weights)) - assert cfg.agg_weights == weights - - def test_weights_none_stays_none(self): - cfg = _build_config_from_cli(self._args(weights=None)) - assert cfg.agg_weights is None - - -# --------------------------------------------------------------------------- -# main: CLI argv parsing -# --------------------------------------------------------------------------- - - -class TestMainCliParsing: - """``main()`` must parse argv, build config, and call run_pipeline.""" - - @patch(f"{MOD}.run_pipeline") - def test_help_exits_0(self, _mock_run): - with pytest.raises(SystemExit) as exc_info: - main(["--help"]) - assert exc_info.value.code == 0 - - @patch(f"{MOD}.run_pipeline") - def test_main_with_no_argv_uses_defaults(self, mock_run): - mock_run.return_value = pd.DataFrame({"forecast": [1.0]}) - # Calling with empty argv list should not raise - main([]) - mock_run.assert_called_once() - - @patch(f"{MOD}.run_pipeline") - def test_forecast_horizon_flag_forwarded(self, mock_run): - mock_run.return_value = pd.DataFrame({"forecast": [1.0]}) - main(["--forecast_horizon", "48"]) - _, kwargs = mock_run.call_args - # The config passed to run_pipeline should have predict_size=48 - assert kwargs["config"].predict_size == 48 - - @patch(f"{MOD}.run_pipeline") - def test_log_dir_becomes_cache_home(self, mock_run, tmp_path: Path): - mock_run.return_value = pd.DataFrame({"forecast": [1.0]}) - main(["--log_dir", str(tmp_path)]) - _, kwargs = mock_run.call_args - assert kwargs["cache_home"] == tmp_path - - @patch(f"{MOD}.run_pipeline", side_effect=RuntimeError("boom")) - def test_pipeline_failure_exits_1(self, _mock_run): - with pytest.raises(SystemExit) as exc_info: - main([]) - assert exc_info.value.code == 1 diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 4e797d28a..d6a863786 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -56,60 +56,5 @@ def test_main_returns_zero_on_success(self): self.assertEqual(result, 0) -class TestTaskSafeN2O1CovDf(unittest.TestCase): - """Tests for task_safe_n_to_1_with_covariates_and_dataframe.py (ConfigMulti-driven).""" - - def test_main_exits_0_on_help(self): - """main(["--help"]) must exit with code 0.""" - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - with self.assertRaises(SystemExit) as ctx: - main(["--help"]) - self.assertEqual(ctx.exception.code, 0) - - @patch( - "spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe.run_pipeline" - ) - def test_main_with_empty_argv_calls_run_pipeline(self, mock_run): - """main([]) builds a config from defaults and delegates to run_pipeline.""" - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - mock_run.return_value = pd.DataFrame({"forecast": [1.0, 2.0, 3.0]}) - main([]) - mock_run.assert_called_once() - - @patch( - "spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe.run_pipeline" - ) - def test_main_forecast_horizon_flag(self, mock_run): - """--forecast_horizon flag is translated to config.predict_size.""" - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - mock_run.return_value = pd.DataFrame({"forecast": [1.0]}) - main(["--forecast_horizon", "48"]) - _, kwargs = mock_run.call_args - self.assertEqual(kwargs["config"].predict_size, 48) - - @patch( - "spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe.run_pipeline", - side_effect=RuntimeError("pipeline error"), - ) - def test_main_exits_1_on_failure(self, _mock_run): - """main exits with code 1 when run_pipeline raises.""" - from spotforecast2_safe.tasks.task_safe_n_to_1_with_covariates_and_dataframe import ( - main, - ) - - with self.assertRaises(SystemExit) as ctx: - main([]) - self.assertEqual(ctx.exception.code, 1) - - if __name__ == "__main__": unittest.main() From defdca73e300d9fd956652400ecd8981cec8cac6 Mon Sep 17 00:00:00 2001 From: bartzbeielstein <32470350+bartzbeielstein@users.noreply.github.com> Date: Mon, 8 Jun 2026 20:41:20 +0200 Subject: [PATCH 2/2] refactor(configurator): ConfigEntsoe inherits ConfigMulti MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ConfigEntsoe was an independent dataclass duplicating ConfigMulti's entire field set, so every new ConfigMulti feature flag had to be hand-mirrored or the ENTSO-E pipeline raised TypeError (the parity gap, patched reactively in 19.4.0). Make ConfigEntsoe a subclass: it now declares only its two genuine differences — the index_name default override ("Time (UTC)") and the ENTSO-E-only retrain_max_age field — and inherits all 73 shared fields plus get_params / set_params / __post_init__(validate_config). The parity gap is closed structurally: any flag added to ConfigMulti appears on ConfigEntsoe automatically. Verified byte-identical: ConfigEntsoe() defaults are unchanged vs the previous standalone class (only sanctioned diffs: index_name default + retrain_max_age). The hand-maintained parity test is replaced by a structural invariant (set(ConfigMulti._PARAM_NAMES) <= set(ConfigEntsoe._PARAM_NAMES); the only extra is retrain_max_age) plus default-equality and isinstance assertions. Enabled by the n2o1 task removal, which dropped the isinstance(config, ConfigMulti) guard that subclassing would otherwise have subverted. Co-Authored-By: Claude Opus 4.8 (1M context) --- ...onfigurator.config_entsoe.ConfigEntsoe.qmd | 221 ++------- .../configurator/config_entsoe.py | 453 ++---------------- tests/test_config_entsoe_feature_parity.py | 72 +-- 3 files changed, 132 insertions(+), 614 deletions(-) diff --git a/docs/reference/configurator.config_entsoe.ConfigEntsoe.qmd b/docs/reference/configurator.config_entsoe.ConfigEntsoe.qmd index c161ebcea..eb921989b 100644 --- a/docs/reference/configurator.config_entsoe.ConfigEntsoe.qmd +++ b/docs/reference/configurator.config_entsoe.ConfigEntsoe.qmd @@ -66,7 +66,6 @@ configurator.config_entsoe.ConfigEntsoe( exog_max_gap_hours=0, exog_max_tail_gap_hours=0, exog_provider_window='full', - retrain_max_age=(lambda: pd.Timedelta(days=7))(), target_qc_range_mw=None, target_qc_step_mw=None, target_qc_window_days=None, @@ -76,203 +75,59 @@ configurator.config_entsoe.ConfigEntsoe( target_qc_deviation_mw=None, target_qc_deviation_ref=None, target_qc_deviation_slots=2, + retrain_max_age=(lambda: pd.Timedelta(days=7))(), ) ``` Configuration for the ENTSO-E forecasting pipeline. -Single-target counterpart to ``ConfigMulti``. Used by the ENTSO-E CLI -(``spotforecast2.tasks.task_entsoe``) and any other single-target pipeline -routed through ``spotforecast2.multitask.runner.run(config_cls=ConfigEntsoe)``. +Single-target counterpart to `ConfigMulti`, used by the ENTSO-E CLI +(``spotforecast2.tasks.task_entsoe``) and any single-target pipeline routed +through ``spotforecast2.multitask.runner.run(config_cls=ConfigEntsoe)``. -``country_code`` is the canonical ISO 3166-1 alpha-2 country-code -attribute used by both API queries and the multitask ``PipelineConfig`` -protocol. +``ConfigEntsoe`` **inherits every field and method of `ConfigMulti`** — so any +feature flag added to ``ConfigMulti`` is available here automatically (this +is what closes the historical feature-flag parity gap structurally, rather +than via a hand-maintained mirror). It differs from ``ConfigMulti`` in +exactly two ways: -## Parameters {.doc-section .doc-section-parameters} +- ``index_name`` defaults to ``"Time (UTC)"`` (the ENTSO-E CSV time column) + instead of ``"DateTime"``. +- it adds ``retrain_max_age`` — the maximum age of a previously trained model + before retraining is required (consumed by + `spotforecast2_safe.manager.trainer.should_retrain`). -| Name | Type | Description | Default | -|------------------------------------|------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------| -| country_code | [str](`str`) | ISO 3166-1 alpha-2 country code (e.g. ``"DE"``). | `'DE'` | -| periods | [Optional](`typing.Optional`)\[[List](`typing.List`)\[[Period](`spotforecast2_safe.data.Period`)\]\] | Cyclical feature encodings. | `default_periods()` | -| lags_consider | [Optional](`typing.Optional`)\[[List](`typing.List`)\[[int](`int`)\]\] | Lag values for autoregressive features. | `(lambda: list(range(1, 24)))()` | -| train_size | [Optional](`typing.Optional`)\[[pd](`pandas`).[Timedelta](`pandas.Timedelta`)\] | Training window. | `(lambda: pd.Timedelta(days=(3 * 365)))()` | -| end_train_default | [str](`str`) | Default end-of-training timestamp (ISO). | `'2025-12-31 00:00+00:00'` | -| delta_val | [Optional](`typing.Optional`)\[[pd](`pandas`).[Timedelta](`pandas.Timedelta`)\] | Validation window. | `(lambda: pd.Timedelta(hours=(24 * 7 * 10)))()` | -| predict_size | [int](`int`) | Prediction horizon in hours. | `24` | -| cv_block_size | [int](`int`) \| None | Cross-validation test-block width in hours. Defaults to ``None``, meaning the CV uses ``predict_size``. Set to a fixed value (e.g. ``24``) to decouple the cross-validation horizon from a render-dependent live ``predict_size``. | `None` | -| refit_size | [int](`int`) | Refit cadence in days. | `7` | -| random_state | [int](`int`) | Random seed. | `314159` | -| n_hyperparameters_trials | [int](`int`) | Hyperparameter-tuning trial budget. | `20` | -| data_filename | [str](`str`) | Path to the merged interim CSV. | `'interim/energy_load.csv'` | -| targets | [Optional](`typing.Optional`)\[[List](`typing.List`)\[[str](`str`)\]\] | Active target column names. ``None`` until set after data loading. For ENTSO-E this is typically ``["Actual Load"]``. | `None` | -| use_outlier_detection | [bool](`bool`) | Apply IsolationForest-based outlier removal. Defaults to ``True``. | `True` | -| contamination | [float](`float`) | IsolationForest contamination fraction. | `0.01` | -| imputation_method | [str](`str`) | Gap-filling strategy. | `'weighted'` | -| window_size | [int](`int`) | Rolling window for weighted imputation. Also the LightGBM rolling-mean feature window in the ENTSO-E factories. | `72` | -| imputation_window_size | [Optional](`typing.Optional`)\[[int](`int`)\] | Width of the gap-penalty zone (in hours) around each imputed value for the ``"weighted"`` strategy. When ``None`` (default), falls back to ``window_size``, so existing behaviour is unchanged. Set this to decouple the imputation penalty zone from the rolling-feature window. | `None` | -| use_exogenous_features | [bool](`bool`) | Build weather/calendar/holiday features. | `True` | -| latitude | [float](`float`) | Location latitude. | `51.5136` | -| longitude | [float](`float`) | Location longitude. | `7.4653` | -| timezone | [str](`str`) | IANA timezone string. | `'UTC'` | -| state | [str](`str`) | Subdivision code for regional holidays. | `'NW'` | -| include_weather_windows | [bool](`bool`) | Weather-window feature toggle. | `False` | -| include_holiday_features | [bool](`bool`) | Holiday feature toggle. | `False` | -| include_holiday_adjacency_features | [bool](`bool`) | Brückentag and before/after-holiday indicator toggle. Defaults to ``False``. | `False` | -| poly_features_degree | [int](`int`) | Polynomial-interaction degree passed to the feature builder. ``1`` (default) generates no interactions; ``2`` adds pairwise bilinear terms; ``3+`` higher order. | `1` | -| max_poly_features | [int](`int`) | Cap on polynomial interaction columns. When more than this many ``poly_*`` columns are generated, only the top ``max_poly_features`` ranked by mutual information with the target are kept (``<= 0`` disables the cap). Defaults to ``10``. | `10` | -| poly_mi_n_jobs | [Optional](`typing.Optional`)\[[int](`int`)\] | Parallel jobs for the mutual-information ranking that enforces ``max_poly_features``. ``-1`` (default) uses all cores; ``None`` runs single-threaded. Parallelism does not change the selection. | `-1` | -| poly_mi_sample_size | [Optional](`typing.Optional`)\[[int](`int`)\] | Row cap for that ranking; longer series are scored on a reproducible random subsample of this size (seeded by ``random_state``), which can change which borderline columns make the top K. ``None`` scores every row (the pre-15.8 behaviour). Defaults to ``4000``. | `4000` | -| include_covid_infection_rate | [bool](`bool`) | Append the bundled German national COVID-19 7-day incidence (RKI) as an exogenous level regressor. Defaults to ``False``. | `False` | -| include_entsoe_forecast_load | [bool](`bool`) | Append the ENTSO-E day-ahead Forecasted Load as a near-oracle exogenous prior. Defaults to ``False``. | `False` | -| include_entsoe_renewable_forecast | [bool](`bool`) | Append the ENTSO-E day-ahead wind and solar generation forecast. Defaults to ``False``. | `False` | -| include_entsoe_net_load | [bool](`bool`) | Append the ENTSO-E day-ahead net load (Forecasted Load minus wind/solar forecast). Defaults to ``False``. | `False` | -| include_entsoe_day_ahead_price | [bool](`bool`) | Append the ENTSO-E day-ahead spot price (DE/LU). Defaults to ``False``. | `False` | -| index_name | [str](`str`) | Datetime column name when the DataFrame index is reset. ENTSO-E CSVs use ``"Time (UTC)"``; defaults to that. | `'Time (UTC)'` | -| bounds | [Optional](`typing.Optional`)\[[List](`typing.List`)\[[tuple](`tuple`)\]\] | Per-column outlier bounds. For single-target ENTSO-E this is typically ``None`` or a single ``[(lower, upper)]`` entry. | `None` | -| verbose | [bool](`bool`) | Verbose pipeline output. | `False` | -| cache_home | [Optional](`typing.Optional`)\[[Any](`typing.Any`)\] | Cache directory override. | `None` | -| n_trials_optuna | [int](`int`) | Optuna Bayesian-search trial budget. | `15` | -| n_trials_spotoptim | [int](`int`) | SpotOptim surrogate-search trial budget. | `10` | -| n_initial_spotoptim | [int](`int`) | SpotOptim initial random evaluations. | `5` | -| n_jobs_spotoptim | [Optional](`typing.Optional`)\[[int](`int`)\] | Worker count for SpotOptim's parallel (steady-state) evaluation. ``None`` (default) runs sequentially; ``-1`` uses all CPU cores; a positive integer pins the worker count. Parallel tuning is faster but, being steady-state, changes the search trajectory, so the tuned result is not bit-identical to a sequential run even with a fixed ``random_state``. | `None` | -| warm_start_lags | [bool](`bool`) | Seed the SpotOptim search with ``lags_consider``. | `False` | -| task | [str](`str`) | Active prediction task name. | `'lazy'` | -| agg_weights | [Optional](`typing.Optional`)\[[List](`typing.List`)\[[float](`float`)\]\] | Per-target aggregation weights. For single-target use this is typically ``[1.0]`` or ``None``. | `None` | -| forecaster_factory | [Optional](`typing.Optional`)\[[Any](`typing.Any`)\] | Callable ``factory(config, *, weight_func, target) -> forecaster`` consumed by ``BaseTask.create_forecaster``. ``None`` falls back to the default LightGBM factory. | `None` | -| data_loader | [Optional](`typing.Optional`)\[[Any](`typing.Any`)\] | Callable ``data_loader(config)`` returning a pandas DataFrame. Invoked by ``BaseTask.prepare_data`` when no DataFrame is supplied — the ENTSO-E pipeline hook for ``download_new_data`` / ``merge_build_manual``. | `None` | -| test_data_loader | [Optional](`typing.Optional`)\[[Any](`typing.Any`)\] | Callable ``test_data_loader(config)`` returning a pandas DataFrame with ground-truth values for the prediction horizon. Invoked by ``BaseTask.prepare_data`` when no test DataFrame is supplied; the returned frame populates ``test_actual`` and ``metrics_future`` in the prediction package. | `None` | -| auto_save_models | [bool](`bool`) | Whether ``BaseTask._run_strategy`` should persist fitted forecasters to ``/models/`` after every training run. Defaults to ``True``. | `True` | -| data_frame_name | [str](`str`) | Identifier for the active dataset. Used by ``BaseTask`` to name cache subdirectories, model files, and the per-dataset log file. Defaults to ``"default"``. | `'default'` | -| number_folds | [int](`int`) | Number of folds used by ``BaseTask.cv_ts`` when building the ``TimeSeriesSplit`` cross-validation splitter for tuning tasks. Defaults to ``10``. | `10` | -| on_weather_failure | [Literal](`typing.Literal`)\[\'raise\', \'skip\'\] | Policy for handling Open-Meteo fetch failures inside ``BaseTask.build_exogenous_features``. ``"raise"`` (default) aborts the pipeline with a ``WeatherFetchError`` and preserves the safety-critical fail-safe semantics. ``"skip"`` logs a warning and continues with empty weather features so the rest of the pipeline can run without the Open-Meteo dependency. | `'raise'` | -| on_exog_provider_failure | [Literal](`typing.Literal`)\[\'raise\', \'skip\'\] | Policy for an exogenous-provider failure inside ``ExogBuilder.build``. ``"raise"`` (default) propagates the ``ExogProviderError`` (fail-safe); ``"skip"`` logs a warning and omits that provider's columns. | `'raise'` | -| exog_max_gap_hours | [int](`int`) | Maximum length, in hours, of a contiguous run of missing exogenous-provider values healed before the provider is rejected. Interior gaps are time-interpolated; leading/trailing edge gaps are back-/forward-filled. ``0`` (default) keeps the strict fail-safe (any gap raises). Healed runs are logged with count and span. Only already-published day-ahead vintages are involved, so healing is leakage-clean (CR-3). | `0` | -| exog_max_tail_gap_hours | [int](`int`) | Extended healing budget, in hours, applied exclusively to the trailing-edge NaN run (the run containing the last index timestamp). The effective tail budget is ``max(exog_max_gap_hours, exog_max_tail_gap_hours)``. The canonical use case is the ENTSO-E day-ahead publication frontier: the last published vintage is zero-order-held forward to the forecast horizon without touching interior gaps (CR-3-clean). When ``exog_max_tail_gap_hours <= exog_max_gap_hours`` the parameter is inert (the interior budget already covers the tail) and a warning is logged. Defaults to ``0``. | `0` | -| exog_provider_window | [Literal](`typing.Literal`)\[\'full\', \'train\'\] | Span the exogenous providers are validated against. ``"full"`` (default) requires coverage of the entire ``data_start``→``cov_end`` request, matching prior behaviour. ``"train"`` validates only the consumed window ``[start_train_ts, cov_end]``, tolerating missing values before the training window. Honoured by the MultiTask pipeline; the forecaster-wrapper path currently always validates the full span. | `'full'` | -| retrain_max_age | [pd](`pandas`).[Timedelta](`pandas.Timedelta`) | Maximum age of a previously trained model before retraining is required. Consumed by ``spotforecast2_safe.manager.trainer.should_retrain`` to gate scheduled retraining workflows. Defaults to ``Timedelta(days=7)``. | `(lambda: pd.Timedelta(days=7))()` | +See `ConfigMulti` for the full field reference (training/validation windows, +feature toggles, exogenous-provider flags, target-corruption knobs, …). -## Attributes {.doc-section .doc-section-attributes} +## Parameters {.doc-section .doc-section-parameters} -| Name | Type | Description | -|--------------------|----------------------------------------------------|--------------------------------------------------------------------------------------------| -| country_code | [str](`str`) | ISO country code used for API queries and holiday feature generation. | -| auto_save_models | [bool](`bool`) | Whether to auto-persist fitted forecasters after each training run. | -| data_frame_name | [str](`str`) | Active-dataset identifier used for cache and log-file naming. | -| number_folds | [int](`int`) | Cross-validation fold count for tuning tasks. | -| on_weather_failure | [Literal](`typing.Literal`)\[\'raise\', \'skip\'\] | Open-Meteo fetch-failure policy: ``"raise"`` aborts, ``"skip"`` continues without weather. | +| Name | Type | Description | Default | +|-----------------|------------------------------------------------|--------------------------------------------------------------------------------------------------|------------------------------------| +| index_name | [str](`str`) | Datetime column name used when resetting the index. Defaults to ``"Time (UTC)"``. | `'Time (UTC)'` | +| retrain_max_age | [pd](`pandas`).[Timedelta](`pandas.Timedelta`) | Maximum age of a trained model before a retrain is forced. Defaults to ``pd.Timedelta(days=7)``. | `(lambda: pd.Timedelta(days=7))()` | ## Examples {.doc-section .doc-section-examples} ```{python} import pandas as pd - from spotforecast2_safe.configurator.config_entsoe import ConfigEntsoe - -# Use default configuration -config = ConfigEntsoe() -print(config.country_code) -print(config.predict_size) -print(config.random_state) - -# Create custom configuration -custom_config = ConfigEntsoe( - country_code="FR", - predict_size=48, - cv_block_size=24, - random_state=42, +from spotforecast2_safe.configurator.config_multi import ConfigMulti + +config = ConfigEntsoe(country_code="DE") +# ENTSO-E-specific defaults: +print("index_name:", config.index_name) +print("retrain_max_age:", config.retrain_max_age) +assert config.index_name == "Time (UTC)" +assert config.retrain_max_age == pd.Timedelta(days=7) + +# Inherits the full ConfigMulti surface, incl. the opt-in feature flags: +assert isinstance(config, ConfigMulti) +config = ConfigEntsoe( + include_ephemeris_features=True, + include_day_type_features=True, + include_degree_hours=True, ) -print(custom_config.country_code) -print(custom_config.predict_size) -print(custom_config.cv_block_size) - -# Verify training window -assert config.train_size == pd.Timedelta(days=3 * 365) - -# Check default periods -print(len(config.periods)) -print(config.periods[0].name) -``` - -## Methods - -| Name | Description | -| --- | --- | -| [get_params](#spotforecast2_safe.configurator.config_entsoe.ConfigEntsoe.get_params) | Get parameters for this configuration object. | -| [set_params](#spotforecast2_safe.configurator.config_entsoe.ConfigEntsoe.set_params) | Set the parameters of this configuration object. | - -### get_params { #spotforecast2_safe.configurator.config_entsoe.ConfigEntsoe.get_params } - -```python -configurator.config_entsoe.ConfigEntsoe.get_params(deep=True) -``` - -Get parameters for this configuration object. - -#### Parameters {.doc-section .doc-section-parameters} - -| Name | Type | Description | Default | -|--------|----------------|-----------------------------------------------------------------------------------------------------------|-----------| -| deep | [bool](`bool`) | If True, will return the parameters for this configuration and contained sub-objects that are estimators. | `True` | - -#### Returns {.doc-section .doc-section-returns} - -| Name | Type | Description | -|--------|-----------------------------------------------------------|-------------------------------------------------------| -| params | [Dict](`typing.Dict`)\[[str](`str`), [object](`object`)\] | Dictionary of parameter names mapped to their values. | - -#### Examples {.doc-section .doc-section-examples} - -```{python} -from spotforecast2_safe.configurator.config_entsoe import ConfigEntsoe - -config = ConfigEntsoe(country_code="FR") -p = config.get_params() -print(p["country_code"]) -print(p["predict_size"]) -assert p["country_code"] == "FR" -assert p["predict_size"] == 24 -assert p["cv_block_size"] is None -``` - -### set_params { #spotforecast2_safe.configurator.config_entsoe.ConfigEntsoe.set_params } - -```python -configurator.config_entsoe.ConfigEntsoe.set_params(params=None, **kwargs) -``` - -Set the parameters of this configuration object. - -#### Parameters {.doc-section .doc-section-parameters} - -| Name | Type | Description | Default | -|----------|-----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|-----------| -| params | [Dict](`typing.Dict`)\[[str](`str`), [object](`object`)\] | Optional dictionary of parameter names mapped to their new values. | `None` | -| **kwargs | [object](`object`) | Additional parameter names mapped to their new values. It supports configuring nested 'Period' objects using the `periods____` notation. | `{}` | - -#### Returns {.doc-section .doc-section-returns} - -| Name | Type | Description | -|--------------|------------------------------------------------------------------------------|--------------------------------------------------------------------------------| -| ConfigEntsoe | [ConfigEntsoe](`spotforecast2_safe.configurator.config_entsoe.ConfigEntsoe`) | The configuration instance with updated parameters (supports method chaining). | - -#### Examples {.doc-section .doc-section-examples} - -```{python} -from spotforecast2_safe.configurator.config_entsoe import ConfigEntsoe - -config = ConfigEntsoe() - -# Flat parameter setting -config.set_params(country_code="FR", predict_size=48) -print(config.country_code) -print(config.predict_size) -assert config.country_code == "FR" -assert config.predict_size == 48 - -# Deep parameter setting for nested Period objects -config.set_params(periods__daily__n_periods=24) -daily_n = next(p.n_periods for p in config.periods if p.name == "daily") -print(daily_n) -assert daily_n == 24 +print("ephemeris:", config.include_ephemeris_features) +print("predict_size:", config.predict_size) ``` \ No newline at end of file diff --git a/src/spotforecast2_safe/configurator/config_entsoe.py b/src/spotforecast2_safe/configurator/config_entsoe.py index fa45d8958..29182baf1 100644 --- a/src/spotforecast2_safe/configurator/config_entsoe.py +++ b/src/spotforecast2_safe/configurator/config_entsoe.py @@ -1,434 +1,79 @@ # SPDX-FileCopyrightText: 2026 bartzbeielstein # SPDX-License-Identifier: AGPL-3.0-or-later -"""Configuration for ENTSO-E task pipeline.""" +"""Configuration for the ENTSO-E (single-target) task pipeline.""" from dataclasses import dataclass, field, fields -from typing import Any, Dict, List, Literal, Optional import pandas as pd -from spotforecast2_safe.configurator._base_config import ( - apply_set_params, - build_get_params, - default_periods, - validate_config, -) -from spotforecast2_safe.data import Period +from spotforecast2_safe.configurator.config_multi import ConfigMulti @dataclass -class ConfigEntsoe: +class ConfigEntsoe(ConfigMulti): """Configuration for the ENTSO-E forecasting pipeline. - Single-target counterpart to ``ConfigMulti``. Used by the ENTSO-E CLI - (``spotforecast2.tasks.task_entsoe``) and any other single-target pipeline - routed through ``spotforecast2.multitask.runner.run(config_cls=ConfigEntsoe)``. + Single-target counterpart to `ConfigMulti`, used by the ENTSO-E CLI + (``spotforecast2.tasks.task_entsoe``) and any single-target pipeline routed + through ``spotforecast2.multitask.runner.run(config_cls=ConfigEntsoe)``. - ``country_code`` is the canonical ISO 3166-1 alpha-2 country-code - attribute used by both API queries and the multitask ``PipelineConfig`` - protocol. + ``ConfigEntsoe`` **inherits every field and method of `ConfigMulti`** — so any + feature flag added to ``ConfigMulti`` is available here automatically (this + is what closes the historical feature-flag parity gap structurally, rather + than via a hand-maintained mirror). It differs from ``ConfigMulti`` in + exactly two ways: - Args: - country_code (str): ISO 3166-1 alpha-2 country code (e.g. ``"DE"``). - periods (Optional[List[Period]]): Cyclical feature encodings. - lags_consider (Optional[List[int]]): Lag values for autoregressive features. - train_size (Optional[pd.Timedelta]): Training window. - end_train_default (str): Default end-of-training timestamp (ISO). - delta_val (Optional[pd.Timedelta]): Validation window. - predict_size (int): Prediction horizon in hours. - cv_block_size (int | None): Cross-validation test-block width in - hours. Defaults to ``None``, meaning the CV uses - ``predict_size``. Set to a fixed value (e.g. ``24``) to - decouple the cross-validation horizon from a render-dependent - live ``predict_size``. - refit_size (int): Refit cadence in days. - random_state (int): Random seed. - n_hyperparameters_trials (int): Hyperparameter-tuning trial budget. - data_filename (str): Path to the merged interim CSV. - targets (Optional[List[str]]): Active target column names. ``None`` - until set after data loading. For ENTSO-E this is typically - ``["Actual Load"]``. - use_outlier_detection (bool): Apply IsolationForest-based outlier - removal. Defaults to ``True``. - contamination (float): IsolationForest contamination fraction. - imputation_method (str): Gap-filling strategy. - window_size (int): Rolling window for weighted imputation. Also the - LightGBM rolling-mean feature window in the ENTSO-E factories. - imputation_window_size (Optional[int]): Width of the gap-penalty zone - (in hours) around each imputed value for the ``"weighted"`` - strategy. When ``None`` (default), falls back to ``window_size``, - so existing behaviour is unchanged. Set this to decouple the - imputation penalty zone from the rolling-feature window. - use_exogenous_features (bool): Build weather/calendar/holiday features. - latitude (float): Location latitude. - longitude (float): Location longitude. - timezone (str): IANA timezone string. - state (str): Subdivision code for regional holidays. - include_weather_windows (bool): Weather-window feature toggle. - include_holiday_features (bool): Holiday feature toggle. - include_holiday_adjacency_features (bool): Brückentag and - before/after-holiday indicator toggle. Defaults to ``False``. - poly_features_degree (int): Polynomial-interaction degree passed to - the feature builder. ``1`` (default) generates no interactions; - ``2`` adds pairwise bilinear terms; ``3+`` higher order. - max_poly_features (int): Cap on polynomial interaction columns. When - more than this many ``poly_*`` columns are generated, only the - top ``max_poly_features`` ranked by mutual information with the - target are kept (``<= 0`` disables the cap). Defaults to ``10``. - poly_mi_n_jobs (Optional[int]): Parallel jobs for the - mutual-information ranking that enforces ``max_poly_features``. - ``-1`` (default) uses all cores; ``None`` runs single-threaded. - Parallelism does not change the selection. - poly_mi_sample_size (Optional[int]): Row cap for that ranking; longer - series are scored on a reproducible random subsample of this size - (seeded by ``random_state``), which can change which borderline - columns make the top K. ``None`` scores every row (the pre-15.8 - behaviour). Defaults to ``4000``. - include_covid_infection_rate (bool): Append the bundled German national - COVID-19 7-day incidence (RKI) as an exogenous level regressor. - Defaults to ``False``. - include_entsoe_forecast_load (bool): Append the ENTSO-E day-ahead - Forecasted Load as a near-oracle exogenous prior. Defaults to - ``False``. - include_entsoe_renewable_forecast (bool): Append the ENTSO-E day-ahead - wind and solar generation forecast. Defaults to ``False``. - include_entsoe_net_load (bool): Append the ENTSO-E day-ahead net load - (Forecasted Load minus wind/solar forecast). Defaults to ``False``. - include_entsoe_day_ahead_price (bool): Append the ENTSO-E day-ahead - spot price (DE/LU). Defaults to ``False``. - index_name (str): Datetime column name when the DataFrame index is - reset. ENTSO-E CSVs use ``"Time (UTC)"``; defaults to that. - bounds (Optional[List[tuple]]): Per-column outlier bounds. For - single-target ENTSO-E this is typically ``None`` or a single - ``[(lower, upper)]`` entry. - verbose (bool): Verbose pipeline output. - cache_home (Optional[Any]): Cache directory override. - n_trials_optuna (int): Optuna Bayesian-search trial budget. - n_trials_spotoptim (int): SpotOptim surrogate-search trial budget. - n_initial_spotoptim (int): SpotOptim initial random evaluations. - n_jobs_spotoptim (Optional[int]): Worker count for SpotOptim's parallel - (steady-state) evaluation. ``None`` (default) runs sequentially; - ``-1`` uses all CPU cores; a positive integer pins the worker count. - Parallel tuning is faster but, being steady-state, changes the search - trajectory, so the tuned result is not bit-identical to a sequential - run even with a fixed ``random_state``. - warm_start_lags (bool): Seed the SpotOptim search with ``lags_consider``. - task (str): Active prediction task name. - agg_weights (Optional[List[float]]): Per-target aggregation weights. - For single-target use this is typically ``[1.0]`` or ``None``. - forecaster_factory (Optional[Any]): Callable - ``factory(config, *, weight_func, target) -> forecaster`` - consumed by ``BaseTask.create_forecaster``. ``None`` falls back - to the default LightGBM factory. - data_loader (Optional[Any]): Callable ``data_loader(config)`` returning - a pandas DataFrame. Invoked by ``BaseTask.prepare_data`` when no - DataFrame is supplied — the ENTSO-E pipeline hook for - ``download_new_data`` / ``merge_build_manual``. - test_data_loader (Optional[Any]): Callable ``test_data_loader(config)`` - returning a pandas DataFrame with ground-truth values for the - prediction horizon. Invoked by ``BaseTask.prepare_data`` when no - test DataFrame is supplied; the returned frame populates - ``test_actual`` and ``metrics_future`` in the prediction package. - auto_save_models (bool): Whether ``BaseTask._run_strategy`` should - persist fitted forecasters to ``/models/`` after every - training run. Defaults to ``True``. - data_frame_name (str): Identifier for the active dataset. Used by - ``BaseTask`` to name cache subdirectories, model files, and the - per-dataset log file. Defaults to ``"default"``. - number_folds (int): Number of folds used by ``BaseTask.cv_ts`` when - building the ``TimeSeriesSplit`` cross-validation splitter for - tuning tasks. Defaults to ``10``. - on_weather_failure (Literal["raise", "skip"]): Policy for handling - Open-Meteo fetch failures inside - ``BaseTask.build_exogenous_features``. ``"raise"`` (default) - aborts the pipeline with a ``WeatherFetchError`` and preserves - the safety-critical fail-safe semantics. ``"skip"`` logs a - warning and continues with empty weather features so the rest - of the pipeline can run without the Open-Meteo dependency. - on_exog_provider_failure (Literal["raise", "skip"]): Policy for an - exogenous-provider failure inside ``ExogBuilder.build``. ``"raise"`` - (default) propagates the ``ExogProviderError`` (fail-safe); - ``"skip"`` logs a warning and omits that provider's columns. - exog_max_gap_hours (int): Maximum length, in hours, of a contiguous run - of missing exogenous-provider values healed before the provider is - rejected. Interior gaps are time-interpolated; leading/trailing edge - gaps are back-/forward-filled. ``0`` (default) keeps the strict - fail-safe (any gap raises). Healed runs are logged with count and - span. Only already-published day-ahead vintages are involved, so - healing is leakage-clean (CR-3). - exog_max_tail_gap_hours (int): Extended healing budget, in hours, - applied exclusively to the trailing-edge NaN run (the run - containing the last index timestamp). The effective tail budget is - ``max(exog_max_gap_hours, exog_max_tail_gap_hours)``. The canonical - use case is the ENTSO-E day-ahead publication frontier: the last - published vintage is zero-order-held forward to the forecast horizon - without touching interior gaps (CR-3-clean). When - ``exog_max_tail_gap_hours <= exog_max_gap_hours`` the parameter is - inert (the interior budget already covers the tail) and a warning is - logged. Defaults to ``0``. - exog_provider_window (Literal["full", "train"]): Span the exogenous - providers are validated against. ``"full"`` (default) requires - coverage of the entire ``data_start``→``cov_end`` request, matching - prior behaviour. ``"train"`` validates only the consumed window - ``[start_train_ts, cov_end]``, tolerating missing values before the - training window. Honoured by the MultiTask pipeline; the - forecaster-wrapper path currently always validates the full span. - retrain_max_age (pd.Timedelta): Maximum age of a previously trained - model before retraining is required. Consumed by - ``spotforecast2_safe.manager.trainer.should_retrain`` to gate - scheduled retraining workflows. Defaults to ``Timedelta(days=7)``. + - ``index_name`` defaults to ``"Time (UTC)"`` (the ENTSO-E CSV time column) + instead of ``"DateTime"``. + - it adds ``retrain_max_age`` — the maximum age of a previously trained model + before retraining is required (consumed by + `spotforecast2_safe.manager.trainer.should_retrain`). + + See `ConfigMulti` for the full field reference (training/validation windows, + feature toggles, exogenous-provider flags, target-corruption knobs, …). - Attributes: - country_code (str): ISO country code used for API queries and - holiday feature generation. - auto_save_models (bool): Whether to auto-persist fitted forecasters - after each training run. - data_frame_name (str): Active-dataset identifier used for cache and - log-file naming. - number_folds (int): Cross-validation fold count for tuning tasks. - on_weather_failure (Literal["raise", "skip"]): Open-Meteo fetch-failure - policy: ``"raise"`` aborts, ``"skip"`` continues without weather. + Args: + index_name (str): Datetime column name used when resetting the index. + Defaults to ``"Time (UTC)"``. + retrain_max_age (pd.Timedelta): Maximum age of a trained model before a + retrain is forced. Defaults to ``pd.Timedelta(days=7)``. Examples: ```{python} import pandas as pd - from spotforecast2_safe.configurator.config_entsoe import ConfigEntsoe - - # Use default configuration - config = ConfigEntsoe() - print(config.country_code) - print(config.predict_size) - print(config.random_state) - - # Create custom configuration - custom_config = ConfigEntsoe( - country_code="FR", - predict_size=48, - cv_block_size=24, - random_state=42, + from spotforecast2_safe.configurator.config_multi import ConfigMulti + + config = ConfigEntsoe(country_code="DE") + # ENTSO-E-specific defaults: + print("index_name:", config.index_name) + print("retrain_max_age:", config.retrain_max_age) + assert config.index_name == "Time (UTC)" + assert config.retrain_max_age == pd.Timedelta(days=7) + + # Inherits the full ConfigMulti surface, incl. the opt-in feature flags: + assert isinstance(config, ConfigMulti) + config = ConfigEntsoe( + include_ephemeris_features=True, + include_day_type_features=True, + include_degree_hours=True, ) - print(custom_config.country_code) - print(custom_config.predict_size) - print(custom_config.cv_block_size) - - # Verify training window - assert config.train_size == pd.Timedelta(days=3 * 365) - - # Check default periods - print(len(config.periods)) - print(config.periods[0].name) + print("ephemeris:", config.include_ephemeris_features) + print("predict_size:", config.predict_size) ``` """ - country_code: str = "DE" - periods: List[Period] = field(default_factory=default_periods) - lags_consider: List[int] = field(default_factory=lambda: list(range(1, 24))) - train_size: pd.Timedelta = field(default_factory=lambda: pd.Timedelta(days=3 * 365)) - end_train_default: str = "2025-12-31 00:00+00:00" - delta_val: pd.Timedelta = field( - default_factory=lambda: pd.Timedelta(hours=24 * 7 * 10) - ) - predict_size: int = 24 - # Cross-validation test-block width (hours). ``None`` defers to - # ``predict_size``; the actual CV-split logic lives in the sibling - # ``spotforecast2`` package (``BaseTask.cv_ts``). - cv_block_size: Optional[int] = None - refit_size: int = 7 - random_state: int = 314159 - n_hyperparameters_trials: int = 20 - data_filename: str = "interim/energy_load.csv" - targets: Optional[List[str]] = None - # Outlier detection - use_outlier_detection: bool = True - contamination: float = 0.01 - # Imputation - imputation_method: str = "weighted" - window_size: int = 72 - imputation_window_size: Optional[int] = None - # Exogenous features - use_exogenous_features: bool = True - latitude: float = 51.5136 - longitude: float = 7.4653 - timezone: str = "UTC" - state: str = "NW" - # Feature selection toggles - include_weather_windows: bool = False - include_holiday_features: bool = False - include_holiday_adjacency_features: bool = False - # Global / derived weather and calendar refinements (parity with ConfigMulti; - # consumed by spotforecast2.multitask.base.build_exogenous_features). All - # default off → byte-identical to the single-point baseline. - # ``use_population_weighted_weather`` samples the fixed German load-centre - # registry and combines cities by population weight; - # ``include_degree_hours`` adds heating/cooling degree-hours; - # ``include_apparent_temperature`` adds apparent temperature + dew point; - # ``include_ephemeris_features`` adds continuous solar geometry - # (solar_elevation, daylight_duration_h, signed sunrise/sunset-relative time); - # ``include_day_type_features`` adds is_workday + a day_type class. - use_population_weighted_weather: bool = False - include_degree_hours: bool = False - include_apparent_temperature: bool = False - degree_hours_base_heating: float = 15.0 - degree_hours_base_cooling: float = 22.0 - include_ephemeris_features: bool = False - include_day_type_features: bool = False - poly_features_degree: int = 1 - max_poly_features: int = 10 - poly_mi_n_jobs: Optional[int] = -1 - poly_mi_sample_size: Optional[int] = 4000 - # Provider-based exogenous toggles, each gated by a registry flag in - # ``spotforecast2_safe.preprocessing.exog_providers``. - include_covid_infection_rate: bool = False - include_entsoe_forecast_load: bool = False - include_entsoe_renewable_forecast: bool = False - include_entsoe_net_load: bool = False - include_entsoe_day_ahead_price: bool = False - # Data source and index (ENTSO-E CSVs use "Time (UTC)") + # Default override (re-declaring an inherited field keeps its position and + # only changes the default). index_name: str = "Time (UTC)" - # Per-column outlier bounds - bounds: Optional[List[tuple]] = None - # Verbosity and caching - verbose: bool = False - cache_home: Optional[Any] = None - # Hyperparameter tuning trial budgets - n_trials_optuna: int = 15 - n_trials_spotoptim: int = 10 - n_initial_spotoptim: int = 5 - # SpotOptim parallel-evaluation worker count (None=serial, -1=all cores); - # consumed by spotforecast2.multitask.strategies.SpotOptimStrategy - n_jobs_spotoptim: Optional[int] = None - # Seed the SpotOptim search with ``lags_consider`` (consumed by - # spotforecast2.multitask.strategies.SpotOptimStrategy) - warm_start_lags: bool = False - # Active task - task: str = "lazy" - # Aggregation weights (single-target uses [1.0] or None) - agg_weights: Optional[List[float]] = None - # Forecaster factory hook (consumed by spotforecast2.multitask.base): - # ``factory(config, *, weight_func, target) -> forecaster``. - forecaster_factory: Optional[Any] = None - # Data-loader hook (consumed by ``BaseTask.prepare_data``): - # ``data_loader(config) -> pd.DataFrame``, invoked when no DataFrame is - # supplied. - data_loader: Optional[Any] = None - # Test-data-loader hook (consumed by ``BaseTask.prepare_data``): mirrors - # ``data_loader`` for the test/ground-truth slice. - test_data_loader: Optional[Any] = None - # Persistence policy and active-dataset name (consumed by - # spotforecast2.multitask.base). - auto_save_models: bool = True - data_frame_name: str = "default" - # Cross-validation fold count (consumed by spotforecast2.multitask.base.cv_ts) - number_folds: int = 10 - # Weather-fetch failure policy (consumed by - # spotforecast2.multitask.base.build_exogenous_features) - on_weather_failure: Literal["raise", "skip"] = "raise" - # Exog-provider failure policy (consumed by - # preprocessing.exog_builder.ExogBuilder) - on_exog_provider_failure: Literal["raise", "skip"] = "raise" - # Gap-healing budget for exog providers (0 = strict fail-safe) - exog_max_gap_hours: int = 0 - # Extended trailing-edge healing budget (0 = same as exog_max_gap_hours) - exog_max_tail_gap_hours: int = 0 - # Validation window for exog providers ("full" or "train") - exog_provider_window: Literal["full", "train"] = "full" - # Maximum age of a previously trained model before retraining is required - # (consumed by spotforecast2_safe.manager.trainer.should_retrain) + # ENTSO-E-only field (appended after the inherited ConfigMulti fields). retrain_max_age: pd.Timedelta = field(default_factory=lambda: pd.Timedelta(days=7)) - # Target-side corruption detector knobs. Detector active only when - # target_qc_window_days AND at least one of target_qc_range_mw / - # target_qc_step_mw / target_qc_deviation_mw are set. Defaults are all - # None / off, so the pipeline is byte-identical to the pre-feature baseline. - # Recommended episode policy: "truncate" (auto-extends predict_size). - # "heal" under the default anchor_zone_hours=168 with a <=7-day QC window - # never engages (refusal by design). The deviation rule (dropout-only, vs a - # published reference column such as "Forecasted Load") catches corruption - # that stays below the dynamics thresholds; when enabling it, scope - # ``targets`` to the actuals so heal/truncate leave the reference intact. - target_qc_range_mw: Optional[float] = None - target_qc_step_mw: Optional[float] = None - target_qc_window_days: Optional[int] = None - target_corruption_policy: str = "abort" - target_max_heal_hours: int = 0 - target_anchor_zone_hours: int = 168 - target_qc_deviation_mw: Optional[float] = None - target_qc_deviation_ref: Optional[str] = None - target_qc_deviation_slots: int = 2 - - def __post_init__(self) -> None: - """Reject clearly-invalid hyperparameter values (fail-safe).""" - validate_config(self) - - def get_params(self, deep: bool = True) -> Dict[str, object]: - """ - Get parameters for this configuration object. - - Args: - deep: If True, will return the parameters for this configuration and - contained sub-objects that are estimators. - - Returns: - params: Dictionary of parameter names mapped to their values. - - Examples: - ```{python} - from spotforecast2_safe.configurator.config_entsoe import ConfigEntsoe - - config = ConfigEntsoe(country_code="FR") - p = config.get_params() - print(p["country_code"]) - print(p["predict_size"]) - assert p["country_code"] == "FR" - assert p["predict_size"] == 24 - assert p["cv_block_size"] is None - ``` - """ - return build_get_params(self, [f.name for f in fields(self)], deep) - - def set_params( - self, params: Dict[str, object] = None, **kwargs: object - ) -> "ConfigEntsoe": - """ - Set the parameters of this configuration object. - - Args: - params: Optional dictionary of parameter names mapped to their - new values. - **kwargs: Additional parameter names mapped to their new values. - It supports configuring nested 'Period' objects using the - `periods____` notation. - - Returns: - ConfigEntsoe: The configuration instance with updated - parameters (supports method chaining). - - Examples: - ```{python} - from spotforecast2_safe.configurator.config_entsoe import ConfigEntsoe - - config = ConfigEntsoe() - - # Flat parameter setting - config.set_params(country_code="FR", predict_size=48) - print(config.country_code) - print(config.predict_size) - assert config.country_code == "FR" - assert config.predict_size == 48 - - # Deep parameter setting for nested Period objects - config.set_params(periods__daily__n_periods=24) - daily_n = next(p.n_periods for p in config.periods if p.name == "daily") - print(daily_n) - assert daily_n == 24 - ``` - """ - return apply_set_params(self, params, **kwargs) -# ``_PARAM_NAMES`` is derived from the dataclass fields (declaration order) so it -# can never drift from the actual fields; consumers and tests still read it as a -# class attribute. Set after the class body because ``fields()`` needs the -# finished dataclass. +# ``_PARAM_NAMES`` is derived from the dataclass fields (declaration order, which +# for a subclass is the base fields followed by the subclass-only fields) so it +# can never drift from the actual fields; consumers and tests read it as a class +# attribute. Set after the class body because ``fields()`` needs the finished +# dataclass. ConfigEntsoe._PARAM_NAMES = tuple(f.name for f in fields(ConfigEntsoe)) diff --git a/tests/test_config_entsoe_feature_parity.py b/tests/test_config_entsoe_feature_parity.py index 2e30b5b19..e78b308ff 100644 --- a/tests/test_config_entsoe_feature_parity.py +++ b/tests/test_config_entsoe_feature_parity.py @@ -1,61 +1,79 @@ # SPDX-FileCopyrightText: 2026 bartzbeielstein # SPDX-License-Identifier: AGPL-3.0-or-later -"""ConfigEntsoe must carry the same opt-in feature flags as ConfigMulti. +"""ConfigEntsoe inherits ConfigMulti — the parity gap is closed structurally. -ConfigEntsoe is an independent dataclass (not a ConfigMulti subclass), so new -feature toggles added to ConfigMulti must be mirrored here or the ENTSO-E -single-target pipeline silently cannot enable them. +Since 20.0.0 ``ConfigEntsoe`` is a subclass of ``ConfigMulti``, so every feature +flag (and every other field) on ``ConfigMulti`` is present on ``ConfigEntsoe`` +automatically. These tests lock in that relationship and the two — and only two +— sanctioned differences (``index_name`` default + the ``retrain_max_age`` +field), so a future asymmetric edit fails loudly. """ +import pandas as pd + from spotforecast2_safe.configurator.config_entsoe import ConfigEntsoe from spotforecast2_safe.configurator.config_multi import ConfigMulti -NEW_FEATURE_FLAGS = ( +# Feature flags that previously had to be hand-mirrored (regression guard). +INHERITED_FEATURE_FLAGS = ( "use_population_weighted_weather", "include_degree_hours", "include_apparent_temperature", "include_ephemeris_features", "include_day_type_features", ) -NEW_FLOAT_FIELDS = ("degree_hours_base_heating", "degree_hours_base_cooling") -class TestFeatureFlagParity: - def test_flags_present_and_default_off(self): +class TestStructuralParity: + def test_entsoe_is_configmulti_subclass(self): + assert issubclass(ConfigEntsoe, ConfigMulti) + assert isinstance(ConfigEntsoe(), ConfigMulti) + + def test_entsoe_is_superset_plus_only_known_extra(self): + multi = set(ConfigMulti._PARAM_NAMES) + entsoe = set(ConfigEntsoe._PARAM_NAMES) + # No ConfigMulti field can be missing from ConfigEntsoe. + assert multi <= entsoe + # The only sanctioned ENTSO-E-specific field. + assert entsoe - multi == {"retrain_max_age"} + + def test_feature_flags_inherited(self): cfg = ConfigEntsoe() - for flag in NEW_FEATURE_FLAGS: + for flag in INHERITED_FEATURE_FLAGS: assert hasattr(cfg, flag), flag assert getattr(cfg, flag) is False, flag - def test_float_defaults_match_configmulti(self): - entsoe, multi = ConfigEntsoe(), ConfigMulti() - for fld in NEW_FLOAT_FIELDS: - assert getattr(entsoe, fld) == getattr(multi, fld), fld - - def test_flags_in_param_names(self): - for flag in NEW_FEATURE_FLAGS + NEW_FLOAT_FIELDS: - assert flag in ConfigEntsoe._PARAM_NAMES, flag - - def test_constructs_with_flags_enabled(self): + def test_constructs_with_inherited_flags(self): cfg = ConfigEntsoe( use_population_weighted_weather=True, include_degree_hours=True, include_apparent_temperature=True, include_ephemeris_features=True, include_day_type_features=True, - degree_hours_base_heating=16.0, - degree_hours_base_cooling=21.0, ) assert cfg.include_ephemeris_features is True assert cfg.include_day_type_features is True - assert cfg.degree_hours_base_heating == 16.0 - def test_configmulti_and_entsoe_agree_on_flag_set(self): - multi = set(ConfigMulti._PARAM_NAMES) - entsoe = set(ConfigEntsoe._PARAM_NAMES) - for flag in NEW_FEATURE_FLAGS + NEW_FLOAT_FIELDS: - assert flag in multi and flag in entsoe, flag + +class TestSanctionedDifferences: + def test_index_name_default_override(self): + assert ConfigMulti().index_name == "DateTime" + assert ConfigEntsoe().index_name == "Time (UTC)" + + def test_retrain_max_age(self): + assert ConfigEntsoe().retrain_max_age == pd.Timedelta(days=7) + assert not hasattr(ConfigMulti(), "retrain_max_age") + + def test_shared_field_defaults_identical(self): + """Every shared field has the same default on both classes (only + index_name is allowed to differ).""" + multi, entsoe = ConfigMulti(), ConfigEntsoe() + shared = set(ConfigMulti._PARAM_NAMES) & set(ConfigEntsoe._PARAM_NAMES) + diffs = { + k for k in shared if repr(getattr(multi, k)) != repr(getattr(entsoe, k)) + } + assert diffs == {"index_name"}, diffs if __name__ == "__main__":