Skip to content

Fix compilation of Jacobian with for loops and vmap #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
6718c9e
Fix compilation
rmoyard Oct 25, 2023
105c18f
Undo
rmoyard Oct 25, 2023
a30633e
Merge branch 'main' into fix_forloop_enzyme
rmoyard Mar 18, 2024
7fb4c87
Merge branch 'main' into fix_forloop_enzyme
rmoyard Mar 19, 2024
1ec7b7d
Merge branch 'main' into fix_forloop_enzyme
rmoyard Mar 19, 2024
655d889
Add integration test
rmoyard Mar 19, 2024
e488748
Black
rmoyard Mar 19, 2024
bc3da9b
Update jax
rmoyard Mar 19, 2024
e3a44ec
Apply suggestions from code review
rmoyard Mar 19, 2024
02a49ab
Readd qjit
rmoyard Mar 19, 2024
b0439ba
Rename test variables
rmoyard Mar 19, 2024
6e8681f
Update atol
rmoyard Mar 19, 2024
45dbf25
Update mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
rmoyard Mar 19, 2024
c4f30b2
Merge branch 'main' into fix_forloop_enzyme
rmoyard Mar 19, 2024
7499729
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 2, 2024
c0e824c
Enzyme version
rmoyard Jul 3, 2024
bdd144d
Update Enzyme
rmoyard Jul 3, 2024
103dd3e
Merge branch 'main' into update-enzyme-130
rmoyard Jul 4, 2024
644808f
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 4, 2024
d0d00f1
Merge branch 'main' of https://github.com/PennyLaneAI/catalyst into u…
rmoyard Jul 4, 2024
8afc5d7
Add no free
rmoyard Jul 4, 2024
66fe89d
Merge branch 'update-enzyme-130' of https://github.com/PennyLaneAI/ca…
rmoyard Jul 4, 2024
e0b12a5
Merge branch 'update-enzyme-130' into fix_forloop_enzyme
rmoyard Jul 4, 2024
345b024
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 9, 2024
2bce302
Merge branch 'fix_forloop_enzyme' of https://github.com/PennyLaneAI/c…
rmoyard Jul 9, 2024
820b7b2
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 16, 2024
ef437bf
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 16, 2024
b7fd2e1
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 19, 2024
0e653b2
Merge branch 'fix_forloop_enzyme' of https://github.com/PennyLaneAI/c…
rmoyard Jul 19, 2024
c930618
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 23, 2024
9a8191b
Merge branch 'fix_forloop_enzyme' of https://github.com/PennyLaneAI/c…
rmoyard Jul 23, 2024
2fdb9a3
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 23, 2024
811170f
Merge branch 'fix_forloop_enzyme' of https://github.com/PennyLaneAI/c…
rmoyard Jul 23, 2024
07fd2cc
Black
rmoyard Jul 23, 2024
ed96284
Xfail test
rmoyard Jul 23, 2024
2813cd3
Add test
rmoyard Jul 23, 2024
a53b113
Changelog
rmoyard Jul 23, 2024
c208cd3
update
rmoyard Jul 23, 2024
7862ae2
Merge branch 'main' into fix_forloop_enzyme
rmoyard Jul 30, 2024
a3f1947
Merge branch 'main' into fix_forloop_enzyme
rmoyard Aug 2, 2024
510c348
Merge branch 'main' into fix_forloop_enzyme
rmoyard Aug 2, 2024
0a419d1
Merge branch 'main' into fix_forloop_enzyme
rmoyard Aug 2, 2024
f735df0
Merge
rmoyard Aug 30, 2024
5105eec
Merge
rmoyard Aug 30, 2024
238b280
Black
rmoyard Aug 30, 2024
ec39190
Move test in gradient
rmoyard Aug 30, 2024
c409819
Merge branch 'main' into fix_forloop_enzyme
rmoyard Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@

<h3>Bug fixes</h3>

