diff --git a/models/base_model.py b/models/base_model.py index e4aff7ee321..5498db8fb0e 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -114,9 +114,9 @@ def setup(self, opt): # Wrap networks with DDP after loading if dist.is_initialized(): - # Check if using syncbatch normalization for DDP - if self.opt.norm == "syncbatch": - raise ValueError(f"For distributed training, opt.norm must be 'syncbatch' or 'inst', but got '{self.opt.norm}'. " "Please set --norm syncbatch for multi-GPU training.") + # Check if using syncbatch or instance normalization for DDP + if self.opt.norm != "syncbatch" and self.opt.norm != "instance": + raise ValueError(f"For distributed training, opt.norm must be 'syncbatch' or 'instance', but got '{self.opt.norm}'. " "Please set --norm syncbatch for multi-GPU training.") net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[self.device.index]) # Sync all processes after DDP wrapping