Skip to content

Commit 4d258ba

Browse files
olegshyshkovGoogle-ML-Automation
authored andcommitted
1 parent 6b9682a commit 4d258ba

1 file changed

Lines changed: 2 additions & 10 deletions

File tree

jax_triton/triton_lib.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ class CompilationResult:
180180
binary: str
181181
name: str
182182
shared_mem_bytes: int
183-
cluster_dims: tuple
184183
ttgir: str | None
185184
llir: str | None
186185

@@ -261,14 +260,12 @@ def compile_ttir_to_ptx_inplace(
261260
if cuda_options.debug:
262261
print(ptx)
263262
name = metadata["name"]
264-
cluster_dims = metadata["cluster_dims"]
265263
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
266264
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
267265
return CompilationResult(
268266
binary=ptx,
269267
name=name,
270268
shared_mem_bytes=shared_mem_bytes,
271-
cluster_dims=cluster_dims,
272269
ttgir=ttgir,
273270
llir=llir,
274271
)
@@ -306,9 +303,6 @@ def compile_ttir_to_hsaco_inplace(
306303
name = metadata["name"]
307304
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
308305
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
309-
# cluster dims are NOT useful on hip backend.
310-
# We just fill up with some value for API compatibility
311-
cluster_dims = (0, 0, 0)
312306
# Instead of passing hsaco which are "bytes", we first write
313307
# to a file and then pass the "string" path. This is needed because
314308
# nanobind doesn't automatically convert between bytes and string.
@@ -320,7 +314,6 @@ def compile_ttir_to_hsaco_inplace(
320314
binary=hsaco_path,
321315
name=name,
322316
shared_mem_bytes=shared_mem_bytes,
323-
cluster_dims=cluster_dims,
324317
ttgir=ttgir,
325318
llir=llir,
326319
)
@@ -469,18 +462,17 @@ def get_or_create_triton_kernel(
469462
) as f:
470463
f.write(
471464
f"{kernel_name}: shared_mem_bytes:"
472-
f" {compilation_result.shared_mem_bytes}, cluster_dims:"
473-
f" {compilation_result.cluster_dims}\n"
465+
f" {compilation_result.shared_mem_bytes}\n"
474466
)
475467

476468
kernel = triton_kernel_call_lib.TritonKernel(
477469
kernel_name,
478470
num_warps,
471+
num_ctas,
479472
compilation_result.shared_mem_bytes,
480473
compilation_result.binary,
481474
ttir,
482475
compute_capability,
483-
*compilation_result.cluster_dims,
484476
)
485477

486478
_COMPILED_KERNEL_CACHE[cache_key] = kernel

0 commit comments

Comments
 (0)