A Python toolkit that uses JAX to model and optimise quantum processes with indefinite causal order.
ICOJax is an open-source Python toolkit that uses JAX to model, validate, and optimise quantum processes with indefinite causal order. It bundles fast, GPU-/TPU-ready linear-algebra primitives (e.g. Kronecker products, partial traces, Hilbert-space permutations) together with purpose-built projectors that map arbitrary matrices onto the convex sub-space of valid process matrices and quantum channels. On top of these building blocks it implements a differentiable optimisation engine, powered by Optax and Jaxopt, that jointly learns Alice/Bob instruments, and the underlying process matrix by parametrising them through a deep neural network, so as to maximise the violation of user-supplied causal inequalities. Because every step is JIT-compiled and fully differentiable, ICOJax achieves >100× speed-ups compared with NumPy baselines while remaining hardware-agnostic. In addition, given the possibility of optimising nonlinear functions, it allows for the exploration of process matrices which traditional techniques such as semidefinite programming cannot tackle, such as optimising over rank 1 process matrices, which satisfy physical properties of interest.