From 0f93f0483e78e6dace51814164d6ee4a377f7354 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Thu, 23 Apr 2026 22:36:35 +0200 Subject: [PATCH 01/11] add batch_size param Signed-off-by: Romeo Kienzler --- gridfm_graphkit/__main__.py | 30 ++++++++++++++++++++++++++++++ gridfm_graphkit/cli.py | 9 +++++++++ 2 files changed, 39 insertions(+) diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index b693089c..607b15e7 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -120,6 +120,12 @@ def main(): default=None, help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) + train_parser.add_argument( + "--batch_size", + type=int, + default=None, + help="Override training.batch_size from the YAML config.", + ) train_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, @@ -173,6 +179,12 @@ def main(): default=None, help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) + finetune_parser.add_argument( + "--batch_size", + type=int, + default=None, + help="Override training.batch_size from the YAML config.", + ) finetune_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, @@ -235,6 +247,12 @@ def main(): default=None, help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) + evaluate_parser.add_argument( + "--batch_size", + type=int, + default=None, + help="Override training.batch_size from the YAML config.", + ) evaluate_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, @@ -284,6 +302,12 @@ def main(): default=None, help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) + predict_parser.add_argument( + "--batch_size", + type=int, + default=None, + help="Override training.batch_size from the YAML config.", + ) predict_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, @@ -332,6 +356,12 @@ def main(): default=None, help="Override data.workers from the YAML config.", ) + benchmark_parser.add_argument( + "--batch_size", + type=int, + default=None, + help="Override training.batch_size from the YAML config.", + ) benchmark_parser.add_argument( "--plugins", nargs="*", diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 0ffd1364..ca1167c0 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -80,6 +80,10 @@ def benchmark_cli(args): if num_workers_override is not None: config_args.data.workers = num_workers_override + batch_size_override = getattr(args, "batch_size", None) + if batch_size_override is not None: + config_args.training.batch_size = batch_size_override + _load_plugins(getattr(args, "plugins", [])) dataset_wrapper = getattr(args, "dataset_wrapper", None) @@ -177,6 +181,11 @@ def main_cli(args): if num_workers_override is not None: config_args.data.workers = num_workers_override + # CLI --batch_size overrides the YAML value + batch_size_override = getattr(args, "batch_size", None) + if batch_size_override is not None: + config_args.training.batch_size = batch_size_override + _load_plugins(getattr(args, "plugins", [])) _validate_dataset_wrapper(dataset_wrapper) From e0c4397aa6b8fdf386c36fbe7ca5fa37a0339b5d Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Thu, 23 Apr 2026 22:38:52 +0200 Subject: [PATCH 02/11] report validation loss Signed-off-by: Romeo Kienzler --- gridfm_graphkit/cli.py | 45 +++++++++++++++++------------------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index ca1167c0..f18b1fe4 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -253,19 +253,23 @@ def main_cli(args): ) if args.command == "train" or args.command == "finetune": trainer.fit(model=model, datamodule=litGrid) - if ( - report_performance - and epoch_timer is not None - and epoch_timer.last_epoch_time is not None - ): - print(f"[performance] last epoch time : {epoch_timer.last_epoch_time:.3f}s") - if ( - epoch_timer.last_epoch_iters_per_sec is not None - and epoch_timer._last_batch_count > 0 - ): - print( - f"[performance] last epoch it/s : {epoch_timer.last_epoch_iters_per_sec:.2f}", - ) + if report_performance: + # Validation loss + val_loss = trainer.callback_metrics.get("Validation loss") + if val_loss is not None: + print(f"[performance] Validation loss : {float(val_loss):.6f}") + else: + print("[performance] Validation loss : not available") + # Epoch timing + if epoch_timer is not None and epoch_timer.last_epoch_time is not None: + print(f"[performance] last epoch time : {epoch_timer.last_epoch_time:.3f}s") + if ( + epoch_timer.last_epoch_iters_per_sec is not None + and epoch_timer._last_batch_count > 0 + ): + print( + f"[performance] last epoch it/s : {epoch_timer.last_epoch_iters_per_sec:.2f}", + ) if args.command != "predict": # Reuse the fit trainer when coming from train/finetune so that @@ -284,20 +288,7 @@ def main_cli(args): **trainer_kwargs, profiler=profiler, ) - test_results = test_trainer.test(model=model, datamodule=litGrid) - if report_performance: - # test_results[0] may be empty when metrics are routed to the logger - # only; fall back to trainer.callback_metrics which always has them. - metrics = ( - test_results[0] - if test_results and test_results[0] - else dict(test_trainer.callback_metrics) - ) - if metrics: - first_metric, first_value = next(iter(metrics.items())) - print(f"[performance] {first_metric} : {first_value}") - else: - print("[performance] no test metrics available") + test_trainer.test(model=model, datamodule=litGrid) artifacts_dir = None is_rank0 = ( From 8efc82a45bfd65d726e1642de7ac1ad2d167781e Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Fri, 24 Apr 2026 21:08:31 +0200 Subject: [PATCH 03/11] enhance DDP strategy configuration for single and multi-device training Signed-off-by: Romeo Kienzler --- gridfm_graphkit/cli.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index f18b1fe4..77279005 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -237,7 +237,10 @@ def main_cli(args): "ddp", "ddp_find_unused_parameters_true", ): - _strategy = DDPStrategy(find_unused_parameters=True) + _num_devices = config_args.training.devices + _single_device = _num_devices == 1 or _num_devices == "1" + _pg_backend = "gloo" if _single_device else "nccl" + _strategy = DDPStrategy(find_unused_parameters=True, process_group_backend=_pg_backend) trainer = L.Trainer( logger=logger, From 591dc73559e2f238a23f0146bd399684b6ecda45 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Fri, 24 Apr 2026 21:16:10 +0200 Subject: [PATCH 04/11] improve DDP strategy handling for single and multi-device training Signed-off-by: Romeo Kienzler --- gridfm_graphkit/cli.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 77279005..bd7f867f 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -238,9 +238,16 @@ def main_cli(args): "ddp_find_unused_parameters_true", ): _num_devices = config_args.training.devices - _single_device = _num_devices == 1 or _num_devices == "1" - _pg_backend = "gloo" if _single_device else "nccl" - _strategy = DDPStrategy(find_unused_parameters=True, process_group_backend=_pg_backend) + _single_device = _num_devices in (1, "1", [1]) or ( + isinstance(_num_devices, list) and len(_num_devices) == 1 + ) + if _single_device: + # No need for DDP with a single device; use auto to avoid NCCL init + _strategy = "auto" + else: + # For multi-device, prefer gloo to avoid NCCL socket interface issues; + # set NCCL_SOCKET_IFNAME=lo in env if nccl is needed instead. + _strategy = DDPStrategy(find_unused_parameters=True, process_group_backend="gloo") trainer = L.Trainer( logger=logger, From 345e6facf55f25ac1a878bb17b0c388881d4cae7 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Fri, 24 Apr 2026 21:32:04 +0200 Subject: [PATCH 05/11] set MASTER_PORT and MASTER_ADDR for DDP strategy to avoid socket issues Signed-off-by: Romeo Kienzler --- gridfm_graphkit/cli.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index bd7f867f..89b8ea7d 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -8,6 +8,7 @@ import importlib import numpy as np import os +import socket import time import yaml import torch @@ -245,8 +246,15 @@ def main_cli(args): # No need for DDP with a single device; use auto to avoid NCCL init _strategy = "auto" else: - # For multi-device, prefer gloo to avoid NCCL socket interface issues; - # set NCCL_SOCKET_IFNAME=lo in env if nccl is needed instead. + # Ensure MASTER_PORT is set to a free port to avoid EADDRINUSE + if not os.environ.get("MASTER_PORT"): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _s: + _s.bind(("", 0)) + _s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + os.environ["MASTER_PORT"] = str(_s.getsockname()[1]) + if not os.environ.get("MASTER_ADDR"): + os.environ["MASTER_ADDR"] = "127.0.0.1" + # Use gloo backend to avoid NCCL socket interface issues _strategy = DDPStrategy(find_unused_parameters=True, process_group_backend="gloo") trainer = L.Trainer( From 9c83c13be067802d7974faf69fd778b5e27f31ee Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Mon, 27 Apr 2026 12:07:35 +0200 Subject: [PATCH 06/11] refactor multiprocessing context handling for DataLoader and set 'fork' globally on Linux Signed-off-by: Romeo Kienzler --- .../datasets/hetero_powergrid_datamodule.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index e5374970..c361ee5b 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -425,20 +425,12 @@ def _dataloader_kwargs(self): pin_memory=torch.cuda.is_available(), persistent_workers=num_workers > 0, ) - # Use 'fork' on Linux. It avoids the forkserver intermediary pipe which - # is fragile when the process has many threads (e.g. OpenBLAS). In - # container environments (Kubernetes) fork works correctly. On - # traditional HPC systems with strict fd-passing restrictions the - # original 'forkserver' may be needed, but the pipe truncation it - # produces under thread pressure is worse than the ancdata warning. - if ( - num_workers > 0 - and torch.multiprocessing.get_start_method(allow_none=True) != "spawn" - ): - import platform - - if platform.system() == "Linux": - kwargs["multiprocessing_context"] = "fork" + # Use 'spawn' for DataLoader workers: CUDA is already initialized by + # Lightning before DataLoaders are created, and forking a process that + # has an active CUDA context corrupts the child's GPU state. 'spawn' + # starts a clean interpreter without inheriting the CUDA context. + if num_workers > 0: + kwargs["multiprocessing_context"] = "spawn" return kwargs def train_dataloader(self): From 4338a3fd0e09aef355e9a1826aa4df8922771b5f Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Mon, 27 Apr 2026 13:52:01 +0200 Subject: [PATCH 07/11] improve port assignment in set_env to avoid EADDRINUSE errors Signed-off-by: Romeo Kienzler --- gridfm_graphkit/__main__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index 607b15e7..a21fd038 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -5,6 +5,7 @@ import subprocess import os +import socket def is_lsf(): return ( @@ -46,7 +47,18 @@ def set_env(): os.environ["MASTER_ADDR"] = HOST_LIST[ 0 ] # Sets the MasterNode to thefirst node on the list of hosts - os.environ["MASTER_PORT"] = "5" + LSB_JOBID[-5:-1] + # Derive a candidate port from job ID, but verify it's free; fall back to + # a dynamically allocated free port to avoid EADDRINUSE when multiple jobs + # share overlapping job ID suffixes. + candidate_port = int("5" + LSB_JOBID[-5:-1]) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _s: + try: + _s.bind(("", candidate_port)) + free_port = candidate_port + except OSError: + _s.bind(("", 0)) + free_port = _s.getsockname()[1] + os.environ["MASTER_PORT"] = str(free_port) os.environ["NODE_RANK"] = str( HOST_LIST.index(os.environ["HOSTNAME"]), ) # Uses the list index for node rank, master node rank must be 0 From be30f7d084649dbd39514aafe341510d2f621718 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Mon, 27 Apr 2026 19:05:14 +0200 Subject: [PATCH 08/11] refactor set_env to derive MASTER_PORT using CRC32 of job ID for better collision avoidance Signed-off-by: Romeo Kienzler --- gridfm_graphkit/__main__.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index a21fd038..8edcf667 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -47,18 +47,12 @@ def set_env(): os.environ["MASTER_ADDR"] = HOST_LIST[ 0 ] # Sets the MasterNode to thefirst node on the list of hosts - # Derive a candidate port from job ID, but verify it's free; fall back to - # a dynamically allocated free port to avoid EADDRINUSE when multiple jobs - # share overlapping job ID suffixes. - candidate_port = int("5" + LSB_JOBID[-5:-1]) - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _s: - try: - _s.bind(("", candidate_port)) - free_port = candidate_port - except OSError: - _s.bind(("", 0)) - free_port = _s.getsockname()[1] - os.environ["MASTER_PORT"] = str(free_port) + # Derive port deterministically from job ID so all ranks agree on the same + # value. Use a wider hash (CRC32 of the full job ID) mapped into the + # ephemeral range [49152, 65535] to minimise collisions with other jobs. + import binascii + crc = binascii.crc32(LSB_JOBID.encode()) & 0xFFFFFFFF + os.environ["MASTER_PORT"] = str(49152 + (crc % 16383)) os.environ["NODE_RANK"] = str( HOST_LIST.index(os.environ["HOSTNAME"]), ) # Uses the list index for node rank, master node rank must be 0 From ae83ff69782f59e3f002b77e6bfccbfb006e537a Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Tue, 5 May 2026 13:02:18 +0200 Subject: [PATCH 09/11] add support for custom multiprocessing start method in CLI arguments Signed-off-by: Romeo Kienzler --- gridfm_graphkit/__main__.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index 602ed00a..4c240af0 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -1,6 +1,7 @@ import argparse from datetime import datetime from gridfm_graphkit.cli import main_cli, benchmark_cli +import torch.multiprocessing import subprocess @@ -97,6 +98,12 @@ def main(): default=False, help="Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision('high').", ) + _start_method_kwargs = dict( + type=str, + default="spawn", + choices=["spawn", "fork", "forkserver"], + help="Multiprocessing start method for dataloader workers (default: spawn). Use 'fork' for faster startup on CPU-only machines.", + ) # ---- TRAIN SUBCOMMAND ---- train_parser = subparsers.add_parser("train", help="Run training") @@ -155,6 +162,7 @@ def main(): action="store_true", help="Print the last training epoch time and a single test metric to stdout.", ) + train_parser.add_argument("--start-method", dest="start_method", **_start_method_kwargs) # ---- FINETUNE SUBCOMMAND ---- finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning") @@ -214,6 +222,7 @@ def main(): action="store_true", help="Print the last training epoch time and a single test metric to stdout.", ) + finetune_parser.add_argument("--start-method", dest="start_method", **_start_method_kwargs) # ---- EVALUATE SUBCOMMAND ---- evaluate_parser = subparsers.add_parser( @@ -286,6 +295,7 @@ def main(): "--save_output", action="store_true", ) + evaluate_parser.add_argument("--start-method", dest="start_method", **_start_method_kwargs) # ---- PREDICT SUBCOMMAND ---- predict_parser = subparsers.add_parser("predict", help="Run prediction") @@ -342,6 +352,7 @@ def main(): default=None, choices=["simple", "advanced", "pytorch"], ) + predict_parser.add_argument("--start-method", dest="start_method", **_start_method_kwargs) # ---- BENCHMARK SUBCOMMAND ---- benchmark_parser = subparsers.add_parser( @@ -389,6 +400,9 @@ def main(): args = parser.parse_args() + start_method = getattr(args, "start_method", "spawn") + torch.multiprocessing.set_start_method(start_method, force=True) + if args.command == "benchmark": benchmark_cli(args) else: From d68345e0e75710c0d9bc17b3023705b8efa407e6 Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Tue, 5 May 2026 13:51:04 +0200 Subject: [PATCH 10/11] remove duplicate batch_size argument from evaluate and predict parsers in main function Signed-off-by: Romeo Kienzler --- gridfm_graphkit/__main__.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index 4c240af0..02879a58 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -268,12 +268,6 @@ def main(): default=None, help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) - evaluate_parser.add_argument( - "--batch_size", - type=int, - default=None, - help="Override training.batch_size from the YAML config.", - ) evaluate_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, @@ -330,12 +324,6 @@ def main(): default=None, help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) - predict_parser.add_argument( - "--batch_size", - type=int, - default=None, - help="Override training.batch_size from the YAML config.", - ) predict_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, From e0e9ffbbc9f2c9f281c5e98e374c76fa8906376a Mon Sep 17 00:00:00 2001 From: Romeo Kienzler Date: Tue, 5 May 2026 16:40:55 +0200 Subject: [PATCH 11/11] apply precommit hooks Signed-off-by: Romeo Kienzler --- docs/quick_start/yaml_config.md | 4 +- gridfm_graphkit/__main__.py | 20 +++- gridfm_graphkit/cli.py | 27 +++-- .../datasets/hetero_powergrid_datamodule.py | 29 ++++-- gridfm_graphkit/datasets/masking.py | 2 + gridfm_graphkit/datasets/normalizers.py | 20 +++- .../datasets/powergrid_hetero_dataset.py | 18 +++- gridfm_graphkit/datasets/task_transforms.py | 3 + gridfm_graphkit/datasets/transforms.py | 1 + gridfm_graphkit/datasets/utils.py | 6 +- gridfm_graphkit/io/registries.py | 1 + gridfm_graphkit/models/utils.py | 3 + gridfm_graphkit/tasks/opf_ac_dc_baseline.py | 98 +++++++++++++------ gridfm_graphkit/tasks/opf_task.py | 36 ++++--- gridfm_graphkit/tasks/pf_task.py | 52 ++++++---- gridfm_graphkit/tasks/se_task.py | 1 + gridfm_graphkit/training/callbacks.py | 1 + gridfm_graphkit/training/loss.py | 19 ++-- 18 files changed, 237 insertions(+), 104 deletions(-) diff --git a/docs/quick_start/yaml_config.md b/docs/quick_start/yaml_config.md index 96c2fa62..568d6f67 100644 --- a/docs/quick_start/yaml_config.md +++ b/docs/quick_start/yaml_config.md @@ -85,12 +85,12 @@ Task name registered in the framework: ### `data.networks` -List of dataset folders under your data root. +List of dataset folders under your data root. Examples: `case14_ieee`, `case118_ieee`, `case2000_goc`, `Texas2k_case1_2016summerpeak`. ### `data.scenarios` -List of scenario counts, one value per network in `data.networks`. +List of scenario counts, one value per network in `data.networks`. Example: with two networks, use two scenario entries in matching order. ### `data.normalization` diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index 02879a58..c898b761 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -8,6 +8,7 @@ import os import socket + def is_lsf(): return ( os.environ.get("LSB_JOBID") is not None @@ -15,6 +16,7 @@ def is_lsf(): and "LSF_ENVDIR" in os.environ # strong LSF indicator ) + def fix_infiniband(): """Configure NCCL to skip Ethernet-only IB ports on this host.""" ibv = subprocess.run("ibv_devinfo", stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -52,6 +54,7 @@ def set_env(): # value. Use a wider hash (CRC32 of the full job ID) mapped into the # ephemeral range [49152, 65535] to minimise collisions with other jobs. import binascii + crc = binascii.crc32(LSB_JOBID.encode()) & 0xFFFFFFFF os.environ["MASTER_PORT"] = str(49152 + (crc % 16383)) os.environ["NODE_RANK"] = str( @@ -62,6 +65,7 @@ def set_env(): ) os.environ["NCCL_IB_CUDA_SUPPORT"] = "1" # Force use of infiniband + def main(): """Parse CLI arguments and dispatch to the selected GridFM subcommand.""" if is_lsf(): @@ -162,7 +166,9 @@ def main(): action="store_true", help="Print the last training epoch time and a single test metric to stdout.", ) - train_parser.add_argument("--start-method", dest="start_method", **_start_method_kwargs) + train_parser.add_argument( + "--start-method", dest="start_method", **_start_method_kwargs, + ) # ---- FINETUNE SUBCOMMAND ---- finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning") @@ -222,7 +228,9 @@ def main(): action="store_true", help="Print the last training epoch time and a single test metric to stdout.", ) - finetune_parser.add_argument("--start-method", dest="start_method", **_start_method_kwargs) + finetune_parser.add_argument( + "--start-method", dest="start_method", **_start_method_kwargs, + ) # ---- EVALUATE SUBCOMMAND ---- evaluate_parser = subparsers.add_parser( @@ -289,7 +297,9 @@ def main(): "--save_output", action="store_true", ) - evaluate_parser.add_argument("--start-method", dest="start_method", **_start_method_kwargs) + evaluate_parser.add_argument( + "--start-method", dest="start_method", **_start_method_kwargs, + ) # ---- PREDICT SUBCOMMAND ---- predict_parser = subparsers.add_parser("predict", help="Run prediction") @@ -340,7 +350,9 @@ def main(): default=None, choices=["simple", "advanced", "pytorch"], ) - predict_parser.add_argument("--start-method", dest="start_method", **_start_method_kwargs) + predict_parser.add_argument( + "--start-method", dest="start_method", **_start_method_kwargs, + ) # ---- BENCHMARK SUBCOMMAND ---- benchmark_parser = subparsers.add_parser( diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 457895a6..caec34be 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -25,7 +25,9 @@ import lightning as L -def _normalize_loaded_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: +def _normalize_loaded_state_dict_keys( + state_dict: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: """Map legacy torch.compile checkpoint keys to the canonical model namespace.""" has_compiled_prefix = any(key.startswith("model._orig_mod.") for key in state_dict) if not has_compiled_prefix: @@ -235,12 +237,17 @@ def main_cli(args): _accelerator = config_args.training.accelerator _strategy = config_args.training.strategy # if mps is available and accelerator is auto, explicitely set accelerator to mps to select the right strategy in the next block - if _accelerator == "auto" and torch.backends.mps.is_available(): + if _accelerator == "auto" and torch.backends.mps.is_available(): _accelerator = "mps" - if _accelerator not in ("mps", "cpu") and isinstance(_strategy, str) and _strategy in ( - "auto", - "ddp", - "ddp_find_unused_parameters_true", + if ( + _accelerator not in ("mps", "cpu") + and isinstance(_strategy, str) + and _strategy + in ( + "auto", + "ddp", + "ddp_find_unused_parameters_true", + ) ): _num_devices = config_args.training.devices _single_device = _num_devices in (1, "1", [1]) or ( @@ -259,7 +266,9 @@ def main_cli(args): if not os.environ.get("MASTER_ADDR"): os.environ["MASTER_ADDR"] = "127.0.0.1" # Use gloo backend to avoid NCCL socket interface issues - _strategy = DDPStrategy(find_unused_parameters=True, process_group_backend="gloo") + _strategy = DDPStrategy( + find_unused_parameters=True, process_group_backend="gloo", + ) trainer = L.Trainer( logger=logger, @@ -284,7 +293,9 @@ def main_cli(args): print("[performance] Validation loss : not available") # Epoch timing if epoch_timer is not None and epoch_timer.last_epoch_time is not None: - print(f"[performance] last epoch time : {epoch_timer.last_epoch_time:.3f}s") + print( + f"[performance] last epoch time : {epoch_timer.last_epoch_time:.3f}s", + ) if ( epoch_timer.last_epoch_iters_per_sec is not None and epoch_timer._last_batch_count > 0 diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index c361ee5b..c1da0b9d 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -121,13 +121,18 @@ def __init__( self._is_setup_done = False if self.split_by_load_scenario_idx: - assert self.split_from_existing_files is None, " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both" + assert self.split_from_existing_files is None, ( + " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both" + ) if self.split_from_existing_files is not None: - assert isinstance(self.split_from_existing_files, str), "`split_from_existing_files` must be an existing folder in string format" + assert isinstance(self.split_from_existing_files, str), ( + "`split_from_existing_files` must be an existing folder in string format" + ) self.split_from_existing_files = Path(self.split_from_existing_files) - assert self.split_from_existing_files.is_dir(), "`split_from_existing_files` must be an existing folder in string format" - + assert self.split_from_existing_files.is_dir(), ( + "`split_from_existing_files` must be an existing folder in string format" + ) def setup(self, stage: str): if self._is_setup_done: @@ -184,7 +189,6 @@ def setup(self, stage: str): # Create a subset all_indices = list(range(len(dataset))) - if self.split_from_existing_files is not None: warnings.warn( "`data.scenarios` is ignored when `split_from_existing_files` is set; " @@ -229,13 +233,13 @@ def setup(self, stage: str): # load_scenario for each scenario in the subset load_scenarios = dataset.load_scenarios[subset_indices] - dataset = Subset(dataset, subset_indices) - + if self.dataset_wrapper is not None: wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) - dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir) - + dataset = wrapper_cls( + dataset, cache_dir=self.dataset_wrapper_cache_dir, + ) # Random seed set before every split, same as above np.random.seed(self.args.seed) @@ -434,7 +438,12 @@ def _dataloader_kwargs(self): return kwargs def train_dataloader(self): - print("creating train dataloader for rank ", dist.get_rank() if dist.is_available() and dist.is_initialized() else "not distributed") + print( + "creating train dataloader for rank ", + dist.get_rank() + if dist.is_available() and dist.is_initialized() + else "not distributed", + ) return DataLoader( self.train_dataset_multi, batch_size=self.batch_size, diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index df2f657c..c91303d7 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -158,6 +158,7 @@ def forward(self, data): class BusToGenBroadcaster(MessagePassing): """Broadcast per-bus values to connected generators via graph propagation.""" + def __init__(self, aggr="add"): super().__init__(aggr=aggr) @@ -176,6 +177,7 @@ def message(self, x_j): class SimulateMeasurements(BaseTransform): """Add configurable noise/outliers and masks to simulate measured quantities.""" + def __init__(self, args): super().__init__() self.measurements = args.task.measurements diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index eb5652d7..66fd35ad 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -228,8 +228,12 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= self.baseMVA - data.baseMVA = torch.tensor(self.baseMVA, dtype=data.x_dict["bus"].dtype) # # needs to be float32 for MPS - data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS + data.baseMVA = torch.tensor( + self.baseMVA, dtype=data.x_dict["bus"].dtype, + ) # # needs to be float32 for MPS + data.is_normalized = torch.tensor( + True, dtype=torch.bool, + ) # needs to be bool for MPS def inverse_transform(self, data: HeteroData): if self.baseMVA is None or self.baseMVA == 0: @@ -299,7 +303,9 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= self.baseMVA - data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = torch.tensor( + False, dtype=torch.bool, + ) # needs to be bool for MPS def inverse_output(self, output, batch): bus_output = output["bus"] @@ -510,7 +516,9 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= e_b - data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = torch.tensor( + True, dtype=torch.bool, + ) # needs to be bool for MPS def inverse_transform(self, data: HeteroData): """Undo per-unit normalization (multiply by baseMVA, inverse log1p for cost coeffs).""" @@ -573,7 +581,9 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= e_b - data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = torch.tensor( + False, dtype=torch.bool, + ) # needs to be bool for MPS def inverse_output(self, output, batch): """ diff --git a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py index 82f57a57..39b9bf23 100644 --- a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py +++ b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py @@ -73,9 +73,13 @@ def process(self): ) if "load_scenario_idx" in bus_data.columns: load_scenarios = torch.tensor( - bus_data.groupby("scenario", sort=True)["load_scenario_idx"].first().values, + bus_data.groupby("scenario", sort=True)["load_scenario_idx"] + .first() + .values, + ) + torch.save( + load_scenarios, osp.join(self.processed_dir, "load_scenarios.pt"), ) - torch.save(load_scenarios, osp.join(self.processed_dir, "load_scenarios.pt")) agg_gen = ( gen_data.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] @@ -136,7 +140,9 @@ def process(self): ] + common_branch_features # Group by scenario - bus_groups = bus_data.groupby("scenario") # Groupby preserves the order of rows within each group. + bus_groups = bus_data.groupby( + "scenario", + ) # Groupby preserves the order of rows within each group. # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.groupby.html gen_groups = gen_data.groupby("scenario") branch_groups = branch_data.groupby("scenario") @@ -159,8 +165,10 @@ def process(self): # Bus nodes bus_df = bus_groups.get_group(scenario) # assert that the buses are in increasing order - assert (bus_df["bus"].values == torch.arange(len(bus_df))).all(), "Buses are not in increasing order" - #todo: we should remove this assert and store the bus idx in the tensors + assert (bus_df["bus"].values == torch.arange(len(bus_df))).all(), ( + "Buses are not in increasing order" + ) + # todo: we should remove this assert and store the bus idx in the tensors # right now we need the increasing order for e.g. the predict step that uses torch.arange(n_nodes) to index the buses. data["bus"].x = torch.tensor(bus_df[bus_features].values, dtype=torch.float) diff --git a/gridfm_graphkit/datasets/task_transforms.py b/gridfm_graphkit/datasets/task_transforms.py index dffb66cb..20a5b798 100644 --- a/gridfm_graphkit/datasets/task_transforms.py +++ b/gridfm_graphkit/datasets/task_transforms.py @@ -16,6 +16,7 @@ @TRANSFORM_REGISTRY.register("PowerFlow") class PowerFlowTransforms(Compose): """Compose preprocessing and masking transforms for PowerFlow datasets.""" + def __init__(self, args): transforms = [] @@ -31,6 +32,7 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("OptimalPowerFlow") class OptimalPowerFlowTransforms(Compose): """Compose preprocessing and masking transforms for OptimalPowerFlow datasets.""" + def __init__(self, args): transforms = [] @@ -46,6 +48,7 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("StateEstimation") class StateEstimationTransforms(Compose): """Compose preprocessing and measurement transforms for StateEstimation datasets.""" + def __init__(self, args): transforms = [] diff --git a/gridfm_graphkit/datasets/transforms.py b/gridfm_graphkit/datasets/transforms.py index c6891dc2..f58a1fec 100644 --- a/gridfm_graphkit/datasets/transforms.py +++ b/gridfm_graphkit/datasets/transforms.py @@ -97,6 +97,7 @@ def forward(self, data): class LoadGridParamsFromPath(BaseTransform): """Inject static grid parameters from a saved grid template into each sample.""" + def __init__(self, args): super().__init__() self.grid_path = args.task.grid_path diff --git a/gridfm_graphkit/datasets/utils.py b/gridfm_graphkit/datasets/utils.py index 65b34f4e..8d16de92 100644 --- a/gridfm_graphkit/datasets/utils.py +++ b/gridfm_graphkit/datasets/utils.py @@ -114,8 +114,8 @@ def split_from_existing_files( split_dataset = Subset(dataset, split_indices) output.append(split_dataset) split_indices = list(split_indices) - print(f'{split=} {len(split_indices)=}') - indices[split]=[int(t.item()) for t in split_indices] + print(f"{split=} {len(split_indices)=}") + indices[split] = [int(t.item()) for t in split_indices] output = tuple(output) - return output, indices \ No newline at end of file + return output, indices diff --git a/gridfm_graphkit/io/registries.py b/gridfm_graphkit/io/registries.py index 65d596a9..c26f07b9 100644 --- a/gridfm_graphkit/io/registries.py +++ b/gridfm_graphkit/io/registries.py @@ -1,5 +1,6 @@ class Registry: """Simple name-to-object registry with decorator-based registration.""" + def __init__(self, name: str): self._name = name self._registry = {} diff --git a/gridfm_graphkit/models/utils.py b/gridfm_graphkit/models/utils.py index bc4b9bfa..41b75f20 100644 --- a/gridfm_graphkit/models/utils.py +++ b/gridfm_graphkit/models/utils.py @@ -82,6 +82,7 @@ def compute_shunt_power(bus_data_pred, bus_data_orig): @PHYSICS_DECODER_REGISTRY.register("OptimalPowerFlow") class PhysicsDecoderOPF(nn.Module): """Map network outputs to OPF-consistent bus states using physics constraints.""" + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): mask_pv = mask_dict["PV"] mask_ref = mask_dict["REF"] @@ -117,6 +118,7 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("PowerFlow") class PhysicsDecoderPF(nn.Module): """Map network outputs to PF-consistent bus states using physics constraints.""" + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): """ PF decoder: @@ -165,6 +167,7 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("StateEstimation") class PhysicsDecoderSE(nn.Module): """Map network outputs to SE targets via bus power-balance relations.""" + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): p_shunt, q_shunt = compute_shunt_power(bus_data_pred, bus_data_orig) Vm_out = bus_data_pred[:, VM_OUT] diff --git a/gridfm_graphkit/tasks/opf_ac_dc_baseline.py b/gridfm_graphkit/tasks/opf_ac_dc_baseline.py index 7dcd3338..40cb8258 100644 --- a/gridfm_graphkit/tasks/opf_ac_dc_baseline.py +++ b/gridfm_graphkit/tasks/opf_ac_dc_baseline.py @@ -65,7 +65,9 @@ def _load_test_data(data_dir: str, test_scenario_ids: list[int]): bus_df = bus_df[bus_df["scenario"].isin(test_set)].reset_index(drop=True) gen_df = gen_df[gen_df["scenario"].isin(test_set)].reset_index(drop=True) branch_df = branch_df[branch_df["scenario"].isin(test_set)].reset_index(drop=True) - runtime_df = runtime_df[runtime_df["scenario"].isin(test_set)].reset_index(drop=True) + runtime_df = runtime_df[runtime_df["scenario"].isin(test_set)].reset_index( + drop=True, + ) print( f" Loaded {len(bus_df)} bus rows, {len(gen_df)} gen rows, " @@ -85,8 +87,12 @@ def _compute_optimality_gap(gen_df: pd.DataFrame) -> dict: pg_ac = gen_df["p_mw"].to_numpy(dtype=float) pg_dc = gen_df["p_mw_dc"].to_numpy(dtype=float) g = gen_df.copy() - g["cost_ac"] = (c0 + c1 * pg_ac + c2 * pg_ac * pg_ac) * g["in_service"] # all is already in MW - g["cost_dc"] = (c0 + c1 * pg_dc + c2 * pg_dc * pg_dc) * g["in_service"] # all is already in MW + g["cost_ac"] = (c0 + c1 * pg_ac + c2 * pg_ac * pg_ac) * g[ + "in_service" + ] # all is already in MW + g["cost_dc"] = (c0 + c1 * pg_dc + c2 * pg_dc * pg_dc) * g[ + "in_service" + ] # all is already in MW per_scenario = g.groupby("scenario")[["cost_ac", "cost_dc"]].sum() cost_ac = per_scenario["cost_ac"].to_numpy(dtype=float) cost_dc = per_scenario["cost_dc"].to_numpy(dtype=float) @@ -113,28 +119,47 @@ def _compute_pg_violations(gen_df: pd.DataFrame) -> dict: def _compute_qg_violations_ac(bus_df: pd.DataFrame, gen_df: pd.DataFrame) -> dict: """Compute AC reactive-power limit violations for PV/REF buses.""" - # opf_task style on bus Qg; AC only + # opf_task style on bus Qg; AC only bus = bus_df.copy() qg = bus["Qg"].to_numpy(dtype=float) # complain if max_q_mvar == min_q_mvar for some gens of gen_df - assert (gen_df["max_q_mvar"] == gen_df["min_q_mvar"]).any() == False, "max_q_mvar == min_q_mvar for some gens of gen_df" + assert (gen_df["max_q_mvar"] == gen_df["min_q_mvar"]).any() == False, ( + "max_q_mvar == min_q_mvar for some gens of gen_df" + ) agg_gen = ( - gen_df.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] - .sum() - .reset_index()) + gen_df.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] + .sum() + .reset_index() + ) bus = bus.merge(agg_gen, on=["scenario", "bus"], how="left") - assert bus[bus["PV"]==1]["min_q_mvar"].isna().sum() == 0, "PV buses have no min_q_mvar" - assert bus[bus["PV"]==1]["max_q_mvar"].isna().sum() == 0, "PV buses have no max_q_mvar" - assert bus[bus["REF"]==1]["min_q_mvar"].isna().sum() == 0, "REF buses have no min_q_mvar" - assert bus[bus["REF"]==1]["max_q_mvar"].isna().sum() == 0, "REF buses have no max_q_mvar" - bus["qg_violation_amount"] = np.maximum(qg - bus["max_q_mvar"], 0.0) + np.maximum(bus["min_q_mvar"] - qg, 0.0) + assert bus[bus["PV"] == 1]["min_q_mvar"].isna().sum() == 0, ( + "PV buses have no min_q_mvar" + ) + assert bus[bus["PV"] == 1]["max_q_mvar"].isna().sum() == 0, ( + "PV buses have no max_q_mvar" + ) + assert bus[bus["REF"] == 1]["min_q_mvar"].isna().sum() == 0, ( + "REF buses have no min_q_mvar" + ) + assert bus[bus["REF"] == 1]["max_q_mvar"].isna().sum() == 0, ( + "REF buses have no max_q_mvar" + ) + bus["qg_violation_amount"] = np.maximum(qg - bus["max_q_mvar"], 0.0) + np.maximum( + bus["min_q_mvar"] - qg, 0.0, + ) pv = bus[bus["PV"] == 1] ref = bus[bus["REF"] == 1] pv_ref = bus[(bus["PV"] == 1) | (bus["REF"] == 1)] return { - "AC Mean Qg violation PV buses": float(np.nanmean(pv["qg_violation_amount"].to_numpy(dtype=float))), - "AC Mean Qg violation REF buses": float(np.nanmean(ref["qg_violation_amount"].to_numpy(dtype=float))), - "AC Mean Qg violation": float(np.nanmean(pv_ref["qg_violation_amount"].to_numpy(dtype=float))), + "AC Mean Qg violation PV buses": float( + np.nanmean(pv["qg_violation_amount"].to_numpy(dtype=float)), + ), + "AC Mean Qg violation REF buses": float( + np.nanmean(ref["qg_violation_amount"].to_numpy(dtype=float)), + ), + "AC Mean Qg violation": float( + np.nanmean(pv_ref["qg_violation_amount"].to_numpy(dtype=float)), + ), } @@ -142,13 +167,21 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> """Compute AC/DC branch thermal and angle-limit violation statistics.""" rate = branch_df["rate_a"].to_numpy(dtype=float) ac_from = np.sqrt( - branch_df["pf"].to_numpy(dtype=float) ** 2 + branch_df["qf"].to_numpy(dtype=float) ** 2, + branch_df["pf"].to_numpy(dtype=float) ** 2 + + branch_df["qf"].to_numpy(dtype=float) ** 2, ) ac_to = np.sqrt( - branch_df["pt"].to_numpy(dtype=float) ** 2 + branch_df["qt"].to_numpy(dtype=float) ** 2, + branch_df["pt"].to_numpy(dtype=float) ** 2 + + branch_df["qt"].to_numpy(dtype=float) ** 2, + ) + dc_from = np.sqrt( + branch_df["pf_dc_computed"].to_numpy(dtype=float) ** 2 + + branch_df["qf_dc_computed"].to_numpy(dtype=float) ** 2, + ) # reactive part is needed here + dc_to = np.sqrt( + branch_df["pt_dc_computed"].to_numpy(dtype=float) ** 2 + + branch_df["qt_dc_computed"].to_numpy(dtype=float) ** 2, ) - dc_from = np.sqrt(branch_df["pf_dc_computed"].to_numpy(dtype=float) ** 2 + branch_df["qf_dc_computed"].to_numpy(dtype=float) ** 2) # reactive part is needed here - dc_to = np.sqrt(branch_df["pt_dc_computed"].to_numpy(dtype=float) ** 2 + branch_df["qt_dc_computed"].to_numpy(dtype=float) ** 2) ac_thermal_from = np.maximum(ac_from - rate, 0.0) ac_thermal_to = np.maximum(ac_to - rate, 0.0) @@ -157,7 +190,7 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> bus_angles = bus_df[["scenario", "bus", "Va", "Va_dc"]] # convert to radians - bus_angles.loc[:, "Va"] = bus_angles["Va"] * np.pi / 180.0 + bus_angles.loc[:, "Va"] = bus_angles["Va"] * np.pi / 180.0 bus_angles.loc[:, "Va_dc"] = bus_angles["Va_dc"] * np.pi / 180.0 from_angles = bus_angles.rename( columns={"bus": "from_bus", "Va": "Va_from", "Va_dc": "Va_dc_from"}, @@ -167,10 +200,12 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> ) br = branch_df.merge(from_angles, on=["scenario", "from_bus"], how="left") br = br.merge(to_angles, on=["scenario", "to_bus"], how="left") - + # AC angle ac_angle_diff = br["Va_from"] - br["Va_to"] - ac_angle_diff = (ac_angle_diff + np.pi) % (2 * np.pi) - np.pi # wrap to [-pi, pi] + ac_angle_diff = (ac_angle_diff + np.pi) % ( + 2 * np.pi + ) - np.pi # wrap to [-pi, pi] ac_angle_excess_low = np.maximum(br["ang_min"] - ac_angle_diff, 0.0) ac_angle_excess_high = np.maximum(ac_angle_diff - br["ang_max"], 0.0) mean_ac_angle_violation = np.mean(ac_angle_excess_low + ac_angle_excess_high) @@ -182,12 +217,20 @@ def _compute_branch_violations(branch_df: pd.DataFrame, bus_df: pd.DataFrame) -> mean_dc_angle_violation = np.mean(dc_angle_excess_low + dc_angle_excess_high) return { - "AC Mean branch thermal violation from (MVA)": float(np.nanmean(ac_thermal_from)), + "AC Mean branch thermal violation from (MVA)": float( + np.nanmean(ac_thermal_from), + ), "AC Mean branch thermal violation to (MVA)": float(np.nanmean(ac_thermal_to)), - "AC Mean branch angle difference violation (radians)": float(mean_ac_angle_violation), - "DC Mean branch thermal violation from (MVA)": float(np.nanmean(dc_thermal_from)), + "AC Mean branch angle difference violation (radians)": float( + mean_ac_angle_violation, + ), + "DC Mean branch thermal violation from (MVA)": float( + np.nanmean(dc_thermal_from), + ), "DC Mean branch thermal violation to (MVA)": float(np.nanmean(dc_thermal_to)), - "DC Mean branch angle difference violation (radians)": float(mean_dc_angle_violation), + "DC Mean branch angle difference violation (radians)": float( + mean_dc_angle_violation, + ), } @@ -258,7 +301,6 @@ def compute_opf_ac_dc_metrics( branch_df["pt_dc_computed"] = pt_dc branch_df["qf_dc_computed"] = qf_dc branch_df["qt_dc_computed"] = qt_dc - opf_extra = {} opf_extra.update(_compute_optimality_gap(gen_df)) diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index dbb1baab..8268a25b 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -87,14 +87,18 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): c2 = batch.x_dict["gen"][:, C2_H] target_pg = batch.y_dict["gen"].squeeze() pred_pg = output["gen"].squeeze() - gen_cost_gt = (c0 + c1 * target_pg + c2 * target_pg**2) # assumes all branches are on! - gen_cost_pred = (c0 + c1 * pred_pg + c2 * pred_pg**2) # assumes all branches are on! + gen_cost_gt = ( + c0 + c1 * target_pg + c2 * target_pg**2 + ) # assumes all branches are on! + gen_cost_pred = ( + c0 + c1 * pred_pg + c2 * pred_pg**2 + ) # assumes all branches are on! gen_batch = batch.batch_dict["gen"] # shape: [N_gen_total] cost_gt = scatter_add(gen_cost_gt, gen_batch, dim=0) cost_pred = scatter_add(gen_cost_pred, gen_batch, dim=0) - + optimality_gap = torch.mean(torch.abs((cost_pred - cost_gt) / cost_gt * 100)) agg_gen_on_bus = scatter_add( @@ -138,14 +142,16 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): bus_angles = output["bus"][:, VA_OUT] # in degrees from_bus = bus_edge_index[0] to_bus = bus_edge_index[1] - angle_diff = bus_angles[from_bus] - bus_angles[to_bus] # keep sign - angle_diff = (angle_diff + torch.pi) % (2 * torch.pi) - torch.pi # wrap to [-pi, pi] + angle_diff = bus_angles[from_bus] - bus_angles[to_bus] # keep sign + angle_diff = (angle_diff + torch.pi) % ( + 2 * torch.pi + ) - torch.pi # wrap to [-pi, pi] angle_excess_low = F.relu(angle_min - angle_diff) angle_excess_high = F.relu(angle_diff - angle_max) branch_angle_violation_mean = torch.mean( - angle_excess_low + angle_excess_high - ) # mean of the abs violation + angle_excess_low + angle_excess_high, + ) # mean of the abs violation P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( @@ -174,8 +180,8 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): mean_Qg_violation_PV = Qg_violation_amount[mask_PV].mean() mean_Qg_violation_REF = Qg_violation_amount[mask_REF].mean() - mask_PV_REF = mask_PV | mask_REF # PV or REF buses - mean_Qg_violation = Qg_violation_amount[mask_PV_REF].mean() # + mask_PV_REF = mask_PV | mask_REF # PV or REF buses + mean_Qg_violation = Qg_violation_amount[mask_PV_REF].mean() # if self.args.verbose: mean_res_P_PQ, max_res_P_PQ = residual_stats_by_type( @@ -270,7 +276,9 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Branch voltage angle difference violations"] = ( branch_angle_violation_mean ) - loss_dict["Mean Qg violation PV buses"] = mean_Qg_violation_PV # mean of the abs violation over the entire batch (all oines in the batch). + loss_dict["Mean Qg violation PV buses"] = ( + mean_Qg_violation_PV # mean of the abs violation over the entire batch (all oines in the batch). + ) # this is then overaged over all the batches and gives same weight to all batches despite them possibly having varying number of branches loss_dict["Mean Qg violation REF buses"] = mean_Qg_violation_REF loss_dict["Mean Qg violation"] = mean_Qg_violation @@ -372,7 +380,9 @@ def on_test_end(self): "Branch thermal violation from", " ", ) - branch_thermal_violation_to = metrics.get("Branch thermal violation to", " ") + branch_thermal_violation_to = metrics.get( + "Branch thermal violation to", " ", + ) branch_angle_violation = metrics.get( "Branch voltage angle difference violations", " ", @@ -543,9 +553,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): local_bus_idx = torch.cat( [ torch.arange(c, device=bus_batch.device) - for c in torch.bincount(bus_batch) + for c in torch.bincount(bus_batch) ], - ) # this works because the order of the buses is preserved by the groupby in the dataset wrapper and datakit data has buses in increasing order. + ) # this works because the order of the buses is preserved by the groupby in the dataset wrapper and datakit data has buses in increasing order. bus_x = batch.x_dict["bus"] bus_y = batch.y_dict["bus"] diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index 948a25e0..ad62be17 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -245,7 +245,7 @@ def on_test_end(self): # Only rank 0 proceeds with logging, CSV writing, and plotting if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: - self.test_outputs.clear() # clear the test outputs for other ranks + self.test_outputs.clear() # clear the test outputs for other ranks return if isinstance(self.logger, MLFlowLogger): @@ -356,22 +356,38 @@ def on_test_end(self): self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): - output, _ = self.shared_step(batch) # get the predicted output from the model - - self.data_normalizers[dataloader_idx].inverse_transform(batch) # normalize the batch data back to the original scale - self.data_normalizers[dataloader_idx].inverse_output(output, batch) # inverse transform the predicted output back to the original scale - - branch_flow_layer = ComputeBranchFlow() # layer to compute the branch flows - node_injection_layer = ComputeNodeInjection() # layer to compute the node injections - node_residuals_layer = ComputeNodeResiduals() # layer to compute the node residuals - - num_bus = batch.x_dict["bus"].size(0) # number of buses in the batch - bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] # from and to buses - bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] # edge attributes (admittance) of the bus connections - - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) # compute the branch flows - P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) # compute the node injections - residual_P, residual_Q = node_residuals_layer( # compute the node residuals + output, _ = self.shared_step(batch) # get the predicted output from the model + + self.data_normalizers[dataloader_idx].inverse_transform( + batch, + ) # normalize the batch data back to the original scale + self.data_normalizers[dataloader_idx].inverse_output( + output, batch, + ) # inverse transform the predicted output back to the original scale + + branch_flow_layer = ComputeBranchFlow() # layer to compute the branch flows + node_injection_layer = ( + ComputeNodeInjection() + ) # layer to compute the node injections + node_residuals_layer = ( + ComputeNodeResiduals() + ) # layer to compute the node residuals + + num_bus = batch.x_dict["bus"].size(0) # number of buses in the batch + bus_edge_index = batch.edge_index_dict[ + ("bus", "connects", "bus") + ] # from and to buses + bus_edge_attr = batch.edge_attr_dict[ + ("bus", "connects", "bus") + ] # edge attributes (admittance) of the bus connections + + Pft, Qft = branch_flow_layer( + output["bus"], bus_edge_index, bus_edge_attr, + ) # compute the branch flows + P_in, Q_in = node_injection_layer( + Pft, Qft, bus_edge_index, num_bus, + ) # compute the node injections + residual_P, residual_Q = node_residuals_layer( # compute the node residuals P_in, Q_in, output["bus"], @@ -388,7 +404,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): torch.arange(c, device=bus_batch.device) for c in torch.bincount(bus_batch) ], - ) # this is based on the assumptions that the buses within a graph are ordered and indexed as 0 ... n_nodes-1. + ) # this is based on the assumptions that the buses within a graph are ordered and indexed as 0 ... n_nodes-1. # todo: we should remove this assert and store the bus idx in the tensors # right now we need the increasing order and we have an assert in the dataset to check it. bus_x = batch.x_dict["bus"] diff --git a/gridfm_graphkit/tasks/se_task.py b/gridfm_graphkit/tasks/se_task.py index 36667ad2..78aa0fa5 100644 --- a/gridfm_graphkit/tasks/se_task.py +++ b/gridfm_graphkit/tasks/se_task.py @@ -27,6 +27,7 @@ @TASK_REGISTRY.register("StateEstimation") class StateEstimationTask(ReconstructionTask): """State-estimation task with evaluation plots for masked and noisy measurements.""" + def __init__(self, args, data_normalizers): super().__init__(args, data_normalizers) diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index ba7a4049..116aa4f0 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -42,6 +42,7 @@ def last_epoch_iters_per_sec(self) -> float | None: class SaveBestModelStateDict(Callback): """Persist the best model state_dict according to a monitored validation metric.""" + def __init__( self, monitor: str, diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index a0521fc2..e7936fc0 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -20,7 +20,7 @@ # Generator feature indices PG_H, # Qg Limits - MIN_QG_H, + MIN_QG_H, MAX_QG_H, ) @@ -85,6 +85,7 @@ def forward( @LOSS_REGISTRY.register("MaskedGenMSE") class MaskedGenMSE(torch.nn.Module): """Compute MSE on generator targets restricted to generator mask entries.""" + def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -110,6 +111,7 @@ def forward( @LOSS_REGISTRY.register("MaskedBusMSE") class MaskedBusMSE(torch.nn.Module): """Compute MSE on selected bus targets, respecting task-specific output columns.""" + def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -242,6 +244,7 @@ def forward( @LOSS_REGISTRY.register("LayeredWeightedPhysics") class LayeredWeightedPhysicsLoss(BaseLoss): """Combine intermediate physics residuals using normalized geometric weights.""" + def __init__(self, loss_args, args) -> None: super().__init__() self.base_weight = loss_args.base_weight @@ -283,6 +286,7 @@ def forward( @LOSS_REGISTRY.register("LossPerDim") class LossPerDim(BaseLoss): """Compute MAE/MSE for one named physical dimension of bus outputs.""" + def __init__(self, loss_args, args): super(LossPerDim, self).__init__() self.reduction = "mean" @@ -362,8 +366,8 @@ def forward( Qg_max = x_dict["bus"][:, MAX_QG_H] Qg_min = x_dict["bus"][:, MIN_QG_H] - max_penalty_mask = (Qg_pred > Qg_max) - min_penalty_mask = (Qg_pred < Qg_min) + max_penalty_mask = Qg_pred > Qg_max + min_penalty_mask = Qg_pred < Qg_min mask_PQ = mask["PQ"] # PQ buses mask_PV = mask["PV"] # PV buses @@ -376,13 +380,13 @@ def forward( Qg_over = Qg_over[max_penalty_mask].mean() Qg_under = Qg_under[min_penalty_mask].mean() - - if Qg_over!=Qg_over: # replacing nan with 0 + + if Qg_over != Qg_over: # replacing nan with 0 Qg_over = 0.0 - if Qg_under!=Qg_under: # replacing nan with 0 + if Qg_under != Qg_under: # replacing nan with 0 Qg_under = 0.0 - penalty_loss = Qg_over + Qg_under + penalty_loss = Qg_over + Qg_under loss += penalty_loss try: @@ -391,4 +395,3 @@ def forward( output = {"loss": loss, "Qg Violation Penalty loss": loss} return output -