GridFM
Train, finetune and interact with a foundation model for the electric power grid.
https://github.com/gridfm/gridfm-graphkit
Category: Energy Systems
Sub Category: Grid Analysis and Planning
Last synced: about 20 hours ago
JSON representation
Repository metadata
Train, finetune and interact with a foundation model for the electric power grid.
- Host: GitHub
- URL: https://github.com/gridfm/gridfm-graphkit
- Owner: gridfm
- License: apache-2.0
- Created: 2025-06-23T14:51:29.000Z (12 months ago)
- Default Branch: main
- Last Pushed: 2026-06-03T13:13:40.000Z (10 days ago)
- Last Synced: 2026-06-04T07:04:00.018Z (9 days ago)
- Language: Python
- Homepage: https://gridfm.github.io/gridfm-graphkit/
- Size: 144 MB
- Stars: 79
- Watchers: 1
- Forks: 20
- Open Issues: 15
- Releases: 7
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
- Code of conduct: CODE_OF_CONDUCT.md
- Support: SUPPORT.md
- Governance: GOVERNANCE.md
README.md
This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
Installation
Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)
python -m venv venv
source venv/bin/activate
Install gridfm-graphkit in editable mode
pip install -e .
torch-scatter is a required dependency. It cannot be bundled in pyproject.toml because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.
Get PyTorch + CUDA version for torch-scatter
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")
Install the correct torch-scatter wheel
pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html
For documentation generation and unit testing, install with the optional dev and test extras:
pip install -e .[dev,test]
CLI commands
Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.
gridfm_graphkit <command> [OPTIONS]
Available commands:
train- Train a new model from scratchfinetune- Fine-tune an existing pre-trained modelevaluate- Evaluate model performance on a datasetpredict- Run inference and save predictions
Training Models
gridfm_graphkit train --config path/to/config.yaml
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to the training configuration YAML file. | None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow tracking/logging directory. | mlruns |
--data_path |
str |
Root dataset directory. | data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
Examples
Standard Training:
gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
Fine-Tuning Models
gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Fine-tuning configuration file. | None |
--model_path |
str |
Required. Path to a pre-trained model state dict. | None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Root dataset directory. | data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
Evaluating Models
gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to evaluation config. | None |
--model_path |
str |
Path to the trained model state dict. | None |
--normalizer_stats |
str |
Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics instead of re-fitting on current split. |
None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Dataset directory. | data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
--save_output |
flag |
Save predictions as <grid_name>_predictions.parquet under MLflow artifacts (.../artifacts/test). |
False |
Example with saved normalizer stats
When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:
gridfm_graphkit evaluate \
--config examples/config/HGNS_PF_datakit_case118.yaml \
--model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
--normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
--data_path data
Note: The
--normalizer_statsflag only affects normalizers withfit_strategy = "fit_on_train"(e.g.HeteroDataMVANormalizer). Per-sample normalizers (HeteroDataPerSampleMVANormalizer) always recompute their statistics from the current dataset regardless of this flag.
Running Predictions
gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to prediction config file. | None |
--model_path |
str |
Path to trained model state dict. Optional; may be defined in config. | None |
--normalizer_stats |
str |
Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics. |
None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Dataset directory. | data |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--output_path |
str |
Directory where predictions are saved as <grid_name>_predictions.parquet. |
data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
Benchmarking Dataloader Throughput
gridfm_graphkit benchmark --config path/to/config.yaml
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to configuration YAML file. | None |
--data_path |
str |
Root dataset directory. | data |
--epochs |
int |
Number of epochs to iterate through the train dataloader. | 3 |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--dataset_wrapper_cache_dir |
str |
Directory for dataset wrapper disk cache. | None |
--num_workers |
int |
Override data.workers from YAML. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration. | [] |
Use built-in help for full command details:
gridfm_graphkit --help
gridfm_graphkit <command> --help
Running Tests
Unit and Integration Tests
Install the test dependencies first (if not already done):
pip install -e .[dev,test]
Run the full unit test suite:
pytest ./tests
Run the base set integration tests:
pytest ./integrationtests/test_base_set.py
Running Base Set Tests on an LSF Cluster (GPU)
To submit the base set integration tests as an interactive LSF job with GPU access, use bsub. Adjust the paths to match your environment:
bsub -gpu "num=1" \
-n 16 \
-R "rusage[mem=32GB] span[hosts=1]" \
-Is \
-J gridfm_base_set_tests \
/bin/bash -c "
cd /path/to/gridfm-graphkit && \
export PATH=/path/to/cuda/bin:\$PATH \
CUDA_HOME=/path/to/cuda \
LD_LIBRARY_PATH=/path/to/cuda/lib64:\$LD_LIBRARY_PATH && \
source /path/to/venv/bin/activate && \
pytest ./integrationtests/test_base_set.py
"
Key bsub options used above:
| Option | Description |
|---|---|
-gpu "num=1" |
Request 1 GPU |
-n 16 |
Request 16 CPU slots |
-R "rusage[mem=32GB] span[hosts=1]" |
Reserve 32 GB of memory on a single host |
-Is |
Run as an interactive shell session |
-J <job_name> |
Assign a name to the job |
Concrete example (adapt paths to your cluster setup):
bsub -gpu "num=1" -n 16 -R "rusage[mem=32GB] span[hosts=1]" -Is -J hpo_trial_190 /bin/bash -c "cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && export PATH=/opt/share/cuda-12.8.1/bin:\$PATH CUDA_HOME=/opt/share/cuda-12.8.1 LD_LIBRARY_PATH=/opt/share/cuda-12.8.1/lib64:\$LD_LIBRARY_PATH && source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && pytest ./integrationtests/test_base_set.py"
Owner metadata
- Name: gridfm
- Login: gridfm
- Email:
- Kind: organization
- Description:
- Website:
- Location:
- Twitter:
- Company:
- Icon url: https://avatars.githubusercontent.com/u/12138362?v=4
- Repositories: 1
- Last ynced at: 2023-03-01T14:30:36.533Z
- Profile URL: https://github.com/gridfm
GitHub Events
Total
- Release event: 4
- Delete event: 13
- Member event: 3
- Pull request event: 24
- Fork event: 4
- Issues event: 5
- Watch event: 27
- Issue comment event: 1
- Push event: 144
- Pull request review comment event: 13
- Pull request review event: 12
- Create event: 36
Last Year
- Release event: 4
- Delete event: 13
- Member event: 3
- Pull request event: 24
- Fork event: 4
- Issues event: 5
- Watch event: 27
- Issue comment event: 1
- Push event: 144
- Pull request review comment event: 13
- Pull request review event: 12
- Create event: 36
Committers metadata
Last synced: 3 days ago
Total Commits: 353
Total Committers: 17
Avg Commits per committer: 20.765
Development Distribution Score (DDS): 0.581
Commits in past year: 333
Committers in past year: 13
Avg Commits per committer in past year: 25.615
Development Distribution Score (DDS) in past year: 0.556
| Name | Commits | |
|---|---|---|
| Romeo Kienzler | r****1@i****m | 148 |
| Alban Puech | n****p@g****m | 57 |
| MatteoMazzonelli | m****1@i****m | 53 |
| Mangaliso Mngomomezulu | 6****M | 20 |
| Matteo Mazzonelli | M****i@i****m | 12 |
| Mangaliso Mngomezulu | m****u@M****l | 10 |
| Naomi Simumba | 7****a | 9 |
| François Mirallès | 4****r | 9 |
| Celia Cintas | c****s@i****m | 9 |
| Mangaliso Mngomezulu | m****u@m****m | 6 |
| Hector Maeso Garcia | h****a@i****m | 6 |
| Alban Puech | A****1@i****m | 6 |
| “Mangaliso-M” | “****u@g****” | 4 |
| Etienne Eben Vos | E****s@i****m | 1 |
| Jonas Weiss | J****E@z****m | 1 |
| Hector Maeso Garcia | H****a@i****m | 1 |
| cs5807 | m****s@i****a | 1 |
Committer domains:
- ibm.com: 8
- ireq.ca: 1
- zurich.ibm.com: 1
- gmail.com”: 1
- macbookpro.witsjuta.za.ibm.com: 1
Issue and Pull Request metadata
Last synced: 3 days ago
Total issues: 5
Total pull requests: 51
Average time to close issues: about 2 months
Average time to close pull requests: 9 days
Total issue authors: 4
Total pull request authors: 11
Average comments per issue: 0.2
Average comments per pull request: 0.29
Merged pull request: 32
Bot issues: 0
Bot pull requests: 0
Past year issues: 5
Past year pull requests: 51
Past year average time to close issues: about 2 months
Past year average time to close pull requests: 9 days
Past year issue authors: 4
Past year pull request authors: 11
Past year average comments per issue: 0.2
Past year average comments per pull request: 0.29
Past year merged pull request: 32
Past year bot issues: 0
Past year bot pull requests: 0
Top Issue Authors
- romeokienzler (2)
- TavoIREQ (1)
- janu000 (1)
- mellson (1)
Top Pull Request Authors
- romeokienzler (13)
- albanpuech (11)
- MatteoMazzonelli (9)
- frmir (4)
- mkisuule (3)
- Mangaliso-M (3)
- ttolhurst (3)
- naomi-simumba (2)
- celiacintas (1)
- Tamaragov (1)
- emmanuelbadmus (1)
Top Issue Labels
- bug (1)
Top Pull Request Labels
- documentation (6)
- enhancement (4)
Package metadata
- Total packages: 1
-
Total downloads:
- pypi: 226 last-month
- Total dependent packages: 0
- Total dependent repositories: 0
- Total versions: 7
- Total maintainers: 2
pypi.org: gridfm-graphkit
Grid Foundation Model
- Homepage:
- Documentation: https://gridfm-graphkit.readthedocs.io/
- Licenses: Apache-2.0
- Latest release: 0.0.6 (published 9 months ago)
- Last Synced: 2026-06-10T11:08:49.905Z (3 days ago)
- Versions: 7
- Dependent Packages: 0
- Dependent Repositories: 0
- Downloads: 226 Last month
-
Rankings:
- Dependent packages count: 8.944%
- Average: 29.672%
- Dependent repos count: 50.399%
- Maintainers (2)
Dependencies
- actions/checkout v4 composite
- actions/setup-python v4 composite
- actions/cache v4 composite
- actions/checkout v4 composite
- actions/setup-python v4 composite
- mlflow >=3.1.0
- nbformat >=5.10.4
- networkx >=3.4.2
- numpy >=2.2.6
- pandas >=2.3.0
- plotly >=6.1.2
- pyyaml >=6.0.2
- torch >=2.7.1
- torch-geometric >=2.6.1
- torchaudio >=2.7.1
- torchvision >=0.22.1
- actions/checkout v4 composite
- actions/setup-python v4 composite
- pypa/gh-action-pypi-publish release/v1 composite
Score: 12.805853755280662