Skip to content

Commit c52d71e

Browse files
[OpenVINO Backend] support ops.split (#21296)
1 parent 8f67bab commit c52d71e

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

keras/src/backend/openvino/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,13 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
450450
def convert_to_numpy(x):
451451
if isinstance(x, np.ndarray):
452452
return x
453-
elif isinstance(x, (int, float, list, tuple)):
453+
elif isinstance(x, (int, float)):
454454
return np.array(x)
455+
elif isinstance(x, (list, tuple)):
456+
x_new = []
457+
for elem in x:
458+
x_new.append(convert_to_numpy(elem))
459+
return np.array(x_new)
455460
elif np.isscalar(x):
456461
return x
457462
elif isinstance(x, ov.Tensor):

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ NumpyDtypeTest::test_round
4848
NumpyDtypeTest::test_searchsorted
4949
NumpyDtypeTest::test_signbit
5050
NumpyDtypeTest::test_sort
51-
NumpyDtypeTest::test_split
5251
NumpyDtypeTest::test_sqrt
5352
NumpyDtypeTest::test_std
5453
NumpyDtypeTest::test_subtract
@@ -112,7 +111,6 @@ NumpyOneInputOpsCorrectnessTest::test_signbit
112111
NumpyOneInputOpsCorrectnessTest::test_size
113112
NumpyOneInputOpsCorrectnessTest::test_slogdet
114113
NumpyOneInputOpsCorrectnessTest::test_sort
115-
NumpyOneInputOpsCorrectnessTest::test_split
116114
NumpyOneInputOpsCorrectnessTest::test_sqrt_int32
117115
NumpyOneInputOpsCorrectnessTest::test_squeeze
118116
NumpyOneInputOpsCorrectnessTest::test_std

keras/src/backend/openvino/numpy.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1340,7 +1340,49 @@ def sort(x, axis=-1):
13401340

13411341

13421342
def split(x, indices_or_sections, axis=0):
1343-
raise NotImplementedError("`split` is not supported with openvino backend")
1343+
x = get_ov_output(x)
1344+
axis_tensor = ov_opset.constant(axis, dtype=Type.i32).output(0)
1345+
1346+
shape_tensor = ov_opset.shape_of(x)
1347+
axis_i32 = ov_opset.constant([axis], dtype=Type.i32)
1348+
dim_at_axis_tensor = ov_opset.gather(
1349+
shape_tensor, axis_i32, ov_opset.constant(0, dtype=Type.i32)
1350+
)
1351+
1352+
if isinstance(indices_or_sections, int):
1353+
num_splits = indices_or_sections
1354+
splits = ov_opset.split(x, axis_tensor, num_splits=num_splits)
1355+
result = []
1356+
for i in range(num_splits):
1357+
result.append(OpenVINOKerasTensor(splits.output(i)))
1358+
return result
1359+
1360+
if isinstance(indices_or_sections, (list, tuple, np.ndarray)):
1361+
indices = list(indices_or_sections)
1362+
split_lengths = []
1363+
split_lengths.append(indices[0])
1364+
for i in range(1, len(indices)):
1365+
split_lengths.append(indices[i] - indices[i - 1])
1366+
1367+
last_index_tensor = ov_opset.constant(indices[-1], dtype=Type.i64)
1368+
remaining_length_tensor = ov_opset.subtract(
1369+
dim_at_axis_tensor, last_index_tensor
1370+
)
1371+
1372+
length_parts = []
1373+
length_parts.append(ov_opset.constant(split_lengths, dtype=Type.i64))
1374+
length_parts.append(remaining_length_tensor)
1375+
length_tensor = ov_opset.concat(length_parts, axis=0)
1376+
1377+
splits = ov_opset.variadic_split(x, axis_tensor, length_tensor)
1378+
result = []
1379+
for i in range(len(split_lengths) + 1):
1380+
result.append(OpenVINOKerasTensor(splits.output(i)))
1381+
return result
1382+
1383+
raise TypeError(
1384+
f"unsupported type of indices_or_sections: {type(indices_or_sections)}"
1385+
)
13441386

13451387

13461388
def stack(x, axis=0):

0 commit comments

Comments
 (0)