Description
Description of the model to be implemented
My request regards not a model but rather an example of how to lift new transformations. There are design docs on lifted transformations and they explain the rationale behind lifting, but fall short of explaining how to go about implementing one "in the wild" (or perhaps I fall short of understanding them). My specific scenario requires jax.experimental.ode.odeint
in Module.__call__
and I have not managed the understand and translate the implementation of lifted vmap
and others to odeint
. For starters, I don't fully understand which is the transformation that makes this operation illegal: is it about the custom vjp? I see that there exists a lifted custom_vjp
but I don't know how it is related to the problem I am facing: should I re-wrap _odeint
and redefine the derivatives myself?
I am suggesting the case of odeint
to be the basis for an expansion of the documentation, but I would be entirely satisfied to receive "one-off" help (which I can later try to turn into a tutorial to contribute back).
A minimal example
Here I created a colab with a minimal example of what I would like to get working.
A rationale for the example
This project implements a continuous normalizing flow (CNF) in flax. It defines the integrand as nn.Module
and then sidesteps the issue of not being able to integrate within Module
by integrating "manually" in loss_fn
and solve_dynamics
(thus effectively duplicating code). Myself, I am looking to use a CNF as a part of a bigger model (a conditional CNF, which "owns" both the CNF and the model for generating the conditional embeddings) and the approach of calling odeint
manually each time I want to call my CCNF quickly becomes very burdensome.
Note
There are more issues with the CNF code than just this, for example the author was apparently not aware of the Module.apply(method=)
keyword, thus the complexity of CNF
and Neg_CNF
classes. I am probably going to make a PR with cleanup in that regard, but I consider that issue completely orthogonal to my example request here.