-
Notifications
You must be signed in to change notification settings - Fork 180
Description
Hey everyone! I've been working on a rewrite of the bulk of luminal to be search-based. Most of the time has been spent on developing an IR flexible enough to represent all desired kernels and thinking about how to search through them.
The thinking on this is that currently Luminal is essentially 50% heuristic pattern matchers, 50% greedy elementwise fusion. All of this is too brittle and too handwritten. To date no one (with possibly the exception of tinygrad) has made a true search-based compiler. Mirage was closest, but they had to heavily constrain their search space to find anything interesting. I think the current state of luminal is a great jumping-off point for creating a proper search-based ML framework. So the goals for "Luminal 2.0" (not really 2.0, just a name I've been using for the major rewrite) is this:
- Write a good codegen to go IR -> kernels
- Write the IR in egglog to search over it
- Write a profiling function to feed in dummy data to test the perf of given kernels
- Run a search to get a single optimized transformer layer (this will take a long time to run, the point is to prove it can be done, not done fast. Goal is to "discover" flash attention).
- Write a heuristic for byte-movement minimization to approximate performance before running the profiler function
- Add MCTS to egglog for tree search instead of saturation search
- Run a search to get an optimized full transformer model
- Rewrite the unit tests
- Add a python frontend library
These will likely take a while, but I'm increasingly convinced this is the right direction to go, and it's a direction only accessible to Luminal because of the design (few primops, static graph, etc.).
Currently I'm starting to add experiments on the next branch under luminal_2 crate.
Here's a quick look at what I think the IR will be:
; To do Tensor(4).exp()
(4)[($0, 1)] -> 1 {
exp($0)
}
; To do a standard matmul
(8)[($0, 8), ($1, 0)] -> 8 {
(8)[($0, 0), ($1, 1)] -> 1 {
(8)[($0, 1), ($1, 8)] -> 1 {
mul($0, $1)
}
sum_reduce(~0, 8, 1)
}
}
; Do a tiled matmul
(4)[($0, 16), ($1, 0)] -> 16 {
(4)[($0, 0), ($1, 2)] -> 2 {
; Do a stack of 4 2x2 matmuls (saved to intermediate structure (4, 2, 2) (4, 2, 1))
; The idea is that this outer loop (4) gets accumulated over in a (2,2) matrix.
(4)[($0, 2), ($1, 16)] -> 4 {
(2)[($0, 8), ($1, 0)] -> 2 {
(2)[($0, 0), ($1, 1)] -> 1 {
(2)[($0, 1), ($1, 8)] -> 1 {
mul($0, $1)
}
sum_reduce(~0, 2, 1)
}
}
}
; Sum reduce the stack into one 2x2
(2)[(~0, 2)] -> 8 {
(2)[($0, 1)] -> 1 {
sum_reduce(~0, 4, 4)
}
}
}
}Essentially the idea behind this is that we have blocks with the syntax
(block_size)[(input_a, input_a_stride), ...] -> output_stride {
block_body
}
This allows us to do very flexible things like tiled matmuls, various stride interleaving structures, etc. all while representing it fairly consistently and not requiring an overly complex codegen pass.
I'll do a longer writeup with a more fleshed out spec once the codegen and egglog search functions are up and running.
I'd love to hear any thoughts!