GridFM

Train, finetune and interact with a foundation model for the electric power grid.
https://github.com/gridfm/gridfm-graphkit

Category: Energy Systems
Sub Category: Grid Analysis and Planning

Last synced: about 20 hours ago
JSON representation

Repository metadata

Train, finetune and interact with a foundation model for the electric power grid.

README.md

DOI
Docs
Coverage
OpenSSF Best Practices
OpenSSF Scorecard
Python
License

This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.


Installation

Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)

python -m venv venv
source venv/bin/activate

Install gridfm-graphkit in editable mode

pip install -e .

torch-scatter is a required dependency. It cannot be bundled in pyproject.toml because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.

Get PyTorch + CUDA version for torch-scatter

TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")

Install the correct torch-scatter wheel

pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html

For documentation generation and unit testing, install with the optional dev and test extras:

pip install -e .[dev,test]

CLI commands

Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.

gridfm_graphkit <command> [OPTIONS]

Available commands:

  • train - Train a new model from scratch
  • finetune - Fine-tune an existing pre-trained model
  • evaluate - Evaluate model performance on a dataset
  • predict - Run inference and save predictions

Training Models

gridfm_graphkit train --config path/to/config.yaml

Arguments

Argument Type Description Default
--config str Required. Path to the training configuration YAML file. None
--exp_name str MLflow experiment name. timestamp
--run_name str MLflow run name. run
--log_dir str MLflow tracking/logging directory. mlruns
--data_path str Root dataset directory. data
--compile [MODE] str Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. None
--bfloat16 flag Cast model to torch.bfloat16 (model.to(torch.bfloat16)). False
--tf32 flag Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). False
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--plugins list[str] Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. []
--num_workers int Override data.workers from YAML. Use 0 to debug worker crashes. None
--dataset_wrapper_cache_dir str Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. None
--profiler str Enable Lightning profiler (simple, advanced, pytorch). None
--compute_dc_ac_metrics flag Compute ground-truth AC/DC power balance metrics on the test split. False

Examples

Standard Training:

gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data

Fine-Tuning Models

gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt

Arguments

Argument Type Description Default
--config str Required. Fine-tuning configuration file. None
--model_path str Required. Path to a pre-trained model state dict. None
--exp_name str MLflow experiment name. timestamp
--run_name str MLflow run name. run
--log_dir str MLflow logging directory. mlruns
--data_path str Root dataset directory. data
--compile [MODE] str Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. None
--bfloat16 flag Cast model to torch.bfloat16 (model.to(torch.bfloat16)). False
--tf32 flag Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). False
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--plugins list[str] Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. []
--num_workers int Override data.workers from YAML. Use 0 to debug worker crashes. None
--dataset_wrapper_cache_dir str Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. None
--profiler str Enable Lightning profiler (simple, advanced, pytorch). None
--compute_dc_ac_metrics flag Compute ground-truth AC/DC power balance metrics on the test split. False

Evaluating Models

gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt

Arguments

Argument Type Description Default
--config str Required. Path to evaluation config. None
--model_path str Path to the trained model state dict. None
--normalizer_stats str Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics instead of re-fitting on current split. None
--exp_name str MLflow experiment name. timestamp
--run_name str MLflow run name. run
--log_dir str MLflow logging directory. mlruns
--data_path str Dataset directory. data
--compile [MODE] str Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. None
--bfloat16 flag Cast model to torch.bfloat16 (model.to(torch.bfloat16)). False
--tf32 flag Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). False
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--plugins list[str] Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. []
--num_workers int Override data.workers from YAML. Use 0 to debug worker crashes. None
--dataset_wrapper_cache_dir str Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. None
--profiler str Enable Lightning profiler (simple, advanced, pytorch). None
--compute_dc_ac_metrics flag Compute ground-truth AC/DC power balance metrics on the test split. False
--save_output flag Save predictions as <grid_name>_predictions.parquet under MLflow artifacts (.../artifacts/test). False

Example with saved normalizer stats

When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:

