-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathmatmul_generator.py
73 lines (63 loc) · 2.05 KB
/
matmul_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import re
def get_higher_order_element_type(element_type):
if element_type[0] in ["i", "f"]:
assert element_type[1:].isdigit(), f"support for {element_type} is missing"
bit_width = int(element_type[1:])
return f"{element_type[0]}{bit_width*2}"
assert False, f"support for {element_type} is missing"
def generate_matmul_test(
output_fn,
input_fn,
m,
n,
k,
lhs_rhs_type,
acc_type,
b=0,
m0=0,
n0=0,
k0=0,
trunci_scale=None,
trunci_shift=None,
):
"""
Generate mlir file (output_fn) from the template file (input_fn).
"""
replace = dict({})
replace["M"] = m
replace["N"] = n
replace["K"] = k
replace["TYPE1"] = lhs_rhs_type
replace["TYPE2"] = acc_type
# Only used for Matmul+Trunc via scaling.
replace["TYPE_MUL_RESULT"] = get_higher_order_element_type(acc_type)
acc_is_int = acc_type[0] == "i"
replace["ZERO"] = 0 if acc_is_int else 0.0
replace["ADD"] = "arith.addi" if acc_is_int else "arith.addf"
replace["MUL"] = "arith.muli" if acc_is_int else "arith.mulf"
replace["EXT"] = "arith.extsi" if acc_is_int else "arith.extf"
# This is only used for batch matmul.
replace["B"] = b
replace["TRUNCI_SCALE"] = trunci_scale
replace["TRUNCI_SHIFT"] = trunci_shift
# m0, n0, k0 are only used for matmul4d as inner dim sizes.
replace["M0"] = m0
replace["N0"] = n0
replace["K0"] = k0
# matmul4d outer dim sizes can be calculated by `total_size/inner_dim_size`.
if m0 != 0:
replace["M1"] = int(m / m0)
if n0 != 0:
replace["N1"] = int(n / n0)
if k0 != 0:
replace["K1"] = int(k / k0)
key_map = map(lambda s: "${" + s + "}", replace.keys())
key_map_escaped = map(re.escape, key_map)
regex = re.compile("|".join(key_map_escaped))
in_file = open(input_fn, "r")
out_file = open(output_fn, "w")
for line in in_file:
subbed = regex.sub(lambda m: str(replace[m.group(0)[2:-1]]), line)
out_file.write(subbed)
in_file.close()
out_file.close()