Skip to content
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ mlruns
*data_out*
site*
.venv
*.zip
32 changes: 32 additions & 0 deletions gridfm_graphkit/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging
from datetime import datetime
from gridfm_graphkit.cli import main_cli, benchmark_cli

Expand Down Expand Up @@ -65,6 +66,17 @@ def main():
prog="gridfm_graphkit",
description="gridfm-graphkit CLI",
)
parser.add_argument(
"--log_level",
type=str,
default="WARNING",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help=(
"Python logging level for the gridfm_graphkit package. "
"Use DEBUG to see performance trace messages (dataset init, "
"split, normalizer fit, wrapper cache timings)."
),
)
subparsers = parser.add_subparsers(dest="command", required=True)
exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

Expand Down Expand Up @@ -120,6 +132,12 @@ def main():
default=None,
help="Override data.workers from the YAML config. Use 0 to debug worker crashes.",
)
train_parser.add_argument(
"--batch_size",
type=int,
default=None,
help="Override training.batch_size from the YAML config.",
)
train_parser.add_argument(
"--dataset_wrapper_cache_dir",
type=str,
Expand Down Expand Up @@ -173,6 +191,12 @@ def main():
default=None,
help="Override data.workers from the YAML config. Use 0 to debug worker crashes.",
)
finetune_parser.add_argument(
"--batch_size",
type=int,
default=None,
help="Override training.batch_size from the YAML config.",
)
finetune_parser.add_argument(
"--dataset_wrapper_cache_dir",
type=str,
Expand Down Expand Up @@ -341,6 +365,14 @@ def main():

args = parser.parse_args()

logging.basicConfig(
level=getattr(logging, args.log_level),
format="%(asctime)s %(name)s %(levelname)s %(message)s",
)
# Ensure the gridfm_graphkit package logger respects the chosen level even
# if third-party libraries have already configured the root logger.
logging.getLogger("gridfm_graphkit").setLevel(getattr(logging, args.log_level))

if args.command == "benchmark":
benchmark_cli(args)
else:
Expand Down
33 changes: 18 additions & 15 deletions gridfm_graphkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def main_cli(args):
if getattr(args, "tf32", False):
torch.set_float32_matmul_precision("high") # enables TF32 on Ampere+ GPUs

# MLflow file-store requires the directory to exist before DDP workers start.
os.makedirs(args.log_dir, exist_ok=True)

logger = MLFlowLogger(
save_dir=args.log_dir,
experiment_name=args.exp_name,
Expand All @@ -177,6 +180,10 @@ def main_cli(args):
if num_workers_override is not None:
config_args.data.workers = num_workers_override

batch_size_override = getattr(args, "batch_size", None)
if batch_size_override is not None:
config_args.training.batch_size = batch_size_override

_load_plugins(getattr(args, "plugins", []))
_validate_dataset_wrapper(dataset_wrapper)

Expand Down Expand Up @@ -228,7 +235,9 @@ def main_cli(args):
"ddp",
"ddp_find_unused_parameters_true",
):
_strategy = DDPStrategy(find_unused_parameters=True)
_strategy = DDPStrategy(
find_unused_parameters=(_strategy == "ddp_find_unused_parameters_true"),
)

trainer = L.Trainer(
logger=logger,
Expand Down Expand Up @@ -257,6 +266,13 @@ def main_cli(args):
print(
f"[performance] last epoch it/s : {epoch_timer.last_epoch_iters_per_sec:.2f}",
)
val_loss = trainer.callback_metrics.get("Validation loss")
if val_loss is not None:
try:
val_loss = val_loss.item()
except AttributeError:
pass
print(f"[performance] Validation loss : {val_loss}")

if args.command != "predict":
# Reuse the fit trainer when coming from train/finetune so that
Expand All @@ -275,20 +291,7 @@ def main_cli(args):
**trainer_kwargs,
profiler=profiler,
)
test_results = test_trainer.test(model=model, datamodule=litGrid)
if report_performance:
# test_results[0] may be empty when metrics are routed to the logger
# only; fall back to trainer.callback_metrics which always has them.
metrics = (
test_results[0]
if test_results and test_results[0]
else dict(test_trainer.callback_metrics)
)
if metrics:
first_metric, first_value = next(iter(metrics.items()))
print(f"[performance] {first_metric} : {first_value}")
else:
print("[performance] no test metrics available")
test_trainer.test(model=model, datamodule=litGrid)

artifacts_dir = None
is_rank0 = (
Expand Down
120 changes: 99 additions & 21 deletions gridfm_graphkit/datasets/hetero_powergrid_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import csv
import json
import logging
import time
import torch
import os
from torch_geometric.loader import DataLoader
Expand All @@ -23,7 +26,9 @@
import lightning as L
from pathlib import Path
from typing import List
from lightning.pytorch.loggers import MLFlowLogger
from gridfm_graphkit.utils.mlflow_artifact_utils import artifact_context

logger = logging.getLogger(__name__)


class LitGridHeteroDataModule(L.LightningDataModule):
Expand Down Expand Up @@ -134,15 +139,27 @@ def setup(self, stage: str):
print(f"Setup already done for stage={stage}, skipping...")
return

# Timing telemetry: list of dicts, one row per network / phase.
# Uploaded as a CSV artifact and emitted as MLflow metrics at the end.
_timing_rows: list[dict] = []

# Load pre-fitted normalizer stats if provided (e.g. from a training run)
saved_stats = None
if self.normalizer_stats_path is not None:
_t0 = time.perf_counter()
saved_stats = torch.load(
self.normalizer_stats_path,
map_location="cpu",
weights_only=True,
)
print(f"Loaded normalizer stats from {self.normalizer_stats_path}")
_timing_rows.append(
{
"network": "_global",
"phase": "load_normalizer_stats_s",
"value": time.perf_counter() - _t0,
},
)

for i, network in enumerate(self.args.data.networks):
data_normalizer = load_normalizer(args=self.args)
Expand All @@ -153,6 +170,7 @@ def setup(self, stage: str):

is_distributed = dist.is_available() and dist.is_initialized()

_t0 = time.perf_counter()
if not is_distributed or dist.get_rank() == 0:
dataset = HeteroGridDatasetDisk(
root=data_path_network,
Expand All @@ -170,6 +188,18 @@ def setup(self, stage: str):
data_normalizer=data_normalizer,
transform=get_task_transforms(args=self.args),
)
_timing_rows.append(
{
"network": network,
"phase": "dataset_init_s",
"value": time.perf_counter() - _t0,
},
)
logger.debug(
"[perf] %s dataset_init_s: %.3f",
network,
_timing_rows[-1]["value"],
)

self.datasets.append(dataset)

Expand Down Expand Up @@ -231,13 +261,13 @@ def setup(self, stage: str):


dataset = Subset(dataset, subset_indices)

if self.dataset_wrapper is not None:
wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper)
dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir)


