Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5258f59
util
naomi-simumba Mar 6, 2026
58e7acd
util
naomi-simumba Mar 6, 2026
9c12f56
configurable optimizer
naomi-simumba Mar 6, 2026
00031c3
new losses
naomi-simumba Mar 6, 2026
65009df
baseline models
naomi-simumba Mar 6, 2026
4d0da4c
style
naomi-simumba Mar 6, 2026
8e46608
new globals
naomi-simumba Mar 7, 2026
5e3955c
update sgared step
naomi-simumba Mar 11, 2026
fbe838a
rename loss
naomi-simumba Mar 13, 2026
86b4afd
cleaned branch to keep only the following:
albanpuech Apr 2, 2026
b138800
changed Qg loss to compute it on PV + REF buses, not only buses with …
albanpuech Apr 2, 2026
520ea22
Merge remote-tracking branch 'origin/main' into opf_integration
albanpuech Apr 7, 2026
ee69279
fix dataset to make load_scenarios.pt optional (when we use opfdata w…
albanpuech Apr 9, 2026
e595872
fix loading of weight of compiled model
albanpuech Apr 9, 2026
0921939
change back to Qg penalization computed only on buses with violations
albanpuech Apr 9, 2026
c23b77a
add config files for Ola
albanpuech Apr 9, 2026
78a0890
clear tes outputs for other ranks after collecting their output + add…
albanpuech Apr 15, 2026
959435d
added predict and dc baseline for opf
albanpuech Apr 15, 2026
4c85499
added todo about gen and bus idx
albanpuech Apr 15, 2026
215b5a1
fixed comment in normalizer
albanpuech Apr 15, 2026
5b5dd38
added comment when creating training dataloader
albanpuech Apr 15, 2026
3f460db
fixed cli to handle predictions for opf
albanpuech Apr 15, 2026
a321733
fixed config files
albanpuech Apr 16, 2026
8b85319
added compute_dc_ac_metrics to train parser
albanpuech Apr 17, 2026
d17511f
added mps compatibility moving stuff to cpu when using torch_scatter
albanpuech Apr 17, 2026
4d67747
fix loc warning and column drop
albanpuech Apr 17, 2026
c25ee09
add transfer_batch_to_device to cast to float 32 for mps support
albanpuech Apr 17, 2026
ff6d165
reformat and cache integration test
albanpuech Apr 20, 2026
06f7c0f
add description of metrics for OPF
albanpuech Apr 20, 2026
623cc09
added compute_dc_ac_metrics for finetuning
albanpuech Apr 20, 2026
95e124c
removed old config files
albanpuech Apr 20, 2026
9c65f95
delete outdated docs
albanpuech Apr 20, 2026
92cf0de
update yaml config docs
albanpuech Apr 20, 2026
a5aacba
load_scenarios is now only accessed when data.split_by_load_scenario_…
albanpuech Apr 20, 2026
b7fad4f
Merge remote-tracking branch 'origin/main' into opf_integration_with_…
albanpuech Apr 20, 2026
3020f4a
add docstring
albanpuech Apr 21, 2026
c49b5fc
Revert "reformat and cache integration test"
albanpuech Apr 21, 2026
b016e7e
added lsf support
albanpuech Apr 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 93 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ gridfm_graphkit <command> [OPTIONS]
Available commands:

* `train` - Train a new model from scratch
* `finetune` Fine-tune an existing pre-trained model
* `evaluate` Evaluate model performance on a dataset
* `predict` Run inference and save predictions
* `finetune` - Fine-tune an existing pre-trained model
* `evaluate` - Evaluate model performance on a dataset
* `predict` - Run inference and save predictions

---

Expand All @@ -74,13 +74,22 @@ gridfm_graphkit train --config path/to/config.yaml

### Arguments

| Argument | Type | Description | Default |
| ---------------- | ------ | ---------------------------------------------------------------- | ------- |
| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
| `--run_name` | `str` | MLflow run name. | `run` |
| `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` |
| `--data_path` | `str` | Root dataset directory. | `data` |
| Argument | Type | Description | Default |
| -------- | ---- | ----------- | ------- |
| `--config` | `str` | **Required**. Path to the training configuration YAML file. | `None` |
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
| `--run_name` | `str` | MLflow run name. | `run` |
| `--log_dir` | `str` | MLflow tracking/logging directory. | `mlruns` |
| `--data_path` | `str` | Root dataset directory. | `data` |
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |

### Examples

Expand All @@ -100,14 +109,23 @@ gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model

### Arguments

| Argument | Type | Description | Default |
| -------------- | ----- | ----------------------------------------------- | --------- |
| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
| `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` |
| `--exp_name` | `str` | MLflow experiment name. | timestamp |
| `--run_name` | `str` | MLflow run name. | `run` |
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
| `--data_path` | `str` | Root dataset directory. | `data` |
| Argument | Type | Description | Default |
| -------- | ---- | ----------- | ------- |
| `--config` | `str` | **Required**. Fine-tuning configuration file. | `None` |
| `--model_path` | `str` | **Required**. Path to a pre-trained model state dict. | `None` |
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
| `--run_name` | `str` | MLflow run name. | `run` |
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
| `--data_path` | `str` | Root dataset directory. | `data` |
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |


---
Expand All @@ -120,17 +138,25 @@ gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.p

### Arguments

| Argument | Type | Description | Default |
| --------------------- | ----- | ------------------------------------------------------------------------------------------------------------- | --------- |
| `--config` | `str` | **Required**. Path to evaluation config. | `None` |
| `--model_path` | `str` | Path to the trained model state dict. | `None` |
| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on the current data split. | `None` |
| `--exp_name` | `str` | MLflow experiment name. | timestamp |
| `--run_name` | `str` | MLflow run name. | `run` |
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
| `--data_path` | `str` | Dataset directory. | `data` |
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
| `--save_output` | `flag` | Save predictions as `<grid_name>_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` |
| Argument | Type | Description | Default |
| -------- | ---- | ----------- | ------- |
| `--config` | `str` | **Required**. Path to evaluation config. | `None` |
| `--model_path` | `str` | Path to the trained model state dict. | `None` |
| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics instead of re-fitting on current split. | `None` |
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
| `--run_name` | `str` | MLflow run name. | `run` |
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
| `--data_path` | `str` | Dataset directory. | `data` |
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |
| `--compute_dc_ac_metrics` | `flag` | Compute ground-truth AC/DC power balance metrics on the test split. | `False` |
| `--save_output` | `flag` | Save predictions as `<grid_name>_predictions.parquet` under MLflow artifacts (`.../artifacts/test`). | `False` |

### Example with saved normalizer stats

Expand All @@ -156,16 +182,44 @@ gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.

### Arguments

| Argument | Type | Description | Default |
| --------------------- | ----- | ------------------------------------------------------------------------------------------------------------- | --------- |
| `--config` | `str` | **Required**. Path to prediction config file. | `None` |
| `--model_path` | `str` | Path to the trained model state dict. | `None` |
| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` |
| `--exp_name` | `str` | MLflow experiment name. | timestamp |
| `--run_name` | `str` | MLflow run name. | `run` |
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
| `--data_path` | `str` | Dataset directory. | `data` |
| `--output_path` | `str` | Directory where predictions are saved as `<grid_name>_predictions.parquet`. | `data` |
| Argument | Type | Description | Default |
| -------- | ---- | ----------- | ------- |
| `--config` | `str` | **Required**. Path to prediction config file. | `None` |
| `--model_path` | `str` | Path to trained model state dict. Optional; may be defined in config. | `None` |
| `--normalizer_stats` | `str` | Path to `normalizer_stats.pt` from a training run. Restores `fit_on_train` normalizers from saved statistics. | `None` |
| `--exp_name` | `str` | MLflow experiment name. | `timestamp` |
| `--run_name` | `str` | MLflow run name. | `run` |
| `--log_dir` | `str` | MLflow logging directory. | `mlruns` |
| `--data_path` | `str` | Dataset directory. | `data` |
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
| `--plugins` | `list[str]` | Python packages to import for plugin registration, e.g. `gridfm_graphkit_ee`. | `[]` |
| `--num_workers` | `int` | Override `data.workers` from YAML. Use `0` to debug worker crashes. | `None` |
| `--dataset_wrapper_cache_dir` | `str` | Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | `None` |
| `--output_path` | `str` | Directory where predictions are saved as `<grid_name>_predictions.parquet`. | `data` |
| `--compile [MODE]` | `str` | Enable `torch.compile` mode. Valid values: `default`, `reduce-overhead`, `max-autotune`, `max-autotune-no-cudagraphs`. If flag is passed without a value, mode is `default`. | `None` |
| `--bfloat16` | `flag` | Cast model to `torch.bfloat16` (`model.to(torch.bfloat16)`). | `False` |
| `--tf32` | `flag` | Enable TF32 on Ampere+ GPUs via `torch.set_float32_matmul_precision("high")`. | `False` |
| `--profiler` | `str` | Enable Lightning profiler (`simple`, `advanced`, `pytorch`). | `None` |

---

## Benchmarking Dataloader Throughput

```bash
gridfm_graphkit benchmark --config path/to/config.yaml
```

### Arguments

| Argument | Type | Description | Default |
| -------- | ---- | ----------- | ------- |
| `--config` | `str` | **Required**. Path to configuration YAML file. | `None` |
| `--data_path` | `str` | Root dataset directory. | `data` |
| `--epochs` | `int` | Number of epochs to iterate through the train dataloader. | `3` |
| `--dataset_wrapper` | `str` | Registered dataset wrapper name (see `DATASET_WRAPPER_REGISTRY`), e.g. `SharedMemoryCacheDataset`. | `None` |
| `--dataset_wrapper_cache_dir` | `str` | Directory for dataset wrapper disk cache. | `None` |
| `--num_workers` | `int` | Override `data.workers` from YAML. | `None` |
| `--plugins` | `list[str]` | Python packages to import for plugin registration. | `[]` |

Use built-in help for full command details:

Expand Down
3 changes: 0 additions & 3 deletions docs/datasets/data_modules.md

This file was deleted.

57 changes: 0 additions & 57 deletions docs/datasets/data_normalization.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/datasets/powergrid.md

This file was deleted.

19 changes: 0 additions & 19 deletions docs/datasets/transforms.md

This file was deleted.

42 changes: 14 additions & 28 deletions docs/install/installation.md
Original file line number Diff line number Diff line change
@@ -1,46 +1,32 @@
# Installation

You can install `gridfm-graphkit` directly from PyPI:
The steps below mirror the [README](https://github.com/gridfm/gridfm-graphkit/blob/main/README.md#installation). Run them from the root of a local clone or source checkout of the repository.

```bash
pip install gridfm-graphkit
```

For GPU support and compatibility with PyTorch Geometric's scatter operations, install PyTorch (and optionally CUDA) first, then install the matching `torch-scatter` wheel. See [PyTorch and torch-scatter](#pytorch-and-torch-scatter-optional) below.

---

## Development Setup

To contribute or develop locally, clone the repository and install in editable mode. Use Python 3.10, 3.11, or 3.12 (3.12 is recommended).
Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)

```bash
git clone git@github.com:gridfm/gridfm-graphkit.git
cd gridfm-graphkit
python -m venv venv
source venv/bin/activate
pip install -e .
```

### PyTorch and torch-scatter (optional)
Install gridfm-graphkit in editable mode

If you need GPU acceleration or PyTorch Geometric scatter ops (used by the library), install PyTorch and the matching `torch-scatter` wheel:
```bash
pip install -e .
```

1. Install PyTorch (see [pytorch.org](https://pytorch.org/) for your platform and CUDA version).
Get PyTorch + CUDA version for torch-scatter

2. Get your Torch + CUDA version string:
```bash
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")
```
```bash
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")
```

3. Install the correct `torch-scatter` wheel:
```bash
pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html
```
Install the correct torch-scatter wheel

---
```bash
pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html
```

## Optional extras

For documentation generation and unit testing, install with the optional `dev` and `test` extras:

Expand Down
37 changes: 0 additions & 37 deletions docs/models/models.md

This file was deleted.

Loading
Loading