Honor mooncake_overlay in primitive inference#1170
Conversation
`MooncakeInterpreter` previously routed every primitive call site through `NativeInterpreter` for `CallMeta`, but `NativeInterpreter` is overlay-blind: when a primitive has a `@mooncake_overlay` that changes its return type, the inferred type at the call site was wrong and downstream dispatch compiled against the original type. Detect overlay matches via `Method.external_mt` and route them through `@invoke` instead, keeping the `NativeInterpreter` fast path for non-overlay primitives. Fixes #1169. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
0eddb03 to
848c6a3
Compare
|
On this issue, I slightly lean towards #1168 because indeed it's more general and the complexity diff w.r.t this PR is not too bad. A minor issue with the current code is the use of |
Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Signed-off-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Both `any_matches_primitive` and `any_matches_overlay` need the same pre-1.12 / 1.12+ unwrap of an `applicable` entry to a `MethodMatch`. Use a single inline `match = VERSION < v"1.12-" ? app : app.match` in each so the version skew sits in one place per function. Also tighten the overlay check to compare directly against `mooncake_method_table` rather than `external_mt !== nothing`, so an unrelated downstream overlay table cannot trip the overlay path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| @@ -204,12 +219,19 @@ end | |||
|
|
|||
| function any_matches_primitive(applicable, C, M, world) | |||
There was a problem hiding this comment.
Note that there is already any_matches_primitive defined, which accesses Julia internals in similar ways.
|
Thanks @yebai for the follow-up. After some thought and investigation, I now lean toward this PR over #1168 for the scope of issue #1169. Use of Before merging, there is one particular case I want to think through. Maybe it is not that important, but I would like to be clear. Comparing on four versions of Mooncake: on using Mooncake
struct A end
struct B end
helper(::A) = A()
Mooncake.@mooncake_overlay helper(::A) = B()
primitive_wrapper(x::A) = helper(x)
Mooncake.@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{
typeof(primitive_wrapper), A
}
caller() = primitive_wrapper(A())
caller_arg(x::A) = primitive_wrapper(x)interp = Mooncake.MooncakeInterpreter(Mooncake.DefaultCtx, Mooncake.ReverseMode)
for sig in (
Tuple{typeof(primitive_wrapper), A},
Tuple{typeof(caller)},
Tuple{typeof(caller_arg), A},
)
ir, rt = Base.code_ircode_by_type(sig; interp)[1]
println(sig, " => ", rt)
endResults:
This result indicates that #1115 introduced a behavior change, and #1168 reverted it. @AstitvaAggarwal and I had an offline chat and we'll dig deeper on this. |
MooncakeInterpreterpreviously routed every primitive call site throughNativeInterpreterforCallMeta, butNativeInterpreteris overlay-blind: when a primitive has a@mooncake_overlaythat changes its return type, the inferred type at the call site was wrong, and downstream dispatch compiled against the original type. Detect overlay matches viaMethod.external_mtand route them through@invokeinstead, keeping theNativeInterpreterfast path for non-overlay primitives.Fixes #1169. Alternative to #1168.
CI Summary — GitHub Actions
Documentation Preview
Mooncake.jl documentation for PR #1170 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR1170/
Performance
Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.