# Random seed set before every split, same as above
_t0 = time.perf_counter()
np.random.seed(self.args.seed)
if self.split_by_load_scenario_idx:
train_dataset, val_dataset, test_dataset = (
Expand All @@ -256,6 +286,18 @@ def setup(self, stage: str):
self.args.data.val_ratio,
self.args.data.test_ratio,
)
_timing_rows.append(
{
"network": network,
"phase": "dataset_split_s",
"value": time.perf_counter() - _t0,
},
)
logger.debug(
"[perf] %s dataset_split_s: %.3f",
network,
_timing_rows[-1]["value"],
)

# Extract scenario IDs for each split
train_scenario_ids = self._extract_scenario_ids(
Expand All @@ -280,6 +322,7 @@ def setup(self, stage: str):
and network in saved_stats
and data_normalizer.fit_strategy == "fit_on_train"
)
_t0 = time.perf_counter()
if use_saved:
print(f"Restoring normalizer for {network} from saved stats")
data_normalizer.fit_from_dict(saved_stats[network])
Expand All @@ -294,11 +337,36 @@ def setup(self, stage: str):
num_scenarios,
saved_stats,
)
_timing_rows.append(
{
"network": network,
"phase": "normalizer_fit_s",
"value": time.perf_counter() - _t0,
},
)
logger.debug(
"[perf] %s normalizer_fit_s: %.3f",
network,
_timing_rows[-1]["value"],
)

# Populate the wrapper cache now that the normalizer is fitted,
# so transform() has BaseMVA set when __getitem__ is called.
if self.dataset_wrapper is not None and hasattr(dataset, "_setup_cache"):
_t0 = time.perf_counter()
dataset._setup_cache()
_timing_rows.append(
{
"network": network,
"phase": "wrapper_cache_s",
"value": time.perf_counter() - _t0,
},
)
logger.debug(
"[perf] %s wrapper_cache_s: %.3f",
network,
_timing_rows[-1]["value"],
)

self.train_datasets.append(train_dataset)
self.val_datasets.append(val_dataset)
Expand All @@ -311,27 +379,37 @@ def setup(self, stage: str):
self.val_dataset_multi = ConcatDataset(self.val_datasets)
self._is_setup_done = True

# Save scenario splits (rank 0 only in DDP)
# Emit setup timings to Logger (MLflow) and save as a CSV artifact.
# Only rank-0 writes; trainer / logger may not be attached in unit tests.
is_rank0 = (
not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0
)
if (
is_rank0
and self.trainer is not None
and getattr(self.trainer, "logger", None) is not None
):
logger = self.trainer.logger
if isinstance(logger, MLFlowLogger):
log_dir = os.path.join(
logger.save_dir,
logger.experiment_id,
logger.run_id,
"artifacts",
"stats",
)
else:
log_dir = os.path.join(logger.save_dir, "stats")
self.save_scenario_splits(log_dir)
if is_rank0 and self.trainer is not None:
_logger = getattr(self.trainer, "logger", None)
if _logger is not None:
if _timing_rows:
# Emit each timing as an MLflow metric (step=0 → setup phase)
mlflow_metrics = {
f"perf/setup/{r['network']}/{r['phase']}": r["value"]
for r in _timing_rows
}
try:
_logger.log_metrics(mlflow_metrics, step=0)
except Exception:
pass # non-MLflow loggers may not support log_metrics directly

# Save timing CSV as an artifact
with artifact_context(_logger, "stats") as _art_dir:
csv_path = os.path.join(_art_dir, "setup_timings.csv")
with open(csv_path, "w", newline="") as _fh:
writer = csv.DictWriter(
_fh,
fieldnames=["network", "phase", "value"],
)
writer.writeheader()
writer.writerows(_timing_rows)

# Save scenario splits (always, independent of timing)

@staticmethod
def _fit_normalizer(
Expand Down
Loading
Loading