Skip to content

Conversation

@mtsokol
Copy link
Member

@mtsokol mtsokol commented Dec 21, 2025

Hi @willow-ahrens,

I'm in the middle of reviewing #251, and I'm working on dense levels kernel for matmul operation.

First thing I would like to clear out is the mismatch of matmul and sddmm kernels generated with optimizer now and before the refactor.

Unfortunately we didn't test it thoroughly enough (e.g. only testing the output result of running a kernel instead of the actual kernel generated where it made sense).

So first - for A @ B where the raw logic is:

A = Table(BufferizedNDArray(shape=(np.int64(2), np.int64(3))), ["i", "i_2"])
A_2 = Table(
    BufferizedNDArray(shape=(np.int64(3), np.int64(2))), ["i_3", "i_4"]
)
A_3 = Reorder(
    MapJoin(
        mul,
        (
            Reorder(
                Relabel(
                    Reorder(Relabel(A, ["i_5", "i_6"]), ["i_5", "i_6", "i_7"]),
                    ["i_8", "i_9", "i_10"],
                ),
                ["i_8", "i_9"],
            ),
            Reorder(
                Relabel(
                    Reorder(
                        Relabel(A_2, ["i_11", "i_12"]),
                        ["i_13", "i_11", "i_12"],
                    ),
                    ["i_14", "i_9", "i_15"],
                ),
                ["i_9", "i_15"],
            ),
        ),
    ),
    ["i_8", "i_9", "i_15"],
)
A_4 = Aggregate(add, 0.0, Relabel(A_3, ["i_16", "i_17", "i_18"]), ["i_17"])

A_5 = A_4
return ("A_5",)

after logic optimization we get:

A = Reorder(A_11, ["i", "i_2"])
A_2 = Reorder(A_12, ["i_3", "i_4"])
A_3 = Aggregate(
    overwrite,
    0,
    Reorder(
        MapJoin(
            mul,
            ("Relabel(A, ['i_8', 'i_9'])", "Relabel(A_2, ['i_9', 'i_15'])"),
        ),
        ["i_8", "i_9", "i_15"],
    ),
    [],
)
A0 = Aggregate(
    add,
    0.0,
    Reorder(Relabel(A_3, ["i_16", "i_17", "i_18"]), ["i_16", "i_17", "i_18"]),
    ["i_17"],
)
A_4 = Reorder(A0, ["i_16", "i_18"])
A_5 = Reorder(A_4, ["i_16", "i_18"])
return ("A_5",)

which results in multiple loop(loop(...)) for each input table only for overwrite.

Can we get rid of aggregate with overwrite?

Here is an expandable section with notation prgm from NotationCompiler:

Details
def main(#A_11#7: BufferizedNDArray_i64_shape_i64_i64_strides_i64_i64, #A_12#11: BufferizedNDArray_i64_shape_i64_i64_strides_i64_i64, #A#15: BufferizedNDArray_i64_shape_i64_i64_strides_i64_i64, #A_2#19: BufferizedNDArray_i64_shape_i64_i64_strides_i64_i64, #A_3#23: BufferizedNDArray_i64_shape_i64_i64_i64_strides_i64_i64_i64, ##A#0#28: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64, #A_4#32: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64, #A_5#36: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64) -> tuple(BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64):
    #_A_11#8: BufferizedNDArray_i64_shape_i64_i64_strides_i64_i64 = unpack(#A_11#7)
    #A_11_dim_0#9: int64 = #_A_11#8.shape[0]
    #A_11_dim_1#10: int64 = #_A_11#8.shape[1]
    #_A_12#12: BufferizedNDArray_i64_shape_i64_i64_strides_i64_i64 = unpack(#A_12#11)
    #A_12_dim_0#13: int64 = #_A_12#12.shape[0]
    #A_12_dim_1#14: int64 = #_A_12#12.shape[1]
    #_A#16: BufferizedNDArray_i64_shape_i64_i64_strides_i64_i64 = unpack(#A#15)
    #A_dim_0#17: int64 = #_A#16.shape[0]
    #A_dim_1#18: int64 = #_A#16.shape[1]
    #_A_2#20: BufferizedNDArray_i64_shape_i64_i64_strides_i64_i64 = unpack(#A_2#19)
    #A_2_dim_0#21: int64 = #_A_2#20.shape[0]
    #A_2_dim_1#22: int64 = #_A_2#20.shape[1]
    #_A_3#24: BufferizedNDArray_i64_shape_i64_i64_i64_strides_i64_i64_i64 = unpack(#A_3#23)
    #A_3_dim_0#25: int64 = #_A_3#24.shape[0]
    #A_3_dim_1#26: int64 = #_A_3#24.shape[1]
    #A_3_dim_2#27: int64 = #_A_3#24.shape[2]
    #_#A#0#29: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64 = unpack(##A#0#28)
    ##A#0_dim_0#30: int64 = #_#A#0#29.shape[0]
    ##A#0_dim_1#31: int64 = #_#A#0#29.shape[1]
    #_A_4#33: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64 = unpack(#A_4#32)
    #A_4_dim_0#34: int64 = #_A_4#33.shape[0]
    #A_4_dim_1#35: int64 = #_A_4#33.shape[1]
    #_A_5#37: BufferizedNDArray_f64_shape_i64_i64_strides_i64_i64 = unpack(#A_5#36)
    #A_5_dim_0#38: int64 = #_A_5#37.shape[0]
    #A_5_dim_1#39: int64 = #_A_5#37.shape[1]
    declare(#_A#16, 0, overwrite, [])
    loop(#i#40, Extent(0, #A_11_dim_0#9)):
        loop(#i_2#41, Extent(0, #A_11_dim_1#10)):
            increment(update(#_A#16, ['#i#40', '#i_2#41'], overwrite), unwrap(read(#_A_11#8, ['#i#40', '#i_2#41'])))
    freeze(#_A#16, overwrite)
    declare(#_A_2#20, 0, overwrite, [])
    loop(#i_3#42, Extent(0, #A_12_dim_0#13)):
        loop(#i_4#43, Extent(0, #A_12_dim_1#14)):
            increment(update(#_A_2#20, ['#i_3#42', '#i_4#43'], overwrite), unwrap(read(#_A_12#12, ['#i_3#42', '#i_4#43'])))
    freeze(#_A_2#20, overwrite)
    declare(#_A_3#24, 0, overwrite, [])
    loop(#i_8#44, Extent(0, #A_dim_0#17)):
        loop(#i_9#45, Extent(0, #A_2_dim_0#21)):
            loop(#i_15#46, Extent(0, #A_2_dim_1#22)):
                increment(update(#_A_3#24, ['#i_8#44', '#i_9#45', '#i_15#46'], overwrite), mul(unwrap(read(#_A#16, ['#i_8#44', '#i_9#45'])), unwrap(read(#_A_2#20, ['#i_9#45', '#i_15#46']))))
    freeze(#_A_3#24, overwrite)
    declare(#_#A#0#29, 0.0, add, [])
    loop(#i_16#47, Extent(0, #A_3_dim_0#25)):
        loop(#i_17#48, Extent(0, #A_3_dim_1#26)):
            loop(#i_18#49, Extent(0, #A_3_dim_2#27)):
                increment(update(#_#A#0#29, ['#i_16#47', '#i_18#49'], add), unwrap(read(#_A_3#24, ['#i_16#47', '#i_17#48', '#i_18#49'])))
    freeze(#_#A#0#29, add)
    declare(#_A_4#33, 0.0, overwrite, [])
    loop(#i_16#50, Extent(0, ##A#0_dim_0#30)):
        loop(#i_18#51, Extent(0, ##A#0_dim_1#31)):
            increment(update(#_A_4#33, ['#i_16#50', '#i_18#51'], overwrite), unwrap(read(#_#A#0#29, ['#i_16#50', '#i_18#51'])))
    freeze(#_A_4#33, overwrite)
    declare(#_A_5#37, 0.0, overwrite, [])
    loop(#i_16#52, Extent(0, #A_4_dim_0#34)):
        loop(#i_18#53, Extent(0, #A_4_dim_1#35)):
            increment(update(#_A_5#37, ['#i_16#52', '#i_18#53'], overwrite), unwrap(read(#_A_4#33, ['#i_16#52', '#i_18#53'])))
    freeze(#_A_5#37, overwrite)
    repack(#_A_11#8, #A_11#7)
    repack(#_A_12#12, #A_12#11)
    repack(#_A#16, #A#15)
    repack(#_A_2#20, #A_2#19)
    repack(#_A_3#24, #A_3#23)
    repack(#_#A#0#29, ##A#0#28)
    repack(#_A_4#33, #A_4#32)
    repack(#_A_5#37, #A_5#36)
    return make_tuple(#A_5#36)

Apart from this issue, in terms of debugging, how about passing an optional debug_ctx down the compilation stack where passes/stages can "log" arbitrary information, such as prgm before and after applying transformations?
This should make us more productive - no need to search for print statement debugging the kernel.

@mtsokol mtsokol marked this pull request as ready for review December 21, 2025 12:22
@willow-ahrens
Copy link
Member

I removed the optimization passes, so there's a lot of messy plans now. I'm adding them back in a separate PR, which should clear up the extra loop(loop(, etc.

After I add back in optimizations, we'll add a few reference outputs so that we can keep tabs on any optimizer regressions.

@willow-ahrens
Copy link
Member

Regarding the debugging, I'm not sure I want to complexify any of the existing passes. Is there a lower-overhead way to instrument the compiler stages?

@mtsokol
Copy link
Member Author

mtsokol commented Dec 22, 2025

Reference outputs SGTM. I added a separate pass by calling class from optimize.py (with push_fields from standardize.py, I guess one from optimize.py can be removed) and with propagate_map_queries_backward I got a plan:

A0 = Table(0, [])
A1 = Aggregate(add, 0.0, Reorder(MapJoin(overwrite, Reorder(A0, []), Reorder(MapJoin(mul, Reorder(Relabel(A2, ['i0', 'i1']), ['i0', 'i1']), Reorder(Relabel(A3, ['i1', 'i2']), ['i1', 'i2'])), ['i0', 'i1', 'i2'])), ['i0', 'i1', 'i2']), ['i1'])
A4 = Reorder(A1, ['i0', 'i2'])
return ('A4',)

Are these double MapJoins intended? I think previously we had InitWrite here.

One thing: init 0 in Aggreate(overwrite) gets transformed to MapJoin and 0 becomes one of the arguments. As it's a Literal it lacks fields method so in propagate_map... we can wrap it in Table class with empty fields for Aggregate(op, init, args, ()) => MapJoin(op, (init, *args)).


About debugging: I will think about it - Each time I get stuck I spend some time searching where to print the plan (after which pass), so it could be instrumented (storing a plan after each pass that you can quickly fetch and print) and documented probably. What is your way of developing a pass?

We could have proper log statements that are enabled with a flag or env var for printing plans.

@mtsokol
Copy link
Member Author

mtsokol commented Dec 22, 2025

For more readable error logs in github CI we can have "smoke tests" with a few exhaustive tests after which whole test suite is executed. If something is broken in a PR in e.g. Logic you don't get +2000 errors from test_interface.py.

@willow-ahrens
Copy link
Member

willow-ahrens commented Dec 22, 2025

I think the right call is probably to add proper python logging of plans, with verbosity levels and logging.py. e.g. in c_codegen we have

logger = logging.getLogger(__name__)

and I think there's a way to access the log.

@willow-ahrens
Copy link
Member

I want to fix optimize.py, but I'm working towards #277 first because it changes the Logic IR significantly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants