We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b54c37c commit 3e976efCopy full SHA for 3e976ef
k_diffusion/utils.py
@@ -40,7 +40,10 @@ def append_dims(x, target_dims):
40
dims_to_append = target_dims - x.ndim
41
if dims_to_append < 0:
42
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
43
- return x[(...,) + (None,) * dims_to_append]
+ expanded = x[(...,) + (None,) * dims_to_append]
44
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
45
+ # https://github.com/pytorch/pytorch/issues/84364
46
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
47
48
49
def n_params(module):
0 commit comments