Skip to content

Commit f4f47af

Browse files
authored
Merge pull request #463 from datamol-io/logging
Improved logging
2 parents 41a1172 + ef5db7f commit f4f47af

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

graphium/config/_loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
# Lightning
1515
from lightning import Trainer
16-
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
16+
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
1717
from lightning.pytorch.loggers import Logger, WandbLogger
1818
from loguru import logger
1919

@@ -415,6 +415,11 @@ def load_trainer(
415415
if "model_checkpoint" in cfg_trainer.keys():
416416
callbacks.append(ModelCheckpoint(**cfg_trainer["model_checkpoint"]))
417417

418+
if "learning_rate_monitor" in cfg_trainer.keys():
419+
callbacks.append(LearningRateMonitor(**cfg_trainer["learning_rate_monitor"]))
420+
else:
421+
callbacks.append(LearningRateMonitor())
422+
418423
# Define the logger parameters
419424
wandb_cfg = config["constants"].get("wandb")
420425
if wandb_cfg is not None:

graphium/trainer/predictor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -618,11 +618,6 @@ def on_validation_epoch_end(self) -> None:
618618
concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs)
619619
concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value)
620620
concatenated_metrics_logs["val/mean_tput"] = self.mean_val_tput_tracker.mean_value
621-
622-
if hasattr(self.optimizers(), "param_groups"):
623-
lr = self.optimizers().param_groups[0]["lr"]
624-
concatenated_metrics_logs["lr"] = torch.tensor(lr)
625-
concatenated_metrics_logs["n_epochs"] = torch.tensor(self.current_epoch, dtype=torch.float32)
626621
self.log_dict(concatenated_metrics_logs)
627622

628623
# Save yaml file with the per-task metrics summaries

graphium/trainer/predictor_summaries.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,6 @@ def get_metrics_logs(self) -> Dict[str, Any]:
248248
metric_logs[self.metric_log_name(self.task_name, "median_target", self.step_name)] = nan_median(
249249
targets
250250
)
251-
if torch.cuda.is_available():
252-
metric_logs[f"gpu_allocated_GB"] = torch.tensor(torch.cuda.memory_allocated() / (2**30))
253251

254252
# Specify which metrics to use
255253
metrics_to_use = self.metrics

0 commit comments

Comments
 (0)