We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ddbd188 commit 55f3789Copy full SHA for 55f3789
ignite/metrics/gan/fid.py
@@ -31,13 +31,13 @@ def fid_score(
31
except ImportError:
32
raise ModuleNotFoundError("fid_score requires scipy to be installed.")
33
34
- mu1, mu2 = mu1.cpu(), mu2.cpu()
35
- sigma1, sigma2 = sigma1.cpu(), sigma2.cpu()
+ mu1, mu2 = mu1.detach().cpu(), mu2.detach().cpu()
+ sigma1, sigma2 = sigma1.detach().cpu(), sigma2.detach().cpu()
36
37
diff = mu1 - mu2
38
39
# Product might be almost singular
40
- covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False)
+ covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2).numpy(), disp=False)
41
# Numerical error might give slight imaginary component
42
if np.iscomplexobj(covmean):
43
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
0 commit comments