Skip to content

Commit 893fbda

Browse files
authored
Fix training pipeline and cleanup logs (#142)
* Fix OmegaConf resolve * Fix bug with benchmark logging in train.py * Fix logging errors in train * Fix cleanup of logfiles * Add log cleanup interval option (default 2 minutes) to train.py * Fix formatting * Black formatting
1 parent d518111 commit 893fbda

File tree

2 files changed

+51
-32
lines changed

2 files changed

+51
-32
lines changed

gluefactory/train.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def write_image_summaries(writer, name, figures, step):
213213
for k, fig in figs.items():
214214
writer.add_figure(f"{name}/{i}_{k}", fig, step)
215215
else:
216-
for k, fig in figs.items():
216+
for k, fig in figures.items():
217217
writer.add_figure(f"{name}/{k}", fig, step)
218218

219219

@@ -414,16 +414,19 @@ def trace_handler(p):
414414
):
415415
for bname, eval_conf in conf.get("benchmarks", {}).items():
416416
logger.info(f"Running eval on {bname}")
417-
results, figures, _ = run_benchmark(
417+
summaries, figures, _ = run_benchmark(
418418
bname,
419419
eval_conf,
420420
settings.EVAL_PATH / bname / args.experiment / str(epoch),
421421
model.eval(),
422422
)
423-
logger.info(str(results))
424-
write_dict_summaries(writer, f"test/{bname}", results, epoch)
423+
str_summaries = [
424+
f"{k} {v:.3E}" for k, v in summaries.items() if isinstance(v, float)
425+
]
426+
logger.info(f'[{bname}] {{{", ".join(str_summaries)}}}')
427+
write_dict_summaries(writer, f"test/{bname}", summaries, epoch)
425428
write_image_summaries(writer, f"figures/{bname}", figures, epoch)
426-
del results, figures
429+
del summaries, figures
427430

428431
# set the seed
429432
set_seed(conf.train.seed + epoch)
@@ -572,7 +575,7 @@ def trace_handler(p):
572575
loss_fn,
573576
conf.train,
574577
rank,
575-
pbar=(rank == -1),
578+
pbar=(rank == 0),
576579
)
577580

578581
if rank == 0:
@@ -615,7 +618,7 @@ def trace_handler(p):
615618
loss_fn,
616619
conf.train,
617620
rank,
618-
pbar=(rank == -1),
621+
pbar=(rank == 0),
619622
)
620623
best_eval = results[conf.train.best_key]
621624
best_eval = save_experiment(
@@ -659,7 +662,9 @@ def trace_handler(p):
659662

660663
def main_worker(rank, conf, output_dir, args):
661664
if rank == 0:
662-
with capture_outputs(output_dir / "log.txt"):
665+
with capture_outputs(
666+
output_dir / "log.txt", cleanup_interval=args.cleanup_interval
667+
):
663668
training(rank, conf, output_dir, args)
664669
else:
665670
training(rank, conf, output_dir, args)
@@ -682,6 +687,11 @@ def main_worker(rank, conf, output_dir, args):
682687
type=str,
683688
choices=["default", "reduce-overhead", "max-autotune"],
684689
)
690+
parser.add_argument(
691+
"--cleanup_interval",
692+
default=120, # Cleanup log files every 120 seconds.
693+
type=int,
694+
)
685695
parser.add_argument("--overfit", action="store_true")
686696
parser.add_argument("--restore", action="store_true")
687697
parser.add_argument("--distributed", action="store_true")
@@ -700,7 +710,9 @@ def main_worker(rank, conf, output_dir, args):
700710

701711
conf = OmegaConf.from_cli(args.dotlist)
702712
if args.conf:
703-
conf = OmegaConf.merge(OmegaConf.resolve(OmegaConf.load(args.conf)), conf)
713+
yaml_conf = OmegaConf.load(args.conf)
714+
OmegaConf.resolve(yaml_conf)
715+
conf = OmegaConf.merge(yaml_conf, conf)
704716
elif args.restore:
705717
restore_conf = OmegaConf.load(output_dir / "config.yaml")
706718
conf = OmegaConf.merge(restore_conf, conf)

gluefactory/utils/stdout_capturing.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import subprocess
1212
import sys
1313
from contextlib import contextmanager
14-
from threading import Timer
1514

1615

1716
def apply_backspaces_and_linefeeds(text):
@@ -61,14 +60,36 @@ def flush():
6160
pass # unsupported
6261

6362

63+
def cleanup(filename):
64+
with open(str(filename), "r", newline="") as target:
65+
text = target.read()
66+
text = apply_backspaces_and_linefeeds(text)
67+
with open(str(filename), "w") as target:
68+
target.write(text)
69+
70+
6471
# Duplicate stdout and stderr to a file. Inspired by:
6572
# http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/
6673
# http://stackoverflow.com/a/651718/1388435
6774
# http://stackoverflow.com/a/22434262/1388435
6875
@contextmanager
69-
def capture_outputs(filename):
76+
def capture_outputs(filename, cleanup_interval=None):
7077
"""Duplicate stdout and stderr to a file on the file descriptor level."""
71-
with open(str(filename), "a+") as target:
78+
79+
if cleanup_interval is not None:
80+
from threading import Timer
81+
82+
class RepeatTimer(Timer):
83+
def run(self):
84+
while not self.finished.wait(self.interval):
85+
self.function(*self.args, **self.kwargs)
86+
87+
timer = RepeatTimer(cleanup_interval, lambda: cleanup(filename))
88+
timer.start()
89+
else:
90+
timer = None
91+
92+
with open(str(filename), mode="a+", newline="") as target:
7293
original_stdout_fd = 1
7394
original_stderr_fd = 2
7495
target_fd = target.fileno()
@@ -109,26 +130,12 @@ def capture_outputs(filename):
109130
os.dup2(saved_stdout_fd, original_stdout_fd)
110131
os.dup2(saved_stderr_fd, original_stderr_fd)
111132

112-
# wait for completion of the tee processes with timeout
113-
# implemented using a timer because timeout support is py3 only
114-
def kill_tees():
115-
tee_stdout.kill()
116-
tee_stderr.kill()
117-
118-
tee_timer = Timer(1, kill_tees)
119-
try:
120-
tee_timer.start()
121-
tee_stdout.wait()
122-
tee_stderr.wait()
123-
finally:
124-
tee_timer.cancel()
125-
133+
tee_stdout.wait(timeout=1)
134+
tee_stderr.wait(timeout=1)
126135
os.close(saved_stdout_fd)
127136
os.close(saved_stderr_fd)
128137

129-
# Cleanup log file
130-
with open(str(filename), "r") as target:
131-
text = target.read()
132-
text = apply_backspaces_and_linefeeds(text)
133-
with open(str(filename), "w") as target:
134-
target.write(text)
138+
if timer is not None:
139+
timer.cancel()
140+
141+
cleanup(filename)

0 commit comments

Comments
 (0)