* Circuits with preprocessing functions outside qnodes can now be differentiated.
[(#332)](https://github.com/PennyLaneAI/catalyst/pull/332)

```python
@qml.qnode(qml.device("lightning.qubit", wires=1))
def f(y):
qml.RX(y, wires=0)
return qml.expval(qml.PauliZ(0))

@catalyst.qjit
def g(x):
return catalyst.grad(lambda y: f(jnp.cos(y)) ** 2)(x)
```

```pycon
>>> g(0.4)
0.3751720385067584
```

<h3>Internal changes</h3>

* Remove the `MemMemCpyOptPass` in llvm O2 (applied for Enzyme), this reduces bugs when
Expand Down
75 changes: 74 additions & 1 deletion frontend/test/pytest/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
pure_callback,
qjit,
value_and_grad,
vmap,
)

# pylint: disable=too-many-lines
Expand Down Expand Up @@ -1444,7 +1445,25 @@ def interpreted(x):
assert np.allclose(compiled(inp), interpreted(inp))


@pytest.mark.xfail(reason="Needs PR #332")
@pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)])
def test_preprocessing_outside_qnode(inp, backend):
"""Test the preprocessing outside qnode."""

@qml.qnode(qml.device(backend, wires=1))
def f(y):
qml.RX(y, wires=0)
return qml.expval(qml.PauliZ(0))

@qjit
def g(x):
return grad(lambda y: f(jnp.cos(y)) ** 2)(x)

def h(x):
return jax.grad(lambda y: f(jnp.cos(y)) ** 2)(x)

assert np.allclose(g(inp), h(inp))


def test_gradient_slice(backend):
"""Test the differentation when the qnode generates memref with non identity layout."""
n_wires = 5
Expand Down Expand Up @@ -1484,6 +1503,60 @@ def my_model(data, weights, bias):
jax_res = jax.jacobian(my_model, argnums=1)(data, params["weights"], params["bias"])
assert np.allclose(cat_res, jax_res)

@pytest.mark.xfail(reason="Vmap yields wrong results when differentiated")
def test_vmap_worflow_derivation(backend):
"""Check the gradient of a vmap workflow"""
n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3

targets = jnp.array([-0.2, 0.4, 0.35, 0.2], dtype=jax.numpy.float64)

dev = qml.device(backend, wires=n_wires)

@qml.qnode(dev, diff_method="adjoint")
def circuit(data, weights):
"""Quantum circuit ansatz"""

@for_loop(0, n_wires, 1)
def data_embedding(i):
qml.RY(data[i], wires=i)

data_embedding()

@for_loop(0, n_wires, 1)
def ansatz(i):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])

ansatz()

return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

circuit = vmap(circuit, in_axes=(1, None))

def my_model(data, weights, bias):
return circuit(data, weights) + bias

def loss_fn(params, data, targets):
predictions = my_model(data, params["weights"], params["bias"])
loss = jnp.sum((targets - predictions) ** 2 / len(data))
return loss

weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.0, dtype=jax.numpy.float64)
params = {"weights": weights, "bias": bias}

results_enzyme = qjit(grad(loss_fn))(params, data, targets)
results_jax = jax.grad(loss_fn)(params, data, targets)

data_enzyme, pytree_enzyme = tree_flatten(results_enzyme)
data_jax, pytree_fd = tree_flatten(results_jax)

assert pytree_enzyme == pytree_fd
assert jnp.allclose(data_enzyme[0], data_jax[0])
assert jnp.allclose(data_enzyme[1], data_jax[1])

@pytest.mark.parametrize(
"gate,state", ((qml.BasisState, np.array([1])), (qml.StatePrep, np.array([0, 1])))
Expand Down
5 changes: 4 additions & 1 deletion frontend/test/pytest/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
import jax.numpy as jnp
import pennylane as qml
import pytest
from jax.tree_util import tree_flatten

Check notice on line 21 in frontend/test/pytest/test_vmap.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_vmap.py#L21

Unused tree_flatten imported from jax.tree_util (unused-import)

from catalyst import qjit, vmap
from catalyst import for_loop, grad, qjit, vmap

Check notice on line 23 in frontend/test/pytest/test_vmap.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_vmap.py#L23

Unused for_loop imported from catalyst (unused-import)

Check notice on line 23 in frontend/test/pytest/test_vmap.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_vmap.py#L23

Unused grad imported from catalyst (unused-import)

# pylint: disable=too-many-public-methods


class TestVectorizeMap:
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ static FailureOr<func::FuncOp> cloneCallee(PatternRewriter &rewriter, Operation
funcOp.walk([&](func::CallOp callOp) {
if (callOp.getCallee() == qnode.getName()) {
PatternRewriter::InsertionGuard insertionGuard(rewriter);
rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
// TODO: optimize the placement of the param count call (e.g. loop hoisting)
rewriter.setInsertionPoint(callOp);
Value paramCount =
rewriter
.create<func::CallOp>(loc, paramCountFn, callOp.getArgOperands())
Expand Down
Loading