@@ -109,7 +109,7 @@ def _generate_test_code(
109109 """生成单测文件代码"""
110110 # 判断是否为accuracy_error,需要生成CPU和目标设备的对比测试
111111 is_accuracy_error = error_info .get ("error_type" ) == "accuracy_error"
112-
112+
113113 code_lines = [
114114 "import sys" ,
115115 "import os" ,
@@ -136,7 +136,7 @@ def _generate_test_code(
136136 f"api_config = APIConfig({ repr (api_config_str )} )" ,
137137 "" ,
138138 ]
139-
139+
140140 # 如果是accuracy_error,需要生成对比测试,先不设置设备
141141 if not is_accuracy_error :
142142 code_lines .append (f"# 设置目标设备" )
@@ -315,7 +315,7 @@ def _generate_test_code(
315315 kwarg_vars [key ] = f"kwarg_{ key } _non_tensor"
316316
317317 code_lines .append ("" )
318-
318+
319319 # 如果是accuracy_error,直接使用测试类来运行对比测试
320320 if is_accuracy_error :
321321 code_lines .append ("# 使用APITestCustomDeviceVSCPU类来运行CPU与目标设备的对比测试" )
@@ -345,7 +345,7 @@ def _generate_test_code(
345345 code_lines .append (" import traceback" )
346346 code_lines .append (" traceback.print_exc()" )
347347 code_lines .append (" raise" )
348-
348+
349349 else :
350350 # 原有的单设备测试代码
351351 code_lines .append ("# 执行API调用" )
@@ -378,7 +378,9 @@ def _generate_test_code(
378378
379379 if is_tensor_method and tensor_var :
380380 if api_call_parts :
381- api_call = f" output = { tensor_var } .{ method_name } (" + ", " .join (api_call_parts ) + ")"
381+ api_call = (
382+ f" output = { tensor_var } .{ method_name } (" + ", " .join (api_call_parts ) + ")"
383+ )
382384 else :
383385 api_call = f" output = { tensor_var } .{ method_name } ()"
384386 else :
0 commit comments