Skip to content

Commit 6a4f19e

Browse files
committed
[experiment](feat) Move run and sweep entrypoints to single file and simplify logic
1 parent 3edd234 commit 6a4f19e

File tree

3 files changed

+107
-108
lines changed

3 files changed

+107
-108
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import logging
2+
from functools import partial
3+
from typing import Any
4+
5+
import wandb
6+
import yaml
7+
8+
from .context import RuntimeContext
9+
from .parameters import left_dict_join, prepare_config_for_sweep
10+
11+
12+
def run_training(
13+
runtime: RuntimeContext,
14+
project_name: str,
15+
dataset_config: dict[str, Any],
16+
training_config: dict[str, Any],
17+
model_config: dict[str, Any],
18+
):
19+
"""
20+
This function is used to run a single training experiment with given configuration. It contains the main
21+
experimentation logic and parameter passing.
22+
"""
23+
for config in (dataset_config, model_config, training_config):
24+
if isinstance(config, dict): # prevents logging parameters twice in sweep mode
25+
wandb.log(config)
26+
27+
try:
28+
# merge static and dynamic parameters of dataset
29+
data_parameters = left_dict_join(training_config, model_config)
30+
# experiment flow
31+
dm = runtime.input_fn(dataset_config, data_parameters)
32+
model, trainer = runtime.model_fn(project_name, training_config, model_config)
33+
model, trainer = runtime.train_fn(model, trainer, dm)
34+
runtime.report_fn(model, trainer, dm)
35+
runtime.save_fn(wandb.run.id, model)
36+
37+
except Exception as e:
38+
logging.exception(e) # log traceback, W&B can sometimes lose information
39+
raise ValueError("Experiment failed.") from e
40+
41+
42+
def run_sweep_step(
43+
runtime: RuntimeContext, project_name: str, dataset_config: dict[str, Any], experiment_config: dict[str, Any]
44+
):
45+
with wandb.init(reinit=True):
46+
parameters = wandb.config # dynamically generated model settings by W&B sweep
47+
48+
run_training(
49+
runtime=runtime,
50+
project_name=project_name,
51+
dataset_config=dataset_config,
52+
model_config=parameters,
53+
training_config=experiment_config["training"],
54+
)
55+
56+
57+
def run(data: str, experiment: str, runtime: RuntimeContext):
58+
"""
59+
Run single experiment with given configuration.
60+
61+
:param data: dataset configuration
62+
:param experiment: experiment configuration
63+
:param runtime: runtime context with code executing the training and all preparations
64+
"""
65+
dataset_config = yaml.safe_load(open(data))
66+
experiment_config = yaml.safe_load(open(experiment))
67+
68+
project = experiment_config["general"]["project"]
69+
name = experiment_config["general"]["name"]
70+
71+
with wandb.init(project=project, name=name):
72+
model_config = experiment_config["model"]
73+
training_config = experiment_config["training"]
74+
75+
run_training(
76+
runtime=runtime,
77+
project_name=project,
78+
dataset_config=dataset_config,
79+
model_config=model_config,
80+
training_config=training_config,
81+
)
82+
83+
84+
def sweep(data: str, experiment: str, runtime: RuntimeContext):
85+
"""
86+
Run a sweep experiment with given configuration.
87+
88+
:param data: dataset configuration
89+
:param experiment: experiment configuration
90+
:param runtime: runtime context with code executing the training and all preparations
91+
"""
92+
dataset_config = yaml.safe_load(open(data))
93+
experiment_config = yaml.safe_load(open(experiment))
94+
95+
sweep_parameters = left_dict_join(experiment_config["sweep_parameters"], experiment_config["model"])
96+
sweep_config = prepare_config_for_sweep(experiment_config["sweep"], sweep_parameters)
97+
sweep_id = wandb.sweep(sweep_config, project=experiment_config["general"]["project"])
98+
99+
run_sweep_fn = partial(
100+
run_sweep_step, runtime, experiment_config["general"]["project"], dataset_config, experiment_config
101+
)
102+
wandb.agent(
103+
sweep_id,
104+
function=run_sweep_fn,
105+
count=experiment_config["general"]["n_runs"],
106+
project=experiment_config["general"]["project"],
107+
)

pydentification/experiment/run.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

pydentification/experiment/sweep.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

0 commit comments

Comments
 (0)