Skip to content

Commit 1d7f83d

Browse files
committed
Fix MPS
1 parent f41b1e3 commit 1d7f83d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ignite/metrics/gan/fid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def fid_score(
3131
except ImportError:
3232
raise ModuleNotFoundError("fid_score requires scipy to be installed.")
3333

34-
mu1, mu2 = mu1.cpu(), mu2.cpu()
35-
sigma1, sigma2 = sigma1.cpu(), sigma2.cpu()
34+
mu1, mu2 = mu1.detach().cpu(), mu2.detach().cpu()
35+
sigma1, sigma2 = sigma1.detach().cpu(), sigma2.detach().cpu()
3636

3737
diff = mu1 - mu2
3838

3939
# Product might be almost singular
40-
covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False)
40+
covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2).numpy(), disp=False)
4141
# Numerical error might give slight imaginary component
4242
if np.iscomplexobj(covmean):
4343
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):

0 commit comments

Comments
 (0)