-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Support numpy.prod operation #21188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Support numpy.prod operation #21188
Changes from 3 commits
e720362
3449d9c
5b1cd25
5342879
a198e04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1117,8 +1117,67 @@ def pad(x, pad_width, mode="constant", constant_values=None): | |
) | ||
|
||
|
||
""" | ||
Helper Function to convert the string dtype to ov type | ||
""" | ||
|
||
|
||
def string_to_ov_type(dtype_str): | ||
from openvino.runtime import Type | ||
|
||
mapping = { | ||
"bool": Type.boolean, | ||
"int8": Type.i8, | ||
"int16": Type.i16, | ||
"int32": Type.i32, | ||
"int64": Type.i64, | ||
"uint8": Type.u8, | ||
"uint16": Type.u16, | ||
"uint32": Type.u32, | ||
"uint64": Type.u64, | ||
"float16": Type.f16, | ||
"float32": Type.f32, | ||
"float64": Type.f64, | ||
} | ||
return mapping[dtype_str] | ||
|
||
|
||
def prod(x, axis=None, keepdims=False, dtype=None): | ||
raise NotImplementedError("`prod` is not supported with openvino backend") | ||
if axis == () or axis == []: | ||
return x | ||
|
||
x = get_ov_output(x) | ||
x_type = x.get_element_type() | ||
|
||
# Promote dtype if not explicitly specified | ||
if dtype is None: | ||
if x_type == Type.boolean: | ||
promoted_dtype = Type.i32 | ||
elif x_type in (Type.i8, Type.i16): | ||
promoted_dtype = Type.i32 | ||
elif x_type in (Type.u8, Type.u16): | ||
promoted_dtype = Type.u32 | ||
else: | ||
promoted_dtype = x_type | ||
else: | ||
promoted_dtype = string_to_ov_type(dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need any type promotion here. Please remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the dtype promotion as you suggested (like uint8 → uint32, etc.), but now a few CI tests are failing because they still expect the promoted dtype (e.g., jnp.prod(uint8) → uint32, but OpenVINO now stays uint8).Could you please let me know that should i remove it or not or waht other ways should i try it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# Cast to promoted dtype if necessary | ||
if x_type != promoted_dtype: | ||
x = ov_opset.convert(x, promoted_dtype).output(0) | ||
|
||
if axis is None: | ||
flatten_shape = ov_opset.constant([-1], Type.i32).output(0) | ||
x = ov_opset.reshape(x, flatten_shape, False).output(0) | ||
axis = 0 | ||
|
||
if isinstance(axis, tuple): | ||
axis = list(axis) | ||
axis = ov_opset.constant(axis, Type.i32).output(0) | ||
|
||
return OpenVINOKerasTensor( | ||
ov_opset.reduce_prod(x, axis, keepdims).output(0) | ||
) | ||
|
||
|
||
def quantile(x, q, axis=None, method="linear", keepdims=False): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ torch-xla==2.6.0;sys_platform != 'darwin' | |
# Jax. | ||
# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. | ||
# Note that we test against the latest JAX on GPU. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert this change |
||
jax[cpu]==0.5.0 | ||
flax | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please re-use this dictionary: https://github.com/keras-team/keras/blob/master/keras/src/backend/openvino/core.py#L25
Yours is not needed