diff --git a/models/base_model.py b/models/base_model.py index e4aff7ee321..65bddbb4a4c 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -114,9 +114,9 @@ def setup(self, opt): # Wrap networks with DDP after loading if dist.is_initialized(): - # Check if using syncbatch normalization for DDP - if self.opt.norm == "syncbatch": - raise ValueError(f"For distributed training, opt.norm must be 'syncbatch' or 'inst', but got '{self.opt.norm}'. " "Please set --norm syncbatch for multi-GPU training.") + # Plain BatchNorm does not synchronize stats across processes. + if self.opt.norm == "batch": + raise ValueError("For distributed training, --norm batch is not supported. Please use --norm syncbatch or --norm instance.") net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[self.device.index]) # Sync all processes after DDP wrapping diff --git a/models/networks.py b/models/networks.py index ec1686f6492..96d53e904ed 100644 --- a/models/networks.py +++ b/models/networks.py @@ -67,7 +67,7 @@ def lambda_rule(epoch): elif opt.lr_policy == "cosine": scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) else: - return NotImplementedError("learning rate policy [%s] is not implemented", opt.lr_policy) + raise NotImplementedError(f"learning rate policy [{opt.lr_policy}] is not implemented") return scheduler diff --git a/tests/test_reliability_regressions.py b/tests/test_reliability_regressions.py new file mode 100644 index 00000000000..538251679fe --- /dev/null +++ b/tests/test_reliability_regressions.py @@ -0,0 +1,87 @@ +from types import SimpleNamespace + +import pytest +import torch + +from models import networks +from models.base_model import BaseModel +import util.visualizer as visualizer_module + + +def test_get_scheduler_invalid_policy_raises(): + param = torch.nn.Parameter(torch.tensor(1.0)) + optimizer = torch.optim.SGD([param], lr=0.01) + opt = SimpleNamespace(lr_policy="invalid", epoch_count=1, n_epochs=1, n_epochs_decay=1, lr_decay_iters=1) + + with pytest.raises(NotImplementedError, match=r"learning rate policy \[invalid\] is not implemented"): + networks.get_scheduler(optimizer, opt) + + +class DummyModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def __init__(self, opt): + super().__init__(opt) + self.model_names = ["G"] + self.netG = torch.nn.Conv2d(3, 3, kernel_size=1) + + def set_input(self, input_data): + pass + + def forward(self): + pass + + def optimize_parameters(self): + pass + + +def test_ddp_rejects_plain_batch_norm(monkeypatch, tmp_path): + monkeypatch.setattr("models.base_model.networks.init_net", lambda net, init_type, init_gain: net) + monkeypatch.setattr("models.base_model.dist.is_initialized", lambda: True) + + opt = SimpleNamespace( + isTrain=True, + checkpoints_dir=str(tmp_path), + name="ddp_norm_guard", + device=torch.device("cpu"), + preprocess="resize_and_crop", + init_type="normal", + init_gain=0.02, + continue_train=False, + norm="batch", + verbose=False, + ) + + model = DummyModel(opt) + with pytest.raises(ValueError, match="--norm batch is not supported"): + model.setup(opt) + + +def _make_visualizer_opt(tmp_path, use_wandb): + checkpoints_dir = tmp_path / "checkpoints" + (checkpoints_dir / "exp").mkdir(parents=True, exist_ok=True) + return SimpleNamespace( + isTrain=True, + no_html=True, + display_winsize=256, + name="exp", + use_wandb=use_wandb, + checkpoints_dir=str(checkpoints_dir), + wandb_project_name="test", + ) + + +def test_visualizer_allows_disabled_wandb_when_missing(monkeypatch, tmp_path): + monkeypatch.setattr(visualizer_module, "wandb", None) + opt = _make_visualizer_opt(tmp_path, use_wandb=False) + visualizer = visualizer_module.Visualizer(opt) + assert visualizer.use_wandb is False + + +def test_visualizer_raises_when_wandb_requested_but_missing(monkeypatch, tmp_path): + monkeypatch.setattr(visualizer_module, "wandb", None) + opt = _make_visualizer_opt(tmp_path, use_wandb=True) + with pytest.raises(ImportError, match="wandb package cannot be found"): + visualizer_module.Visualizer(opt) diff --git a/util/visualizer.py b/util/visualizer.py index a3573faec92..559433018b9 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -4,10 +4,14 @@ import time from . import util, html from pathlib import Path -import wandb import os import torch.distributed as dist +try: + import wandb +except ImportError: + wandb = None + def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): """Save images to the disk. @@ -63,6 +67,8 @@ def __init__(self, opt): # Initialize wandb if enabled if self.use_wandb: + if wandb is None: + raise ImportError('wandb package cannot be found. Install "wandb" or run without --use_wandb.') # Only initialize wandb on main process (rank 0) if not dist.is_initialized() or dist.get_rank() == 0: self.wandb_project_name = getattr(opt, "wandb_project_name", "CycleGAN-and-pix2pix")