gridfm_graphkit evaluate \
  --config examples/config/HGNS_PF_datakit_case118.yaml \
  --model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
  --normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
  --data_path data

Note: The --normalizer_stats flag only affects normalizers with fit_strategy = "fit_on_train" (e.g. HeteroDataMVANormalizer). Per-sample normalizers (HeteroDataPerSampleMVANormalizer) always recompute their statistics from the current dataset regardless of this flag.


Running Predictions

gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt

Arguments

Argument Type Description Default
--config str Required. Path to prediction config file. None
--model_path str Path to trained model state dict. Optional; may be defined in config. None
--normalizer_stats str Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics. None
--exp_name str MLflow experiment name. timestamp
--run_name str MLflow run name. run
--log_dir str MLflow logging directory. mlruns
--data_path str Dataset directory. data
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--plugins list[str] Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. []
--num_workers int Override data.workers from YAML. Use 0 to debug worker crashes. None
--dataset_wrapper_cache_dir str Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. None
--output_path str Directory where predictions are saved as <grid_name>_predictions.parquet. data
--compile [MODE] str Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. None
--bfloat16 flag Cast model to torch.bfloat16 (model.to(torch.bfloat16)). False
--tf32 flag Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). False
--profiler str Enable Lightning profiler (simple, advanced, pytorch). None

Benchmarking Dataloader Throughput

gridfm_graphkit benchmark --config path/to/config.yaml

Arguments

Argument Type Description Default
--config str Required. Path to configuration YAML file. None
--data_path str Root dataset directory. data
--epochs int Number of epochs to iterate through the train dataloader. 3
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--dataset_wrapper_cache_dir str Directory for dataset wrapper disk cache. None
--num_workers int Override data.workers from YAML. None
--plugins list[str] Python packages to import for plugin registration. []

Use built-in help for full command details:

gridfm_graphkit --help
gridfm_graphkit <command> --help

Running Tests

Unit and Integration Tests

Install the test dependencies first (if not already done):

pip install -e .[dev,test]

Run the full unit test suite:

pytest ./tests

Run the base set integration tests:

pytest ./integrationtests/test_base_set.py

Running Base Set Tests on an LSF Cluster (GPU)

To submit the base set integration tests as an interactive LSF job with GPU access, use bsub. Adjust the paths to match your environment:

bsub -gpu "num=1" \
     -n 16 \
     -R "rusage[mem=32GB] span[hosts=1]" \
     -Is \
     -J gridfm_base_set_tests \
     /bin/bash -c "
       cd /path/to/gridfm-graphkit && \
       export PATH=/path/to/cuda/bin:\$PATH \
               CUDA_HOME=/path/to/cuda \
               LD_LIBRARY_PATH=/path/to/cuda/lib64:\$LD_LIBRARY_PATH && \
       source /path/to/venv/bin/activate && \
       pytest ./integrationtests/test_base_set.py
     "

Key bsub options used above:

Option Description
-gpu "num=1" Request 1 GPU
-n 16 Request 16 CPU slots
-R "rusage[mem=32GB] span[hosts=1]" Reserve 32 GB of memory on a single host
-Is Run as an interactive shell session
-J <job_name> Assign a name to the job

Concrete example (adapt paths to your cluster setup):

bsub -gpu "num=1" -n 16 -R "rusage[mem=32GB] span[hosts=1]" -Is -J hpo_trial_190 /bin/bash -c "cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && export PATH=/opt/share/cuda-12.8.1/bin:\$PATH CUDA_HOME=/opt/share/cuda-12.8.1 LD_LIBRARY_PATH=/opt/share/cuda-12.8.1/lib64:\$LD_LIBRARY_PATH && source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && pytest ./integrationtests/test_base_set.py"


Owner metadata


GitHub Events

Total
Last Year

Committers metadata

Last synced: 3 days ago

Total Commits: 353
Total Committers: 17
Avg Commits per committer: 20.765
Development Distribution Score (DDS): 0.581

Commits in past year: 333
Committers in past year: 13
Avg Commits per committer in past year: 25.615
Development Distribution Score (DDS) in past year: 0.556

Name Email Commits
Romeo Kienzler r****1@i****m 148
Alban Puech n****p@g****m 57
MatteoMazzonelli m****1@i****m 53
Mangaliso Mngomomezulu 6****M 20
Matteo Mazzonelli M****i@i****m 12
Mangaliso Mngomezulu m****u@M****l 10
Naomi Simumba 7****a 9
François Mirallès 4****r 9
Celia Cintas c****s@i****m 9
Mangaliso Mngomezulu m****u@m****m 6
Hector Maeso Garcia h****a@i****m 6
Alban Puech A****1@i****m 6
“Mangaliso-M” “****u@g****” 4
Etienne Eben Vos E****s@i****m 1
Jonas Weiss J****E@z****m 1
Hector Maeso Garcia H****a@i****m 1
cs5807 m****s@i****a 1

Committer domains:


Issue and Pull Request metadata

Last synced: 3 days ago

Total issues: 5
Total pull requests: 51
Average time to close issues: about 2 months
Average time to close pull requests: 9 days
Total issue authors: 4
Total pull request authors: 11
Average comments per issue: 0.2
Average comments per pull request: 0.29
Merged pull request: 32
Bot issues: 0
Bot pull requests: 0

Past year issues: 5
Past year pull requests: 51
Past year average time to close issues: about 2 months
Past year average time to close pull requests: 9 days
Past year issue authors: 4
Past year pull request authors: 11
Past year average comments per issue: 0.2
Past year average comments per pull request: 0.29
Past year merged pull request: 32
Past year bot issues: 0
Past year bot pull requests: 0

More stats: https://issues.ecosyste.ms/repositories/lookup?url=https://github.com/gridfm/gridfm-graphkit

Top Issue Authors

  • romeokienzler (2)
  • TavoIREQ (1)
  • janu000 (1)
  • mellson (1)

Top Pull Request Authors

  • romeokienzler (13)
  • albanpuech (11)
  • MatteoMazzonelli (9)
  • frmir (4)
  • mkisuule (3)
  • Mangaliso-M (3)
  • ttolhurst (3)
  • naomi-simumba (2)
  • celiacintas (1)
  • Tamaragov (1)
  • emmanuelbadmus (1)

Top Issue Labels

  • bug (1)

Top Pull Request Labels

  • documentation (6)
  • enhancement (4)

Package metadata

pypi.org: gridfm-graphkit

Grid Foundation Model

  • Homepage:
  • Documentation: https://gridfm-graphkit.readthedocs.io/
  • Licenses: Apache-2.0
  • Latest release: 0.0.6 (published 9 months ago)
  • Last Synced: 2026-06-10T11:08:49.905Z (3 days ago)
  • Versions: 7
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 226 Last month
  • Rankings:
    • Dependent packages count: 8.944%
    • Average: 29.672%
    • Dependent repos count: 50.399%
  • Maintainers (2)

Dependencies

.github/workflows/ci-build.yaml actions
  • actions/checkout v4 composite
  • actions/setup-python v4 composite
.github/workflows/deploy_docs.yaml actions
  • actions/cache v4 composite
  • actions/checkout v4 composite
  • actions/setup-python v4 composite
pyproject.toml pypi
  • mlflow >=3.1.0
  • nbformat >=5.10.4
  • networkx >=3.4.2
  • numpy >=2.2.6
  • pandas >=2.3.0
  • plotly >=6.1.2
  • pyyaml >=6.0.2
  • torch >=2.7.1
  • torch-geometric >=2.6.1
  • torchaudio >=2.7.1
  • torchvision >=0.22.1
.github/workflows/release.yaml actions
  • actions/checkout v4 composite
  • actions/setup-python v4 composite
  • pypa/gh-action-pypi-publish release/v1 composite

Score: 12.805853755280662