@@ -180,7 +180,6 @@ class CompilationResult:
180180 binary : str
181181 name : str
182182 shared_mem_bytes : int
183- cluster_dims : tuple
184183 ttgir : str | None
185184 llir : str | None
186185
@@ -261,14 +260,12 @@ def compile_ttir_to_ptx_inplace(
261260 if cuda_options .debug :
262261 print (ptx )
263262 name = metadata ["name" ]
264- cluster_dims = metadata ["cluster_dims" ]
265263 ttgir = str (ttgir ) if _JAX_TRITON_DUMP_DIR else None
266264 llir = str (llir ) if _JAX_TRITON_DUMP_DIR else None
267265 return CompilationResult (
268266 binary = ptx ,
269267 name = name ,
270268 shared_mem_bytes = shared_mem_bytes ,
271- cluster_dims = cluster_dims ,
272269 ttgir = ttgir ,
273270 llir = llir ,
274271 )
@@ -306,9 +303,6 @@ def compile_ttir_to_hsaco_inplace(
306303 name = metadata ["name" ]
307304 ttgir = str (ttgir ) if _JAX_TRITON_DUMP_DIR else None
308305 llir = str (llir ) if _JAX_TRITON_DUMP_DIR else None
309- # cluster dims are NOT useful on hip backend.
310- # We just fill up with some value for API compatibility
311- cluster_dims = (0 , 0 , 0 )
312306 # Instead of passing hsaco which are "bytes", we first write
313307 # to a file and then pass the "string" path. This is needed because
314308 # nanobind doesn't automatically convert between bytes and string.
@@ -320,7 +314,6 @@ def compile_ttir_to_hsaco_inplace(
320314 binary = hsaco_path ,
321315 name = name ,
322316 shared_mem_bytes = shared_mem_bytes ,
323- cluster_dims = cluster_dims ,
324317 ttgir = ttgir ,
325318 llir = llir ,
326319 )
@@ -469,18 +462,17 @@ def get_or_create_triton_kernel(
469462 ) as f :
470463 f .write (
471464 f"{ kernel_name } : shared_mem_bytes:"
472- f" { compilation_result .shared_mem_bytes } , cluster_dims:"
473- f" { compilation_result .cluster_dims } \n "
465+ f" { compilation_result .shared_mem_bytes } \n "
474466 )
475467
476468 kernel = triton_kernel_call_lib .TritonKernel (
477469 kernel_name ,
478470 num_warps ,
471+ num_ctas ,
479472 compilation_result .shared_mem_bytes ,
480473 compilation_result .binary ,
481474 ttir ,
482475 compute_capability ,
483- * compilation_result .cluster_dims ,
484476 )
485477
486478 _COMPILED_KERNEL_CACHE [cache_key ] = kernel
0 commit comments