diff --git a/.gitignore b/.gitignore index 3859d4b3..1126bc9b 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ mlruns *data_out* site* .venv +*.zip diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index b693089c..ca8f2c04 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -1,4 +1,5 @@ import argparse +import logging from datetime import datetime from gridfm_graphkit.cli import main_cli, benchmark_cli @@ -65,6 +66,17 @@ def main(): prog="gridfm_graphkit", description="gridfm-graphkit CLI", ) + parser.add_argument( + "--log_level", + type=str, + default="WARNING", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help=( + "Python logging level for the gridfm_graphkit package. " + "Use DEBUG to see performance trace messages (dataset init, " + "split, normalizer fit, wrapper cache timings)." + ), + ) subparsers = parser.add_subparsers(dest="command", required=True) exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -120,6 +132,12 @@ def main(): default=None, help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) + train_parser.add_argument( + "--batch_size", + type=int, + default=None, + help="Override training.batch_size from the YAML config.", + ) train_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, @@ -173,6 +191,12 @@ def main(): default=None, help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) + finetune_parser.add_argument( + "--batch_size", + type=int, + default=None, + help="Override training.batch_size from the YAML config.", + ) finetune_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, @@ -341,6 +365,14 @@ def main(): args = parser.parse_args() + logging.basicConfig( + level=getattr(logging, args.log_level), + format="%(asctime)s %(name)s %(levelname)s %(message)s", + ) + # Ensure the gridfm_graphkit package logger respects the chosen level even + # if third-party libraries have already configured the root logger. + logging.getLogger("gridfm_graphkit").setLevel(getattr(logging, args.log_level)) + if args.command == "benchmark": benchmark_cli(args) else: diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 0ffd1364..0bdd89ac 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -155,6 +155,9 @@ def main_cli(args): if getattr(args, "tf32", False): torch.set_float32_matmul_precision("high") # enables TF32 on Ampere+ GPUs + # MLflow file-store requires the directory to exist before DDP workers start. + os.makedirs(args.log_dir, exist_ok=True) + logger = MLFlowLogger( save_dir=args.log_dir, experiment_name=args.exp_name, @@ -177,6 +180,10 @@ def main_cli(args): if num_workers_override is not None: config_args.data.workers = num_workers_override + batch_size_override = getattr(args, "batch_size", None) + if batch_size_override is not None: + config_args.training.batch_size = batch_size_override + _load_plugins(getattr(args, "plugins", [])) _validate_dataset_wrapper(dataset_wrapper) @@ -228,7 +235,9 @@ def main_cli(args): "ddp", "ddp_find_unused_parameters_true", ): - _strategy = DDPStrategy(find_unused_parameters=True) + _strategy = DDPStrategy( + find_unused_parameters=(_strategy == "ddp_find_unused_parameters_true"), + ) trainer = L.Trainer( logger=logger, @@ -257,6 +266,13 @@ def main_cli(args): print( f"[performance] last epoch it/s : {epoch_timer.last_epoch_iters_per_sec:.2f}", ) + val_loss = trainer.callback_metrics.get("Validation loss") + if val_loss is not None: + try: + val_loss = val_loss.item() + except AttributeError: + pass + print(f"[performance] Validation loss : {val_loss}") if args.command != "predict": # Reuse the fit trainer when coming from train/finetune so that @@ -275,20 +291,7 @@ def main_cli(args): **trainer_kwargs, profiler=profiler, ) - test_results = test_trainer.test(model=model, datamodule=litGrid) - if report_performance: - # test_results[0] may be empty when metrics are routed to the logger - # only; fall back to trainer.callback_metrics which always has them. - metrics = ( - test_results[0] - if test_results and test_results[0] - else dict(test_trainer.callback_metrics) - ) - if metrics: - first_metric, first_value = next(iter(metrics.items())) - print(f"[performance] {first_metric} : {first_value}") - else: - print("[performance] no test metrics available") + test_trainer.test(model=model, datamodule=litGrid) artifacts_dir = None is_rank0 = ( diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index e5374970..ea59949f 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -1,4 +1,7 @@ +import csv import json +import logging +import time import torch import os from torch_geometric.loader import DataLoader @@ -23,7 +26,9 @@ import lightning as L from pathlib import Path from typing import List -from lightning.pytorch.loggers import MLFlowLogger +from gridfm_graphkit.utils.mlflow_artifact_utils import artifact_context + +logger = logging.getLogger(__name__) class LitGridHeteroDataModule(L.LightningDataModule): @@ -134,15 +139,27 @@ def setup(self, stage: str): print(f"Setup already done for stage={stage}, skipping...") return + # Timing telemetry: list of dicts, one row per network / phase. + # Uploaded as a CSV artifact and emitted as MLflow metrics at the end. + _timing_rows: list[dict] = [] + # Load pre-fitted normalizer stats if provided (e.g. from a training run) saved_stats = None if self.normalizer_stats_path is not None: + _t0 = time.perf_counter() saved_stats = torch.load( self.normalizer_stats_path, map_location="cpu", weights_only=True, ) print(f"Loaded normalizer stats from {self.normalizer_stats_path}") + _timing_rows.append( + { + "network": "_global", + "phase": "load_normalizer_stats_s", + "value": time.perf_counter() - _t0, + }, + ) for i, network in enumerate(self.args.data.networks): data_normalizer = load_normalizer(args=self.args) @@ -153,6 +170,7 @@ def setup(self, stage: str): is_distributed = dist.is_available() and dist.is_initialized() + _t0 = time.perf_counter() if not is_distributed or dist.get_rank() == 0: dataset = HeteroGridDatasetDisk( root=data_path_network, @@ -170,6 +188,18 @@ def setup(self, stage: str): data_normalizer=data_normalizer, transform=get_task_transforms(args=self.args), ) + _timing_rows.append( + { + "network": network, + "phase": "dataset_init_s", + "value": time.perf_counter() - _t0, + }, + ) + logger.debug( + "[perf] %s dataset_init_s: %.3f", + network, + _timing_rows[-1]["value"], + ) self.datasets.append(dataset) @@ -231,13 +261,13 @@ def setup(self, stage: str): dataset = Subset(dataset, subset_indices) - + if self.dataset_wrapper is not None: wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir) - # Random seed set before every split, same as above + _t0 = time.perf_counter() np.random.seed(self.args.seed) if self.split_by_load_scenario_idx: train_dataset, val_dataset, test_dataset = ( @@ -256,6 +286,18 @@ def setup(self, stage: str): self.args.data.val_ratio, self.args.data.test_ratio, ) + _timing_rows.append( + { + "network": network, + "phase": "dataset_split_s", + "value": time.perf_counter() - _t0, + }, + ) + logger.debug( + "[perf] %s dataset_split_s: %.3f", + network, + _timing_rows[-1]["value"], + ) # Extract scenario IDs for each split train_scenario_ids = self._extract_scenario_ids( @@ -280,6 +322,7 @@ def setup(self, stage: str): and network in saved_stats and data_normalizer.fit_strategy == "fit_on_train" ) + _t0 = time.perf_counter() if use_saved: print(f"Restoring normalizer for {network} from saved stats") data_normalizer.fit_from_dict(saved_stats[network]) @@ -294,11 +337,36 @@ def setup(self, stage: str): num_scenarios, saved_stats, ) + _timing_rows.append( + { + "network": network, + "phase": "normalizer_fit_s", + "value": time.perf_counter() - _t0, + }, + ) + logger.debug( + "[perf] %s normalizer_fit_s: %.3f", + network, + _timing_rows[-1]["value"], + ) # Populate the wrapper cache now that the normalizer is fitted, # so transform() has BaseMVA set when __getitem__ is called. if self.dataset_wrapper is not None and hasattr(dataset, "_setup_cache"): + _t0 = time.perf_counter() dataset._setup_cache() + _timing_rows.append( + { + "network": network, + "phase": "wrapper_cache_s", + "value": time.perf_counter() - _t0, + }, + ) + logger.debug( + "[perf] %s wrapper_cache_s: %.3f", + network, + _timing_rows[-1]["value"], + ) self.train_datasets.append(train_dataset) self.val_datasets.append(val_dataset) @@ -311,27 +379,37 @@ def setup(self, stage: str): self.val_dataset_multi = ConcatDataset(self.val_datasets) self._is_setup_done = True - # Save scenario splits (rank 0 only in DDP) + # Emit setup timings to Logger (MLflow) and save as a CSV artifact. + # Only rank-0 writes; trainer / logger may not be attached in unit tests. is_rank0 = ( not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0 ) - if ( - is_rank0 - and self.trainer is not None - and getattr(self.trainer, "logger", None) is not None - ): - logger = self.trainer.logger - if isinstance(logger, MLFlowLogger): - log_dir = os.path.join( - logger.save_dir, - logger.experiment_id, - logger.run_id, - "artifacts", - "stats", - ) - else: - log_dir = os.path.join(logger.save_dir, "stats") - self.save_scenario_splits(log_dir) + if is_rank0 and self.trainer is not None: + _logger = getattr(self.trainer, "logger", None) + if _logger is not None: + if _timing_rows: + # Emit each timing as an MLflow metric (step=0 → setup phase) + mlflow_metrics = { + f"perf/setup/{r['network']}/{r['phase']}": r["value"] + for r in _timing_rows + } + try: + _logger.log_metrics(mlflow_metrics, step=0) + except Exception: + pass # non-MLflow loggers may not support log_metrics directly + + # Save timing CSV as an artifact + with artifact_context(_logger, "stats") as _art_dir: + csv_path = os.path.join(_art_dir, "setup_timings.csv") + with open(csv_path, "w", newline="") as _fh: + writer = csv.DictWriter( + _fh, + fieldnames=["network", "phase", "value"], + ) + writer.writeheader() + writer.writerows(_timing_rows) + + # Save scenario splits (always, independent of timing) @staticmethod def _fit_normalizer( diff --git a/gridfm_graphkit/tasks/base_task.py b/gridfm_graphkit/tasks/base_task.py index fc2b95e3..e5c08018 100644 --- a/gridfm_graphkit/tasks/base_task.py +++ b/gridfm_graphkit/tasks/base_task.py @@ -3,8 +3,8 @@ from abc import ABC, abstractmethod import lightning as L from pytorch_lightning.utilities import rank_zero_only -from lightning.pytorch.loggers import MLFlowLogger import torch +from gridfm_graphkit.utils.mlflow_artifact_utils import artifact_context from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -83,33 +83,23 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): @rank_zero_only def on_fit_start(self): - # Save normalization statistics - if isinstance(self.logger, MLFlowLogger): - log_dir = os.path.join( - self.logger.save_dir, - self.logger.experiment_id, - self.logger.run_id, - "artifacts", - "stats", - ) - else: - log_dir = os.path.join(self.logger.save_dir, "stats") - - os.makedirs(log_dir, exist_ok=True) - - # Human-readable log - log_stats_path = os.path.join(log_dir, "normalization_stats.txt") - with open(log_stats_path, "w") as log_file: + # Save normalization statistics. + # artifact_context writes to a temp dir and uploads via the MLflow + # client API, supporting both local and remote tracking servers. + with artifact_context(self.logger, "stats") as log_dir: + # Human-readable log + log_stats_path = os.path.join(log_dir, "normalization_stats.txt") + with open(log_stats_path, "w") as log_file: + for i, normalizer in enumerate(self.data_normalizers): + log_file.write( + f"Data Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n", + ) + + # Machine-loadable stats (one file per network, keyed by network name) + stats_dict = {} for i, normalizer in enumerate(self.data_normalizers): - log_file.write( - f"Data Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n", - ) - - # Machine-loadable stats (one file per network, keyed by network name) - stats_dict = {} - for i, normalizer in enumerate(self.data_normalizers): - stats_dict[self.args.data.networks[i]] = normalizer.get_stats() - torch.save(stats_dict, os.path.join(log_dir, "normalizer_stats.pt")) + stats_dict[self.args.data.networks[i]] = normalizer.get_stats() + torch.save(stats_dict, os.path.join(log_dir, "normalizer_stats.pt")) def configure_optimizers(self): self.optimizer = torch.optim.AdamW( diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index 7df1344c..8fb696f3 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -43,7 +43,7 @@ ) import matplotlib.pyplot as plt import seaborn as sns -from lightning.pytorch.loggers import MLFlowLogger +from gridfm_graphkit.utils.mlflow_artifact_utils import artifact_write_ctx import numpy as np import os import pandas as pd @@ -321,15 +321,7 @@ def on_test_end(self): self.test_outputs.clear() return - if isinstance(self.logger, MLFlowLogger): - artifact_dir = os.path.join( - self.logger.save_dir, - self.logger.experiment_id, - self.logger.run_id, - "artifacts", - ) - else: - artifact_dir = self.logger.save_dir + artifact_dir, _upload = artifact_write_ctx(self.logger) final_metrics = self.trainer.callback_metrics grouped_metrics = {} @@ -508,6 +500,7 @@ def on_test_end(self): ), ) + _upload() self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index 2c2478ee..537aaf55 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -29,7 +29,7 @@ ComputeNodeInjection, ComputeNodeResiduals, ) -from lightning.pytorch.loggers import MLFlowLogger +from gridfm_graphkit.utils.mlflow_artifact_utils import artifact_write_ctx import os import pandas as pd @@ -244,15 +244,7 @@ def on_test_end(self): self.test_outputs.clear() # clear the test outputs for other ranks return - if isinstance(self.logger, MLFlowLogger): - artifact_dir = os.path.join( - self.logger.save_dir, - self.logger.experiment_id, - self.logger.run_id, - "artifacts", - ) - else: - artifact_dir = self.logger.save_dir + artifact_dir, _upload = artifact_write_ctx(self.logger) final_metrics = self.trainer.callback_metrics grouped_metrics = {} @@ -349,6 +341,7 @@ def on_test_end(self): prefix=dataset_name, ) + _upload() self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): diff --git a/gridfm_graphkit/tasks/reconstruction_tasks.py b/gridfm_graphkit/tasks/reconstruction_tasks.py index 45975aee..c29148c3 100644 --- a/gridfm_graphkit/tasks/reconstruction_tasks.py +++ b/gridfm_graphkit/tasks/reconstruction_tasks.py @@ -62,6 +62,12 @@ def shared_step(self, batch): return output, loss_dict def training_step(self, batch): + import time + + # Measure step execution time (forward + loss). _batch_start_time is + # set by BaseTask.on_train_batch_start just before this call. + step_start = getattr(self, "_batch_start_time", None) + _, loss_dict = self.shared_step(batch) current_lr = self.optimizer.param_groups[0]["lr"] metrics = {} @@ -79,6 +85,32 @@ def training_step(self, batch): on_step=True, ) + # Per-step throughput metrics – logged every step for fine-grained + # profiling; appear as time-series curves in MLflow. + if step_start is not None: + step_elapsed = time.perf_counter() - step_start + if step_elapsed > 0: + self.log( + "perf/step_time_ms", + step_elapsed * 1000.0, + batch_size=batch.num_graphs, + sync_dist=False, + on_epoch=False, + on_step=True, + prog_bar=False, + logger=True, + ) + self.log( + "perf/step_it_s", + 1.0 / step_elapsed, + batch_size=batch.num_graphs, + sync_dist=False, + on_epoch=False, + on_step=True, + prog_bar=False, + logger=True, + ) + return loss_dict["loss"] def validation_step(self, batch, batch_idx): diff --git a/gridfm_graphkit/tasks/se_task.py b/gridfm_graphkit/tasks/se_task.py index 36667ad2..b7e2f00a 100644 --- a/gridfm_graphkit/tasks/se_task.py +++ b/gridfm_graphkit/tasks/se_task.py @@ -20,7 +20,7 @@ from pytorch_lightning.utilities import rank_zero_only import torch from torch_scatter import scatter_add -from lightning.pytorch.loggers import MLFlowLogger +from gridfm_graphkit.utils.mlflow_artifact_utils import artifact_write_ctx import os @@ -121,15 +121,7 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): @rank_zero_only def on_test_end(self): - if isinstance(self.logger, MLFlowLogger): - artifact_dir = os.path.join( - self.logger.save_dir, - self.logger.experiment_id, - self.logger.run_id, - "artifacts", - ) - else: - artifact_dir = self.logger.save_dir + artifact_dir, _upload = artifact_write_ctx(self.logger) if self.args.verbose: for dataset_idx, outputs in self.test_outputs.items(): @@ -181,6 +173,7 @@ def on_test_end(self): ylabel="Measured", ) + _upload() self.test_outputs.clear() def predict_step(self, batch, batch_idx, dataloader_idx=0): diff --git a/gridfm_graphkit/training/callbacks.py b/gridfm_graphkit/training/callbacks.py index ba7a4049..06da20a1 100644 --- a/gridfm_graphkit/training/callbacks.py +++ b/gridfm_graphkit/training/callbacks.py @@ -1,19 +1,28 @@ from lightning.pytorch.callbacks import Callback from pytorch_lightning.utilities.rank_zero import rank_zero_only -from lightning.pytorch.loggers import MLFlowLogger import os import time import torch +from gridfm_graphkit.utils.mlflow_artifact_utils import artifact_context class EpochTimerCallback(Callback): - """Records wall-clock duration and iteration rate of every training epoch.""" + """Records wall-clock duration and iteration rate of every training epoch. + + Logs the following metrics to the Lightning logger (e.g. MLflow) so that + throughput and timing are tracked as time-series over the full training run: + + * ``perf/train_epoch_time_s`` – wall-clock seconds for the epoch + * ``perf/train_it_s`` – training batches per second + * ``perf/val_epoch_time_s`` – wall-clock seconds for the validation loop + """ def __init__(self): self.epoch_times: list[float] = [] self._epoch_start: float | None = None self._batch_count: int = 0 self._last_batch_count: int = 0 + self._val_start: float | None = None def on_train_epoch_start(self, trainer, pl_module): self._epoch_start = time.perf_counter() @@ -24,10 +33,49 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): def on_train_epoch_end(self, trainer, pl_module): if self._epoch_start is not None: - self.epoch_times.append(time.perf_counter() - self._epoch_start) + elapsed = time.perf_counter() - self._epoch_start + self.epoch_times.append(elapsed) self._last_batch_count = self._batch_count self._epoch_start = None + # Log epoch timing + throughput to MLflow (one point per epoch) + pl_module.log( + "perf/train_epoch_time_s", + elapsed, + on_epoch=True, + on_step=False, + logger=True, + prog_bar=False, + sync_dist=False, + ) + if elapsed > 0 and self._batch_count > 0: + pl_module.log( + "perf/train_it_s", + self._batch_count / elapsed, + on_epoch=True, + on_step=False, + logger=True, + prog_bar=False, + sync_dist=False, + ) + + def on_validation_epoch_start(self, trainer, pl_module): + self._val_start = time.perf_counter() + + def on_validation_epoch_end(self, trainer, pl_module): + if self._val_start is not None: + val_elapsed = time.perf_counter() - self._val_start + self._val_start = None + pl_module.log( + "perf/val_epoch_time_s", + val_elapsed, + on_epoch=True, + on_step=False, + logger=True, + prog_bar=False, + sync_dist=False, + ) + @property def last_epoch_time(self) -> float | None: return self.epoch_times[-1] if self.epoch_times else None @@ -74,21 +122,9 @@ def on_validation_end(self, trainer, pl_module): ): self.best_score = current - # Determine artifact directory + # Save the model's state_dict. + # artifact_context handles both local and remote MLflow servers. logger = trainer.logger - if isinstance(logger, MLFlowLogger): - model_dir = os.path.join( - logger.save_dir, - logger.experiment_id, - logger.run_id, - "artifacts", - "model", - ) - else: - model_dir = os.path.join(logger.save_dir, "model") - - os.makedirs(model_dir, exist_ok=True) - - # Save the model's state_dict - model_path = os.path.join(model_dir, self.filename) - torch.save(self._canonical_state_dict(pl_module), model_path) + with artifact_context(logger, "model") as model_dir: + model_path = os.path.join(model_dir, self.filename) + torch.save(self._canonical_state_dict(pl_module), model_path) diff --git a/gridfm_graphkit/utils/mlflow_artifact_utils.py b/gridfm_graphkit/utils/mlflow_artifact_utils.py new file mode 100644 index 00000000..cb4362a9 --- /dev/null +++ b/gridfm_graphkit/utils/mlflow_artifact_utils.py @@ -0,0 +1,102 @@ +""" +Utilities for writing MLflow artifacts that work with both local and remote +MLflow tracking servers. + +Two helpers are provided: + +``artifact_context`` – context manager:: + + with artifact_context(logger, "stats") as local_dir: + torch.save(my_data, os.path.join(local_dir, "data.pt")) + +``artifact_write_ctx`` – imperative style (no indentation change required):: + + artifact_dir, _upload = artifact_write_ctx(logger) + # … write files into artifact_dir … + _upload() # uploads to MLflow; no-op for other loggers + +For MLflow loggers both helpers write to a temporary directory then call +``MlflowClient.log_artifacts``, which works for **local and remote** tracking +servers alike. For other loggers files are written directly under +``logger.save_dir``. +""" + +import os +import shutil +import tempfile +from contextlib import contextmanager +from typing import Callable, Tuple + +from lightning.pytorch.loggers import MLFlowLogger + + +@contextmanager +def artifact_context(logger, artifact_subpath: str = ""): + """Context manager that yields a local directory for writing artifacts. + + On exit the directory contents are committed to the artifact store: + + * **MLflow logger** – files are written to a temporary directory and then + uploaded via ``MlflowClient.log_artifacts``. This works with both + local file-based tracking (``mlruns/``) and remote tracking servers + (HTTP/HTTPS, Databricks, etc.). + * **Other loggers** – files are written directly to + ``os.path.join(logger.save_dir, artifact_subpath)``. + + Parameters + ---------- + logger: + The Lightning logger attached to the trainer. + artifact_subpath: + Sub-directory within the artifact store to place the files under + (e.g. ``"stats"``, ``"model"``, ``"test"``). An empty string + targets the artifact root. + """ + if isinstance(logger, MLFlowLogger): + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + logger.experiment.log_artifacts( + logger.run_id, + tmpdir, + artifact_subpath if artifact_subpath else None, + ) + else: + local_dir = ( + os.path.join(logger.save_dir, artifact_subpath) + if artifact_subpath + else logger.save_dir + ) + os.makedirs(local_dir, exist_ok=True) + yield local_dir + + +def artifact_write_ctx(logger) -> Tuple[str, Callable[[], None]]: + """Imperative alternative to :func:`artifact_context`. + + Returns a ``(local_dir, upload_fn)`` tuple. Write all artifacts into + ``local_dir`` (creating sub-directories as needed), then call + ``upload_fn()`` to commit them to the artifact store. + + * **MLflow logger** – ``local_dir`` is a fresh temporary directory; + ``upload_fn`` uploads its contents via ``MlflowClient.log_artifacts`` + and then removes the temporary directory. + * **Other loggers** – ``local_dir`` is ``logger.save_dir``; ``upload_fn`` + is a no-op. + + Parameters + ---------- + logger: + The Lightning logger attached to the trainer. + """ + if isinstance(logger, MLFlowLogger): + tmpdir = tempfile.mkdtemp() + + def _upload() -> None: + logger.experiment.log_artifacts(logger.run_id, tmpdir) + shutil.rmtree(tmpdir, ignore_errors=True) + + return tmpdir, _upload + else: + local_dir = logger.save_dir + os.makedirs(local_dir, exist_ok=True) + return local_dir, lambda: None diff --git a/integrationtests/test_base_set.py b/integrationtests/test_base_set.py index 90da468a..1535d1bc 100644 --- a/integrationtests/test_base_set.py +++ b/integrationtests/test_base_set.py @@ -2,50 +2,41 @@ import subprocess import os import glob +import zipfile import pandas as pd import yaml -import urllib.request import shutil +import gdown def execute_and_live_output(cmd) -> None: subprocess.run(cmd, text=True, shell=True, check=True) -def prepare_config(): - """ - Download default.yaml from gridfm-datakit repo and modify it with test parameters. - """ - config_url = "https://raw.githubusercontent.com/gridfm/gridfm-datakit/refs/heads/main/scripts/config/default.yaml" - config_path = "integrationtests/default.yaml" +# Data generation via gridfm-datakit is replaced by a pre-built dataset +# hosted on Google Drive (case14_ieee, 10 000 scenarios, 2 topology variants). +# def prepare_config(): +# ... - print(f"Downloading config from {config_url}...") - with urllib.request.urlopen(config_url) as response: - config_content = response.read().decode("utf-8") - config = yaml.safe_load(config_content) +def download_dataset() -> None: + """Download the pre-built dataset from Google Drive and extract it to the repo root.""" + gdrive_file_id = "1in6tbkV4VTy3zQ5HJFvG40EOXAUg7UL9" + zip_path = "integrationtests/case14_ieee_dataset.zip" - config["network"]["name"] = "case14_ieee" - config["load"]["scenarios"] = 10000 - config["topology_perturbation"]["n_topology_variants"] = 2 + print(f"Downloading dataset (file id={gdrive_file_id}) from Google Drive...") + gdown.download(id=gdrive_file_id, output=zip_path, quiet=False) - with open(config_path, "w") as f: - yaml.dump(config, f, default_flow_style=False, sort_keys=False) + print(f"Extracting {zip_path}...") + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(".") - print(f"Config prepared at {config_path} with:") - print(f" - network.name: {config['network']['name']}") - print(f" - load.scenarios: {config['load']['scenarios']}") - print( - f" - topology_perturbation.n_topology_variants: " - f"{config['topology_perturbation']['n_topology_variants']}", - ) - - return config_path + print("Dataset ready.") def prepare_training_config(): """ - Modify the training config to set epochs to 2 for testing. + Modify the training config to set epochs to 2 for a quick integration test run. """ config_path = "examples/config/HGNS_PF_datakit_case14.yaml" @@ -60,7 +51,7 @@ def prepare_training_config(): with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) - print(f"Training config updated: epochs set to {config['training']['epochs']}") + print(f"Training config updated: epochs={config['training']['epochs']}") return config_path @@ -82,15 +73,15 @@ def cleanup_test_artifacts(): if os.path.exists(backup_config): shutil.move(backup_config, training_config) - # Remove downloaded config - config_file = "integrationtests/default.yaml" - if os.path.exists(config_file): - os.remove(config_file) + # Remove downloaded zip + zip_file = "integrationtests/case14_ieee_dataset.zip" + if os.path.exists(zip_file): + os.remove(zip_file) # Remove generated directories - for d in ["data_out", "logs"]: - if os.path.exists(d): - shutil.rmtree(d, ignore_errors=True) + #for d in ["data_out", "logs"]: + # if os.path.exists(d): + # shutil.rmtree(d, ignore_errors=True) def test_train(cleanup_test_artifacts): @@ -106,13 +97,10 @@ def test_train(cleanup_test_artifacts): data_dir = "data_out" if not os.path.exists(data_dir) or not os.listdir(data_dir): - print("Data directory not found or empty, generating data...") - - config_path = prepare_config() - - execute_and_live_output(f"gridfm_datakit generate {config_path}") + print("Data directory not found or empty, downloading pre-built dataset...") + download_dataset() else: - print(f"Data directory '{data_dir}' already exists, skipping generation.") + print(f"Data directory '{data_dir}' already exists, skipping download.") training_config_path = prepare_training_config() diff --git a/pyproject.toml b/pyproject.toml index 2b6c523c..f9739c82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,8 @@ dev = [ test = [ "pytest", - "pytest-cov" + "pytest-cov", + "gdown>=5.2.0" ] [project.scripts]