diff --git a/README.md b/README.md index ad85f056..753fdb7e 100644 --- a/README.md +++ b/README.md @@ -60,9 +60,9 @@ gridfm_graphkit [OPTIONS] Available commands: * `train` - Train a new model from scratch -* `finetune` – Fine-tune an existing pre-trained model -* `evaluate` – Evaluate model performance on a dataset -* `predict` – Run inference and save predictions +* `finetune` - Fine-tune an existing pre-trained model +* `evaluate` - Evaluate model performance on a dataset +* `predict` - Run inference and save predictions --- @@ -74,13 +74,22 @@ gridfm_graphkit train --config path/to/config.yaml ### Arguments -| Argument | Type | Description | Default | -| ---------------- | ------ | ---------------------------------------------------------------- | ------- | -| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` | -| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | -| `--run_name` | `str` | MLflow run name. | `run` | -| `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` | -| `--data_path` | `str` | Root dataset directory. | `data` | +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` | +| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | +| `--run_name` | `str` | MLflow run name. | `run` | +| `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` | +| `--data_path` | `str` | Root dataset directory. | `data` | +| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` | +| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` | +| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` | +| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` | +| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` | +| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` | ### Examples @@ -100,14 +109,23 @@ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model ### Arguments -| Argument | Type | Description | Default | -| -------------- | ----- | ----------------------------------------------- | --------- | -| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` | -| `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` | -| `--exp_name` | `str` | MLflow experiment name. | timestamp | -| `--run_name` | `str` | MLflow run name. | `run` | -| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | -| `--data_path` | `str` | Root dataset directory. | `data` | +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` | +| `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` | +| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | +| `--run_name` | `str` | MLflow run name. | `run` | +| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | +| `--data_path` | `str` | Root dataset directory. | `data` | +| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` | +| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` | +| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` | +| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` | +| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` | +| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` | --- @@ -120,17 +138,25 @@ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.p ### Arguments -| Argument | Type | Description | Default | -| --------------------- | ----- | ------------------------------------------------------------------------------------------------------------- | --------- | -| `--config` | `str` | **Required**. Path to evaluation config. | `None` | -| `--model_path` | `str` | Path to the trained model state dict. | `None` | -| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on the current data split. | `None` | -| `--exp_name` | `str` | MLflow experiment name. | timestamp | -| `--run_name` | `str` | MLflow run name. | `run` | -| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | -| `--data_path` | `str` | Dataset directory. | `data` | -| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` | -| `--save_output` | `flag` | Save predictions as `_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` | +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Path to evaluation config. | `None` | +| `--model_path` | `str` | Path to the trained model state dict. | `None` | +| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on current split. | `None` | +| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | +| `--run_name` | `str` | MLflow run name. | `run` | +| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | +| `--data_path` | `str` | Dataset directory. | `data` | +| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` | +| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` | +| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` | +| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` | +| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` | +| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` | +| `--save_output` | `flag` | Save predictions as `_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` | ### Example with saved normalizer stats @@ -156,16 +182,44 @@ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model. ### Arguments -| Argument | Type | Description | Default | -| --------------------- | ----- | ------------------------------------------------------------------------------------------------------------- | --------- | -| `--config` | `str` | **Required**. Path to prediction config file. | `None` | -| `--model_path` | `str` | Path to the trained model state dict. | `None` | -| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` | -| `--exp_name` | `str` | MLflow experiment name. | timestamp | -| `--run_name` | `str` | MLflow run name. | `run` | -| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | -| `--data_path` | `str` | Dataset directory. | `data` | -| `--output_path` | `str` | Directory where predictions are saved as `_predictions.parquet`. | `data` | +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Path to prediction config file. | `None` | +| `--model_path` | `str` | Path to trained model state dict. Optional; may be defined in config. | `None` | +| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` | +| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | +| `--run_name` | `str` | MLflow run name. | `run` | +| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | +| `--data_path` | `str` | Dataset directory. | `data` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` | +| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` | +| `--output_path` | `str` | Directory where predictions are saved as `_predictions.parquet`. | `data` | +| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` | +| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` | +| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` | +| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` | + +--- + +## Benchmarking Dataloader Throughput + +```bash +gridfm_graphkit benchmark --config path/to/config.yaml +``` + +### Arguments + +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Path to configuration YAML file. | `None` | +| `--data_path` | `str` | Root dataset directory. | `data` | +| `--epochs` | `int` | Number of epochs to iterate through the train dataloader. | `3` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Directory for dataset wrapper disk cache. | `None` | +| `--num_workers` | `int` | Override `data.workers` from YAML. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration. | `[]` | Use built-in help for full command details: diff --git a/docs/datasets/data_modules.md b/docs/datasets/data_modules.md deleted file mode 100644 index a5e4dff4..00000000 --- a/docs/datasets/data_modules.md +++ /dev/null @@ -1,3 +0,0 @@ -# LitGridHeteroDataModule - -::: gridfm_graphkit.datasets.hetero_powergrid_datamodule.LitGridHeteroDataModule diff --git a/docs/datasets/data_normalization.md b/docs/datasets/data_normalization.md deleted file mode 100644 index 1747fd17..00000000 --- a/docs/datasets/data_normalization.md +++ /dev/null @@ -1,57 +0,0 @@ -# Data Normalization - - - -Normalization improves neural network training by ensuring features are well-scaled, preventing issues like exploding gradients and slow convergence. In power grids, where variables like voltage and power span wide ranges, normalization is essential. -The `gridfm-graphkit` package offers normalization methods based on the per-unit (p.u.) system: - -- [`BaseMVA Normalization`](#heterodatamvanormalizer) -- [`Per-Sample BaseMVA Normalization`](#heterodatapersamplemvanormalizer) - -Each of these strategies implements a unified interface and can be used interchangeably depending on the learning task and data characteristics. - -> Users can create their own custom normalizers by extending the base [`Normalizer`](#normalizer) class to suit specific needs. - - ---- - -## Available Normalizers - -### `Normalizer` - -::: gridfm_graphkit.datasets.normalizers.Normalizer - ---- - -### `HeteroDataMVANormalizer` - -::: gridfm_graphkit.datasets.normalizers.HeteroDataMVANormalizer - ---- - -### `HeteroDataPerSampleMVANormalizer` - -::: gridfm_graphkit.datasets.normalizers.HeteroDataPerSampleMVANormalizer - ---- - -## Usage Workflow - -Example: - -```python -from gridfm_graphkit.datasets.normalizers import HeteroDataMVANormalizer -from torch_geometric.data import HeteroData - -# Create normalizer -normalizer = HeteroDataMVANormalizer(args) - -# Fit on training data -params = normalizer.fit(data_path, scenario_ids) - -# Transform data -normalizer.transform(hetero_data) - -# Inverse transform to restore original scale -normalizer.inverse_transform(hetero_data) -``` diff --git a/docs/datasets/powergrid.md b/docs/datasets/powergrid.md deleted file mode 100644 index 1f983a53..00000000 --- a/docs/datasets/powergrid.md +++ /dev/null @@ -1,3 +0,0 @@ -## `HeteroGridDatasetDisk` - -::: gridfm_graphkit.datasets.powergrid_hetero_dataset.HeteroGridDatasetDisk diff --git a/docs/datasets/transforms.md b/docs/datasets/transforms.md deleted file mode 100644 index 0dcf981b..00000000 --- a/docs/datasets/transforms.md +++ /dev/null @@ -1,19 +0,0 @@ -# Transforms - -> Each transformation class inherits from [`BaseTransform`](https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.BaseTransform) provided by [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/). - -### `RemoveInactiveGenerators` - -::: gridfm_graphkit.datasets.transforms.RemoveInactiveGenerators - -### `RemoveInactiveBranches` - -::: gridfm_graphkit.datasets.transforms.RemoveInactiveBranches - -### `ApplyMasking` - -::: gridfm_graphkit.datasets.transforms.ApplyMasking - -### `LoadGridParamsFromPath` - -::: gridfm_graphkit.datasets.transforms.LoadGridParamsFromPath diff --git a/docs/install/installation.md b/docs/install/installation.md index 89ee4bea..c65ab752 100644 --- a/docs/install/installation.md +++ b/docs/install/installation.md @@ -1,46 +1,32 @@ # Installation -You can install `gridfm-graphkit` directly from PyPI: +The steps below mirror the [README](https://github.com/gridfm/gridfm-graphkit/blob/main/README.md#installation). Run them from the root of a local clone or source checkout of the repository. -```bash -pip install gridfm-graphkit -``` - -For GPU support and compatibility with PyTorch Geometric's scatter operations, install PyTorch (and optionally CUDA) first, then install the matching `torch-scatter` wheel. See [PyTorch and torch-scatter](#pytorch-and-torch-scatter-optional) below. - ---- - -## Development Setup - -To contribute or develop locally, clone the repository and install in editable mode. Use Python 3.10, 3.11, or 3.12 (3.12 is recommended). +Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12) ```bash -git clone git@github.com:gridfm/gridfm-graphkit.git -cd gridfm-graphkit python -m venv venv source venv/bin/activate -pip install -e . ``` -### PyTorch and torch-scatter (optional) +Install gridfm-graphkit in editable mode -If you need GPU acceleration or PyTorch Geometric scatter ops (used by the library), install PyTorch and the matching `torch-scatter` wheel: +```bash +pip install -e . +``` -1. Install PyTorch (see [pytorch.org](https://pytorch.org/) for your platform and CUDA version). +Get PyTorch + CUDA version for torch-scatter -2. Get your Torch + CUDA version string: - ```bash - TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))") - ``` +```bash +TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))") +``` -3. Install the correct `torch-scatter` wheel: - ```bash - pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html - ``` +Install the correct torch-scatter wheel ---- +```bash +pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html +``` -## Optional extras For documentation generation and unit testing, install with the optional `dev` and `test` extras: diff --git a/docs/models/models.md b/docs/models/models.md deleted file mode 100644 index 7c8c5c6b..00000000 --- a/docs/models/models.md +++ /dev/null @@ -1,37 +0,0 @@ -# Models - -### `GNS_heterogeneous` - -::: gridfm_graphkit.models.gnn_heterogeneous_gns.GNS_heterogeneous - ---- - -## Physics Decoders - -### `PhysicsDecoderOPF` - -::: gridfm_graphkit.models.utils.PhysicsDecoderOPF - -### `PhysicsDecoderPF` - -::: gridfm_graphkit.models.utils.PhysicsDecoderPF - -### `PhysicsDecoderSE` - -::: gridfm_graphkit.models.utils.PhysicsDecoderSE - ---- - -## Utility Modules - -### `ComputeBranchFlow` - -::: gridfm_graphkit.models.utils.ComputeBranchFlow - -### `ComputeNodeInjection` - -::: gridfm_graphkit.models.utils.ComputeNodeInjection - -### `ComputeNodeResiduals` - -::: gridfm_graphkit.models.utils.ComputeNodeResiduals diff --git a/docs/quick_start/quick_start.md b/docs/quick_start/quick_start.md index 575da88a..319a7396 100644 --- a/docs/quick_start/quick_start.md +++ b/docs/quick_start/quick_start.md @@ -8,10 +8,11 @@ gridfm_graphkit [OPTIONS] Available commands: -* `train` – Train a new model from scratch -* `finetune` – Fine-tune an existing pre-trained model -* `evaluate` – Evaluate model performance on a dataset -* `predict` – Run inference and save predictions +* `train` - Train a new model from scratch +* `finetune` - Fine-tune an existing pre-trained model +* `evaluate` - Evaluate model performance on a dataset +* `predict` - Run inference and save predictions +* `benchmark` - Benchmark train-dataloader iteration speed --- @@ -23,13 +24,22 @@ gridfm_graphkit train --config path/to/config.yaml ### Arguments -| Argument | Type | Description | Default | -| ---------------- | ------ | ---------------------------------------------------------------- | ------- | -| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` | -| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | -| `--run_name` | `str` | MLflow run name. | `run` | -| `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` | -| `--data_path` | `str` | Root dataset directory. | `data` | +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` | +| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | +| `--run_name` | `str` | MLflow run name. | `run` | +| `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` | +| `--data_path` | `str` | Root dataset directory. | `data` | +| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` | +| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` | +| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` | +| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` | +| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` | +| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` | ### Examples @@ -49,14 +59,23 @@ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model ### Arguments -| Argument | Type | Description | Default | -| -------------- | ----- | ----------------------------------------------- | --------- | -| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` | -| `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` | -| `--exp_name` | `str` | MLflow experiment name. | timestamp | -| `--run_name` | `str` | MLflow run name. | `run` | -| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | -| `--data_path` | `str` | Root dataset directory. | `data` | +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` | +| `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` | +| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | +| `--run_name` | `str` | MLflow run name. | `run` | +| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | +| `--data_path` | `str` | Root dataset directory. | `data` | +| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` | +| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` | +| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` | +| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` | +| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` | +| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` | --- @@ -69,17 +88,25 @@ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.p ### Arguments -| Argument | Type | Description | Default | -| --------------------- | ----- | ------------------------------------------------------------------------------------------------------------- | --------- | -| `--config` | `str` | **Required**. Path to evaluation config. | `None` | -| `--model_path` | `str` | Path to the trained model state dict. | `None` | -| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on the current data split. | `None` | -| `--exp_name` | `str` | MLflow experiment name. | timestamp | -| `--run_name` | `str` | MLflow run name. | `run` | -| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | -| `--data_path` | `str` | Dataset directory. | `data` | -| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` | -| `--save_output` | `flag` | Save predictions as `_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` | +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Path to evaluation config. | `None` | +| `--model_path` | `str` | Path to the trained model state dict. | `None` | +| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on current split. | `None` | +| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | +| `--run_name` | `str` | MLflow run name. | `run` | +| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | +| `--data_path` | `str` | Dataset directory. | `data` | +| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` | +| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` | +| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` | +| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` | +| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` | +| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` | +| `--save_output` | `flag` | Save predictions as `_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` | ### Example with saved normalizer stats @@ -105,16 +132,44 @@ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model. ### Arguments -| Argument | Type | Description | Default | -| --------------------- | ----- | ------------------------------------------------------------------------------------------------------------- | --------- | -| `--config` | `str` | **Required**. Path to prediction config file. | `None` | -| `--model_path` | `str` | Path to the trained model state dict. | `None` | -| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` | -| `--exp_name` | `str` | MLflow experiment name. | timestamp | -| `--run_name` | `str` | MLflow run name. | `run` | -| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | -| `--data_path` | `str` | Dataset directory. | `data` | -| `--output_path` | `str` | Directory where predictions are saved as `_predictions.parquet`. | `data` | +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Path to prediction config file. | `None` | +| `--model_path` | `str` | Path to trained model state dict. Optional; may be defined in config. | `None` | +| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` | +| `--exp_name` | `str` | MLflow experiment name. | `timestamp` | +| `--run_name` | `str` | MLflow run name. | `run` | +| `--log_dir` | `str` | MLflow logging directory. | `mlruns` | +| `--data_path` | `str` | Dataset directory. | `data` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` | +| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` | +| `--output_path` | `str` | Directory where predictions are saved as `_predictions.parquet`. | `data` | +| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` | +| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` | +| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` | +| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` | + +--- + +## Benchmarking Dataloader Throughput + +```bash +gridfm_graphkit benchmark --config path/to/config.yaml +``` + +### Arguments + +| Argument | Type | Description | Default | +| -------- | ---- | ----------- | ------- | +| `--config` | `str` | **Required**. Path to configuration YAML file. | `None` | +| `--data_path` | `str` | Root dataset directory. | `data` | +| `--epochs` | `int` | Number of epochs to iterate through the train dataloader. | `3` | +| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` | +| `--dataset_wrapper_cache_dir` | `str` | Directory for dataset wrapper disk cache. | `None` | +| `--num_workers` | `int` | Override `data.workers` from YAML. | `None` | +| `--plugins` | `list[str]` | Python packages to import for plugin registration. | `[]` | Use built-in help for full command details: diff --git a/docs/quick_start/yaml_config.md b/docs/quick_start/yaml_config.md index 5976f73c..96c2fa62 100644 --- a/docs/quick_start/yaml_config.md +++ b/docs/quick_start/yaml_config.md @@ -1,170 +1,206 @@ -# The YAML configuration file +# YAML configuration reference -Every experiment in **`gridfm-graphkit`** is defined through a single YAML configuration file. -This file specifies which networks to load, how to normalize the data, which model architecture to build, and how different stages of the workflow should be executed. +Every experiment is driven by one YAML file in `examples/config/`. -Rather than modifying the source code, you simply adjust the YAML file to describe your experiment. This approach makes results reproducible and easy to share: all the important details are stored in one place. -The configuration is divided into sections (`data`, `model`, `training`, `optimizer`, etc.), with each section grouping related options. -We will explain these fields one by one and show how to use them effectively. +## Full example (current style) -For ready-to-use examples, check the folder [**`examples/config/`**](https://github.com/gridfm/gridfm-graphkit/tree/main/examples/config), which contains valid configuration files you can adapt for your own experiments. +```yaml +task: + task_name: OptimalPowerFlow +data: + baseMVA: 100 + mask_value: 0.0 + normalization: HeteroDataMVANormalizer + networks: + - case14_ieee + scenarios: + - 300000 + workers: 32 + split_by_load_scenario_idx: false + split_from_existing_files: "/dccstor/gridfm/march_opf_exp/opfdata_olay_splits/" +model: + attention_head: 8 + edge_dim: 10 + hidden_size: 48 + input_bus_dim: 15 + input_gen_dim: 6 + output_bus_dim: 2 + output_gen_dim: 1 + num_layers: 12 + type: GNS_heterogeneous +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0005 + lr_decay: 0.7 + lr_patience: 5 +training: + batch_size: 64 + epochs: 200 + loss_weights: [0.1, 0.1, 0.75, 0.001] + losses: [LayeredWeightedPhysics, MaskedGenMSE, MaskedBusMSE, QgViolationPenalty] + loss_args: + - base_weight: 0.5 + - {} + - {} + - {} + accelerator: auto + devices: auto + strategy: auto +seed: 0 +verbose: true +callbacks: + patience: 100 + tol: 0 +``` --- -## Data +## Top-level keys -The `data` section defines **which networks and scenarios to use**, as well as **how to prepare and mask the input features**. +- `task`: task-specific settings (`OptimalPowerFlow` or `PowerFlow`). +- `data`: dataset selection, normalization, splits, and loading behavior. +- `model`: model architecture and dimensions. +- `optimizer`: optimizer and scheduler parameters. +- `training`: epochs, loss composition, and accelerator strategy. +- `callbacks`: early stopping behavior. +- `seed`: random seed used for reproducible shuffling/splits. +- `verbose`: enables extra outputs (for example additional test plots/log artifacts). -Example: +--- -```yaml -data: - networks: ["case300_ieee", "case30_ieee"] - scenarios: [8500, 4000] - normalization: baseMVAnorm - baseMVA: 100 - mask_type: rnd - mask_value: 0.0 - mask_ratio: 0.5 - mask_dim: 6 - learn_mask: false - val_ratio: 0.1 - test_ratio: 0.1 - workers: 4 -``` +## `task` section + +### `task.task_name` -**Key fields:** - -- **`networks`**: List of network topologies (e.g., IEEE test cases) used. -- **`scenarios`**: Number of scenarios (samples) for each network. -- **`normalization`**: Method to scale features. Options: - - `minmax`: scale between min and max. - - `standard`: zero mean, unit variance. - - `baseMVAnorm`: divide by base MVA value (see `baseMVA`). - - `identity`: no normalization. -- **`baseMVA`**: Base MVA value from the case file (default: 100). -- **`mask_type`**: Defines how input features are masked: - * `rnd` = random masking (controlled by `mask_ratio` and `mask_dim`). - * `pf` = power flow problem setup. - * `opf` = optimal power flow setup. - * `none` = no masking. -* **`mask_value`**: Numerical value used to mask inputs (default: 0.0). -* **`mask_ratio`**: Probability of masking a feature (only used when `mask_type=rnd`). -* **`mask_dim`**: Number of features that can masked (default: the first 6 → Pd, Qd, Pg, Qg, Vm, Va). -* **`learn_mask`**: If true, the mask value becomes learnable. -* **`val_ratio` / `test_ratio`**: Fractions of the dataset used for validation and testing. -* **`workers`**: Number of data-loading workers +Task name registered in the framework: + +- `OptimalPowerFlow` +- `PowerFlow` --- -## Model +## `data` section -The `model` section specifies the neural network architecture and its hyperparameters. +### `data.networks` -Example: +List of dataset folders under your data root. +Examples: `case14_ieee`, `case118_ieee`, `case2000_goc`, `Texas2k_case1_2016summerpeak`. -```yaml -model: - type: GPSconv - input_dim: 9 - output_dim: 6 - edge_dim: 2 - pe_dim: 20 - num_layers: 6 - hidden_size: 256 - attention_head: 8 - dropout: 0.0 -``` +### `data.scenarios` -**Key fields:** +List of scenario counts, one value per network in `data.networks`. +Example: with two networks, use two scenario entries in matching order. -* **`type`**: Model architecture (e.g., `"GPSconv"`). -* **`input_dim`**: Input feature dimension (default: 9 → Pd, Qd, Pg, Qg, Vm, Va, PQ, PV, REF). -* **`output_dim`**: Output feature dimension (default: 6 → Pd, Qd, Pg, Qg, Vm, Va). -* **`edge_dim`**: Dimension of edge features (default: 2 → G, B). -* **`pe_dim`**: Size of positional encoding (e.g., random walk length). -* **`num_layers`**: Number of layers in the network. -* **`hidden_size`**: Width of hidden layers. -* **`attention_head`**: Number of attention heads. -* **`dropout`**: Dropout probability (default: 0.0). +### `data.normalization` ---- +Normalizer class name: -## Training +- `HeteroDataMVANormalizer`: fit one normalization scale from training data. +- `HeteroDataPerSampleMVANormalizer`: fit per-scenario scales across the selected dataset. -The `training` section defines how the model is optimized and which loss functions are used. +### `data.baseMVA` -Example: +Base MVA reference value (default in examples: `100`). Used by normalizers for per-unit scaling. -```yaml -training: - batch_size: 16 - epochs: 100 - losses: ["MaskedMSE", "PBE"] - loss_weights: [0.01, 0.99] - accelerator: auto - devices: auto - strategy: auto -``` +### `data.mask_value` -**Key fields:** - -* **`batch_size`**: Number of samples per training batch. -* **`epochs`**: Number of training epochs. -* **`losses`**: List of losses to combine. Options: - * `MSE` = Mean Squared Error. - * `MaskedMSE` = Masked Mean Squared Error. - * `SCE` = Scaled Cosine Error. - * `PBE` = Power Balance Equation loss. -* **`loss_weights`**: Relative weights applied to each loss term. -* **`accelerator`**: Device type used for training (cpu, gpu, mps, or auto). -* **`devices`**: Number of devices (GPUs/CPUs) to use (or auto) -* **`strategy`**: Training strategy (e.g., ddp for distributed data parallel, or auto). - -!!! note - On macOS, using accelerator: `cpu` is often the most stable choice. ---- +Fill value used when masking unavailable measurements/features (examples use `0.0`). -## Optimizer +### `data.test_ratio`, `data.val_ratio` -Defines the optimizer and learning rate scheduling. +Fractions for validation and test splits when split files are not supplied. -Example: +### `data.workers` -```yaml -optimizer: - learning_rate: 0.0001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.5 - lr_patience: 3 -``` +Number of dataloader workers. + +### `data.split_by_load_scenario_idx` -**Key fields:** +- `true`: split train/val/test by load scenario identifiers. +- `false`: perform standard random split. -* **`learning_rate`**: Initial learning rate. -* **`beta1`**, **`beta2`**: Adam optimizer parameters (defaults: 0.9, 0.999). -* **`lr_decay`**: Factor to decay the learning rate. -* **`lr_patience`**: Number of epochs to wait before reducing the LR. +### `data.split_from_existing_files` + +Optional path to precomputed split files. When provided: + +- split IDs are loaded from this folder, +- `data.scenarios` is ignored for split construction, +- do **not** combine with `split_by_load_scenario_idx: true`. + +## `model` section + +Current configs use the heterogeneous GNS model: + +- `type`: model registry name (examples use `GNS_heterogeneous`). +- `input_bus_dim`: bus-node input feature dimension. +- `input_gen_dim`: generator-node input feature dimension. +- `output_bus_dim`: bus-node output dimension. +- `output_gen_dim`: generator-node output dimension. +- `edge_dim`: edge feature dimension. +- `hidden_size`: hidden feature width. +- `num_layers`: number of stacked message-passing layers. +- `attention_head`: attention head count per layer. --- -## Callbacks +## `training` section -Callbacks add additional behavior during training, such as early stopping. +### Core training controls -Example: +- `batch_size`: mini-batch size. +- `epochs`: number of epochs. +- `accelerator`: Lightning accelerator (`auto`, `mps`, `cpu`, `gpu`, etc.). +- `devices`: Lightning device selection (`auto`, integer, list). +- `strategy`: Lightning strategy (`auto`, `ddp`, etc.). -```yaml -callbacks: - patience: 100 - tol: 0 -``` +### Multi-loss configuration + +- `losses`: list of registered loss names. +- `loss_weights`: scalar weight per loss. +- `loss_args`: list of argument objects matching `losses` by position. + +All three lists must be aligned (same length and same order). + +Registered loss names in current code: -**Key fields:** +- `LayeredWeightedPhysics` +- `MaskedGenMSE` +- `MaskedBusMSE` +- `QgViolationPenalty` +- `LossPerDim` +- `MaskedMSE` +- `MSE` -* **`patience`**: Number of epochs to wait before early stopping. -* **`tol`**: Minimum improvement required in validation loss to reset patience. +Common `loss_args` patterns: + +- `LayeredWeightedPhysics`: `{base_weight: }` +- `LossPerDim`: `{dim: VM|VA|P_in|Q_in, loss_str: MAE|MSE}` +- `MaskedGenMSE`, `MaskedBusMSE`, `QgViolationPenalty`, `MaskedMSE`, `MSE`: `{}` --- + +## `optimizer` section + +- `learning_rate`: initial learning rate. +- `beta1`, `beta2`: Adam betas. +- `lr_decay`: scheduler decay factor (e.g., ReduceLROnPlateau factor). +- `lr_patience`: epochs to wait before applying LR decay. + +--- + +## `callbacks` section + +- `patience`: early stopping patience (epochs without sufficient improvement). +- `tol`: minimum required improvement threshold to reset patience. + +--- + +## Practical validation checklist + +Before launching a run, verify: + +- `len(data.networks) == len(data.scenarios)`. +- `len(training.losses) == len(training.loss_weights) == len(training.loss_args)`. +- `split_by_load_scenario_idx` and `split_from_existing_files` are not both active. diff --git a/docs/tasks/base_task.md b/docs/tasks/base_task.md deleted file mode 100644 index 1153a2d0..00000000 --- a/docs/tasks/base_task.md +++ /dev/null @@ -1,216 +0,0 @@ -# Base Task - -The `BaseTask` class is an abstract base class that provides the foundation for all task implementations in GridFM-GraphKit. It extends PyTorch Lightning's `LightningModule` and defines the common interface and shared functionality for training, validation, and testing. - -## Overview - -`BaseTask` serves as the parent class for all task-specific implementations, providing: - -- **Abstract method definitions**: Enforces implementation of core methods in subclasses -- **Optimizer configuration**: Sets up AdamW optimizer with learning rate scheduling -- **Normalization statistics logging**: Saves normalization parameters for reproducibility -- **Hyperparameter management**: Automatically saves hyperparameters for experiment tracking - -## BaseTask Class - -::: gridfm_graphkit.tasks.base_task.BaseTask - options: - show_root_heading: true - show_source: true - members: - - __init__ - - forward - - training_step - - validation_step - - test_step - - predict_step - - on_fit_start - - configure_optimizers - -## Methods - -### `__init__(args, data_normalizers)` - -Initialize the base task with configuration and normalizers. - -**Parameters:** - -- `args` (NestedNamespace): Experiment configuration containing all hyperparameters -- `data_normalizers` (list): List of normalizer objects, one per dataset - -**Attributes Set:** - -- `self.args`: Stores the configuration -- `self.data_normalizers`: Stores the normalizers -- Automatically calls `save_hyperparameters()` for experiment tracking - ---- - -### `forward(*args, **kwargs)` (Abstract) - -Defines the forward pass through the model. Must be implemented by subclasses. - -**Returns:** - -- Model output (structure depends on task implementation) - ---- - -### `training_step(batch)` (Abstract) - -Executes one training step. Must be implemented by subclasses. - -**Parameters:** - -- `batch`: A batch of data from the training dataloader - -**Returns:** - -- Loss tensor for backpropagation - ---- - -### `validation_step(batch, batch_idx)` (Abstract) - -Executes one validation step. Must be implemented by subclasses. - -**Parameters:** - -- `batch`: A batch of data from the validation dataloader -- `batch_idx` (int): Index of the current batch - -**Returns:** - -- Loss tensor or metrics dictionary - ---- - -### `test_step(batch, batch_idx, dataloader_idx=0)` (Abstract) - -Executes one test step. Must be implemented by subclasses. - -**Parameters:** - -- `batch`: A batch of data from the test dataloader -- `batch_idx` (int): Index of the current batch -- `dataloader_idx` (int): Index of the dataloader (for multiple test datasets) - -**Returns:** - -- Metrics dictionary or None - ---- - -### `predict_step(batch, batch_idx, dataloader_idx=0)` (Abstract) - -Executes one prediction step. Must be implemented by subclasses. - -**Parameters:** - -- `batch`: A batch of data from the prediction dataloader -- `batch_idx` (int): Index of the current batch -- `dataloader_idx` (int): Index of the dataloader - -**Returns:** - -- Predictions dictionary - ---- - -### `on_fit_start()` - -Called at the beginning of training. Saves normalization statistics to disk. - -**Behavior:** - -- Creates a `stats` directory in the logging directory -- Saves human-readable normalization statistics to `normalization_stats.txt` -- Saves machine-loadable statistics to `normalizer_stats.pt` (PyTorch format) -- Only executes on rank 0 in distributed training (via `@rank_zero_only` decorator) - -**Output Files:** - -1. **`normalization_stats.txt`**: Human-readable text file with statistics for each dataset -2. **`normalizer_stats.pt`**: PyTorch file containing a dictionary keyed by network name - ---- - -### `configure_optimizers()` - -Configures the optimizer and learning rate scheduler. - -**Optimizer:** - -- **Type**: AdamW -- **Learning Rate**: From `args.optimizer.learning_rate` -- **Betas**: From `args.optimizer.beta1` and `args.optimizer.beta2` - -**Scheduler:** - -- **Type**: ReduceLROnPlateau -- **Mode**: Minimize -- **Factor**: From `args.optimizer.lr_decay` -- **Patience**: From `args.optimizer.lr_patience` -- **Monitored Metric**: "Validation loss" - -**Returns:** - -- Dictionary with optimizer and lr_scheduler configuration - -## Usage - -`BaseTask` is not used directly. Instead, create a subclass that implements all abstract methods: - -```python -from gridfm_graphkit.tasks.base_task import BaseTask - -class MyCustomTask(BaseTask): - def __init__(self, args, data_normalizers): - super().__init__(args, data_normalizers) - # Initialize task-specific components - - def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): - # Implement forward pass - pass - - def training_step(self, batch): - # Implement training logic - pass - - def validation_step(self, batch, batch_idx): - # Implement validation logic - pass - - def test_step(self, batch, batch_idx, dataloader_idx=0): - # Implement test logic - pass - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - # Implement prediction logic - pass -``` - -## Configuration Example - -The base task uses the following configuration sections: - -```yaml -optimizer: - learning_rate: 0.001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 5 - -data: - networks: - - case14_ieee - - case118_ieee -``` - -## Related - -- [Reconstruction Task](reconstruction_task.md): Base class for reconstruction tasks -- [Power Flow Task](power_flow.md): Concrete implementation for power flow -- [Optimal Power Flow Task](optimal_power_flow.md): Concrete implementation for OPF -- [State Estimation Task](state_estimation.md): Concrete implementation for state estimation diff --git a/docs/tasks/feature_reconstruction.md b/docs/tasks/feature_reconstruction.md deleted file mode 100644 index fbde3eae..00000000 --- a/docs/tasks/feature_reconstruction.md +++ /dev/null @@ -1,185 +0,0 @@ -# Task Classes Overview - -GridFM-GraphKit provides a hierarchical task system for power grid analysis. All tasks inherit from a common base class and share core functionality while implementing domain-specific logic. - -## Task Hierarchy - -``` -BaseTask (Abstract) - └── ReconstructionTask - ├── PowerFlowTask - ├── OptimalPowerFlowTask - └── StateEstimationTask -``` - -## Available Task Classes - -### Base Classes - -- **[BaseTask](base_task.md)**: Abstract base class providing common functionality for all tasks - - Optimizer configuration - - Learning rate scheduling - - Normalization statistics logging - - Abstract method definitions - -- **[ReconstructionTask](reconstruction_task.md)**: Base class for feature reconstruction tasks - - Model integration - - Loss function handling - - Shared training/validation logic - - Test output management - -### Concrete Task Implementations - -- **[PowerFlowTask](power_flow.md)**: Power flow analysis - - Computes voltage profiles and power flows - - Physics-based validation with Power Balance Error (PBE) - - Separate metrics for PQ, PV, and REF buses - - Detailed per-bus predictions - -- **[OptimalPowerFlowTask](optimal_power_flow.md)**: Optimal power flow with economic optimization - - Minimizes generation costs - - Tracks optimality gap - - Monitors constraint violations (thermal, voltage, angle) - - Evaluates reactive power limits - -- **[StateEstimationTask](state_estimation.md)**: State estimation from noisy measurements - - Handles measurement noise and outliers - - Separate evaluation for outliers, masked values, and clean measurements - - Correlation analysis between predictions, measurements, and targets - -## Quick Reference - -### Method Overview - -All task classes implement the following core methods: - -| Method | Purpose | Implemented In | -|--------|---------|----------------| -| `__init__` | Initialize task with config and normalizers | All classes | -| `forward` | Forward pass through model | ReconstructionTask+ | -| `training_step` | Execute one training step | ReconstructionTask+ | -| `validation_step` | Execute one validation step | ReconstructionTask+ | -| `test_step` | Execute one test step | Concrete tasks | -| `predict_step` | Execute one prediction step | Concrete tasks | -| `on_fit_start` | Save normalization stats before training | BaseTask | -| `on_test_end` | Generate reports and plots after testing | Concrete tasks | -| `configure_optimizers` | Setup optimizer and scheduler | BaseTask | - -### Task Selection - -Tasks are automatically selected based on your YAML configuration: - -```yaml -task: - task_name: PowerFlow # or OptimalPowerFlow, StateEstimation -``` - -The task registry automatically instantiates the correct task class based on the `task_name` field. - -## Common Features - -All tasks share these features: - -### 1. Distributed Training Support -- Multi-GPU training with proper metric synchronization -- Rank 0 handles logging and file I/O -- Automatic gathering of test outputs across ranks - -### 2. Comprehensive Logging -- Training and validation metrics logged to MLflow or TensorBoard -- Automatic hyperparameter tracking -- Normalization statistics saved for reproducibility - -### 3. Test Outputs -- CSV reports with detailed metrics -- Visualization plots (when `verbose=True`) -- Per-dataset analysis for multiple test sets - -### 4. Physics-Based Evaluation -- Power balance error computation -- Branch flow calculations -- Residual analysis by bus type - -## Configuration - -### Basic Configuration - -```yaml -task: - task_name: PowerFlow - verbose: true - -training: - batch_size: 64 - epochs: 100 - losses: ["MaskedMSE", "PBE"] - loss_weights: [0.01, 0.99] - -optimizer: - learning_rate: 0.001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 5 -``` - -### Task-Specific Options - -Each task may have additional configuration options. See the individual task documentation for details: - -- [Power Flow Configuration](power_flow.md#configuration-example) -- [Optimal Power Flow Configuration](optimal_power_flow.md#configuration-example) -- [State Estimation Configuration](state_estimation.md#configuration-example) - -## Creating Custom Tasks - -To create a custom task, extend `ReconstructionTask` or `BaseTask`: - -```python -from gridfm_graphkit.tasks.reconstruction_tasks import ReconstructionTask -from gridfm_graphkit.io.registries import TASK_REGISTRY - -@TASK_REGISTRY.register("MyCustomTask") -class MyCustomTask(ReconstructionTask): - def __init__(self, args, data_normalizers): - super().__init__(args, data_normalizers) - # Add custom initialization - - def test_step(self, batch, batch_idx, dataloader_idx=0): - # Implement custom test logic - output, loss_dict = self.shared_step(batch) - - # Add custom metrics - custom_metric = self.compute_custom_metric(output, batch) - loss_dict["Custom Metric"] = custom_metric - - # Log metrics - for metric, value in loss_dict.items(): - self.log(f"{dataset_name}/{metric}", value) - - return loss_dict["loss"] - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - # Implement custom prediction logic - output, _ = self.shared_step(batch) - return {"predictions": output} - - def on_test_end(self): - # Custom analysis and visualization - # Generate reports, plots, etc. - super().on_test_end() -``` - -Then use it in your configuration: - -```yaml -task: - task_name: MyCustomTask -``` - -## Related Documentation - -- [Loss Functions](../training/loss.md): Available loss functions and their configuration -- [Data Modules](../datasets/data_modules.md): Data loading and preprocessing -- [Models](../models/models.md): Available model architectures -- [Quick Start Guide](../quick_start/quick_start.md): Getting started with training diff --git a/docs/tasks/optimal_power_flow.md b/docs/tasks/optimal_power_flow.md deleted file mode 100644 index 3d13a57f..00000000 --- a/docs/tasks/optimal_power_flow.md +++ /dev/null @@ -1,12 +0,0 @@ -# Optimal Power Flow Task - -## OptimalPowerFlowTask Class - -::: gridfm_graphkit.tasks.opf_task.OptimalPowerFlowTask - options: - show_root_heading: true - show_source: true - members: - - __init__ - - test_step - - on_test_end diff --git a/docs/tasks/power_flow.md b/docs/tasks/power_flow.md deleted file mode 100644 index 8912a267..00000000 --- a/docs/tasks/power_flow.md +++ /dev/null @@ -1,12 +0,0 @@ -# Power Flow Task - -## PowerFlowTask Class - -::: gridfm_graphkit.tasks.pf_task.PowerFlowTask - options: - show_root_heading: true - show_source: true - members: - - __init__ - - test_step - - on_test_end diff --git a/docs/tasks/reconstruction_task.md b/docs/tasks/reconstruction_task.md deleted file mode 100644 index 54e9e5a8..00000000 --- a/docs/tasks/reconstruction_task.md +++ /dev/null @@ -1,293 +0,0 @@ -# Reconstruction Task - -The `ReconstructionTask` class is a concrete implementation of `BaseTask` that provides the foundation for node feature reconstruction on power grid graphs. It wraps a GridFM model and defines the training, validation, and testing logic for reconstructing masked node features. - -## Overview - -`ReconstructionTask` serves as the base class for all reconstruction-based tasks in GridFM-GraphKit, including: - -- Power Flow (PF) -- Optimal Power Flow (OPF) -- State Estimation (SE) - -It provides: - -- **Model integration**: Loads and wraps the GridFM model -- **Loss function handling**: Configures and applies loss functions -- **Shared training logic**: Common training and validation steps -- **Test output management**: Collects and manages test outputs for analysis - -## ReconstructionTask Class - -::: gridfm_graphkit.tasks.reconstruction_tasks.ReconstructionTask - options: - show_root_heading: true - show_source: true - members: - - __init__ - - forward - - shared_step - - training_step - - validation_step - - on_test_end - -## Methods - -### `__init__(args, data_normalizers)` - -Initialize the reconstruction task with model, loss function, and configuration. - -**Parameters:** - -- `args` (NestedNamespace): Experiment configuration with fields like: - - `training.batch_size`: Batch size for training - - `optimizer.*`: Optimizer configuration - - `model.*`: Model architecture configuration - - `training.losses`: List of loss functions to use - - `data.networks`: List of network names -- `data_normalizers` (list): One normalizer per dataset for feature normalization/denormalization - -**Attributes Set:** - -- `self.model`: GridFM model loaded via `load_model()` -- `self.loss_fn`: Loss function resolved from configuration via `get_loss_function()` -- `self.batch_size`: Training batch size -- `self.test_outputs`: Dictionary to store test outputs per dataset (keyed by dataloader index) - -**Example:** - -```python -task = ReconstructionTask(args, data_normalizers) -``` - ---- - -### `forward(x_dict, edge_index_dict, edge_attr_dict, mask_dict)` - -Forward pass through the model. - -**Parameters:** - -- `x_dict` (dict): Node features dictionary with keys like `"bus"`, `"gen"` -- `edge_index_dict` (dict): Edge indices dictionary for heterogeneous edges -- `edge_attr_dict` (dict): Edge attributes dictionary -- `mask_dict` (dict): Masking dictionary indicating which features are masked - -**Returns:** - -- Model output dictionary with predicted node features - -**Example:** - -```python -output = task.forward( - x_dict=batch.x_dict, - edge_index_dict=batch.edge_index_dict, - edge_attr_dict=batch.edge_attr_dict, - mask_dict=batch.mask_dict -) -``` - ---- - -### `shared_step(batch)` - -Common logic for training and validation steps. - -**Parameters:** - -- `batch`: A batch from the dataloader containing: - - `x_dict`: Input node features - - `y_dict`: Target node features - - `edge_index_dict`: Edge connectivity - - `edge_attr_dict`: Edge attributes - - `mask_dict`: Feature masks - -**Returns:** - -- `output` (dict): Model predictions -- `loss_dict` (dict): Dictionary containing: - - `"loss"`: Total loss value - - Additional loss components (if applicable) - -**Behavior:** - -1. Performs forward pass through the model -2. Computes loss using the configured loss function -3. Returns both predictions and loss dictionary - -**Example:** - -```python -output, loss_dict = task.shared_step(batch) -total_loss = loss_dict["loss"] -``` - ---- - -### `training_step(batch)` - -Execute one training step. - -**Parameters:** - -- `batch`: Training batch from dataloader - -**Returns:** - -- Loss tensor for backpropagation - -**Logged Metrics:** - -- `"Training Loss"`: Total training loss -- `"Learning Rate"`: Current learning rate - -**Logging Configuration:** - -- `batch_size`: Number of graphs in batch -- `sync_dist=False`: No synchronization across GPUs during training -- `on_epoch=False`: Log per step, not per epoch -- `on_step=True`: Log at each training step -- `prog_bar=False`: Don't show in progress bar -- `logger=True`: Send to logger (e.g., MLflow) - ---- - -### `validation_step(batch, batch_idx)` - -Execute one validation step. - -**Parameters:** - -- `batch`: Validation batch from dataloader -- `batch_idx` (int): Index of the current batch - -**Returns:** - -- Loss tensor - -**Logged Metrics:** - -- `"Validation loss"`: Total validation loss -- Additional loss components (if multiple losses are used) - -**Logging Configuration:** - -- `batch_size`: Number of graphs in batch -- `sync_dist=True`: Synchronize metrics across GPUs -- `on_epoch=True`: Aggregate and log at epoch end -- `on_step=False`: Don't log individual steps -- `logger=True`: Send to logger - -**Note:** The validation loss is monitored by the learning rate scheduler for automatic learning rate reduction. - ---- - -### `on_test_end()` - -Called at the end of testing. Clears stored test outputs. - -**Behavior:** - -- Clears the `self.test_outputs` dictionary -- Only executes on rank 0 in distributed training (via `@rank_zero_only` decorator) -- Subclasses typically override this to add custom analysis, plotting, and CSV generation - -**Note:** This is a minimal implementation. Task-specific subclasses (PowerFlowTask, OptimalPowerFlowTask, StateEstimationTask) override this method to: - -- Generate detailed metrics CSV files -- Create visualization plots -- Save analysis results - ---- - -## Usage - -`ReconstructionTask` can be used directly for simple reconstruction tasks, but is typically subclassed for specific power system tasks: - -```python -from gridfm_graphkit.tasks.reconstruction_tasks import ReconstructionTask - -# Direct usage (simple reconstruction) -task = ReconstructionTask(args, data_normalizers) - -# Or create a subclass for custom behavior -class CustomReconstructionTask(ReconstructionTask): - def test_step(self, batch, batch_idx, dataloader_idx=0): - # Custom test logic - output, loss_dict = self.shared_step(batch) - # Add custom metrics - return loss_dict["loss"] - - def on_test_end(self): - # Custom analysis and visualization - super().on_test_end() -``` - -## Configuration Example - -```yaml -task: - task_name: Reconstruction # Or PowerFlow, OptimalPowerFlow, StateEstimation - -model: - type: GNS_heterogeneous - hidden_size: 48 - num_layers: 12 - attention_head: 8 - -training: - batch_size: 64 - epochs: 100 - losses: - - MaskedMSE - loss_weights: - - 1.0 - -optimizer: - learning_rate: 0.001 - beta1: 0.9 - beta2: 0.999 - lr_decay: 0.7 - lr_patience: 5 -``` - -## Loss Functions - -The reconstruction task supports various loss functions configured via the YAML file: - -- **MaskedMSE**: Mean squared error on masked features only -- **MaskedBusMSE**: MSE specifically for bus node features -- **LayeredWeightedPhysics**: Physics-based loss with layer-wise weighting -- **PBE**: Power Balance Error loss - -Multiple losses can be combined with weights: - -```yaml -training: - losses: - - LayeredWeightedPhysics - - MaskedBusMSE - loss_weights: - - 0.1 - - 0.9 - loss_args: - - base_weight: 0.5 - - {} -``` - -## Subclasses - -The following task classes extend `ReconstructionTask`: - -- **[PowerFlowTask](power_flow.md)**: Adds power flow-specific metrics and physics validation -- **[OptimalPowerFlowTask](optimal_power_flow.md)**: Adds economic optimization metrics and constraint violation tracking -- **[StateEstimationTask](state_estimation.md)**: Adds measurement-based estimation and outlier handling - -## Related - -- [Base Task](base_task.md): Abstract base class for all tasks -- [Power Flow Task](power_flow.md): Power flow analysis implementation -- [Optimal Power Flow Task](optimal_power_flow.md): OPF optimization implementation -- [State Estimation Task](state_estimation.md): State estimation implementation -- [Loss Functions](../training/loss.md): Available loss functions diff --git a/docs/tasks/state_estimation.md b/docs/tasks/state_estimation.md deleted file mode 100644 index c3adbbcb..00000000 --- a/docs/tasks/state_estimation.md +++ /dev/null @@ -1,11 +0,0 @@ -# State Estimation Task - -::: gridfm_graphkit.tasks.se_task.StateEstimationTask - options: - show_root_heading: true - show_source: true - members: - - __init__ - - test_step - - on_test_end - - predict_step diff --git a/docs/training/loss.md b/docs/training/loss.md deleted file mode 100644 index de56d4ba..00000000 --- a/docs/training/loss.md +++ /dev/null @@ -1,47 +0,0 @@ -# Loss Functions - -## Base Loss - -::: gridfm_graphkit.training.loss.BaseLoss - ---- - -## Mean Squared Error Loss - -::: gridfm_graphkit.training.loss.MSELoss - ---- - -## Masked Mean Squared Error Loss - -::: gridfm_graphkit.training.loss.MaskedMSELoss - ---- - -## Masked Generator MSE Loss - -::: gridfm_graphkit.training.loss.MaskedGenMSE - ---- - -## Masked Bus MSE Loss - -::: gridfm_graphkit.training.loss.MaskedBusMSE - ---- - -## Mixed Loss - -::: gridfm_graphkit.training.loss.MixedLoss - ---- - -## Layered Weighted Physics Loss - -::: gridfm_graphkit.training.loss.LayeredWeightedPhysicsLoss - ---- - -## Loss Per Dimension - -::: gridfm_graphkit.training.loss.LossPerDim diff --git a/docs/tutorials/contingency_analysis.md b/docs/tutorials/contingency_analysis.md deleted file mode 100644 index db31962a..00000000 --- a/docs/tutorials/contingency_analysis.md +++ /dev/null @@ -1,326 +0,0 @@ -👉 [Link](https://github.com/gridfm/gridfm-graphkit/tree/main/examples/notebooks) to the tutorial notebooks on Github -👉 [Link](https://colab.research.google.com/github/gridfm/gridfm-graphkit/blob/main/examples/notebooks/Tutorial_contingency_analisys.ipynb) to the tutorial on Google Colab - ---- -Contingency analysis is a critical process in power system operations used to assess the impact of potential failures (e.g., line outages) on grid stability and reliability. It helps operators prepare for unexpected events by simulating scenarios such as N-1 or N-2 contingencies, where one or more components are removed from service. This analysis ensures that the grid can continue to operate within safe limits even under stressed conditions. - ---- - -## Dataset Generation and Model Evaluation - -The dataset used in this study originates from the Texas transmission grid, which includes approximately 2,000 nodes. Using the contingency mode of the `gridfm-datakit`, we simulated N-2 contingencies by removing up to two transmission lines at a time. For each scenario, we first solved the optimal power flow (OPF) problem to determine the generation dispatch. Then, we applied the contingency by removing lines and re-solved the power flow to observse the resulting grid state. - -This process generated around 100,000 unique scenarios. Our model, **GridFMv0.1**, was fine-tuned on this dataset to predict power flow outcomes. For demonstration purposes, we selected a subsample of 10 scenarios. The `gridfm-datakit` also computed DC power flow results, enabling a comparison between GridFMv0.1 predictions and traditional DC power flow estimates, specifically in terms of line loading accuracy. - -All predictions are benchmarked against the ground truth obtained from AC power flow simulations. Additionally, we analyze bus voltage violations, which GridFM can predict but are not captured by the DC solver, highlighting GridFM’s enhanced capabilities in modeling grid behavior. - -```python -import sys - -if "google.colab" in sys.modules: - try: - !git clone https://github.com/gridfm/gridfm-graphkit.git - %cd /content/gridfm-graphkit - !pip install . - %cd examples/notebooks/ - except Exception as e: - print(f"Failed to start Google Collab setup, due to {e}") -``` - -```python -from gridfm_graphkit.datasets.postprocessing import ( - compute_branch_currents_kA, - compute_loading, -) -from gridfm_graphkit.datasets.postprocessing import create_admittance_matrix -from gridfm_graphkit.utils.utils import compute_cm_metrics -from gridfm_graphkit.utils.visualization import ( - plot_mass_correlation_density, - plot_cm, - plot_loading_predictions, - plot_mass_correlation_density_voltage, -) - -import os -from tqdm import tqdm -import matplotlib.pyplot as plt -from sklearn.metrics import f1_score -import numpy as np -import pandas as pd -``` - -## Load Data - -We load both the ground truth and predicted values of the power flow solution. The predictions are generated using the `gridfm-graphkit` CLI: - -```bash -gridfm-graphkit predict ... -``` - -We then merge the datasets using `scenario` and `bus` as keys, allowing us to align the predicted and actual values for each grid state and bus. - -```python -root_pred_folder = "../data/contingency_texas/" -prediction_dir = "prediction_gridfm01" -label_plot = "GridFM_v0.1 Fine-tuned" - -pf_node_GT = pd.read_csv(os.path.join(root_pred_folder, "pf_node_10_examples.csv")) -pg_node_predicted = pd.read_csv( - os.path.join(root_pred_folder, "predictions_10_examples.csv") -) - -branch_idx_removed = pd.read_csv("{}branch_idx_removed.csv".format(root_pred_folder)) -edge_params = pd.read_csv("{}edge_params.csv".format(root_pred_folder)) -bus_params = pd.read_csv("{}bus_params.csv".format(root_pred_folder)) - -pf_node = pg_node_predicted.merge(pf_node_GT, on=["scenario", "bus"], how="left") -``` - -## Create Admittance matrix - -```python -sn_mva = 100 -Yf, Yt, Vf_base_kV, Vt_base_kV = create_admittance_matrix( - bus_params, edge_params, sn_mva -) -rate_a = edge_params["rate_a"] -``` - -## Correct voltage predictions for GridFM and DC - -```python -pf_node["Vm_pred_corrected"] = pf_node["VM_pred"] -pf_node["Va_pred_corrected"] = pf_node["VA_pred"] - -pf_node.loc[pf_node.PV == 1, "Vm_pred_corrected"] = pf_node.loc[pf_node.PV == 1, "Vm"] -pf_node.loc[pf_node.REF == 1, "Va_pred_corrected"] = pf_node.loc[pf_node.REF == 1, "Va"] - -pf_node["Vm_dc_corrected"] = pf_node["Vm_dc"] -pf_node["Va_dc_corrected"] = pf_node["Va_dc"] - -pf_node.loc[pf_node.PV == 1, "Vm_dc_corrected"] = pf_node.loc[pf_node.PV == 1, "Vm"] -pf_node.loc[pf_node.REF == 1, "Va_dc_corrected"] = pf_node.loc[pf_node.REF == 1, "Va"] -``` - -## Compute branch current and line loading - -```python -loadings = [] -loadings_pred = [] -loadings_dc = [] - -for scenario_idx in tqdm(pf_node.scenario.unique()): - pf_node_scenario = pf_node[pf_node.scenario == scenario_idx] - branch_idx_removed_scenario = ( - branch_idx_removed[branch_idx_removed.scenario == scenario_idx] - .iloc[:, 1:] - .values - ) - # remove nan - branch_idx_removed_scenario = branch_idx_removed_scenario[ - ~np.isnan(branch_idx_removed_scenario) - ].astype(np.int32) - V_true = pf_node_scenario["Vm"].values * np.exp( - 1j * pf_node_scenario["Va"].values * np.pi / 180 - ) - V_pred = pf_node_scenario["Vm_pred_corrected"].values * np.exp( - 1j * pf_node_scenario["Va_pred_corrected"].values * np.pi / 180 - ) - V_dc = pf_node_scenario["Vm_dc_corrected"].values * np.exp( - 1j * pf_node_scenario["Va_dc_corrected"].values * np.pi / 180 - ) - If_true, It_true = compute_branch_currents_kA( - Yf, Yt, V_true, Vf_base_kV, Vt_base_kV, sn_mva - ) - If_pred, It_pred = compute_branch_currents_kA( - Yf, Yt, V_pred, Vf_base_kV, Vt_base_kV, sn_mva - ) - If_dc, It_dc = compute_branch_currents_kA( - Yf, Yt, V_dc, Vf_base_kV, Vt_base_kV, sn_mva - ) - - loading_true = compute_loading(If_true, It_true, Vf_base_kV, Vt_base_kV, rate_a) - loading_pred = compute_loading(If_pred, It_pred, Vf_base_kV, Vt_base_kV, rate_a) - loading_dc = compute_loading(If_dc, It_dc, Vf_base_kV, Vt_base_kV, rate_a) - - # remove the branches that are removed from loading - loading_true[branch_idx_removed_scenario] = -1 - loading_pred[branch_idx_removed_scenario] = -1 - loading_dc[branch_idx_removed_scenario] = -1 - - loadings.append(loading_true) - loadings_pred.append(loading_pred) - loadings_dc.append(loading_dc) - - -loadings = np.array(loadings) -loadings_pred = np.array(loadings_pred) -loadings_dc = np.array(loadings_dc) -removed_lines = loadings == -1 -removed_lines_pred = loadings_pred == -1 -removed_lines_dc = loadings_dc == -1 - - -# assert the same lines are removed -assert (removed_lines == removed_lines_pred).all() -assert (removed_lines == removed_lines_dc).all() - -# assert the same number of lines are removed -assert removed_lines.sum() == removed_lines_pred.sum() -assert removed_lines.sum() == removed_lines_dc.sum() - -overloadings = loadings[not removed_lines] > 1.0 -overloadings_pred = loadings_pred[not removed_lines] > 1.0 -overloadings_dc = loadings_dc[not removed_lines] > 1.0 -``` - -## Histogram of true line loadings - -```python -plt.hist(loadings[not removed_lines], bins=100) -plt.xlabel("Line Loadings") -plt.ylabel("Frequency") -# log scale -plt.savefig(f"loadings_histogram_{prediction_dir}.png") -plt.show() -``` -

