Skip to content

Implementing a custom lifted transformation: jax.experimental.ode.odeint #1496

Open
@jatentaki

Description

@jatentaki

Description of the model to be implemented

My request regards not a model but rather an example of how to lift new transformations. There are design docs on lifted transformations and they explain the rationale behind lifting, but fall short of explaining how to go about implementing one "in the wild" (or perhaps I fall short of understanding them). My specific scenario requires jax.experimental.ode.odeint in Module.__call__ and I have not managed the understand and translate the implementation of lifted vmap and others to odeint. For starters, I don't fully understand which is the transformation that makes this operation illegal: is it about the custom vjp? I see that there exists a lifted custom_vjp but I don't know how it is related to the problem I am facing: should I re-wrap _odeint and redefine the derivatives myself?

I am suggesting the case of odeint to be the basis for an expansion of the documentation, but I would be entirely satisfied to receive "one-off" help (which I can later try to turn into a tutorial to contribute back).

A minimal example

Here I created a colab with a minimal example of what I would like to get working.

A rationale for the example

This project implements a continuous normalizing flow (CNF) in flax. It defines the integrand as nn.Module and then sidesteps the issue of not being able to integrate within Module by integrating "manually" in loss_fn and solve_dynamics (thus effectively duplicating code). Myself, I am looking to use a CNF as a part of a bigger model (a conditional CNF, which "owns" both the CNF and the model for generating the conditional embeddings) and the approach of calling odeint manually each time I want to call my CCNF quickly becomes very burdensome.

Note

There are more issues with the CNF code than just this, for example the author was apparently not aware of the Module.apply(method=) keyword, thus the complexity of CNF and Neg_CNF classes. I am probably going to make a PR with cleanup in that regard, but I consider that issue completely orthogonal to my example request here.

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions