diff --git a/tools/generate_op_accuracy_tests.py b/tools/generate_op_accuracy_tests.py index 84ca22d86..0ffc8b9a4 100644 --- a/tools/generate_op_accuracy_tests.py +++ b/tools/generate_op_accuracy_tests.py @@ -289,6 +289,7 @@ def _build_code_from_aten_ttnn_graphs(aten_graph, ttnn_graph, output_nodes): # comment out signature if not the first graph graph_code = [forward_signature] if len(output_nodes) == 0 else [" # " + forward_signature] graph_code.append(" device = ttnn.open_device(device_id=0, l1_small_size=16384)") + graph_code.append(""" inf = float("inf")""") for node in aten_all_nodes: if node.op == "output": output_nodes.append(node.args[0])