- True loading -
-

- - -## Predicted vs True line loading - -```python -# Valid lines -valid_mask = not removed_lines - -# Extract valid values -true_vals = loadings[valid_mask] -gfm_vals = loadings_pred[valid_mask] -dc_vals = loadings_dc[valid_mask] -``` - -```python -plot_mass_correlation_density(true_vals, gfm_vals, prediction_dir, label_plot) -``` -

- Loading gridfm -
-

- -```python -plot_mass_correlation_density(true_vals, dc_vals, "DC", "DC Solver") -``` -

- Loading DC -
-

- -```python -plot_cm(TN_gridfm, FP_gridfm, FN_gridfm, TP_gridfm, prediction_dir, label_plot) -``` -

- Confusion gridfm -
-

- -```python -plot_cm(TN_dc, FP_dc, FN_dc, TP_dc, "DC", "DC Solver") -``` -

- Confusion DC -
-

- - -```python -# Histograms of loadings -plot_loading_predictions( - loadings_pred[not removed_lines], - loadings_dc[not removed_lines], - loadings[not removed_lines], - prediction_dir, - label_plot, -) -``` -

- Loading predictions -
-

- - -```python -# create df from loadings -loadings_df = pd.DataFrame(loadings) -loadings_df.columns = [f"branch_{i}" for i in range(loadings_df.shape[1])] - -loadings_pred_df = pd.DataFrame(loadings_pred) -loadings_pred_df.columns = [f"branch_{i}" for i in range(loadings_pred_df.shape[1])] - -loadings_df["scenario"] = pf_node["scenario"].unique() -loadings_pred_df["scenario"] = pf_node["scenario"].unique() - -# make bar plot of wrongly classified loadings for different bins -bins = np.arange(0, 2.2, 0.2) -mse_pred = [] -mse_dc = [] -for i in range(len(bins) - 1): - idx_in_bins = (loadings[not removed_lines] > bins[i]) & ( - loadings[not removed_lines] < bins[i + 1] - ) - mse_pred.append( - np.mean( - ( - loadings_pred[not removed_lines][idx_in_bins] - - loadings[not removed_lines][idx_in_bins] - ) - ** 2 - ) - ) - mse_dc.append( - np.mean( - ( - loadings_dc[not removed_lines][idx_in_bins] - - loadings[not removed_lines][idx_in_bins] - ) - ** 2 - ) - ) - - -# labels -labels = [f"{bins[i]:.1f}-{bins[i + 1]:.1f}" for i in range(len(bins) - 1)] -plt.bar(labels, mse_pred, label=label_plot, alpha=0.5) -plt.bar(labels, mse_dc, label="DC", alpha=0.5) -plt.legend() -plt.xlabel("Loadings") -plt.ylabel("MSE") -# y log scale -plt.yscale("log") -# rotate x labels -plt.xticks(rotation=45) -plt.savefig(f"loading_mse_{prediction_dir}.png") -plt.show() -``` - -

