Skip to content

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Oct 28, 2025

Prior to this change we have two separate systems for computing parallel extents:

  1. ParallelDimensionMap is used to compute dimensions for generating the CUDA kernel
  2. The KernelExecutor uses another system including ParallelExtentMap which gathers all of the extents of IDs in the promoted group for each paralleltype, and computes the max extent across all those.

This is redundant: we should have a centralized place to find the following information:

  1. Extent expressions for use in generated code
  2. Extent expressions for evaluation in the executor
  3. Index expressions used in generated code

In the future, we might include other stuff like predication for warp specialization. This PR is a step toward centralizing our parallel dimension mapping into ParallelDimensionMap by removing the redundant parts in KernelExecutor.

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@github-actions
Copy link

github-actions bot commented Oct 28, 2025

Review updated until commit 2923cb6

Description

  • Remove redundant ParallelExtentMap in favor of ParallelDimensionMap

  • Centralize parallel dimension handling in KernelExecutor

  • Pre-bind launch constraints to avoid evaluation errors

  • Improve validation of parallel type constraints


Changes walkthrough 📝

Relevant files
Enhancement
evaluator_common.cpp
Remove ParallelExtentMap binding logic                                     

csrc/evaluator_common.cpp

  • Remove bindParallelExtents function that used ParallelExtentMap
  • Clean up unused parallel extent binding logic
  • +0/-16   
    executor.cpp
    Use ParallelDimensionMap for launch parameter computation

    csrc/runtime/executor.cpp

  • Replace ParallelExtentMap usage with ParallelDimensionMap
  • Pre-bind launch constraints before evaluation
  • Improve constraint validation with better error messages
  • Simplify parallel dimension resolution logic
  • +58/-83 
    executor_utils.cpp
    Remove deprecated parallel extent utility functions           

    csrc/runtime/executor_utils.cpp

  • Remove getParallelBindingsIterDomains and getParallelIterExtents
    functions
  • Delete unused ParallelBindingIterDomains and ParallelIterExtentMap
    template instantiations
  • Keep only necessary executor utility functions
  • +0/-56   
    evaluator_common.h
    Remove ParallelExtentMap from evaluator interface               

    csrc/evaluator_common.h

  • Remove ParallelExtentMap type alias
  • Delete bindParallelExtents function declaration
  • Clean up unused parallel extent related declarations
  • +0/-8     
    executor_utils.h
    Remove parallel extent map type definitions                           

    csrc/runtime/executor_utils.h

  • Remove ParallelBindingIterDomains and ParallelIterExtentMap classes
  • Delete CompileTimeEntryType enum values for parallel extent entries
  • Remove getParallelBindingsIterDomains and getParallelIterExtents
    declarations
  • Clean up unused parallel extent related types
  • +0/-42   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The PR removes the use of ParallelExtentMap and relies solely on ParallelDimensionMap, but the validation logic for launch constraints now only checks inferred values when a constraint is provided. Previously, even without a constraint, extents were evaluated and validated via bindParallelExtents. This change may skip important validation steps for parallel extents that are not user-constrained, potentially leading to incorrect launch parameters if the evaluation context is incomplete.

    for (auto [p_type, extent] : parallel_dim_map) {
      if (launch_constraints.hasDim(p_type)) {
        // User provided a launch constraint for this parallel type
        int64_t constraint_value = launch_constraints.getDim(p_type);
        expr_eval.bind(extent, constraint_value);
        expr_eval.bind(p_type, constraint_value);
      }
    }
    
    // Process launch constraints and compute launch parameters.
    // For each parallel type in the ParallelDimensionMap, either use the
    // launch constraint if provided, or evaluate the extent to infer the size.
    for (auto [p_type, extent] : parallel_dim_map) {
      FUSER_PERF_SCOPE("KernelExecutor::ParallelBindingResolution");
    
      if (launch_constraints.hasDim(p_type)) {
        // User provided a launch constraint for this parallel type
        int64_t constraint_val = launch_constraints.getDim(p_type);
    
        // Try to evaluate the extent to validate the constraint
        auto inferred_val = expr_eval.evaluate(extent);
        if (inferred_val.hasValue()) {
          // We can infer the value - validate it matches the constraint
          bool valid = inferred_val.as<int64_t>() == constraint_val ||
              launch_constraints.getRawVal(p_type) == -1;
          if (!useFallback() && !valid) {
            TORCH_WARN_ONCE(
                "Cannot validate parallelization scheme for ",
                p_type,
                ": inferred value ",
                inferred_val.as<int64_t>(),
                " does not match constraint ",
                constraint_val,
                ". This may be due to mixed broadcast axes that are "
                "parallelized.");
          }
        }
    
        launch_params.bind(constraint_val, p_type);
      } else {
        // No launch constraint - infer the parallel dimension size
        auto val = expr_eval.evaluate(extent);
        NVF_ERROR(
            val.hasValue(),
            "Tried to evaluate the extent, ",
            extent->toInlineString(),
            " for the ptype: ",
            p_type,
            " to set launch bounds but could not.");
    
        if (val > 0) {
          expr_eval.bind(p_type, val);
          launch_params.bind(val.as<int64_t>(), p_type);
        }
      }
    }
    Function Removal

    The functions getParallelBindingsIterDomains and getParallelIterExtents are removed, which were responsible for collecting and mapping parallelized iterdomains and their extents. While the PR intends to centralize this logic in ParallelDimensionMap, it is unclear if ParallelDimensionMap fully captures the same semantics, especially regarding broadcast domains and unresolved mappings. This may lead to missing parallel bindings in edge cases previously handled by the removed code.

    void validateIndexCasts(
        kir::Kernel* kernel,
        ExpressionEvaluator& expr_eval,
        const LaunchParams& launch_params) {
      if (!kernel->summary().has_narrowing_index_casts) {
        return;
      }
      ScalarBoundsCalculator calc(kernel, expr_eval, launch_params);
      NVF_ERROR(
          calc.castsFromIndexAreSafe(),
          "Found unsafe casts from DataType::Index. ",
          "This is likely because one coordinate of a TMA instruction overflowed "
          "Int32");
    }
    
    } // namespace executor_utils
    } // namespace nvfuser

    @jacobhinkle
    Copy link
    Collaborator Author

    !test --diff

    @jacobhinkle
    Copy link
    Collaborator Author

    !test --diff

    @jacobhinkle
    Copy link
    Collaborator Author

    !test --diff

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants