-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathresnet18_int8_schedule.txt
More file actions
93 lines (93 loc) · 9.25 KB
/
resnet18_int8_schedule.txt
File metadata and controls
93 lines (93 loc) · 9.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
primfn(placeholder_4: handle, placeholder_5: handle, placeholder_6: handle, placeholder_7: handle, T_relu_1: handle) -> ()
attr = {"global_symbol": "fused_cast_fixed_point_multiply_nn_conv2d_cast_fixed_point_multiply_add_cast_fix_1476761249106760241__3", "tir.noalias": True}
buffers = {placeholder_3: Buffer(placeholder_8: handle, int32, [1, 16, 1, 1, 4], []),
T_relu: Buffer(T_relu_2: handle, int32, [32, 16, 56, 56, 4], []),
placeholder_2: Buffer(placeholder_9: handle, int8, [32, 16, 56, 56, 4], []),
placeholder: Buffer(placeholder_10: handle, int32, [32, 16, 56, 56, 4], []),
placeholder_1: Buffer(placeholder_11: handle, int8, [16, 16, 3, 3, 4, 4], [])}
buffer_map = {placeholder_4: placeholder, placeholder_6: placeholder_1, placeholder_5: placeholder_2, T_relu_1: T_relu, placeholder_7: placeholder_3} {
attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 32;
attr [compute: handle] "storage_scope" = "local";
allocate(compute, int32, [32]);
attr [pad_data.shared: handle] "storage_scope" = "shared";
allocate(pad_data.shared, int8x4, [200]);
attr [placeholder.shared: handle] "storage_scope" = "shared";
allocate(placeholder.shared, int8x4, [576]);
attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 1;
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 49;
attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 16;
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 8 {
for (oh.init: int32, 0, 8) "unroll" {
for (oc_block.init: int32, 0, 4) "unroll" {
compute[((oh.init*4) + oc_block.init)] = 0
}
}
attr [IterVar(threadIdx.z_1: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 16;
attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 8;
if (((threadIdx.y_1*8) + threadIdx.x_1) < 100) {
if (threadIdx.y_1 < 13) {
pad_data.shared[ramp(((threadIdx.y_1*32) + (threadIdx.x_1*4)), 1, 4)] = @tir.if_then_else(((((1 <= ((floordiv(blockIdx.x, 7)*8) + floordiv(((threadIdx.y_1*8) + threadIdx.x_1), 10))) && (((floordiv(blockIdx.x, 7)*8) + floordiv(((threadIdx.y_1*8) + threadIdx.x_1), 10)) < 57)) && (1 <= ((floormod(blockIdx.x, 7)*8) + floormod(((threadIdx.y_1*8) + threadIdx.x_1), 10)))) && (((floormod(blockIdx.x, 7)*8) + floormod(((threadIdx.y_1*8) + threadIdx.x_1), 10)) < 57)), (int8x4*)placeholder_9[ramp(((((((blockIdx.z*200704) + (floordiv(blockIdx.x, 7)*1792)) + (floordiv(((threadIdx.y_1*8) + threadIdx.x_1), 10)*224)) + (floormod(blockIdx.x, 7)*32)) + (floormod(((threadIdx.y_1*8) + threadIdx.x_1), 10)*4)) - 228), 1, 4)], broadcast(0i8, 4), dtype=int8x4)
}
}
for (ic_chunk.outer.outer: int32, 0, 15) {
attr [pad_data.shared] "double_buffer_write" = 1;
attr [IterVar(threadIdx.z_1, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y_1, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 16;
attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 8;
if (((threadIdx.y_1*8) + threadIdx.x_1) < 100) {
if (threadIdx.y_1 < 13) {
pad_data.shared[ramp((((floormod((ic_chunk.outer.outer + 1), 2)*400) + (threadIdx.y_1*32)) + (threadIdx.x_1*4)), 1, 4)] = @tir.if_then_else(((((1 <= ((floordiv(blockIdx.x, 7)*8) + floordiv(((threadIdx.y_1*8) + threadIdx.x_1), 10))) && (((floordiv(blockIdx.x, 7)*8) + floordiv(((threadIdx.y_1*8) + threadIdx.x_1), 10)) < 57)) && (1 <= ((floormod(blockIdx.x, 7)*8) + floormod(((threadIdx.y_1*8) + threadIdx.x_1), 10)))) && (((floormod(blockIdx.x, 7)*8) + floormod(((threadIdx.y_1*8) + threadIdx.x_1), 10)) < 57)), (int8x4*)placeholder_9[ramp((((((((blockIdx.z*200704) + (ic_chunk.outer.outer*12544)) + (floordiv(blockIdx.x, 7)*1792)) + (floordiv(((threadIdx.y_1*8) + threadIdx.x_1), 10)*224)) + (floormod(blockIdx.x, 7)*32)) + (floormod(((threadIdx.y_1*8) + threadIdx.x_1), 10)*4)) + 12316), 1, 4)], broadcast(0i8, 4), dtype=int8x4)
}
}
for (ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer: int32, 0, 5) "unroll" {
attr [IterVar(threadIdx.z_2: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 16;
attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 8;
if ((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer*32) + (threadIdx.y_2*2)) + floordiv(threadIdx.x_2, 4)) < 144) {
if ((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer*128) + (threadIdx.y_2*8)) + threadIdx.x_2) < 576) {
if (((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer*16) + threadIdx.y_2) < 72) {
placeholder.shared[ramp((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer*512) + (threadIdx.y_2*32)) + (threadIdx.x_2*4)), 1, 4)] = (int8x4*)placeholder_11[ramp(((((floordiv((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer*32) + (threadIdx.y_2*2)) + floordiv(threadIdx.x_2, 4)), 9)*2304) + (ic_chunk.outer.outer*144)) + (floormod((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer*32) + (threadIdx.y_2*2)) + floordiv(threadIdx.x_2, 4)), 9)*16)) + (floormod(threadIdx.x_2, 4)*4)), 1, 4)]
}
}
}
}
for (kw.inner: int32, 0, 3) "unroll" {
for (kh.inner: int32, 0, 3) "unroll" {
for (oh: int32, 0, 8) "unroll" {
for (oc_block: int32, 0, 4) "unroll" {
compute[((oh*4) + oc_block)] = @tir.call_pure_extern("__dp4a", (int8x4*)pad_data.shared[ramp((((((floormod(ic_chunk.outer.outer, 2)*400) + (oh*40)) + (kh.inner*40)) + (threadIdx.x*4)) + (kw.inner*4)), 1, 4)], (int8x4*)placeholder.shared[ramp(((((threadIdx.y*144) + (kh.inner*48)) + (kw.inner*16)) + (oc_block*4)), 1, 4)], (int32*)compute[((oh*4) + oc_block)], dtype=int32)
}
}
}
}
}
for (ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer_1: int32, 0, 5) "unroll" {
attr [IterVar(threadIdx.z_2, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
attr [IterVar(threadIdx.y_2, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 16;
attr [IterVar(threadIdx.x_2, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 8;
if ((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer_1*32) + (threadIdx.y_2*2)) + floordiv(threadIdx.x_2, 4)) < 144) {
if ((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer_1*128) + (threadIdx.y_2*8)) + threadIdx.x_2) < 576) {
if (((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer_1*16) + threadIdx.y_2) < 72) {
placeholder.shared[ramp((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer_1*512) + (threadIdx.y_2*32)) + (threadIdx.x_2*4)), 1, 4)] = (int8x4*)placeholder_11[ramp(((((floordiv((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer_1*32) + (threadIdx.y_2*2)) + floordiv(threadIdx.x_2, 4)), 9)*2304) + (floormod((((ax0.ax1.fused.ax2.fused.ax3.fused.ax4.fused.ax5.outer.fused.outer.outer.outer_1*32) + (threadIdx.y_2*2)) + floordiv(threadIdx.x_2, 4)), 9)*16)) + (floormod(threadIdx.x_2, 4)*4)) + 2160), 1, 4)]
}
}
}
}
for (kw.inner_1: int32, 0, 3) "unroll" {
for (kh.inner_1: int32, 0, 3) "unroll" {
for (oh_1: int32, 0, 8) "unroll" {
for (oc_block_1: int32, 0, 4) "unroll" {
compute[((oh_1*4) + oc_block_1)] = @tir.call_pure_extern("__dp4a", (int8x4*)pad_data.shared[ramp((((((oh_1*40) + (kh.inner_1*40)) + (threadIdx.x*4)) + (kw.inner_1*4)) + 400), 1, 4)], (int8x4*)placeholder.shared[ramp(((((threadIdx.y*144) + (kh.inner_1*48)) + (kw.inner_1*16)) + (oc_block_1*4)), 1, 4)], (int32*)compute[((oh_1*4) + oc_block_1)], dtype=int32)
}
}
}
}
for (ax2.inner.inner.inner: int32, 0, 8) "unroll" {
for (ax4: int32, 0, 4) "unroll" {
T_relu_2[(((((((blockIdx.z*200704) + (threadIdx.y*12544)) + (floordiv(blockIdx.x, 7)*1792)) + (ax2.inner.inner.inner*224)) + (floormod(blockIdx.x, 7)*32)) + (threadIdx.x*4)) + ax4)] = max((@tir.q_multiply_shift((int32*)placeholder_10[(((((((blockIdx.z*200704) + (threadIdx.y*12544)) + (floordiv(blockIdx.x, 7)*1792)) + (ax2.inner.inner.inner*224)) + (floormod(blockIdx.x, 7)*32)) + (threadIdx.x*4)) + ax4)], 1807194190, 31, -1, dtype=int32) + @tir.q_multiply_shift((@tir.q_multiply_shift((int32*)compute[((ax2.inner.inner.inner*4) + ax4)], 1574726727, 31, 16, dtype=int32) + (int32*)placeholder_8[((threadIdx.y*4) + ax4)]), 1137624986, 31, 1, dtype=int32)), 0)
}
}
}
}