Skip to content

Commit 3e976ef

Browse files
committed
MPS workaround for inf values stemming from pytorch/pytorch#84364
1 parent b54c37c commit 3e976ef

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

k_diffusion/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def append_dims(x, target_dims):
4040
dims_to_append = target_dims - x.ndim
4141
if dims_to_append < 0:
4242
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
43-
return x[(...,) + (None,) * dims_to_append]
43+
expanded = x[(...,) + (None,) * dims_to_append]
44+
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
45+
# https://github.com/pytorch/pytorch/issues/84364
46+
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
4447

4548

4649
def n_params(module):

0 commit comments

Comments
 (0)