Fix Type infering for overlayed methods#1168
Open
AstitvaAggarwal wants to merge 4 commits into
Open
Conversation
4ecd824 to
ac5a35f
Compare
Collaborator
|
Yes that seems like exactly the fix needed. |
sunxd3
reviewed
May 8, 2026
| end | ||
|
|
||
| function OverlayAwareNativeInterpreter(world::UInt) | ||
| OverlayAwareNativeInterpreter(CC.NativeInterpreter(world), CC.InferenceResult[]) |
Collaborator
There was a problem hiding this comment.
inf_cache is allocated and need some caution — any potential issue here?
Member
Author
There was a problem hiding this comment.
There shouldnt be an issue, OverlayAwareNativeInterpreter is constructed as a local variable when called and is never stored anywhere. inf_cache starts empty each time, so there are no stale results to return and cache_owner returns nothing, so nothing is committed to the global cache.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR Fixes incorrect gradients when
@mooncake_overlaychanges a primitive's return type :When a primitive has a
@mooncake_overlaythat changes its return type, downstream call sites in the same function were getting the wrong type inferred. The root cause was that the primitive inference path usedNativeInterpreterdirectly, which doesn't know about Mooncake's overlay method table.So even though the overlay exists and runs at runtime, the type inference step would see the original method body and infer the original return type. Anything that dispatched on that return type downstream would then compile against the wrong specialisation.
The fix introduces
OverlayAwareNativeInterpreter, which infers every primitive usingNativeInterpreter(with the standard julia method table) and also redirects it toMooncake's overlay table when we see a overlayed method. Using this in the primitive path means inference sees the overlay body and gets the correct return type.The
OverlayAwareNativeInterpreterwrapper has noabstract_call_gf_by_typeoverride, so calls inside primitive bodies still go through the default path and we don't recurse back intoMooncakeInterpretermachinery - the base-case guarantee from PR #1115 is preserved.Also added a regression test: a primitive overlay switches
_OverlayTypeAto_OverlayTypeB, and a downstream function has separaterrule!!s for each type giving gradients of1xvs2x. Without the fix the gradient is1.0; with it,2.0.CI Summary — GitHub Actions
Documentation Preview
Mooncake.jl documentation for PR #1168 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR1168/
Performance
Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.