Skip to content

Commit 6af5453

Browse files
[MI35X] Fix meta mha 950 error to make guard return bool func itself (ROCm#1170)
* Make bool guard return func itself * make all return func * write querysol fake
1 parent ddcc4ea commit 6af5453

2 files changed

Lines changed: 25 additions & 4 deletions

File tree

aiter/jit/utils/torch_guard.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ def fake_impl(dummy_tensor, *args, **kwargs):
106106
if gen_fake is not None:
107107
return gen_fake(*args, **kwargs)
108108
return func(*args, **kwargs)
109-
default_values = {int: 0, bool: True, float: 0.0}
110109

111-
default_value = default_values.get(return_annotation, None)
112-
return out, default_value
110+
if gen_fake is not None:
111+
return out, gen_fake(*args, **kwargs)
112+
return out, func(*args, **kwargs)
113113

114114
if is_torch_equal_or_newer("2.8.0"):
115115
tags = ()

aiter/tuned_gemm.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,28 @@ def query_sol_core(
9696
return solution_idx
9797

9898

99-
@torch_compile_guard()
99+
def query_sol_fake(
100+
m: int, n: int, k: int, bias: bool, dtype: str, otype: str, scaleAB: bool = False
101+
) -> int:
102+
global solids, solMap, soltype
103+
# soltype = None
104+
solution_idx = 0
105+
cu_count = get_cu_num()
106+
if dtype in [dtypes.fp16, dtypes.bf16] and k % 8 == 0:
107+
if (
108+
((m == 1 and n <= 2 * cu_count) or (m > 1 and m <= 4 and n <= cu_count))
109+
and k <= 9216
110+
or (m > 4 and m <= 8 and n <= cu_count)
111+
and k <= 5120
112+
or (m > 8 and m <= 16 and n <= cu_count)
113+
and k <= 256
114+
):
115+
soltype, solution_idx = 3, 2
116+
117+
return solution_idx
118+
119+
120+
@torch_compile_guard(gen_fake=query_sol_fake)
100121
def query_sol(
101122
m: int, n: int, k: int, bias: bool, dtype: str, otype: str, scaleAB: bool = False
102123
) -> int:

0 commit comments

Comments
 (0)