Skip to content

Honor mooncake_overlay in primitive inference#1170

Open
yebai wants to merge 3 commits into
mainfrom
fix/1169-overlay-primitive-callmeta
Open

Honor mooncake_overlay in primitive inference#1170
yebai wants to merge 3 commits into
mainfrom
fix/1169-overlay-primitive-callmeta

Conversation

@yebai
Copy link
Copy Markdown
Member

@yebai yebai commented May 7, 2026

MooncakeInterpreter previously routed every primitive call site through NativeInterpreter for CallMeta, but NativeInterpreter is overlay-blind: when a primitive has a @mooncake_overlay that changes its return type, the inferred type at the call site was wrong, and downstream dispatch compiled against the original type. Detect overlay matches via Method.external_mt and route them through @invoke instead, keeping the NativeInterpreter fast path for non-overlay primitives.

Fixes #1169. Alternative to #1168.

CI Summary — GitHub Actions

Documentation Preview

Mooncake.jl documentation for PR #1170 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR1170/

Performance

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌───────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                 Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                String │   String │   String │      String │  String │      String │ String │
├───────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│              sum_1000 │ 180.0 ns │     1.56 │        1.56 │   0.722 │        3.56 │   7.01 │
│             _sum_1000 │   1.1 μs │      6.0 │        1.02 │  4610.0 │        37.2 │   1.06 │
│          sum_sin_1000 │  7.42 μs │     2.46 │        1.12 │    1.64 │        11.0 │   1.75 │
│         _sum_sin_1000 │  4.56 μs │     3.83 │        2.69 │   345.0 │        18.1 │   3.11 │
│              kron_sum │ 188.0 μs │     13.3 │        3.35 │    7.39 │       544.0 │   19.9 │
│         kron_view_sum │ 265.0 μs │     12.6 │        5.27 │    26.7 │       449.0 │   13.9 │
│ naive_map_sin_cos_exp │  2.29 μs │      2.7 │        1.51 │ missing │        8.16 │   2.09 │
│       map_sin_cos_exp │  2.16 μs │     3.42 │         1.6 │    1.52 │        7.33 │   2.75 │
│ broadcast_sin_cos_exp │  2.26 μs │     3.02 │        1.54 │    4.35 │        1.43 │   2.14 │
│            simple_mlp │ 357.0 μs │     5.02 │        2.83 │    2.27 │        9.95 │    3.1 │
│                gp_lml │ 164.0 μs │     11.9 │        2.62 │    5.69 │     missing │   5.65 │
│    large_single_block │ 471.0 ns │     5.19 │        1.94 │  4180.0 │        31.9 │   2.08 │
└───────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

@yebai yebai requested a review from sunxd3 May 7, 2026 22:28
`MooncakeInterpreter` previously routed every primitive call site through
`NativeInterpreter` for `CallMeta`, but `NativeInterpreter` is overlay-blind:
when a primitive has a `@mooncake_overlay` that changes its return type, the
inferred type at the call site was wrong and downstream dispatch compiled
against the original type. Detect overlay matches via `Method.external_mt`
and route them through `@invoke` instead, keeping the `NativeInterpreter`
fast path for non-overlay primitives. Fixes #1169.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@yebai yebai force-pushed the fix/1169-overlay-primitive-callmeta branch from 0eddb03 to 848c6a3 Compare May 7, 2026 23:33
@sunxd3
Copy link
Copy Markdown
Collaborator

sunxd3 commented May 8, 2026

On this issue, I slightly lean towards #1168 because indeed it's more general and the complexity diff w.r.t this PR is not too bad. A minor issue with the current code is the use of external_mt — not blocking, but less ideal in terms of interfacing with julia internals.

Comment thread src/interpreter/abstract_interpretation.jl Outdated
Comment thread src/interpreter/abstract_interpretation.jl Outdated
yebai and others added 2 commits May 8, 2026 14:32
Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
Signed-off-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Both `any_matches_primitive` and `any_matches_overlay` need the same
pre-1.12 / 1.12+ unwrap of an `applicable` entry to a `MethodMatch`.
Use a single inline `match = VERSION < v"1.12-" ? app : app.match`
in each so the version skew sits in one place per function. Also
tighten the overlay check to compare directly against
`mooncake_method_table` rather than `external_mt !== nothing`, so an
unrelated downstream overlay table cannot trip the overlay path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@@ -204,12 +219,19 @@ end

function any_matches_primitive(applicable, C, M, world)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there is already any_matches_primitive defined, which accesses Julia internals in similar ways.

@sunxd3
Copy link
Copy Markdown
Collaborator

sunxd3 commented May 8, 2026

Thanks @yebai for the follow-up.

After some thought and investigation, I now lean toward this PR over #1168 for the scope of issue #1169. Use of external_mt, IMO, is fine and a small price to pay; the alternatives seem worse as far as I can tell.

Before merging, there is one particular case I want to think through. Maybe it is not that important, but I would like to be clear.


Comparing on four versions of Mooncake:

pre-#1115                 4e2df06c0  c7c30f259^
main                      b79d67e1e
#1168                     ac5a35f72
#1170                     a0f11465e

on

using Mooncake

struct A end
struct B end

helper(::A) = A()
Mooncake.@mooncake_overlay helper(::A) = B()

primitive_wrapper(x::A) = helper(x)
Mooncake.@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{
    typeof(primitive_wrapper), A
}

caller() = primitive_wrapper(A())
caller_arg(x::A) = primitive_wrapper(x)
interp = Mooncake.MooncakeInterpreter(Mooncake.DefaultCtx, Mooncake.ReverseMode)

for sig in (
    Tuple{typeof(primitive_wrapper), A},
    Tuple{typeof(caller)},
    Tuple{typeof(caller_arg), A},
)
    ir, rt = Base.code_ircode_by_type(sig; interp)[1]
    println(sig, " => ", rt)
end

Results:

Version primitive_wrapper(::A) caller() caller_arg(::A)
pre-#1115 4e2df06c0 B B B
#1168 ac5a35f72 B B B
main b79d67e1e B A A
#1170 a0f11465e B A A

This result indicates that #1115 introduced a behavior change, and #1168 reverted it. @AstitvaAggarwal and I had an offline chat and we'll dig deeper on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

@mooncake_overlay usage for or within primitives are silently ignored

2 participants