2
2
3
3
import operator
4
4
import textwrap
5
+ from dataclasses import dataclass
5
6
6
7
from numba .core import ir
7
8
8
9
10
+ def get_constant (var , ssa_ir ):
11
+ # Get constant value defining this var, else raise error
12
+ assert var .name in ssa_ir ._definitions
13
+ vardef = ssa_ir ._definitions [var .name ][0 ]
14
+ if type (vardef ) != ir .Const :
15
+ raise ValueError ("expected constant variable" )
16
+ return vardef .value
17
+
18
+
19
+ class HeaderInfo :
20
+
21
+ def __init__ (self , header_block ):
22
+ body = header_block .body
23
+ assert len (body ) == 5
24
+ self .phi_var = body [3 ].target
25
+ self .body_id = body [4 ].truebr
26
+ self .next_id = body [4 ].falsebr
27
+
28
+
29
+ class RangeArgs :
30
+
31
+ def __init__ (self , range_call , ssa_ir ):
32
+ args = range_call .value .args
33
+ self .stop = get_constant (args [0 ], ssa_ir )
34
+ self .start = 0
35
+ self .step = 1
36
+ if len (args ) > 1 :
37
+ self .stop = get_constant (args [1 ], ssa_ir )
38
+ if len (args ) > 2 :
39
+ self .step = get_constant (args [2 ], ssa_ir )
40
+
41
+
42
+ @dataclass
43
+ class Loop :
44
+ header_id : int
45
+ header : HeaderInfo
46
+ range : RangeArgs
47
+ inits : list [str ]
48
+
49
+
50
+ def build_loop_from_call (index , body , blocks , ssa_ir ):
51
+ # Build a loop from a range call starting at index
52
+ assert type (body [index + 1 ] == ir .Assign )
53
+ assert type (body [index + 2 ] == ir .Assign )
54
+ assert type (body [index + 3 ] == ir .Jump )
55
+
56
+ header_id = body [index + 3 ].target
57
+ header = HeaderInfo (blocks [header_id ])
58
+ range_args = RangeArgs (body [index ], ssa_ir )
59
+
60
+ # Loop body must start with assigning the local iter var
61
+ loop_body = blocks [header .body_id ].body
62
+ assert loop_body [0 ].value == header .phi_var
63
+
64
+ inits = []
65
+ for instr in loop_body [1 :]:
66
+ if type (instr ) == ir .Assign and not instr .target .is_temp :
67
+ inits .append (instr .target )
68
+ if len (inits ) > 1 :
69
+ raise NotImplementedError ("Multiple iter_args not supported" )
70
+
71
+ return Loop (
72
+ header_id ,
73
+ header ,
74
+ range_args ,
75
+ inits ,
76
+ )
77
+
78
+
79
+ def is_range_call (instr , ssa_ir ):
80
+ # Returns true if the IR instruction is a call to a range
81
+ if type (instr ) != ir .Assign or type (instr .value ) != ir .Expr :
82
+ return False
83
+ if instr .value .op != "call" :
84
+ return False
85
+ func = instr .value .func
86
+
87
+ assert func .name in ssa_ir ._definitions
88
+ func_def_list = ssa_ir ._definitions [func .name ]
89
+ assert len (func_def_list ) == 1
90
+ func_def = func_def_list [0 ]
91
+
92
+ if type (func_def ) != ir .Global :
93
+ return False
94
+ return func_def .name == "range"
95
+
96
+
9
97
class TextualMlirEmitter :
10
98
11
99
def __init__ (self , ssa_ir ):
12
100
self .ssa_ir = ssa_ir
13
101
self .temp_var_id = 0
14
102
self .numba_names_to_ssa_var_names = {}
15
103
self .globals_map = {}
104
+ self .loops = {}
105
+ self .printed_blocks = {}
106
+ self .omit_block_header = {}
16
107
17
108
def emit (self ):
18
109
func_name = self .ssa_ir .func_id .func_name
@@ -25,38 +116,60 @@ def emit(self):
25
116
# TODO(#1162): get inferred or explicit return types
26
117
return_types_str = "i64"
27
118
28
- body = self .emit_body ()
119
+ body = self .emit_blocks ()
29
120
30
121
mlir_func = f"""func.func @{ func_name } ({ args_str } ) -> ({ return_types_str } ) {{
31
122
{ textwrap .indent (body , ' ' )}
32
123
}}
33
124
"""
34
125
return mlir_func
35
126
36
- def emit_body (self ):
127
+ def emit_blocks (self ):
37
128
blocks = self .ssa_ir .blocks
38
- str_blocks = []
39
- first = True
40
129
130
+ # collect loops and block header needs
131
+ self .omit_block_header [sorted (blocks .items ())[0 ][0 ]] = True
41
132
for block_id , block in sorted (blocks .items ()):
42
- instructions = []
43
- for instr in block .body :
44
- result = self .emit_instruction (instr )
45
- if result :
46
- instructions .append (result )
133
+ for i in range (len (block .body )):
134
+ # Detect a range call
135
+ instr = block .body [i ]
136
+ if is_range_call (instr , self .ssa_ir ):
137
+ loop = build_loop_from_call (i , block .body , blocks , self .ssa_ir )
138
+ self .loops [instr .target ] = loop
139
+ self .omit_block_header [loop .header .next_id ] = True
47
140
48
- if first :
49
- first = False
141
+ # print blocks
142
+ str_blocks = []
143
+ for block_id , block in sorted (blocks .items ()):
144
+ if block_id in self .printed_blocks :
145
+ continue
146
+ if block_id in self .omit_block_header :
50
147
block_header = ""
51
148
else :
52
149
block_header = f"^bb{ block_id } :\n "
53
-
54
150
str_blocks .append (
55
- block_header + textwrap .indent (" \n " . join ( instructions ), " " )
151
+ block_header + textwrap .indent (self . emit_block ( block ), " " )
56
152
)
153
+ self .printed_blocks [block_id ] = True
57
154
58
155
return "\n " .join (str_blocks )
59
156
157
+ def emit_block (self , block ):
158
+ instructions = []
159
+ for i in range (len (block .body )):
160
+ instr = block .body [i ]
161
+ if type (instr ) == ir .Assign and instr .target in self .loops :
162
+ # We hit a range call for a loop
163
+ result = self .emit_loop (instr .target )
164
+ instructions .append (result )
165
+ # Exit instructions, should be the end of the block
166
+ break
167
+ else :
168
+ result = self .emit_instruction (instr )
169
+ if result :
170
+ instructions .append (result )
171
+ return "\n " .join (instructions )
172
+
60
173
def emit_instruction (self , instr ):
61
174
match instr :
62
175
case ir .Assign ():
@@ -65,6 +178,10 @@ def emit_instruction(self, instr):
65
178
return self .emit_branch (instr )
66
179
case ir .Return ():
67
180
return self .emit_return (instr )
181
+ case ir .Jump ():
182
+ # TODO ignore
183
+ assert instr .is_terminator
184
+ return
68
185
raise NotImplementedError ("Unsupported instruction: " + str (instr ))
69
186
70
187
def get_or_create_name (self , var ):
@@ -125,6 +242,10 @@ def emit_assign(self, assign):
125
242
case ir .Global ():
126
243
self .globals_map [assign .target .name ] = assign .value .name
127
244
return ""
245
+ case ir .Var ():
246
+ # FIXME: keep track of loop var, and assign
247
+ self .forward_name (from_var = assign .target , to_var = assign .value )
248
+ return ""
128
249
raise NotImplementedError ()
129
250
130
251
def emit_expr (self , expr ):
@@ -134,16 +255,16 @@ def emit_expr(self, expr):
134
255
raise NotImplementedError ()
135
256
elif expr .op == "unary" :
136
257
raise NotImplementedError ()
137
-
138
- # these are all things numba has hooks for upstream, but we didn't implement
139
- # in the prototype
140
-
141
258
elif expr .op == "pair_first" :
142
259
raise NotImplementedError ()
143
260
elif expr .op == "pair_second" :
144
261
raise NotImplementedError ()
145
262
elif expr .op in ("getiter" , "iternext" ):
146
263
raise NotImplementedError ()
264
+
265
+ # these are all things numba has hooks for upstream, but we didn't implement
266
+ # in the prototype
267
+
147
268
elif expr .op == "exhaust_iter" :
148
269
raise NotImplementedError ()
149
270
elif expr .op == "getattr" :
@@ -191,6 +312,49 @@ def emit_branch(self, branch):
191
312
condvar = self .get_name (branch .cond )
192
313
return f"cf.cond_br { condvar } , ^bb{ branch .truebr } , ^bb{ branch .falsebr } "
193
314
315
+ def emit_loop (self , target ):
316
+ # Note right now loops that use the %i value directly will need an index_cast to an i64 element.
317
+ loop = self .loops [target ]
318
+ resultvar = self .get_or_create_name (target )
319
+ itvar = self .get_or_create_name (loop .header .phi_var ) # create var for i
320
+ for_str = (
321
+ f"affine.for { itvar } = { loop .range .start } to { loop .range .stop } step"
322
+ f" { loop .range .step } "
323
+ )
324
+
325
+ if len (loop .inits ) == 1 :
326
+ # Note: we must generalize to inits > 1
327
+ init_val = self .get_name (loop .inits [0 ])
328
+ # Within the loop, forward the name for the init val to a new temp var
329
+ iter_arg = "itarg"
330
+ self .numba_names_to_ssa_var_names [loop .inits [0 ].name ] = iter_arg
331
+ for_str = (
332
+ f"{ resultvar } = { for_str } iter_args(%{ iter_arg } = { init_val } ) ->"
333
+ " (i64)"
334
+ )
335
+ self .printed_blocks [loop .header_id ] = True
336
+
337
+ # TODO(#1412): support nested loops.
338
+ loop_block = self .ssa_ir .blocks [loop .header .body_id ]
339
+ for instr in loop_block .body :
340
+ if type (instr ) == ir .Assign and instr .target in self .loops :
341
+ raise NotImplementedError ("Nested loops are not supported" )
342
+
343
+ body_str = self .emit_block (loop_block )
344
+ if len (loop .inits ) == 1 :
345
+ # Yield the iter arg
346
+ yield_var = self .get_name (loop .inits [0 ])
347
+ yield_str = f"affine.yield { yield_var } : i64"
348
+ body_str += "\n " + yield_str
349
+
350
+ # After we emit the body, we need to update the printed block map and replace all uses of the iterarg after the block with the result of the for loop.
351
+ if loop .inits :
352
+ self .forward_name (loop .inits [0 ], target )
353
+ self .printed_blocks [loop .header .body_id ] = True
354
+
355
+ result = for_str + " {\n " + textwrap .indent (body_str , " " ) + "\n }"
356
+ return result
357
+
194
358
def emit_return (self , ret ):
195
359
var = self .get_name (ret .value )
196
360
# TODO(#1162): replace i64 with inferred or explicit return type
0 commit comments