diff --git a/gridfm_graphkit PR-VLDloss.zip b/gridfm_graphkit PR-VLDloss.zip new file mode 100644 index 00000000..244841a8 Binary files /dev/null and b/gridfm_graphkit PR-VLDloss.zip differ diff --git a/gridfm_graphkit/datasets/__init__.py b/gridfm_graphkit/datasets/__init__.py index fee635f5..739b00a2 100644 --- a/gridfm_graphkit/datasets/__init__.py +++ b/gridfm_graphkit/datasets/__init__.py @@ -5,6 +5,10 @@ PowerFlowTransforms, OptimalPowerFlowTransforms, StateEstimationTransforms, + ############## + VoltageLossDetectionTransforms + ############# + ) __all__ = [ @@ -12,4 +16,7 @@ "PowerFlowTransforms", "OptimalPowerFlowTransforms", "StateEstimationTransforms", + ################### + "VoltageLossDetectionTransforms" + ################## ] diff --git a/gridfm_graphkit/datasets/globals.py b/gridfm_graphkit/datasets/globals.py index ab3c7e3d..8bf5690a 100644 --- a/gridfm_graphkit/datasets/globals.py +++ b/gridfm_graphkit/datasets/globals.py @@ -17,6 +17,10 @@ BS = 13 # Shunt susceptance (p.u.) VN_KV = 14 # Nominal voltage +##ADDITIONAL INPUT FEATURES OF BUS STATUS +BUS_BASE_STATUS_H = 15 # bus ON/OFF status in the pre-contingency(base topology) +BUS_CONT_H = 16 # bus contingency to be applied + # ========================= # === OUTPUT FEATURE INDICES == # ========================= @@ -26,6 +30,10 @@ QG_OUT = 3 PG_OUT_GEN = 0 +##ADDITIONAL OUTPUT FEATURE OF BUS STATUS +PHYSICAL_BUS_DIM = 4 # Physical bus outputs predicted by the model for VLD tasks +BUS_STATUS_TARGET = 5 # post-contingency energized/de-energized target +BUS_STATUS_LOGIT_OUT = 4 # # Extra model output columns for VLD tasks # ================================ # === GENERATOR FEATURE INDICES == @@ -52,3 +60,8 @@ ANG_MAX = 8 # Angle max (deg) RATE_A = 9 # Thermal limit B_ON = 10 # Branch on/off + +##ADDITIONAL INPUT FEATURES OF BUS STATUS +BRANCH_BASE_STATUS_E = 11 # branch ON/OFF status in the pre-contingency(base topology) +BRANCH_CONT_E = 12 # branch contingency to be applied + diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index e5374970..54db39c7 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -14,14 +14,12 @@ from gridfm_graphkit.datasets.utils import ( split_dataset, split_dataset_by_load_scenario_idx, - split_from_existing_files, ) from gridfm_graphkit.datasets.powergrid_hetero_dataset import HeteroGridDatasetDisk import numpy as np import random import warnings import lightning as L -from pathlib import Path from typing import List from lightning.pytorch.loggers import MLFlowLogger @@ -103,11 +101,6 @@ def __init__( "split_by_load_scenario_idx", False, ) - self.split_from_existing_files = getattr( - args.data, - "split_from_existing_files", - None, - ) self.args = args self.normalizer_stats_path = normalizer_stats_path self.data_normalizers = [] @@ -120,15 +113,6 @@ def __init__( self.test_scenario_ids: List[List[int]] = [] self._is_setup_done = False - if self.split_by_load_scenario_idx: - assert self.split_from_existing_files is None, " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both" - - if self.split_from_existing_files is not None: - assert isinstance(self.split_from_existing_files, str), "`split_from_existing_files` must be an existing folder in string format" - self.split_from_existing_files = Path(self.split_from_existing_files) - assert self.split_from_existing_files.is_dir(), "`split_from_existing_files` must be an existing folder in string format" - - def setup(self, stage: str): if self._is_setup_done: print(f"Setup already done for stage={stage}, skipping...") @@ -183,94 +167,54 @@ def setup(self, stage: str): # Create a subset all_indices = list(range(len(dataset))) + # Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order + random.seed(self.args.seed) + random.shuffle(all_indices) + subset_indices = all_indices[:num_scenarios] + # load_scenario for each scenario in the subset + load_scenarios = dataset.load_scenarios[subset_indices] - if self.split_from_existing_files is not None: - warnings.warn( - "`data.scenarios` is ignored when `split_from_existing_files` is set; " - "train/val/test scenario ids are loaded from the provided split files.", - ) + dataset = Subset(dataset, subset_indices) - if self.dataset_wrapper is not None: - wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) - dataset = wrapper_cls( - dataset, - cache_dir=self.dataset_wrapper_cache_dir, - ) + if self.dataset_wrapper is not None: + wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) + dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir) - (train_dataset, val_dataset, test_dataset), subset_indices = ( - split_from_existing_files( - dataset, - self.split_from_existing_files, - ) - ) - train_scenario_ids = subset_indices["train"] - val_scenario_ids = subset_indices["val"] - test_scenario_ids = subset_indices["test"] - num_scenarios = int( - np.unique( - train_scenario_ids + val_scenario_ids + test_scenario_ids, - ).shape[0], - ) - else: - # Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order - random.seed(self.args.seed) - random.shuffle(all_indices) - subset_indices = all_indices[:num_scenarios] - - load_scenarios = None - if self.split_by_load_scenario_idx: - if not hasattr(dataset, "load_scenarios"): - raise ValueError( - "`data.split_by_load_scenario_idx=true` requires " - "`load_scenario_idx` in raw bus data so " - "`processed/load_scenarios.pt` can be created.", - ) - # load_scenario for each scenario in the subset - load_scenarios = dataset.load_scenarios[subset_indices] - - - dataset = Subset(dataset, subset_indices) - - if self.dataset_wrapper is not None: - wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) - dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir) - - - # Random seed set before every split, same as above - np.random.seed(self.args.seed) - if self.split_by_load_scenario_idx: - train_dataset, val_dataset, test_dataset = ( - split_dataset_by_load_scenario_idx( - dataset, - self.data_dir, - load_scenarios, - self.args.data.val_ratio, - self.args.data.test_ratio, - ) - ) - else: - train_dataset, val_dataset, test_dataset = split_dataset( + # Random seed set before every split, same as above + np.random.seed(self.args.seed) + if self.split_by_load_scenario_idx: + train_dataset, val_dataset, test_dataset = ( + split_dataset_by_load_scenario_idx( dataset, self.data_dir, + load_scenarios, self.args.data.val_ratio, self.args.data.test_ratio, ) - - # Extract scenario IDs for each split - train_scenario_ids = self._extract_scenario_ids( - train_dataset, - subset_indices, - ) - val_scenario_ids = self._extract_scenario_ids( - val_dataset, - subset_indices, ) - test_scenario_ids = self._extract_scenario_ids( - test_dataset, - subset_indices, + else: + train_dataset, val_dataset, test_dataset = split_dataset( + dataset, + self.data_dir, + self.args.data.val_ratio, + self.args.data.test_ratio, ) + # Extract scenario IDs for each split + train_scenario_ids = self._extract_scenario_ids( + train_dataset, + subset_indices, + ) + val_scenario_ids = self._extract_scenario_ids( + val_dataset, + subset_indices, + ) + test_scenario_ids = self._extract_scenario_ids( + test_dataset, + subset_indices, + ) + # Fit normalizer: restore from saved stats only for fit_on_train # normalizers (global baseMVA must match the model's training run). # fit_on_dataset normalizers compute per-scenario stats and must @@ -425,12 +369,11 @@ def _dataloader_kwargs(self): pin_memory=torch.cuda.is_available(), persistent_workers=num_workers > 0, ) - # Use 'fork' on Linux. It avoids the forkserver intermediary pipe which - # is fragile when the process has many threads (e.g. OpenBLAS). In - # container environments (Kubernetes) fork works correctly. On - # traditional HPC systems with strict fd-passing restrictions the - # original 'forkserver' may be needed, but the pipe truncation it - # produces under thread pressure is worse than the ancdata warning. + # On Linux some HPC environments restrict passing open file descriptors + # via Unix socket ancillary data (SCM_RIGHTS), which causes + # "received 0 items of ancdata" with the default 'fork' start method. + # 'forkserver' avoids fd-passing by having a dedicated server process + # that re-opens shared memory objects by name instead. if ( num_workers > 0 and torch.multiprocessing.get_start_method(allow_none=True) != "spawn" @@ -438,11 +381,10 @@ def _dataloader_kwargs(self): import platform if platform.system() == "Linux": - kwargs["multiprocessing_context"] = "fork" + kwargs["multiprocessing_context"] = "forkserver" return kwargs def train_dataloader(self): - print("creating train dataloader for rank ", dist.get_rank() if dist.is_available() and dist.is_initialized() else "not distributed") return DataLoader( self.train_dataset_multi, batch_size=self.batch_size, diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index df2f657c..c783e26f 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -157,7 +157,6 @@ def forward(self, data): class BusToGenBroadcaster(MessagePassing): - """Broadcast per-bus values to connected generators via graph propagation.""" def __init__(self, aggr="add"): super().__init__(aggr=aggr) @@ -175,7 +174,6 @@ def message(self, x_j): class SimulateMeasurements(BaseTransform): - """Add configurable noise/outliers and masks to simulate measured quantities.""" def __init__(self, args): super().__init__() self.measurements = args.task.measurements @@ -323,3 +321,74 @@ def forward(self, data): } return data + +####################### +class AddVLDHeteroMask(BaseTransform): + """ + PF-like masking for VLD: + - keeps PF bus-type masks required by the physics decoder + - masks only physical reconstruction channels + - leaves appended topology/status metadata unmasked + """ + + def __init__(self): + super().__init__() + + def forward(self, data): + bus_x = data.x_dict["bus"] + gen_x = data.x_dict["gen"] + + mask_PQ = bus_x[:, PQ_H] == 1 + mask_PV = bus_x[:, PV_H] == 1 + mask_REF = bus_x[:, REF_H] == 1 + + mask_bus = torch.zeros_like(bus_x, dtype=torch.bool) + mask_gen = torch.zeros_like(gen_x, dtype=torch.bool) + + # Keep same physical masking pattern as PF + mask_bus[:, MIN_VM_H] = True + mask_bus[:, MAX_VM_H] = True + mask_bus[:, MIN_QG_H] = True + mask_bus[:, MAX_QG_H] = True + mask_bus[:, VN_KV] = True + + mask_gen[:, MIN_PG] = True + mask_gen[:, MAX_PG] = True + mask_gen[:, C0_H] = True + mask_gen[:, C1_H] = True + mask_gen[:, C2_H] = True + + mask_bus[mask_PQ, VM_H] = True + mask_bus[mask_PQ, VA_H] = True + + mask_bus[mask_PV, VA_H] = True + mask_bus[mask_PV, QG_H] = True + + mask_bus[mask_REF, VM_H] = True + mask_bus[mask_REF, QG_H] = True + + gen_bus_edges = data.edge_index_dict[("gen", "connected_to", "bus")] + gen_indices, bus_indices = gen_bus_edges + ref_gens = gen_indices[mask_REF[bus_indices]] + mask_gen[ref_gens, PG_H] = True + + mask_branch = torch.zeros_like( + data.edge_attr_dict[("bus", "connects", "bus")], + dtype=torch.bool, + ) + mask_branch[:, P_E] = True + mask_branch[:, Q_E] = True + mask_branch[:, ANG_MIN] = True + mask_branch[:, ANG_MAX] = True + mask_branch[:, RATE_A] = True + + data.mask_dict = { + "bus": mask_bus, + "gen": mask_gen, + "branch": mask_branch, + "PQ": mask_PQ, + "PV": mask_PV, + "REF": mask_REF, + } + return data +####################### \ No newline at end of file diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index eb5652d7..11601a66 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -228,8 +228,8 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= self.baseMVA - data.baseMVA = torch.tensor(self.baseMVA, dtype=data.x_dict["bus"].dtype) # # needs to be float32 for MPS - data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS + data.baseMVA = self.baseMVA + data.is_normalized = True def inverse_transform(self, data: HeteroData): if self.baseMVA is None or self.baseMVA == 0: @@ -299,7 +299,7 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= self.baseMVA - data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = False def inverse_output(self, output, batch): bus_output = output["bus"] @@ -510,10 +510,10 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= e_b - data.is_normalized = torch.tensor(True, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = True def inverse_transform(self, data: HeteroData): - """Undo per-unit normalization (multiply by baseMVA, inverse log1p for cost coeffs).""" + """Undo per-unit normalization (multiply by baseMVA, rad->deg, inverse log1p for cost coeffs).""" if self._baseMVA_lookup is None: raise ValueError("Normalizer not fitted or lookups not loaded") if not data.is_normalized.all(): @@ -573,7 +573,7 @@ def inverse_transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= 180.0 / torch.pi data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] *= e_b - data.is_normalized = torch.tensor(False, dtype=torch.bool) # needs to be bool for MPS + data.is_normalized = False def inverse_output(self, output, batch): """ diff --git a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py index 82f57a57..b1be1b00 100644 --- a/gridfm_graphkit/datasets/powergrid_hetero_dataset.py +++ b/gridfm_graphkit/datasets/powergrid_hetero_dataset.py @@ -10,6 +10,19 @@ from torch_geometric.data import HeteroData from gridfm_graphkit.datasets.globals import VA_H, PG_H +##################### +def _optional_cols(df, cols, default_value): + """ + Return a list of columns guaranteed to exist in df. + Missing columns are created and filled with the corresponding default value. + """ + if len(default_value) < len(cols): + raise ValueError(f"default_value has length {len(default_value)} but cols has length {len(cols)}") + for i, col in enumerate(cols): + if col not in df.columns: + df[col] = default_value[i] + return cols +##################### class HeteroGridDatasetDisk(Dataset): """ @@ -55,6 +68,7 @@ def processed_done_file(self): @property def processed_file_names(self): return [ + "load_scenarios.pt", self.processed_done_file, ] @@ -71,11 +85,11 @@ def process(self): bus_data["scenario"].min() == 0 and bus_data["scenario"].max() == len(bus_data["scenario"].unique()) - 1 ) - if "load_scenario_idx" in bus_data.columns: - load_scenarios = torch.tensor( - bus_data.groupby("scenario", sort=True)["load_scenario_idx"].first().values, - ) - torch.save(load_scenarios, osp.join(self.processed_dir, "load_scenarios.pt")) + + load_scenarios = torch.tensor( + bus_data.groupby("scenario", sort=True)["load_scenario_idx"].first().values, + ) + torch.save(load_scenarios, osp.join(self.processed_dir, "load_scenarios.pt")) agg_gen = ( gen_data.groupby(["scenario", "bus"])[["min_q_mvar", "max_q_mvar"]] @@ -107,6 +121,12 @@ def process(self): "vn_kv", ] + ##################### + # Optional VLD bus input columns + vld_bus_input_features = _optional_cols(bus_data,["bus_base_status", "bus_contingency"],default_value=[1.0, 0.0]) + bus_features = bus_features + vld_bus_input_features + ##################### + gen_features = [ "p_mw", "min_p_mw", @@ -118,6 +138,13 @@ def process(self): ] common_branch_features = ["tap", "ang_min", "ang_max", "rate_a", "br_status"] + + ##################### + # Optional VLD branch topology columns + vld_branch_features = _optional_cols(branch_data,["branch_base_status", "branch_contingency"], default_value=[1.0, 0.0]) + common_branch_features = common_branch_features + vld_branch_features + ##################### + forward_branch_features = [ "pf", "qf", @@ -135,9 +162,13 @@ def process(self): "Ytf_i", ] + common_branch_features + ##################### + # Optional VLD target column on bus.y + vld_bus_target_features = _optional_cols(bus_data,["bus_status_target"],default_value=[1.0]) + ##################### + # Group by scenario - bus_groups = bus_data.groupby("scenario") # Groupby preserves the order of rows within each group. - # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.groupby.html + bus_groups = bus_data.groupby("scenario") gen_groups = gen_data.groupby("scenario") branch_groups = branch_data.groupby("scenario") @@ -158,19 +189,27 @@ def process(self): # Bus nodes bus_df = bus_groups.get_group(scenario) - # assert that the buses are in increasing order - assert (bus_df["bus"].values == torch.arange(len(bus_df))).all(), "Buses are not in increasing order" - #todo: we should remove this assert and store the bus idx in the tensors - # right now we need the increasing order for e.g. the predict step that uses torch.arange(n_nodes) to index the buses. data["bus"].x = torch.tensor(bus_df[bus_features].values, dtype=torch.float) # Generator nodes gen_df = gen_groups.get_group(scenario).reset_index() data["gen"].x = torch.tensor(gen_df[gen_features].values, dtype=torch.float) gen_df["gen_index"] = gen_df.index # Use actual index as generator ID - # todo: change this to instead use the generator id as the index - data["bus"].y = data["bus"].x[:, : (VA_H + 1)].clone() + ##################### + #data["bus"].y = data["bus"].x[:, : (VA_H + 1)].clone() + # Append the optional VLD target column. + bus_y_base = torch.tensor( + bus_df[["Pd", "Qd", "Qg", "Vm", "Va"]].values, + dtype=torch.float, + ) + bus_y_vld = torch.tensor( + bus_df[vld_bus_target_features].values, + dtype=torch.float, + ) + data["bus"].y = torch.cat([bus_y_base, bus_y_vld], dim=1) + ##################### + data["gen"].y = data["gen"].x[:, : (PG_H + 1)].clone() # Bus-Bus edges diff --git a/gridfm_graphkit/datasets/task_transforms.py b/gridfm_graphkit/datasets/task_transforms.py index dffb66cb..479b1d0a 100644 --- a/gridfm_graphkit/datasets/task_transforms.py +++ b/gridfm_graphkit/datasets/task_transforms.py @@ -4,18 +4,23 @@ RemoveInactiveGenerators, ApplyMasking, LoadGridParamsFromPath, + ################### + RemoveInactiveBranchesKeepTopology, + ################### ) from gridfm_graphkit.datasets.masking import ( AddOPFHeteroMask, AddPFHeteroMask, SimulateMeasurements, + ################### + AddVLDHeteroMask, + ################### ) from gridfm_graphkit.io.registries import TRANSFORM_REGISTRY @TRANSFORM_REGISTRY.register("PowerFlow") class PowerFlowTransforms(Compose): - """Compose preprocessing and masking transforms for PowerFlow datasets.""" def __init__(self, args): transforms = [] @@ -30,7 +35,6 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("OptimalPowerFlow") class OptimalPowerFlowTransforms(Compose): - """Compose preprocessing and masking transforms for OptimalPowerFlow datasets.""" def __init__(self, args): transforms = [] @@ -45,7 +49,6 @@ def __init__(self, args): @TRANSFORM_REGISTRY.register("StateEstimation") class StateEstimationTransforms(Compose): - """Compose preprocessing and measurement transforms for StateEstimation datasets.""" def __init__(self, args): transforms = [] @@ -58,3 +61,20 @@ def __init__(self, args): # Pass the list of transforms to Compose super().__init__(transforms) + +############################ +@TRANSFORM_REGISTRY.register("VoltageLossDetection") +class VoltageLossDetectionTransforms(Compose): + def __init__(self, args): + transforms = [] + + if hasattr(args.task, "grid_path"): + transforms.append(LoadGridParamsFromPathVLD(args)) + + transforms.append(RemoveInactiveBranchesKeepTopology()) + transforms.append(RemoveInactiveGenerators()) + transforms.append(AddVLDHeteroMask()) + transforms.append(ApplyMasking(args=args)) + + super().__init__(transforms) +########################## \ No newline at end of file diff --git a/gridfm_graphkit/datasets/transforms.py b/gridfm_graphkit/datasets/transforms.py index c6891dc2..57cebfc9 100644 --- a/gridfm_graphkit/datasets/transforms.py +++ b/gridfm_graphkit/datasets/transforms.py @@ -96,7 +96,6 @@ def forward(self, data): class LoadGridParamsFromPath(BaseTransform): - """Inject static grid parameters from a saved grid template into each sample.""" def __init__(self, args): super().__init__() self.grid_path = args.task.grid_path @@ -124,3 +123,48 @@ def forward(self, data): ].edge_attr[:, cols] data["gen"].x[:, G_ON] = grid_data["gen"].x[:, G_ON] return data + +################## +class RemoveInactiveBranchesKeepTopology(BaseTransform): + """ + Removes inactive branches using B_ON, but preserves all edge_attr columns, + including appended VLD topology columns after B_ON. + """ + + def forward(self, data): + et = ("bus", "connects", "bus") + active_mask = data[et].edge_attr[:, B_ON] == 1 + + data[et].edge_index = data[et].edge_index[:, active_mask] + data[et].edge_attr = data[et].edge_attr[active_mask] + data[et].y = data[et].y[active_mask] + return data + + +class LoadGridParamsFromPathVLD(BaseTransform): + def __init__(self, args): + super().__init__() + self.grid_path = args.task.grid_path + self.grid_data = HeteroData.from_dict( + torch.load(self.grid_path, weights_only=True) + ) + self.normalizer = HeteroDataMVANormalizer(args) + self.normalizer.vn_kv_max = 1 + + def forward(self, data): + if hasattr(data, "is_normalized"): + self.normalizer.baseMVA = data.baseMVA + grid_data = deepcopy(self.grid_data) + self.normalizer.transform(grid_data) + else: + grid_data = deepcopy(self.grid_data) + + cols = [YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I] + data[("bus", "connects", "bus")].edge_attr[:, cols] = grid_data[ + ("bus", "connects", "bus") + ].edge_attr[:, cols] + + data["gen"].x[:, G_ON] = grid_data["gen"].x[:, G_ON] + return data +################## + diff --git a/gridfm_graphkit/datasets/utils.py b/gridfm_graphkit/datasets/utils.py index 65b34f4e..f330d496 100644 --- a/gridfm_graphkit/datasets/utils.py +++ b/gridfm_graphkit/datasets/utils.py @@ -3,7 +3,6 @@ from typing import Tuple from torch import Tensor import torch -from pathlib import Path def split_dataset( @@ -59,7 +58,6 @@ def split_dataset_by_load_scenario_idx( val_ratio: float = 0.1, test_ratio: float = 0.1, ) -> Tuple[Subset, Subset, Subset]: - """Split dataset by unique load-scenario IDs to avoid scenario leakage.""" if val_ratio + test_ratio >= 1: raise ValueError("The sum of val_ratio and test_ratio must be less than 1.") @@ -92,30 +90,3 @@ def split_dataset_by_load_scenario_idx( test_dataset = Subset(dataset, test_indices) return train_dataset, val_dataset, test_dataset - - -def split_from_existing_files( - dataset, - splits_folder: Path, -) -> Tuple[Subset, Subset, Subset]: - """Build train/val/test subsets from split index files. - - Expects `train.pt`, `val.pt`, and `test.pt` inside `splits_folder`. - Returns both the dataset subsets and the raw scenario ids per split. - """ - output = [] - - indices = {} - - for split in ["train", "val", "test"]: - split_file = splits_folder / f"{split}.pt" - assert split_file.is_file(), f"{str(split_file)} does not exist" - split_indices = torch.load(str(split_file), weights_only=True) - split_dataset = Subset(dataset, split_indices) - output.append(split_dataset) - split_indices = list(split_indices) - print(f'{split=} {len(split_indices)=}') - indices[split]=[int(t.item()) for t in split_indices] - - output = tuple(output) - return output, indices \ No newline at end of file diff --git a/gridfm_graphkit/io/registries.py b/gridfm_graphkit/io/registries.py index 65d596a9..32feb20a 100644 --- a/gridfm_graphkit/io/registries.py +++ b/gridfm_graphkit/io/registries.py @@ -1,5 +1,4 @@ class Registry: - """Simple name-to-object registry with decorator-based registration.""" def __init__(self, name: str): self._name = name self._registry = {} diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index f8245352..956f9213 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -3,6 +3,9 @@ PhysicsDecoderOPF, PhysicsDecoderPF, PhysicsDecoderSE, + ############### + PhysicsDecoderVLD + ############## ) __all__ = [ @@ -10,4 +13,7 @@ "PhysicsDecoderOPF", "PhysicsDecoderPF", "PhysicsDecoderSE", + ############### + "PhysicsDecoderVLD", + ############## ] diff --git a/gridfm_graphkit/models/gnn_heterogeneous_gns.py b/gridfm_graphkit/models/gnn_heterogeneous_gns.py index 10735603..93366db7 100644 --- a/gridfm_graphkit/models/gnn_heterogeneous_gns.py +++ b/gridfm_graphkit/models/gnn_heterogeneous_gns.py @@ -19,6 +19,9 @@ # Output feature indices VM_OUT, PG_OUT_GEN, + ############## + BUS_STATUS_LOGIT_OUT, + ############## # Generator feature indices PG_H, MIN_PG, @@ -49,6 +52,9 @@ def __init__(self, args) -> None: self.heads = args.model.attention_head self.task = args.task.task_name self.dropout = getattr(args.model, "dropout", 0.0) + #################### + self.is_vld_task = self.task == "VoltageLossDetection" + #################### # projections for each node type self.input_proj_bus = nn.Sequential( @@ -129,6 +135,15 @@ def __init__(self, args) -> None: nn.Linear(self.hidden_dim, self.output_bus_dim), ) + ############################# + self.bus_status_head = nn.Sequential( + nn.Linear(self.hidden_dim * self.heads, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, 1), + ) + ############################ + self.mlp_gen = nn.Sequential( nn.Linear(self.hidden_dim * self.heads, self.hidden_dim), nn.LayerNorm(self.hidden_dim), @@ -156,6 +171,9 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): """ self.layer_residuals = {} + ##################### + self.latest_x_dict = x_dict + #################### # 1) initial projections h_bus = self.input_proj_bus(x_dict["bus"]) # [num_bus, hidden_dim] @@ -199,6 +217,10 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): bus_temp = self.mlp_bus(h_bus) # [Nb, 2] -> Vm, Va gen_temp = self.mlp_gen(h_gen) # [Ng, 1] -> Pg + ####################### + status_logit = self.bus_status_head(h_bus) + ####################### + if self.task == "StateEstimation": if i == self.num_layers - 1: Pft, Qft = self.branch_flow_layer( @@ -278,4 +300,15 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): ).mean() h_bus = h_bus + self.physics_mlp(bus_residuals) - return {"bus": output_temp, "gen": gen_temp} + ####################### + #return {"bus": output_temp, "gen": gen_temp} + if self.is_vld_task: + output_bus = torch.cat([output_temp, status_logit], dim=1) + else: + output_bus = output_temp + + return {"bus": output_bus, "gen": gen_temp} + ##################### + + + diff --git a/gridfm_graphkit/models/utils.py b/gridfm_graphkit/models/utils.py index bc4b9bfa..3a9203e8 100644 --- a/gridfm_graphkit/models/utils.py +++ b/gridfm_graphkit/models/utils.py @@ -73,7 +73,6 @@ def forward(self, Pft, Qft, edge_index, num_bus): def compute_shunt_power(bus_data_pred, bus_data_orig): - """Compute active/reactive shunt power contributions per bus.""" p_shunt = -bus_data_orig[:, GS] * bus_data_pred[:, VM_OUT] ** 2 q_shunt = bus_data_orig[:, BS] * bus_data_pred[:, VM_OUT] ** 2 return p_shunt, q_shunt @@ -81,7 +80,6 @@ def compute_shunt_power(bus_data_pred, bus_data_orig): @PHYSICS_DECODER_REGISTRY.register("OptimalPowerFlow") class PhysicsDecoderOPF(nn.Module): - """Map network outputs to OPF-consistent bus states using physics constraints.""" def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): mask_pv = mask_dict["PV"] mask_ref = mask_dict["REF"] @@ -116,7 +114,6 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("PowerFlow") class PhysicsDecoderPF(nn.Module): - """Map network outputs to PF-consistent bus states using physics constraints.""" def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): """ PF decoder: @@ -164,7 +161,6 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): @PHYSICS_DECODER_REGISTRY.register("StateEstimation") class PhysicsDecoderSE(nn.Module): - """Map network outputs to SE targets via bus power-balance relations.""" def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): p_shunt, q_shunt = compute_shunt_power(bus_data_pred, bus_data_orig) Vm_out = bus_data_pred[:, VM_OUT] @@ -172,6 +168,37 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): output = torch.stack([Vm_out, Va_out, P_in - p_shunt, Q_in - q_shunt], dim=1) return output +######################### +@PHYSICS_DECODER_REGISTRY.register("VoltageLossDetection") +class PhysicsDecoderVLD(nn.Module): + """ + VLD decoder: + Use the same physical decoding rule as PowerFlow for the bus outputs + [Vm, Va, Pg, Qg]. The VLD-specific bus-status logit is produced by a + separate model head and concatenated later in the model forward pass. + """ + + def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): + mask_pv = mask_dict["PV"] + mask_ref = mask_dict["REF"] + mask_pvref = mask_pv | mask_ref + + p_shunt, q_shunt = compute_shunt_power(bus_data_pred, bus_data_orig) + + Pd = bus_data_orig[:, PD_H] + Qd = bus_data_orig[:, QD_H] + + Qg_new = torch.where(mask_pvref, Q_in + Qd - q_shunt, torch.zeros_like(Q_in)) + + Pg_ref = torch.where(mask_ref, P_in + Pd - p_shunt, torch.zeros_like(P_in)) + Pg_new = torch.where(mask_pv, agg_bus, Pg_ref) + + Vm_out = bus_data_pred[:, VM_OUT] + Va_out = bus_data_pred[:, VA_OUT] + + output = torch.stack([Vm_out, Va_out, Pg_new, Qg_new], dim=1) + return output +######################## class ComputeNodeResiduals(nn.Module): """Compute net residuals per bus combining branch flows, generators, loads, and shunts.""" @@ -188,5 +215,4 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig): def bound_with_sigmoid(pred, low, high): - """Squash unconstrained predictions into [low, high] with a sigmoid map.""" return low + (high - low) * torch.sigmoid(pred) diff --git a/gridfm_graphkit/tasks/__init__.py b/gridfm_graphkit/tasks/__init__.py index 8ed9b137..d9ce68e9 100644 --- a/gridfm_graphkit/tasks/__init__.py +++ b/gridfm_graphkit/tasks/__init__.py @@ -1,5 +1,10 @@ from gridfm_graphkit.tasks.pf_task import PowerFlowTask from gridfm_graphkit.tasks.opf_task import OptimalPowerFlowTask from gridfm_graphkit.tasks.se_task import StateEstimationTask +################ +from gridfm_graphkit.tasks.vld_task import VoltageLossDetectionTask +############### -__all__ = ["PowerFlowTask", "OptimalPowerFlowTask", "StateEstimationTask"] +############## +__all__ = ["PowerFlowTask", "OptimalPowerFlowTask", "StateEstimationTask", "VoltageLossDetectionTask" ] +############# \ No newline at end of file diff --git a/gridfm_graphkit/tasks/base_task.py b/gridfm_graphkit/tasks/base_task.py index fc2b95e3..90c8f7b5 100644 --- a/gridfm_graphkit/tasks/base_task.py +++ b/gridfm_graphkit/tasks/base_task.py @@ -20,20 +20,6 @@ def __init__(self, args, data_normalizers): self.data_normalizers = data_normalizers self.save_hyperparameters() - def transfer_batch_to_device(self, batch, device, dataloader_idx): - """Pre-cast float64 tensors before moving batches onto MPS. - - PyTorch MPS does not support float64 tensors. Some PyG metadata fields can - get collated as float64 even when model inputs are float32, so coerce them - first and then delegate to Lightning's standard device transfer. - """ - if getattr(device, "type", None) == "mps" and hasattr(batch, "stores"): - for store in batch.stores: - for key, val in store.items(): - if isinstance(val, torch.Tensor) and val.dtype == torch.float64: - store[key] = val.to(torch.float32) - return super().transfer_batch_to_device(batch, device, dataloader_idx) - def on_after_batch_transfer(self, batch, dataloader_idx: int): """Cast float tensors in HeteroData batches to the model's parameter dtype. diff --git a/gridfm_graphkit/tasks/compute_ac_dc_metrics.py b/gridfm_graphkit/tasks/compute_ac_dc_metrics.py new file mode 100644 index 00000000..8dcfc8c0 --- /dev/null +++ b/gridfm_graphkit/tasks/compute_ac_dc_metrics.py @@ -0,0 +1,227 @@ +"""Compute AC/DC power balance residuals and runtime statistics on test splits.""" + +import json +import os +import numpy as np +import pandas as pd +from gridfm_datakit.utils.power_balance import ( + compute_branch_powers_vectorized, + compute_bus_balance, +) + +N_SCENARIO_PER_PARTITION = 200 +NUM_PROCESSES = 64 + + +def _load_test_data(data_dir: str, test_scenario_ids: list[int]): + partitions = sorted(set(s // N_SCENARIO_PER_PARTITION for s in test_scenario_ids)) + test_set = set(test_scenario_ids) + partition_filter = [("scenario_partition", "in", partitions)] + + bus_df = pd.read_parquet( + os.path.join(data_dir, "bus_data.parquet"), + filters=partition_filter, + ) + branch_df = pd.read_parquet( + os.path.join(data_dir, "branch_data.parquet"), + filters=partition_filter, + ) + runtime_df = pd.read_parquet( + os.path.join(data_dir, "runtime_data.parquet"), + filters=partition_filter, + ) + + bus_df = bus_df[bus_df["scenario"].isin(test_set)].reset_index(drop=True) + branch_df = branch_df[branch_df["scenario"].isin(test_set)].reset_index(drop=True) + runtime_df = runtime_df[runtime_df["scenario"].isin(test_set)].reset_index( + drop=True, + ) + + print( + f" Loaded {len(bus_df)} bus rows, {len(branch_df)} branch rows, " + f"{len(runtime_df)} runtime rows for {len(test_set)} test scenarios", + ) + return bus_df, branch_df, runtime_df + + +def _compute_residual_stats(balance_df: pd.DataFrame, dc: bool) -> dict: + grouped = balance_df.groupby("scenario") + + if dc: + P_mis = balance_df["P_mis_dc"].to_numpy() + nan_scenarios = int(grouped["P_mis_dc"].apply(lambda x: x.isna().all()).sum()) + return { + "Avg. active res. (MW)": float(np.nanmean(np.abs(P_mis))), + "DC NaN scenarios": nan_scenarios, + } + + P_mis = balance_df["P_mis_ac"].to_numpy() + Q_mis = balance_df["Q_mis_ac"].to_numpy() + pbe = np.sqrt(P_mis**2 + Q_mis**2) + + pbe_per_scenario_mean = grouped.apply( + lambda g: np.nanmean( + np.sqrt(g["P_mis_ac"].to_numpy() ** 2 + g["Q_mis_ac"].to_numpy() ** 2), + ), + include_groups=False, + ) + + return { + "Avg. active res. (MW)": float(np.nanmean(np.abs(P_mis))), + "Avg. reactive res. (MVar)": float(np.nanmean(np.abs(Q_mis))), + "PBE Mean": float(np.nanmean(pbe_per_scenario_mean)), + "PBE Max": float(np.nanmax(pbe)), + } + + +def _compute_runtime_stats(runtime_df: pd.DataFrame) -> dict: + results = {} + for mode in ["ac", "dc"]: + if mode not in runtime_df.columns: + continue + + rt_ms = runtime_df[mode].to_numpy(dtype=float) * 1000.0 / NUM_PROCESSES + valid = rt_ms[~np.isnan(rt_ms)] + + results[f"runtime_{mode}_mean_ms_with_{NUM_PROCESSES}_cores"] = float( + np.mean(valid), + ) + results[f"runtime_{mode}_median_ms_with_{NUM_PROCESSES}_cores"] = float( + np.median(valid), + ) + results[f"runtime_{mode}_std_ms_with_{NUM_PROCESSES}_cores"] = float( + np.std(valid), + ) + results[f"runtime_{mode}_max_ms_with_{NUM_PROCESSES}_cores"] = float( + np.max(valid), + ) + + return results + + +def compute_ac_dc_metrics( + artifacts_dir: str, + data_dir: str, + grid_name: str, + sn_mva: float, +) -> bool: + """Compute AC/DC ground-truth power balance and runtime metrics, save results. + + Saves: + - Aggregated metrics (CSV) + - AC per-bus residuals (Parquet) + - DC per-bus residuals (Parquet) + + Returns: + True if metrics were computed, False if splits JSON was not found. + """ + + splits_json = os.path.join( + artifacts_dir, + "stats", + f"{grid_name}_scenario_splits.json", + ) + if not os.path.exists(splits_json): + print(f" Skipping: no splits JSON found at {splits_json}") + return False + + with open(splits_json) as f: + test_ids = json.load(f)["test"] + + print(f" Test split: {len(test_ids)} scenarios") + + bus_df, branch_df, runtime_df = _load_test_data(data_dir, test_ids) + + # ========================= + # AC residuals + # ========================= + print(" Computing AC power balance...") + balance_ac = compute_bus_balance( + bus_df, + branch_df, + branch_df[["pf", "qf", "pt", "qt"]], + dc=False, + sn_mva=sn_mva, + ) + + ac_stats = _compute_residual_stats(balance_ac, dc=False) + + # ========================= + # DC residuals + # ========================= + print(" Computing DC power balance...") + pf_dc, _, pt_dc, _ = compute_branch_powers_vectorized( + branch_df, + bus_df, + dc=True, + sn_mva=sn_mva, + ) + + balance_dc = compute_bus_balance( + bus_df, + branch_df, + pd.DataFrame( + {"pf_dc": pf_dc, "pt_dc": pt_dc}, + index=branch_df.index, + ), + dc=True, + sn_mva=sn_mva, + ) + + dc_stats = _compute_residual_stats(balance_dc, dc=True) + + # ========================= + # Save per-bus residuals (PARQUET) + # ========================= + out_dir = os.path.join(artifacts_dir, "test") + os.makedirs(out_dir, exist_ok=True) + + # AC: active + reactive + ac_bus_residuals = ( + balance_ac[["scenario", "bus", "P_mis_ac", "Q_mis_ac"]] + .copy() + .rename( + columns={ + "P_mis_ac": "active res. (MW)", + "Q_mis_ac": "reactive res. (MVar)", + }, + ) + ) + ac_residuals_path = os.path.join(out_dir, f"{grid_name}_ac_bus_residuals.parquet") + ac_bus_residuals.to_parquet(ac_residuals_path, index=False) + print(f" AC per-bus residuals saved to {ac_residuals_path}") + + # DC: active only + dc_bus_residuals = ( + balance_dc[["scenario", "bus", "P_mis_dc"]] + .copy() + .rename( + columns={ + "P_mis_dc": "DC active res. (MW)", + }, + ) + ) + + dc_residuals_path = os.path.join(out_dir, f"{grid_name}_dc_bus_residuals.parquet") + dc_bus_residuals.to_parquet(dc_residuals_path, index=False) + print(f" DC per-bus residuals saved to {dc_residuals_path}") + + # ========================= + # Save aggregated metrics (CSV) + # ========================= + runtime_stats = _compute_runtime_stats(runtime_df) + + rows = [] + for key, val in ac_stats.items(): + rows.append({"Metric": f"AC {key}", "Value": val}) + for key, val in dc_stats.items(): + rows.append({"Metric": f"DC {key}", "Value": val}) + for key, val in runtime_stats.items(): + rows.append({"Metric": key, "Value": val}) + + metrics_path = os.path.join(out_dir, f"{grid_name}_ac_dc_metrics.csv") + pd.DataFrame(rows).to_csv(metrics_path, index=False) + + print(f" Aggregated metrics saved to {metrics_path}") + + return True diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index dbb1baab..06d938df 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -1,12 +1,8 @@ from gridfm_graphkit.datasets.globals import ( # Bus feature indices - PD_H, - QD_H, QG_H, VM_H, VA_H, - MIN_VM_H, - MAX_VM_H, MIN_QG_H, MAX_QG_H, # Output feature indices @@ -16,8 +12,6 @@ QG_OUT, # Generator feature indices PG_H, - MIN_PG, - MAX_PG, C0_H, C1_H, C2_H, @@ -34,8 +28,8 @@ plot_residuals_histograms, residual_stats_by_type, ) +from pytorch_lightning.utilities import rank_zero_only import torch -import torch.distributed as dist import torch.nn.functional as F from torch_scatter import scatter_add from gridfm_graphkit.models.utils import ( @@ -87,14 +81,14 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): c2 = batch.x_dict["gen"][:, C2_H] target_pg = batch.y_dict["gen"].squeeze() pred_pg = output["gen"].squeeze() - gen_cost_gt = (c0 + c1 * target_pg + c2 * target_pg**2) # assumes all branches are on! - gen_cost_pred = (c0 + c1 * pred_pg + c2 * pred_pg**2) # assumes all branches are on! + gen_cost_gt = c0 + c1 * target_pg + c2 * target_pg**2 + gen_cost_pred = c0 + c1 * pred_pg + c2 * pred_pg**2 gen_batch = batch.batch_dict["gen"] # shape: [N_gen_total] cost_gt = scatter_add(gen_cost_gt, gen_batch, dim=0) cost_pred = scatter_add(gen_cost_pred, gen_batch, dim=0) - + optimality_gap = torch.mean(torch.abs((cost_pred - cost_gt) / cost_gt * 100)) agg_gen_on_bus = scatter_add( @@ -118,7 +112,7 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): # output["bus"] = target Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) - # Compute branch thermal limits violations + # Compute branch termal limits violations Sft = torch.sqrt(Pft**2 + Qft**2) # apparent power flow per branch branch_thermal_limits = bus_edge_attr[:, RATE_A] branch_thermal_excess = F.relu(Sft - branch_thermal_limits) @@ -138,14 +132,13 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): bus_angles = output["bus"][:, VA_OUT] # in degrees from_bus = bus_edge_index[0] to_bus = bus_edge_index[1] - angle_diff = bus_angles[from_bus] - bus_angles[to_bus] # keep sign - angle_diff = (angle_diff + torch.pi) % (2 * torch.pi) - torch.pi # wrap to [-pi, pi] - angle_excess_low = F.relu(angle_min - angle_diff) - angle_excess_high = F.relu(angle_diff - angle_max) + angle_diff = torch.abs(bus_angles[from_bus] - bus_angles[to_bus]) - branch_angle_violation_mean = torch.mean( - angle_excess_low + angle_excess_high - ) # mean of the abs violation + angle_excess_low = F.relu(angle_min - angle_diff) # violation if too small + angle_excess_high = F.relu(angle_diff - angle_max) # violation if too large + branch_angle_violation_mean = ( + torch.mean(angle_excess_low + angle_excess_high) * 180.0 / torch.pi + ) P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( @@ -174,8 +167,6 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): mean_Qg_violation_PV = Qg_violation_amount[mask_PV].mean() mean_Qg_violation_REF = Qg_violation_amount[mask_REF].mean() - mask_PV_REF = mask_PV | mask_REF # PV or REF buses - mean_Qg_violation = Qg_violation_amount[mask_PV_REF].mean() # if self.args.verbose: mean_res_P_PQ, max_res_P_PQ = residual_stats_by_type( @@ -270,10 +261,8 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Branch voltage angle difference violations"] = ( branch_angle_violation_mean ) - loss_dict["Mean Qg violation PV buses"] = mean_Qg_violation_PV # mean of the abs violation over the entire batch (all oines in the batch). - # this is then overaged over all the batches and gives same weight to all batches despite them possibly having varying number of branches + loss_dict["Mean Qg violation PV buses"] = mean_Qg_violation_PV loss_dict["Mean Qg violation REF buses"] = mean_Qg_violation_REF - loss_dict["Mean Qg violation"] = mean_Qg_violation loss_dict["MSE PQ nodes - PG"] = mse_PQ[PG_OUT] loss_dict["MSE PV nodes - PG"] = mse_PV[PG_OUT] @@ -304,25 +293,8 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): ) return + @rank_zero_only def on_test_end(self): - # In DDP, gather verbose test outputs from all ranks to rank 0 - # so that plots and detailed analysis cover the full test set. - if self.args.verbose and dist.is_available() and dist.is_initialized(): - world_size = dist.get_world_size() - gathered = [None] * world_size if dist.get_rank() == 0 else None - dist.gather_object(self.test_outputs, gathered, dst=0) - if dist.get_rank() == 0: - merged = {i: [] for i in range(len(self.args.data.networks))} - for rank_data in gathered: - for dl_idx, batches in rank_data.items(): - merged[dl_idx].extend(batches) - self.test_outputs = merged - - # Only rank 0 proceeds with logging, CSV writing, and plotting - if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: - self.test_outputs.clear() - return - if isinstance(self.logger, MLFlowLogger): artifact_dir = os.path.join( self.logger.save_dir, @@ -369,10 +341,10 @@ def on_test_end(self): rmse_gen = metrics.get("MSE PG", 0) ** 0.5 optimality_gap = metrics.get("Opt gap", " ") branch_thermal_violation_from = metrics.get( - "Branch thermal violation from", + "Branch termal violation from", " ", ) - branch_thermal_violation_to = metrics.get("Branch thermal violation to", " ") + branch_thermal_violation_to = metrics.get("Branch termal violation to", " ") branch_angle_violation = metrics.get( "Branch voltage angle difference violations", " ", @@ -382,7 +354,6 @@ def on_test_end(self): "Mean Qg violation REF buses", " ", ) - mean_qg_violation = metrics.get("Mean Qg violation", " ") # --- Main RMSE metrics file --- data_main = { @@ -401,12 +372,11 @@ def on_test_end(self): "Avg. reactive res. (MVar)", "RMSE PG generators (MW)", "Mean optimality gap (%)", - "Mean branch thermal violation from (MVA)", - "Mean branch thermal violation to (MVA)", + "Mean branch termal violation from (MVA)", + "Mean branch termal violation to (MVA)", "Mean branch angle difference violation (radians)", "Mean Qg violation PV buses", "Mean Qg violation REF buses", - "Mean Qg violation", ], "Value": [ avg_active_res, @@ -418,7 +388,6 @@ def on_test_end(self): branch_angle_violation, mean_qg_violation_PV_buses, mean_qg_violation_REF_buses, - mean_qg_violation, ], } df_residuals = pd.DataFrame(data_residuals) @@ -513,100 +482,4 @@ def on_test_end(self): self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): - output, _ = self.shared_step(batch) - - self.data_normalizers[dataloader_idx].inverse_transform(batch) - self.data_normalizers[dataloader_idx].inverse_output(output, batch) - - branch_flow_layer = ComputeBranchFlow() - node_injection_layer = ComputeNodeInjection() - node_residuals_layer = ComputeNodeResiduals() - - num_bus = batch.x_dict["bus"].size(0) - bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] - bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] - - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) - P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) - residual_P, residual_Q = node_residuals_layer( - P_in, - Q_in, - output["bus"], - batch.x_dict["bus"], - ) - residual_P = torch.abs(residual_P) - residual_Q = torch.abs(residual_Q) - residual_mva = torch.sqrt(residual_P**2 + residual_Q**2) - - bus_batch = batch.batch_dict["bus"] - scenario_ids = batch["scenario_id"][bus_batch] - local_bus_idx = torch.cat( - [ - torch.arange(c, device=bus_batch.device) - for c in torch.bincount(bus_batch) - ], - ) # this works because the order of the buses is preserved by the groupby in the dataset wrapper and datakit data has buses in increasing order. - - bus_x = batch.x_dict["bus"] - bus_y = batch.y_dict["bus"] - mask_PQ = batch.mask_dict["PQ"] - mask_PV = batch.mask_dict["PV"] - mask_REF = batch.mask_dict["REF"] - - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - gen_batch = batch.batch_dict["gen"] - gen_scenario_ids = batch["scenario_id"][gen_batch] - local_gen_idx = torch.cat( - [ - torch.arange(c, device=gen_batch.device) - for c in torch.bincount(gen_batch) - ], - ) - gen_x = batch.x_dict["gen"] - gen_target = batch.y_dict["gen"].reshape(-1) - gen_pred = output["gen"].reshape(-1) - - return { - "bus": { - "scenario": scenario_ids.cpu().numpy(), - "bus": local_bus_idx.cpu().numpy(), - "Pd": bus_x[:, PD_H].cpu().numpy(), - "Qd": bus_x[:, QD_H].cpu().numpy(), - "Vm_min": bus_x[:, MIN_VM_H].cpu().numpy(), - "Vm_max": bus_x[:, MAX_VM_H].cpu().numpy(), - "Qg_min": bus_x[:, MIN_QG_H].cpu().numpy(), - "Qg_max": bus_x[:, MAX_QG_H].cpu().numpy(), - "Vm_target": bus_y[:, VM_H].cpu().numpy(), - "Va_target": bus_y[:, VA_H].cpu().numpy(), - "Pg_target": agg_gen_on_bus.squeeze().cpu().numpy(), - "Qg_target": bus_y[:, QG_H].cpu().numpy(), - "PQ": mask_PQ.cpu().numpy().astype(int), - "PV": mask_PV.cpu().numpy().astype(int), - "REF": mask_REF.cpu().numpy().astype(int), - "Vm_pred": output["bus"][:, VM_OUT].detach().cpu().numpy(), - "Va_pred": output["bus"][:, VA_OUT].detach().cpu().numpy(), - "Pg_pred": output["bus"][:, PG_OUT].detach().cpu().numpy(), - "Qg_pred": output["bus"][:, QG_OUT].detach().cpu().numpy(), - "active res. (MW)": residual_P.detach().cpu().numpy(), - "reactive res. (MVar)": residual_Q.detach().cpu().numpy(), - "PBE": residual_mva.detach().cpu().numpy(), - }, - "gen": { - "scenario": gen_scenario_ids.cpu().numpy(), - "idx": local_gen_idx.cpu().numpy(), - "bus": local_bus_idx[gen_to_bus_index].cpu().numpy(), - "p_mw_target": gen_target.cpu().numpy(), - "p_mw_pred": gen_pred.detach().cpu().numpy(), - "min_p_mw": gen_x[:, MIN_PG].cpu().numpy(), - "max_p_mw": gen_x[:, MAX_PG].cpu().numpy(), - "cp0_eur": gen_x[:, C0_H].cpu().numpy(), - "cp1_eur_per_mw": gen_x[:, C1_H].cpu().numpy(), - "cp2_eur_per_mw2": gen_x[:, C2_H].cpu().numpy(), - }, - } + raise NotImplementedError diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index 948a25e0..cdc9d646 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -5,10 +5,6 @@ QG_H, VM_H, VA_H, - MIN_VM_H, - MAX_VM_H, - MIN_QG_H, - MAX_QG_H, # Output feature indices VM_OUT, VA_OUT, @@ -245,7 +241,6 @@ def on_test_end(self): # Only rank 0 proceeds with logging, CSV writing, and plotting if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: - self.test_outputs.clear() # clear the test outputs for other ranks return if isinstance(self.logger, MLFlowLogger): @@ -356,22 +351,22 @@ def on_test_end(self): self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): - output, _ = self.shared_step(batch) # get the predicted output from the model + output, _ = self.shared_step(batch) - self.data_normalizers[dataloader_idx].inverse_transform(batch) # normalize the batch data back to the original scale - self.data_normalizers[dataloader_idx].inverse_output(output, batch) # inverse transform the predicted output back to the original scale + self.data_normalizers[dataloader_idx].inverse_transform(batch) + self.data_normalizers[dataloader_idx].inverse_output(output, batch) - branch_flow_layer = ComputeBranchFlow() # layer to compute the branch flows - node_injection_layer = ComputeNodeInjection() # layer to compute the node injections - node_residuals_layer = ComputeNodeResiduals() # layer to compute the node residuals + branch_flow_layer = ComputeBranchFlow() + node_injection_layer = ComputeNodeInjection() + node_residuals_layer = ComputeNodeResiduals() - num_bus = batch.x_dict["bus"].size(0) # number of buses in the batch - bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] # from and to buses - bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] # edge attributes (admittance) of the bus connections + num_bus = batch.x_dict["bus"].size(0) + bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] + bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) # compute the branch flows - P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) # compute the node injections - residual_P, residual_Q = node_residuals_layer( # compute the node residuals + Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) + P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) + residual_P, residual_Q = node_residuals_layer( P_in, Q_in, output["bus"], @@ -388,9 +383,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): torch.arange(c, device=bus_batch.device) for c in torch.bincount(bus_batch) ], - ) # this is based on the assumptions that the buses within a graph are ordered and indexed as 0 ... n_nodes-1. - # todo: we should remove this assert and store the bus idx in the tensors - # right now we need the increasing order and we have an assert in the dataset to check it. + ) + bus_x = batch.x_dict["bus"] bus_y = batch.y_dict["bus"] mask_PQ = batch.mask_dict["PQ"] @@ -408,23 +402,19 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return { "scenario": scenario_ids.cpu().numpy(), "bus": local_bus_idx.cpu().numpy(), - "Pd": bus_x[:, PD_H].cpu().numpy(), - "Qd": bus_x[:, QD_H].cpu().numpy(), - "Vm_min": bus_x[:, MIN_VM_H].cpu().numpy(), - "Vm_max": bus_x[:, MAX_VM_H].cpu().numpy(), - "Qg_min": bus_x[:, MIN_QG_H].cpu().numpy(), - "Qg_max": bus_x[:, MAX_QG_H].cpu().numpy(), - "Vm_target": bus_y[:, VM_H].cpu().numpy(), - "Va_target": bus_y[:, VA_H].cpu().numpy(), - "Pg_target": agg_gen_on_bus.squeeze().cpu().numpy(), - "Qg_target": bus_y[:, QG_H].cpu().numpy(), - "PQ": mask_PQ.cpu().numpy().astype(int), - "PV": mask_PV.cpu().numpy().astype(int), - "REF": mask_REF.cpu().numpy().astype(int), - "Vm_pred": output["bus"][:, VM_OUT].detach().cpu().numpy(), - "Va_pred": output["bus"][:, VA_OUT].detach().cpu().numpy(), - "Pg_pred": output["bus"][:, PG_OUT].detach().cpu().numpy(), - "Qg_pred": output["bus"][:, QG_OUT].detach().cpu().numpy(), + "pd_mw": bus_x[:, PD_H].cpu().numpy(), + "qd_mvar": bus_x[:, QD_H].cpu().numpy(), + "vm_pu_target": bus_y[:, VM_H].cpu().numpy(), + "va_target": bus_y[:, VA_H].cpu().numpy(), + "pg_mw_target": agg_gen_on_bus.squeeze().cpu().numpy(), + "qg_mvar_target": bus_y[:, QG_H].cpu().numpy(), + "is_pq": mask_PQ.cpu().numpy().astype(int), + "is_pv": mask_PV.cpu().numpy().astype(int), + "is_ref": mask_REF.cpu().numpy().astype(int), + "vm_pu": output["bus"][:, VM_OUT].detach().cpu().numpy(), + "va": output["bus"][:, VA_OUT].detach().cpu().numpy(), + "pg_mw": output["bus"][:, PG_OUT].detach().cpu().numpy(), + "qg_mvar": output["bus"][:, QG_OUT].detach().cpu().numpy(), "active res. (MW)": residual_P.detach().cpu().numpy(), "reactive res. (MVar)": residual_Q.detach().cpu().numpy(), "PBE": residual_mva.detach().cpu().numpy(), diff --git a/gridfm_graphkit/tasks/reconstruction_tasks.py b/gridfm_graphkit/tasks/reconstruction_tasks.py index 45975aee..8742646b 100644 --- a/gridfm_graphkit/tasks/reconstruction_tasks.py +++ b/gridfm_graphkit/tasks/reconstruction_tasks.py @@ -57,7 +57,6 @@ def shared_step(self, batch): batch.edge_attr_dict, batch.mask_dict, model=self.model, - x_dict=batch.x_dict, ) return output, loss_dict diff --git a/gridfm_graphkit/tasks/se_task.py b/gridfm_graphkit/tasks/se_task.py index 36667ad2..5e45182d 100644 --- a/gridfm_graphkit/tasks/se_task.py +++ b/gridfm_graphkit/tasks/se_task.py @@ -26,7 +26,6 @@ @TASK_REGISTRY.register("StateEstimation") class StateEstimationTask(ReconstructionTask): - """State-estimation task with evaluation plots for masked and noisy measurements.""" def __init__(self, args, data_normalizers): super().__init__(args, data_normalizers) diff --git a/gridfm_graphkit/tasks/utils.py b/gridfm_graphkit/tasks/utils.py index 273d79f5..c77a9953 100644 --- a/gridfm_graphkit/tasks/utils.py +++ b/gridfm_graphkit/tasks/utils.py @@ -7,25 +7,10 @@ def residual_stats_by_type(residual, mask, bus_batch): - """Return per-graph mean and max absolute residuals for a masked bus subset.""" residual_masked = residual[mask] batch_masked = bus_batch[mask] - abs_residual = torch.abs(residual_masked) - - # torch_scatter on MPS can dispatch into a CPU-only path for scatter_max. - # Compute the grouped stats on CPU and move the results back so verbose - # evaluation works without changing the torch/torch_scatter stack. - if abs_residual.device.type == "mps": - abs_residual_cpu = abs_residual.cpu() - batch_masked_cpu = batch_masked.cpu() - mean_res = scatter_mean(abs_residual_cpu, batch_masked_cpu, dim=0).to( - abs_residual.device, - ) - max_res, _ = scatter_max(abs_residual_cpu, batch_masked_cpu, dim=0) - max_res = max_res.to(abs_residual.device) - else: - mean_res = scatter_mean(abs_residual, batch_masked, dim=0) - max_res, _ = scatter_max(abs_residual, batch_masked, dim=0) + mean_res = scatter_mean(torch.abs(residual_masked), batch_masked, dim=0) + max_res, _ = scatter_max(torch.abs(residual_masked), batch_masked, dim=0) return mean_res, max_res @@ -45,27 +30,19 @@ def plot_residuals_histograms(outputs, dataset_name, plot_dir): for stat_key, title in stats: # Gather all data first to compute common bin edges - all_data = ( - torch.cat( - [ - torch.cat([d[f"{stat_key}_{bus_type}"] for d in outputs]) - for bus_type in bus_types - ], - ) - .float() - .numpy() - ) + all_data = torch.cat( + [ + torch.cat([d[f"{stat_key}_{bus_type}"] for d in outputs]) + for bus_type in bus_types + ], + ).numpy() # Define bins across the entire data range bins = np.linspace(all_data.min(), all_data.max(), 61) # 30 bins of equal width plt.figure(figsize=(10, 6)) for bus_type, color in zip(bus_types, colors): - data = ( - torch.cat([d[f"{stat_key}_{bus_type}"] for d in outputs]) - .float() - .numpy() - ) + data = torch.cat([d[f"{stat_key}_{bus_type}"] for d in outputs]).numpy() plt.hist(data, bins=bins, alpha=0.6, label=bus_type, color=color) plt.title(f"{title} per Bus Type in {dataset_name}") diff --git a/gridfm_graphkit/tasks/vld_task.py b/gridfm_graphkit/tasks/vld_task.py new file mode 100644 index 00000000..2d5ebd3e --- /dev/null +++ b/gridfm_graphkit/tasks/vld_task.py @@ -0,0 +1,158 @@ +import os +import torch +import torch.distributed as dist +import pandas as pd +from lightning.pytorch.loggers import MLFlowLogger + +from gridfm_graphkit.io.registries import TASK_REGISTRY +from gridfm_graphkit.tasks.reconstruction_tasks import ReconstructionTask +from gridfm_graphkit.datasets.globals import ( + VM_OUT, + VM_H, + BUS_STATUS_TARGET, + BUS_STATUS_LOGIT_OUT, +) + +@TASK_REGISTRY.register("VoltageLossDetection") +class VoltageLossDetectionTask(ReconstructionTask): + """ + Topology-aware voltage loss detection task. + + Uses the standard ReconstructionTask training/validation flow and adds + VLD-specific test/predict metrics for bus status and Vm behavior. + """ + + def __init__(self, args, data_normalizers): + super().__init__(args, data_normalizers) + + def test_step(self, batch, batch_idx, dataloader_idx=0): + output, loss_dict = self.shared_step(batch) + dataset_name = self.args.data.networks[dataloader_idx] + + bus_pred = output["bus"] + bus_target = batch.y_dict["bus"] + + status_prob = torch.sigmoid(bus_pred[:, BUS_STATUS_LOGIT_OUT]) + status_pred = (status_prob >= 0.5).float() + status_true = bus_target[:, BUS_STATUS_TARGET].float() + + vm_pred = bus_pred[:, VM_OUT] + vm_true = bus_target[:, VM_H] + + status_acc = (status_pred == status_true).float().mean() + + off_mask = status_true < 0.5 + on_mask = status_true >= 0.5 + + off_vm_mae = ( + vm_pred[off_mask].abs().mean() + if off_mask.any() + else torch.tensor(0.0, device=vm_pred.device) + ) + on_vm_rmse = ( + torch.sqrt(torch.mean((vm_pred[on_mask] - vm_true[on_mask]) ** 2)) + if on_mask.any() + else torch.tensor(0.0, device=vm_pred.device) + ) + + loss_dict["Status Accuracy"] = status_acc.detach() + loss_dict["OFF Vm MAE"] = off_vm_mae.detach() + loss_dict["ON Vm RMSE"] = on_vm_rmse.detach() + + loss_dict["Test loss"] = loss_dict.pop("loss").detach() + + for metric, value in loss_dict.items(): + metric_name = f"{dataset_name}/{metric}" + self.log( + metric_name, + value, + batch_size=batch.num_graphs, + add_dataloader_idx=False, + sync_dist=True, + logger=False, + ) + + self.test_outputs[dataloader_idx].append( + { + "dataset": dataset_name, + "status_prob": status_prob.detach().cpu(), + "status_pred": status_pred.detach().cpu(), + "status_true": status_true.detach().cpu(), + "vm_pred": vm_pred.detach().cpu(), + "vm_true": vm_true.detach().cpu(), + } + ) + + def on_test_end(self): + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + gathered = [None] * world_size if dist.get_rank() == 0 else None + dist.gather_object(self.test_outputs, gathered, dst=0) + if dist.get_rank() == 0: + merged = {i: [] for i in range(len(self.args.data.networks))} + for rank_data in gathered: + for dl_idx, batches in rank_data.items(): + merged[dl_idx].extend(batches) + self.test_outputs = merged + + if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: + return + + if isinstance(self.logger, MLFlowLogger): + artifact_dir = os.path.join( + self.logger.save_dir, + self.logger.experiment_id, + self.logger.run_id, + "artifacts", + ) + else: + artifact_dir = self.logger.save_dir + + test_dir = os.path.join(artifact_dir, "test") + os.makedirs(test_dir, exist_ok=True) + + for dataset_idx, outputs in self.test_outputs.items(): + if not outputs: + continue + + dataset_name = self.args.data.networks[dataset_idx] + status_prob = torch.cat([o["status_prob"] for o in outputs]).numpy() + status_pred = torch.cat([o["status_pred"] for o in outputs]).numpy() + status_true = torch.cat([o["status_true"] for o in outputs]).numpy() + vm_pred = torch.cat([o["vm_pred"] for o in outputs]).numpy() + vm_true = torch.cat([o["vm_true"] for o in outputs]).numpy() + + df = pd.DataFrame( + { + "status_prob": status_prob, + "status_pred": status_pred, + "status_true": status_true, + "vm_pred": vm_pred, + "vm_true": vm_true, + } + ) + df.to_csv(os.path.join(test_dir, f"{dataset_name}_vld_predictions.csv"), index=False) + + self.test_outputs.clear() + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output, _ = self.shared_step(batch) + + bus_pred = output["bus"] + status_prob = torch.sigmoid(bus_pred[:, BUS_STATUS_LOGIT_OUT]) + + bus_batch = batch.batch_dict["bus"] + scenario_ids = batch["scenario_id"][bus_batch] + + local_bus_idx = torch.cat( + [torch.arange(c, device=bus_batch.device) for c in torch.bincount(bus_batch)] + ) + + return { + "scenario": scenario_ids.cpu().numpy(), + "bus": local_bus_idx.cpu().numpy(), + "vm_pred": bus_pred[:, VM_OUT].detach().cpu().numpy(), + "status_prob": status_prob.detach().cpu().numpy(), + "status_pred": (status_prob >= 0.5).float().detach().cpu().numpy(), + "status_true": batch.y_dict["bus"][:, BUS_STATUS_TARGET].detach().cpu().numpy(), + } \ No newline at end of file diff --git a/gridfm_graphkit/training/__init__.py b/gridfm_graphkit/training/__init__.py index 15452eec..146b834f 100644 --- a/gridfm_graphkit/training/__init__.py +++ b/gridfm_graphkit/training/__init__.py @@ -4,6 +4,7 @@ LayeredWeightedPhysicsLoss, MaskedBusMSE, MaskedGenMSE, + VLDTopologyLoss ) __all__ = [ @@ -13,4 +14,5 @@ "MaskedBusMSE", "MaskedGenMSE", "LossPerDim", + "VLDTopologyLoss", ] diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index ba7a4049..e755133f 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -2,46 +2,10 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only from lightning.pytorch.loggers import MLFlowLogger import os -import time import torch -class EpochTimerCallback(Callback): - """Records wall-clock duration and iteration rate of every training epoch.""" - - def __init__(self): - self.epoch_times: list[float] = [] - self._epoch_start: float | None = None - self._batch_count: int = 0 - self._last_batch_count: int = 0 - - def on_train_epoch_start(self, trainer, pl_module): - self._epoch_start = time.perf_counter() - self._batch_count = 0 - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - self._batch_count += 1 - - def on_train_epoch_end(self, trainer, pl_module): - if self._epoch_start is not None: - self.epoch_times.append(time.perf_counter() - self._epoch_start) - self._last_batch_count = self._batch_count - self._epoch_start = None - - @property - def last_epoch_time(self) -> float | None: - return self.epoch_times[-1] if self.epoch_times else None - - @property - def last_epoch_iters_per_sec(self) -> float | None: - t = self.last_epoch_time - if t is None or t == 0 or self._last_batch_count == 0: - return None - return self._last_batch_count / t - - class SaveBestModelStateDict(Callback): - """Persist the best model state_dict according to a monitored validation metric.""" def __init__( self, monitor: str, @@ -53,15 +17,6 @@ def __init__( self.filename = filename self.best_score = float("inf") if mode == "min" else -float("inf") - @staticmethod - def _canonical_state_dict(pl_module): - """Return a state dict with compile wrappers removed from key names.""" - state_dict = pl_module.state_dict() - return { - key.replace("model._orig_mod.", "model."): value - for key, value in state_dict.items() - } - @rank_zero_only def on_validation_end(self, trainer, pl_module): current = trainer.callback_metrics.get(self.monitor) @@ -91,4 +46,4 @@ def on_validation_end(self, trainer, pl_module): # Save the model's state_dict model_path = os.path.join(model_dir, self.filename) - torch.save(self._canonical_state_dict(pl_module), model_path) + torch.save(pl_module.state_dict(), model_path) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index a0521fc2..d7f8257c 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -19,9 +19,16 @@ PG_OUT, # Generator feature indices PG_H, - # Qg Limits - MIN_QG_H, - MAX_QG_H, + ##################### + ## Indices of features of the VLD task + # Bus feature indices + BUS_BASE_STATUS_H, + BUS_CONT_H, + B_ON, + # Branch feature indices + BRANCH_BASE_STATUS_E, + BRANCH_CONT_E, + ##################### ) @@ -39,7 +46,6 @@ def forward( edge_attr=None, mask=None, model=None, - x_dict=None, ): """ Compute the loss. @@ -76,7 +82,6 @@ def forward( edge_attr=None, mask=None, model=None, - x_dict=None, ): loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction) return {"loss": loss, "Masked MSE loss": loss.detach()} @@ -84,7 +89,6 @@ def forward( @LOSS_REGISTRY.register("MaskedGenMSE") class MaskedGenMSE(torch.nn.Module): - """Compute MSE on generator targets restricted to generator mask entries.""" def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -97,7 +101,6 @@ def forward( edge_attr, mask_dict, model=None, - x_dict=None, ): loss = F.mse_loss( pred_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], @@ -109,7 +112,6 @@ def forward( @LOSS_REGISTRY.register("MaskedBusMSE") class MaskedBusMSE(torch.nn.Module): - """Compute MSE on selected bus targets, respecting task-specific output columns.""" def __init__(self, loss_args, args): super().__init__() self.reduction = "mean" @@ -123,7 +125,6 @@ def forward( edge_attr, mask_dict, model=None, - x_dict=None, ): if self.args.task == "OptimalPowerFlow": pred_cols = [VM_OUT, VA_OUT, QG_OUT] @@ -161,7 +162,6 @@ def forward( edge_attr=None, mask=None, model=None, - x_dict=None, ): loss = F.mse_loss(pred, target, reduction=self.reduction) return {"loss": loss, "MSE loss": loss.detach()} @@ -195,7 +195,6 @@ def forward( edge_attr=None, mask=None, model=None, - x_dict=None, ): """ Compute the weighted sum of all specified losses. @@ -222,7 +221,6 @@ def forward( edge_attr, mask, model, - x_dict, ) # Assume each loss function returns a dictionary with a "loss" key @@ -239,9 +237,10 @@ def forward( return loss_details + + @LOSS_REGISTRY.register("LayeredWeightedPhysics") class LayeredWeightedPhysicsLoss(BaseLoss): - """Combine intermediate physics residuals using normalized geometric weights.""" def __init__(self, loss_args, args) -> None: super().__init__() self.base_weight = loss_args.base_weight @@ -254,7 +253,6 @@ def forward( edge_attr=None, mask=None, model=None, - x_dict=None, ): total_loss = 0.0 loss_details = {} @@ -282,7 +280,6 @@ def forward( @LOSS_REGISTRY.register("LossPerDim") class LossPerDim(BaseLoss): - """Compute MAE/MSE for one named physical dimension of bus outputs.""" def __init__(self, loss_args, args): super(LossPerDim, self).__init__() self.reduction = "mean" @@ -306,7 +303,6 @@ def forward( edge_attr, mask_dict, model=None, - x_dict=None, ): if self.dim == "VM": temp_pred = pred_dict["bus"][:, VM_OUT] @@ -339,56 +335,166 @@ def forward( f"MAE loss {self.dim}": mae_loss.detach(), } +####################### +@LOSS_REGISTRY.register("VLDTopologyLoss") +class VLDTopologyLoss(BaseLoss): + """ + Topology-first voltage loss detection objective. -@LOSS_REGISTRY.register("QgViolationPenalty") -class QgViolationPenaltyLoss(BaseLoss): - """Standard Mean Squared Error loss.""" + Expected bus tensor layouts: + pred_dict["bus"] : [Vm, Va, Pg, Qg, status_logit] + target_dict["bus"] : [Pd, Qd, Qg, Vm, Va, bus_status_target] + x_dict["bus"] : standard bus features + [bus_base_status, bus_contingency] + edge_attr : standard edge attrs + [branch_base_status, branch_contingency] + """ def __init__(self, loss_args, args): super().__init__() + self.args = args + self.input_state_threshold = getattr(loss_args, "input_state_threshold", 0.5) + self.prediction_threshold = getattr(loss_args, "prediction_threshold", 0.5) + self.topology_weight = getattr(loss_args, "topology_weight", 1.0) + self.target_anchor_weight = getattr(loss_args, "target_anchor_weight", 0.25) + self.off_vm_weight = getattr(loss_args, "off_vm_weight", 1.0) + self.on_vm_weight = getattr(loss_args, "on_vm_weight", 0.5) + self.unreachable_l1_weight = getattr(loss_args, "unreachable_l1_weight", 1.0) + self.topology_confidence_gamma = getattr(loss_args, "topology_confidence_gamma", 10.0) + + @staticmethod + def _build_graph_reachability( + num_bus, + edge_index, + edge_attr, + bus_x, + device, + threshold=0.5, + ): + """ + Build hard reachability labels from base-status and contingency indicators. + + A bus is considered initially available if base_status is on and it is not + directly hit by contingency. A branch is traversable if base branch status + is on and it is not hit by contingency. + """ + bus_base = (bus_x[:, BUS_BASE_STATUS_H] > threshold) + bus_hit = (bus_x[:, BUS_CONT_H] > threshold) + bus_available = bus_base & (~bus_hit) + + src, dst = edge_index + branch_base = (edge_attr[:, BRANCH_BASE_STATUS_E] > threshold) + branch_hit = (edge_attr[:, BRANCH_CONT_E] > threshold) + branch_available = branch_base & (~branch_hit) + + reachable = torch.zeros(num_bus, dtype=torch.bool, device=device) + + # Seeds: all buses that remain available after direct contingency. + seed_nodes = torch.where(bus_available)[0] + if seed_nodes.numel() == 0: + return reachable.float(), bus_available.float() + + reachable[seed_nodes] = True + + changed = True + while changed: + prev = reachable.clone() + active_edges = branch_available & reachable[src] + reachable[dst[active_edges]] = True + changed = not torch.equal(prev, reachable) + + return reachable.float(), bus_available.float() def forward( - self, - pred, - target, - edge_index=None, - edge_attr=None, - mask=None, - model=None, - x_dict=None, + self, + pred_dict, + target_dict, + edge_index_dict, + edge_attr_dict, + mask_dict, + model=None, ): - # --- Qg limit violation mask --- - Qg_pred = pred["bus"][:, QG_OUT] - Qg_max = x_dict["bus"][:, MAX_QG_H] - Qg_min = x_dict["bus"][:, MIN_QG_H] + bus_pred = pred_dict["bus"] + bus_target = target_dict["bus"] + bus_x = model.latest_x_dict["bus"] if hasattr(model, "latest_x_dict") else None + + if bus_x is None: + raise RuntimeError( + "VLDTopologyLoss requires model.latest_x_dict['bus']. " + "Store x_dict on the model inside the forward pass." + ) - max_penalty_mask = (Qg_pred > Qg_max) - min_penalty_mask = (Qg_pred < Qg_min) + if bus_pred.size(1) <= BUS_STATUS_LOGIT_OUT: + raise ValueError( + "VLDTopologyLoss expects bus predictions to include a status logit " + f"at column {BUS_STATUS_LOGIT_OUT}." + ) - mask_PQ = mask["PQ"] # PQ buses - mask_PV = mask["PV"] # PV buses - mask_REF = mask["REF"] # Reference buses + edge_index = edge_index_dict[("bus", "connects", "bus")] + edge_attr = edge_attr_dict[("bus", "connects", "bus")] - loss = 0.0 - # where there are violations, compute penalty loss - Qg_over = F.relu(Qg_pred - Qg_max) # amount above max limit - Qg_under = F.relu(Qg_min - Qg_pred) # amount below min limit + num_bus = bus_pred.size(0) + device = bus_pred.device - Qg_over = Qg_over[max_penalty_mask].mean() - Qg_under = Qg_under[min_penalty_mask].mean() - - if Qg_over!=Qg_over: # replacing nan with 0 - Qg_over = 0.0 - if Qg_under!=Qg_under: # replacing nan with 0 - Qg_under = 0.0 + topo_target, bus_available = self._build_graph_reachability( + num_bus=num_bus, + edge_index=edge_index, + edge_attr=edge_attr, + bus_x=bus_x, + device=device, + threshold=self.input_state_threshold, + ) - penalty_loss = Qg_over + Qg_under - loss += penalty_loss + status_logit = bus_pred[:, BUS_STATUS_LOGIT_OUT] + status_prob = torch.sigmoid(status_logit) - try: - output = {"loss": loss, "Qg Violation Penalty loss": loss.detach()} - except: - output = {"loss": loss, "Qg Violation Penalty loss": loss} + target_status = bus_target[:, BUS_STATUS_TARGET].float() + target_vm = bus_target[:, VM_H].float() + pred_vm = bus_pred[:, VM_OUT].float() - return output + topology_confidence = torch.exp( + -self.topology_confidence_gamma * torch.abs(bus_available - topo_target) + ) + + topology_bce_raw = F.binary_cross_entropy_with_logits( + status_logit, + topo_target, + reduction="none", + ) + topology_bce = (topology_confidence * topology_bce_raw).mean() + target_anchor_bce = F.binary_cross_entropy_with_logits( + status_logit, + target_status, + reduction="mean", + ) + + unreachable_mask = (topo_target < 0.5).float() + reachable_mask = (topo_target >= 0.5).float() + + off_vm_l2 = ((pred_vm ** 2) * unreachable_mask).sum() / unreachable_mask.sum().clamp_min(1.0) + off_vm_l1 = (pred_vm.abs() * unreachable_mask).sum() / unreachable_mask.sum().clamp_min(1.0) + off_vm_loss = off_vm_l2 + self.unreachable_l1_weight * off_vm_l1 + + on_vm_sq = ((pred_vm - target_vm) ** 2) * reachable_mask * status_prob.detach() + on_vm_loss = on_vm_sq.sum() / (reachable_mask * status_prob.detach()).sum().clamp_min(1.0) + + total_loss = ( + self.topology_weight * topology_bce + + self.target_anchor_weight * target_anchor_bce + + self.off_vm_weight * off_vm_loss + + self.on_vm_weight * on_vm_loss + ) + + pred_status = (status_prob >= self.prediction_threshold).float() + topo_acc = (pred_status == topo_target).float().mean() + target_acc = (pred_status == target_status).float().mean() + + return { + "loss": total_loss, + "VLD Topology BCE": topology_bce.detach(), + "VLD Target Anchor BCE": target_anchor_bce.detach(), + "VLD Off Vm Loss": off_vm_loss.detach(), + "VLD On Vm Loss": on_vm_loss.detach(), + "VLD Topology Accuracy": topo_acc.detach(), + "VLD Target Accuracy": target_acc.detach(), + } +####################### \ No newline at end of file diff --git a/gridfm_graphkit/utils/visualization.py b/gridfm_graphkit/utils/visualization.py index 3a8151c8..276d403b 100644 --- a/gridfm_graphkit/utils/visualization.py +++ b/gridfm_graphkit/utils/visualization.py @@ -11,7 +11,6 @@ def visualize_error(data_point, output, node_normalizer): - """Plot node-wise active power residuals on the grid topology.""" loss = PBELoss(visualization=True) loss_dict = loss(