Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def setup(self, opt):

# Wrap networks with DDP after loading
if dist.is_initialized():
# Check if using syncbatch normalization for DDP
if self.opt.norm == "syncbatch":
raise ValueError(f"For distributed training, opt.norm must be 'syncbatch' or 'inst', but got '{self.opt.norm}'. " "Please set --norm syncbatch for multi-GPU training.")
# Plain BatchNorm does not synchronize stats across processes.
if self.opt.norm == "batch":
raise ValueError("For distributed training, --norm batch is not supported. Please use --norm syncbatch or --norm instance.")

net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[self.device.index])
# Sync all processes after DDP wrapping
Expand Down
2 changes: 1 addition & 1 deletion models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def lambda_rule(epoch):
elif opt.lr_policy == "cosine":
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
else:
return NotImplementedError("learning rate policy [%s] is not implemented", opt.lr_policy)
raise NotImplementedError(f"learning rate policy [{opt.lr_policy}] is not implemented")
return scheduler


Expand Down
87 changes: 87 additions & 0 deletions tests/test_reliability_regressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from types import SimpleNamespace

import pytest
import torch

from models import networks
from models.base_model import BaseModel
import util.visualizer as visualizer_module


def test_get_scheduler_invalid_policy_raises():
param = torch.nn.Parameter(torch.tensor(1.0))
optimizer = torch.optim.SGD([param], lr=0.01)
opt = SimpleNamespace(lr_policy="invalid", epoch_count=1, n_epochs=1, n_epochs_decay=1, lr_decay_iters=1)

with pytest.raises(NotImplementedError, match=r"learning rate policy \[invalid\] is not implemented"):
networks.get_scheduler(optimizer, opt)


class DummyModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train):
return parser

def __init__(self, opt):
super().__init__(opt)
self.model_names = ["G"]
self.netG = torch.nn.Conv2d(3, 3, kernel_size=1)

def set_input(self, input_data):
pass

def forward(self):
pass

def optimize_parameters(self):
pass


def test_ddp_rejects_plain_batch_norm(monkeypatch, tmp_path):
monkeypatch.setattr("models.base_model.networks.init_net", lambda net, init_type, init_gain: net)
monkeypatch.setattr("models.base_model.dist.is_initialized", lambda: True)

opt = SimpleNamespace(
isTrain=True,
checkpoints_dir=str(tmp_path),
name="ddp_norm_guard",
device=torch.device("cpu"),
preprocess="resize_and_crop",
init_type="normal",
init_gain=0.02,
continue_train=False,
norm="batch",
verbose=False,
)

model = DummyModel(opt)
with pytest.raises(ValueError, match="--norm batch is not supported"):
model.setup(opt)


def _make_visualizer_opt(tmp_path, use_wandb):
checkpoints_dir = tmp_path / "checkpoints"
(checkpoints_dir / "exp").mkdir(parents=True, exist_ok=True)
return SimpleNamespace(
isTrain=True,
no_html=True,
display_winsize=256,
name="exp",
use_wandb=use_wandb,
checkpoints_dir=str(checkpoints_dir),
wandb_project_name="test",
)


def test_visualizer_allows_disabled_wandb_when_missing(monkeypatch, tmp_path):
monkeypatch.setattr(visualizer_module, "wandb", None)
opt = _make_visualizer_opt(tmp_path, use_wandb=False)
visualizer = visualizer_module.Visualizer(opt)
assert visualizer.use_wandb is False


def test_visualizer_raises_when_wandb_requested_but_missing(monkeypatch, tmp_path):
monkeypatch.setattr(visualizer_module, "wandb", None)
opt = _make_visualizer_opt(tmp_path, use_wandb=True)
with pytest.raises(ImportError, match="wandb package cannot be found"):
visualizer_module.Visualizer(opt)
8 changes: 7 additions & 1 deletion util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
import time
from . import util, html
from pathlib import Path
import wandb
import os
import torch.distributed as dist

try:
import wandb
except ImportError:
wandb = None


def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
"""Save images to the disk.
Expand Down Expand Up @@ -63,6 +67,8 @@ def __init__(self, opt):

# Initialize wandb if enabled
if self.use_wandb:
if wandb is None:
raise ImportError('wandb package cannot be found. Install "wandb" or run without --use_wandb.')
# Only initialize wandb on main process (rank 0)
if not dist.is_initialized() or dist.get_rank() == 0:
self.wandb_project_name = getattr(opt, "wandb_project_name", "CycleGAN-and-pix2pix")
Expand Down