Skip to content

Commit 6d1f789

Browse files
committed
Call tf_keras for keras 3.x.
Signed-off-by: Jay Zhang <[email protected]>
1 parent 9bcef5e commit 6d1f789

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

tests/keras2onnx_unit_tests/conftest.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
K = keras.backend
1414

1515

16+
def is_keras_3():
17+
if hasattr(keras, '__version__'):
18+
return keras.__version__.startswith("3.")
19+
20+
return False
21+
1622
@pytest.fixture(scope='function')
1723
def runner():
1824
np.random.seed(42)
@@ -25,10 +31,15 @@ def runner():
2531
def runner_func(*args, **kwargs):
2632
return run_onnx_runtime(*args, model_files, **kwargs)
2733

28-
# Ensure Keras layer naming is reset for each function
29-
K.reset_uids()
30-
# Reset the TensorFlow session to avoid resource leaking between tests
31-
K.clear_session()
34+
if is_keras_3():
35+
import tf_keras
36+
tf_keras.backend.reset_uids()
37+
tf_keras.backend.clear_session()
38+
else:
39+
# Ensure Keras layer naming is reset for each function
40+
K.reset_uids()
41+
# Reset the TensorFlow session to avoid resource leaking between tests
42+
K.clear_session()
3243

3344
# Provide wrapped run_onnx_runtime function
3445
yield runner_func

0 commit comments

Comments
 (0)