Description
There are several places where we modify JAX during compilation.
Just to list some:
# Required for JAX tracer objects as PennyLane wires.
# pylint: disable=unnecessary-lambda
setattr(jax.interpreters.partial_eval.DynamicJaxprTracer, "__hash__", lambda x: id(x))
# This flag cannot be set in ``QJIT.get_mlir()`` because values created before
# that function is called must be consistent with the JAX configuration value.
jax.config.update("jax_enable_x64", True)
Patchers... (see jax_extras, see jax_transient_config)
And we also have a global context to see whether or not we are running or jax via the EvaluationContext.
With callbacks, this now changes the assumption that if we are tracing, we will never go back to the python environment. We should have a function that is able to save the configuration before we trace, change however we want it, reset it during callbacks, and reset it back to what we need once we exit the callback scope.
Note, could we instead of changing jax.interpreters.partial_eval.DynamicJaxprTracer
and adding a hash, can't we change pennylane wire utilities to find whether the wire is jax.interpreters.partial_eval.DynamicJaxprTracer
and compute the id
as its hash
instead of modifying jax itself?