We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f90952a commit 7f16b2cCopy full SHA for 7f16b2c
k_diffusion/utils.py
@@ -42,7 +42,10 @@ def append_dims(x, target_dims):
42
dims_to_append = target_dims - x.ndim
43
if dims_to_append < 0:
44
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
45
- return x[(...,) + (None,) * dims_to_append]
+ expanded = x[(...,) + (None,) * dims_to_append]
46
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
47
+ # https://github.com/pytorch/pytorch/issues/84364
48
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
49
50
51
def n_params(module):
0 commit comments