Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Oct 23, 2025

This PR is not ready to merge -- too many safety checks are missing and see code comments for details. However, this PR gives an estimate on how much properly enabling CUDA graph in nvFuser can help this particular benchmark, which happens to use bounded dynamic shapes.

Without CUDA graph:

$ python thunder/benchmarks/benchmark_inference.py --input-length 4096 --output-length 2 --mode thunder --enable-nv-linear
  Prefill Time: 21.11 ms
  Decode Time: 9.34 ms

With CUDA graph:

$ NVFUSER_DISABLE=kernel_reuse python thunder/benchmarks/benchmark_inference.py --input-length 4096 --output-length 2 --mode thunder --enable-nv-linear
  Prefill Time: 14.30 ms
  Decode Time: 5.00 ms

@github-actions
Copy link

github-actions bot commented Oct 23, 2025

Review updated until commit adc4e08

Description

  • Added CUDA graph support for kernel execution

  • Implemented state management for CUDA graph lifecycle

  • Enhanced input/output handling for graph capture

  • Added test for CUDA graph capture and replay


Changes walkthrough 📝

Relevant files
Enhancement
fusion_kernel_runtime.cpp
Add CUDA graph state machine and capture logic                     

csrc/runtime/fusion_kernel_runtime.cpp

  • Initialize CUDA graph state based on KernelReuse option
  • Implement state transitions: warmup → capture → replay
  • Add synchronization before graph capture
  • Handle input/output during graph capture and replay
  • +80/-1   
    fusion_kernel_runtime.h
    Add CUDA graph state enum and members                                       

    csrc/runtime/fusion_kernel_runtime.h

  • Add CudaGraphState enum for state tracking
  • Include CUDAGraph header and dependencies
  • Add CUDA graph related member variables
  • Update header includes and ordering
  • +24/-4   
    Tests
    test_cuda_graph.py
    Add CUDA graph capture and replay test                                     

    tests/python/direct/test_cuda_graph.py

  • Add test for CUDA graph capture and replay
  • Use private stream for graph capture
  • Verify output correctness after replay
  • Include NVTX ranges for profiling
  • +38/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Safety of CUDA Graph Capture

    The PR enables CUDA graph capture when KernelReuse is disabled, but the comment acknowledges that fixed input shapes are not sufficient to guarantee safe capture. Dynamic tensor shapes due to scalar inputs or data-dependent operations (e.g., all-to-all) are potential risks. The PR should validate or document how such cases are handled to avoid incorrect graph captures.

    if (isOptionDisabled(DisableOption::KernelReuse)) {
      // It's safer to use CUDA graph when KernelReuse is disabled. When
      // KernelReuse is disabled, we only reuse a FusionKernelRuntime when the
      // input shapes match exactly.
      //
      // FIXME: input shapes staying fixed is not sufficient for a safe capture.
      // For example, a scalar input can affect intermediate tensor shapes. In
      // addition, intermediate tensor shapes can be data dependent, e.g., the
      // receiving tensor of all-to-all is size-dependent on the number of tokens
      // per expert. A safer approach could be to go through the complete fusion
      // or host IR to find if any allocation can be dynamic.
      cuda_graph_state_ = CudaGraphState::kWarmup;
    Input/Output Address Stability

    CUDA graph requires input and output tensor addresses to remain constant during capture and replay. The current implementation relies on PyTorch's allocator reusing addresses but does not implement staging buffers or validation. This could lead to silent correctness issues if address reuse is not guaranteed.

      switch (cuda_graph_state_) {
        case CudaGraphState::kCapture: {
          // FIXME: CUDA graph requires the input and output tensor addresses to
          // remain constant during capturing and replaying. If we can't guarantee
          // that, we'll have to copy the input tensors to sticky "staged" input
          // buffers. I'm **not** doing copy at this moment because weights take
          // too long to copy, defeating the benefit of CUDA Graph, and that
          // PyTorch's caching allocator is doing a fairly good job reusing same
          // addresses.
    #if 0
          cuda_graph_inputs_.resize(args.size());
          for (auto [i, arg] : enumerate(args)) {
            if (arg.is<at::Tensor>()) {
              auto static_arg = arg.as<at::Tensor>().clone();
              args[i] = cuda_graph_inputs_[i] = static_arg;
            }
          }
    #endif
    
          // This is to guarantee all in-flight lazy deallocations in PyTorch's
          // allocator have landed before capturing. Otherwise, they would be
          // captured and replayed.
          at::cuda::getCurrentCUDAStream().synchronize();
    
          // FIXME: at::cuda::CUDAGraph doesn't allow capturing on the default
          // stream. I'm currently working around the problem by changing the
          // benchmark to use a private stream:
          // https://github.com/Lightning-AI/lightning-thunder/pull/2692. But this
          // should be fixed in nvFuser instead.
    
          cuda_graph_.capture_begin();
        } break;
        case CudaGraphState::kReplay: {
    #if 0
          for (auto [i, arg] : enumerate(args)) {
            if (arg.is<at::Tensor>()) {
              auto static_arg = cuda_graph_inputs_[i].as<at::Tensor>();
              static_arg.copy_(arg.as<at::Tensor>());
              args[i] = static_arg;
            }
          }
    #endif
          cuda_graph_.replay();
          return cuda_graph_outputs_;
        }
        default:
          break;
      }
    Stream Safety for Graph Capture

    The PR notes that CUDAGraph does not support capturing on the default stream and references a workaround in benchmarks. This limitation should be addressed in the core logic or clearly documented, as it may cause failures in standard use cases that rely on the default stream.

    // FIXME: at::cuda::CUDAGraph doesn't allow capturing on the default
    // stream. I'm currently working around the problem by changing the
    // benchmark to use a private stream:
    // https://github.com/Lightning-AI/lightning-thunder/pull/2692. But this
    // should be fixed in nvFuser instead.

    @wujingyue wujingyue changed the title CUDA graph experiment [prototype] CUDA graph capture/replay Oct 23, 2025
    heuristics_ = std::move(maybe_heuristics.value());

    if (isOptionDisabled(DisableOption::KernelReuse)) {
    // It's safer to use CUDA graph when KernelReuse is disabled. When
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is there a way to disable kernel reuse without environment variables and from Python?

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants