From 6718c9e5d328e35d53dd0ba7f2358b308b67ac78 Mon Sep 17 00:00:00 2001 From: rmoyard Date: Tue, 24 Oct 2023 21:53:43 -0400 Subject: [PATCH 01/20] Fix compilation --- mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp index 5b7a1d293..54fbbfe3f 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp @@ -170,7 +170,7 @@ FailureOr HybridGradientLowering::cloneCallee(PatternRewriter &rew // argument, we need to insert a call to the parameter count function at the // location of the grad op. PatternRewriter::InsertionGuard insertionGuard(rewriter); - rewriter.setInsertionPoint(gradOp); + rewriter.setInsertionPointAfterValue(gradOp.getArgOperands().back()); Value paramCount = rewriter.create(loc, paramCountFn, gradOp.getArgOperands()) @@ -187,7 +187,7 @@ FailureOr HybridGradientLowering::cloneCallee(PatternRewriter &rew funcOp.walk([&](func::CallOp callOp) { if (callOp.getCallee() == qnode.getName()) { PatternRewriter::InsertionGuard insertionGuard(rewriter); - rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front()); + rewriter.setInsertionPoint(callOp); Value paramCount = rewriter .create(loc, paramCountFn, callOp.getArgOperands()) From 105c18fb9dc20a8b6638057792fa03e1b5245682 Mon Sep 17 00:00:00 2001 From: rmoyard Date: Tue, 24 Oct 2023 22:00:58 -0400 Subject: [PATCH 02/20] Undo --- mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp index 54fbbfe3f..d7d53f7eb 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp @@ -170,7 +170,7 @@ FailureOr HybridGradientLowering::cloneCallee(PatternRewriter &rew // argument, we need to insert a call to the parameter count function at the // location of the grad op. PatternRewriter::InsertionGuard insertionGuard(rewriter); - rewriter.setInsertionPointAfterValue(gradOp.getArgOperands().back()); + rewriter.setInsertionPoint(gradOp); Value paramCount = rewriter.create(loc, paramCountFn, gradOp.getArgOperands()) From 655d889dd2c1694d19a57031b9889b366234a82a Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 19 Mar 2024 10:42:24 -0400 Subject: [PATCH 03/20] Add integration test --- frontend/test/pytest/test_vmap.py | 61 ++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 0d5b31c4b..5ba1b68fe 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -18,8 +18,9 @@ import jax.numpy as jnp import pennylane as qml import pytest +from jax.tree_util import tree_flatten -from catalyst import qjit, vmap +from catalyst import for_loop, grad, qjit, vmap class TestVectorizeMap: @@ -752,3 +753,61 @@ def circuit(x): match="Invalid batch size; it must be a non-zero integer, but got 0.", ): qjit(workflow)(x) + + def test_vmap_worflow_derivation(self, backend): + """Check the gradient of a vmap workflow""" + n_wires = 5 + data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3 + targets = jnp.array([-0.2, 0.4, 0.35, 0.2]) + + dev = qml.device(backend, wires=n_wires) + + + @qml.qnode(dev) + def circuit(data, weights): + """Quantum circuit ansatz""" + + @for_loop(0, n_wires, 1) + def data_embedding(i): + qml.RY(data[i], wires=i) + + data_embedding() + + @for_loop(0, n_wires, 1) + def ansatz(i): + qml.RX(weights[i, 0], wires=i) + qml.RY(weights[i, 1], wires=i) + qml.RX(weights[i, 2], wires=i) + qml.CNOT(wires=[i, (i + 1) % n_wires]) + + ansatz() + + return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)])) + + circuit = qml.qjit(vmap(circuit, in_axes=(1, None)), autograph=False) + + + + def my_model(data, weights, bias): + return circuit(data, weights) + bias + + @qml.qjit + def loss_fn(params, data, targets): + predictions = my_model(data, params["weights"], params["bias"]) + loss = jnp.sum((targets - predictions) ** 2 / len(data)) + return loss + + + weights = jnp.ones([n_wires, 3]) + bias = jnp.array(0.) + params = {"weights": weights, "bias": bias} + + results_enzyme = qml.qjit(grad(loss_fn))(params, data, targets) + results_fd = qml.qjit(grad(loss_fn, method="fd"))(params, data, targets) + + data_enzyme, pytree_enzyme = tree_flatten(results_enzyme) + data_fd, pytree_fd = tree_flatten(results_fd) + + assert pytree_enzyme == pytree_fd + assert jnp.allclose(data_enzyme[0], data_fd[0], atol=1e-1) + assert jnp.allclose(data_enzyme[1], data_fd[1], atol=1e-1) \ No newline at end of file From e488748b51cf2e8b44af4cc902bf91301fb11f7c Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 19 Mar 2024 10:43:12 -0400 Subject: [PATCH 04/20] Black --- frontend/test/pytest/test_vmap.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 5ba1b68fe..505ba6c45 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -753,7 +753,7 @@ def circuit(x): match="Invalid batch size; it must be a non-zero integer, but got 0.", ): qjit(workflow)(x) - + def test_vmap_worflow_derivation(self, backend): """Check the gradient of a vmap workflow""" n_wires = 5 @@ -761,7 +761,6 @@ def test_vmap_worflow_derivation(self, backend): targets = jnp.array([-0.2, 0.4, 0.35, 0.2]) dev = qml.device(backend, wires=n_wires) - @qml.qnode(dev) def circuit(data, weights): @@ -785,8 +784,6 @@ def ansatz(i): return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)])) circuit = qml.qjit(vmap(circuit, in_axes=(1, None)), autograph=False) - - def my_model(data, weights, bias): return circuit(data, weights) + bias @@ -797,9 +794,8 @@ def loss_fn(params, data, targets): loss = jnp.sum((targets - predictions) ** 2 / len(data)) return loss - weights = jnp.ones([n_wires, 3]) - bias = jnp.array(0.) + bias = jnp.array(0.0) params = {"weights": weights, "bias": bias} results_enzyme = qml.qjit(grad(loss_fn))(params, data, targets) @@ -810,4 +806,4 @@ def loss_fn(params, data, targets): assert pytree_enzyme == pytree_fd assert jnp.allclose(data_enzyme[0], data_fd[0], atol=1e-1) - assert jnp.allclose(data_enzyme[1], data_fd[1], atol=1e-1) \ No newline at end of file + assert jnp.allclose(data_enzyme[1], data_fd[1], atol=1e-1) From bc3da9b6df87663894e93a54532cfa6d6fec7516 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 19 Mar 2024 11:00:42 -0400 Subject: [PATCH 05/20] Update jax --- frontend/catalyst/jit.py | 1 + frontend/test/pytest/test_vmap.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index f7e70201b..5f39af6e5 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -384,6 +384,7 @@ def compute_jvp(self, primals, tangents): # Optimization: Do not compute Jacobians for arguments which do not participate in # differentiation. argnums = [] + tangents, _ = tree_flatten(tangents) for idx, tangent in enumerate(tangents): if not isinstance(tangent, jax.custom_derivatives.SymbolicZero): argnums.append(idx) diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 505ba6c45..b06112557 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -799,11 +799,11 @@ def loss_fn(params, data, targets): params = {"weights": weights, "bias": bias} results_enzyme = qml.qjit(grad(loss_fn))(params, data, targets) - results_fd = qml.qjit(grad(loss_fn, method="fd"))(params, data, targets) + results_fd = jax.grad(loss_fn)(params, data, targets) data_enzyme, pytree_enzyme = tree_flatten(results_enzyme) data_fd, pytree_fd = tree_flatten(results_fd) assert pytree_enzyme == pytree_fd - assert jnp.allclose(data_enzyme[0], data_fd[0], atol=1e-1) - assert jnp.allclose(data_enzyme[1], data_fd[1], atol=1e-1) + assert jnp.allclose(data_enzyme[0], data_fd[0]) + assert jnp.allclose(data_enzyme[1], data_fd[1]) From e3a44ec5cdeb385345c1fa0b3275b74dc3a83e45 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 19 Mar 2024 11:36:18 -0400 Subject: [PATCH 06/20] Apply suggestions from code review Co-authored-by: David Ittah --- frontend/test/pytest/test_vmap.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index b06112557..6bb04ffdd 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -783,12 +783,11 @@ def ansatz(i): return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)])) - circuit = qml.qjit(vmap(circuit, in_axes=(1, None)), autograph=False) + circuit = vmap(circuit, in_axes=(1, None)) def my_model(data, weights, bias): return circuit(data, weights) + bias - @qml.qjit def loss_fn(params, data, targets): predictions = my_model(data, params["weights"], params["bias"]) loss = jnp.sum((targets - predictions) ** 2 / len(data)) @@ -798,7 +797,7 @@ def loss_fn(params, data, targets): bias = jnp.array(0.0) params = {"weights": weights, "bias": bias} - results_enzyme = qml.qjit(grad(loss_fn))(params, data, targets) + results_enzyme = qjit(grad(loss_fn))(params, data, targets) results_fd = jax.grad(loss_fn)(params, data, targets) data_enzyme, pytree_enzyme = tree_flatten(results_enzyme) From 02a49ab7fe92294e829b03dc9850189194e5f015 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 19 Mar 2024 11:59:06 -0400 Subject: [PATCH 07/20] Readd qjit --- frontend/test/pytest/test_vmap.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 6bb04ffdd..55496513f 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -788,6 +788,7 @@ def ansatz(i): def my_model(data, weights, bias): return circuit(data, weights) + bias + @qjit def loss_fn(params, data, targets): predictions = my_model(data, params["weights"], params["bias"]) loss = jnp.sum((targets - predictions) ** 2 / len(data)) From b0439ba1e5233ed4de732d1a05ef7c83468bb3a0 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 19 Mar 2024 12:09:24 -0400 Subject: [PATCH 08/20] Rename test variables --- frontend/test/pytest/test_vmap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 55496513f..ec18d316c 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -762,7 +762,7 @@ def test_vmap_worflow_derivation(self, backend): dev = qml.device(backend, wires=n_wires) - @qml.qnode(dev) + @qml.qnode(dev, diff_method="adjoint") def circuit(data, weights): """Quantum circuit ansatz""" @@ -799,10 +799,10 @@ def loss_fn(params, data, targets): params = {"weights": weights, "bias": bias} results_enzyme = qjit(grad(loss_fn))(params, data, targets) - results_fd = jax.grad(loss_fn)(params, data, targets) + results_jax = jax.grad(loss_fn)(params, data, targets) data_enzyme, pytree_enzyme = tree_flatten(results_enzyme) - data_fd, pytree_fd = tree_flatten(results_fd) + data_fd, pytree_fd = tree_flatten(results_jax) assert pytree_enzyme == pytree_fd assert jnp.allclose(data_enzyme[0], data_fd[0]) From 6e8681fbb3bd9dd15ace3e0e7eef2332283319ae Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 19 Mar 2024 12:32:28 -0400 Subject: [PATCH 09/20] Update atol --- frontend/test/pytest/test_vmap.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index ec18d316c..292a78bcd 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -758,7 +758,8 @@ def test_vmap_worflow_derivation(self, backend): """Check the gradient of a vmap workflow""" n_wires = 5 data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3 - targets = jnp.array([-0.2, 0.4, 0.35, 0.2]) + + targets = jnp.array([-0.2, 0.4, 0.35, 0.2], dtype=jax.numpy.float64) dev = qml.device(backend, wires=n_wires) @@ -788,22 +789,21 @@ def ansatz(i): def my_model(data, weights, bias): return circuit(data, weights) + bias - @qjit def loss_fn(params, data, targets): predictions = my_model(data, params["weights"], params["bias"]) loss = jnp.sum((targets - predictions) ** 2 / len(data)) return loss weights = jnp.ones([n_wires, 3]) - bias = jnp.array(0.0) + bias = jnp.array(0.0, dtype=jax.numpy.float64) params = {"weights": weights, "bias": bias} results_enzyme = qjit(grad(loss_fn))(params, data, targets) results_jax = jax.grad(loss_fn)(params, data, targets) data_enzyme, pytree_enzyme = tree_flatten(results_enzyme) - data_fd, pytree_fd = tree_flatten(results_jax) + data_jax, pytree_fd = tree_flatten(results_jax) assert pytree_enzyme == pytree_fd - assert jnp.allclose(data_enzyme[0], data_fd[0]) - assert jnp.allclose(data_enzyme[1], data_fd[1]) + assert jnp.allclose(data_enzyme[0], data_jax[0]) + assert jnp.allclose(data_enzyme[1], data_jax[1], atol=8e-2) From 45dbf25931dd4d783655e019517be5fd8556c790 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 19 Mar 2024 16:00:00 -0400 Subject: [PATCH 10/20] Update mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp Co-authored-by: David Ittah --- mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp index cf6c29133..88d74ff4f 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp @@ -183,6 +183,7 @@ FailureOr HybridGradientLowering::cloneCallee(PatternRewriter &rew funcOp.walk([&](func::CallOp callOp) { if (callOp.getCallee() == qnode.getName()) { PatternRewriter::InsertionGuard insertionGuard(rewriter); + # TODO: optimize the placement of the param count call (e.g. loop hoisting) rewriter.setInsertionPoint(callOp); Value paramCount = rewriter From c0e824ca4c3cc905890684d64c73255e8a0e187c Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Wed, 3 Jul 2024 16:32:23 -0400 Subject: [PATCH 11/20] Enzyme version --- .dep-versions | 2 +- mlir/Enzyme | 2 +- mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.dep-versions b/.dep-versions index afe00502e..2195172d1 100644 --- a/.dep-versions +++ b/.dep-versions @@ -2,7 +2,7 @@ jax=0.4.23 mhlo=4611968a5f6818e6bdfb82217b9e836e0400bba9 llvm=cd9a641613eddf25d4b25eaa96b2c393d401d42c -enzyme=1beb98b51442d50652eaa3ffb9574f4720d611f1 +enzyme=v0.0.130 # Always remove custom PL/LQ versions before release. pennylane=d90137dd8f6af46699653deda5f839c27701769f diff --git a/mlir/Enzyme b/mlir/Enzyme index 1beb98b51..b53704d21 160000 --- a/mlir/Enzyme +++ b/mlir/Enzyme @@ -1 +1 @@ -Subproject commit 1beb98b51442d50652eaa3ffb9574f4720d611f1 +Subproject commit b53704d21839b1e1e0af54d65471afc11cb6b9ee diff --git a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp index e361a705e..683ae6942 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp @@ -195,7 +195,7 @@ static FailureOr cloneCallee(PatternRewriter &rewriter, Operation funcOp.walk([&](func::CallOp callOp) { if (callOp.getCallee() == qnode.getName()) { PatternRewriter::InsertionGuard insertionGuard(rewriter); - # TODO: optimize the placement of the param count call (e.g. loop hoisting) + // TODO: optimize the placement of the param count call (e.g. loop hoisting) rewriter.setInsertionPoint(callOp); Value paramCount = rewriter From bdd144d6de5c121950f08d389b0b9549569a53d5 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Wed, 3 Jul 2024 16:37:07 -0400 Subject: [PATCH 12/20] Update Enzyme --- .dep-versions | 2 +- mlir/Enzyme | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.dep-versions b/.dep-versions index afe00502e..2195172d1 100644 --- a/.dep-versions +++ b/.dep-versions @@ -2,7 +2,7 @@ jax=0.4.23 mhlo=4611968a5f6818e6bdfb82217b9e836e0400bba9 llvm=cd9a641613eddf25d4b25eaa96b2c393d401d42c -enzyme=1beb98b51442d50652eaa3ffb9574f4720d611f1 +enzyme=v0.0.130 # Always remove custom PL/LQ versions before release. pennylane=d90137dd8f6af46699653deda5f839c27701769f diff --git a/mlir/Enzyme b/mlir/Enzyme index 1beb98b51..b53704d21 160000 --- a/mlir/Enzyme +++ b/mlir/Enzyme @@ -1 +1 @@ -Subproject commit 1beb98b51442d50652eaa3ffb9574f4720d611f1 +Subproject commit b53704d21839b1e1e0af54d65471afc11cb6b9ee From 8afc5d752f4574053b1b76da43ed36fe4a92febd Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Thu, 4 Jul 2024 13:33:24 -0400 Subject: [PATCH 13/20] Add no free --- mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp | 5 ++++- mlir/test/Catalyst/ConversionTest.mlir | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp index 2b0ab6f63..3479be8fa 100644 --- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp +++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp @@ -461,7 +461,10 @@ struct DefineCallbackOpPattern : public OpConversionPattern { LLVM::LLVMFuncOp customCallFnOp = mlir::LLVM::lookupOrCreateFn( mod, "__catalyst_inactive_callback", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg); - + SmallVector passthroughs; + auto keyAttr = StringAttr::get(ctx, "nofree"); + passthroughs.push_back(keyAttr); + customCallFnOp.setPassthroughAttr(ArrayAttr::get(ctx, passthroughs)); // TODO: remove redundant alloca+store since ultimately we'll receive struct* for (auto arg : op.getArguments()) { Type structTy = typeConverter->convertType(arg.getType()); diff --git a/mlir/test/Catalyst/ConversionTest.mlir b/mlir/test/Catalyst/ConversionTest.mlir index 01c0c4d5e..975d4da59 100644 --- a/mlir/test/Catalyst/ConversionTest.mlir +++ b/mlir/test/Catalyst/ConversionTest.mlir @@ -134,6 +134,7 @@ module @test0 { // CHECK-LABEL: @test1 module @test1 { catalyst.callback @callback_1(memref, memref) attributes {argc = 1 : i64, id = 1 : i64, resc = 1 : i64} + // CHECK: __catalyst_inactive_callback(i64, i64, i64, ...) attributes {passthrough = ["nofree"]} // CHECK-LABEL: func.func private @foo( // CHECK-SAME: [[arg0:%.+]]: tensor // CHECK-SAME:) From 07fd2cc983bc243510e8dedc465357e998865eb0 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 23 Jul 2024 13:17:29 -0400 Subject: [PATCH 14/20] Black --- frontend/catalyst/jit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index c78e4916e..08f242445 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -319,7 +319,7 @@ def f( have dynamic shape ``n``. Passing a sequence of dictionaries: - + .. code-block:: python abstracted_axes=({}, {0: 'n'}, {1: 'm', 0: 'n'}) From ed962843e36d8a5764ee06f591042d38e9e7e885 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 23 Jul 2024 13:20:34 -0400 Subject: [PATCH 15/20] Xfail test --- frontend/test/pytest/test_vmap.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 74877eb5f..35a9583fe 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -22,6 +22,8 @@ from catalyst import for_loop, grad, qjit, vmap +# pylint: disable=too-many-public-methods + class TestVectorizeMap: """Test QJIT compatibility with JAX vectorization.""" @@ -754,6 +756,7 @@ def circuit(x): ): qjit(workflow)(x) + @pytest.xfail("wrong results for vmap") def test_vmap_worflow_derivation(self, backend): """Check the gradient of a vmap workflow""" n_wires = 5 @@ -806,7 +809,7 @@ def loss_fn(params, data, targets): assert pytree_enzyme == pytree_fd assert jnp.allclose(data_enzyme[0], data_jax[0]) - assert jnp.allclose(data_enzyme[1], data_jax[1], atol=8e-2) + assert jnp.allclose(data_enzyme[1], data_jax[1]) def test_vmap_usage_patterns(self, backend): """Test usage patterns of catalyst.vmap.""" From 2813cd39776051f5a995c8470d7d9597a1d5f544 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 23 Jul 2024 13:31:39 -0400 Subject: [PATCH 16/20] Add test --- frontend/test/pytest/test_gradient.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index 914ae2c88..fcc09d704 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -1307,6 +1307,25 @@ def interpreted(x): assert np.allclose(compiled(inp), interpreted(inp)) +@pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) +def test_preprocessing_outside_qnode(inp, backend): + """Test the preprocessing outside qnode.""" + + @qml.qnode(qml.device(backend, wires=1)) + def f(y): + qml.RX(y, wires=0) + return qml.expval(qml.PauliZ(0)) + + @qjit + def g(x): + return grad(lambda y: f(jnp.cos(y)) ** 2)(x) + + def h(x): + return jax.grad(lambda y: f(jnp.cos(y)) ** 2)(x) + + assert np.allclose(g(inp), h(inp)) + + class TestGradientErrors: """Test errors when an operation which does not have a valid gradient is reachable from the grad op""" From a53b11361e37e6e3d50ea572e9b6ea7c30c88493 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 23 Jul 2024 13:55:31 -0400 Subject: [PATCH 17/20] Changelog --- doc/changelog.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/doc/changelog.md b/doc/changelog.md index 7405eac51..c1713a475 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -127,6 +127,25 @@

Bug fixes

+* Circuits with preprocessing functions outside qnodes can now be differentiated. + [(#332)](https://github.com/PennyLaneAI/catalyst/pull/332) + + ```python + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def f(y): + qml.RX(y, wires=0) + return qml.expval(qml.PauliZ(0)) + + @catalyst.qjit + def g(x): + return catalyst.grad(lambda y: f(jnp.cos(y)) ** 2)(x) + ``` + + ```pycon + >>> g(0.4) + 0.3751720385067584 + ``` + * Static arguments can now be passed through a QNode when specified with the `static_argnums` keyword argument. [(#932)](https://github.com/PennyLaneAI/catalyst/pull/932) From c208cd3e651c66b041fb1038221a16bc4772269e Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 23 Jul 2024 14:14:13 -0400 Subject: [PATCH 18/20] update --- frontend/test/pytest/test_vmap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 35a9583fe..37476559c 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -756,7 +756,7 @@ def circuit(x): ): qjit(workflow)(x) - @pytest.xfail("wrong results for vmap") + @pytest.mark.xfail(reason="Vmap yields wrong results when differentiated") def test_vmap_worflow_derivation(self, backend): """Check the gradient of a vmap workflow""" n_wires = 5 From 238b28038f97a9ede12844eb7579529ca7ce15ad Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 30 Aug 2024 09:33:46 -0400 Subject: [PATCH 19/20] Black --- frontend/test/pytest/test_gradient.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index 7c195d029..d758c0780 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -1443,6 +1443,7 @@ def interpreted(x): assert np.allclose(compiled(inp), interpreted(inp)) + @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_preprocessing_outside_qnode(inp, backend): """Test the preprocessing outside qnode.""" @@ -1461,6 +1462,7 @@ def h(x): assert np.allclose(g(inp), h(inp)) + def test_gradient_slice(backend): """Test the differentation when the qnode generates memref with non identity layout.""" n_wires = 5 From ec39190644f9b53f057fd014205c1ae58cb98ead Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 30 Aug 2024 09:52:53 -0400 Subject: [PATCH 20/20] Move test in gradient --- frontend/test/pytest/test_gradient.py | 57 +++++++++++++++++++++++++++ frontend/test/pytest/test_vmap.py | 55 -------------------------- 2 files changed, 57 insertions(+), 55 deletions(-) diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index d758c0780..be457f9bd 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -34,6 +34,7 @@ pure_callback, qjit, value_and_grad, + vmap, ) # pylint: disable=too-many-lines @@ -1503,6 +1504,62 @@ def my_model(data, weights, bias): assert np.allclose(cat_res, jax_res) +@pytest.mark.xfail(reason="Vmap yields wrong results when differentiated") +def test_vmap_worflow_derivation(backend): + """Check the gradient of a vmap workflow""" + n_wires = 5 + data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3 + + targets = jnp.array([-0.2, 0.4, 0.35, 0.2], dtype=jax.numpy.float64) + + dev = qml.device(backend, wires=n_wires) + + @qml.qnode(dev, diff_method="adjoint") + def circuit(data, weights): + """Quantum circuit ansatz""" + + @for_loop(0, n_wires, 1) + def data_embedding(i): + qml.RY(data[i], wires=i) + + data_embedding() + + @for_loop(0, n_wires, 1) + def ansatz(i): + qml.RX(weights[i, 0], wires=i) + qml.RY(weights[i, 1], wires=i) + qml.RX(weights[i, 2], wires=i) + qml.CNOT(wires=[i, (i + 1) % n_wires]) + + ansatz() + + return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)])) + + circuit = vmap(circuit, in_axes=(1, None)) + + def my_model(data, weights, bias): + return circuit(data, weights) + bias + + def loss_fn(params, data, targets): + predictions = my_model(data, params["weights"], params["bias"]) + loss = jnp.sum((targets - predictions) ** 2 / len(data)) + return loss + + weights = jnp.ones([n_wires, 3]) + bias = jnp.array(0.0, dtype=jax.numpy.float64) + params = {"weights": weights, "bias": bias} + + results_enzyme = qjit(grad(loss_fn))(params, data, targets) + results_jax = jax.grad(loss_fn)(params, data, targets) + + data_enzyme, pytree_enzyme = tree_flatten(results_enzyme) + data_jax, pytree_fd = tree_flatten(results_jax) + + assert pytree_enzyme == pytree_fd + assert jnp.allclose(data_enzyme[0], data_jax[0]) + assert jnp.allclose(data_enzyme[1], data_jax[1]) + + class TestGradientErrors: """Test errors when an operation which does not have a valid gradient is reachable from the grad op""" diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 37476559c..1ec3cc8ff 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -756,61 +756,6 @@ def circuit(x): ): qjit(workflow)(x) - @pytest.mark.xfail(reason="Vmap yields wrong results when differentiated") - def test_vmap_worflow_derivation(self, backend): - """Check the gradient of a vmap workflow""" - n_wires = 5 - data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3 - - targets = jnp.array([-0.2, 0.4, 0.35, 0.2], dtype=jax.numpy.float64) - - dev = qml.device(backend, wires=n_wires) - - @qml.qnode(dev, diff_method="adjoint") - def circuit(data, weights): - """Quantum circuit ansatz""" - - @for_loop(0, n_wires, 1) - def data_embedding(i): - qml.RY(data[i], wires=i) - - data_embedding() - - @for_loop(0, n_wires, 1) - def ansatz(i): - qml.RX(weights[i, 0], wires=i) - qml.RY(weights[i, 1], wires=i) - qml.RX(weights[i, 2], wires=i) - qml.CNOT(wires=[i, (i + 1) % n_wires]) - - ansatz() - - return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)])) - - circuit = vmap(circuit, in_axes=(1, None)) - - def my_model(data, weights, bias): - return circuit(data, weights) + bias - - def loss_fn(params, data, targets): - predictions = my_model(data, params["weights"], params["bias"]) - loss = jnp.sum((targets - predictions) ** 2 / len(data)) - return loss - - weights = jnp.ones([n_wires, 3]) - bias = jnp.array(0.0, dtype=jax.numpy.float64) - params = {"weights": weights, "bias": bias} - - results_enzyme = qjit(grad(loss_fn))(params, data, targets) - results_jax = jax.grad(loss_fn)(params, data, targets) - - data_enzyme, pytree_enzyme = tree_flatten(results_enzyme) - data_jax, pytree_fd = tree_flatten(results_jax) - - assert pytree_enzyme == pytree_fd - assert jnp.allclose(data_enzyme[0], data_jax[0]) - assert jnp.allclose(data_enzyme[1], data_jax[1]) - def test_vmap_usage_patterns(self, backend): """Test usage patterns of catalyst.vmap."""