Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,32 @@ def cast(x, dtype):


def cond(pred, true_fn, false_fn):
raise NotImplementedError("`cond` is not supported with openvino backend")
true_val = true_fn()
false_val = false_fn()
Comment on lines +853 to +854
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... so the whole point of cond is that you should never have to evaluate the value of both true_fn and false_fn.

But I suppose this is a limitation of OpenVino, right?


if true_val is None:
return None
Comment on lines +856 to +857
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed?


if isinstance(pred, bool):
pred_ov = ov_opset.constant(pred, Type.boolean).output(0)
else:
pred_ov = get_ov_output(pred)
if pred_ov.get_element_type() != Type.boolean:
pred_ov = ov_opset.convert(pred_ov, Type.boolean).output(0)

def _select(t, f):
t_ov, f_ov = align_operand_types(
get_ov_output(t), get_ov_output(f), "cond"
)
return OpenVINOKerasTensor(
ov_opset.select(pred_ov, t_ov, f_ov).output(0)
)

if isinstance(true_val, (list, tuple)):
return type(true_val)(
_select(t, f) for t, f in zip(true_val, false_val)
)
return _select(true_val, false_val)


def vectorized_map(function, elements):
Expand Down
16 changes: 0 additions & 16 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,12 @@ CoreOpsCallsTests::test_scan_basic_call
CoreOpsCallsTests::test_switch_basic_call
CoreOpsCallsTests::test_unstack_basic_functionality
CoreOpsCorrectnessTest::test_associative_scan
CoreOpsCorrectnessTest::test_cond
CoreOpsCorrectnessTest::test_fori_loop
CoreOpsCorrectnessTest::test_map
CoreOpsCorrectnessTest::test_scan
CoreOpsCorrectnessTest::test_switch
CoreOpsCorrectnessTest::test_unstack
CoreOpsCorrectnessTest::test_vectorized_map
CosineDecayRestartsTest::test_alpha
CosineDecayRestartsTest::test_decay
CosineDecayRestartsTest::test_float64
CosineDecayRestartsTest::test_mmul
CosineDecayRestartsTest::test_tmul
CosineDecayTest::test_warmup
CosineDecayTest::test_warmup_decay
Comment on lines -65 to -71
Copy link
Contributor Author

@goyaladitya05 goyaladitya05 Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These have no use case for OpenVINO backend (it being inference-only), these will never be used. But after the implementation of cond, the tests are passing.

After this change, the tests will run in the CI (since they do not have @pytest.mark.requires_trainable_backend decorator) and pass, since they are mathematical tests.

Should I keep them excluded or enable them ?
Or maybe we could have @pytest.mark.requires_trainable_backend decorators for LR schedule releted tests, since they will be used in training only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they pass, let's run them.

The @pytest.mark.requires_trainable_backend was really a way to say "requires auto differentiation (gradients)" support. I think it's fine to still run tests that are training related but don't use gradients like LR schedules.

CTCTest::test_correctness
CTCTest::test_dtype_arg
DenseTest::test_dense_quantize_config_int4
Expand Down Expand Up @@ -162,10 +154,6 @@ GRUTest::test_masking
GRUTest::test_pass_initial_state
GRUTest::test_pass_return_state
GRUTest::test_statefulness
HingeTest::test_dtype_arg
HingeTest::test_unweighted
HingeTest::test_weighted
HingeTest::test_zero_weighted
HistogramTest::test_histogram_predict_jit_compile_false
HistogramTest::test_histogram_predict_jit_compile_true
ImageDatasetFromDirectoryTest::test_image_dataset_from_directory_binary_grain
Expand Down Expand Up @@ -414,10 +402,6 @@ SparseCategoricalCrossentropyTest::test_sample_weighted
SparseCategoricalCrossentropyTest::test_scalar_weighted
SparseCategoricalCrossentropyTest::test_unweighted
SpectralNormalizationTest::test_apply_layer
SquaredHingeTest::test_dtype_arg
SquaredHingeTest::test_unweighted
SquaredHingeTest::test_weighted
SquaredHingeTest::test_zero_weighted
StackedRNNTest::test_correctness_single_state_stack
StackedRNNTest::test_correctness_two_states_stack
StackedRNNTest::test_return_state_stacked_lstm_cell
Expand Down
Loading