Skip to content

[JAX FE] : Implement jax.lax.iota operation #28221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/frontends/jax/src/op/iota.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/jax/node_context.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/constant.hpp"

#include "utils.hpp"

namespace ov {
namespace frontend {
namespace jax {
namespace op {

using namespace ov::op;

OutputVector translate_iota(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto dtype = context.const_named_param<element::Type>("dtype");
auto size = context.const_named_param<int64_t>("size");
auto start = v0::Constant::create(ov::element::i64, Shape{}, {0});
auto step = v0::Constant::create(ov::element::i64, Shape{}, {1});
auto stop = v0::Constant::create(ov::element::i64, Shape{}, {size});
auto res = std::make_shared<ov::op::v4::Range>(start, stop, step, dtype);
return {res};

};

} // namespace op
} // namespace jax
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ OP_CONVERTER(translate_copy);
OP_CONVERTER(translate_dot_general);
OP_CONVERTER(translate_erfc);
OP_CONVERTER(translate_integer_pow);
OP_CONVERTER(translate_iota);
OP_T_CONVERTER(translate_reduce_op);
OP_CONVERTER(translate_reduce_window_max);
OP_CONVERTER(translate_reduce_window_sum);
Expand Down Expand Up @@ -79,6 +80,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
{"ge", op::translate_binary_op<v1::GreaterEqual>},
{"gt", op::translate_binary_op<v1::Greater>},
{"integer_pow", op::translate_integer_pow},
{"iota", op::translate_iota},
{"lt", op::translate_binary_op<v1::Less>},
{"le", op::translate_binary_op<v1::LessEqual>},
{"max", op::translate_1to1_match_2_inputs<v1::Maximum>},
Expand Down
36 changes: 36 additions & 0 deletions tests/layer_tests/jax_tests/test_iota.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import jax
import numpy as np
import pytest
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(5402)


class TestIota(JaxLayerTest):
def _prepare_input(self):
return (self.input_type, self.input_shape)

def create_model(self, input_shape, input_type):
self.input_shape = input_shape
self.input_type = input_type

def jax_iota(dtype, size):
return jax.lax.iota(dtype, size)

return jax_iota, None, 'iota'


@pytest.mark.parametrize("input_shape", [1,2,3])
@pytest.mark.parametrize("input_type", ["jnp.float32"])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_jax_fe
def test_iota(self, ie_device, precision, ir_version, input_shape, input_type):
self._test(*self.create_model(input_shape, input_type),
ie_device, precision,
ir_version)
Loading