Description
TEP - Krylov Subspace Time Evolution
Author
@refraction-ray
Status
Draft
Created
<2025-03-27>
Abstract
This TEP proposes the addition of a Krylov subspace method for time evolution of quantum states under a given Hamiltonian. This method provides an approximation to the exact evolution krylov_evol
, will be designed to integrate seamlessly with TensorCircuit's machine learning backend, supporting automatic differentiation (AD), vectorized mapping (vmap), just-in-time (JIT) compilation, and GPU acceleration. Its interface will mirror the existing hamiltonian_evol
function for ease of use.
Motivation and Scope
Simulating the time evolution of quantum systems is a fundamental task in quantum physics and quantum computing. TensorCircuit currently provides hamiltonian_evol which relies on computing the full matrix exponential (expm) of the Hamiltonian. While exact, this approach scales poorly with the system size N (Hilbert space dimension d=2^N), typically requiring O(d³) computational complexity and O(d²) memory, making it infeasible for systems beyond a small number of qubits.
Krylov subspace methods offer a powerful alternative. They approximate the action of the matrix exponential on a vector, expm(A)v, by projecting the problem onto a low-dimensional Krylov subspace K_m(A, v) = span{v, Av, A²v, ..., A^(m-1)v}, where m << d. The core idea is that for many relevant initial states and Hamiltonians, the evolved state psi(t) lies close to this subspace. The evolution is then approximated by exponentiating a much smaller m x m matrix obtained from the projection of H onto the Krylov basis. Common algorithms for constructing this basis include Lanczos (for Hermitian matrices) and Arnoldi (for general matrices).
Scope:
- Implement a function
krylov_evol
with an interface similar tohamiltonian_evol
. - The core algorithm will likely be based on the Lanczos iteration to build the Krylov basis.
- The implementation must use TensorCircuit's backend primitives (e.g., TensorFlow, JAX, PyTorch via tensornetwork) to ensure compatibility with AD, vmap, JIT, and GPU execution.
- Numerical stability (e.g., using re-orthogonalization) should be considered during implementation.
Usage and Impact
krylov_results = tc.experimental.krylov_evol(
h=h,
psi0=psi0,
tlist=np.arange(0, 10, 0.1),
m=30,
callback=compute_expectation
)
Impact: get ready for more numerical solvers for Rydberg analog simulation.
Backward compatibility
No issue.
Related Work
- qutip
- quimb
Implementation
- Core Krylov Basis Generation:
- Implement the Lanczos (for Hermitian H) or Arnoldi (general case, potentially useful for non-Hermitian effective Hamiltonians or Liouvillians later) iteration using backend-compatible operations.
- Utilize constructs like jax.lax.scan, tf.scan, or equivalent PyTorch mechanisms for the iterative process to ensure JIT compatibility.
- Input: Hamiltonian H, initial vector psi0, dimension m.
- Output: Orthonormal basis V_m (matrix [v_0, v_1, ..., v_{m-1}]) and the projected matrix T_m (tridiagonal for Lanczos, upper Hessenberg for Arnoldi).
- Address numerical stability: Implement Gram-Schmidt orthogonalization carefully.
- Subspace Evolution:
- Compute the matrix exponential of the small (m x m) projected matrix T_m using the backend's expm function (tf.linalg.expm, jax.scipy.linalg.expm, torch.linalg.expm). Since m is small, this is computationally cheap.
- Compute exp(-i * H * t) @ psi0 ≈ V_m @ expm(-i * T_m * t) @ e_1 * norm(psi0), where e_1 is the first standard basis vector [1, 0, ..., 0]^T in the m-dimensional subspace. Note that psi0 = V_m @ e_1 * norm(psi0) because v_0 = psi0 / norm(psi0).
- Time Stepping and Callback:
- Loop or scan over the provided tlist. For each t:
- Calculate exp(-i * T_m * t).
- Compute the approximate state psi(t) = V_m @ (expm(-i * T_m * t) @ e_1) * norm(psi0). Ensure correct handling of shapes and potential batch dimensions inherited from psi0.
- If a callback function is provided, call it with psi(t) and store the result.
- Structure the time stepping to be efficient and compatible with vmap and JIT (e.g., potentially using jax.lax.map or equivalent).
- Loop or scan over the provided tlist. For each t:
- Function Interface:
- Create the krylov_evol(h, psi0, tlist, m, callback=None) function signature.
- Include checks for input validity (e.g., m > 0, shapes).
- Handle potential batch dimension in psi0. vmap should ideally handle batching over parameters in h or multiple tlists and psi0.
- Testing:
- Accuracy tests: Compare results against hamiltonian_evol for small systems where expm is feasible. Verify convergence as m increases.
- Backend compatibility tests: Run tests across TensorFlow, JAX, and PyTorch backends.
- AD tests: Verify gradient computation w.r.t. parameters in h and psi0.
- vmap tests: Verify vectorization over a batch dimension in psi0.
- JIT tests: Ensure the function can be JIT-compiled (especially important for JAX).
- GPU tests: Confirm execution on GPU hardware.