Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
b66a49d
Add data downloading functionality and gdown dependency for test setup
romeokienzler Apr 29, 2026
8b99b82
Refactor test setup by removing unused config preparation function an…
romeokienzler Apr 29, 2026
9be32c4
Add test data generation functions for power-flow and optimal power-f…
romeokienzler Apr 29, 2026
a76b1cc
Add OPF integration test with separate data/log dirs and metric valid…
romeokienzler Apr 29, 2026
bba356c
fix dateset
romeokienzler Apr 30, 2026
2856323
fix test data generator
romeokienzler Apr 30, 2026
e57abc0
Update Google Drive file ID for OPF data download in test_train_opf
romeokienzler May 6, 2026
c2c0684
Add calibration option and enhance training tests with metrics collec…
romeokienzler May 7, 2026
2b45d8b
Refactor calibration stats printing to include confidence intervals a…
romeokienzler May 8, 2026
c2a5255
Update test assertions to reflect 95% confidence intervals for PBE Me…
romeokienzler May 8, 2026
68f4960
Enhance calibration functionality by adding confidence interval optio…
romeokienzler May 8, 2026
e55cac9
Update calibration stats and tests to reflect 99.5% confidence intervals
romeokienzler May 9, 2026
143f81a
Implement retry logic for metric validation in training tests
romeokienzler May 11, 2026
20e75de
Add note about `torch-scatter` as a required dependency in installati…
romeokienzler May 15, 2026
d1a795b
Update default confidence interval to 0.995 in calibration options an…
romeokienzler May 18, 2026
2021827
Rename test function to clarify focus on power flow integration and u…
romeokienzler May 18, 2026
1176d99
Update test assertions to reflect 95% confidence interval for PBE Mea…
romeokienzler May 18, 2026
df99154
Update PBE Mean value assertions to reflect new 95% confidence interval
romeokienzler May 18, 2026
15a9820
Merge branch 'main' into improve_tests
romeokienzler May 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Install gridfm-graphkit in editable mode
pip install -e .
```

**`torch-scatter` is a required dependency.** It cannot be bundled in `pyproject.toml` because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.

Get PyTorch + CUDA version for torch-scatter
```bash
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")
Expand Down
2 changes: 2 additions & 0 deletions docs/install/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Install gridfm-graphkit in editable mode
pip install -e .
```

**`torch-scatter` is a required dependency.** It cannot be bundled in `pyproject.toml` because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.

Get PyTorch + CUDA version for torch-scatter

```bash
Expand Down
31 changes: 31 additions & 0 deletions integrationtests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest


def pytest_addoption(parser):
parser.addoption(
"--calibrate",
type=int,
default=0,
help="Run training N times to collect metric mean/std for range calibration. "
"Skips metric range assertions. Example: pytest --calibrate 5",
)
parser.addoption(
"--ci",
type=float,
default=0.995,
help="Confidence interval level for calibration stats (default 0.995). "
"Example: pytest --calibrate 5 -s --ci 0.995",
)


@pytest.fixture
def calibrate_runs(request):
"""Number of calibration runs requested via --calibrate (0 = normal test mode)."""
return request.config.getoption("--calibrate")


@pytest.fixture
def ci_level(request):
"""Confidence interval level requested via --ci (default 0.995)."""
return request.config.getoption("--ci")

72 changes: 72 additions & 0 deletions integrationtests/generate_test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import urllib.request
import yaml
import subprocess


def execute_and_live_output(cmd) -> None:
subprocess.run(cmd, text=True, shell=True, check=True)


def _base_config() -> dict:
"""
Download the default config from gridfm-datakit and apply common test parameters.
"""
config_url = (
"https://raw.githubusercontent.com/gridfm/gridfm-datakit/refs/heads/main"
"/scripts/config/default.yaml"
)

print(f"Downloading config from {config_url}...")
with urllib.request.urlopen(config_url) as response:
config_content = response.read().decode("utf-8")

config = yaml.safe_load(config_content)

config["network"]["name"] = "case14_ieee"
config["load"]["scenarios"] = 10000
config["topology_perturbation"]["n_topology_variants"] = 2

return config


def generate_pf_test_data(config_path: str = "integrationtests/default_pf.yaml") -> None:
"""
Generate power-flow (PF) test data for case14_ieee with 10 000 scenarios
and 2 topology variants.
"""
config = _base_config()

with open(config_path, "w") as f:
yaml.dump(config, f, default_flow_style=False, sort_keys=False)

print(f"PF config written to {config_path}")
print(f" network.name : {config['network']['name']}")
print(f" load.scenarios : {config['load']['scenarios']}")
print(f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}")

execute_and_live_output(f"gridfm_datakit generate {config_path}")


def generate_opf_test_data(config_path: str = "integrationtests/default_opf.yaml") -> None:
"""
Generate optimal power-flow (OPF) test data for case14_ieee with 10 000 scenarios
and 2 topology variants.
"""
config = _base_config()
config["settings"]["mode"] = "opf"

with open(config_path, "w") as f:
yaml.dump(config, f, default_flow_style=False, sort_keys=False)

print(f"OPF config written to {config_path}")
print(f" network.name : {config['network']['name']}")
print(f" load.scenarios : {config['load']['scenarios']}")
print(f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}")
print(f" settings.mode : {config['settings']['mode']}")

execute_and_live_output(f"gridfm_datakit generate {config_path}")


if __name__ == "__main__":
#generate_pf_test_data()
generate_opf_test_data()
Loading
Loading