Skip to content

Commit 093b9ef

Browse files
authored
Merge pull request #246 from rom1504/sample_per_sec
Sample per sec
2 parents ae0e6a9 + 8866785 commit 093b9ef

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

train_dalle.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
from pathlib import Path
3+
import time
34

45
import torch
56
import wandb # Quit early if user doesn't have wandb installed.
@@ -381,6 +382,8 @@ def save_model(path):
381382
if data_sampler:
382383
data_sampler.set_epoch(epoch)
383384
for i, (text, images) in enumerate(distr_dl):
385+
if i % 10 == 0 and distr_backend.is_root_worker():
386+
t = time.time()
384387
if args.fp16:
385388
images = images.half()
386389
text, images = map(lambda t: t.cuda(), (text, images))
@@ -433,6 +436,11 @@ def save_model(path):
433436
log['image'] = wandb.Image(image, caption=decoded_text)
434437

435438

439+
if i % 10 == 9 and distr_backend.is_root_worker():
440+
sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
441+
log["sample_per_sec"] = sample_per_sec
442+
print(epoch, i, f'sample_per_sec - {sample_per_sec}')
443+
436444
if distr_backend.is_root_worker():
437445
wandb.log(log)
438446

0 commit comments

Comments
 (0)