Skip to content

Commit cfedc54

Browse files
authored
Merge pull request #58 from cyber-physical-systems-group/experiment/model-composer
Experiment/model composer
2 parents d18947e + 6277e83 commit cfedc54

File tree

14 files changed

+427
-40
lines changed

14 files changed

+427
-40
lines changed
Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,30 @@ To run the experiment, use the following command (assuming W&B is installed and
99
sources section):
1010

1111
```bash
12-
python -m examples.prediction.run
12+
python -m examples.core.prediction
1313
```
1414

15+
# Simulation
16+
17+
This example uses `pydentification` to run training with simulation experiment on example benchmark.
18+
Only single experiment is run and registered to W&B.
19+
20+
### Running
21+
22+
To run the experiment, use the following command (assuming W&B is installed and logged in, for more details go to
23+
sources section):
24+
25+
```bash
26+
python -m examples.core.simulation
27+
```
28+
29+
### Experiment
30+
31+
This is the example for fully reproducible sweeps, storing snapshot of the code used to run the experiment in ZIP,
32+
alongside stand-alone function to re-create the model, init parameters in JSON and model weights in safe-tensors format.
33+
34+
It is based on prediction example.
35+
1536
### Sources
1637

1738
* [https://docs.wandb.ai/guides/sweeps](https://docs.wandb.ai/guides/sweeps)

examples/core/experiment.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# isort: skip_file
2+
import os
3+
4+
import lightning.pytorch as pl
5+
import pandas as pd
6+
import torch
7+
8+
import wandb
9+
10+
from pydentification.experiment.storage.wrapper import dump
11+
from pydentification.experiment.storage.models import save_lightning
12+
from pydentification.experiment.storage.code import save_code_snapshot
13+
from pydentification.experiment.storage.sync import save_to_wandb
14+
from pydentification.training.lightning.prediction import LightningPredictionTrainingModule
15+
from pydentification.data.datamodules.prediction import PredictionDataModule
16+
17+
18+
def input_fn(parameters: dict):
19+
data = pd.read_csv("data/lorenz.csv") # assume dataset exists and has ~100 000 samples with 3 columns: x, y, z
20+
return PredictionDataModule(
21+
data[["x", "y", "z"]].values,
22+
test_size=30_000, # 30% assuming 100 000 sample
23+
batch_size=32,
24+
validation_size=0.1, # 10% of the training set, which is 70% of the whole dataset
25+
n_backward_time_steps=parameters["n_input_time_steps"], # sweep parameter, which can be changed between runs
26+
n_forward_time_steps=parameters["n_output_time_steps"],
27+
n_workers=4,
28+
)
29+
30+
31+
# pass parameterless lambda function dynamically returning path to the decorator, after W&B run is initialized
32+
@dump(path=lambda: f"outputs/{wandb.run.id}", param_store="both") # noqa
33+
def model_fn(hidden_dim: int = 64):
34+
return torch.nn.Sequential(
35+
torch.nn.Linear(3, hidden_dim),
36+
torch.nn.ReLU(),
37+
torch.nn.Linear(hidden_dim, hidden_dim),
38+
torch.nn.ReLU(),
39+
torch.nn.Linear(hidden_dim, 3),
40+
)
41+
42+
43+
def trainer_fn(model, parameters: dict):
44+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
45+
timer = pl.callbacks.Timer(duration="00:04:00:00", interval="epoch")
46+
model = LightningPredictionTrainingModule(module=model, optimizer=optimizer)
47+
48+
trainer = pl.Trainer(
49+
max_epochs=3, # just an example
50+
precision=64,
51+
accelerator="cpu",
52+
devices=1,
53+
callbacks=[timer],
54+
)
55+
56+
return model, trainer
57+
58+
59+
def train_fn(model, trainer, dm):
60+
"""Runs training using lightning trainer and given datamodule"""
61+
trainer.fit(model, datamodule=dm)
62+
return model, trainer
63+
64+
65+
def run_single_experiment():
66+
with wandb.init(reinit=True):
67+
# prepare directories and library code snapshot
68+
os.makedirs(f"outputs/{wandb.run.id}/models", exist_ok=True)
69+
os.makedirs(f"outputs/{wandb.run.id}/code", exist_ok=True)
70+
save_code_snapshot(name="code", source_dir="pydentification", target_dir=f"outputs/{wandb.run.id}/code")
71+
72+
parameters = dict(wandb.config) # cast to dict is needed to serialize the parameters
73+
try:
74+
print(f"Starting experiment with {wandb.run.id}")
75+
dm = input_fn(parameters)
76+
model = model_fn()
77+
model, trainer = trainer_fn(model, parameters)
78+
model, trainer = train_fn(model, trainer, dm)
79+
80+
# store trained model and send it to W&B
81+
save_lightning(f"outputs/{wandb.run.id}/models", model=model, method="safetensors", save_hparams=True)
82+
save_to_wandb(f"outputs/{wandb.run.id}") # save all files in the directory to W&B
83+
except Exception as e:
84+
print(e) # print traceback, since W&B uses multiprocessing, which can lose information about exception
85+
raise ValueError("Experiment failed.") from e
86+
87+
88+
if __name__ == "__main__":
89+
sweep_id = wandb.sweep({"hidden_di": [32, 64, 128]}, project="test")
90+
wandb.agent(sweep_id, function=run_single_experiment, count=3, project="test")
Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# isort: skip_file
2-
import os
32
from datetime import timedelta
43

54
import lightning.pytorch as pl
@@ -9,6 +8,7 @@
98

109
from pydentification.data.datamodules.prediction import PredictionDataModule
1110
from pydentification.experiment.reporters import report_metrics, report_prediction_plot, report_trainable_parameters
11+
from pydentification.experiment.storage.models import save_lightning
1212
from pydentification.metrics import regression_metrics
1313
from pydentification.models.networks.transformer import (
1414
CausalDelayLineFeedforward,
@@ -19,10 +19,9 @@
1919

2020

2121
def input_fn(parameters: dict):
22-
df = pd.read_csv("dataset.csv") # assume dataset exists and has ~100 000 samples with 3 columns: x, y, z
23-
22+
data = pd.read_csv("data/lorenz.csv") # assume dataset exists and has ~100 000 samples with 3 columns: x, y, z
2423
return PredictionDataModule(
25-
df[["x", "y", "z"]],
24+
data[["x", "y", "z"]].values,
2625
test_size=30_000, # 30% assuming 100 000 sample
2726
batch_size=32,
2827
validation_size=0.1, # 10% of the training set, which is 70% of the whole dataset
@@ -83,10 +82,10 @@ def trainer_fn(model, parameters: dict):
8382
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, verbose=True)
8483
# callbacks for stopping the training early, with 4 hour timeout and patience of 50 epochs (with 20 for reducing LR)
8584
timer = pl.callbacks.Timer(duration="00:04:00:00", interval="epoch")
86-
stopping = pl.callbacks.EarlyStopping(monitor="training/validation_loss", patience=50, mode="min", verbose=True)
85+
stopping = pl.callbacks.EarlyStopping(monitor="trainer/validation_loss", patience=50, mode="min", verbose=True)
8786
# checkpointing the model every 100 epochs and every hour to single directory
88-
path = f"models/{wandb.run.id}"
89-
epoch_checkpoint = pl.callbacks.ModelCheckpoint(dirpath=path, monitor="validation/loss", every_n_epochs=100)
87+
path = f"outputs/models/{wandb.run.id}"
88+
epoch_checkpoint = pl.callbacks.ModelCheckpoint(dirpath=path, monitor="trainer/validation_loss", every_n_epochs=100)
9089
time_checkpoint = pl.callbacks.ModelCheckpoint(dirpath=path, train_time_interval=timedelta(hours=1))
9190

9291
# wrap model in training class with auto-regression training defined
@@ -146,10 +145,8 @@ def run_single_experiment():
146145
model, trainer = train_fn(model, trainer, dm)
147146
report_fn(model, dm, auto_regression_scales=[16, 32, 128]) # sample of regression scales
148147
# store trained model and send it to W&B
149-
os.makedirs(f"models/{wandb.run.id}", exist_ok=True)
150-
path = f"models/{wandb.run.id}/trained-model.pt"
151-
torch.save(model, path)
152-
wandb.save(path)
148+
save_lightning(name=wandb.run.id, model=model, method="safetensors", save_hparams=True)
149+
wandb.save(wandb.run.id)
153150
except Exception as e:
154151
print(e) # print traceback, since W&B uses multiprocessing, which can lose information about exception
155152
raise ValueError("Experiment failed.") from e
@@ -201,4 +198,4 @@ def run_single_experiment():
201198

202199
if __name__ == "__main__":
203200
sweep_id = wandb.sweep(SWEEP_CONFIG, project="test") # change project name
204-
wandb.agent(sweep_id, function=run_single_experiment, count=10, project="test")
201+
wandb.agent(sweep_id, function=run_single_experiment, count=1, project="test")
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def model_fn():
6161
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=50, verbose=True)
6262

6363
timer = pl.callbacks.Timer(duration="00:04:00:00", interval="epoch") # 4 hours
64-
stopping = pl.callbacks.EarlyStopping(monitor="training/validation_loss", patience=50, mode="min", verbose=True)
64+
stopping = pl.callbacks.EarlyStopping(monitor="trainer/validation_loss", patience=50, mode="min", verbose=True)
6565

6666
path = f"models/{wandb.run.id}"
67-
epoch_checkpoint = pl.callbacks.ModelCheckpoint(dirpath=path, monitor="training/validation_loss", every_n_epochs=10)
67+
epoch_checkpoint = pl.callbacks.ModelCheckpoint(dirpath=path, monitor="trainer/validation_loss", every_n_epochs=10)
6868

6969
model = LightningSimulationTrainingModule(transformer, optimizer, lr_scheduler, loss=torch.nn.MSELoss())
7070

pydentification/data/datamodules/prediction.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def __init__(
8484
"""
8585
super().__init__()
8686

87+
if not isinstance(states, np.ndarray):
88+
raise TypeError(f"States must be numpy given as numpy array, got {type(states)}!")
89+
8790
self.states = states
8891

8992
self.test_size = test_size

pydentification/data/datamodules/simulation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def __init__(
5353
batch_size: int = 32,
5454
dtype: torch.dtype = torch.float32,
5555
):
56+
if not isinstance(inputs, np.ndarray) or not isinstance(outputs, np.ndarray):
57+
raise TypeError(f"Inputs and outputs must be numpy arrays! Got {type(inputs)} and {type(outputs)}!")
58+
5659
self.inputs = inputs
5760
self.outputs = outputs
5861

File renamed without changes.

pydentification/experiment/dumper/code.py renamed to pydentification/experiment/storage/code.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
import uuid
44
from pathlib import Path
55

6-
PYTHON_EXTENSIONS = frozenset({".py", ".json", ".txt", ".md", ".yaml", ".yml", ".toml", ".ini"})
7-
DEFAULT_FORBIDDEN_PREFIX = frozenset({"venv", ".ipynb_checkpoints", "__pycache__", ".git", ".pytest_cache"})
8-
96

107
def _load_gitignore() -> set[str]:
118
"""Load .gitignore from default name and root directory as set"""
@@ -21,7 +18,7 @@ def not_comment(line: str) -> bool:
2118
return set(filter(not_comment, f.read().splitlines()))
2219

2320

24-
def _skip_subdir(current: Path, archive_path: Path, forbidden_paths: frozenset[str]) -> bool:
21+
def _skip_subdir(current: Path, archive_path: Path, forbidden_paths: set[str]) -> bool:
2522
# prevent copying the temp directory, where the archive with source code is build
2623
if str(archive_path.absolute()) == current:
2724
return True
@@ -34,38 +31,57 @@ def _skip_subdir(current: Path, archive_path: Path, forbidden_paths: frozenset[s
3431
return False
3532

3633

37-
def save_code_snapshot(name: str, source_dir: str | Path):
38-
"""Save only text-based files in a ZIP archive, excluding binary data files."""
39-
34+
def save_code_snapshot(
35+
name: str,
36+
source_dir: str | Path,
37+
target_dir: str | Path,
38+
filter_prefix: set[str] = frozenset({"venv", ".ipynb_checkpoints", "__pycache__", ".git", ".pytest_cache"}),
39+
accept_suffix: set[str] = frozenset({".py", ".json", ".txt", ".md", ".yaml", ".yml", ".toml", ".ini"}),
40+
use_gitignore: bool = True,
41+
):
42+
"""
43+
Save only text-based files in a ZIP archive, excluding binary data files.
44+
45+
:param name: name of the archive file
46+
:param source_dir: path to the directory with source code
47+
:param target_dir: path to the directory where the archive will be saved
48+
:param filter_prefix: set of prefixes to exclude from the archive
49+
:param accept_suffix: set of suffixes to include in the archive
50+
:param use_gitignore: whether to use .gitignore file in the source directory for filter_prefix
51+
"""
4052
if isinstance(source_dir, str):
4153
source_dir = Path(source_dir)
4254

43-
source_dir = Path(source_dir).resolve() # ensure absolute path
44-
snapshot_filename = f"source_code_{name}"
45-
temp_dir = Path(f"temp_code_snapshot_{uuid.uuid4()}") # append random UUID to avoid conflicts
55+
if isinstance(target_dir, str):
56+
target_dir = Path(target_dir)
4657

47-
gitignore = _load_gitignore()
48-
forbidden = DEFAULT_FORBIDDEN_PREFIX | gitignore
58+
source_dir = Path(source_dir).resolve() # ensure absolute path
59+
snapshot_path = target_dir / name
60+
temp_dir = target_dir / str(uuid.uuid4()) # create temp dir with unique name for copying files
4961

5062
if temp_dir.exists():
5163
shutil.rmtree(temp_dir)
5264

5365
temp_dir.mkdir(parents=True, exist_ok=True)
5466

67+
if use_gitignore:
68+
filter_prefix |= _load_gitignore() # union with .gitignore, if present
69+
5570
for root, dirs, files in os.walk(source_dir):
5671
root_path = Path(root)
57-
if _skip_subdir(root_path, temp_dir, forbidden):
72+
73+
if _skip_subdir(root_path, temp_dir, filter_prefix):
5874
dirs.clear() # prevent descending into this directory
5975
continue # skip to the next directory
6076

6177
for file in files:
62-
file_path = root_path / file
63-
if file_path.suffix in PYTHON_EXTENSIONS:
64-
relative_path = file_path.relative_to(source_dir)
78+
source_path = root_path / file
79+
if source_path.suffix in accept_suffix:
80+
relative_path = source_path.relative_to(os.getcwd())
6581
dest_path = temp_dir / relative_path
6682

6783
dest_path.parent.mkdir(parents=True, exist_ok=True)
68-
shutil.copy2(file_path, dest_path)
84+
shutil.copy2(relative_path, dest_path)
6985

70-
shutil.make_archive(snapshot_filename, format="zip", root_dir=temp_dir) # archive the directory
86+
shutil.make_archive(str(snapshot_path), format="zip", root_dir=temp_dir) # archive the directory
7187
shutil.rmtree(temp_dir)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import importlib.util
2+
import json
3+
import os
4+
import shutil
5+
import sys
6+
import zipfile
7+
from pathlib import Path
8+
from typing import Any, Callable
9+
10+
11+
class ReplaceSourceCode:
12+
"""
13+
ContextManager over-writing imports with given `path` to ZIP with source code created by
14+
`pydentification.experiment.storage.code.save_code_snapshot`.
15+
16+
Code is extracted to a temporary directory and added to the `sys.path` for the duration of the context and removed
17+
afterward on exit. The source code needs to be unique directory to avoid conflicts with other imports.
18+
"""
19+
20+
def __init__(self, path: Path):
21+
self.path = path
22+
self.source_path = path.with_suffix("")
23+
24+
def __enter__(self):
25+
if self.source_path.exists():
26+
raise FileExistsError(f"Can't overwrite {self.source_path.stem}!")
27+
28+
with zipfile.ZipFile(self.path, "r") as zip:
29+
zip.extractall(str(self.source_path))
30+
31+
sys.path.append(str(self.source_path))
32+
return self
33+
34+
def __exit__(self, exc_type, exc_val, exc_tb):
35+
sys.path.remove(str(self.source_path))
36+
shutil.rmtree(self.source_path)
37+
38+
39+
def _import_function_from_path(module_path: str, function_name: str) -> Callable:
40+
"""Dynamically imports a function from a Python file given the file path and function name."""
41+
module_name = os.path.basename(module_path).replace(".py", "")
42+
spec = importlib.util.spec_from_file_location(module_name, module_path)
43+
44+
module = importlib.util.module_from_spec(spec)
45+
sys.modules[module_name] = module
46+
47+
spec.loader.exec_module(module)
48+
function = getattr(module, function_name)
49+
50+
return function
51+
52+
53+
def _load_model_and_parameters(path: str | Path, name: str, parameters: dict[str, Any]) -> Any:
54+
model_fn = _import_function_from_path(path, name)
55+
return model_fn(**parameters)
56+
57+
58+
def compose_model(
59+
path: str | Path,
60+
name: str = "model_fn",
61+
parameters: str | Path | None = None,
62+
source: str | Path | None = None,
63+
):
64+
"""
65+
Compose model from dump, which will contain model generating function, JSON with its parameters and source code
66+
for module definitions (ZIP of entire `pydentification`).
67+
68+
:param path: filesystem Path to the model generating function, which will be imported by `import_function_from_path`
69+
:param name: name of the function to be imported, default is `model_fn`
70+
:param parameters: filesystem Path to the JSON file with parameters, if None, empty dictionary will be used
71+
:param source: filesystem Path to the ZIP file with source code
72+
if None imports are attempted from the current working directory.
73+
"""
74+
if isinstance(source, str):
75+
source = Path(source)
76+
77+
if parameters is not None:
78+
with open(parameters, "r") as f:
79+
parameters = json.load(f)
80+
else:
81+
parameters = {}
82+
83+
if source is not None:
84+
with ReplaceSourceCode(source):
85+
return _load_model_and_parameters(path, name, parameters)
86+
else:
87+
return _load_model_and_parameters(path, name, parameters)

0 commit comments

Comments
 (0)