@@ -1141,8 +1141,9 @@ def _generate_elf_launcher(constants, signature, kernel_name):
11411141"""
11421142
11431143
1144- def compile_module (launcher_src , kernel_placeholder_name , output_format = "xclbin" ,
1145- actual_sizes = None ):
1144+ def compile_module (
1145+ launcher_src , kernel_placeholder_name , output_format = "xclbin" , actual_sizes = None
1146+ ):
11461147 py_version = sys .version_info
11471148 if platform .system () == "Windows" :
11481149 py_include_dir = os .path .join (sys .base_prefix , "include" )
@@ -1186,8 +1187,9 @@ def launch(
11861187 air_proj_path = _get_air_project_path ()
11871188 os .makedirs (air_proj_path , exist_ok = True )
11881189 Path (os .path .join (air_proj_path , "asm_src.mlir" )).write_bytes (asm_src )
1189- air_output = _ttshared_to_air (asm_src , gridX , gridY , gridZ ,
1190- actual_sizes = actual_sizes )
1190+ air_output = _ttshared_to_air (
1191+ asm_src , gridX , gridY , gridZ , actual_sizes = actual_sizes
1192+ )
11911193 with open (Path (os .path .join (air_proj_path , "asm_air_output.mlir" )), "w" ) as f :
11921194 f .write (str (air_output ))
11931195
@@ -1408,10 +1410,16 @@ def __init__(self, src, metadata):
14081410 n_val = raw_constants .get ((n_idx ,))
14091411 if m_val is not None and n_val is not None :
14101412 # Check if BLOCK_SIZE_M/N are available to determine alignment
1411- bsm_idx = (arg_names .index ("BLOCK_SIZE_M" )
1412- if "BLOCK_SIZE_M" in arg_names else None )
1413- bsn_idx = (arg_names .index ("BLOCK_SIZE_N" )
1414- if "BLOCK_SIZE_N" in arg_names else None )
1413+ bsm_idx = (
1414+ arg_names .index ("BLOCK_SIZE_M" )
1415+ if "BLOCK_SIZE_M" in arg_names
1416+ else None
1417+ )
1418+ bsn_idx = (
1419+ arg_names .index ("BLOCK_SIZE_N" )
1420+ if "BLOCK_SIZE_N" in arg_names
1421+ else None
1422+ )
14151423 bsm = raw_constants .get ((bsm_idx ,)) if bsm_idx is not None else None
14161424 bsn = raw_constants .get ((bsn_idx ,)) if bsn_idx is not None else None
14171425 needs_padding = True
@@ -1423,7 +1431,9 @@ def __init__(self, src, metadata):
14231431 # Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name
14241432 # in the following launch function.
14251433 self .launch = compile_module (
1426- launcher_src , kernel_placeholder_name , self .output_format ,
1434+ launcher_src ,
1435+ kernel_placeholder_name ,
1436+ self .output_format ,
14271437 actual_sizes = actual_sizes ,
14281438 )
14291439
0 commit comments