Skip to content

BufferizationStage IR amplification for runtime-coefficient Hamiltonians under @qjit #2759

@yjmaxpayne

Description

@yjmaxpayne

Summary

When using qml.TrotterProduct with JAX-traced runtime coefficients inside @qjit, the BufferizationStage of the MLIR compilation pipeline produces a large IR size amplification. This is the dominant contributor to extreme compiler memory consumption (~19 GB compiler-subprocess RSS) for even small molecules.

With fixed (Python float) coefficients, BufferizationStage achieves ~1.0x amplification (constant-folding eliminates redundant IR). With runtime (JAX array) coefficients, the symbolic representation prevents constant-folding; for an H2 reference workload the IR grows from ~3.9 MB to ~60 MB at this single stage. The subsequent MLIRToLLVMDialectConversion then amplifies a further ~4.5x, yielding a total IR growth of ~62x (3.9 MB → 241 MB).

This makes runtime-parameterized Hamiltonians impractical for systems beyond the smallest molecules: H2 (15 Pauli terms) already consumes ~21 GB total peak; H3O+ (193 terms) extrapolates well beyond commodity 32 GB hardware.

The runtime-coefficient code path is the workaround community members are pointed to (cf. #1464) and is what allows electrostatic-embedding workflows (Monte Carlo solvation, VQE-style coefficient updates) to avoid full per-step recompilation. The memory cost at this single stage currently caps how large that workaround can scale.

Related: #1464 (qml.dot with runtime coefficients), #1507 (Hamiltonian aggregate-type lowering)

Minimal reproducible example

import numpy as np
import pennylane as qml
from jax import numpy as jnp

# --- Hamiltonian construction (numpy path, gives 15-term H2) ---
# qml.qchem.molecular_hamiltonian expects atomic units (Bohr).
# Note: passing the geometry as a JAX array currently disables term aggregation
# in molecular_hamiltonian and yields 47 SProd operands instead of 15. We use
# numpy here to keep the workload aligned with the measurements below.
ANGSTROM_TO_BOHR = 1.8897259886
symbols = ["H", "H"]
coords_angstrom = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.74]])
coords_bohr = (coords_angstrom * ANGSTROM_TO_BOHR).flatten()

H, n_qubits = qml.qchem.molecular_hamiltonian(
    symbols=symbols,
    coordinates=coords_bohr,
    mapping="jordan_wigner",
    active_electrons=2,
    active_orbitals=2,
)
assert len(H.operands) == 15  # H2 STO-3G / JW with active space (2e, 2o)

# coeffs is a JAX array, so it becomes a runtime tracer at compile time —
# this is the configuration that triggers the BufferizationStage amplification.
coeffs = jnp.array([float(op.scalar) for op in H.operands if isinstance(op, qml.ops.SProd)])
ops = [op.base for op in H.operands if isinstance(op, qml.ops.SProd)]

n_est = 4
n_trotter = 10
total_wires = n_qubits + n_est
dev = qml.device("lightning.qubit", wires=total_wires)

@qml.qjit(keep_intermediate=True)  # dump per-stage IR for diagnosis
@qml.qnode(dev)
def qpe_circuit(runtime_coeffs):
    H_rt = qml.dot(runtime_coeffs, ops)  # `ops` captured as compile-time closure
    for k in range(n_est):
        t = 2 ** (n_est - 1 - k)
        qml.ctrl(
            qml.adjoint(
                qml.TrotterProduct(
                    H_rt, time=t, n=n_trotter, order=2,
                    check_hermitian=False,  # required to admit JAX-tracer coeffs
                )
            ),
            control=n_qubits + k,
        )
    return qml.probs(wires=range(n_qubits, total_wires))

# Compilation consumes ~19 GB compiler-subprocess RSS (~21 GB total peak)
# for just 15 Pauli terms. Per-stage IR is written under the keep_intermediate
# directory; BufferizationStage shows the dominant ~9.9x growth.
result = qpe_circuit(coeffs)

Measured data

Environment: PennyLane 0.44.0, Catalyst 0.14.0, JAX 0.7.1, Linux x86_64, 30.4 GB RAM.
Measurement note: numbers below were collected with keep_intermediate=True to obtain per-stage IR sizes; production runs without keep_intermediate show similar end-to-end memory and time.

Per-stage IR size (H2, n_est=4, n_trotter=10, 15 Pauli terms)

Compilation Stage H_fixed (KB) H_dynamic (KB) Fixed amp. Dynamic amp.
mlir (input) 706 3,860
QuantumCompilationStage 1,629 3,782 2.3x 1.0x
HLOLoweringStage 696 6,101 0.4x 1.6x
BufferizationStage 696 60,320 1.0x 9.9x
MLIRToLLVMDialectConversion 2,487 273,442 3.6x 4.5x
LLVMIRTranslation 1,820 240,654 0.7x 0.9x
Total 706 → 1,820 3,860 → 240,654 2.6x 62.4x

End-to-end compilation metrics

Metric H_fixed H_dynamic Ratio
Compile time 27 s 306 s 11.2x
Python RSS 1.2 GB 2.2 GB 1.9x
Compiler subprocess RSS 0.9 GB 19.1 GB 21.4x
Total peak (parent+child) 2.0 GB 21.3 GB 10.3x

System scaling (dynamic mode)

System Pauli Terms IR ops (lower bound) Memory Status
H2 (4-bit) 15 600 21.3 GB Verified
H2 (8-bit) 15 1,200 OOM @ n_trotter ≥ 7 Verified
H3O+ 193 ~7,720 OOM (extrapolated) Not attempted

Impact

This is currently a blocking issue for any quantum-chemistry workflow that requires runtime-parameterized Hamiltonians under Catalyst:

  1. QM/MM workflows: Monte Carlo solvation dynamics requires updating Hamiltonian coefficients at each step (electrostatic embedding). Without runtime parameterization the alternative is full per-step recompilation (~67 s/step in our setup).
  2. VQE / variational methods: any workflow where Hamiltonian coefficients depend on classical parameters.
  3. System-size ceiling: even H2 (15 Pauli terms) consumes ~21 GB. H3O+ (193 terms) is infeasible on commodity hardware.

Suggested directions

  1. Sparse / structured IR representation: detect that runtime coefficients multiply structurally-identical operator blocks and factor them at BufferizationStage rather than expanding each symbolically.
  2. Lazy lowering: defer constant-folding-dependent optimizations to JIT time when actual coefficient values are available.
  3. Pre-compilation IR-size estimate: expose an API that estimates IR size before compilation so users can guard against OOM.
  4. Operator fusion: combine repeated TrotterProduct blocks that share structure but differ only in the time parameter.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions