- 
                Notifications
    
You must be signed in to change notification settings  - Fork 3.7k
 
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops tests(5) #18417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
          Summary of ChangesHello @tlopex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates support for decomposed operators within the PyTorch frontend for Relax. It updates the expected Intermediate Representation (IR) in several test cases to accurately reflect how operators like  Highlights
 Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either  
 Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a  Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates tests for the PyTorch frontend to support decomposed operators by enabling run_ep_decomposition=True. The changes involve updating the expected Relax IR for several operators like stack, tile, fill, masked_fill, and type casting ops to match the output of PyTorch's decomposition pass. Most changes correctly reflect more efficient or accurate decompositions. However, I've pointed out an inconsistency in the decomposition of torch.stack for different axes and suggested a more canonical approach for better robustness and consistency.
| lv: R.Tensor((4, 3), dtype="float32") = R.concat((x, y), axis=0) | ||
| lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, R.shape([2, 2, 3])) | ||
| gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decomposition of torch.stack with axis=0 into concat and reshape is valid for contiguous tensors, but it's not as canonical as using expand_dims and concat. The decomposition for axis=-1 in Expected3 uses expand_dims and concat, which is a more direct translation of the stack operation's semantics of inserting a new dimension. For consistency and clarity, it would be better to use the same expand_dims and concat approach for axis=0. This would make the decomposition logic more robust and easier to understand across different axes.
For example:
lv: R.Tensor((1, 2, 3), dtype="float32") = R.expand_dims(x, axis=0)
lv1: R.Tensor((1, 2, 3), dtype="float32") = R.expand_dims(y, axis=0)
lv2: R.Tensor((2, 2, 3), dtype="float32") = R.concat((lv, lv1), axis=0)
gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv2,)| lv: R.Tensor((2, 6), dtype="float32") = R.concat((x, y), axis=1) | ||
| lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, R.shape([2, 2, 3])) | ||
| gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the axis=0 case, the decomposition of torch.stack with axis=1 into concat and reshape is less canonical than using expand_dims and concat. Using a consistent decomposition strategy across all axes would improve the robustness and readability of the translated IR.
For example:
lv: R.Tensor((2, 1, 3), dtype="float32") = R.expand_dims(x, axis=1)
lv1: R.Tensor((2, 1, 3), dtype="float32") = R.expand_dims(y, axis=1)
lv2: R.Tensor((2, 2, 3), dtype="float32") = R.concat((lv, lv1), axis=1)
gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv2,)
This pr fixes ops like
tile,transposeand so on.