Skip to content

Function for setting and resetting global changes during compilation #913

Open
@erick-xanadu

Description

@erick-xanadu

There are several places where we modify JAX during compilation.

Just to list some:

# Required for JAX tracer objects as PennyLane wires.
# pylint: disable=unnecessary-lambda
setattr(jax.interpreters.partial_eval.DynamicJaxprTracer, "__hash__", lambda x: id(x))

# This flag cannot be set in ``QJIT.get_mlir()`` because values created before
# that function is called must be consistent with the JAX configuration value.
jax.config.update("jax_enable_x64", True)

Patchers... (see jax_extras, see jax_transient_config)

And we also have a global context to see whether or not we are running or jax via the EvaluationContext.

With callbacks, this now changes the assumption that if we are tracing, we will never go back to the python environment. We should have a function that is able to save the configuration before we trace, change however we want it, reset it during callbacks, and reset it back to what we need once we exit the callback scope.

Note, could we instead of changing jax.interpreters.partial_eval.DynamicJaxprTracer and adding a hash, can't we change pennylane wire utilities to find whether the wire is jax.interpreters.partial_eval.DynamicJaxprTracer and compute the id as its hash instead of modifying jax itself?

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions