Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/quick_start/yaml_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ Task name registered in the framework:

### `data.networks`

List of dataset folders under your data root.
List of dataset folders under your data root.
Examples: `case14_ieee`, `case118_ieee`, `case2000_goc`, `Texas2k_case1_2016summerpeak`.

### `data.scenarios`

List of scenario counts, one value per network in `data.networks`.
List of scenario counts, one value per network in `data.networks`.
Example: with two networks, use two scenario entries in matching order.

### `data.normalization`
Expand Down
52 changes: 51 additions & 1 deletion gridfm_graphkit/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import argparse
from datetime import datetime
from gridfm_graphkit.cli import main_cli, benchmark_cli
import torch.multiprocessing


import subprocess
import os
import socket


def is_lsf():
return (
Expand All @@ -13,6 +16,7 @@ def is_lsf():
and "LSF_ENVDIR" in os.environ # strong LSF indicator
)


def fix_infiniband():
"""Configure NCCL to skip Ethernet-only IB ports on this host."""
ibv = subprocess.run("ibv_devinfo", stdout=subprocess.PIPE, stderr=subprocess.PIPE)
Expand Down Expand Up @@ -46,7 +50,13 @@ def set_env():
os.environ["MASTER_ADDR"] = HOST_LIST[
0
] # Sets the MasterNode to thefirst node on the list of hosts
os.environ["MASTER_PORT"] = "5" + LSB_JOBID[-5:-1]
# Derive port deterministically from job ID so all ranks agree on the same
# value. Use a wider hash (CRC32 of the full job ID) mapped into the
# ephemeral range [49152, 65535] to minimise collisions with other jobs.
import binascii

crc = binascii.crc32(LSB_JOBID.encode()) & 0xFFFFFFFF
os.environ["MASTER_PORT"] = str(49152 + (crc % 16383))
os.environ["NODE_RANK"] = str(
HOST_LIST.index(os.environ["HOSTNAME"]),
) # Uses the list index for node rank, master node rank must be 0
Expand All @@ -55,6 +65,7 @@ def set_env():
)
os.environ["NCCL_IB_CUDA_SUPPORT"] = "1" # Force use of infiniband


def main():
"""Parse CLI arguments and dispatch to the selected GridFM subcommand."""
if is_lsf():
Expand Down Expand Up @@ -91,6 +102,12 @@ def main():
default=False,
help="Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision('high').",
)
_start_method_kwargs = dict(
type=str,
default="spawn",
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method for dataloader workers (default: spawn). Use 'fork' for faster startup on CPU-only machines.",
)

# ---- TRAIN SUBCOMMAND ----
train_parser = subparsers.add_parser("train", help="Run training")
Expand Down Expand Up @@ -120,6 +137,12 @@ def main():
default=None,
help="Override data.workers from the YAML config. Use 0 to debug worker crashes.",
)
train_parser.add_argument(
"--batch_size",
type=int,
default=None,
help="Override training.batch_size from the YAML config.",
)
train_parser.add_argument(
"--dataset_wrapper_cache_dir",
type=str,
Expand All @@ -143,6 +166,9 @@ def main():
action="store_true",
help="Print the last training epoch time and a single test metric to stdout.",
)
train_parser.add_argument(
"--start-method", dest="start_method", **_start_method_kwargs,
)

# ---- FINETUNE SUBCOMMAND ----
finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
Expand Down Expand Up @@ -173,6 +199,12 @@ def main():
default=None,
help="Override data.workers from the YAML config. Use 0 to debug worker crashes.",
)
finetune_parser.add_argument(
"--batch_size",
type=int,
default=None,
help="Override training.batch_size from the YAML config.",
)
finetune_parser.add_argument(
"--dataset_wrapper_cache_dir",
type=str,
Expand All @@ -196,6 +228,9 @@ def main():
action="store_true",
help="Print the last training epoch time and a single test metric to stdout.",
)
finetune_parser.add_argument(
"--start-method", dest="start_method", **_start_method_kwargs,
)

# ---- EVALUATE SUBCOMMAND ----
evaluate_parser = subparsers.add_parser(
Expand Down Expand Up @@ -262,6 +297,9 @@ def main():
"--save_output",
action="store_true",
)
evaluate_parser.add_argument(
"--start-method", dest="start_method", **_start_method_kwargs,
)

# ---- PREDICT SUBCOMMAND ----
predict_parser = subparsers.add_parser("predict", help="Run prediction")
Expand Down Expand Up @@ -312,6 +350,9 @@ def main():
default=None,
choices=["simple", "advanced", "pytorch"],
)
predict_parser.add_argument(
"--start-method", dest="start_method", **_start_method_kwargs,
)

