diff --git a/deeplens/optics/wave.py b/deeplens/optics/wave.py index ecca83c..d0f40aa 100644 --- a/deeplens/optics/wave.py +++ b/deeplens/optics/wave.py @@ -377,7 +377,11 @@ def AngularSpectrumMethod(u, z, wvln, ps, n=1.0, padding=True): torch.linspace(0.5 / ps, -0.5 / ps, Himg, device=u.device), indexing="xy", ) - square_root = torch.sqrt(1 - wvln_mm**2 * (fx**2 + fy**2)) + # Use complex128 as the squareroot operator to avoid nan values. + operator = 1 - wvln_mm ** 2 * (fx ** 2 + fy ** 2) + operator = operator.to(torch.complex128) + square_root = torch.sqrt(operator) + H = torch.exp(1j * k * z * square_root) H = ifftshift(H)