|
| 1 | +# Cursors |
| 2 | + |
| 3 | +This documentation covers how to use cursors to navigate, point-to, and apply forwarding on procedures. |
| 4 | +Throughout this document: |
| 5 | +- `p` refers to an Exo `Procedure` object |
| 6 | +- `c` refers to an Exo `Cursor` object |
| 7 | + |
| 8 | +## Obtaining Cursors |
| 9 | + |
| 10 | +### From Procedures |
| 11 | +An Exo `Procedure` provides methods to obtain `Cursor`s: |
| 12 | + |
| 13 | +- `p.args()`: Returns cursors to the procedure's arguments. |
| 14 | +- `p.body()`: Returns a `BlockCursor` selecting the entire body of the procedure. |
| 15 | +- `p.find(pattern, many=False)`: Finds cursor(s) matching the given `pattern` string: |
| 16 | + - If `many=False` (default), returns the first matching cursor. |
| 17 | + - If `many=True`, returns a list of all matching cursors. |
| 18 | +- `p.find_loop(loop_pattern, many=False)`: Finds cursor(s) to a loop, expanding shorthand patterns: |
| 19 | + - `"name"` or `"name #n"` are expanded to `for name in _:_` |
| 20 | + - Works like `p.find()`, returning the first match by default unless `many=True` |
| 21 | +- `p.find_alloc_or_arg(buf_name)`: Finds an allocation or argument cursor, expanding the name to `buf_name: _`. |
| 22 | +- `p.find_all(pattern)`: Shorthand for `p.find(pattern, many=True)`, returning all matching cursors. |
| 23 | + |
| 24 | +### From Cursors |
| 25 | +A `Cursor` provides a similar method to find sub-cursors within its sub-AST: |
| 26 | + |
| 27 | +- `c.find(pattern, many=False)`: Finds cursor(s) matching the `pattern` within the cursor's sub-AST. |
| 28 | + - Like `p.find()`, returns the first match by default unless `many=True`. |
| 29 | + |
| 30 | +### Pattern Language |
| 31 | +The `pattern` argument is a string using the following special syntax: |
| 32 | + |
| 33 | +- `_` is a wildcard matching any statement or expression |
| 34 | +- `#n` at the end selects the `n+1`th match instead of the first |
| 35 | + - Ex. `"for i in _:_ #2"` matches the 3rd `i` loop |
| 36 | +- `;` is a sequence of statements |
| 37 | + |
| 38 | +Example patterns: |
| 39 | +- `"for i in _:_"` matches a `for i in seq(0, n):...` loop |
| 40 | +- `"if i == 0:_"` or `"if _:_"` match `if` statements |
| 41 | +- `"a : i8"` or `"a : _"` match an allocation of a buffer `a` |
| 42 | +- `"a = 3.0"` or `"a = _"` match an assignment to `a` |
| 43 | +- `"a += 3.0"` or `"a += _"` match a reduction on `a` |
| 44 | +- `"a = 3.0 ; b = 2.0"` matches a block with those two statements |
| 45 | + |
| 46 | +## Cursor Types |
| 47 | + |
| 48 | +Exo defines the following `Cursor` types: |
| 49 | + |
| 50 | +- `StmtCursor`: Cursor to a specific Exo IR statement |
| 51 | +- `GapCursor`: Cursor to the space between statements, anchored to (before or after) a statement |
| 52 | +- `BlockCursor`: Cursor to a block (sequence) of statements |
| 53 | +- `ArgCursor`: Cursor to a procedure argument (no navigation) |
| 54 | +- `InvalidCursor`: Special cursor type for invalid cursors |
| 55 | + |
| 56 | +## Common Cursor Methods |
| 57 | + |
| 58 | +All `Cursor` types provide these common methods: |
| 59 | + |
| 60 | +- `c.parent()`: Returns `StmtCursor` to the parent node in Exo IR |
| 61 | + - Raises `InvalidCursorError` if at the root with no parent |
| 62 | +- `c.proc()`: Returns the `Procedure` this cursor is pointing to |
| 63 | +- `c.find(pattern, many=False)`: Finds cursors by pattern-match within `c`s sub-AST |
| 64 | + |
| 65 | +## Statement Cursor Navigation |
| 66 | + |
| 67 | +A `StmtCursor` (pointing to one IR statement) provides these navigation methods. |
| 68 | + |
| 69 | +- `c.next()`: Returns `StmtCursor` to next statement |
| 70 | +- `c.prev()`: Returns `StmtCursor` to previous statement |
| 71 | +- `c.before()`: Returns `GapCursor` to space immediately before this statement |
| 72 | +- `c.after()`: Returns `GapCursor` to space immediately after this statement |
| 73 | +- `c.as_block()`: Returns a `BlockCursor` containing only this one statement |
| 74 | +- `c.expand()`: Shorthand for `stmt_cursor.as_block().expand(...)` |
| 75 | +- `c.body()`: Returns a `BlockCursor` to the body. Only works on `ForCursor` and `IfCursor`. |
| 76 | +- `c.orelse()`: Returns a `BlockCursor` to the orelse branch. Works only on `IfCursor`. |
| 77 | + |
| 78 | +`c.next()` / `c.prev()` return an `InvalidCursor` when there is no next/previous statement. |
| 79 | +`c.before()` / `c.after()` return anchored `GapCursor`s that move with their anchor statements. |
| 80 | + |
| 81 | +Examples: |
| 82 | +``` |
| 83 | +s1 <- c |
| 84 | +s2 <- c.next() |
| 85 | +
|
| 86 | +s1 <- c.prev() |
| 87 | +s2 <- c |
| 88 | +
|
| 89 | +s1 |
| 90 | + <- c.before() |
| 91 | +s2 <- c |
| 92 | +
|
| 93 | +s1 |
| 94 | +s2 <- c |
| 95 | + <- c.after() |
| 96 | +``` |
| 97 | + |
| 98 | +## Other Cursor Navigation |
| 99 | + |
| 100 | +- `GapCursor.anchor()`: Returns `StmtCursor` to the statement this gap is anchored to |
| 101 | + |
| 102 | +- `BlockCursor.expand(delta_lo=None, delta_hi=None)`: Returns an expanded block cursor |
| 103 | + - `delta_lo`/`delta_hi` specify statements to add at start/end; `None` means expand fully |
| 104 | + - Ex. in `s1; s2; s3`, if `c` is a `BlockCursor` pointing `s1; s2`, then `c.expand(0, 1)` returns a new `BlockCursor` pointing `s1; s2; s3` |
| 105 | +- `BlockCursor.before()`: Returns `GapCursor` before block's first statement |
| 106 | +- `BlockCursor.after()`: Returns `GapCursor` after block's last statement |
| 107 | +- `BlockCursor[pt]`: Returns a `pt+1`th `StmtCursor` within the BlockCursor (e.g. `c[0]` returns `s1` when `c` is pointing to `s1;s2;...`) |
| 108 | +- `BlockCursor[lo:hi]`: Returns a slice of `BlockCursor` from `lo` to `hi-1`. (e.g. `c[0:2]` returns `s1;s2` when `c` is pointing to `s2;s2;...`) |
| 109 | + |
| 110 | +## Cursor inspection |
| 111 | + |
| 112 | +`StmtCursor`s wrap the underlying Exo IR object and can be inspected. |
| 113 | + - Ex. check cursor type with `isinstance(c, PC.AllocCursor)` |
| 114 | + |
| 115 | +`StmtCursor`s are one of the following types. |
| 116 | + |
| 117 | +#### `ArgCursor` |
| 118 | + |
| 119 | +Represents a cursor pointing to a procedure argument of the form: |
| 120 | +``` |
| 121 | +name : type @ mem |
| 122 | +``` |
| 123 | + |
| 124 | +Methods: |
| 125 | +- `name() -> str`: Returns the name of the argument. |
| 126 | +- `mem() -> Memory`: Returns the memory location of the argument. |
| 127 | +- `is_tensor() -> bool`: Checks if the argument is a tensor. |
| 128 | +- `shape() -> ExprListCursor`: Returns a cursor to the shape expression list. |
| 129 | +- `type() -> API.ExoType`: Returns the type of the argument. |
| 130 | + |
| 131 | +#### `AssignCursor` |
| 132 | + |
| 133 | +Represents a cursor pointing to an assignment statement of the form: |
| 134 | +``` |
| 135 | +name[idx] = rhs |
| 136 | +``` |
| 137 | + |
| 138 | +Methods: |
| 139 | +- `name() -> str`: Returns the name of the variable being assigned to. |
| 140 | +- `idx() -> ExprListCursor`: Returns a cursor to the index expression list. |
| 141 | +- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression. |
| 142 | +- `type() -> API.ExoType`: Returns the type of the assignment. |
| 143 | + |
| 144 | +#### `ReduceCursor` |
| 145 | + |
| 146 | +Represents a cursor pointing to a reduction statement of the form: |
| 147 | +``` |
| 148 | +name[idx] += rhs |
| 149 | +``` |
| 150 | + |
| 151 | +Methods: |
| 152 | +- `name() -> str`: Returns the name of the variable being reduced. |
| 153 | +- `idx() -> ExprListCursor`: Returns a cursor to the index expression list. |
| 154 | +- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression. |
| 155 | + |
| 156 | + |
| 157 | +#### `AssignConfigCursor` |
| 158 | + |
| 159 | +Represents a cursor pointing to a configuration assignment statement of the form: |
| 160 | +``` |
| 161 | +config.field = rhs |
| 162 | +``` |
| 163 | + |
| 164 | +Methods: |
| 165 | +- `config() -> Config`: Returns the configuration object. |
| 166 | +- `field() -> str`: Returns the name of the configuration field being assigned to. |
| 167 | +- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression. |
| 168 | + |
| 169 | +#### `PassCursor` |
| 170 | + |
| 171 | +Represents a cursor pointing to a no-op statement: |
| 172 | +``` |
| 173 | +pass |
| 174 | +``` |
| 175 | + |
| 176 | +#### `IfCursor` |
| 177 | + |
| 178 | +Represents a cursor pointing to an if statement of the form: |
| 179 | +``` |
| 180 | +if condition: |
| 181 | + body |
| 182 | +``` |
| 183 | +or |
| 184 | +``` |
| 185 | +if condition: |
| 186 | + body |
| 187 | +else: |
| 188 | + orelse |
| 189 | +``` |
| 190 | + |
| 191 | +Methods: |
| 192 | +- `cond() -> ExprCursor`: Returns a cursor to the if condition expression. |
| 193 | +- `body() -> BlockCursor`: Returns a cursor to the if body block. |
| 194 | +- `orelse() -> BlockCursor | InvalidCursor`: Returns a cursor to the else block if present, otherwise returns an invalid cursor. |
| 195 | + |
| 196 | +#### `ForCursor` |
| 197 | + |
| 198 | +Represents a cursor pointing to a loop statement of the form: |
| 199 | +``` |
| 200 | +for name in seq(0, hi): |
| 201 | + body |
| 202 | +``` |
| 203 | + |
| 204 | +Methods: |
| 205 | +- `name() -> str`: Returns the loop variable name. |
| 206 | +- `lo() -> ExprCursor`: Returns a cursor to the lower bound expression (defaults to 0). |
| 207 | +- `hi() -> ExprCursor`: Returns a cursor to the upper bound expression. |
| 208 | +- `body() -> BlockCursor`: Returns a cursor to the loop body block. |
| 209 | + |
| 210 | + |
| 211 | +#### `AllocCursor` |
| 212 | + |
| 213 | +Represents a cursor pointing to a buffer allocation statement of the form: |
| 214 | +``` |
| 215 | +name : type @ mem |
| 216 | +``` |
| 217 | + |
| 218 | +Methods: |
| 219 | +- `name() -> str`: Returns the name of the allocated buffer. |
| 220 | +- `mem() -> Memory`: Returns the memory location of the buffer. |
| 221 | +- `is_tensor() -> bool`: Checks if the allocated buffer is a tensor. |
| 222 | +- `shape() -> ExprListCursor`: Returns a cursor to the shape expression list. |
| 223 | +- `type() -> API.ExoType`: Returns the type of the allocated buffer. |
| 224 | + |
| 225 | + |
| 226 | +#### `CallCursor` |
| 227 | + |
| 228 | +Represents a cursor pointing to a sub-procedure call statement of the form: |
| 229 | +``` |
| 230 | +subproc(args) |
| 231 | +``` |
| 232 | + |
| 233 | +Methods: |
| 234 | +- `subproc()`: Returns the called sub-procedure. |
| 235 | +- `args() -> ExprListCursor`: Returns a cursor to the argument expression list. |
| 236 | + |
| 237 | + |
| 238 | +#### `WindowStmtCursor` |
| 239 | + |
| 240 | +Represents a cursor pointing to a window declaration statement of the form: |
| 241 | +``` |
| 242 | +name = winexpr |
| 243 | +``` |
| 244 | + |
| 245 | +Methods: |
| 246 | +- `name() -> str`: Returns the name of the window. |
| 247 | +- `winexpr() -> ExprCursor`: Returns a cursor to the window expression. |
| 248 | + |
| 249 | + |
| 250 | +## ExoType |
| 251 | + |
| 252 | +The `ExoType` enumeration represents user-facing various data and control types. It is a wrapper around Exo IR types. |
| 253 | + |
| 254 | +- `F16`: Represents a 16-bit floating-point type. |
| 255 | +- `F32`: Represents a 32-bit floating-point type. |
| 256 | +- `F64`: Represents a 64-bit floating-point type. |
| 257 | +- `UI8`: Represents an 8-bit unsigned integer type. |
| 258 | +- `I8`: Represents an 8-bit signed integer type. |
| 259 | +- `UI16`: Represents a 16-bit unsigned integer type. |
| 260 | +- `I32`: Represents a 32-bit signed integer type. |
| 261 | +- `R`: Represents a generic numeric type. |
| 262 | +- `Index`: Represents an index type. |
| 263 | +- `Bool`: Represents a boolean type. |
| 264 | +- `Size`: Represents a size type. |
| 265 | +- `Int`: Represents a generic integer type. |
| 266 | +- `Stride`: Represents a stride type. |
| 267 | + |
| 268 | +The `ExoType` provides the following utility methods: |
| 269 | + |
| 270 | +#### `is_indexable()` |
| 271 | + |
| 272 | +Returns `True` if the `ExoType` is one of the indexable types, which include: |
| 273 | +- `ExoType.Index` |
| 274 | +- `ExoType.Size` |
| 275 | +- `ExoType.Int` |
| 276 | +- `ExoType.Stride` |
| 277 | + |
| 278 | +#### `is_numeric()` |
| 279 | + |
| 280 | +Returns `True` if the `ExoType` is one of the numeric types, which include: |
| 281 | +- `ExoType.F16` |
| 282 | +- `ExoType.F32` |
| 283 | +- `ExoType.F64` |
| 284 | +- `ExoType.I8` |
| 285 | +- `ExoType.UI8` |
| 286 | +- `ExoType.UI16` |
| 287 | +- `ExoType.I32` |
| 288 | +- `ExoType.R` |
| 289 | + |
| 290 | +#### `is_bool()` |
| 291 | + |
| 292 | +Returns `True` if the `ExoType` is the boolean type (`ExoType.Bool`). |
| 293 | + |
| 294 | + |
| 295 | +## Cursor Forwarding |
| 296 | + |
| 297 | +When a procedure `p1` is transformed into a new procedure `p2` by applying scheduling primitives, any cursors pointing into `p1` need to be updated to point to the corresponding locations in `p2`. This process is called *cursor forwarding*. |
| 298 | + |
| 299 | +To forward a cursor `c1` from `p1` to `p2`, you can use the `forward` method on the new procedure: |
| 300 | +```python |
| 301 | +c2 = p2.forward(c1) |
| 302 | +``` |
| 303 | + |
| 304 | +### How Forwarding Works |
| 305 | + |
| 306 | +Internally, each scheduling primitive returns a *forwarding function* that maps AST locations in the input procedure to locations in the output procedure. |
| 307 | + |
| 308 | +When you call `p2.forward(c1)`, Exo composes the forwarding functions for all the scheduling steps between `c1.proc()` (the procedure `c1` points into, in this case `p1`) and `p2` (the final procedure). This composition produces a single function that can map `c1` from its original procedure to the corresponding location in `p2`. |
| 309 | + |
| 310 | +Here's the actual implementation of the forwarding in `src/exo/API.py`: |
| 311 | + |
| 312 | +```python |
| 313 | +def forward(self, cur: C.Cursor): |
| 314 | + p = self |
| 315 | + fwds = [] |
| 316 | + while p is not None and p is not cur.proc(): |
| 317 | + fwds.append(p._forward) |
| 318 | + p = p._provenance_eq_Procedure |
| 319 | + |
| 320 | + ir = cur._impl |
| 321 | + for fn in reversed(fwds): |
| 322 | + ir = fn(ir) |
| 323 | + |
| 324 | + return C.lift_cursor(ir, self) |
| 325 | +``` |
| 326 | + |
| 327 | +The key steps are: |
| 328 | + |
| 329 | +1. Collect the forwarding functions (`p._forward`) for all procedures between `cur.proc()` and `self` (the final procedure). |
| 330 | +2. Get the underlying Exo IR for the input cursor (`cur._impl`). |
| 331 | +3. Apply the forwarding functions in reverse order to map the IR node to its final location. |
| 332 | +4. Lift the mapped IR node back into a cursor in the final procedure. |
| 333 | + |
| 334 | +So in summary, `p.forward(c)` computes and applies the composite forwarding function to map cursor `c` from its original procedure to the corresponding location in procedure `p`. |
| 335 | + |
| 336 | +Note that a forwarding function can return an invalid cursor, and that is expected. For example, when a statement cease to exist by a rewrite, cursors pointing to the statement will be forwarded to an invalid cursor. |
| 337 | + |
| 338 | +### Implicit and Explicit Cursor Forwarding in Scheduling Primitives |
| 339 | + |
| 340 | +Scheduling primitives, such as `lift_alloc` and `expand_dim`, operate on a target procedure, which is passed as the first argument. When passing cursors to these primitives, the cursors should be forwarded to point to the target procedure. |
| 341 | + |
| 342 | +Consider the following example: |
| 343 | +```python |
| 344 | +c = p0.find("x : _") |
| 345 | +p1 = lift_alloc(p0, c) |
| 346 | +p2 = expand_dim(p1, p1.forward(c), ...) |
| 347 | +``` |
| 348 | + |
| 349 | +In the call to `expand_dim`, the cursor `c` is explicitly forwarded to `p1` using `p1.forward(c)`. This is necessary because `c` was originally obtained from `p0`, and it needs to be adjusted to point to the correct location in `p1`. |
| 350 | + |
| 351 | +However, the scheduling primitives support *implicit forwarding* of cursors. This means that all the cursors passed to these primitives will be automatically forwarded to point to the first argument procedure. The above code can be simplified as follows: |
| 352 | + |
| 353 | +```python |
| 354 | +c = p0.find("x : _") |
| 355 | +p1 = lift_alloc(p0, c) |
| 356 | +p2 = expand_dim(p1, c, ...) # implicit forwarding! |
| 357 | +``` |
| 358 | + |
| 359 | +In this case, `c` is implicitly forwarded to `p1` within the `expand_dim` primitive, eliminating the need for explicit forwarding. |
| 360 | + |
| 361 | +#### Limitations of Implicit Forwarding |
| 362 | + |
| 363 | +It is important to note that implicit forwarding does not work when navigation is applied to a forwarded cursor. Consider the following example: |
| 364 | + |
| 365 | +```python |
| 366 | +c = p0.find("x : _") |
| 367 | +p1 = lift_alloc(p0, c) |
| 368 | +p2 = reorder_scope(p1, p1.forward(c).next(), ...) |
| 369 | +``` |
| 370 | + |
| 371 | +In this code, the navigation `.next()` is applied to the forwarded cursor `p1.forward(c)`. Attempting to change `p1.forward(c).next()` to `p1.forward(c.next())` will result in incorrect behavior. This is because navigation and forwarding are *not commutative*. |
| 372 | + |
0 commit comments