- 
                Notifications
    
You must be signed in to change notification settings  - Fork 68
 
Add meta device support for grouped mma #5472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| 
          
 Review updated until commit 09c86f7 Description
 Changes walkthrough 📝
 PR Reviewer Guide 🔍Here are some key observations to aid the review process: 
  | 
    
| 
           !test  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds meta device support for grouped matrix multiplication operations in nvFuser.
Key changes:
- Implements fast path in 
GroupedMmaOp::evaluate()that handles meta tensors without executing actual computation - Correctly computes output shapes for all three supported input dimension combinations: [2D, 2D] -> [g, m, n], [3D, 2D] -> [m, n], and [2D, 3D] -> [m, n]
 - Properly handles rfactor dimensions via 
unsqueezewhen needed - Preserves output dtype using 
data_type_to_aten(out()->dtype()) - Adds comprehensive test coverage with 20 parameterized test cases (2×2 + 6×2 + 2×6 combinations) testing various memory layouts including transposed and permuted tensors
 
Confidence Score: 5/5
- This PR is safe to merge with minimal risk
 - The implementation correctly mirrors the non-meta execution path logic for output shape calculation, handles all edge cases (rfactor dimensions, dtype conversion), and includes thorough test coverage across multiple memory layout configurations. The code follows established patterns from other meta device implementations in the codebase.
 - No files require special attention
 
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| csrc/ir/nodes.cpp | 5/5 | Added meta device fast path for GroupedMmaOp evaluation that correctly computes output shapes for all three supported input configurations | 
| tests/cpp/test_meta.cpp | 5/5 | Comprehensive parameterized tests covering all three grouped MMA input configurations with various memory layouts | 
Sequence Diagram
sequenceDiagram
    participant User
    participant ExpressionEvaluator
    participant GroupedMmaOp
    participant MetaPath
    participant CUDAPath
    
    User->>ExpressionEvaluator: evaluate(fusion_output)
    ExpressionEvaluator->>GroupedMmaOp: evaluate(inputs)
    
    GroupedMmaOp->>GroupedMmaOp: Check if inputs are tensors
    GroupedMmaOp->>GroupedMmaOp: Check if any tensor is_meta()
    
    alt Meta Device Path
        GroupedMmaOp->>MetaPath: Handle meta tensors
        MetaPath->>MetaPath: Get num_groups from offsets
        MetaPath->>MetaPath: Determine output shape based on input dims
        Note over MetaPath: Case 1: [2,2] -> [g,m,n]<br/>Case 2: [3,2] -> [m,n]<br/>Case 3: [2,3] -> [m,n]
        MetaPath->>MetaPath: Create meta tensor with at::empty()
        MetaPath->>MetaPath: Handle rfactor dimension if needed
        MetaPath-->>GroupedMmaOp: Return meta tensor
    else CUDA Device Path
        GroupedMmaOp->>CUDAPath: Handle CUDA tensors
        CUDAPath->>CUDAPath: Split inputs into groups
        CUDAPath->>CUDAPath: Allocate output with at::empty()
        CUDAPath->>CUDAPath: Perform matmul for each group
        CUDAPath->>CUDAPath: Convert dtype if needed
        CUDAPath->>CUDAPath: Handle rfactor dimension if needed
        CUDAPath-->>GroupedMmaOp: Return result tensor
    end
    
    GroupedMmaOp-->>ExpressionEvaluator: Return output tensor
    ExpressionEvaluator-->>User: Return evaluated result
    2 files reviewed, no comments
| auto options = mat1_meta_check.options() | ||
| .device(c10::Device(c10::kMeta)) | ||
| .dtype(data_type_to_aten(out()->dtype())); | ||
| at::Tensor result = at::empty(result_sizes, options); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do strides matter when you create a meta tensor? If not, why does test_meta.cpp check various permutations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent question. Yes, strides do matter, in a strange way. Meta tensor does have a stride, and according to https://docs.pytorch.org/docs/stable/meta.html:
In some cases, not all device types (e.g., CPU and CUDA) have exactly the same output metadata for an operation; we typically prefer representing the CUDA behavior faithfully in this situation.
We can treat meta tensor as a tool for shape and stride inference for CUDA tensors.
The strange part of this problem is, when implementing a kernel for an op, people often only think about how to write that kernel for contiguous tensors to simplify the problem. So in PyTorch, a lot of ops begin with just calling .contiguous for all operands. If this is the case, the stride of the input tensor does not matter, because whatever the input stride is, it will be converted to contiguous first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this case, I intentionally ignored the input stride to match the behavior of CUDA device.
Co-authored-by: Jingyue Wu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds meta device support for GroupedMmaOp operations and corresponding comprehensive test coverage. The changes include:
- Minor refactoring: renamed 
fusion_ptrtofusionin existing scan tests for consistency - Added meta device fast path in 
GroupedMmaOp::evaluate()(csrc/ir/nodes.cpp) that handles meta tensors without actual computation - Added extensive parameterized tests for three GroupedMma configurations:
- 2D × 2D matrices → 3D output
 - 3D × 2D matrices → 2D output
 - 2D × 3D matrices → 2D output
 
 - Tests verify all memory layout permutations (contiguous, transposed for 2D; all 6 permutations for 3D)
 - Each test validates meta tensor properties: device type, dtype, sizes, and strides match reference CUDA execution
 