- MSE loading -
-

- - -## Voltage violations - -```python -# merge bus_params["vmax"] and bus_params["vmin"] with pf_node on bus_idx -pf_node = pd.merge(pf_node, bus_params[["bus", "vmax", "vmin"]], on="bus", how="left") - -plot_mass_correlation_density_voltage(pf_node, prediction_dir, label_plot) -``` - -

- Correlation voltage -
-

diff --git a/docs/tutorials/feature_reconstruction.md b/docs/tutorials/feature_reconstruction.md deleted file mode 100644 index dff2fafe..00000000 --- a/docs/tutorials/feature_reconstruction.md +++ /dev/null @@ -1,201 +0,0 @@ -👉 [Link](https://github.com/gridfm/gridfm-graphkit/tree/main/examples/notebooks) to the tutorial notebooks on Github -👉 [Link](https://colab.research.google.com/github/gridfm/gridfm-graphkit/blob/main/examples/notebooks/Tutorial_reconstruction_visualization.ipynb) to the tutorial on Google Colab - ---- -This notebook demonstrates the state reconstruction capabilities of **GridFM-v0.2**, a graph-based neural network model for transmission grids. We focus on the IEEE case30 network, a standard benchmark with 30 buses, chosen for its compact size and suitability for visualization. - -The dataset includes **1,023 load scenarios**, each representing a different operating condition of the grid. For each scenario, the model reconstructs the six key features of the power flow solution that are masked: - -- Active Power Demand (MW) -- Reactive Power Demand (MVar) -- Active Power Generated (MW) -- Reactive Power Generated (MVar) -- Voltage Magnitude (p.u.) -- Voltage Angle (degrees) - -```python -import sys - -if "google.colab" in sys.modules: - try: - !git clone https://github.com/gridfm/gridfm-graphkit.git - %cd /content/gridfm-graphkit - !pip install . - %cd examples/notebooks/ - except Exception as e: - print(f"Failed to start Google Collab setup, due to {e}") -``` - -```python -from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule -from gridfm_graphkit.io.param_handler import NestedNamespace -from gridfm_graphkit.tasks.feature_reconstruction_task import FeatureReconstructionTask -from gridfm_graphkit.utils.visualization import visualize_error, visualize_quantity_heatmap -from gridfm_graphkit.datasets.globals import PD, QD, PG, QG, VM, VA - -import yaml -import torch -import numpy as np -import random -``` - -## Load YAML configuration file - -```python -config_path = "../config/case30_ieee_base.yaml" -with open(config_path) as f: - config_dict = yaml.safe_load(f) - -config_args = NestedNamespace(**config_dict) -torch.manual_seed(config_args.seed) -random.seed(config_args.seed) -np.random.seed(config_args.seed) -``` - -## Initialize the DataModule - -```python -data_module = LitGridDataModule(config_args, "../data") -data_module.setup("test") -test_loader = data_module.test_dataloader() -``` - -## Load the pre-trained model GridFM-v0.2 - -```python -model = FeatureReconstructionTask( - config_args, data_module.node_normalizers, data_module.edge_normalizers -) -state_dict = torch.load("../models/GridFM_v0_2.pth") -model.load_state_dict(state_dict) -``` - -## Perform inference, batch size is equal to 1 for further visualization purposes - -```python -batch = next(iter(test_loader[0])) - -model.eval() -with torch.no_grad(): - output = model( - x=batch.x, - pe=batch.pe, - edge_index=batch.edge_index, - edge_attr=batch.edge_attr, - batch=batch.batch, - mask=batch.mask, - ) -``` - -## Visualize Nodal Active Power Residuals - -```python -visualize_error(batch, output, data_module.node_normalizers[0]) -``` - -

- Active Residual -
-

- - -## Visualize the state reconstruction capability of gridFM-v0.2 for each feature: -- Active Power Demand (MW) -- Reactive Power Demand (MVar) -- Active Power Generated (MW) -- Reactive Power Generated (MVar) -- Voltage Magnitude (p.u.) -- Voltage Angle (degrees) - -```python -visualize_quantity_heatmap( - batch, - output, - PD, - "Active Power Demand", - "MW", - data_module.node_normalizers[0], -) -``` -

- Active demand -
-

- - -```python -visualize_quantity_heatmap( - batch, - output, - QD, - "Reactive Power Demand", - "MVar", - data_module.node_normalizers[0], -) -``` -

- Reactive demand -
-

- - - -```python -visualize_quantity_heatmap( - batch, - output, - PG, - "Active Power Generated", - "MW", - data_module.node_normalizers[0], -) -``` -

- Active generated -
-

- -```python -visualize_quantity_heatmap( - batch, - output, - QG, - "Reactive Power Generated", - "MVar", - data_module.node_normalizers[0], -) -``` -

- Reactive generated -
-

- -```python -visualize_quantity_heatmap( - batch, - output, - VM, - "Voltage magnitude", - "p.u.", - data_module.node_normalizers[0], -) -``` -

- Voltage magnitude -
-

- -```python -visualize_quantity_heatmap( - batch, - output, - VA, - "Voltage Angle", - "degrees", - data_module.node_normalizers[0], -) -``` -

- Voltage angle -
-

diff --git a/docs/tutorials/figs/MSE_loading.png b/docs/tutorials/figs/MSE_loading.png deleted file mode 100644 index ce0dd191..00000000 Binary files a/docs/tutorials/figs/MSE_loading.png and /dev/null differ diff --git a/docs/tutorials/figs/active_demand.png b/docs/tutorials/figs/active_demand.png deleted file mode 100644 index d0340286..00000000 Binary files a/docs/tutorials/figs/active_demand.png and /dev/null differ diff --git a/docs/tutorials/figs/active_generated.png b/docs/tutorials/figs/active_generated.png deleted file mode 100644 index 797f9f71..00000000 Binary files a/docs/tutorials/figs/active_generated.png and /dev/null differ diff --git a/docs/tutorials/figs/active_residuals.png b/docs/tutorials/figs/active_residuals.png deleted file mode 100644 index b2c28772..00000000 Binary files a/docs/tutorials/figs/active_residuals.png and /dev/null differ diff --git a/docs/tutorials/figs/confusion_DC.png b/docs/tutorials/figs/confusion_DC.png deleted file mode 100644 index 8e197a8b..00000000 Binary files a/docs/tutorials/figs/confusion_DC.png and /dev/null differ diff --git a/docs/tutorials/figs/confusion_gridfm.png b/docs/tutorials/figs/confusion_gridfm.png deleted file mode 100644 index 7479ed0c..00000000 Binary files a/docs/tutorials/figs/confusion_gridfm.png and /dev/null differ diff --git a/docs/tutorials/figs/correlation_voltage.png b/docs/tutorials/figs/correlation_voltage.png deleted file mode 100644 index 6b53369e..00000000 Binary files a/docs/tutorials/figs/correlation_voltage.png and /dev/null differ diff --git a/docs/tutorials/figs/hist_true_loadings.png b/docs/tutorials/figs/hist_true_loadings.png deleted file mode 100644 index 0c7f503a..00000000 Binary files a/docs/tutorials/figs/hist_true_loadings.png and /dev/null differ diff --git a/docs/tutorials/figs/loading_DC.png b/docs/tutorials/figs/loading_DC.png deleted file mode 100644 index 0bac8375..00000000 Binary files a/docs/tutorials/figs/loading_DC.png and /dev/null differ diff --git a/docs/tutorials/figs/loading_gridfm.png b/docs/tutorials/figs/loading_gridfm.png deleted file mode 100644 index f9b23ec5..00000000 Binary files a/docs/tutorials/figs/loading_gridfm.png and /dev/null differ diff --git a/docs/tutorials/figs/loading_predictions.png b/docs/tutorials/figs/loading_predictions.png deleted file mode 100644 index 210b97cd..00000000 Binary files a/docs/tutorials/figs/loading_predictions.png and /dev/null differ diff --git a/docs/tutorials/figs/reactive_demand.png b/docs/tutorials/figs/reactive_demand.png deleted file mode 100644 index 57261410..00000000 Binary files a/docs/tutorials/figs/reactive_demand.png and /dev/null differ diff --git a/docs/tutorials/figs/reactive_generated.png b/docs/tutorials/figs/reactive_generated.png deleted file mode 100644 index 7709390d..00000000 Binary files a/docs/tutorials/figs/reactive_generated.png and /dev/null differ diff --git a/docs/tutorials/figs/voltage_angle.png b/docs/tutorials/figs/voltage_angle.png deleted file mode 100644 index 7cf74b6f..00000000 Binary files a/docs/tutorials/figs/voltage_angle.png and /dev/null differ diff --git a/docs/tutorials/figs/voltage_magnitude.png b/docs/tutorials/figs/voltage_magnitude.png deleted file mode 100644 index 92112def..00000000 Binary files a/docs/tutorials/figs/voltage_magnitude.png and /dev/null differ diff --git a/examples/config/HGNS_OPFData_case2000.yaml b/examples/config/HGNS_OPFData_case2000.yaml deleted file mode 100644 index 3ec576e5..00000000 --- a/examples/config/HGNS_OPFData_case2000.yaml +++ /dev/null @@ -1,52 +0,0 @@ -callbacks: - patience: 100 - tol: 0 -task: - task_name: OptimalPowerFlow -data: - baseMVA: 100 - mask_value: 0.0 - normalization: HeteroDataMVANormalizer - networks: - - case2000_goc - scenarios: - - 300000 - test_ratio: 0.1 - val_ratio: 0.1 - workers: 16 -model: - attention_head: 8 - edge_dim: 10 - hidden_size: 48 - input_bus_dim: 15 - input_gen_dim: 6 - output_bus_dim: 2 - output_gen_dim: 1 - num_layers: 12 - type: GNS_heterogeneous -optimizer: - beta1: 0.9 - beta2: 0.999 - learning_rate: 0.0005 - lr_decay: 0.7 - lr_patience: 5 -seed: 0 -training: - batch_size: 16 - epochs: 200 - loss_weights: - - 0.1 - - 0.1 - - 0.8 - losses: - - LayeredWeightedPhysics - - MaskedGenMSE - - MaskedBusMSE - loss_args: - - base_weight: 0.5 - - {} - - {} - accelerator: auto - devices: auto - strategy: auto -verbose: true diff --git a/examples/config/HGNS_OPFData_case118.yaml b/examples/config/HGNS_OPF_Ola_case118.yaml similarity index 80% rename from examples/config/HGNS_OPFData_case118.yaml rename to examples/config/HGNS_OPF_Ola_case118.yaml index 3f57ac90..66206256 100644 --- a/examples/config/HGNS_OPFData_case118.yaml +++ b/examples/config/HGNS_OPF_Ola_case118.yaml @@ -1,6 +1,3 @@ -callbacks: - patience: 100 - tol: 0 task: task_name: OptimalPowerFlow data: @@ -11,9 +8,9 @@ data: - case118_ieee scenarios: - 300000 - test_ratio: 0.1 - val_ratio: 0.1 - workers: 16 + workers: 32 + split_by_load_scenario_idx: false + split_from_existing_files: "/dccstor/gridfm/march_opf_exp/opfdata_olay_splits/" model: attention_head: 8 edge_dim: 10 @@ -30,23 +27,30 @@ optimizer: learning_rate: 0.0005 lr_decay: 0.7 lr_patience: 5 -seed: 0 training: batch_size: 64 epochs: 200 loss_weights: - 0.1 - 0.1 - - 0.8 + - 0.75 + - 0.001 losses: - LayeredWeightedPhysics - MaskedGenMSE - MaskedBusMSE + - QgViolationPenalty + loss_args: - base_weight: 0.5 - {} - {} + - {} accelerator: auto devices: auto strategy: auto +seed: 0 verbose: true +callbacks: + patience: 100 + tol: 0 diff --git a/examples/config/HGNS_OPFData_case14.yaml b/examples/config/HGNS_OPF_Ola_case14.yaml similarity index 81% rename from examples/config/HGNS_OPFData_case14.yaml rename to examples/config/HGNS_OPF_Ola_case14.yaml index 714995dc..a2353297 100644 --- a/examples/config/HGNS_OPFData_case14.yaml +++ b/examples/config/HGNS_OPF_Ola_case14.yaml @@ -8,9 +8,9 @@ data: - case14_ieee scenarios: - 300000 - test_ratio: 0.1 - val_ratio: 0.1 workers: 32 + split_by_load_scenario_idx: false + split_from_existing_files: "/dccstor/gridfm/march_opf_exp/opfdata_olay_splits/" model: attention_head: 8 edge_dim: 10 @@ -27,25 +27,29 @@ optimizer: learning_rate: 0.0005 lr_decay: 0.7 lr_patience: 5 -seed: 0 training: batch_size: 64 epochs: 200 loss_weights: - 0.1 - 0.1 - - 0.8 + - 0.75 + - 0.001 losses: - LayeredWeightedPhysics - MaskedGenMSE - MaskedBusMSE + - QgViolationPenalty + loss_args: - base_weight: 0.5 - {} - {} + - {} accelerator: auto devices: auto strategy: auto +seed: 0 verbose: true callbacks: patience: 100 diff --git a/examples/config/HGNS_OPFData_case30.yaml b/examples/config/HGNS_OPF_Ola_case2000.yaml similarity index 80% rename from examples/config/HGNS_OPFData_case30.yaml rename to examples/config/HGNS_OPF_Ola_case2000.yaml index f8bba01e..a0bdd900 100644 --- a/examples/config/HGNS_OPFData_case30.yaml +++ b/examples/config/HGNS_OPF_Ola_case2000.yaml @@ -1,6 +1,3 @@ -callbacks: - patience: 100 - tol: 0 task: task_name: OptimalPowerFlow data: @@ -8,12 +5,12 @@ data: mask_value: 0.0 normalization: HeteroDataMVANormalizer networks: - - case30_ieee + - case2000_goc scenarios: - 300000 - test_ratio: 0.1 - val_ratio: 0.1 workers: 32 + split_by_load_scenario_idx: false + split_from_existing_files: "/dccstor/gridfm/march_opf_exp/opfdata_olay_splits/" model: attention_head: 8 edge_dim: 10 @@ -30,23 +27,30 @@ optimizer: learning_rate: 0.0005 lr_decay: 0.7 lr_patience: 5 -seed: 0 training: batch_size: 64 epochs: 200 loss_weights: - 0.1 - 0.1 - - 0.8 + - 0.75 + - 0.001 losses: - LayeredWeightedPhysics - MaskedGenMSE - MaskedBusMSE + - QgViolationPenalty + loss_args: - base_weight: 0.5 - {} - {} + - {} accelerator: auto devices: auto strategy: auto +seed: 0 verbose: true +callbacks: + patience: 100 + tol: 0 diff --git a/examples/config/HGNS_OPFData_case500.yaml b/examples/config/HGNS_OPF_Ola_case500.yaml similarity index 80% rename from examples/config/HGNS_OPFData_case500.yaml rename to examples/config/HGNS_OPF_Ola_case500.yaml index 442eac78..9cdb27f4 100644 --- a/examples/config/HGNS_OPFData_case500.yaml +++ b/examples/config/HGNS_OPF_Ola_case500.yaml @@ -1,6 +1,3 @@ -callbacks: - patience: 100 - tol: 0 task: task_name: OptimalPowerFlow data: @@ -11,9 +8,9 @@ data: - case500_goc scenarios: - 300000 - test_ratio: 0.1 - val_ratio: 0.1 - workers: 16 + workers: 32 + split_by_load_scenario_idx: false + split_from_existing_files: "/dccstor/gridfm/march_opf_exp/opfdata_olay_splits/" model: attention_head: 8 edge_dim: 10 @@ -30,23 +27,30 @@ optimizer: learning_rate: 0.0005 lr_decay: 0.7 lr_patience: 5 -seed: 0 training: batch_size: 16 epochs: 200 loss_weights: - 0.1 - 0.1 - - 0.8 + - 0.75 + - 0.001 losses: - LayeredWeightedPhysics - MaskedGenMSE - MaskedBusMSE + - QgViolationPenalty + loss_args: - base_weight: 0.5 - {} - {} + - {} accelerator: auto devices: auto strategy: auto +seed: 0 verbose: true +callbacks: + patience: 100 + tol: 0 diff --git a/examples/config/HGNS_OPFData_case57.yaml b/examples/config/HGNS_OPF_Ola_case57.yaml similarity index 81% rename from examples/config/HGNS_OPFData_case57.yaml rename to examples/config/HGNS_OPF_Ola_case57.yaml index 024efa8c..74af6377 100644 --- a/examples/config/HGNS_OPFData_case57.yaml +++ b/examples/config/HGNS_OPF_Ola_case57.yaml @@ -1,6 +1,3 @@ -callbacks: - patience: 100 - tol: 0 task: task_name: OptimalPowerFlow data: @@ -11,9 +8,9 @@ data: - case57_ieee scenarios: - 300000 - test_ratio: 0.1 - val_ratio: 0.1 workers: 32 + split_by_load_scenario_idx: false + split_from_existing_files: "/dccstor/gridfm/march_opf_exp/opfdata_olay_splits/" model: attention_head: 8 edge_dim: 10 @@ -30,23 +27,30 @@ optimizer: learning_rate: 0.0005 lr_decay: 0.7 lr_patience: 5 -seed: 0 training: batch_size: 64 epochs: 200 loss_weights: - 0.1 - 0.1 - - 0.8 + - 0.75 + - 0.001 losses: - LayeredWeightedPhysics - MaskedGenMSE - MaskedBusMSE + - QgViolationPenalty + loss_args: - base_weight: 0.5 - {} - {} + - {} accelerator: auto devices: auto strategy: auto +seed: 0 verbose: true +callbacks: + patience: 100 + tol: 0 diff --git a/examples/config/HGNS_OPF_datakit_case14.yaml b/examples/config/HGNS_OPF_datakit_case14.yaml index 4ed0aa7b..c4a8c532 100644 --- a/examples/config/HGNS_OPF_datakit_case14.yaml +++ b/examples/config/HGNS_OPF_datakit_case14.yaml @@ -34,15 +34,18 @@ training: loss_weights: - 0.1 - 0.1 - - 0.8 + - 0.75 + - 0.001 losses: - LayeredWeightedPhysics - MaskedGenMSE - MaskedBusMSE + - QgViolationPenalty loss_args: - base_weight: 0.5 - {} - {} + - {} accelerator: auto devices: auto strategy: auto diff --git a/examples/config/HGNS_PF_pfdelta_case118.yaml b/examples/config/HGNS_PF_pfdelta_case118.yaml deleted file mode 100644 index 42202176..00000000 --- a/examples/config/HGNS_PF_pfdelta_case118.yaml +++ /dev/null @@ -1,54 +0,0 @@ -callbacks: - patience: 100 - tol: 0 -task: - task_name: PowerFlow -data: - baseMVA: 100 - mask_value: 0.0 - normalization: HeteroDataMVANormalizer - normalizationMVA: 240 - networks: - - case118_ieee_N_1 - - case118_ieee_N_2 - - case118_ieee_N - scenarios: - - 20000 - - 20000 - - 20000 - test_ratio: 0.1 - val_ratio: 0.1 - workers: 16 -model: - attention_head: 8 - edge_dim: 10 - hidden_size: 48 - input_bus_dim: 15 - input_gen_dim: 6 - output_bus_dim: 2 - output_gen_dim: 1 - num_layers: 12 - type: GNS_heterogeneous -optimizer: - beta1: 0.9 - beta2: 0.999 - learning_rate: 0.0005 - lr_decay: 0.7 - lr_patience: 5 -seed: 0 -training: - batch_size: 64 - epochs: 300 - loss_weights: - - 0.1 - - 0.9 - losses: - - LayeredWeightedPhysics - - MaskedBusMSE - loss_args: - - base_weight: 0.5 - - {} - accelerator: auto - devices: auto - strategy: auto -verbose: true diff --git a/examples/config/HGNS_SE_datakit_case118.yaml b/examples/config/HGNS_SE_datakit_case118.yaml deleted file mode 100644 index 6c97a7e8..00000000 --- a/examples/config/HGNS_SE_datakit_case118.yaml +++ /dev/null @@ -1,75 +0,0 @@ -callbacks: - patience: 100 - tol: 0 -task: - task_name: StateEstimation - noise_type: Gaussian - measurements: - power_inj: - mask_ratio: 0.2 - outlier_ratio: 0.1 - std: 0.02 - power_flow: - mask_ratio: 0.2 - outlier_ratio: 0.1 - std: 0.02 - vm: - mask_ratio: 0.2 - outlier_ratio: 0.1 - std: 0.02 - relative_measurement: true -data: - baseMVA: 100 - mask_value: 0.0 - normalization: HeteroDataMVANormalizer - networks: - - case118_ieee - scenarios: - - 100000 - split_by_load_scenario_idx: true - test_ratio: 0.1 - val_ratio: 0.1 - workers: 32 -model: - attention_head: 8 - edge_dim: 10 - hidden_size: 48 - input_bus_dim: 15 - input_gen_dim: 6 - output_bus_dim: 2 - output_gen_dim: 1 - num_layers: 12 - type: GNS_heterogeneous -optimizer: - beta1: 0.9 - beta2: 0.999 - learning_rate: 0.0005 - lr_decay: 0.7 - lr_patience: 5 -seed: 0 -training: - batch_size: 32 - epochs: 200 - losses: - - LossPerDim - - LossPerDim - - LossPerDim - - LossPerDim - loss_weights: - - 0.35 - - 0.35 - - 0.1 - - 0.1 - loss_args: - - dim: VM - loss_str: MAE - - dim: VA - loss_str: MAE - - dim: P_in - loss_str: MAE - - dim: Q_in - loss_str: MAE - accelerator: auto - devices: auto - strategy: auto -verbose: true diff --git a/examples/config/HGNS_SE_datakit_case14.yaml b/examples/config/HGNS_SE_datakit_case14.yaml deleted file mode 100644 index 09ad20af..00000000 --- a/examples/config/HGNS_SE_datakit_case14.yaml +++ /dev/null @@ -1,75 +0,0 @@ -callbacks: - patience: 100 - tol: 0 -task: - task_name: StateEstimation - noise_type: Gaussian - measurements: - power_inj: - mask_ratio: 0.2 - outlier_ratio: 0.1 - std: 0.02 - power_flow: - mask_ratio: 0.2 - outlier_ratio: 0.1 - std: 0.02 - vm: - mask_ratio: 0.2 - outlier_ratio: 0.1 - std: 0.02 - relative_measurement: true -data: - baseMVA: 100 - mask_value: 0.0 - normalization: HeteroDataMVANormalizer - networks: - - case14_ieee - scenarios: - - 100000 - split_by_load_scenario_idx: true - test_ratio: 0.1 - val_ratio: 0.1 - workers: 32 -model: - attention_head: 8 - edge_dim: 10 - hidden_size: 48 - input_bus_dim: 15 - input_gen_dim: 6 - output_bus_dim: 2 - output_gen_dim: 1 - num_layers: 12 - type: GNS_heterogeneous -optimizer: - beta1: 0.9 - beta2: 0.999 - learning_rate: 0.0005 - lr_decay: 0.7 - lr_patience: 5 -seed: 0 -training: - batch_size: 32 - epochs: 200 - losses: - - LossPerDim - - LossPerDim - - LossPerDim - - LossPerDim - loss_weights: - - 0.35 - - 0.35 - - 0.1 - - 0.1 - loss_args: - - dim: VM - loss_str: MAE - - dim: VA - loss_str: MAE - - dim: P_in - loss_str: MAE - - dim: Q_in - loss_str: MAE - accelerator: auto - devices: auto - strategy: auto -verbose: true diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index 5b173a61..b693089c 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -3,7 +3,64 @@ from gridfm_graphkit.cli import main_cli, benchmark_cli +import subprocess +import os + +def is_lsf(): + return ( + os.environ.get("LSB_JOBID") is not None + and os.environ.get("LSB_MCPU_HOSTS") is not None + and "LSF_ENVDIR" in os.environ # strong LSF indicator + ) + +def fix_infiniband(): + """Configure NCCL to skip Ethernet-only IB ports on this host.""" + ibv = subprocess.run("ibv_devinfo", stdout=subprocess.PIPE, stderr=subprocess.PIPE) + lines = ibv.stdout.decode("utf-8").split("\n") + exclude = "" + for line in lines: + if "hca_id:" in line: + name = line.split(":")[1].strip() + if "\tport:" in line: + port = line.split(":")[1].strip() + if "link_layer:" in line and "Ethernet" in line: + exclude = exclude + f"{name}:{port}," + + if exclude: + exclude = "^" + exclude[:-1] + os.environ["NCCL_IB_HCA"] = exclude + + +def set_env(): + """Populate distributed-training environment variables from LSF metadata.""" + # print("Using " + str(torch.cuda.device_count()) + " GPUs---------------------------------------------------------------------") + LSB_MCPU_HOSTS = os.environ[ + "LSB_MCPU_HOSTS" + ].split( + " ", + ) # Parses Node list set by LSF, in format hostname proceeded by number of cores requested + HOST_LIST = LSB_MCPU_HOSTS[::2] # Strips the cores per node items in the list + LSB_JOBID = os.environ[ + "LSB_JOBID" + ] # Parses Node list set by LSF, in format hostname proceeded by number of cores requested + os.environ["MASTER_ADDR"] = HOST_LIST[ + 0 + ] # Sets the MasterNode to thefirst node on the list of hosts + os.environ["MASTER_PORT"] = "5" + LSB_JOBID[-5:-1] + os.environ["NODE_RANK"] = str( + HOST_LIST.index(os.environ["HOSTNAME"]), + ) # Uses the list index for node rank, master node rank must be 0 + os.environ["NCCL_SOCKET_IFNAME"] = ( + "ib,bond" # avoids using docker of loopback interface + ) + os.environ["NCCL_IB_CUDA_SUPPORT"] = "1" # Force use of infiniband + def main(): + """Parse CLI arguments and dispatch to the selected GridFM subcommand.""" + if is_lsf(): + print("Using LSF") + set_env() + fix_infiniband() parser = argparse.ArgumentParser( prog="gridfm_graphkit", description="gridfm-graphkit CLI", @@ -76,12 +133,18 @@ def main(): choices=["simple", "advanced", "pytorch"], help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.", ) + train_parser.add_argument( + "--compute_dc_ac_metrics", + action="store_true", + ) train_parser.add_argument( "--report-performance", dest="report_performance", action="store_true", help="Print the last training epoch time and a single test metric to stdout.", ) + + # ---- FINETUNE SUBCOMMAND ---- finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning") finetune_parser.add_argument("--config", type=str, required=True) finetune_parser.add_argument("--model_path", type=str, required=True) @@ -123,6 +186,10 @@ def main(): choices=["simple", "advanced", "pytorch"], help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.", ) + finetune_parser.add_argument( + "--compute_dc_ac_metrics", + action="store_true", + ) finetune_parser.add_argument( "--report-performance", dest="report_performance", diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index ed2e9b34..0ffd1364 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -15,7 +15,8 @@ import pandas as pd from gridfm_graphkit.io.param_handler import get_task -from gridfm_graphkit.tasks.compute_ac_dc_metrics import compute_ac_dc_metrics +from gridfm_graphkit.tasks.opf_ac_dc_baseline import compute_opf_ac_dc_metrics +from gridfm_graphkit.tasks.pf_ac_dc_baseline import compute_pf_ac_dc_metrics from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers import MLFlowLogger @@ -23,6 +24,17 @@ import lightning as L +def _normalize_loaded_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Map legacy torch.compile checkpoint keys to the canonical model namespace.""" + has_compiled_prefix = any(key.startswith("model._orig_mod.") for key in state_dict) + if not has_compiled_prefix: + return state_dict + return { + key.replace("model._orig_mod.", "model."): value + for key, value in state_dict.items() + } + + def _load_plugins(plugins: list[str]) -> None: """Import plugin packages so their registry decorators fire.""" for plugin_pkg in plugins: @@ -35,6 +47,15 @@ def _load_plugins(plugins: list[str]) -> None: ) from e +def _predictions_to_dataframe(predictions: list[dict[str, np.ndarray]]) -> pd.DataFrame: + """Convert a list of prediction batch dicts into one concatenated DataFrame.""" + rows = {key: [] for key in predictions[0].keys()} + for batch in predictions: + for key in rows: + rows[key].append(batch[key]) + return pd.DataFrame({key: np.concatenate(vals) for key, vals in rows.items()}) + + def _validate_dataset_wrapper(name: str | None) -> None: """Raise a helpful error if *name* is not registered in DATASET_WRAPPER_REGISTRY.""" if name is None: @@ -104,6 +125,7 @@ def benchmark_cli(args): def get_training_callbacks(args): + """Build the standard callback stack used for train/finetune runs.""" early_stop_callback = EarlyStopping( monitor="Validation loss", min_delta=args.callbacks.tol, @@ -129,6 +151,7 @@ def get_training_callbacks(args): def main_cli(args): + """Run a GridFM CLI command using config-driven datamodule and trainer setup.""" if getattr(args, "tf32", False): torch.set_float32_matmul_precision("high") # enables TF32 on Ampere+ GPUs @@ -168,6 +191,7 @@ def main_cli(args): if args.command != "train": print(f"Loading model weights from {args.model_path}") state_dict = torch.load(args.model_path, map_location="cpu") + state_dict = _normalize_loaded_state_dict_keys(state_dict) model.load_state_dict(state_dict) precision = "bf16-true" if getattr(args, "bfloat16", False) else None @@ -279,12 +303,20 @@ def main_cli(args): ) compute_dc_ac = getattr(args, "compute_dc_ac_metrics", False) + task_type = {"optimalpowerflow": "opf", "powerflow": "pf"}.get( + str(getattr(getattr(config_args, "task", None), "task_name", "")).lower(), + ) if is_rank0 and compute_dc_ac: sn_mva = config_args.data.baseMVA for grid_name in config_args.data.networks: raw_dir = os.path.join(args.data_path, grid_name, "raw") print(f"\nComputing ground-truth AC/DC metrics for {grid_name}...") - compute_ac_dc_metrics(artifacts_dir, raw_dir, grid_name, sn_mva) + if task_type == "opf": + compute_opf_ac_dc_metrics(artifacts_dir, raw_dir, grid_name, sn_mva) + elif task_type == "pf": + compute_pf_ac_dc_metrics(artifacts_dir, raw_dir, grid_name, sn_mva) + else: + raise ValueError(f"Invalid task: {task_type}") save_output = getattr(args, "save_output", False) or args.command == "predict" if is_rank0 and save_output: @@ -305,19 +337,27 @@ def main_cli(args): ) predictions = predict_trainer.predict(model=model, datamodule=litGrid) - rows = {key: [] for key in predictions[0].keys()} - for batch in predictions: - for key in rows: - rows[key].append(batch[key]) - - df = pd.DataFrame({key: np.concatenate(vals) for key, vals in rows.items()}) - grid_name = config_args.data.networks[0] if args.command == "predict": output_dir = args.output_path else: output_dir = os.path.join(artifacts_dir, "test") os.makedirs(output_dir, exist_ok=True) - out_path = os.path.join(output_dir, f"{grid_name}_predictions.parquet") - df.to_parquet(out_path, index=False) - print(f"Saved predictions to {out_path}") + first_prediction = predictions[0] + if any(isinstance(value, dict) for value in first_prediction.values()): + for table_name in first_prediction: + df = _predictions_to_dataframe( + [batch[table_name] for batch in predictions], + ) + suffix = "" if table_name == "bus" else f"_{table_name}" + out_path = os.path.join( + output_dir, + f"{grid_name}{suffix}_predictions.parquet", + ) + df.to_parquet(out_path, index=False) + print(f"Saved {table_name} predictions to {out_path}") + else: + df = _predictions_to_dataframe(predictions) + out_path = os.path.join(output_dir, f"{grid_name}_predictions.parquet") + df.to_parquet(out_path, index=False) + print(f"Saved predictions to {out_path}") diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 0efb1731..e5374970 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -14,12 +14,14 @@ from gridfm_graphkit.datasets.utils import ( split_dataset, split_dataset_by_load_scenario_idx, + split_from_existing_files, ) from gridfm_graphkit.datasets.powergrid_hetero_dataset import HeteroGridDatasetDisk import numpy as np import random import warnings import lightning as L +from pathlib import Path from typing import List from lightning.pytorch.loggers import MLFlowLogger @@ -101,6 +103,11 @@ def __init__( "split_by_load_scenario_idx", False, ) + self.split_from_existing_files = getattr( + args.data, + "split_from_existing_files", + None, + ) self.args = args self.normalizer_stats_path = normalizer_stats_path self.data_normalizers = [] @@ -113,6 +120,15 @@ def __init__( self.test_scenario_ids: List[List[int]] = [] self._is_setup_done = False + if self.split_by_load_scenario_idx: + assert self.split_from_existing_files is None, " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both" + + if self.split_from_existing_files is not None: + assert isinstance(self.split_from_existing_files, str), "`split_from_existing_files` must be an existing folder in string format" + self.split_from_existing_files = Path(self.split_from_existing_files) + assert self.split_from_existing_files.is_dir(), "`split_from_existing_files` must be an existing folder in string format" + + def setup(self, stage: str): if self._is_setup_done: print(f"Setup already done for stage={stage}, skipping...") @@ -167,53 +183,93 @@ def setup(self, stage: str): # Create a subset all_indices = list(range(len(dataset))) - # Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order - random.seed(self.args.seed) - random.shuffle(all_indices) - subset_indices = all_indices[:num_scenarios] - # load_scenario for each scenario in the subset - load_scenarios = dataset.load_scenarios[subset_indices] - dataset = Subset(dataset, subset_indices) + if self.split_from_existing_files is not None: + warnings.warn( + "`data.scenarios` is ignored when `split_from_existing_files` is set; " + "train/val/test scenario ids are loaded from the provided split files.", + ) - if self.dataset_wrapper is not None: - wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) - dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir) + if self.dataset_wrapper is not None: + wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) + dataset = wrapper_cls( + dataset, + cache_dir=self.dataset_wrapper_cache_dir, + ) - # Random seed set before every split, same as above - np.random.seed(self.args.seed) - if self.split_by_load_scenario_idx: - train_dataset, val_dataset, test_dataset = ( - split_dataset_by_load_scenario_idx( + (train_dataset, val_dataset, test_dataset), subset_indices = ( + split_from_existing_files( + dataset, + self.split_from_existing_files, + ) + ) + train_scenario_ids = subset_indices["train"] + val_scenario_ids = subset_indices["val"] + test_scenario_ids = subset_indices["test"] + num_scenarios = int( + np.unique( + train_scenario_ids + val_scenario_ids + test_scenario_ids, + ).shape[0], + ) + else: + # Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order + random.seed(self.args.seed) + random.shuffle(all_indices) + subset_indices = all_indices[:num_scenarios] + + load_scenarios = None + if self.split_by_load_scenario_idx: + if not hasattr(dataset, "load_scenarios"): + raise ValueError( + "`data.split_by_load_scenario_idx=true` requires " + "`load_scenario_idx` in raw bus data so " + "`processed/load_scenarios.pt` can be created.", + ) + # load_scenario for each scenario in the subset + load_scenarios = dataset.load_scenarios[subset_indices] + + + dataset = Subset(dataset, subset_indices) + + if self.dataset_wrapper is not None: + wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) + dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir) + + + # Random seed set before every split, same as above + np.random.seed(self.args.seed) + if self.split_by_load_scenario_idx: + train_dataset, val_dataset, test_dataset = ( + split_dataset_by_load_scenario_idx( + dataset, + self.data_dir, + load_scenarios, + self.args.data.val_ratio, + self.args.data.test_ratio, + ) + ) + else: + train_dataset, val_dataset, test_dataset = split_dataset( dataset, self.data_dir, - load_scenarios, self.args.data.val_ratio, self.args.data.test_ratio, ) + + # Extract scenario IDs for each split + train_scenario_ids = self._extract_scenario_ids( + train_dataset, + subset_indices, ) - else: - train_dataset, val_dataset, test_dataset = split_dataset( - dataset, - self.data_dir, - self.args.data.val_ratio, - self.args.data.test_ratio, + val_scenario_ids = self._extract_scenario_ids( + val_dataset, + subset_indices, + ) + test_scenario_ids = self._extract_scenario_ids( + test_dataset, + subset_indices, ) - - # Extract scenario IDs for each split - train_scenario_ids = self._extract_scenario_ids( - train_dataset, - subset_indices, - ) - val_scenario_ids = self._extract_scenario_ids( - val_dataset, - subset_indices, - ) - test_scenario_ids = self._extract_scenario_ids( - test_dataset, - subset_indices, - ) # Fit normalizer: restore from saved stats only for fit_on_train # normalizers (global baseMVA must match the model's training run). @@ -386,6 +442,7 @@ def _dataloader_kwargs(self): return kwargs def train_dataloader(self): + print("creating train dataloader for rank ", dist.get_rank() if dist.is_available() and dist.is_initialized() else "not distributed") return DataLoader( self.train_dataset_multi, batch_size=self.batch_size, diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index b615e0cb..df2f657c 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -157,6 +157,7 @@ def forward(self, data): class BusToGenBroadcaster(MessagePassing): + """Broadcast per-bus values to connected generators via graph propagation.""" def __init__(self, aggr="add"): super().__init__(aggr=aggr) @@ -174,6 +175,7 @@ def message(self, x_j): class SimulateMeasurements(BaseTransform): + """Add configurable noise/outliers and masks to simulate measured quantities.""" def __init__(self, args): super().__init__() self.measurements = args.task.measurements diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index 11601a66..eb5652d7 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -228,8 +228,8 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= self.baseMVA - data.baseMVA = self.baseMVA - data.is_normalized = True + data.baseMVA = torch.tensor(self.baseMVA, dtype=data.x_dict["bus"].dtype) # # needs to be float32 for MPS + data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS def inverse_transform(self, data: HeteroData): if self.baseMVA is None or self.baseMVA == 0: @@ -299,7 +299,7 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= self.baseMVA - data.is_normalized = False + data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS def inverse_output(self, output, batch): bus_output = output["bus"] @@ -510,10 +510,10 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= e_b - data.is_normalized = True + data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS def inverse_transform(self, data: HeteroData): - """Undo per-unit normalization (multiply by baseMVA, rad->deg, inverse log1p for cost coeffs).""" + """Undo per-unit normalization (multiply by baseMVA, inverse log1p for cost coeffs).""" if self._baseMVA_lookup is None: raise ValueError("Normalizer not fitted or lookups not loaded") if not data.is_normalized.all(): @@ -573,7 +573,7 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= e_b - data.is_normalized = False + data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS def inverse_output(self, output, batch): """ diff --git a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py index 32d3e5ee..82f57a57 100644 --- a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py +++ b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py @@ -55,7 +55,6 @@ def processed_done_file(self): @property def processed_file_names(self): return [ - "load_scenarios.pt", self.processed_done_file, ] @@ -72,11 +71,11 @@ def process(self): bus_data["scenario"].min() == 0 and bus_data["scenario"].max() == len(bus_data["scenario"].unique()) - 1 ) - - load_scenarios = torch.tensor( - bus_data.groupby("scenario", sort=True)["load_scenario_idx"].first().values, - ) - torch.save(load_scenarios, osp.join(self.processed_dir, "load_scenarios.pt")) + if "load_scenario_idx" in bus_data.columns: + load_scenarios = torch.tensor( + bus_data.groupby("scenario", sort=True)["load_scenario_idx"].first().values, + ) + torch.save(load_scenarios, osp.join(self.processed_dir, "load_scenarios.pt")) agg_gen = ( gen_data.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] @@ -137,7 +136,8 @@ def process(self): ] + common_branch_features # Group by scenario - bus_groups = bus_data.groupby("scenario") + bus_groups = bus_data.groupby("scenario") # Groupby preserves the order of rows within each group. + # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.groupby.html gen_groups = gen_data.groupby("scenario") branch_groups = branch_data.groupby("scenario") @@ -158,12 +158,17 @@ def process(self): # Bus nodes bus_df = bus_groups.get_group(scenario) + # assert that the buses are in increasing order + assert (bus_df["bus"].values == torch.arange(len(bus_df))).all(), "Buses are not in increasing order" + #todo: we should remove this assert and store the bus idx in the tensors + # right now we need the increasing order for e.g. the predict step that uses torch.arange(n_nodes) to index the buses. data["bus"].x = torch.tensor(bus_df[bus_features].values, dtype=torch.float) # Generator nodes gen_df = gen_groups.get_group(scenario).reset_index() data["gen"].x = torch.tensor(gen_df[gen_features].values, dtype=torch.float) gen_df["gen_index"] = gen_df.index # Use actual index as generator ID + # todo: change this to instead use the generator id as the index data["bus"].y = data["bus"].x[:, : (VA_H + 1)].clone() data["gen"].y = data["gen"].x[:, : (PG_H + 1)].clone() diff --git a/gridfm_graphkit/datasets/task_transforms.py b/gridfm_graphkit/datasets/task_transforms.py index eaaca66c..dffb66cb 100644 --- a/gridfm_graphkit/datasets/task_transforms.py +++ b/gridfm_graphkit/datasets/task_transforms.py @@ -15,6 +15,7 @@ @TRANSFORM_REGISTRY.register("PowerFlow") class PowerFlowTransforms(Compose): + """Compose preprocessing and masking transforms for PowerFlow datasets.""" def __init__(self, args): transforms = [] @@ -29,6 +30,7 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("OptimalPowerFlow") class OptimalPowerFlowTransforms(Compose): + """Compose preprocessing and masking transforms for OptimalPowerFlow datasets.""" def __init__(self, args): transforms = [] @@ -43,6 +45,7 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("StateEstimation") class StateEstimationTransforms(Compose): + """Compose preprocessing and measurement transforms for StateEstimation datasets.""" def __init__(self, args): transforms = [] diff --git a/gridfm_graphkit/datasets/transforms.py b/gridfm_graphkit/datasets/transforms.py index 2f730337..c6891dc2 100644 --- a/gridfm_graphkit/datasets/transforms.py +++ b/gridfm_graphkit/datasets/transforms.py @@ -96,6 +96,7 @@ def forward(self, data): class LoadGridParamsFromPath(BaseTransform): + """Inject static grid parameters from a saved grid template into each sample.""" def __init__(self, args): super().__init__() self.grid_path = args.task.grid_path diff --git a/gridfm_graphkit/datasets/utils.py b/gridfm_graphkit/datasets/utils.py index f330d496..65b34f4e 100644 --- a/gridfm_graphkit/datasets/utils.py +++ b/gridfm_graphkit/datasets/utils.py @@ -3,6 +3,7 @@ from typing import Tuple from torch import Tensor import torch +from pathlib import Path def split_dataset( @@ -58,6 +59,7 @@ def split_dataset_by_load_scenario_idx( val_ratio: float = 0.1, test_ratio: float = 0.1, ) -> Tuple[Subset, Subset, Subset]: + """Split dataset by unique load-scenario IDs to avoid scenario leakage.""" if val_ratio + test_ratio >= 1: raise ValueError("The sum of val_ratio and test_ratio must be less than 1.") @@ -90,3 +92,30 @@ def split_dataset_by_load_scenario_idx( test_dataset = Subset(dataset, test_indices) return train_dataset, val_dataset, test_dataset + + +def split_from_existing_files( + dataset, + splits_folder: Path, +) -> Tuple[Subset, Subset, Subset]: + """Build train/val/test subsets from split index files. + + Expects `train.pt`, `val.pt`, and `test.pt` inside `splits_folder`. + Returns both the dataset subsets and the raw scenario ids per split. + """ + output = [] + + indices = {} + + for split in ["train", "val", "test"]: + split_file = splits_folder / f"{split}.pt" + assert split_file.is_file(), f"{str(split_file)} does not exist" + split_indices = torch.load(str(split_file), weights_only=True) + split_dataset = Subset(dataset, split_indices) + output.append(split_dataset) + split_indices = list(split_indices) + print(f'{split=} {len(split_indices)=}') + indices[split]=[int(t.item()) for t in split_indices] + + output = tuple(output) + return output, indices \ No newline at end of file diff --git a/gridfm_graphkit/io/registries.py b/gridfm_graphkit/io/registries.py index 32feb20a..65d596a9 100644 --- a/gridfm_graphkit/io/registries.py +++ b/gridfm_graphkit/io/registries.py @@ -1,4 +1,5 @@ class Registry: + """Simple name-to-object registry with decorator-based registration.""" def __init__(self, name: str): self._name = name self._registry = {} diff --git a/gridfm_graphkit/models/utils.py b/gridfm_graphkit/models/utils.py index ea4ecafe..bc4b9bfa 100644 --- a/gridfm_graphkit/models/utils.py +++ b/gridfm_graphkit/models/utils.py @@ -73,6 +73,7 @@ def forward(self, Pft, Qft, edge_index, num_bus): def compute_shunt_power(bus_data_pred, bus_data_orig): + """Compute active/reactive shunt power contributions per bus.""" p_shunt = -bus_data_orig[:, GS] * bus_data_pred[:, VM_OUT] ** 2 q_shunt = bus_data_orig[:, BS] * bus_data_pred[:, VM_OUT] ** 2 return p_shunt, q_shunt @@ -80,6 +81,7 @@ def compute_shunt_power(bus_data_pred, bus_data_orig): @PHYSICS_DECODER_REGISTRY.register("OptimalPowerFlow") class PhysicsDecoderOPF(nn.Module): + """Map network outputs to OPF-consistent bus states using physics constraints.""" def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): mask_pv = mask_dict["PV"] mask_ref = mask_dict["REF"] @@ -114,6 +116,7 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("PowerFlow") class PhysicsDecoderPF(nn.Module): + """Map network outputs to PF-consistent bus states using physics constraints.""" def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): """ PF decoder: @@ -161,6 +164,7 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("StateEstimation") class PhysicsDecoderSE(nn.Module): + """Map network outputs to SE targets via bus power-balance relations.""" def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): p_shunt, q_shunt = compute_shunt_power(bus_data_pred, bus_data_orig) Vm_out = bus_data_pred[:, VM_OUT] @@ -184,4 +188,5 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig): def bound_with_sigmoid(pred, low, high): + """Squash unconstrained predictions into [low, high] with a sigmoid map.""" return low + (high - low) * torch.sigmoid(pred) diff --git a/gridfm_graphkit/tasks/base_task.py b/gridfm_graphkit/tasks/base_task.py index 90c8f7b5..fc2b95e3 100644 --- a/gridfm_graphkit/tasks/base_task.py +++ b/gridfm_graphkit/tasks/base_task.py @@ -20,6 +20,20 @@ def __init__(self, args, data_normalizers): self.data_normalizers = data_normalizers self.save_hyperparameters() + def transfer_batch_to_device(self, batch, device, dataloader_idx): + """Pre-cast float64 tensors before moving batches onto MPS. + + PyTorch MPS does not support float64 tensors. Some PyG metadata fields can + get collated as float64 even when model inputs are float32, so coerce them + first and then delegate to Lightning's standard device transfer. + """ + if getattr(device, "type", None) == "mps" and hasattr(batch, "stores"): + for store in batch.stores: + for key, val in store.items(): + if isinstance(val, torch.Tensor) and val.dtype == torch.float64: + store[key] = val.to(torch.float32) + return super().transfer_batch_to_device(batch, device, dataloader_idx) + def on_after_batch_transfer(self, batch, dataloader_idx: int): """Cast float tensors in HeteroData batches to the model's parameter dtype. diff --git a/gridfm_graphkit/tasks/opf_ac_dc_baseline.py b/gridfm_graphkit/tasks/opf_ac_dc_baseline.py new file mode 100644 index 00000000..8f08917c --- /dev/null +++ b/gridfm_graphkit/tasks/opf_ac_dc_baseline.py @@ -0,0 +1,309 @@ +"""Compute AC/DC OPF baseline metrics on test splits. + +Uses the same AC/DC power-balance and residual aggregation as +:mod:`gridfm_graphkit.tasks.pf_ac_dc_baseline` (via shared helpers). + +Adds OPF-style inequality metrics. Compared to +:mod:`gridfm_graphkit.tasks.opf_task` ``test_step``: + +- **Residuals / runtime**: Same formulas as the PF baseline (ground-truth + ``compute_bus_balance`` on parquet), not the neural ``ComputeNodeResiduals`` + in the task. +- **Optimality**: The task logs ``Opt gap`` = mean(|cost_pred − cost_gt| / cost_gt) + per scenario (model vs label). Here **DC Mean optimality gap (%)** is the mean + over scenarios of |cost_dc − cost_ac| / cost_ac × 100 with scenario totals from + ``p_mw`` vs ``p_mw_dc`` (DC solution vs AC reference) +- **Branch thermal / angle / Qg**: Same relu-style violations and flat means as + the task; **Pg bound** violations are baseline-only (not logged in ``opf_task``). +""" + +import json +import os + +import numpy as np +import pandas as pd +from gridfm_datakit.utils.power_balance import ( + compute_branch_powers_vectorized, + compute_bus_balance, +) +from gridfm_graphkit.tasks.pf_ac_dc_baseline import ( + N_SCENARIO_PER_PARTITION, + NUM_PROCESSES, + _compute_residual_stats, + _compute_runtime_stats, +) + + +def _load_test_data(data_dir: str, test_scenario_ids: list[int]): + """Load OPF test-split bus/gen/branch/runtime tables from partitioned parquet.""" + partitions = sorted(set(s // N_SCENARIO_PER_PARTITION for s in test_scenario_ids)) + test_set = set(test_scenario_ids) + partition_filter = [("scenario_partition", "in", partitions)] + + bus_df = pd.read_parquet( + os.path.join(data_dir, "bus_data.parquet"), + filters=partition_filter, + ) + gen_df = pd.read_parquet( + os.path.join(data_dir, "gen_data.parquet"), + filters=partition_filter, + ) + branch_df = pd.read_parquet( + os.path.join(data_dir, "branch_data.parquet"), + filters=partition_filter, + ) + branch_df = branch_df.drop(columns=["pf_dc", "pt_dc"], axis=1) + runtime_df = pd.read_parquet( + os.path.join(data_dir, "runtime_data.parquet"), + filters=partition_filter, + ) + + bus_df = bus_df[bus_df["scenario"].isin(test_set)].reset_index(drop=True) + gen_df = gen_df[gen_df["scenario"].isin(test_set)].reset_index(drop=True) + branch_df = branch_df[branch_df["scenario"].isin(test_set)].reset_index(drop=True) + runtime_df = runtime_df[runtime_df["scenario"].isin(test_set)].reset_index(drop=True) + + print( + f" Loaded {len(bus_df)} bus rows, {len(gen_df)} gen rows, " + f"{len(branch_df)} branch rows, {len(runtime_df)} runtime rows " + f"for {len(test_set)} test scenarios", + ) + return bus_df, gen_df, branch_df, runtime_df + + +def _compute_optimality_gap(gen_df: pd.DataFrame) -> dict: + """Compute mean AC/DC scenario-level optimality gap from generator costs.""" + # Same aggregation as opf_task scatter_add + mean over graphs, but compares + # scenario DC cost vs AC cost (not model pred vs GT). + c0 = gen_df["cp0_eur"].to_numpy(dtype=float) + c1 = gen_df["cp1_eur_per_mw"].to_numpy(dtype=float) + c2 = gen_df["cp2_eur_per_mw2"].to_numpy(dtype=float) + pg_ac = gen_df["p_mw"].to_numpy(dtype=float) + pg_dc = gen_df["p_mw_dc"].to_numpy(dtype=float) + g = gen_df.copy() + g["cost_ac"] = c0 + c1 * pg_ac + c2 * pg_ac * pg_ac # all is already in MW + g["cost_dc"] = c0 + c1 * pg_dc + c2 * pg_dc * pg_dc # all is already in MW + per_scenario = g.groupby("scenario")[["cost_ac", "cost_dc"]].sum() + cost_ac = per_scenario["cost_ac"].to_numpy(dtype=float) + cost_dc = per_scenario["cost_dc"].to_numpy(dtype=float) + gap_pct = np.abs((cost_dc - cost_ac) / cost_ac * 100.0) + return { + "AC Mean optimality gap (%)": 0.0, + "DC Mean optimality gap (%)": float(np.nanmean(gap_pct)), + } + + +def _compute_pg_violations(gen_df: pd.DataFrame) -> dict: + """Compute mean AC/DC generator active-power bound violations.""" + min_p = gen_df["min_p_mw"].to_numpy(dtype=float) + max_p = gen_df["max_p_mw"].to_numpy(dtype=float) + pg_ac = gen_df["p_mw"].to_numpy(dtype=float) + pg_dc = gen_df["p_mw_dc"].to_numpy(dtype=float) + viol_ac = np.maximum(pg_ac - max_p, 0.0) + np.maximum(min_p - pg_ac, 0.0) + viol_dc = np.maximum(pg_dc - max_p, 0.0) + np.maximum(min_p - pg_dc, 0.0) + return { + "AC Mean Pg bound violation (MW)": float(np.nanmean(viol_ac)), + "DC Mean Pg bound violation (MW)": float(np.nanmean(viol_dc)), + } + + +def _compute_qg_violations_ac(bus_df: pd.DataFrame, gen_df: pd.DataFrame) -> dict: + """Compute AC reactive-power limit violations for PV/REF buses.""" + # opf_task style on bus Qg; AC only + bus = bus_df.copy() + qg = bus["Qg"].to_numpy(dtype=float) + # complain if max_q_mvar == min_q_mvar for some gens of gen_df + assert (gen_df["max_q_mvar"] == gen_df["min_q_mvar"]).any() == False, "max_q_mvar == min_q_mvar for some gens of gen_df" + agg_gen = ( + gen_df.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] + .sum() + .reset_index()) + bus = bus.merge(agg_gen, on=["scenario", "bus"], how="left") + assert bus[bus["PV"]==1]["min_q_mvar"].isna().sum() == 0, "PV buses have no min_q_mvar" + assert bus[bus["PV"]==1]["max_q_mvar"].isna().sum() == 0, "PV buses have no max_q_mvar" + assert bus[bus["REF"]==1]["min_q_mvar"].isna().sum() == 0, "REF buses have no min_q_mvar" + assert bus[bus["REF"]==1]["max_q_mvar"].isna().sum() == 0, "REF buses have no max_q_mvar" + bus["qg_violation_amount"] = np.maximum(qg - bus["max_q_mvar"], 0.0) + np.maximum(bus["min_q_mvar"] - qg, 0.0) + pv = bus[bus["PV"] == 1] + ref = bus[bus["REF"] == 1] + pv_ref = bus[(bus["PV"] == 1) | (bus["REF"] == 1)] + return { + "AC Mean Qg violation PV buses": float(np.nanmean(pv["qg_violation_amount"].to_numpy(dtype=float))), + "AC Mean Qg violation REF buses": float(np.nanmean(ref["qg_violation_amount"].to_numpy(dtype=float))), + "AC Mean Qg violation": float(np.nanmean(pv_ref["qg_violation_amount"].to_numpy(dtype=float))), + } + + +def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> dict: + """Compute AC/DC branch thermal and angle-limit violation statistics.""" + rate = branch_df["rate_a"].to_numpy(dtype=float) + ac_from = np.sqrt( + branch_df["pf"].to_numpy(dtype=float) ** 2 + branch_df["qf"].to_numpy(dtype=float) ** 2, + ) + ac_to = np.sqrt( + branch_df["pt"].to_numpy(dtype=float) ** 2 + branch_df["qt"].to_numpy(dtype=float) ** 2, + ) + dc_from = np.abs(branch_df["pf_dc_computed"].to_numpy(dtype=float)) + dc_to = np.abs(branch_df["pt_dc_computed"].to_numpy(dtype=float)) + + ac_thermal_from = np.maximum(ac_from - rate, 0.0) + ac_thermal_to = np.maximum(ac_to - rate, 0.0) + dc_thermal_from = np.maximum(dc_from - rate, 0.0) + dc_thermal_to = np.maximum(dc_to - rate, 0.0) + + bus_angles = bus_df[["scenario", "bus", "Va", "Va_dc"]] + # convert to radians + bus_angles.loc[:, "Va"] = bus_angles["Va"] * np.pi / 180.0 + bus_angles.loc[:, "Va_dc"] = bus_angles["Va_dc"] * np.pi / 180.0 + from_angles = bus_angles.rename( + columns={"bus": "from_bus", "Va": "Va_from", "Va_dc": "Va_dc_from"}, + ) + to_angles = bus_angles.rename( + columns={"bus": "to_bus", "Va": "Va_to", "Va_dc": "Va_dc_to"}, + ) + br = branch_df.merge(from_angles, on=["scenario", "from_bus"], how="left") + br = br.merge(to_angles, on=["scenario", "to_bus"], how="left") + + # AC angle + ac_angle_diff = br["Va_from"] - br["Va_to"] + ac_angle_diff = (ac_angle_diff + np.pi) % (2 * np.pi) - np.pi # wrap to [-pi, pi] + ac_angle_excess_low = np.maximum(br["ang_min"] - ac_angle_diff, 0.0) + ac_angle_excess_high = np.maximum(ac_angle_diff - br["ang_max"], 0.0) + mean_ac_angle_violation = np.mean(ac_angle_excess_low + ac_angle_excess_high) + # DC angle + dc_angle_diff = br["Va_dc_from"] - br["Va_dc_to"] + dc_angle_diff = (dc_angle_diff + np.pi) % (2 * np.pi) - np.pi + dc_angle_excess_low = np.maximum(br["ang_min"] - dc_angle_diff, 0.0) + dc_angle_excess_high = np.maximum(dc_angle_diff - br["ang_max"], 0.0) + mean_dc_angle_violation = np.mean(dc_angle_excess_low + dc_angle_excess_high) + + return { + "AC Mean branch thermal violation from (MVA)": float(np.nanmean(ac_thermal_from)), + "AC Mean branch thermal violation to (MVA)": float(np.nanmean(ac_thermal_to)), + "AC Mean branch angle difference violation (radians)": float(mean_ac_angle_violation), + "DC Mean branch thermal violation from (MVA)": float(np.nanmean(dc_thermal_from)), + "DC Mean branch thermal violation to (MVA)": float(np.nanmean(dc_thermal_to)), + "DC Mean branch angle difference violation (radians)": float(mean_dc_angle_violation), + } + + +def compute_opf_ac_dc_metrics( + artifacts_dir: str, + data_dir: str, + grid_name: str, + sn_mva: float, +) -> bool: + """Compute AC/DC OPF baseline metrics (PF metrics + OPF inequalities), save results. + + Saves: + - Aggregated metrics (CSV) + - AC per-bus residuals (Parquet) + - DC per-bus residuals (Parquet) + + Returns: + True if metrics were computed, False if splits JSON was not found. + """ + + splits_json = os.path.join( + artifacts_dir, + "stats", + f"{grid_name}_scenario_splits.json", + ) + if not os.path.exists(splits_json): + print(f" Skipping: no splits JSON found at {splits_json}") + return False + + with open(splits_json) as f: + test_ids = json.load(f)["test"] + + print(f" Test split: {len(test_ids)} scenarios") + + bus_df, gen_df, branch_df, runtime_df = _load_test_data(data_dir, test_ids) + + print(" Computing AC power balance...") + balance_ac = compute_bus_balance( + bus_df, + branch_df, + branch_df[["pf", "qf", "pt", "qt"]], + dc=False, + sn_mva=sn_mva, + ) + ac_stats = _compute_residual_stats(balance_ac, dc=False) + + print(" Computing DC power balance...") + pf_dc, _, pt_dc, _ = compute_branch_powers_vectorized( + branch_df, + bus_df, + dc=True, + sn_mva=sn_mva, + ) + balance_dc = compute_bus_balance( + bus_df, + branch_df, + pd.DataFrame( + {"pf_dc": pf_dc, "pt_dc": pt_dc}, + index=branch_df.index, + ), + dc=True, + sn_mva=sn_mva, + ) + dc_stats = _compute_residual_stats(balance_dc, dc=True) + + branch_df = branch_df.copy() + branch_df["pf_dc_computed"] = pf_dc + branch_df["pt_dc_computed"] = pt_dc + + + opf_extra = {} + opf_extra.update(_compute_optimality_gap(gen_df)) + opf_extra.update(_compute_branch_violations(branch_df, bus_df)) + opf_extra.update(_compute_pg_violations(gen_df)) + opf_extra.update(_compute_qg_violations_ac(bus_df, gen_df)) + + out_dir = os.path.join(artifacts_dir, "test") + os.makedirs(out_dir, exist_ok=True) + + ac_bus_residuals = ( + balance_ac[["scenario", "bus", "P_mis_ac", "Q_mis_ac"]] + .copy() + .rename( + columns={ + "P_mis_ac": "active res. (MW)", + "Q_mis_ac": "reactive res. (MVar)", + }, + ) + ) + ac_residuals_path = os.path.join(out_dir, f"{grid_name}_ac_bus_residuals.parquet") + ac_bus_residuals.to_parquet(ac_residuals_path, index=False) + print(f" AC per-bus residuals saved to {ac_residuals_path}") + + dc_bus_residuals = ( + balance_dc[["scenario", "bus", "P_mis_dc"]] + .copy() + .rename( + columns={ + "P_mis_dc": "DC active res. (MW)", + }, + ) + ) + dc_residuals_path = os.path.join(out_dir, f"{grid_name}_dc_bus_residuals.parquet") + dc_bus_residuals.to_parquet(dc_residuals_path, index=False) + print(f" DC per-bus residuals saved to {dc_residuals_path}") + + runtime_stats = _compute_runtime_stats(runtime_df) + + rows = [] + for key, val in ac_stats.items(): + rows.append({"Metric": f"AC {key}", "Value": val}) + for key, val in dc_stats.items(): + rows.append({"Metric": f"DC {key}", "Value": val}) + for key, val in opf_extra.items(): + rows.append({"Metric": key, "Value": val}) + for key, val in runtime_stats.items(): + rows.append({"Metric": key, "Value": val}) + + metrics_path = os.path.join(out_dir, f"{grid_name}_opf_ac_dc_metrics.csv") + pd.DataFrame(rows).to_csv(metrics_path, index=False) + print(f" Aggregated OPF AC/DC metrics saved to {metrics_path}") + + return True diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index 06d938df..7df1344c 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -1,5 +1,7 @@ from gridfm_graphkit.datasets.globals import ( # Bus feature indices + PD_H, + QD_H, QG_H, VM_H, VA_H, @@ -12,6 +14,8 @@ QG_OUT, # Generator feature indices PG_H, + MIN_PG, + MAX_PG, C0_H, C1_H, C2_H, @@ -28,8 +32,8 @@ plot_residuals_histograms, residual_stats_by_type, ) -from pytorch_lightning.utilities import rank_zero_only import torch +import torch.distributed as dist import torch.nn.functional as F from torch_scatter import scatter_add from gridfm_graphkit.models.utils import ( @@ -112,7 +116,7 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): # output["bus"] = target Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) - # Compute branch termal limits violations + # Compute branch thermal limits violations Sft = torch.sqrt(Pft**2 + Qft**2) # apparent power flow per branch branch_thermal_limits = bus_edge_attr[:, RATE_A] branch_thermal_excess = F.relu(Sft - branch_thermal_limits) @@ -132,13 +136,14 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): bus_angles = output["bus"][:, VA_OUT] # in degrees from_bus = bus_edge_index[0] to_bus = bus_edge_index[1] - angle_diff = torch.abs(bus_angles[from_bus] - bus_angles[to_bus]) + angle_diff = bus_angles[from_bus] - bus_angles[to_bus] # keep sign + angle_diff = (angle_diff + torch.pi) % (2 * torch.pi) - torch.pi # wrap to [-pi, pi] + angle_excess_low = F.relu(angle_min - angle_diff) + angle_excess_high = F.relu(angle_diff - angle_max) - angle_excess_low = F.relu(angle_min - angle_diff) # violation if too small - angle_excess_high = F.relu(angle_diff - angle_max) # violation if too large - branch_angle_violation_mean = ( - torch.mean(angle_excess_low + angle_excess_high) * 180.0 / torch.pi - ) + branch_angle_violation_mean = torch.mean( + angle_excess_low + angle_excess_high + ) # mean of the abs violation P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( @@ -167,6 +172,8 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): mean_Qg_violation_PV = Qg_violation_amount[mask_PV].mean() mean_Qg_violation_REF = Qg_violation_amount[mask_REF].mean() + mask_PV_REF = mask_PV | mask_REF # PV or REF buses + mean_Qg_violation = Qg_violation_amount[mask_PV_REF].mean() # if self.args.verbose: mean_res_P_PQ, max_res_P_PQ = residual_stats_by_type( @@ -261,8 +268,10 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Branch voltage angle difference violations"] = ( branch_angle_violation_mean ) - loss_dict["Mean Qg violation PV buses"] = mean_Qg_violation_PV + loss_dict["Mean Qg violation PV buses"] = mean_Qg_violation_PV # mean of the abs violation over the entire batch (all oines in the batch). + # this is then overaged over all the batches and gives same weight to all batches despite them possibly having varying number of branches loss_dict["Mean Qg violation REF buses"] = mean_Qg_violation_REF + loss_dict["Mean Qg violation"] = mean_Qg_violation loss_dict["MSE PQ nodes - PG"] = mse_PQ[PG_OUT] loss_dict["MSE PV nodes - PG"] = mse_PV[PG_OUT] @@ -293,8 +302,25 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): ) return - @rank_zero_only def on_test_end(self): + # In DDP, gather verbose test outputs from all ranks to rank 0 + # so that plots and detailed analysis cover the full test set. + if self.args.verbose and dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + gathered = [None] * world_size if dist.get_rank() == 0 else None + dist.gather_object(self.test_outputs, gathered, dst=0) + if dist.get_rank() == 0: + merged = {i: [] for i in range(len(self.args.data.networks))} + for rank_data in gathered: + for dl_idx, batches in rank_data.items(): + merged[dl_idx].extend(batches) + self.test_outputs = merged + + # Only rank 0 proceeds with logging, CSV writing, and plotting + if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: + self.test_outputs.clear() + return + if isinstance(self.logger, MLFlowLogger): artifact_dir = os.path.join( self.logger.save_dir, @@ -341,10 +367,10 @@ def on_test_end(self): rmse_gen = metrics.get("MSE PG", 0) ** 0.5 optimality_gap = metrics.get("Opt gap", " ") branch_thermal_violation_from = metrics.get( - "Branch termal violation from", + "Branch thermal violation from", " ", ) - branch_thermal_violation_to = metrics.get("Branch termal violation to", " ") + branch_thermal_violation_to = metrics.get("Branch thermal violation to", " ") branch_angle_violation = metrics.get( "Branch voltage angle difference violations", " ", @@ -354,6 +380,7 @@ def on_test_end(self): "Mean Qg violation REF buses", " ", ) + mean_qg_violation = metrics.get("Mean Qg violation", " ") # --- Main RMSE metrics file --- data_main = { @@ -372,11 +399,12 @@ def on_test_end(self): "Avg. reactive res. (MVar)", "RMSE PG generators (MW)", "Mean optimality gap (%)", - "Mean branch termal violation from (MVA)", - "Mean branch termal violation to (MVA)", + "Mean branch thermal violation from (MVA)", + "Mean branch thermal violation to (MVA)", "Mean branch angle difference violation (radians)", "Mean Qg violation PV buses", "Mean Qg violation REF buses", + "Mean Qg violation", ], "Value": [ avg_active_res, @@ -388,6 +416,7 @@ def on_test_end(self): branch_angle_violation, mean_qg_violation_PV_buses, mean_qg_violation_REF_buses, + mean_qg_violation, ], } df_residuals = pd.DataFrame(data_residuals) @@ -482,4 +511,93 @@ def on_test_end(self): self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): - raise NotImplementedError + output, _ = self.shared_step(batch) + + self.data_normalizers[dataloader_idx].inverse_transform(batch) + self.data_normalizers[dataloader_idx].inverse_output(output, batch) + + branch_flow_layer = ComputeBranchFlow() + node_injection_layer = ComputeNodeInjection() + node_residuals_layer = ComputeNodeResiduals() + + num_bus = batch.x_dict["bus"].size(0) + bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] + bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] + + Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) + P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) + residual_P, residual_Q = node_residuals_layer( + P_in, + Q_in, + output["bus"], + batch.x_dict["bus"], + ) + residual_P = torch.abs(residual_P) + residual_Q = torch.abs(residual_Q) + residual_mva = torch.sqrt(residual_P**2 + residual_Q**2) + + bus_batch = batch.batch_dict["bus"] + scenario_ids = batch["scenario_id"][bus_batch] + local_bus_idx = torch.cat( + [ + torch.arange(c, device=bus_batch.device) + for c in torch.bincount(bus_batch) + ], + ) # this works because the order of the buses is preserved by the groupby in the dataset wrapper and datakit data has buses in increasing order. + + bus_x = batch.x_dict["bus"] + bus_y = batch.y_dict["bus"] + mask_PQ = batch.mask_dict["PQ"] + mask_PV = batch.mask_dict["PV"] + mask_REF = batch.mask_dict["REF"] + + _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] + agg_gen_on_bus = scatter_add( + batch.y_dict["gen"], + gen_to_bus_index, + dim=0, + dim_size=num_bus, + ) + gen_batch = batch.batch_dict["gen"] + gen_scenario_ids = batch["scenario_id"][gen_batch] + local_gen_idx = torch.cat( + [ + torch.arange(c, device=gen_batch.device) + for c in torch.bincount(gen_batch) + ], + ) + gen_x = batch.x_dict["gen"] + gen_target = batch.y_dict["gen"].reshape(-1) + gen_pred = output["gen"].reshape(-1) + + return { + "bus": { + "scenario": scenario_ids.cpu().numpy(), + "bus": local_bus_idx.cpu().numpy(), + "pd_mw": bus_x[:, PD_H].cpu().numpy(), + "qd_mvar": bus_x[:, QD_H].cpu().numpy(), + "vm_pu_target": bus_y[:, VM_H].cpu().numpy(), + "va_target": bus_y[:, VA_H].cpu().numpy(), + "pg_mw_target": agg_gen_on_bus.squeeze().cpu().numpy(), + "qg_mvar_target": bus_y[:, QG_H].cpu().numpy(), + "is_pq": mask_PQ.cpu().numpy().astype(int), + "is_pv": mask_PV.cpu().numpy().astype(int), + "is_ref": mask_REF.cpu().numpy().astype(int), + "vm_pu": output["bus"][:, VM_OUT].detach().cpu().numpy(), + "va": output["bus"][:, VA_OUT].detach().cpu().numpy(), + "pg_mw": output["bus"][:, PG_OUT].detach().cpu().numpy(), + "qg_mvar": output["bus"][:, QG_OUT].detach().cpu().numpy(), + "active res. (MW)": residual_P.detach().cpu().numpy(), + "reactive res. (MVar)": residual_Q.detach().cpu().numpy(), + "PBE": residual_mva.detach().cpu().numpy(), + }, + "gen": { + "scenario": gen_scenario_ids.cpu().numpy(), + "gen": local_gen_idx.cpu().numpy(), + "connected_bus": local_bus_idx[gen_to_bus_index].cpu().numpy(), + "pg_mw_target": gen_target.cpu().numpy(), + "pg_mw": gen_pred.detach().cpu().numpy(), + "min_pg_mw": gen_x[:, MIN_PG].cpu().numpy(), + "max_pg_mw": gen_x[:, MAX_PG].cpu().numpy(), + }, + } diff --git a/gridfm_graphkit/tasks/compute_ac_dc_metrics.py b/gridfm_graphkit/tasks/pf_ac_dc_baseline.py similarity index 96% rename from gridfm_graphkit/tasks/compute_ac_dc_metrics.py rename to gridfm_graphkit/tasks/pf_ac_dc_baseline.py index 8dcfc8c0..00da4512 100644 --- a/gridfm_graphkit/tasks/compute_ac_dc_metrics.py +++ b/gridfm_graphkit/tasks/pf_ac_dc_baseline.py @@ -14,6 +14,7 @@ def _load_test_data(data_dir: str, test_scenario_ids: list[int]): + """Load PF test-split bus/branch/runtime tables from partitioned parquet.""" partitions = sorted(set(s // N_SCENARIO_PER_PARTITION for s in test_scenario_ids)) test_set = set(test_scenario_ids) partition_filter = [("scenario_partition", "in", partitions)] @@ -45,6 +46,7 @@ def _load_test_data(data_dir: str, test_scenario_ids: list[int]): def _compute_residual_stats(balance_df: pd.DataFrame, dc: bool) -> dict: + """Aggregate AC or DC residual statistics from per-bus balance outputs.""" grouped = balance_df.groupby("scenario") if dc: @@ -75,6 +77,7 @@ def _compute_residual_stats(balance_df: pd.DataFrame, dc: bool) -> dict: def _compute_runtime_stats(runtime_df: pd.DataFrame) -> dict: + """Compute summary statistics for AC/DC runtime columns (milliseconds).""" results = {} for mode in ["ac", "dc"]: if mode not in runtime_df.columns: @@ -99,7 +102,7 @@ def _compute_runtime_stats(runtime_df: pd.DataFrame) -> dict: return results -def compute_ac_dc_metrics( +def compute_pf_ac_dc_metrics( artifacts_dir: str, data_dir: str, grid_name: str, diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index cdc9d646..2c2478ee 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -241,6 +241,7 @@ def on_test_end(self): # Only rank 0 proceeds with logging, CSV writing, and plotting if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: + self.test_outputs.clear() # clear the test outputs for other ranks return if isinstance(self.logger, MLFlowLogger): @@ -351,22 +352,22 @@ def on_test_end(self): self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): - output, _ = self.shared_step(batch) + output, _ = self.shared_step(batch) # get the predicted output from the model - self.data_normalizers[dataloader_idx].inverse_transform(batch) - self.data_normalizers[dataloader_idx].inverse_output(output, batch) + self.data_normalizers[dataloader_idx].inverse_transform(batch) # normalize the batch data back to the original scale + self.data_normalizers[dataloader_idx].inverse_output(output, batch) # inverse transform the predicted output back to the original scale - branch_flow_layer = ComputeBranchFlow() - node_injection_layer = ComputeNodeInjection() - node_residuals_layer = ComputeNodeResiduals() + branch_flow_layer = ComputeBranchFlow() # layer to compute the branch flows + node_injection_layer = ComputeNodeInjection() # layer to compute the node injections + node_residuals_layer = ComputeNodeResiduals() # layer to compute the node residuals - num_bus = batch.x_dict["bus"].size(0) - bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] - bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] + num_bus = batch.x_dict["bus"].size(0) # number of buses in the batch + bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] # from and to buses + bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] # edge attributes (admittance) of the bus connections - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) - P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) - residual_P, residual_Q = node_residuals_layer( + Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) # compute the branch flows + P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) # compute the node injections + residual_P, residual_Q = node_residuals_layer( # compute the node residuals P_in, Q_in, output["bus"], @@ -383,8 +384,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): torch.arange(c, device=bus_batch.device) for c in torch.bincount(bus_batch) ], - ) - + ) # this is based on the assumptions that the buses within a graph are ordered and indexed as 0 ... n_nodes-1. + # todo: we should remove this assert and store the bus idx in the tensors + # right now we need the increasing order and we have an assert in the dataset to check it. bus_x = batch.x_dict["bus"] bus_y = batch.y_dict["bus"] mask_PQ = batch.mask_dict["PQ"] @@ -402,20 +404,20 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return { "scenario": scenario_ids.cpu().numpy(), "bus": local_bus_idx.cpu().numpy(), - "pd_mw": bus_x[:, PD_H].cpu().numpy(), - "qd_mvar": bus_x[:, QD_H].cpu().numpy(), - "vm_pu_target": bus_y[:, VM_H].cpu().numpy(), - "va_target": bus_y[:, VA_H].cpu().numpy(), - "pg_mw_target": agg_gen_on_bus.squeeze().cpu().numpy(), - "qg_mvar_target": bus_y[:, QG_H].cpu().numpy(), - "is_pq": mask_PQ.cpu().numpy().astype(int), - "is_pv": mask_PV.cpu().numpy().astype(int), - "is_ref": mask_REF.cpu().numpy().astype(int), - "vm_pu": output["bus"][:, VM_OUT].detach().cpu().numpy(), - "va": output["bus"][:, VA_OUT].detach().cpu().numpy(), - "pg_mw": output["bus"][:, PG_OUT].detach().cpu().numpy(), - "qg_mvar": output["bus"][:, QG_OUT].detach().cpu().numpy(), - "active res. (MW)": residual_P.detach().cpu().numpy(), - "reactive res. (MVar)": residual_Q.detach().cpu().numpy(), - "PBE": residual_mva.detach().cpu().numpy(), + "pd_mw": bus_x[:, PD_H].cpu().numpy(), # from original input + "qd_mvar": bus_x[:, QD_H].cpu().numpy(), # from original input + "vm_pu_target": bus_y[:, VM_H].cpu().numpy(), # from original input + "va_target": bus_y[:, VA_H].cpu().numpy(), # from original input + "pg_mw_target": agg_gen_on_bus.squeeze().cpu().numpy(), # from original input + "qg_mvar_target": bus_y[:, QG_H].cpu().numpy(), # from original input + "is_pq": mask_PQ.cpu().numpy().astype(int), # from original input + "is_pv": mask_PV.cpu().numpy().astype(int), # from original input + "is_ref": mask_REF.cpu().numpy().astype(int), # from original input + "vm_pu": output["bus"][:, VM_OUT].detach().cpu().numpy(), # predicted output + "va": output["bus"][:, VA_OUT].detach().cpu().numpy(), # predicted output + "pg_mw": output["bus"][:, PG_OUT].detach().cpu().numpy(), # predicted output + "qg_mvar": output["bus"][:, QG_OUT].detach().cpu().numpy(), # predicted output + "active res. (MW)": residual_P.detach().cpu().numpy(), # predicted output + "reactive res. (MVar)": residual_Q.detach().cpu().numpy(), # predicted output + "PBE": residual_mva.detach().cpu().numpy(), # predicted output } diff --git a/gridfm_graphkit/tasks/reconstruction_tasks.py b/gridfm_graphkit/tasks/reconstruction_tasks.py index 8742646b..45975aee 100644 --- a/gridfm_graphkit/tasks/reconstruction_tasks.py +++ b/gridfm_graphkit/tasks/reconstruction_tasks.py @@ -57,6 +57,7 @@ def shared_step(self, batch): batch.edge_attr_dict, batch.mask_dict, model=self.model, + x_dict=batch.x_dict, ) return output, loss_dict diff --git a/gridfm_graphkit/tasks/se_task.py b/gridfm_graphkit/tasks/se_task.py index 5e45182d..36667ad2 100644 --- a/gridfm_graphkit/tasks/se_task.py +++ b/gridfm_graphkit/tasks/se_task.py @@ -26,6 +26,7 @@ @TASK_REGISTRY.register("StateEstimation") class StateEstimationTask(ReconstructionTask): + """State-estimation task with evaluation plots for masked and noisy measurements.""" def __init__(self, args, data_normalizers): super().__init__(args, data_normalizers) diff --git a/gridfm_graphkit/tasks/utils.py b/gridfm_graphkit/tasks/utils.py index d874eff9..273d79f5 100644 --- a/gridfm_graphkit/tasks/utils.py +++ b/gridfm_graphkit/tasks/utils.py @@ -7,10 +7,25 @@ def residual_stats_by_type(residual, mask, bus_batch): + """Return per-graph mean and max absolute residuals for a masked bus subset.""" residual_masked = residual[mask] batch_masked = bus_batch[mask] - mean_res = scatter_mean(torch.abs(residual_masked), batch_masked, dim=0) - max_res, _ = scatter_max(torch.abs(residual_masked), batch_masked, dim=0) + abs_residual = torch.abs(residual_masked) + + # torch_scatter on MPS can dispatch into a CPU-only path for scatter_max. + # Compute the grouped stats on CPU and move the results back so verbose + # evaluation works without changing the torch/torch_scatter stack. + if abs_residual.device.type == "mps": + abs_residual_cpu = abs_residual.cpu() + batch_masked_cpu = batch_masked.cpu() + mean_res = scatter_mean(abs_residual_cpu, batch_masked_cpu, dim=0).to( + abs_residual.device, + ) + max_res, _ = scatter_max(abs_residual_cpu, batch_masked_cpu, dim=0) + max_res = max_res.to(abs_residual.device) + else: + mean_res = scatter_mean(abs_residual, batch_masked, dim=0) + max_res, _ = scatter_max(abs_residual, batch_masked, dim=0) return mean_res, max_res diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index df8ee247..ba7a4049 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -41,6 +41,7 @@ def last_epoch_iters_per_sec(self) -> float | None: class SaveBestModelStateDict(Callback): + """Persist the best model state_dict according to a monitored validation metric.""" def __init__( self, monitor: str, @@ -52,6 +53,15 @@ def __init__( self.filename = filename self.best_score = float("inf") if mode == "min" else -float("inf") + @staticmethod + def _canonical_state_dict(pl_module): + """Return a state dict with compile wrappers removed from key names.""" + state_dict = pl_module.state_dict() + return { + key.replace("model._orig_mod.", "model."): value + for key, value in state_dict.items() + } + @rank_zero_only def on_validation_end(self, trainer, pl_module): current = trainer.callback_metrics.get(self.monitor) @@ -81,4 +91,4 @@ def on_validation_end(self, trainer, pl_module): # Save the model's state_dict model_path = os.path.join(model_dir, self.filename) - torch.save(pl_module.state_dict(), model_path) + torch.save(self._canonical_state_dict(pl_module), model_path) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index d253d2b3..a0521fc2 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -19,6 +19,9 @@ PG_OUT, # Generator feature indices PG_H, + # Qg Limits + MIN_QG_H, + MAX_QG_H, ) @@ -36,6 +39,7 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, ): """ Compute the loss. @@ -72,6 +76,7 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, ): loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction) return {"loss": loss, "Masked MSE loss": loss.detach()} @@ -79,6 +84,7 @@ def forward( @LOSS_REGISTRY.register("MaskedGenMSE") class MaskedGenMSE(torch.nn.Module): + """Compute MSE on generator targets restricted to generator mask entries.""" def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -91,6 +97,7 @@ def forward( edge_attr, mask_dict, model=None, + x_dict=None, ): loss = F.mse_loss( pred_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], @@ -102,6 +109,7 @@ def forward( @LOSS_REGISTRY.register("MaskedBusMSE") class MaskedBusMSE(torch.nn.Module): + """Compute MSE on selected bus targets, respecting task-specific output columns.""" def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -115,6 +123,7 @@ def forward( edge_attr, mask_dict, model=None, + x_dict=None, ): if self.args.task == "OptimalPowerFlow": pred_cols = [VM_OUT, VA_OUT, QG_OUT] @@ -152,6 +161,7 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, ): loss = F.mse_loss(pred, target, reduction=self.reduction) return {"loss": loss, "MSE loss": loss.detach()} @@ -185,6 +195,7 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, ): """ Compute the weighted sum of all specified losses. @@ -211,6 +222,7 @@ def forward( edge_attr, mask, model, + x_dict, ) # Assume each loss function returns a dictionary with a "loss" key @@ -229,6 +241,7 @@ def forward( @LOSS_REGISTRY.register("LayeredWeightedPhysics") class LayeredWeightedPhysicsLoss(BaseLoss): + """Combine intermediate physics residuals using normalized geometric weights.""" def __init__(self, loss_args, args) -> None: super().__init__() self.base_weight = loss_args.base_weight @@ -241,6 +254,7 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, ): total_loss = 0.0 loss_details = {} @@ -268,6 +282,7 @@ def forward( @LOSS_REGISTRY.register("LossPerDim") class LossPerDim(BaseLoss): + """Compute MAE/MSE for one named physical dimension of bus outputs.""" def __init__(self, loss_args, args): super(LossPerDim, self).__init__() self.reduction = "mean" @@ -291,6 +306,7 @@ def forward( edge_attr, mask_dict, model=None, + x_dict=None, ): if self.dim == "VM": temp_pred = pred_dict["bus"][:, VM_OUT] @@ -322,3 +338,57 @@ def forward( f"MSE loss {self.dim}": mse_loss.detach(), f"MAE loss {self.dim}": mae_loss.detach(), } + + +@LOSS_REGISTRY.register("QgViolationPenalty") +class QgViolationPenaltyLoss(BaseLoss): + """Standard Mean Squared Error loss.""" + + def __init__(self, loss_args, args): + super().__init__() + + def forward( + self, + pred, + target, + edge_index=None, + edge_attr=None, + mask=None, + model=None, + x_dict=None, + ): + # --- Qg limit violation mask --- + Qg_pred = pred["bus"][:, QG_OUT] + Qg_max = x_dict["bus"][:, MAX_QG_H] + Qg_min = x_dict["bus"][:, MIN_QG_H] + + max_penalty_mask = (Qg_pred > Qg_max) + min_penalty_mask = (Qg_pred < Qg_min) + + mask_PQ = mask["PQ"] # PQ buses + mask_PV = mask["PV"] # PV buses + mask_REF = mask["REF"] # Reference buses + + loss = 0.0 + # where there are violations, compute penalty loss + Qg_over = F.relu(Qg_pred - Qg_max) # amount above max limit + Qg_under = F.relu(Qg_min - Qg_pred) # amount below min limit + + Qg_over = Qg_over[max_penalty_mask].mean() + Qg_under = Qg_under[min_penalty_mask].mean() + + if Qg_over!=Qg_over: # replacing nan with 0 + Qg_over = 0.0 + if Qg_under!=Qg_under: # replacing nan with 0 + Qg_under = 0.0 + + penalty_loss = Qg_over + Qg_under + loss += penalty_loss + + try: + output = {"loss": loss, "Qg Violation Penalty loss": loss.detach()} + except: + output = {"loss": loss, "Qg Violation Penalty loss": loss} + + return output + diff --git a/gridfm_graphkit/utils/visualization.py b/gridfm_graphkit/utils/visualization.py index 276d403b..3a8151c8 100644 --- a/gridfm_graphkit/utils/visualization.py +++ b/gridfm_graphkit/utils/visualization.py @@ -11,6 +11,7 @@ def visualize_error(data_point, output, node_normalizer): + """Plot node-wise active power residuals on the grid topology.""" loss = PBELoss(visualization=True) loss_dict = loss( diff --git a/integrationtests/test_base_set.py b/integrationtests/test_base_set.py index de603a61..90da468a 100644 --- a/integrationtests/test_base_set.py +++ b/integrationtests/test_base_set.py @@ -70,7 +70,7 @@ def cleanup_test_artifacts(): """ Backup modified files and remove generated artifacts after the test. """ - training_config = "examples/config/HGNS_PF_datakit_case14.yaml" + training_config = " " backup_config = training_config + ".bak" if os.path.exists(training_config): diff --git a/mkdocs.yml b/mkdocs.yml index 6581214c..c4717e04 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,26 +9,6 @@ nav: - Getting started: - CLI commands: quick_start/quick_start.md - YAML configuration file: quick_start/yaml_config.md - - Tutorials: - - Visualizing predictions of GridFM: tutorials/feature_reconstruction.md - - Contingency analysis: tutorials/contingency_analysis.md - - Components: - - Datasets: - - Data normalization: datasets/data_normalization.md - - Power Grid datasets: datasets/powergrid.md - - Data Modules: datasets/data_modules.md - - Transforms: datasets/transforms.md - - Tasks: - - Overview: tasks/feature_reconstruction.md - - Base Task: tasks/base_task.md - - Reconstruction Task: tasks/reconstruction_task.md - - Power Flow Task: tasks/power_flow.md - - Optimal Power Flow Task: tasks/optimal_power_flow.md - - State Estimation Task: tasks/state_estimation.md - - Models: models/models.md - - Training: - - Losses: training/loss.md - theme: name: material