Skip to content

Commit 5e6d837

Browse files
Merge pull request #57 from gridfm/improve_tests
Add data downloading functionality and gdown dependency for test setup
2 parents a5e6449 + 15a9820 commit 5e6d837

6 files changed

Lines changed: 381 additions & 67 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ Install gridfm-graphkit in editable mode
3333
pip install -e .
3434
```
3535

36+
**`torch-scatter` is a required dependency.** It cannot be bundled in `pyproject.toml` because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.
37+
3638
Get PyTorch + CUDA version for torch-scatter
3739
```bash
3840
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")

docs/install/installation.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ Install gridfm-graphkit in editable mode
1515
pip install -e .
1616
```
1717

18+
**`torch-scatter` is a required dependency.** It cannot be bundled in `pyproject.toml` because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.
19+
1820
Get PyTorch + CUDA version for torch-scatter
1921

2022
```bash

integrationtests/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
3+
4+
def pytest_addoption(parser):
5+
parser.addoption(
6+
"--calibrate",
7+
type=int,
8+
default=0,
9+
help="Run training N times to collect metric mean/std for range calibration. "
10+
"Skips metric range assertions. Example: pytest --calibrate 5",
11+
)
12+
parser.addoption(
13+
"--ci",
14+
type=float,
15+
default=0.995,
16+
help="Confidence interval level for calibration stats (default 0.995). "
17+
"Example: pytest --calibrate 5 -s --ci 0.995",
18+
)
19+
20+
21+
@pytest.fixture
22+
def calibrate_runs(request):
23+
"""Number of calibration runs requested via --calibrate (0 = normal test mode)."""
24+
return request.config.getoption("--calibrate")
25+
26+
27+
@pytest.fixture
28+
def ci_level(request):
29+
"""Confidence interval level requested via --ci (default 0.995)."""
30+
return request.config.getoption("--ci")
31+
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import urllib.request
2+
import yaml
3+
import subprocess
4+
5+
6+
def execute_and_live_output(cmd) -> None:
7+
subprocess.run(cmd, text=True, shell=True, check=True)
8+
9+
10+
def _base_config() -> dict:
11+
"""
12+
Download the default config from gridfm-datakit and apply common test parameters.
13+
"""
14+
config_url = (
15+
"https://raw.githubusercontent.com/gridfm/gridfm-datakit/refs/heads/main"
16+
"/scripts/config/default.yaml"
17+
)
18+
19+
print(f"Downloading config from {config_url}...")
20+
with urllib.request.urlopen(config_url) as response:
21+
config_content = response.read().decode("utf-8")
22+
23+
config = yaml.safe_load(config_content)
24+
25+
config["network"]["name"] = "case14_ieee"
26+
config["load"]["scenarios"] = 10000
27+
config["topology_perturbation"]["n_topology_variants"] = 2
28+
29+
return config
30+
31+
32+
def generate_pf_test_data(config_path: str = "integrationtests/default_pf.yaml") -> None:
33+
"""
34+
Generate power-flow (PF) test data for case14_ieee with 10 000 scenarios
35+
and 2 topology variants.
36+
"""
37+
config = _base_config()
38+
39+
with open(config_path, "w") as f:
40+
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
41+
42+
print(f"PF config written to {config_path}")
43+
print(f" network.name : {config['network']['name']}")
44+
print(f" load.scenarios : {config['load']['scenarios']}")
45+
print(f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}")
46+
47+
execute_and_live_output(f"gridfm_datakit generate {config_path}")
48+
49+
50+
def generate_opf_test_data(config_path: str = "integrationtests/default_opf.yaml") -> None:
51+
"""
52+
Generate optimal power-flow (OPF) test data for case14_ieee with 10 000 scenarios
53+
and 2 topology variants.
54+
"""
55+
config = _base_config()
56+
config["settings"]["mode"] = "opf"
57+
58+
with open(config_path, "w") as f:
59+
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
60+
61+
print(f"OPF config written to {config_path}")
62+
print(f" network.name : {config['network']['name']}")
63+
print(f" load.scenarios : {config['load']['scenarios']}")
64+
print(f" topology_perturbation.n_topology_variants: {config['topology_perturbation']['n_topology_variants']}")
65+
print(f" settings.mode : {config['settings']['mode']}")
66+
67+
execute_and_live_output(f"gridfm_datakit generate {config_path}")
68+
69+
70+
if __name__ == "__main__":
71+
#generate_pf_test_data()
72+
generate_opf_test_data()

0 commit comments

Comments
 (0)