Skip to content

Fix compilation of Jacobian with for loops and vmap #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
6718c9e
Fix compilation
rmoyard Oct 25, 2023
105c18f
Undo
rmoyard Oct 25, 2023
a30633e
Merge branch 'main' into fix_forloop_enzyme
rmoyard Mar 18, 2024
7fb4c87
Merge branch 'main' into fix_forloop_enzyme
rmoyard Mar 19, 2024
1ec7b7d
Merge branch 'main' into fix_forloop_enzyme
rmoyard Mar 19, 2024
655d889
Add integration test
rmoyard Mar 19, 2024
e488748
Black
rmoyard Mar 19, 2024
bc3da9b
Update jax
rmoyard Mar 19, 2024
e3a44ec
Apply suggestions from code review
rmoyard Mar 19, 2024
02a49ab
Readd qjit
rmoyard Mar 19, 2024
b0439ba
Rename test variables
rmoyard Mar 19, 2024
6e8681f
Update atol
rmoyard Mar 19, 2024
45dbf25
Update mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
rmoyard Mar 19, 2024
c4f30b2
Merge branch 'main' into fix_forloop_enzyme
rmoyard Mar 19, 2024
7499729
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 2, 2024
c0e824c
Enzyme version
rmoyard Jul 3, 2024
bdd144d
Update Enzyme
rmoyard Jul 3, 2024
103dd3e
Merge branch 'main' into update-enzyme-130
rmoyard Jul 4, 2024
644808f
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 4, 2024
d0d00f1
Merge branch 'main' of https://github.com/PennyLaneAI/catalyst into u…
rmoyard Jul 4, 2024
8afc5d7
Add no free
rmoyard Jul 4, 2024
66fe89d
Merge branch 'update-enzyme-130' of https://github.com/PennyLaneAI/ca…
rmoyard Jul 4, 2024
e0b12a5
Merge branch 'update-enzyme-130' into fix_forloop_enzyme
rmoyard Jul 4, 2024
345b024
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 9, 2024
2bce302
Merge branch 'fix_forloop_enzyme' of https://github.com/PennyLaneAI/c…
rmoyard Jul 9, 2024
820b7b2
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 16, 2024
ef437bf
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 16, 2024
b7fd2e1
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 19, 2024
0e653b2
Merge branch 'fix_forloop_enzyme' of https://github.com/PennyLaneAI/c…
rmoyard Jul 19, 2024
c930618
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 23, 2024
9a8191b
Merge branch 'fix_forloop_enzyme' of https://github.com/PennyLaneAI/c…
rmoyard Jul 23, 2024
2fdb9a3
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 23, 2024
811170f
Merge branch 'fix_forloop_enzyme' of https://github.com/PennyLaneAI/c…
rmoyard Jul 23, 2024
07fd2cc
Black
rmoyard Jul 23, 2024
ed96284
Xfail test
rmoyard Jul 23, 2024
2813cd3
Add test
rmoyard Jul 23, 2024
a53b113
Changelog
rmoyard Jul 23, 2024
c208cd3
update
rmoyard Jul 23, 2024
7862ae2
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 30, 2024
a3f1947
Merge branch 'main' into fix_forloop_enzyme
rmoyard Aug 2, 2024
510c348
Merge branch 'main' into fix_forloop_enzyme
rmoyard Aug 2, 2024
0a419d1
Merge branch 'main' into fix_forloop_enzyme
rmoyard Aug 2, 2024
f735df0
Merge
rmoyard Aug 30, 2024
5105eec
Merge
rmoyard Aug 30, 2024
238b280
Black
rmoyard Aug 30, 2024
ec39190
Move test in gradient
rmoyard Aug 30, 2024
c409819
Merge branch 'main' into fix_forloop_enzyme
rmoyard Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def compute_jvp(self, primals, tangents):
# Optimization: Do not compute Jacobians for arguments which do not participate in
# differentiation.
argnums = []
tangents, _ = tree_flatten(tangents)
for idx, tangent in enumerate(tangents):
if not isinstance(tangent, jax.custom_derivatives.SymbolicZero):
argnums.append(idx)
Expand Down
57 changes: 56 additions & 1 deletion frontend/test/pytest/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import jax.numpy as jnp
import pennylane as qml
import pytest
from jax.tree_util import tree_flatten

from catalyst import qjit, vmap
from catalyst import for_loop, grad, qjit, vmap


class TestVectorizeMap:
Expand Down Expand Up @@ -752,3 +753,57 @@ def circuit(x):
match="Invalid batch size; it must be a non-zero integer, but got 0.",
):
qjit(workflow)(x)

def test_vmap_worflow_derivation(self, backend):
"""Check the gradient of a vmap workflow"""
n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev = qml.device(backend, wires=n_wires)

@qml.qnode(dev)
def circuit(data, weights):
"""Quantum circuit ansatz"""

@for_loop(0, n_wires, 1)
def data_embedding(i):
qml.RY(data[i], wires=i)

data_embedding()

@for_loop(0, n_wires, 1)
def ansatz(i):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])

ansatz()

return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

circuit = qml.qjit(vmap(circuit, in_axes=(1, None)), autograph=False)

def my_model(data, weights, bias):
return circuit(data, weights) + bias

@qml.qjit
def loss_fn(params, data, targets):
predictions = my_model(data, params["weights"], params["bias"])
loss = jnp.sum((targets - predictions) ** 2 / len(data))
return loss

weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.0)
params = {"weights": weights, "bias": bias}

results_enzyme = qml.qjit(grad(loss_fn))(params, data, targets)
results_fd = jax.grad(loss_fn)(params, data, targets)

data_enzyme, pytree_enzyme = tree_flatten(results_enzyme)
data_fd, pytree_fd = tree_flatten(results_fd)

assert pytree_enzyme == pytree_fd
assert jnp.allclose(data_enzyme[0], data_fd[0])
assert jnp.allclose(data_enzyme[1], data_fd[1])
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ FailureOr<func::FuncOp> HybridGradientLowering::cloneCallee(PatternRewriter &rew
funcOp.walk([&](func::CallOp callOp) {
if (callOp.getCallee() == qnode.getName()) {
PatternRewriter::InsertionGuard insertionGuard(rewriter);
rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
rewriter.setInsertionPoint(callOp);
Value paramCount =
rewriter
.create<func::CallOp>(loc, paramCountFn, callOp.getArgOperands())
Expand Down