-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathevaluator.py
More file actions
80 lines (67 loc) · 2.77 KB
/
evaluator.py
File metadata and controls
80 lines (67 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import logging
from pathlib import Path
import hydra
from lightning import Callback, Trainer
from omegaconf import DictConfig, OmegaConf
from ice_station_zebra.data_loaders import ZebraDataModule
from ice_station_zebra.models import ZebraModel
from ice_station_zebra.utils import get_timestamp
logger = logging.getLogger(__name__)
class ZebraEvaluator:
"""A wrapper for PyTorch evaluation"""
def __init__(self, config: DictConfig, checkpoint_path: Path) -> None:
"""Initialize the Zebra evaluator."""
# Verify the checkpoint path
if checkpoint_path.exists():
logger.debug(f"Loaded checkpoint from {checkpoint_path}.")
else:
msg = f"Checkpoint file {checkpoint_path} does not exist."
raise FileNotFoundError(msg)
# Load the model configuration
config_path = checkpoint_path.parent.parent / "model_config.yaml"
try:
ckpt_config = OmegaConf.load(config_path)
logger.debug(f"Loaded checkpoint config from {ckpt_config}.")
config["model"]["_target_"] = ckpt_config["model"]["_target_"] # type: ignore[index]
except (NotADirectoryError, FileNotFoundError):
msg = f"Could not find model configuration file at {config_path}."
logger.debug(msg)
# Load the model from checkpoint
model_cls: type[ZebraModel] = hydra.utils.get_class(config["model"]["_target_"])
self.model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint_path)
# Load inputs into a data module
self.data_module = ZebraDataModule(config)
# Add callbacks
callbacks: list[Callback] = []
for callback_cfg in config["evaluate"].get("callbacks", {}).values():
callbacks.append(hydra.utils.instantiate(callback_cfg))
logger.debug(
f"Adding evaluation callback {callbacks[-1].__class__.__name__}."
)
# Construct lightning loggers
lightning_loggers = [
hydra.utils.instantiate(
dict(**logger_config)
| {
"job_type": "evaluate",
"name": f"{self.model.name}-{get_timestamp()}",
"project": "leaderboard",
},
)
for logger_config in config.get("loggers", {}).values()
]
# Construct the trainer
self.trainer: Trainer = hydra.utils.instantiate(
dict(
{
"callbacks": callbacks,
"logger": lightning_loggers,
},
**config["train"]["trainer"],
)
)
def evaluate(self) -> None:
self.trainer.test(
model=self.model,
datamodule=self.data_module,
)