The implementation correctly handles shape inference for all three matrix multiplication patterns and properly propagates tensor metadata.
Confidence Score: 5/5
- This PR is safe to merge with minimal risk
 - The changes are well-structured test additions with a minor variable rename refactor. The new tests provide comprehensive coverage of different memory layouts for grouped matrix multiplication operations. The implementation follows existing patterns, uses proper parameterized testing, and validates all critical properties (device type, dtype, sizes, strides). No logical errors or potential runtime issues detected.
 - No files require special attention
 
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| tests/cpp/test_meta.cpp | 5/5 | Refactored variable names from fusion_ptr to fusion and added comprehensive meta device tests for GroupedMmaOp with various memory layouts | 
Sequence Diagram
sequenceDiagram
    participant Test as Test Case
    participant Fusion as Fusion Object
    participant EE_CUDA as ExpressionEvaluator (CUDA)
    participant EE_Meta as ExpressionEvaluator (Meta)
    participant GroupedMma as GroupedMmaOp
    
    Test->>Fusion: Create fusion with grouped_mm
    Test->>Fusion: Add inputs (mat1, mat2, offsets)
    Test->>Fusion: Add output (result.tv)
    
    Test->>EE_CUDA: bind() real CUDA tensors
    Test->>EE_CUDA: evaluate() fusion output
    EE_CUDA->>GroupedMma: evaluate() with CUDA tensors
    GroupedMma-->>EE_CUDA: Return computed result
    EE_CUDA-->>Test: real_out (CUDA tensor)
    
    Test->>EE_Meta: bind() meta tensors (empty_strided)
    Test->>EE_Meta: evaluate() fusion output
    EE_Meta->>GroupedMma: evaluate() with meta tensors
    Note over GroupedMma: Fast path: detect meta device<br/>compute output shape/strides<br/>return empty meta tensor
    GroupedMma-->>EE_Meta: Return meta result
    EE_Meta-->>Test: meta_out (meta tensor)
    
    Test->>Test: Verify meta_out properties:<br/>- is_meta() == true<br/>- dtype matches<br/>- sizes match real_out<br/>- strides match real_out
    1 file reviewed, no comments
