Skip to content

Commit c66e485

Browse files
committed
add qjit to original demo
1 parent 710b99a commit c66e485

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

demonstrations/tutorial_eqnn_force_field.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@
143143
import matplotlib.pyplot as plt
144144
import sklearn
145145

146+
######################################################################
147+
# To speed up the computation, we also import catalyst, a jit compiler for PennyLane quantum programs.
148+
import catalyst
149+
146150
######################################################################
147151
# Let us construct Pauli matrices, which are used to build the Hamiltonian.
148152
X = np.array([[0, 1], [1, 0]])
@@ -301,10 +305,13 @@ def noise_layer(epsilon, wires):
301305
#################################
302306

303307

304-
dev = qml.device("default.qubit", wires=num_qubits)
308+
######################################################################
309+
# To speed up the computation, we will be using catalyst to compile our quantum program, and we will be
310+
# running our program on the lightning backend instead of the default qubit backend.
311+
dev = qml.device("lightning.qubit", wires=num_qubits)
305312

306313

307-
@qml.qnode(dev, interface="jax")
314+
@qml.qnode(dev)
308315
def vqlm(data, params):
309316

310317
weights = params["params"]["weights"]
@@ -396,25 +403,27 @@ def vqlm(data, params):
396403
)
397404

398405
#################################
399-
# We will know define the cost function and how to train the model using Jax. We will use the mean-square-error loss function.
400-
# To speed up the computation, we use the decorator ``@jax.jit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the
401-
# benefit that all following executions will be significantly faster, see the `Jax docs on jitting <https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html>`_.
406+
# We will now define the cost function and how to train the model using Jax. We will use the mean-square-error loss function.
407+
# To speed up the computation, we use the decorator ``@catalyst.qjit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the
408+
# benefit that all following executions will be significantly faster, see the `Catalyst documentation <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`_.
402409

403410
#################################
404411
from jax.example_libraries import optimizers
405412

406413
# We vectorize the model over the data points
407-
vec_vqlm = jax.vmap(vqlm, (0, None), 0)
414+
vec_vqlm = catalyst.vmap(
415+
vqlm,
416+
in_axes=(0, {"params": {"alphas": None, "epsilon": None, "weights": None}}),
417+
out_axes=0,
418+
)
408419

409420

410421
# Mean-squared-error loss function
411-
@jax.jit
412422
def mse_loss(predictions, targets):
413423
return jnp.mean(0.5 * (predictions - targets) ** 2)
414424

415425

416426
# Make prediction and compute the loss
417-
@jax.jit
418427
def cost(weights, loss_data):
419428
data, E_target, F_target = loss_data
420429
E_pred = vec_vqlm(data, weights)
@@ -424,17 +433,19 @@ def cost(weights, loss_data):
424433

425434

426435
# Perform one training step
427-
@jax.jit
436+
# This function will be repeatedly called, so we qjit it to exploit the saved runtime from many runs.
437+
@catalyst.qjit
428438
def train_step(step_i, opt_state, loss_data):
429439

430440
net_params = get_params(opt_state)
431-
loss, grads = jax.value_and_grad(cost, argnums=0)(net_params, loss_data)
432-
441+
loss = cost(net_params, loss_data)
442+
grads = catalyst.grad(cost, method="fd", h=1e-13, argnums=0)(net_params, loss_data)
433443
return loss, opt_update(step_i, grads, opt_state)
434444

435445

436446
# Return prediction and loss at inference times, e.g. for testing
437-
@jax.jit
447+
# This function is also repeatedly called, so qjit it.
448+
@catalyst.qjit
438449
def inference(loss_data, opt_state):
439450

440451
data, E_target, F_target = loss_data

0 commit comments

Comments
 (0)