Skip to content

Conversation

@jacobhinkle
Copy link
Collaborator

No description provided.

@github-actions
Copy link

Description

  • Support epilogue inputs like bias and beta in CUTLASS kernel

  • Fix input/output ordering and dtype handling in Sm90 compute

  • Add proper argument passing for EVT nodes with scalar and tensor inputs

  • Implement Sm90AuxLoad and Sm90SrcFetch for bias tensor handling


Changes walkthrough 📝

Relevant files
Enhancement
codegen.cpp
Move and add fusion position helper functions                       

csrc/cutlass/codegen.cpp

  • Added fusionInputPosition and fusionOutputPosition helper functions
  • Moved helper functions from gemm.cpp to enable reuse
  • Includes fusion and IR headers for Val and Fusion access
  • +16/-2   
    evt.cpp
    Support epilogue inputs with proper EVT node arguments     

    csrc/cutlass/evt.cpp

  • Added getPointerCode to generate input/output pointer casting
  • Implemented makeAuxLoadNode for Sm90AuxLoad node creation
  • Enhanced argument handling with key-value pairs in EVT nodes
  • Added input validation for alpha, beta, and bias contiguity
  • +120/-42
    gemm.cpp
    Enable bias support in CUTLASS GEMM kernel                             

    csrc/cutlass/gemm.cpp

  • Added bias tensor handling in kernel configuration
  • Introduced EpilogueTileShape and EpilogueScheduleType
  • Set ElementC based on bias dtype when present
  • Updated argument passing to include bias pointer
  • +50/-20 
    codegen.h
    Declare fusion position utility functions                               

    csrc/cutlass/codegen.h

  • Declared fusionInputPosition and fusionOutputPosition
  • Added Val forward declaration
  • Header now supports new helper functions
  • +9/-0     
    evt.h
    Update EVT node to support multiple arguments                       

    csrc/cutlass/evt.h

  • Replaced argument with arguments vector of key-value pairs
  • Supports multiple named arguments per EVT node
  • Enables proper scalar and tensor argument passing
  • +2/-2     
    cutlass.h
    Add epilogue tiling parameters                                                     

    csrc/scheduler/cutlass.h

  • Added epilogue_tile parameter for tiling control
  • Introduced epilogue_stages for circular buffering
  • Default tile size set to 64x64
  • +7/-0     
    Bug fix
    cutlass_compiled_kernel.cpp
    Remove fixed argument count assumption                                     

    csrc/runtime/cutlass_compiled_kernel.cpp

  • Removed outdated argument count validation
  • Generalized tensor argument handling
  • Prepares for dynamic input ordering
  • +0/-10   
    Tests
    test_cutlass_scheduler.cpp
    Add test for bias epilogue in CUTLASS                                       

    tests/cpp/test_cutlass_scheduler.cpp

  • Added test for bias + beta epilogue with ReLU
  • Created reference fusion for validation
  • Tests proper ordering and computation
  • Skips on unsupported GPU architectures
  • +154/-4 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function getPointerCode computes the index for fusion outputs using fusion_->inputs().size() + fusionOutputPosition(fusion_, tv), but it does not validate that the computed index is within the bounds of the combined inputs and outputs vector. This could lead to an out-of-bounds access in the generated code if the output position is incorrect or if outputs are not properly ordered.

    std::string getPointerCode(TensorView* tv) {
      int64_t index = -1;
      if (tv->isFusionInput()) {
        index = fusionInputPosition(fusion_, tv);
      } else if (tv->isFusionOutput()) {
        index = fusion_->inputs().size() + fusionOutputPosition(fusion_, tv);
      } else {
        NVF_THROW(
            "Cannot get pointer for TV ",
            tv->toString(),
            " which is not a fusion input or output");
      }
      return "static_cast<" + dtypeToCutlass(tv->dtype()) + "*>(inputs.at(" +
          std::to_string(index) + ").data_ptr)";
    }
    Possible Issue

    The code adds params.epilogue_tile dimensions to the KernelTraits struct, but uses hardcoded values _64, _64 for EpilogueTileShape. This inconsistency could lead to incorrect epilogue tiling if params.epilogue_tile.m or .n are not 64. The values should be derived from the parameters to maintain consistency.

    using EpilogueTileShape = Shape<_64, _64>;
    Performance Issue

    The check_input lambda function validates contiguity of alpha, beta, and bias tensors, but creates a new lambda for each call. This could be optimized by reusing a single lambda or inlining the checks to reduce overhead, especially since these checks are performed during code generation which may affect compilation performance.

    auto check_input = [](TensorView* inp) {
      if (inp == nullptr) {
        // Allow null
        return;
      }
      // Check that input is contiguous
      const std::vector<std::optional<bool>>& contig = inp->getContiguity();
      NVF_ERROR(
          std::all_of(
              contig.begin(),
              contig.end(),
              [](const std::optional<bool>& c) {
                return !c.has_value() || c.value();
              }),
          "Expected all inputs to ScaledMmaOp to be contiguous but found ",
          inp->toString());
    };
    check_input(alpha);
    check_input(beta);
    check_input(bias);
    

    jacobhinkle added a commit that referenced this pull request Oct 30, 2025
    Previously, we supported a single `TensorView*` argument for each EVT
    node. Instead, this PR changes to allow an argument list, provided as a
    simple list of key-value string pairs. This provides more flexibility
    which is needed to support EVT nodes that require multiple parameters.
    
    This is needed for #5441 and #5440
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants