Skip to content

Conversation

@rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Oct 31, 2025

This PR creates a simple hash function for Fusion. It will be used to create LRU Cache for mapping Fusion definitions to FusionExecutorCache.

PR Stack

@rdspring1 rdspring1 added the Direct Bindings Python extension with direct mapping to NvFuser CPP objects. label Oct 31, 2025
@github-actions
Copy link

github-actions bot commented Oct 31, 2025

Review updated until commit b6cf1fb

Description

  • Implement hash function for Fusion class

  • Add hashing support for IR nodes (Val, Expr)

  • Support hashing of PolymorphicValue types

  • Add unit test for Fusion hash functionality


Changes walkthrough 📝

Relevant files
Enhancement
fusion.cpp
Add Fusion::hash() implementation                                               

csrc/fusion.cpp

  • Added hash() method to Fusion class
  • Hash combines inputs, exprs, and outputs
  • Uses hashCombine for incremental hashing
  • +17/-0   
    base_nodes.cpp
    Implement hashing for IR values and expressions                   

    csrc/ir/base_nodes.cpp

  • Implemented getHash() for Val and Expr
  • Val hash includes type, data type, and value
  • Expr hash includes op string and I/O vals
  • Added helper to hash vector of Vals
  • +32/-0   
    polymorphic_value.cpp
    Add hashing support for PolymorphicValue                                 

    csrc/polymorphic_value.cpp

  • Added hash() function for PolymorphicValue
  • Supports hashing of complex, double, int64_t, bool
  • Throws error for unsupported types
  • +21/-0   
    fusion.h
    Declare Fusion::hash() method                                                       

    csrc/fusion.h

  • Declared hash() method in Fusion class
  • Added documentation for hash usage
  • +3/-0     
    base_nodes.h
    Add hashing interface to IR nodes                                               

    csrc/ir/base_nodes.h

  • Added hash() method to Statement
  • Added virtual getHash() in base class
  • Declared final getHash() in Val and Expr
  • +14/-0   
    polymorphic_value.h
    Declare PolymorphicValue hash function                                     

    csrc/polymorphic_value.h

  • Declared hash() function for PolymorphicValue
  • Added to PolymorphicValue_functions namespace
  • +2/-0     
    Tests
    test_fusion_hash.cpp
    Add test for Fusion hash functionality                                     

    tests/cpp/test_fusion_hash.cpp

  • Added unit test for Fusion hash
  • Tests hash stability and non-zero result
  • Validates fusion execution with cache
  • +48/-0   
    Configuration changes
    CMakeLists.txt
    Register fusion hash test in build                                             

    CMakeLists.txt

    • Added test_fusion_hash.cpp to JIT test sources
    +1/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Incomplete Hash Handling

    The getHash method in Statement class throws an exception by default, but derived classes may not all override it. This could lead to runtime errors if a derived class does not implement getHash. All subclasses of Statement must ensure they provide a valid hash implementation.

    size_t Statement::getHash() const {
      NVF_THROW("getHash for IR node ", typeid(*this).name(), " is not defined");
    }
    Unhandled Types in Hash

    The hash function in PolymorphicValue_functions does not handle all possible types that PolymorphicValue can hold. If a new type is added without updating the hash function, it will result in a NVF_THROW. This makes the hashing fragile and error-prone for future extensions.

    size_t hash(const PolymorphicValue& v) {
      size_t hash = 0;
      if (v.is<std::monostate>()) {
        return 0;
      } else if (v.is<std::complex<double>>()) {
        std::complex<double> val = v.as<std::complex<double>>();
        std::hash<double> hasher;
        hashCombine(hash, hasher(val.real()));
        hashCombine(hash, hasher(val.imag()));
      } else if (v.is<double>()) {
        hashCombine(hash, std::hash<double>()(v.as<double>()));
      } else if (v.is<int64_t>()) {
        hashCombine(hash, std::hash<int64_t>()(v.as<int64_t>()));
      } else if (v.is<bool>()) {
        hashCombine(hash, std::hash<bool>()(v.as<bool>()));
      } else {
        NVF_THROW("Cannot hash PolymorphicValue");
      }
      return hash;
    }
    Hash Stability Assumption

    The hash function in Fusion combines hashes of inputs, expressions, and outputs. However, the order of exprs() and outputs() must be deterministic across runs; otherwise, the hash may not be stable. The implementation should ensure that these collections are consistently ordered.

    size_t Fusion::hash() const {
      size_t hash = 0;
    
      for (const Val* val : inputs()) {
        hashCombine(hash, val->hash());
      }
    
      for (const Expr* expr : exprs()) {
        hashCombine(hash, expr->hash());
      }
    
      for (const Val* val : outputs()) {
        hashCombine(hash, val->hash());
      }
      return hash;
    }

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    Implements hash function for Fusion by recursively hashing inputs, expressions, and outputs. Each Statement combines its sequential name_ with content-based hash from getHash().

    Critical Issues Found:

    • Val::getHash() uses std::get<PrimDataType> which will throw std::bad_variant_access when dtype_.type contains PointerType, ArrayType, StructType, or OpaqueType (all actively used in codebase)
    • Statement::hash() includes name_ field, causing semantically identical statements from different containers to produce different hashes, defeating cache lookup purpose

    Implementation Details:

    • Fusion::hash() sequentially combines hashes of inputs → exprs → outputs
    • Val::getHash() hashes vtype_, dtype_, and value_
    • Expr::getHash() hashes operation string and all inputs/outputs
    • PolymorphicValue_functions::hash() handles monostate, complex, double, int64_t, bool

    Confidence Score: 1/5

    • This PR has critical bugs that will cause runtime crashes and incorrect cache behavior
    • Two critical logic errors: (1) Val::getHash() will crash with std::bad_variant_access when hashing Vals with non-primitive DataTypes like PointerType, which are used throughout the codebase, and (2) including Statement name_ in hash defeats the cache lookup purpose since equivalent fusions from different containers will have different hashes
    • csrc/ir/base_nodes.cpp and csrc/ir/base_nodes.h require immediate fixes before merge

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/ir/base_nodes.cpp 1/5 Added Val::getHash() and Expr::getHash() with critical bug: assumes dtype_ is always PrimDataType, will crash on PointerType/ArrayType/StructType
    csrc/ir/base_nodes.h 1/5 Added Statement::hash() that includes name_, causing semantically identical statements to hash differently across containers
    csrc/fusion.cpp 3/5 Implements Fusion::hash() by combining hashes of inputs, exprs, and outputs in order
    csrc/polymorphic_value.cpp 4/5 Implements hash for PolymorphicValue supporting std::monostate, std::complex<double>, double, int64_t, bool

    Sequence Diagram

    sequenceDiagram
        participant User
        participant Fusion
        participant Val
        participant Expr
        participant Statement
        participant PolymorphicValue
        
        User->>Fusion: hash()
        Fusion->>Fusion: hash = 0
        
        loop For each input
            Fusion->>Val: hash()
            Val->>Statement: hash()
            Statement->>Val: getHash()
            Val->>Val: hash vtype_
            Val->>Val: hash dtype_ (PrimDataType only)
            Val->>PolymorphicValue: hash(value_)
            PolymorphicValue-->>Val: value hash
            Val-->>Statement: content hash
            Statement->>Statement: combine name_ + content hash
            Statement-->>Fusion: val hash
            Fusion->>Fusion: hashCombine(hash, val_hash)
        end
        
        loop For each expr
            Fusion->>Expr: hash()
            Expr->>Statement: hash()
            Statement->>Expr: getHash()
            Expr->>Expr: hash opString
            loop For each input/output
                Expr->>Val: hash()
                Val-->>Expr: val hash
            end
            Expr-->>Statement: content hash
            Statement->>Statement: combine name_ + content hash
            Statement-->>Fusion: expr hash
            Fusion->>Fusion: hashCombine(hash, expr_hash)
        end
        
        loop For each output
            Fusion->>Val: hash()
            Val-->>Fusion: val hash
            Fusion->>Fusion: hashCombine(hash, val_hash)
        end
        
        Fusion-->>User: final hash
    
    Loading

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +175 to +177
    hashCombine(
    hash,
    std::hash<int>()(static_cast<int>(std::get<PrimDataType>(dtype_.type))));
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: std::get<PrimDataType> will throw std::bad_variant_access if dtype_.type holds ArrayType, PointerType, StructType, or OpaqueType. The codebase uses these types (e.g., csrc/ir/builder.cpp:144, csrc/kernel_ir.cpp:1965).

    Suggested change
    hashCombine(
    hash,
    std::hash<int>()(static_cast<int>(std::get<PrimDataType>(dtype_.type))));
    hashCombine(hash, std::hash<ValType>()(vtype_));
    std::visit([&hash](auto&& arg) {
    using T = std::decay_t<decltype(arg)>;
    if constexpr (std::is_same_v<T, PrimDataType>) {
    hashCombine(hash, std::hash<int>()(static_cast<int>(arg)));
    } else {
    // Handle other DataType variants - for now, hash the type index
    hashCombine(hash, std::hash<size_t>()(dtype_.type.index()));
    }
    }, dtype_.type);

    Comment on lines +179 to +181
    size_t hash() const {
    size_t hash = 0;
    hashCombine(hash, std::hash<StmtNameType>()(name_));
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: Hashing name_ means two Fusion objects with semantically identical definitions will have different hashes if their statements have different names (which are assigned sequentially per container). This defeats the purpose of using hash for cache lookups where equivalent fusions should map to the same cache entry.

    @rdspring1
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Direct Bindings Python extension with direct mapping to NvFuser CPP objects.

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants