-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathbest_hyperparameters.py
executable file
·65 lines (57 loc) · 1.93 KB
/
best_hyperparameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python
"""Retrieve hyperparameters from a W&B sweep."""
import argparse
import logging
import wandb
from yoyodyne import defaults
# Expand as needed.
FLAGS_TO_IGNORE = frozenset(
["eval_metrics", "local_run_dir", "n_model_params"]
)
def main(args: argparse.Namespace) -> None:
api = wandb.Api()
sweep = api.sweep(f"{args.entity}/{args.project}/{args.sweep_id}")
best_run = sweep.best_run()
logging.info("Best run URL: %s", best_run.url)
# Sorting for stability.
args = []
for key, value in sorted(best_run.config.items()):
# Exclusions:
#
# * Explicitly ignored flags
# * Keys with "/" is for redundant parameters in the scheduler.
# * Key/value pairs that are defaults can be omitted.
# * Keys ending in "_cls", "_idx", and "vocab_size" are set
# automatically.
# * None values are defaults definitionally.
if key in FLAGS_TO_IGNORE:
continue
if "/" in key:
continue
key_upper = key.upper()
if hasattr(defaults, key_upper):
if getattr(defaults, key_upper) == value:
continue
if (
key.endswith("_cls")
or key.endswith("_idx")
or key.endswith("vocab_size")
):
continue
if value is None:
continue
args.append((key, value))
print(" ".join(f"--{key} {value}" for key, value in args))
if __name__ == "__main__":
logging.basicConfig(format="%(levelname)s: %(message)s", level="INFO")
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--entity",
required=True,
help="The entity scope for the project.",
)
parser.add_argument(
"--project", required=True, help="The project of the sweep."
)
parser.add_argument("--sweep_id", required=True, help="ID for the sweep.")
main(parser.parse_args())