Skip to content

Commit 4a4c251

Browse files
borisfomKumoLiu
andauthored
Removed CPU randn() from schedulers (#8145)
Fixes performance issued due to extra CPU/GPU sync: https://nvbugswb.nvidia.com/NvBugs5/SWBug.aspx?bugid=4904446&cmtNo= --------- Signed-off-by: Boris Fomitchev <[email protected]> Signed-off-by: YunLiu <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent 5002fd9 commit 4a4c251

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

Diff for: monai/networks/schedulers/ddim.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def step(
220220
if eta > 0:
221221
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
222222
device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu")
223-
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
223+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator, device=device)
224224
variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise
225225

226226
pred_prev_sample = pred_prev_sample + variance

Diff for: monai/networks/schedulers/ddpm.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,12 @@ def step(
241241
variance = 0
242242
if timestep > 0:
243243
noise = torch.randn(
244-
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
245-
).to(model_output.device)
244+
model_output.size(),
245+
dtype=model_output.dtype,
246+
layout=model_output.layout,
247+
generator=generator,
248+
device=model_output.device,
249+
)
246250
variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise
247251

248252
pred_prev_sample = pred_prev_sample + variance

Diff for: requirements-dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ isort>=5.1
2222
ruff
2323
pytype>=2020.6.1; platform_system != "Windows"
2424
types-setuptools
25-
mypy>=1.5.0
25+
mypy>=1.5.0, <1.12.0
2626
ninja
2727
torchvision
2828
psutil

0 commit comments

Comments
 (0)