diff --git a/main.py b/main.py index 5f07db6..87dc550 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,19 @@ def main(args): cudnn.benchmark = True - data_loader, class_mask = build_continual_dataloader(args) + if args.distributed: + if utils.is_main_process(): + # prepare datasets on main process first to avoid race condition on downloading datasets + data_loader, class_mask = build_continual_dataloader(args) + + # wait until the main process complete + torch.distributed.barrier() + + # let other processes prepare datasets + if not utils.is_main_process(): + data_loader, class_mask = build_continual_dataloader(args) + else: + data_loader, class_mask = build_continual_dataloader(args) print(f"Creating original model: {args.model}") original_model = create_model(