Skip to content

Fix Type infering for overlayed methods#1168

Open
AstitvaAggarwal wants to merge 4 commits into
mainfrom
fix/overlay-type-change-primitive-inference
Open

Fix Type infering for overlayed methods#1168
AstitvaAggarwal wants to merge 4 commits into
mainfrom
fix/overlay-type-change-primitive-inference

Conversation

@AstitvaAggarwal
Copy link
Copy Markdown
Member

@AstitvaAggarwal AstitvaAggarwal commented May 6, 2026

This PR Fixes incorrect gradients when @mooncake_overlay changes a primitive's return type :

When a primitive has a @mooncake_overlay that changes its return type, downstream call sites in the same function were getting the wrong type inferred. The root cause was that the primitive inference path used NativeInterpreter directly, which doesn't know about Mooncake's overlay method table.
So even though the overlay exists and runs at runtime, the type inference step would see the original method body and infer the original return type. Anything that dispatched on that return type downstream would then compile against the wrong specialisation.

The fix introduces OverlayAwareNativeInterpreter, which infers every primitive using NativeInterpreter (with the standard julia method table) and also redirects it to Mooncake's overlay table when we see a overlayed method. Using this in the primitive path means inference sees the overlay body and gets the correct return type.
The OverlayAwareNativeInterpreter wrapper has no abstract_call_gf_by_type override, so calls inside primitive bodies still go through the default path and we don't recurse back into MooncakeInterpreter machinery - the base-case guarantee from PR #1115 is preserved.

Also added a regression test: a primitive overlay switches _OverlayTypeA to _OverlayTypeB, and a downstream function has separate rrule!!s for each type giving gradients of 1x vs 2x. Without the fix the gradient is 1.0; with it, 2.0.

CI Summary — GitHub Actions

Documentation Preview

Mooncake.jl documentation for PR #1168 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR1168/

Performance

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌───────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                 Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                String │   String │   String │      String │  String │      String │ String │
├───────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│              sum_1000 │ 181.0 ns │     1.55 │        1.61 │   0.718 │         3.6 │   7.02 │
│             _sum_1000 │  1.07 μs │     6.41 │        1.05 │  4290.0 │        43.0 │   1.08 │
│          sum_sin_1000 │  7.43 μs │     2.46 │        1.11 │    1.63 │        10.9 │   1.75 │
│         _sum_sin_1000 │  4.79 μs │     3.75 │        2.55 │   347.0 │        16.9 │   2.99 │
│              kron_sum │ 189.0 μs │     13.2 │        3.33 │    7.33 │       468.0 │   22.8 │
│         kron_view_sum │ 261.0 μs │     12.8 │        5.31 │    28.3 │       454.0 │   14.0 │
│ naive_map_sin_cos_exp │  2.26 μs │     2.79 │        1.53 │ missing │        8.32 │   2.14 │
│       map_sin_cos_exp │  2.16 μs │     3.37 │        1.67 │    1.59 │        7.32 │   2.72 │
│ broadcast_sin_cos_exp │   2.3 μs │     2.93 │        1.57 │     4.2 │        1.39 │   2.07 │
│            simple_mlp │ 457.0 μs │      3.8 │         2.2 │    1.64 │        7.13 │   2.41 │
│                gp_lml │ 165.0 μs │     11.6 │        2.61 │    4.79 │     missing │   5.55 │
│    large_single_block │ 471.0 ns │     5.27 │        1.93 │  4220.0 │        31.4 │   2.06 │
└───────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

@AstitvaAggarwal AstitvaAggarwal force-pushed the fix/overlay-type-change-primitive-inference branch from 4ecd824 to ac5a35f Compare May 7, 2026 10:44
@ChrisRackauckas
Copy link
Copy Markdown
Collaborator

Yes that seems like exactly the fix needed.

end

function OverlayAwareNativeInterpreter(world::UInt)
OverlayAwareNativeInterpreter(CC.NativeInterpreter(world), CC.InferenceResult[])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inf_cache is allocated and need some caution — any potential issue here?

Copy link
Copy Markdown
Member Author

@AstitvaAggarwal AstitvaAggarwal May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldnt be an issue, OverlayAwareNativeInterpreter is constructed as a local variable when called and is never stored anywhere. inf_cache starts empty each time, so there are no stale results to return and cache_owner returns nothing, so nothing is committed to the global cache.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

wontfix This will not be worked on

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants