Skip to content

Commit 946c773

Browse files
authored
Merge pull request #54 from kaylode/dev
A quick fix
2 parents 3717229 + f4d5400 commit 946c773

File tree

13 files changed

+411
-78
lines changed

13 files changed

+411
-78
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ runs/
66
catboost_info/
77
.cache
88
/build/
9+
credentials.json

theseus/base/callbacks/loss_logging_callback.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ def setup(
5252
else:
5353
self.params["valloader_length"] = None
5454

55+
testloader = pl_module.datamodule.testloader
56+
if testloader is not None:
57+
batch_size = testloader.batch_size
58+
self.params["testloader_length"] = len(testloader)
59+
else:
60+
self.params["testloader_length"] = None
61+
5562
if self.print_interval is None:
5663
self.print_interval = self.auto_get_print_interval(pl_module)
5764
LOGGER.text(
@@ -70,6 +77,8 @@ def auto_get_print_interval(
7077
self.params["trainloader_length"]
7178
if self.params["trainloader_length"] is not None
7279
else self.params["valloader_length"]
80+
if self.params["valloader_length"] is not None
81+
else self.params["testloader_length"]
7382
)
7483
print_interval = max(int(train_fraction * num_iterations_per_epoch), 1)
7584
return print_interval
@@ -249,4 +258,56 @@ def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
249258
for k, v in self.running_loss.items()
250259
]
251260

261+
self.running_loss = {}
262+
LOGGER.log(log_dict)
263+
264+
def on_test_epoch_start(self, *args, **kwargs):
265+
"""
266+
Before main test loops
267+
"""
268+
return self.on_validation_epoch_start(*args, **kwargs)
269+
270+
def on_test_batch_end(self, *args, **kwargs):
271+
"""
272+
After finish a batch
273+
"""
274+
return self.on_validation_batch_end(*args, **kwargs)
275+
276+
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
277+
"""
278+
After finish validation
279+
"""
280+
281+
iters = trainer.global_step
282+
num_iterations = self.params["num_iterations"]
283+
epoch_time = time.time() - self.running_time
284+
285+
# Log loss
286+
for key in self.running_loss.keys():
287+
self.running_loss[key] = np.round(np.mean(self.running_loss[key]), 5)
288+
loss_string = (
289+
"{}".format(self.running_loss)[1:-1].replace("'", "").replace(",", " ||")
290+
)
291+
LOGGER.text(
292+
"[{}|{}] || {} || Time: {:10.4f} (it/s)".format(
293+
iters,
294+
num_iterations,
295+
loss_string,
296+
self.params["testloader_length"] / epoch_time,
297+
),
298+
level=LoggerObserver.INFO,
299+
)
300+
301+
# Call other loggers
302+
log_dict = [
303+
{
304+
"tag": f"Test/{k} Loss",
305+
"value": v,
306+
"type": LoggerObserver.SCALAR,
307+
"kwargs": {"step": iters},
308+
}
309+
for k, v in self.running_loss.items()
310+
]
311+
252312
LOGGER.log(log_dict)
313+
self.running_loss = {}

theseus/base/callbacks/metric_logging_callback.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,8 @@ class MetricLoggerCallback(Callback):
2121
def __init__(self, save_json: bool = True, **kwargs) -> None:
2222
super().__init__()
2323
self.save_json = save_json
24-
if self.save_json:
25-
self.save_dir = kwargs.get("save_dir", None)
26-
if self.save_dir is not None:
27-
self.save_dir = osp.join(self.save_dir, "Validation")
28-
os.makedirs(self.save_dir, exist_ok=True)
29-
self.output_dict = []
24+
self.save_dir = kwargs.get("save_dir", None)
25+
self.output_dict = []
3026

3127
def on_validation_end(
3228
self, trainer: pl.Trainer, pl_module: pl.LightningModule
@@ -64,14 +60,57 @@ def on_validation_end(
6460

6561
LOGGER.log(log_dict)
6662

63+
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
64+
"""
65+
After finish validation
66+
"""
67+
iters = trainer.global_step
68+
metric_dict = pl_module.metric_dict
69+
70+
# Save json
71+
if self.save_json:
72+
item = {}
73+
for metric, score in metric_dict.items():
74+
if isinstance(score, (int, float)):
75+
item[metric] = float(f"{score:.5f}")
76+
if len(item.keys()) > 0:
77+
item["iters"] = iters
78+
self.output_dict.append(item)
79+
80+
# Log metric
81+
metric_string = ""
82+
for metric, score in metric_dict.items():
83+
if isinstance(score, (int, float)):
84+
metric_string += metric + ": " + f"{score:.5f}" + " | "
85+
metric_string += "\n"
86+
87+
LOGGER.text(metric_string, level=LoggerObserver.INFO)
88+
89+
# Call other loggers
90+
log_dict = [
91+
{"tag": f"Test/{k}", "value": v, "kwargs": {"step": iters}}
92+
for k, v in metric_dict.items()
93+
]
94+
95+
LOGGER.log(log_dict)
96+
6797
def teardown(
6898
self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
6999
) -> None:
70100
"""
71101
After finish everything
72102
"""
103+
73104
if self.save_json:
74-
save_json = osp.join(self.save_dir, "metrics.json")
75-
if len(self.output_dict) > 0:
76-
with open(save_json, "w") as f:
77-
json.dump(self.output_dict, f)
105+
if self.save_dir is not None:
106+
save_dir = osp.join(self.save_dir, stage.capitalize())
107+
os.makedirs(save_dir, exist_ok=True)
108+
save_json = osp.join(save_dir, "metrics.json")
109+
if len(self.output_dict) > 0:
110+
with open(save_json, "w") as f:
111+
json.dump(
112+
self.output_dict,
113+
f,
114+
indent=4,
115+
default=lambda x: "<not serializable>",
116+
)

theseus/base/callbacks/timer_callback.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,29 @@ def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
9292
"================================================================",
9393
LoggerObserver.INFO,
9494
)
95+
96+
def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
97+
"""
98+
Before main test loops
99+
"""
100+
self.test_epoch_start_time = time.time()
101+
LOGGER.text(
102+
"=============================TEST EVALUATION===================================",
103+
LoggerObserver.INFO,
104+
)
105+
106+
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
107+
"""
108+
After finish test
109+
"""
110+
111+
running_time = time.time() - self.test_epoch_start_time
112+
h, m, s = seconds_to_hours(running_time)
113+
LOGGER.text(
114+
f"Test evaluation epoch running time: {h} hours, {m} minutes and {s} seconds",
115+
level=LoggerObserver.INFO,
116+
)
117+
LOGGER.text(
118+
"================================================================",
119+
LoggerObserver.INFO,
120+
)

theseus/base/callbacks/wandb_callback.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
self.run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
7878
self.run_name = osp.basename(save_dir)
7979

80+
self.config_dict = OmegaConf.to_container(self.config_dict, resolve=True)
8081
if self.resume is None:
8182
self.id = wandblogger.util.generate_id()
8283
else:

theseus/base/models/wrapper.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,19 @@ def on_validation_epoch_end(self) -> None:
6464
batch_size=self.datamodule.valloader.batch_size,
6565
)
6666

67+
def on_test_epoch_end(self) -> None:
68+
self.metric_dict = {}
69+
if self.metrics is not None:
70+
for metric in self.metrics:
71+
self.metric_dict.update(metric.value())
72+
metric.reset()
73+
74+
self.log_dict(
75+
self.metric_dict,
76+
prog_bar=True,
77+
batch_size=self.datamodule.testloader.batch_size,
78+
)
79+
6780
def _forward(self, batch: Dict, metrics: List[Any] = None):
6881
"""
6982
Forward the batch through models, losses and metrics
@@ -94,6 +107,12 @@ def validation_step(self, batch, batch_idx):
94107
self.log_dict(outputs["loss_dict"], prog_bar=True, on_step=True, on_epoch=False)
95108
return outputs
96109

110+
def test_step(self, batch, batch_idx):
111+
# this is the test loop
112+
outputs = self._forward(batch, metrics=self.metrics)
113+
self.log_dict(outputs["loss_dict"], prog_bar=True, on_step=True, on_epoch=False)
114+
return outputs
115+
97116
def predict_step(self, batch, batch_idx=None):
98117
pred = self.model.get_prediction(batch)
99118
return pred

0 commit comments

Comments
 (0)