Skip to content

Commit 201026a

Browse files
authored
Merge pull request #178 from gautierronan/master
Allow complex numbers in odeint
2 parents a71cdcf + 52566f9 commit 201026a

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

test/models/test_ode.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,40 @@ def forward(self, t, x, u, v, z, args={}):
258258
t_eval, sol2 = odeprob(x0, t_span=torch.linspace(0, 5, 10))
259259

260260
assert (sol1==sol2).all()
261-
grad(sol2.sum(), x0)
261+
grad(sol2.sum(), x0)
262+
263+
264+
def test_complex_ode():
265+
"""Test odeint for complex numbers with a simple complex-valued ODE, corresponding
266+
to Rabi oscillations of quantum two-level system."""
267+
class Rabi(nn.Module):
268+
def __init__(self, omega):
269+
super().__init__()
270+
self.sx = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex128)
271+
self.omega = omega
272+
return
273+
def forward(self, t, x):
274+
dx = -1.0j * self.omega * self.sx @ x
275+
dx += dx.adjoint()
276+
return dx
277+
278+
# Odeint parameters
279+
omega = torch.randn(1)
280+
rabi = Rabi(omega)
281+
tspan = torch.linspace(0., 2., 10)
282+
283+
# Random initial state
284+
x0 = torch.rand(2, 2, dtype=torch.complex128)
285+
x0 = 0.5 * (x0 + x0.adjoint()) / torch.real(x0.trace())
286+
# Solve the ODE problem
287+
t_eval, sol = odeint(f=rabi, x=x0, t_span=tspan, solver="dopri5", atol=1e-8, rtol=1e-6)
288+
289+
# Expected solution
290+
sx = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex128)
291+
si = torch.tensor([[1, 0], [0, 1]], dtype=torch.complex128)
292+
U_t = torch.cos(omega * t_eval)[:, None, None] * si
293+
U_t += -1j * torch.sin(omega * t_eval)[:, None, None] * sx
294+
sol_exp = U_t @ x0 @ U_t.adjoint()
295+
296+
# Check result
297+
assert torch.allclose(sol, sol_exp, rtol=1e-5, atol=1e-5)

torchdyn/numerics/odeint.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,6 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n
6565
x, t_span = solver.sync_device_dtype(x, t_span)
6666
stepping_class = solver.stepping_class
6767

68-
# instantiate save_at tensor
69-
if len(save_at) == 0: save_at = t_span
70-
if not isinstance(save_at, torch.Tensor):
71-
save_at = torch.tensor(save_at)
72-
7368
# instantiate the interpolator similar to the solver steps above
7469
if isinstance(solver, Tsitouras45):
7570
if verbose: warn("Running interpolation not yet implemented for `tsit5`")
@@ -87,6 +82,10 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n
8782
if stepping_class == 'fixed':
8883
if atol != odeint.__defaults__[0] or rtol != odeint.__defaults__[1]:
8984
warn("Setting tolerances has no effect on fixed-step methods")
85+
# instantiate save_at tensor
86+
if len(save_at) == 0: save_at = t_span
87+
if not isinstance(save_at, torch.Tensor):
88+
save_at = torch.tensor(save_at)
9089
return _fixed_odeint(f_, x, t_span, solver, save_at=save_at, args=args)
9190
elif stepping_class == 'adaptive':
9291
t = t_span[0]

torchdyn/numerics/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,23 @@ def norm_(aug_state):
3131

3232

3333
def hairer_norm(tensor):
34-
return tensor.pow(2).mean().sqrt()
34+
return tensor.abs().pow(2).mean().sqrt()
3535

3636

3737
def init_step(f, f0, x0, t0, order, atol, rtol):
3838
scale = atol + torch.abs(x0) * rtol
3939
d0, d1 = hairer_norm(x0 / scale), hairer_norm(f0 / scale)
4040

4141
if d0 < 1e-5 or d1 < 1e-5:
42-
h0 = torch.tensor(1e-6, dtype=x0.dtype, device=x0.device)
42+
h0 = torch.tensor(1e-6, dtype=t0.dtype, device=t0.device)
4343
else:
4444
h0 = 0.01 * d0 / d1
4545

4646
x_new = x0 + h0 * f0
4747
f_new = f(t0 + h0, x_new)
4848
d2 = hairer_norm((f_new - f0) / scale) / h0
4949
if d1 <= 1e-15 and d2 <= 1e-15:
50-
h1 = torch.max(torch.tensor(1e-6, dtype=x0.dtype, device=x0.device), h0 * 1e-3)
50+
h1 = torch.max(torch.tensor(1e-6, dtype=t0.dtype, device=t0.device), h0 * 1e-3)
5151
else:
5252
h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1))
5353
dt = torch.min(100 * h0, h1).to(t0)

0 commit comments

Comments
 (0)