Co-authored-by: Jingyue Wu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Added meta device support for grouped matrix multiplication operations (GroupedMmaOp). The implementation adds a fast-path in GroupedMmaOp::evaluate() that handles meta tensors by computing output shapes without executing actual operations, enabling shape inference without GPU memory allocation.
Key Changes:
- Implementation (
csrc/ir/nodes.cpp): Added meta tensor fast-path that computes output shapes for three supported input configurations: (2D, 2D), (3D, 2D), and (2D, 3D) - Tests (
tests/cpp/test_meta.cpp):- Refactored existing tests to use cleaner variable naming (
fusion_ptr→fusion) - Added comprehensive parameterized test suites covering all three input dimension combinations
 - Each test suite validates various memory layouts (contiguous, transposed, and all 6 permutations for 3D tensors)
 - Total of 2×2 + 6×2 + 2×6 = 28 new parameterized test cases
 
 - Refactored existing tests to use cleaner variable naming (
 
The implementation correctly handles shape inference for grouped matrix multiplications and validates that meta tensors maintain proper shapes, strides, and dtypes matching actual CUDA execution.
Confidence Score: 5/5
- This PR is safe to merge with minimal risk
 - The changes are well-structured with comprehensive test coverage. The implementation follows existing patterns in the codebase, adds proper error handling, and the logic for computing output shapes is straightforward and correct. The refactoring is minimal and low-risk (variable renaming only). All three input dimension combinations are thoroughly tested across multiple memory layouts.
 - No files require special attention
 
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| tests/cpp/test_meta.cpp | 5/5 | Added comprehensive parameterized tests for grouped MMA operations with various memory layouts, refactored variable names for clarity | 
Sequence Diagram
sequenceDiagram
    participant Test as Test Framework
    participant Fusion as Fusion Object
    participant EE_CUDA as ExpressionEvaluator (CUDA)
    participant EE_Meta as ExpressionEvaluator (Meta)
    participant GroupedMma as GroupedMmaOp::evaluate()
    
    Test->>Fusion: Create fusion with grouped_mm op
    Fusion->>Fusion: addInput(mat1, mat2, offsets)
    Fusion->>Fusion: addOutput(result)
    
    Note over Test,EE_CUDA: CUDA Path (for reference)
    Test->>EE_CUDA: bind(inputs, CUDA tensors)
    EE_CUDA->>GroupedMma: evaluate(inputs)
    GroupedMma->>GroupedMma: Execute actual GEMM operation
    GroupedMma-->>EE_CUDA: Return result tensor
    EE_CUDA-->>Test: Return CUDA output
    
    Note over Test,EE_Meta: Meta Path (new feature)
    Test->>EE_Meta: bind(inputs, Meta tensors)
    EE_Meta->>GroupedMma: evaluate(inputs)
    GroupedMma->>GroupedMma: Check if any input is_meta()
    GroupedMma->>GroupedMma: Determine output shape based on input dims
    alt mat1=2D, mat2=2D
        GroupedMma->>GroupedMma: result_sizes = [num_groups, m, n]
    else mat1=3D, mat2=2D
        GroupedMma->>GroupedMma: result_sizes = [m, n]
    else mat1=2D, mat2=3D
        GroupedMma->>GroupedMma: result_sizes = [m, n]
    end
    GroupedMma->>GroupedMma: Create empty meta tensor
    GroupedMma-->>EE_Meta: Return meta tensor
    EE_Meta-->>Test: Return meta output
    
    Test->>Test: Verify meta tensor properties
    Test->>Test: Compare sizes/strides with CUDA output
    1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Added meta device support for GroupedMmaOp by implementing a fast path in the evaluate() method. When any input tensor is on the meta device, the code determines output shape based on input dimensions and creates an appropriately sized tensor without performing actual computation.
Key changes:
- Added early-exit meta device path in 
GroupedMmaOp::evaluate()(csrc/ir/nodes.cpp:5685-5721) - Handles three input configurations: [2D, 2D] → [g, m, n], [3D, 2D] → [m, n], [2D, 3D] → [m, n]
 - Applies rfactor device dimension unsqueeze when needed
 - Added comprehensive parameterized tests covering 2×2 + 6×2 + 2×6 = 52 memory format combinations
 
Potential issue:
The meta path uses at::empty() which creates contiguous tensors, while MatmulOp::evaluate() uses inferShapeOfOutput() to preserve strides. Tests explicitly verify stride matching, which may fail for non-contiguous inputs.
Confidence Score: 3/5
- PR has good test coverage but contains a logic issue that may cause stride-related test failures
 - The implementation correctly handles shape computation for all three input dimension combinations and includes rfactor handling. However, the meta device path uses 
at::empty()instead ofinferShapeOfOutput()to determine strides, which likely won't preserve non-contiguous memory layouts. Since the tests explicitly check stride equality, this will probably fail for transposed/permuted inputs. - csrc/ir/nodes.cpp - stride handling in meta device path needs to match MatmulOp pattern
 
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| csrc/ir/nodes.cpp | 3/5 | Added meta device fast path for GroupedMmaOp::evaluate(). Issue: uses at::empty() instead of inferShapeOfOutput() which may not preserve strides correctly. | 
| tests/cpp/test_meta.cpp | 5/5 | Added comprehensive parameterized tests for GroupedMmaOp with various memory formats (2D/3D tensors). Tests verify that meta device evaluation produces correct sizes and strides. | 
Sequence Diagram
sequenceDiagram
    participant Test as Test Code
    participant EE as ExpressionEvaluator
    participant GMO as GroupedMmaOp::evaluate()
    participant AT as ATen Library
    
    Test->>EE: bind meta tensors
    Test->>EE: evaluate(fusion->outputs())
    EE->>GMO: evaluate(ee, inputs)
    
    alt Meta Device Fast Path
        GMO->>GMO: Check if any input is meta
        GMO->>GMO: Determine output shape based on input dims
        Note over GMO: [2,2]->[g,m,n]<br/>[3,2]->[m,n]<br/>[2,3]->[m,n]
        GMO->>AT: at::empty(result_sizes, options)
        GMO->>GMO: Check rfactor_did_idx
        opt rfactor_did_idx != -1
            GMO->>AT: result.unsqueeze(rfactor_did_idx)
        end
        GMO-->>EE: return meta tensor
    else Regular CUDA Path
        GMO->>AT: Execute grouped matmul
        GMO->>GMO: Apply rfactor unsqueeze
        GMO-->>EE: return result tensor
    end
    
    EE-->>Test: tensor result
    Test->>Test: Verify sizes and strides match
    2 files reviewed, 1 comment
| auto options = mat1_meta.options() | ||
| .device(c10::Device(c10::kMeta)) | ||
| .dtype(data_type_to_aten(out()->dtype())); | ||
| at::Tensor result = at::empty(result_sizes, options); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: using at::empty() creates a contiguous tensor that may not match expected strides. MatmulOp::evaluate at line 4831 uses inferShapeOfOutput(out(), ee) to get correct sizes/strides, then creates the output with at::detail::empty_strided_meta() or at::empty_strided(). The tests verify stride matching (line 246, 319, 392 in test_meta.cpp), so this may fail for non-contiguous layouts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds comprehensive test coverage for grouped MMA meta device support, testing three different input dimension combinations (2Dx2D, 3Dx2D, 2Dx3D) with various memory layouts.
Key changes:
- Minor refactoring: renamed 
fusion_ptrtofusionin existing scan tests for consistency - Added parameterized test infrastructure with 
MemoryFormat2DandMemoryFormat3Denums to test non-contiguous input layouts - Created helper functions 
createTensor2DandcreateTensor3Dto generate tensors with specific memory layouts (transposed, permuted) - Implemented three test suites covering all grouped MMA input/output dimension combinations:
MetaTestGroupedMma2D2D: mat1=[m,k] x mat2=[k,n] → out=[g,m,n]MetaTestGroupedMma3D2D: mat1=[g,m,k] x mat2=[k,n] → out=[m,n]MetaTestGroupedMma2D3D: mat1=[m,k] x mat2=[g,k,n] → out=[m,n]
 - Each test verifies that meta device evaluation produces matching sizes, strides, and dtypes compared to CUDA execution
 
The tests validate that the meta device path in GroupedMmaOp::evaluate (added in commit 83923b4) correctly handles non-contiguous input tensors and produces output tensors with correct stride information.
Confidence Score: 5/5
- This PR is safe to merge - adds only test coverage with no functional changes to production code
 - This PR exclusively adds test code to verify existing meta device support for grouped MMA operations. The tests are well-structured with parameterized testing covering multiple memory layout combinations. The minor refactoring (fusion_ptr → fusion) improves code consistency. No production code is modified, eliminating risk of introducing bugs.
 - No files require special attention
 
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| tests/cpp/test_meta.cpp | 5/5 | Adds comprehensive parameterized tests for grouped MMA meta device support with various memory formats (2D/3D combinations with different strides). Also refactors existing scan tests to use consistent naming. | 
Sequence Diagram
sequenceDiagram
    participant Test as Test Suite
    participant Fusion as Fusion Graph
    participant EE_CUDA as ExpressionEvaluator (CUDA)
    participant EE_Meta as ExpressionEvaluator (Meta)
    participant GMM as GroupedMmaOp
    participant CUDA as CUDA Runtime
    
    Test->>Fusion: Create fusion with grouped_mm op
    Fusion->>Fusion: addInput(mat1, mat2, offsets)
    Fusion->>Fusion: addOutput(result)
    
    Note over Test,EE_CUDA: CUDA Path (Reference)
    Test->>Test: Create input tensors with specific memory formats
    Test->>EE_CUDA: bind(inputs, real_tensors)
    EE_CUDA->>GMM: evaluate(inputs)
    GMM->>CUDA: Execute grouped MM
    CUDA-->>GMM: Return result tensor
    GMM-->>EE_CUDA: Return PolymorphicValue
    EE_CUDA-->>Test: real_out (with strides)
    
    Note over Test,EE_Meta: Meta Path (Under Test)
    Test->>Test: Create meta tensors with at::empty_strided
    Test->>EE_Meta: bind(inputs, meta_tensors)
    EE_Meta->>GMM: evaluate(inputs)
    GMM->>GMM: Detect meta device inputs
    GMM->>GMM: Compute result_sizes based on input dims
    GMM->>GMM: Create result with at::empty(sizes, meta_options)
    GMM-->>EE_Meta: Return meta tensor
    EE_Meta-->>Test: meta_out
    
    Test->>Test: Verify: sizes match
    Test->>Test: Verify: strides match
    Test->>Test: Verify: dtype matches
    Test->>Test: Verify: is_meta() == true
    1 file reviewed, no comments
| 
           !test  | 
    
No description provided.