Skip to content
Open
105 changes: 91 additions & 14 deletions keras/src/backend/openvino/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from openvino import Type

from keras.src import backend
from keras.src.backend.openvino.core import OPENVINO_DTYPES
from keras.src.backend.openvino.core import OpenVINOKerasTensor
from keras.src.backend.openvino.core import get_ov_output

Expand All @@ -16,6 +17,23 @@ def relu6(x):
return OpenVINOKerasTensor(ov_opset.clamp(x, 0.0, 6.0).output(0))


def celu(x, alpha=1.0):
x = get_ov_output(x)
const_zero = get_ov_output(0.0, x.get_element_type())
const_alpha = get_ov_output(alpha, x.get_element_type())
const_one = get_ov_output(1.0, x.get_element_type())
exp_x_div_alpha = ov_opset.exp(ov_opset.divide(x, const_alpha)).output(0)
negative_branch = ov_opset.multiply(
const_alpha, ov_opset.subtract(exp_x_div_alpha, const_one)
)

celu_x = ov_opset.add(
ov_opset.maximum(x, const_zero).output(0),
ov_opset.minimum(negative_branch, const_zero).output(0),
)
return OpenVINOKerasTensor(celu_x.output(0))


def sigmoid(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(ov_opset.sigmoid(x).output(0))
Expand All @@ -38,14 +56,14 @@ def softsign(x):

def silu(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(
ov_opset.multiply(x, ov_opset.sigmoid(x)).output(0)
)
return OpenVINOKerasTensor(ov_opset.swish(x).output(0))


def log_sigmoid(x):
raise NotImplementedError(
"`log_sigmoid` is not supported with openvino backend"
x = get_ov_output(x)
neg_x = ov_opset.negative(x)
return OpenVINOKerasTensor(
ov_opset.negative(ov_opset.softplus(neg_x)).output(0)
)


Expand Down Expand Up @@ -128,8 +146,8 @@ def max_pool(
padding="valid",
data_format=None,
):
raise NotImplementedError(
"`max_pool` is not supported with openvino backend"
return _pool(
inputs, pool_size, ov_opset.max_pool, strides, padding, data_format
)


Expand All @@ -140,11 +158,52 @@ def average_pool(
padding="valid",
data_format=None,
):
raise NotImplementedError(
"`average_pool` is not supported with openvino backend"
return _pool(
inputs,
pool_size,
ov_opset.avg_pool,
strides,
padding,
data_format,
exclude_pad=True,
)


def _pool(
inputs,
pool_size,
pooling_func,
strides=None,
padding="valid",
data_format=None,
**kwargs,
):
data_format = backend.standardize_data_format(data_format)
inputs = get_ov_output(inputs)

num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2
if isinstance(pool_size, int):
pool_size = [pool_size] * num_spatial_dims

if strides is None:
strides = pool_size

strides = _adjust_strides_dilation(strides, num_spatial_dims)
pad_mode, pads_begin, pads_end = _adjust_padding(padding)
inputs = _adjust_input(inputs, num_spatial_dims, data_format)
pool_kwargs = {
"kernel_shape": pool_size,
"strides": strides,
"auto_pad": pad_mode,
"pads_begin": pads_begin,
"pads_end": pads_end,
**kwargs,
}
pooled = pooling_func(inputs, **pool_kwargs).output(0)
adjusted_pooled = _adjust_outputs(pooled, num_spatial_dims, data_format)
return OpenVINOKerasTensor(adjusted_pooled)


def _adjust_strides_dilation(
x,
num_spatial_dims,
Expand Down Expand Up @@ -374,9 +433,22 @@ def conv_transpose(


def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
raise NotImplementedError(
"`one_hot` is not supported with openvino backend"
)
if sparse:
raise ValueError("`sparse=True` is not supported with openvino backend")
x = get_ov_output(x)
if dtype is None:
dtype = backend.floatx()
ov_dtype = OPENVINO_DTYPES[dtype]
on_value = get_ov_output(1, ov_dtype)
off_value = get_ov_output(0, ov_dtype)
one_hot_encoded = ov_opset.one_hot(
x,
depth=num_classes,
axis=axis,
on_value=on_value,
off_value=off_value,
).output(0)
return OpenVINOKerasTensor(one_hot_encoded)


def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
Expand Down Expand Up @@ -465,9 +537,14 @@ def batch_normalization(


def ctc_loss(target, output, target_length, output_length, mask_index=0):
raise NotImplementedError(
"`ctc_loss` is not supported with openvino backend"
target = get_ov_output(target)
output = get_ov_output(output)
target_length = get_ov_output(target_length)
output_length = get_ov_output(output_length)
ctc_loss_ = ov_opset.ctc_loss(
output, output_length, target, target_length, blank_index=mask_index
)
return OpenVINOKerasTensor(ctc_loss_.output(0))


def ctc_decode(
Expand Down