Skip to content

Commit 36a8ec4

Browse files
Add Cursor documentation (#721)
1 parent 3541352 commit 36a8ec4

File tree

2 files changed

+378
-6
lines changed

2 files changed

+378
-6
lines changed

docs/API.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88

99
## Procedure Object Methods
1010

11+
The following are methods on Exo Procedures (functions decorated with `@proc` or `@instr`).
12+
1113
### Inspection Operations
1214

1315
- `.name()`: Returns the procedure name.
1416
- `.is_instr()`: Returns `True` if the procedure has a hardware instruction string.
1517
- `.get_instr()`: Returns the hardware instruction string.
16-
- `.args()`: Returns cursors to procedure arguments.
17-
- `.body()`: Returns a BlockCursor selecting the entire body of the Procedure.
18-
- `.find(pattern, many=False)`: Finds a cursor for the given pattern. If `many=True`, returns a list of all cursors matching the pattern.
19-
- `.find_loop(loop_pattern, many=False)`: Finds a cursor pointing to a loop. Similar to `proc.find(...)`, but if the supplied pattern is of the form 'name' or 'name #n', it will be auto-expanded to `for name in _:_`.
20-
- `.find_alloc_or_arg(pattern)`: Finds an allocation or argument cursor.
21-
- `.find_all(pattern)`: Finds a list of all cursors matching the pattern.
18+
19+
### Obtaining Cursors
20+
21+
Cursors can be obtained by querying patterns on a procedure. All the Cursor related documentations are in [Cursors.md](Cursors.md).
2222

2323
### Compilation Operations
2424

docs/Cursors.md

+372
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
# Cursors
2+
3+
This documentation covers how to use cursors to navigate, point-to, and apply forwarding on procedures.
4+
Throughout this document:
5+
- `p` refers to an Exo `Procedure` object
6+
- `c` refers to an Exo `Cursor` object
7+
8+
## Obtaining Cursors
9+
10+
### From Procedures
11+
An Exo `Procedure` provides methods to obtain `Cursor`s:
12+
13+
- `p.args()`: Returns cursors to the procedure's arguments.
14+
- `p.body()`: Returns a `BlockCursor` selecting the entire body of the procedure.
15+
- `p.find(pattern, many=False)`: Finds cursor(s) matching the given `pattern` string:
16+
- If `many=False` (default), returns the first matching cursor.
17+
- If `many=True`, returns a list of all matching cursors.
18+
- `p.find_loop(loop_pattern, many=False)`: Finds cursor(s) to a loop, expanding shorthand patterns:
19+
- `"name"` or `"name #n"` are expanded to `for name in _:_`
20+
- Works like `p.find()`, returning the first match by default unless `many=True`
21+
- `p.find_alloc_or_arg(buf_name)`: Finds an allocation or argument cursor, expanding the name to `buf_name: _`.
22+
- `p.find_all(pattern)`: Shorthand for `p.find(pattern, many=True)`, returning all matching cursors.
23+
24+
### From Cursors
25+
A `Cursor` provides a similar method to find sub-cursors within its sub-AST:
26+
27+
- `c.find(pattern, many=False)`: Finds cursor(s) matching the `pattern` within the cursor's sub-AST.
28+
- Like `p.find()`, returns the first match by default unless `many=True`.
29+
30+
### Pattern Language
31+
The `pattern` argument is a string using the following special syntax:
32+
33+
- `_` is a wildcard matching any statement or expression
34+
- `#n` at the end selects the `n+1`th match instead of the first
35+
- Ex. `"for i in _:_ #2"` matches the 3rd `i` loop
36+
- `;` is a sequence of statements
37+
38+
Example patterns:
39+
- `"for i in _:_"` matches a `for i in seq(0, n):...` loop
40+
- `"if i == 0:_"` or `"if _:_"` match `if` statements
41+
- `"a : i8"` or `"a : _"` match an allocation of a buffer `a`
42+
- `"a = 3.0"` or `"a = _"` match an assignment to `a`
43+
- `"a += 3.0"` or `"a += _"` match a reduction on `a`
44+
- `"a = 3.0 ; b = 2.0"` matches a block with those two statements
45+
46+
## Cursor Types
47+
48+
Exo defines the following `Cursor` types:
49+
50+
- `StmtCursor`: Cursor to a specific Exo IR statement
51+
- `GapCursor`: Cursor to the space between statements, anchored to (before or after) a statement
52+
- `BlockCursor`: Cursor to a block (sequence) of statements
53+
- `ArgCursor`: Cursor to a procedure argument (no navigation)
54+
- `InvalidCursor`: Special cursor type for invalid cursors
55+
56+
## Common Cursor Methods
57+
58+
All `Cursor` types provide these common methods:
59+
60+
- `c.parent()`: Returns `StmtCursor` to the parent node in Exo IR
61+
- Raises `InvalidCursorError` if at the root with no parent
62+
- `c.proc()`: Returns the `Procedure` this cursor is pointing to
63+
- `c.find(pattern, many=False)`: Finds cursors by pattern-match within `c`s sub-AST
64+
65+
## Statement Cursor Navigation
66+
67+
A `StmtCursor` (pointing to one IR statement) provides these navigation methods.
68+
69+
- `c.next()`: Returns `StmtCursor` to next statement
70+
- `c.prev()`: Returns `StmtCursor` to previous statement
71+
- `c.before()`: Returns `GapCursor` to space immediately before this statement
72+
- `c.after()`: Returns `GapCursor` to space immediately after this statement
73+
- `c.as_block()`: Returns a `BlockCursor` containing only this one statement
74+
- `c.expand()`: Shorthand for `stmt_cursor.as_block().expand(...)`
75+
- `c.body()`: Returns a `BlockCursor` to the body. Only works on `ForCursor` and `IfCursor`.
76+
- `c.orelse()`: Returns a `BlockCursor` to the orelse branch. Works only on `IfCursor`.
77+
78+
`c.next()` / `c.prev()` return an `InvalidCursor` when there is no next/previous statement.
79+
`c.before()` / `c.after()` return anchored `GapCursor`s that move with their anchor statements.
80+
81+
Examples:
82+
```
83+
s1 <- c
84+
s2 <- c.next()
85+
86+
s1 <- c.prev()
87+
s2 <- c
88+
89+
s1
90+
<- c.before()
91+
s2 <- c
92+
93+
s1
94+
s2 <- c
95+
<- c.after()
96+
```
97+
98+
## Other Cursor Navigation
99+
100+
- `GapCursor.anchor()`: Returns `StmtCursor` to the statement this gap is anchored to
101+
102+
- `BlockCursor.expand(delta_lo=None, delta_hi=None)`: Returns an expanded block cursor
103+
- `delta_lo`/`delta_hi` specify statements to add at start/end; `None` means expand fully
104+
- Ex. in `s1; s2; s3`, if `c` is a `BlockCursor` pointing `s1; s2`, then `c.expand(0, 1)` returns a new `BlockCursor` pointing `s1; s2; s3`
105+
- `BlockCursor.before()`: Returns `GapCursor` before block's first statement
106+
- `BlockCursor.after()`: Returns `GapCursor` after block's last statement
107+
- `BlockCursor[pt]`: Returns a `pt+1`th `StmtCursor` within the BlockCursor (e.g. `c[0]` returns `s1` when `c` is pointing to `s1;s2;...`)
108+
- `BlockCursor[lo:hi]`: Returns a slice of `BlockCursor` from `lo` to `hi-1`. (e.g. `c[0:2]` returns `s1;s2` when `c` is pointing to `s2;s2;...`)
109+
110+
## Cursor inspection
111+
112+
`StmtCursor`s wrap the underlying Exo IR object and can be inspected.
113+
- Ex. check cursor type with `isinstance(c, PC.AllocCursor)`
114+
115+
`StmtCursor`s are one of the following types.
116+
117+
#### `ArgCursor`
118+
119+
Represents a cursor pointing to a procedure argument of the form:
120+
```
121+
name : type @ mem
122+
```
123+
124+
Methods:
125+
- `name() -> str`: Returns the name of the argument.
126+
- `mem() -> Memory`: Returns the memory location of the argument.
127+
- `is_tensor() -> bool`: Checks if the argument is a tensor.
128+
- `shape() -> ExprListCursor`: Returns a cursor to the shape expression list.
129+
- `type() -> API.ExoType`: Returns the type of the argument.
130+
131+
#### `AssignCursor`
132+
133+
Represents a cursor pointing to an assignment statement of the form:
134+
```
135+
name[idx] = rhs
136+
```
137+
138+
Methods:
139+
- `name() -> str`: Returns the name of the variable being assigned to.
140+
- `idx() -> ExprListCursor`: Returns a cursor to the index expression list.
141+
- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression.
142+
- `type() -> API.ExoType`: Returns the type of the assignment.
143+
144+
#### `ReduceCursor`
145+
146+
Represents a cursor pointing to a reduction statement of the form:
147+
```
148+
name[idx] += rhs
149+
```
150+
151+
Methods:
152+
- `name() -> str`: Returns the name of the variable being reduced.
153+
- `idx() -> ExprListCursor`: Returns a cursor to the index expression list.
154+
- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression.
155+
156+
157+
#### `AssignConfigCursor`
158+
159+
Represents a cursor pointing to a configuration assignment statement of the form:
160+
```
161+
config.field = rhs
162+
```
163+
164+
Methods:
165+
- `config() -> Config`: Returns the configuration object.
166+
- `field() -> str`: Returns the name of the configuration field being assigned to.
167+
- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression.
168+
169+
#### `PassCursor`
170+
171+
Represents a cursor pointing to a no-op statement:
172+
```
173+
pass
174+
```
175+
176+
#### `IfCursor`
177+
178+
Represents a cursor pointing to an if statement of the form:
179+
```
180+
if condition:
181+
body
182+
```
183+
or
184+
```
185+
if condition:
186+
body
187+
else:
188+
orelse
189+
```
190+
191+
Methods:
192+
- `cond() -> ExprCursor`: Returns a cursor to the if condition expression.
193+
- `body() -> BlockCursor`: Returns a cursor to the if body block.
194+
- `orelse() -> BlockCursor | InvalidCursor`: Returns a cursor to the else block if present, otherwise returns an invalid cursor.
195+
196+
#### `ForCursor`
197+
198+
Represents a cursor pointing to a loop statement of the form:
199+
```
200+
for name in seq(0, hi):
201+
body
202+
```
203+
204+
Methods:
205+
- `name() -> str`: Returns the loop variable name.
206+
- `lo() -> ExprCursor`: Returns a cursor to the lower bound expression (defaults to 0).
207+
- `hi() -> ExprCursor`: Returns a cursor to the upper bound expression.
208+
- `body() -> BlockCursor`: Returns a cursor to the loop body block.
209+
210+
211+
#### `AllocCursor`
212+
213+
Represents a cursor pointing to a buffer allocation statement of the form:
214+
```
215+
name : type @ mem
216+
```
217+
218+
Methods:
219+
- `name() -> str`: Returns the name of the allocated buffer.
220+
- `mem() -> Memory`: Returns the memory location of the buffer.
221+
- `is_tensor() -> bool`: Checks if the allocated buffer is a tensor.
222+
- `shape() -> ExprListCursor`: Returns a cursor to the shape expression list.
223+
- `type() -> API.ExoType`: Returns the type of the allocated buffer.
224+
225+
226+
#### `CallCursor`
227+
228+
Represents a cursor pointing to a sub-procedure call statement of the form:
229+
```
230+
subproc(args)
231+
```
232+
233+
Methods:
234+
- `subproc()`: Returns the called sub-procedure.
235+
- `args() -> ExprListCursor`: Returns a cursor to the argument expression list.
236+
237+
238+
#### `WindowStmtCursor`
239+
240+
Represents a cursor pointing to a window declaration statement of the form:
241+
```
242+
name = winexpr
243+
```
244+
245+
Methods:
246+
- `name() -> str`: Returns the name of the window.
247+
- `winexpr() -> ExprCursor`: Returns a cursor to the window expression.
248+
249+
250+
## ExoType
251+
252+
The `ExoType` enumeration represents user-facing various data and control types. It is a wrapper around Exo IR types.
253+
254+
- `F16`: Represents a 16-bit floating-point type.
255+
- `F32`: Represents a 32-bit floating-point type.
256+
- `F64`: Represents a 64-bit floating-point type.
257+
- `UI8`: Represents an 8-bit unsigned integer type.
258+
- `I8`: Represents an 8-bit signed integer type.
259+
- `UI16`: Represents a 16-bit unsigned integer type.
260+
- `I32`: Represents a 32-bit signed integer type.
261+
- `R`: Represents a generic numeric type.
262+
- `Index`: Represents an index type.
263+
- `Bool`: Represents a boolean type.
264+
- `Size`: Represents a size type.
265+
- `Int`: Represents a generic integer type.
266+
- `Stride`: Represents a stride type.
267+
268+
The `ExoType` provides the following utility methods:
269+
270+
#### `is_indexable()`
271+
272+
Returns `True` if the `ExoType` is one of the indexable types, which include:
273+
- `ExoType.Index`
274+
- `ExoType.Size`
275+
- `ExoType.Int`
276+
- `ExoType.Stride`
277+
278+
#### `is_numeric()`
279+
280+
Returns `True` if the `ExoType` is one of the numeric types, which include:
281+
- `ExoType.F16`
282+
- `ExoType.F32`
283+
- `ExoType.F64`
284+
- `ExoType.I8`
285+
- `ExoType.UI8`
286+
- `ExoType.UI16`
287+
- `ExoType.I32`
288+
- `ExoType.R`
289+
290+
#### `is_bool()`
291+
292+
Returns `True` if the `ExoType` is the boolean type (`ExoType.Bool`).
293+
294+
295+
## Cursor Forwarding
296+
297+
When a procedure `p1` is transformed into a new procedure `p2` by applying scheduling primitives, any cursors pointing into `p1` need to be updated to point to the corresponding locations in `p2`. This process is called *cursor forwarding*.
298+
299+
To forward a cursor `c1` from `p1` to `p2`, you can use the `forward` method on the new procedure:
300+
```python
301+
c2 = p2.forward(c1)
302+
```
303+
304+
### How Forwarding Works
305+
306+
Internally, each scheduling primitive returns a *forwarding function* that maps AST locations in the input procedure to locations in the output procedure.
307+
308+
When you call `p2.forward(c1)`, Exo composes the forwarding functions for all the scheduling steps between `c1.proc()` (the procedure `c1` points into, in this case `p1`) and `p2` (the final procedure). This composition produces a single function that can map `c1` from its original procedure to the corresponding location in `p2`.
309+
310+
Here's the actual implementation of the forwarding in `src/exo/API.py`:
311+
312+
```python
313+
def forward(self, cur: C.Cursor):
314+
p = self
315+
fwds = []
316+
while p is not None and p is not cur.proc():
317+
fwds.append(p._forward)
318+
p = p._provenance_eq_Procedure
319+
320+
ir = cur._impl
321+
for fn in reversed(fwds):
322+
ir = fn(ir)
323+
324+
return C.lift_cursor(ir, self)
325+
```
326+
327+
The key steps are:
328+
329+
1. Collect the forwarding functions (`p._forward`) for all procedures between `cur.proc()` and `self` (the final procedure).
330+
2. Get the underlying Exo IR for the input cursor (`cur._impl`).
331+
3. Apply the forwarding functions in reverse order to map the IR node to its final location.
332+
4. Lift the mapped IR node back into a cursor in the final procedure.
333+
334+
So in summary, `p.forward(c)` computes and applies the composite forwarding function to map cursor `c` from its original procedure to the corresponding location in procedure `p`.
335+
336+
Note that a forwarding function can return an invalid cursor, and that is expected. For example, when a statement cease to exist by a rewrite, cursors pointing to the statement will be forwarded to an invalid cursor.
337+
338+
### Implicit and Explicit Cursor Forwarding in Scheduling Primitives
339+
340+
Scheduling primitives, such as `lift_alloc` and `expand_dim`, operate on a target procedure, which is passed as the first argument. When passing cursors to these primitives, the cursors should be forwarded to point to the target procedure.
341+
342+
Consider the following example:
343+
```python
344+
c = p0.find("x : _")
345+
p1 = lift_alloc(p0, c)
346+
p2 = expand_dim(p1, p1.forward(c), ...)
347+
```
348+
349+
In the call to `expand_dim`, the cursor `c` is explicitly forwarded to `p1` using `p1.forward(c)`. This is necessary because `c` was originally obtained from `p0`, and it needs to be adjusted to point to the correct location in `p1`.
350+
351+
However, the scheduling primitives support *implicit forwarding* of cursors. This means that all the cursors passed to these primitives will be automatically forwarded to point to the first argument procedure. The above code can be simplified as follows:
352+
353+
```python
354+
c = p0.find("x : _")
355+
p1 = lift_alloc(p0, c)
356+
p2 = expand_dim(p1, c, ...) # implicit forwarding!
357+
```
358+
359+
In this case, `c` is implicitly forwarded to `p1` within the `expand_dim` primitive, eliminating the need for explicit forwarding.
360+
361+
#### Limitations of Implicit Forwarding
362+
363+
It is important to note that implicit forwarding does not work when navigation is applied to a forwarded cursor. Consider the following example:
364+
365+
```python
366+
c = p0.find("x : _")
367+
p1 = lift_alloc(p0, c)
368+
p2 = reorder_scope(p1, p1.forward(c).next(), ...)
369+
```
370+
371+
In this code, the navigation `.next()` is applied to the forwarded cursor `p1.forward(c)`. Attempting to change `p1.forward(c).next()` to `p1.forward(c.next())` will result in incorrect behavior. This is because navigation and forwarding are *not commutative*.
372+

0 commit comments

Comments
 (0)