Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ runs/
catboost_info/
.cache
/build/
credentials.json
61 changes: 61 additions & 0 deletions theseus/base/callbacks/loss_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def setup(
else:
self.params["valloader_length"] = None

testloader = pl_module.datamodule.testloader
if testloader is not None:
batch_size = testloader.batch_size
self.params["testloader_length"] = len(testloader)
else:
self.params["testloader_length"] = None

if self.print_interval is None:
self.print_interval = self.auto_get_print_interval(pl_module)
LOGGER.text(
Expand All @@ -70,6 +77,8 @@ def auto_get_print_interval(
self.params["trainloader_length"]
if self.params["trainloader_length"] is not None
else self.params["valloader_length"]
if self.params["valloader_length"] is not None
else self.params["testloader_length"]
)
print_interval = max(int(train_fraction * num_iterations_per_epoch), 1)
return print_interval
Expand Down Expand Up @@ -249,4 +258,56 @@ def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
for k, v in self.running_loss.items()
]

self.running_loss = {}
LOGGER.log(log_dict)

def on_test_epoch_start(self, *args, **kwargs):
"""
Before main test loops
"""
return self.on_validation_epoch_start(*args, **kwargs)

def on_test_batch_end(self, *args, **kwargs):
"""
After finish a batch
"""
return self.on_validation_batch_end(*args, **kwargs)

def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
"""
After finish validation
"""

iters = trainer.global_step
num_iterations = self.params["num_iterations"]
epoch_time = time.time() - self.running_time

# Log loss
for key in self.running_loss.keys():
self.running_loss[key] = np.round(np.mean(self.running_loss[key]), 5)
loss_string = (
"{}".format(self.running_loss)[1:-1].replace("'", "").replace(",", " ||")
)
LOGGER.text(
"[{}|{}] || {} || Time: {:10.4f} (it/s)".format(
iters,
num_iterations,
loss_string,
self.params["testloader_length"] / epoch_time,
),
level=LoggerObserver.INFO,
)

# Call other loggers
log_dict = [
{
"tag": f"Test/{k} Loss",
"value": v,
"type": LoggerObserver.SCALAR,
"kwargs": {"step": iters},
}
for k, v in self.running_loss.items()
]

LOGGER.log(log_dict)
self.running_loss = {}
59 changes: 49 additions & 10 deletions theseus/base/callbacks/metric_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@ class MetricLoggerCallback(Callback):
def __init__(self, save_json: bool = True, **kwargs) -> None:
super().__init__()
self.save_json = save_json
if self.save_json:
self.save_dir = kwargs.get("save_dir", None)
if self.save_dir is not None:
self.save_dir = osp.join(self.save_dir, "Validation")
os.makedirs(self.save_dir, exist_ok=True)
self.output_dict = []
self.save_dir = kwargs.get("save_dir", None)
self.output_dict = []

def on_validation_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
Expand Down Expand Up @@ -64,14 +60,57 @@ def on_validation_end(

LOGGER.log(log_dict)

def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""
After finish validation
"""
iters = trainer.global_step
metric_dict = pl_module.metric_dict

# Save json
if self.save_json:
item = {}
for metric, score in metric_dict.items():
if isinstance(score, (int, float)):
item[metric] = float(f"{score:.5f}")
if len(item.keys()) > 0:
item["iters"] = iters
self.output_dict.append(item)

# Log metric
metric_string = ""
for metric, score in metric_dict.items():
if isinstance(score, (int, float)):
metric_string += metric + ": " + f"{score:.5f}" + " | "
metric_string += "\n"

LOGGER.text(metric_string, level=LoggerObserver.INFO)

# Call other loggers
log_dict = [
{"tag": f"Test/{k}", "value": v, "kwargs": {"step": iters}}
for k, v in metric_dict.items()
]

LOGGER.log(log_dict)

def teardown(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
) -> None:
"""
After finish everything
"""

if self.save_json:
save_json = osp.join(self.save_dir, "metrics.json")
if len(self.output_dict) > 0:
with open(save_json, "w") as f:
json.dump(self.output_dict, f)
if self.save_dir is not None:
save_dir = osp.join(self.save_dir, stage.capitalize())
os.makedirs(save_dir, exist_ok=True)
save_json = osp.join(save_dir, "metrics.json")
if len(self.output_dict) > 0:
with open(save_json, "w") as f:
json.dump(
self.output_dict,
f,
indent=4,
default=lambda x: "<not serializable>",
)
26 changes: 26 additions & 0 deletions theseus/base/callbacks/timer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,29 @@ def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
"================================================================",
LoggerObserver.INFO,
)

def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
"""
Before main test loops
"""
self.test_epoch_start_time = time.time()
LOGGER.text(
"=============================TEST EVALUATION===================================",
LoggerObserver.INFO,
)

def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
"""
After finish test
"""

running_time = time.time() - self.test_epoch_start_time
h, m, s = seconds_to_hours(running_time)
LOGGER.text(
f"Test evaluation epoch running time: {h} hours, {m} minutes and {s} seconds",
level=LoggerObserver.INFO,
)
LOGGER.text(
"================================================================",
LoggerObserver.INFO,
)
1 change: 1 addition & 0 deletions theseus/base/callbacks/wandb_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self.run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.run_name = osp.basename(save_dir)

self.config_dict = OmegaConf.to_container(self.config_dict, resolve=True)
if self.resume is None:
self.id = wandblogger.util.generate_id()
else:
Expand Down
19 changes: 19 additions & 0 deletions theseus/base/models/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ def on_validation_epoch_end(self) -> None:
batch_size=self.datamodule.valloader.batch_size,
)

def on_test_epoch_end(self) -> None:
self.metric_dict = {}
if self.metrics is not None:
for metric in self.metrics:
self.metric_dict.update(metric.value())
metric.reset()

self.log_dict(
self.metric_dict,
prog_bar=True,
batch_size=self.datamodule.testloader.batch_size,
)

def _forward(self, batch: Dict, metrics: List[Any] = None):
"""
Forward the batch through models, losses and metrics
Expand Down Expand Up @@ -94,6 +107,12 @@ def validation_step(self, batch, batch_idx):
self.log_dict(outputs["loss_dict"], prog_bar=True, on_step=True, on_epoch=False)
return outputs

def test_step(self, batch, batch_idx):
# this is the test loop
outputs = self._forward(batch, metrics=self.metrics)
self.log_dict(outputs["loss_dict"], prog_bar=True, on_step=True, on_epoch=False)
return outputs

def predict_step(self, batch, batch_idx=None):
pred = self.model.get_prediction(batch)
return pred
Expand Down
Loading
Loading