Skip to content

Commit

Permalink
adding tensorboard support
Browse files Browse the repository at this point in the history
  • Loading branch information
exx8 committed Mar 14, 2023
1 parent cd3ee78 commit 2bdd443
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
7 changes: 7 additions & 0 deletions timm/utils/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def update_summary(
lr=None,
write_header=False,
log_wandb=False,
tensorboard_writer=False,
):
rowd = OrderedDict(epoch=epoch)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
Expand All @@ -43,6 +44,12 @@ def update_summary(
rowd['lr'] = lr
if log_wandb:
wandb.log(rowd)
if tensorboard_writer:
import torch
for k, v in rowd.items():
if isinstance(v, float):
tensorboard_writer.add_scalar(k, v, epoch)

with open(filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if write_header: # first iteration (epoch == 1 can't be used)
Expand Down
22 changes: 19 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@
has_functorch = True
except ImportError as e:
has_functorch = False

#test tensorboard install
try:
from torch.utils.tensorboard import SummaryWriter
has_tensorboard = True
except ImportError as e:
has_tensorboard = False
has_compile = hasattr(torch, 'compile')


Expand Down Expand Up @@ -347,8 +352,8 @@
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')


group.add_argument('--log-tensorboard', default='', type=str, metavar='PATH',
help='log training and validation metrics to TensorBoard')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
Expand Down Expand Up @@ -726,6 +731,16 @@ def main():
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")

if utils.is_primary(args) and args.log_tensorboard:
if has_tensorboard:
writer = SummaryWriter(args.log_tensorboard)
else:
_logger.warning(
"You've requested to log metrics to tensorboard but package not found. "
"Metrics not being logged to tensorboard, try `pip install tensorboard`")



# setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train)
lr_scheduler, num_epochs = create_scheduler_v2(
Expand Down Expand Up @@ -809,6 +824,7 @@ def main():
lr=sum(lrs) / len(lrs),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb,
tensorboard_writer=writer if writer is not None and has_tensorboard else False,
)

if saver is not None:
Expand Down

0 comments on commit 2bdd443

Please sign in to comment.