forked from act-compiler/act
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathQKV_2.py
More file actions
130 lines (112 loc) · 4.21 KB
/
QKV_2.py
File metadata and controls
130 lines (112 loc) · 4.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""QKV Accelerator ISA Definition for Tutorial 2"""
from taidl import Accelerator
qkv_2 = Accelerator("QKV_2")
# Define Data Models
qkv_2.add_data_model("d1", [128], [64], "bf16")
qkv_2.add_data_model("d2", [64], [64], "bf16")
qkv_2.add_data_model("d3", [128], [64], "bf16")
# Load instructions
instr = qkv_2.add_instruction("load_rm_d1", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]])
instr.set_outputs([["d1", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY load_rm_d1 {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")
instr = qkv_2.add_instruction("load_rm_d2", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]])
instr.set_outputs([["d2", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY load_rm_d2 {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")
instr = qkv_2.add_instruction("load_rm_d3", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]])
instr.set_outputs([["d3", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY load_rm_d3 {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")
instr = qkv_2.add_instruction("store_rm_d2", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d2", ["@a.addr_in"], ["@c.n"]]])
instr.set_outputs([["d0", ["@a.addr_out"], ["@c.n * 128"]]])
instr.add_semantics("""
ENTRY store_rm_d2 {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = u8[`@c.n`,64,2] bitcast_convert(%In1);
ROOT %Out0 = u8[`@c.n*128`] reshape(%a);
}
""")
# Compute instructions
instr = qkv_2.add_instruction("gemm_d1_d3", [], ["addr_1", "addr_2", "addr_out"])
instr.set_inputs([["d1", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]]])
instr.set_outputs([["d2", ["@a.addr_out"], ["64"]]])
instr.add_semantics("""
ENTRY gemm_d1_d3 {
%In1 = bf16[64,64] parameter(0);
%In2 = bf16[64,64] parameter(1);
ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0};
}
""")
instr = qkv_2.add_instruction("gemm_d3_d3", [], ["addr_1", "addr_2", "addr_out"])
instr.set_inputs([["d3", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]]])
instr.set_outputs([["d2", ["@a.addr_out"], ["64"]]])
instr.add_semantics("""
ENTRY gemm_d3_d3 {
%In1 = bf16[64,64] parameter(0);
%In2 = bf16[64,64] parameter(1);
ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0};
}
""")
instr = qkv_2.add_instruction("softmax", ["n"], ["addr"])
instr.set_inputs([["d2", ["@a.addr"], ["@c.n"]]])
instr.set_outputs([["d2", ["@a.addr"], ["@c.n"]]])
instr.add_semantics("""
ENTRY softmax {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = bf16[`@c.n`,64] exponential(%In1);
%reduced = bf16[`@c.n`] reduce_add(%a), dimensions={1};
%b = bf16[`@c.n`,64] broadcast(%reduced), dimensions={0};
ROOT %Out0 = bf16[`@c.n`,64] divide(%a, %b);
}
""")
instr = qkv_2.add_instruction("copy_d2_d1", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d2", ["@a.addr_in"], ["@c.n"]]])
instr.set_outputs([["d1", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY copy_d2_d1 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[`@c.n`,64] copy(%In1);
}
""")
instr = qkv_2.add_instruction("copy_d2_d3", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d2", ["@a.addr_in"], ["@c.n"]]])
instr.set_outputs([["d3", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY copy_d2_d3 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[`@c.n`,64] copy(%In1);
}
""")
instr = qkv_2.add_instruction("transpose_d1_d3", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d1", ["@a.addr_in"], ["@c.n"]]])
instr.set_outputs([["d3", ["@a.addr_out"], ["64"]]])
instr.add_semantics("""
ENTRY transpose_d1_d3 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[64,`@c.n`] transpose(%In1), dimensions={1,0};
}
""")
# Generate programming APIs and test oracle (functional simulator)
qkv_2.generate_oracle()
# Generate compiler backend
qkv_2.generate_backend()