Skip to content

Commit bfe28f8

Browse files
authored
Merge pull request #39 from cyber-physical-systems-group/experiment/configs
Experiment/configs
2 parents 64777c3 + bfab7bd commit bfe28f8

File tree

11 files changed

+386
-20
lines changed

11 files changed

+386
-20
lines changed
Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,117 @@
11
# Experiment
22

3-
This package contains utils for running experiments with W&B, including single runs, sweeps etc.
4-
The code here can be used only with W&B, but this is not required to use other packages
3+
This directory contains experiment utils, including entrypoints, which can be used to run W&B experiments. They are not
4+
integral part of the library, so they need additional code defining the experiment settings to run.
55

6-
## Reporters
6+
## How to Use
77

8-
Reporters are standalone functions used to log commonly needed experiment properties to W&B, including plots.
9-
To use then run them in the experiment code:
8+
To use the entrypoints and utils provided here, `RuntimeContext` needs to be implemented, which is used to parametrize
9+
the experiment. Interface is given by state-less class, so it can be defined as single namespace, it needs to be passed
10+
to the entrypoint function.
1011

11-
* `report_prediction_plot` - adds interactive plotly graphic to W&B
12-
* `report_metrics` - adds numeric value for each regression metrics
13-
* `report_trainable_parameters` - adds number of trainable parameters of the model
12+
To run experiment (assume not using sweep for now), following code needs to be implemented. This additionally allows
13+
to use `click` library to pass the config files as options.
1414

15-
### Example
15+
```python
16+
import click
1617

17-
To use reporters run following example code:
18+
from pydentification.experiment.entrypoints import run
1819

19-
```python
20-
y_hat = trainer.predict(model, test_loader) # assume trainer and model are trained
21-
y_pred = torch.cat(y_hat).numpy()
22-
y_true = torch.cat([y for _, y in test_loader]).numpy()
20+
from src import runtime # assume this is project specific code
21+
22+
23+
@click.command()
24+
@click.option("--data", type=click.Path(exists=True), required=True)
25+
@click.option("--experiment", type=click.Path(exists=True), required=True)
26+
def main(data, experiment):
27+
run(data=data, experiment=experiment, runtime=runtime)
28+
29+
30+
if __name__ == "__main__":
31+
main()
32+
```
33+
34+
To run simply execute the script with flags for config passing.
35+
36+
```bash
37+
python main.py --data data.yaml --experiment experiment.yaml
38+
```
39+
40+
## Parametrization
2341

24-
metrics = regression_metrics(y_pred=y_pred.flatten(), y_true=y_true.flatten()) # function from pydentification.metrics
42+
Each training run is parametrized by 5 functions, which take two configurations. Functions are used to define the input,
43+
model architecture, training logic, reporting to W&B and storage logic. For the last three, we provide useful defaults
44+
in `defaults` package. They need to be implemented as part of single namespace, called `context`, which is passed to
45+
entrypoint for running experiment and sweep. Interface for `context` is given by `RuntimeContext`.
2546

26-
# run reporters
27-
report_metrics(metrics)
28-
report_trainable_parameters(model)
29-
report_prediction_plot(predictions=y_pred, targets=y_true)
47+
Additionally, two config files can be used, one for storing data and one for model parameters. They are used to abstract
48+
dataset loading and creating model architecture from the code, to quickly iterate through different configurations.
49+
Otherwise `pydentification` library can be used as collection of standalone components, which can be useful for various
50+
project related to neural system identification.
51+
52+
### Functions
53+
54+
* `input_fn` - function which takes configuration file and returns `pl.DataModule` subclass, typically one of data-modules provided by `pydentification`.
55+
* `model_fn` - function which takes configuration file and returns `pl.LightningModule` and `pl.Trainer` instances.
56+
* `train_fn` - function which takes model, trainer and the data-module and executes training code, typically inside the `Trainer`.
57+
* `report_fn` - function, which takes model, trainer and the data-module, it should run predictions and store relevant metrics in W&B dashboard.
58+
* `save_fn` - function takes in run name (given by W&B `id` or `name`) and model, it saves the model to the disk.
59+
60+
### Configurations
61+
62+
The entrypoint (both training and sweep) are parameterized by two configs, one of them is for the data settings and the
63+
other for the experiment and model, which contains hyperparameters and training settings.
64+
65+
The data config is stored in `YAML` and it is passed to the `input_fn` function. Not all parameters of the data-module
66+
is stored in the data config, only the static values. The example config looks following:
67+
68+
```yaml
69+
name: Dataset
70+
path: data/dataset.csv
71+
test_size: 10000
72+
input_columns: [x]
73+
output_columns: [y]
3074
```
75+
76+
The experiment config looks in the following way.
77+
78+
```yaml
79+
general:
80+
project: project-name
81+
n_runs: 1
82+
name: placeholder
83+
training:
84+
n_epochs: 10
85+
patience: 1
86+
timeout: "00:00:01:00"
87+
batch_size: 32
88+
shift: 1
89+
validation_size: 0.1
90+
model:
91+
model_name: MLP
92+
# generic parameter convention
93+
n_input_time_steps: 64
94+
n_output_time_steps: 1
95+
n_input_state_variables: 1
96+
n_output_state_variables: 1
97+
# neural network parameters
98+
n_hidden_layers: 2
99+
activation: relu
100+
n_hidden_time_steps: 32
101+
n_hidden_state_variables: 4
102+
```
103+
104+
To use sweep add following section to the experiment config.
105+
106+
```yaml
107+
sweep:
108+
name: sweep
109+
method: random
110+
metric: {name: test/root_mean_squared_error, goal: minimize}
111+
sweep_parameters:
112+
# neural network
113+
n_hidden_layers: [1, 2, 3, 4, 5]
114+
n_hidden_time_steps: [32, 16, 8]
115+
n_hidden_state_variables: [1, 4, 8, 16]
116+
activation: [leaky_relu, relu, gelu, sigmoid, tanh]
117+
```
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
import lightning.pytorch as pl
5+
6+
7+
class RuntimeContext(ABC):
8+
"""
9+
This interface defined the runtime context needed for experiment execution by provided entrypoints.
10+
It can be used to define custom experiment execution flow.
11+
12+
The interface can be implemented as module or namespace
13+
"""
14+
15+
@staticmethod
16+
@abstractmethod
17+
def input_fn(config: dict[str, Any], parameters: dict[str, Any]) -> pl.LightningDataModule:
18+
"""
19+
:param config: static dataset configuration
20+
:param parameters: dynamic experiment configuration, for example delay-line length for dynamical systems
21+
22+
:return: LightningDataModule instance, which is used to load and prepare data for training
23+
"""
24+
...
25+
26+
@staticmethod
27+
@abstractmethod
28+
def model_fn(
29+
name: str, config: dict[str, Any], parameters: dict[str, Any]
30+
) -> tuple[pl.LightningModule, pl.Trainer]:
31+
"""
32+
:param name: name of the W&B project, will be used for logging with callbacks
33+
:param config: static configuration, for example timeout, validation-size etc.
34+
:param parameters: dynamic experiment configuration, for example model settings or batch-size
35+
"""
36+
37+
@staticmethod
38+
@abstractmethod
39+
def train_fn(
40+
model: pl.LightningModule, trainer: pl.Trainer, dm: pl.LightningDataModule
41+
) -> tuple[pl.LightningModule, pl.Trainer]:
42+
"""
43+
:param model: LightningModule instance, returned from model_fn
44+
:param trainer: Trainer instance, returned from model_fn
45+
:param dm: LightningDataModule instance, returned from input_fn
46+
47+
:return: trained model and trainer
48+
"""
49+
...
50+
51+
@staticmethod
52+
@abstractmethod
53+
def report_fn(model: pl.LightningModule, trainer: pl.Trainer, dm: pl.LightningDataModule):
54+
"""
55+
:param model: LightningModule instance, returned from train_fn (needs to be trained)
56+
:param trainer: Trainer instance, returned from train_fn, can be used for easier prediction
57+
:param dm: LightningDataModule instance, returned from input_fn, used for prediction on test data
58+
"""
59+
...
60+
61+
@staticmethod
62+
@abstractmethod
63+
def save_fn(name: str, model: pl.LightningModule):
64+
"""
65+
:param name: name of the run, returned from wandb.run.id or config
66+
:param model: trained LightningModule instance to be saved, returned from train_fn
67+
"""
68+
...

pydentification/experiment/defaults/__init__.py

Whitespace-only changes.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import lightning.pytorch as pl
2+
import torch
3+
4+
from pydentification.experiment.reporters import report_metrics, report_prediction_plot, report_trainable_parameters
5+
from pydentification.metrics import regression_metrics
6+
7+
8+
def report_fn(model: pl.LightningModule, trainer: pl.Trainer, dm: pl.LightningDataModule) -> None:
9+
"""Logs the experiment results to W&B"""
10+
y_hat = trainer.predict(model, datamodule=dm)
11+
y_pred = torch.cat(y_hat).numpy()
12+
y_true = torch.cat([y for _, y in dm.test_dataloader()]).numpy()
13+
14+
metrics = regression_metrics(y_pred=y_pred.flatten(), y_true=y_true.flatten()) # type: ignore
15+
16+
report_metrics(metrics, prefix="test") # type: ignore
17+
report_trainable_parameters(model, prefix="config")
18+
report_prediction_plot(predictions=y_pred, targets=y_true, prefix="test")
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
3+
import torch
4+
import wandb
5+
6+
7+
def save_fn(name: str, model: torch.nn.Module):
8+
path = f"models/{name}/trained-model.pt"
9+
os.makedirs(os.path.dirname(path), exist_ok=True)
10+
torch.save(model, path)
11+
wandb.save(path)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import lightning.pytorch as pl
2+
3+
4+
def train_fn(
5+
model: pl.LightningModule, trainer: pl.Trainer, dm: pl.LightningDataModule
6+
) -> tuple[pl.LightningModule, pl.Trainer]:
7+
"""
8+
Runs training using pl.Trainer and pl.LightningModule
9+
with given LightningDataModule, returns both model and trainer
10+
"""
11+
trainer.fit(model, datamodule=dm)
12+
return model, trainer
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import logging
2+
from functools import partial
3+
from typing import Any
4+
5+
import wandb
6+
import yaml
7+
8+
from .context import RuntimeContext
9+
from .parameters import left_dict_join, prepare_config_for_sweep
10+
11+
12+
def run_training(
13+
runtime: RuntimeContext,
14+
project_name: str,
15+
dataset_config: dict[str, Any],
16+
training_config: dict[str, Any],
17+
model_config: dict[str, Any],
18+
):
19+
"""
20+
This function is used to run a single training experiment with given configuration. It contains the main
21+
experimentation logic and parameter passing.
22+
"""
23+
for config in (dataset_config, model_config, training_config):
24+
if isinstance(config, dict): # prevents logging parameters twice in sweep mode
25+
wandb.log(config)
26+
27+
try:
28+
# merge static and dynamic parameters of dataset
29+
data_parameters = left_dict_join(training_config, model_config)
30+
# experiment flow
31+
dm = runtime.input_fn(dataset_config, data_parameters)
32+
model, trainer = runtime.model_fn(project_name, training_config, model_config)
33+
model, trainer = runtime.train_fn(model, trainer, dm)
34+
runtime.report_fn(model, trainer, dm)
35+
runtime.save_fn(wandb.run.id, model)
36+
37+
except Exception as e:
38+
logging.exception(e) # log traceback, W&B can sometimes lose information
39+
raise ValueError("Experiment failed.") from e
40+
41+
42+
def run_sweep_step(
43+
runtime: RuntimeContext, project_name: str, dataset_config: dict[str, Any], experiment_config: dict[str, Any]
44+
):
45+
with wandb.init(reinit=True):
46+
parameters = wandb.config # dynamically generated model settings by W&B sweep
47+
48+
run_training(
49+
runtime=runtime,
50+
project_name=project_name,
51+
dataset_config=dataset_config,
52+
model_config=parameters,
53+
training_config=experiment_config["training"],
54+
)
55+
56+
57+
def run(data: str, experiment: str, runtime: RuntimeContext):
58+
"""
59+
Run single experiment with given configuration.
60+
61+
:param data: dataset configuration
62+
:param experiment: experiment configuration
63+
:param runtime: runtime context with code executing the training and all preparations
64+
"""
65+
dataset_config = yaml.safe_load(open(data))
66+
experiment_config = yaml.safe_load(open(experiment))
67+
68+
project = experiment_config["general"]["project"]
69+
name = experiment_config["general"]["name"]
70+
71+
with wandb.init(project=project, name=name):
72+
model_config = experiment_config["model"]
73+
training_config = experiment_config["training"]
74+
75+
run_training(
76+
runtime=runtime,
77+
project_name=project,
78+
dataset_config=dataset_config,
79+
model_config=model_config,
80+
training_config=training_config,
81+
)
82+
83+
84+
def sweep(data: str, experiment: str, runtime: RuntimeContext):
85+
"""
86+
Run a sweep experiment with given configuration.
87+
88+
:param data: dataset configuration
89+
:param experiment: experiment configuration
90+
:param runtime: runtime context with code executing the training and all preparations
91+
"""
92+
dataset_config = yaml.safe_load(open(data))
93+
experiment_config = yaml.safe_load(open(experiment))
94+
95+
sweep_parameters = left_dict_join(experiment_config["sweep_parameters"], experiment_config["model"])
96+
sweep_config = prepare_config_for_sweep(experiment_config["sweep"], sweep_parameters)
97+
sweep_id = wandb.sweep(sweep_config, project=experiment_config["general"]["project"])
98+
99+
run_sweep_fn = partial(
100+
run_sweep_step, runtime, experiment_config["general"]["project"], dataset_config, experiment_config
101+
)
102+
wandb.agent(
103+
sweep_id,
104+
function=run_sweep_fn,
105+
count=experiment_config["general"]["n_runs"],
106+
project=experiment_config["general"]["project"],
107+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from itertools import chain
2+
from typing import Any
3+
4+
5+
def left_dict_join(main: dict, other: dict) -> dict:
6+
"""Merges two dictionaries into single one, where keys from main are added when duplicate is found"""
7+
return dict(chain(other.items(), main.items()))
8+
9+
10+
def prepare_config_for_sweep(config: dict[str, Any], parameters: dict[str, Any]) -> dict[str, Any]:
11+
"""
12+
Prepares W&B config for running sweep, based on two distinct configs
13+
14+
:param config: general sweep config with values such as name, method or metric
15+
for more details see: https://docs.wandb.ai/guides/sweeps/define-sweep-configuration
16+
:param parameters: parameters to sweep over, each given as list
17+
18+
:return: configuration dictionary ready for sweep
19+
"""
20+
parameters = {
21+
key: {"values": values if isinstance(values, list) else [values]} for key, values in parameters.items()
22+
}
23+
config.update({"parameters": parameters})
24+
25+
return config

0 commit comments

Comments
 (0)