-
Notifications
You must be signed in to change notification settings - Fork 19.7k
[OpenVINO] Implement cond operator and enable tests for Hinge losses #22409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -850,7 +850,32 @@ def cast(x, dtype): | |
|
|
||
|
|
||
| def cond(pred, true_fn, false_fn): | ||
| raise NotImplementedError("`cond` is not supported with openvino backend") | ||
| true_val = true_fn() | ||
| false_val = false_fn() | ||
|
Comment on lines
+853
to
+854
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... so the whole point of But I suppose this is a limitation of OpenVino, right? |
||
|
|
||
| if true_val is None: | ||
| return None | ||
|
Comment on lines
+856
to
+857
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this really needed? |
||
|
|
||
| if isinstance(pred, bool): | ||
| pred_ov = ov_opset.constant(pred, Type.boolean).output(0) | ||
| else: | ||
| pred_ov = get_ov_output(pred) | ||
| if pred_ov.get_element_type() != Type.boolean: | ||
| pred_ov = ov_opset.convert(pred_ov, Type.boolean).output(0) | ||
|
|
||
| def _select(t, f): | ||
| t_ov, f_ov = align_operand_types( | ||
| get_ov_output(t), get_ov_output(f), "cond" | ||
| ) | ||
| return OpenVINOKerasTensor( | ||
| ov_opset.select(pred_ov, t_ov, f_ov).output(0) | ||
| ) | ||
|
|
||
| if isinstance(true_val, (list, tuple)): | ||
| return type(true_val)( | ||
| _select(t, f) for t, f in zip(true_val, false_val) | ||
| ) | ||
| return _select(true_val, false_val) | ||
|
|
||
|
|
||
| def vectorized_map(function, elements): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,20 +55,12 @@ CoreOpsCallsTests::test_scan_basic_call | |
| CoreOpsCallsTests::test_switch_basic_call | ||
| CoreOpsCallsTests::test_unstack_basic_functionality | ||
| CoreOpsCorrectnessTest::test_associative_scan | ||
| CoreOpsCorrectnessTest::test_cond | ||
| CoreOpsCorrectnessTest::test_fori_loop | ||
| CoreOpsCorrectnessTest::test_map | ||
| CoreOpsCorrectnessTest::test_scan | ||
| CoreOpsCorrectnessTest::test_switch | ||
| CoreOpsCorrectnessTest::test_unstack | ||
| CoreOpsCorrectnessTest::test_vectorized_map | ||
| CosineDecayRestartsTest::test_alpha | ||
| CosineDecayRestartsTest::test_decay | ||
| CosineDecayRestartsTest::test_float64 | ||
| CosineDecayRestartsTest::test_mmul | ||
| CosineDecayRestartsTest::test_tmul | ||
| CosineDecayTest::test_warmup | ||
| CosineDecayTest::test_warmup_decay | ||
|
Comment on lines
-65
to
-71
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These have no use case for OpenVINO backend (it being inference-only), these will never be used. But after the implementation of cond, the tests are passing. After this change, the tests will run in the CI (since they do not have Should I keep them excluded or enable them ?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If they pass, let's run them. The |
||
| CTCTest::test_correctness | ||
| CTCTest::test_dtype_arg | ||
| DenseTest::test_dense_quantize_config_int4 | ||
|
|
@@ -162,10 +154,6 @@ GRUTest::test_masking | |
| GRUTest::test_pass_initial_state | ||
| GRUTest::test_pass_return_state | ||
| GRUTest::test_statefulness | ||
| HingeTest::test_dtype_arg | ||
| HingeTest::test_unweighted | ||
| HingeTest::test_weighted | ||
| HingeTest::test_zero_weighted | ||
| HistogramTest::test_histogram_predict_jit_compile_false | ||
| HistogramTest::test_histogram_predict_jit_compile_true | ||
| ImageDatasetFromDirectoryTest::test_image_dataset_from_directory_binary_grain | ||
|
|
@@ -414,10 +402,6 @@ SparseCategoricalCrossentropyTest::test_sample_weighted | |
| SparseCategoricalCrossentropyTest::test_scalar_weighted | ||
| SparseCategoricalCrossentropyTest::test_unweighted | ||
| SpectralNormalizationTest::test_apply_layer | ||
| SquaredHingeTest::test_dtype_arg | ||
| SquaredHingeTest::test_unweighted | ||
| SquaredHingeTest::test_weighted | ||
| SquaredHingeTest::test_zero_weighted | ||
| StackedRNNTest::test_correctness_single_state_stack | ||
| StackedRNNTest::test_correctness_two_states_stack | ||
| StackedRNNTest::test_return_state_stacked_lstm_cell | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.