# ---- BENCHMARK SUBCOMMAND ----
benchmark_parser = subparsers.add_parser(
Expand Down Expand Up @@ -344,6 +385,12 @@ def main():
default=None,
help="Override data.workers from the YAML config.",
)
benchmark_parser.add_argument(
"--batch_size",
type=int,
default=None,
help="Override training.batch_size from the YAML config.",
)
benchmark_parser.add_argument(
"--plugins",
nargs="*",
Expand All @@ -353,6 +400,9 @@ def main():

args = parser.parse_args()

start_method = getattr(args, "start_method", "spawn")
torch.multiprocessing.set_start_method(start_method, force=True)

if args.command == "benchmark":
benchmark_cli(args)
else:
Expand Down
91 changes: 58 additions & 33 deletions gridfm_graphkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import importlib
import numpy as np
import os
import socket
import time
import yaml
import torch
Expand All @@ -24,7 +25,9 @@
import lightning as L


def _normalize_loaded_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
def _normalize_loaded_state_dict_keys(
state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""Map legacy torch.compile checkpoint keys to the canonical model namespace."""
has_compiled_prefix = any(key.startswith("model._orig_mod.") for key in state_dict)
if not has_compiled_prefix:
Expand Down Expand Up @@ -80,6 +83,10 @@
if num_workers_override is not None:
config_args.data.workers = num_workers_override

batch_size_override = getattr(args, "batch_size", None)
if batch_size_override is not None:
config_args.training.batch_size = batch_size_override

_load_plugins(getattr(args, "plugins", []))

dataset_wrapper = getattr(args, "dataset_wrapper", None)
Expand Down Expand Up @@ -177,6 +184,7 @@
if num_workers_override is not None:
config_args.data.workers = num_workers_override

# CLI --batch_size overrides the YAML value
batch_size_override = getattr(args, "batch_size", None)
if batch_size_override is not None:
config_args.training.batch_size = batch_size_override
Expand Down Expand Up @@ -229,14 +237,38 @@
_accelerator = config_args.training.accelerator
_strategy = config_args.training.strategy
# if mps is available and accelerator is auto, explicitely set accelerator to mps to select the right strategy in the next block
if _accelerator == "auto" and torch.backends.mps.is_available():
if _accelerator == "auto" and torch.backends.mps.is_available():
_accelerator = "mps"
if _accelerator not in ("mps", "cpu") and isinstance(_strategy, str) and _strategy in (
"auto",
"ddp",
"ddp_find_unused_parameters_true",
): # when using mps, we don't want to use ddp.
_strategy = DDPStrategy(find_unused_parameters=True)
if (
_accelerator not in ("mps", "cpu")
and isinstance(_strategy, str)
and _strategy
in (
"auto",
"ddp",
"ddp_find_unused_parameters_true",
)
):
_num_devices = config_args.training.devices
_single_device = _num_devices in (1, "1", [1]) or (
isinstance(_num_devices, list) and len(_num_devices) == 1
)
if _single_device:
# No need for DDP with a single device; use auto to avoid NCCL init
_strategy = "auto"
else:
# Ensure MASTER_PORT is set to a free port to avoid EADDRINUSE
if not os.environ.get("MASTER_PORT"):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _s:
_s.bind(("", 0))

Check warning

Code scanning / CodeQL

Binding a socket to all network interfaces Medium

Binding a socket to all interfaces (using
''
) is a security risk.
Binding a socket to all interfaces (using
''
) is a security risk.
Binding a socket to all interfaces (using
''
) is a security risk.
Binding a socket to all interfaces (using
''
) is a security risk.
_s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
os.environ["MASTER_PORT"] = str(_s.getsockname()[1])
if not os.environ.get("MASTER_ADDR"):
os.environ["MASTER_ADDR"] = "127.0.0.1"
# Use gloo backend to avoid NCCL socket interface issues
_strategy = DDPStrategy(
find_unused_parameters=True, process_group_backend="gloo",
)

trainer = L.Trainer(
logger=logger,
Expand All @@ -252,19 +284,25 @@
)
if args.command == "train" or args.command == "finetune":
trainer.fit(model=model, datamodule=litGrid)
if (
report_performance
and epoch_timer is not None
and epoch_timer.last_epoch_time is not None
):
print(f"[performance] last epoch time : {epoch_timer.last_epoch_time:.3f}s")
if (
epoch_timer.last_epoch_iters_per_sec is not None
and epoch_timer._last_batch_count > 0
):
if report_performance:
# Validation loss
val_loss = trainer.callback_metrics.get("Validation loss")
if val_loss is not None:
print(f"[performance] Validation loss : {float(val_loss):.6f}")
else:
print("[performance] Validation loss : not available")
# Epoch timing
if epoch_timer is not None and epoch_timer.last_epoch_time is not None:
print(
f"[performance] last epoch it/s : {epoch_timer.last_epoch_iters_per_sec:.2f}",
f"[performance] last epoch time : {epoch_timer.last_epoch_time:.3f}s",
)
if (
epoch_timer.last_epoch_iters_per_sec is not None
and epoch_timer._last_batch_count > 0
):
print(
f"[performance] last epoch it/s : {epoch_timer.last_epoch_iters_per_sec:.2f}",
)

if args.command != "predict":
# Reuse the fit trainer when coming from train/finetune so that
Expand All @@ -283,20 +321,7 @@
**trainer_kwargs,
profiler=profiler,
)
test_results = test_trainer.test(model=model, datamodule=litGrid)
if report_performance:
# test_results[0] may be empty when metrics are routed to the logger
# only; fall back to trainer.callback_metrics which always has them.
metrics = (
test_results[0]
if test_results and test_results[0]
else dict(test_trainer.callback_metrics)
)
if metrics:
first_metric, first_value = next(iter(metrics.items()))
print(f"[performance] {first_metric} : {first_value}")
else:
print("[performance] no test metrics available")
test_trainer.test(model=model, datamodule=litGrid)

artifacts_dir = None
is_rank0 = (
Expand Down
49 changes: 25 additions & 24 deletions gridfm_graphkit/datasets/hetero_powergrid_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,18 @@ def __init__(
self._is_setup_done = False

if self.split_by_load_scenario_idx:
assert self.split_from_existing_files is None, " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both"
assert self.split_from_existing_files is None, (
" either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both"
)

if self.split_from_existing_files is not None:
assert isinstance(self.split_from_existing_files, str), "`split_from_existing_files` must be an existing folder in string format"
assert isinstance(self.split_from_existing_files, str), (
"`split_from_existing_files` must be an existing folder in string format"
)
self.split_from_existing_files = Path(self.split_from_existing_files)
assert self.split_from_existing_files.is_dir(), "`split_from_existing_files` must be an existing folder in string format"

assert self.split_from_existing_files.is_dir(), (
"`split_from_existing_files` must be an existing folder in string format"
)

def setup(self, stage: str):
if self._is_setup_done:
Expand Down Expand Up @@ -184,7 +189,6 @@ def setup(self, stage: str):
# Create a subset
all_indices = list(range(len(dataset)))


if self.split_from_existing_files is not None:
warnings.warn(
"`data.scenarios` is ignored when `split_from_existing_files` is set; "
Expand Down Expand Up @@ -229,13 +233,13 @@ def setup(self, stage: str):
# load_scenario for each scenario in the subset
load_scenarios = dataset.load_scenarios[subset_indices]


dataset = Subset(dataset, subset_indices)

if self.dataset_wrapper is not None:
wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper)
dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir)

dataset = wrapper_cls(
dataset, cache_dir=self.dataset_wrapper_cache_dir,
)

# Random seed set before every split, same as above
np.random.seed(self.args.seed)
Expand Down Expand Up @@ -425,24 +429,21 @@ def _dataloader_kwargs(self):
pin_memory=torch.cuda.is_available(),
persistent_workers=num_workers > 0,
)
# Use 'fork' on Linux. It avoids the forkserver intermediary pipe which
# is fragile when the process has many threads (e.g. OpenBLAS). In
# container environments (Kubernetes) fork works correctly. On
# traditional HPC systems with strict fd-passing restrictions the
# original 'forkserver' may be needed, but the pipe truncation it
# produces under thread pressure is worse than the ancdata warning.
if (
num_workers > 0
and torch.multiprocessing.get_start_method(allow_none=True) != "spawn"
):
import platform

if platform.system() == "Linux":
kwargs["multiprocessing_context"] = "fork"
# Use 'spawn' for DataLoader workers: CUDA is already initialized by
# Lightning before DataLoaders are created, and forking a process that
# has an active CUDA context corrupts the child's GPU state. 'spawn'
# starts a clean interpreter without inheriting the CUDA context.
if num_workers > 0:
kwargs["multiprocessing_context"] = "spawn"
return kwargs

def train_dataloader(self):
print("creating train dataloader for rank ", dist.get_rank() if dist.is_available() and dist.is_initialized() else "not distributed")
print(
"creating train dataloader for rank ",
dist.get_rank()
if dist.is_available() and dist.is_initialized()
else "not distributed",
)
return DataLoader(
self.train_dataset_multi,
batch_size=self.batch_size,
Expand Down
Loading
Loading