Description
Context
OpenVINO component responsible for support of JAX/Flax models is called as JAX Frontend (JAX FE). JAX FE converts a JAX/Flax model represented by ClosedJAXpr
graph object with operations from jax.lax opset to OpenVINO IR containing operations from OpenVINO opset.
In order to infer JAX/Flax models containing jax.lax.iota operation by OpenVINO, JAX FE needs to be extended with this operation support.
What needs to be done?
For jax.lax.iota operation support, you need to implement the corresponding loader into JAX FE op directory and to register it into the dictionary of Loaders. One loader is responsible for conversion (or decomposition) of one type of JAX operation.
Here is an example of loader implementation for jax.lax.reshape operation:
OutputVector translate_reshape(const NodeContext& context) {
num_inputs_check(context, 1, 1);
Output<Node> input = context.get_input(0);
auto new_sizes = context.const_named_param<std::vector<int64_t>>("new_sizes");
if (context.has_param("dimensions")) {
auto dimensions = context.const_named_param<std::vector<int64_t>>("dimensions");
// transpose the input first.
auto permutation_node = std::make_shared<v0::Constant>(element::i64, Shape{dimensions.size()}, dimensions);
input = std::make_shared<v1::Transpose>(input, permutation_node);
}
auto new_shape_node = std::make_shared<v0::Constant>(element::i64, Shape{new_sizes.size()}, new_sizes);
Output<Node> res = std::make_shared<v1::Reshape>(input, new_shape_node, false);
return {res};
};
In this example, translate_reshape
expresses jax.lax.reshape
using OpenVINO opset. Since jax.lax.reshape
performs transposition and tensor reshaping according to JAX documentation, the resulted decomposition contains OpenVINO Transpose
and Reshape
operations. For Transpose
and Reshape
nodes, this conversion parses constant parameters dimensions
to permute input tensor and new_size
that is the target shape of the result.
Once you are done with implementation of the translator, you need to implement the corresponding layer tests test_iota.py
and put it into layer_tests/jax_tests directory. Example how to run some layer test:
export TEST_DEVICE=CPU
export JAX_TRACE_MODE=JAXPR
export
cd openvino/tests/layer_tests/jax_tests
pytest test_reshape.py
Example Pull Requests
- [JAX FE] Support concatenate, integer_pow, reduce ops #26288
- [JAX FE] Support operations for Vision Transformer model #26254
Resources
- What is OpenVINO?
- How to Build OpenVINO
- Contribution guide - start here!
- Intel DevHub Discord channel - engage in discussions, ask questions and talk to OpenVINO developers
Contact points
- @openvinotoolkit/openvino-jax-frontend-maintainers
- @rkazants in GitHub and Discord
Ticket
No response
Metadata
Metadata
Assignees
Type
Projects
Status