File tree 1 file changed +15
-4
lines changed
tests/keras2onnx_unit_tests
1 file changed +15
-4
lines changed Original file line number Diff line number Diff line change 13
13
K = keras .backend
14
14
15
15
16
+ def is_keras_3 ():
17
+ if hasattr (keras , '__version__' ):
18
+ return keras .__version__ .startswith ("3." )
19
+
20
+ return False
21
+
16
22
@pytest .fixture (scope = 'function' )
17
23
def runner ():
18
24
np .random .seed (42 )
@@ -25,10 +31,15 @@ def runner():
25
31
def runner_func (* args , ** kwargs ):
26
32
return run_onnx_runtime (* args , model_files , ** kwargs )
27
33
28
- # Ensure Keras layer naming is reset for each function
29
- K .reset_uids ()
30
- # Reset the TensorFlow session to avoid resource leaking between tests
31
- K .clear_session ()
34
+ if is_keras_3 ():
35
+ import tf_keras
36
+ tf_keras .backend .reset_uids ()
37
+ tf_keras .backend .clear_session ()
38
+ else :
39
+ # Ensure Keras layer naming is reset for each function
40
+ K .reset_uids ()
41
+ # Reset the TensorFlow session to avoid resource leaking between tests
42
+ K .clear_session ()
32
43
33
44
# Provide wrapped run_onnx_runtime function
34
45
yield runner_func
You can’t perform that action at this time.
0 commit comments