Skip to content

feat: caching attempts #3527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft

feat: caching attempts #3527

wants to merge 13 commits into from

Conversation

peri044
Copy link
Collaborator

@peri044 peri044 commented May 20, 2025

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@peri044 peri044 added the WIP Work is in progress, pull request should not be merged yet label May 20, 2025
@peri044 peri044 marked this pull request as draft May 20, 2025 23:30
@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 20, 2025
@github-actions github-actions bot requested a review from gs-olive May 20, 2025 23:30
@peri044 peri044 changed the title feat : caching attempts feat: caching attempts May 20, 2025
@peri044 peri044 removed the request for review from gs-olive May 21, 2025 00:40
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/cache_utils.py	2025-05-28 23:49:40.726795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/cache_utils.py	2025-05-28 23:50:03.207085+00:00
@@ -5,81 +5,90 @@
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._pytree import _LEAF_SPEC
from torch._export.utils import _detect_fake_mode_from_gm

+
def get_kv_nodes(gm):
    """
    Get the key and value nodes from the graph.
    """
    kv_nodes = []
    for node in gm.graph.nodes:
-        if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
+        if (
+            node.op == "call_function"
+            and node.target == torch._C._nn.scaled_dot_product_attention
+        ):
            q_node, k_node, v_node = node.args[:3]
            kv_nodes.append((k_node, v_node))
    return kv_nodes

+
def get_random_tensor_from_node(node: Node) -> torch.Tensor:
-        """
-        Creates a random tensor based on the shape information in a node's metadata.
-        For symbolic dimensions, extracts the maximum value from the shape environment.
-        
-        Args:
-            node: A torch.fx.Node object with metadata containing tensor information
-            
-        Returns:
-            A random tensor with shape matching the node's metadata, or None if no valid
-            tensor information is found
-        """
-        if "val" not in node.meta:
-            raise ValueError(f"No tensor information found in node metadata for node: {node}")
-            
-        fake_tensor = node.meta["val"]
-        shape = []
-        
-        # Iterate through each dimension and handle symbolic dimensions
-        for dim in fake_tensor.shape:
-            if isinstance(dim, torch.SymInt):
-                # Extract the maximum value from the shape environment
-                max_val = dim.node.hint
-                shape.append(max_val)
-            else:
-                shape.append(dim)
-        
-        # Create a random tensor with the determined shape
-        dtype = fake_tensor.dtype
-        device = fake_tensor.device
-        random_tensor = torch.rand(shape, dtype=dtype, device=device)
+    """
+    Creates a random tensor based on the shape information in a node's metadata.
+    For symbolic dimensions, extracts the maximum value from the shape environment.

-        return random_tensor
+    Args:
+        node: A torch.fx.Node object with metadata containing tensor information
+
+    Returns:
+        A random tensor with shape matching the node's metadata, or None if no valid
+        tensor information is found
+    """
+    if "val" not in node.meta:
+        raise ValueError(
+            f"No tensor information found in node metadata for node: {node}"
+        )
+
+    fake_tensor = node.meta["val"]
+    shape = []
+
+    # Iterate through each dimension and handle symbolic dimensions
+    for dim in fake_tensor.shape:
+        if isinstance(dim, torch.SymInt):
+            # Extract the maximum value from the shape environment
+            max_val = dim.node.hint
+            shape.append(max_val)
+        else:
+            shape.append(dim)
+
+    # Create a random tensor with the determined shape
+    dtype = fake_tensor.dtype
+    device = fake_tensor.device
+    random_tensor = torch.rand(shape, dtype=dtype, device=device)
+
+    return random_tensor
+

def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]:
    """
    Creates random tensors based on the shape information in node metadata.
    For symbolic dimensions, extracts the maximum value from the shape environment.
-    
+
    Args:
        nodes: List of torch.fx.Node objects with metadata
-        
+
    Returns:
        List of random tensors with shapes matching the nodes' metadata
    """
    random_tensors = []
-    
+
    for node in nodes:
        if isinstance(node, Node):
            node_tensor = get_random_tensor_from_node(node)
        elif isinstance(node, tuple):
            node_tensor_list = []
            for n in node:
                random_tensor = get_random_tensor_from_node(n)
                node_tensor_list.append(random_tensor)
            node_tensor = tuple(node_tensor_list)
-               
+
        random_tensors.append(node_tensor)
-    
+
    return random_tensors
+

def add_graph_input(
    gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
) -> Node:
    """Add a graph input to the given GraphModule and return the newly created node.
@@ -130,10 +139,11 @@
        in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor)

    # return new node...
    return in_node

+
def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool:
    """Check if the node is a call to one of the ops."""
    if node.op != "call_function":
        return False
    # check if it's a single op that's provided
@@ -144,9 +154,10 @@
    if any(node.target == op for op in ops):
        return True

    return False

+
def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]:
    input_nodes: List[Node] = graph.find_nodes(op="placeholder")
    output_nodes: List[Node] = graph.find_nodes(op="output")
-    return (input_nodes, output_nodes)
\ No newline at end of file
+    return (input_nodes, output_nodes)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama_benchmark.py	2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama_benchmark.py	2025-05-28 23:50:03.276791+00:00
@@ -10,68 +10,72 @@
def main():
    # Initialize model and tokenizer
    print("Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
-        MODEL_NAME,
-        torch_dtype=torch.float16,
-        use_cache=False,
-        device_map="auto"
+        MODEL_NAME, torch_dtype=torch.float16, use_cache=False, device_map="auto"
    )
    model.generation_config.cache_implementation = "static"
    model.forward = torch.compile(model.forward)
-    
+
    # Prepare input prompt
    word = "What"
    # Tokenize the word
-    word_ids = tokenizer(word, return_tensors="pt").input_ids[0]  # Get the first (and only) sequence
+    word_ids = tokenizer(word, return_tensors="pt").input_ids[
+        0
+    ]  # Get the first (and only) sequence
    # Repeat the token 2048 times
-    input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device)  # Add batch dimension and move to device
+    input_ids = (
+        word_ids.repeat(1024).unsqueeze(0).to(model.device)
+    )  # Add batch dimension and move to device
    print(f"Input tensor shape: {input_ids.shape}")

    # # Warm-up pass
    print("Running warm-up pass...")
    output_ids = model.generate(
        input_ids,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
-        use_cache=USE_CACHE
+        use_cache=USE_CACHE,
    )
-    
+
    # Benchmark loop
    print("Running benchmark...")
    num_iterations = 10
    total_time = 0
    timings = []
-    
+
    for i in range(num_iterations):
        start_time = timeit.default_timer()
        output_ids = model.generate(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
-            use_cache=USE_CACHE
+            use_cache=USE_CACHE,
        )
        end_time = timeit.default_timer()
        generation_time = end_time - start_time
        total_time += generation_time
        timings.append(generation_time)
-        
+
        # Decode and print first iteration output
        # if i == 0:
        #     output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        #     print("\nFirst generation output:")
        #     print(output_text)
-    
+
    # Calculate and print statistics
    average_time = total_time / num_iterations
    print(f"\nPerformance Statistics:")
-    print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds")
+    print(
+        f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds"
+    )
    print(f"Average tokens per second: {100/average_time:.2f}")
    print("\nIndividual timings (ms):")
    for i, t in enumerate(timings):
        print(f"Iteration {i+1}: {t*1000:.2f}")

+
if __name__ == "__main__":
-    main() 
\ No newline at end of file
+    main()
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/dynamic_cache.py	2025-05-28 23:49:40.726795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/dynamic_cache.py	2025-05-28 23:50:03.285151+00:00
@@ -12,36 +12,45 @@
from torch_tensorrt.dynamo.utils import extract_var_range_info
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
    clean_up_graph_after_modifications,
)

-from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op
+from cache_utils import (
+    add_graph_input,
+    create_random_output_tensors,
+    get_kv_nodes,
+    is_op,
+)
import tensorrt
import torch.utils._pytree as pytree
+
logger = logging.getLogger(__name__)

-@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True)
+
+@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
+    torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True
+)
def cond_converter(
    ctx: torch_tensorrt.dynamo.conversion.ConversionContext,
    target: Target,
    args: Tuple[Any, ...],
    kwargs: Dict[str, Any],
    name: str,
) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]:
    """
    Converter for torch.ops.higher_order.cond operation to TensorRT.
-    
+
    This function handles the conversion of PyTorch's conditional operation to TensorRT.
    The conditional operation selects between two tensors based on a boolean predicate.
-    
+
    Args:
        ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context
        target (Target): The target operation to convert
        args (Tuple[Argument, ...]): The arguments to the operation
        kwargs (Dict[str, Argument]): The keyword arguments to the operation
        name (str): The name to give to the TensorRT layer
-        
+
    Returns:
        Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s)
    """
    if_layer = ctx.net.add_if_conditional()
    condition, true_branch, false_branch = args[0], args[1], args[2]
@@ -49,30 +58,31 @@
    output_layer = if_layer.add_output(true_branch, false_branch)
    output = output_layer.get_output(0)

    return output

+
def add_kv_as_outputs(gm):
    """
    Modifies the graph to add query, key, and value tensors as outputs.
-    
+
    This function identifies all scaled dot-product attention (SDPA) operations
    in the graph, creates copies of their query, key, and value inputs, and adds
    these copies to the graph's outputs. This allows for accessing these tensors
    externally, which is useful for operations like key-value caching.
-    
+
    Args:
        graph: The torch.fx.Graph to modify
-        
+
    Returns:
        None. The graph is modified in-place.
    """
    # list of MHA kernels we would want to detect and replace
    mha_ops = {
        torch._C._nn.scaled_dot_product_attention,
    }
-    
+
    # Find all SDPA nodes in the graph
    mha_nodes = []
    for node in gm.graph.nodes:
        if is_op(node, mha_ops):
            mha_nodes.append(node)
@@ -80,157 +90,170 @@
    # Iterate through each MHA node to extract shape information
    for mha_node in mha_nodes:
        if "val" in mha_node.meta and len(mha_node.args) >= 3:
            # Get the input nodes (query, key, value)
            q_node, k_node, v_node = mha_node.args[:3]
-            
+
            # Add the copy nodes as outputs to the graph
-            output_node = next(node for node in gm.graph.nodes if node.op == "output")            
+            output_node = next(node for node in gm.graph.nodes if node.op == "output")

            # Get the current output args (typically a tuple)
            current_outputs = output_node.args[0]
-            
+
            # If the current output is a tuple, extend it with our new outputs
            if isinstance(current_outputs, tuple):
                new_outputs = current_outputs + ((k_node, v_node),)
            else:
                # If there's only one output or it's not a tuple, create a new tuple
                new_outputs = (current_outputs, (k_node, v_node))
-            
+
            gm.graph.output(new_outputs)
            gm.graph.erase_node(output_node)
-        
+
    return new_outputs


-
-
def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True):
-        """
-        Add key-value tensors and index parameters as inputs to the graph.
-        
-        Args:
-            gm: The GraphModule to modify
-            fixed_kv: Boolean indicating whether to use static tensors for KV cache
-            
-        Returns:
-            A tuple containing:
-            - List of (k_input, v_input) node pairs for each SDPA operation
-            - start_idx input node for slicing operations
-            - end_idx input node for slicing operations
-        """
-
-        def get_static_tensor(tensor: torch.Tensor):
-            key_shape = []
-            for dim in tensor.shape:
-                if isinstance(dim, torch.SymInt):
-                    min_max_opt = extract_var_range_info(dim)
-                    key_shape.append(min_max_opt["max"])
-                else:
-                    key_shape.append(dim)
-            
-            static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
-            return static_tensor
-        
-        keys_values = get_kv_nodes(gm)
-
-        kv_inputs = []
-        for idx, key_value in enumerate(keys_values):
-            k_val = key_value[0].meta["val"]
-            v_val = key_value[1].meta["val"]
-            if fixed_kv:
-                k_val = get_static_tensor(k_val)
-                v_val = get_static_tensor(v_val)
-
-            # Add new inputs using add_graph_input
-            k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
-            v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
-            kv_inputs.append((k_input, v_input))
-
-        return kv_inputs
-
-
-def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]):
+    """
+    Add key-value tensors and index parameters as inputs to the graph.
+
+    Args:
+        gm: The GraphModule to modify
+        fixed_kv: Boolean indicating whether to use static tensors for KV cache
+
+    Returns:
+        A tuple containing:
+        - List of (k_input, v_input) node pairs for each SDPA operation
+        - start_idx input node for slicing operations
+        - end_idx input node for slicing operations
+    """
+
+    def get_static_tensor(tensor: torch.Tensor):
+        key_shape = []
+        for dim in tensor.shape:
+            if isinstance(dim, torch.SymInt):
+                min_max_opt = extract_var_range_info(dim)
+                key_shape.append(min_max_opt["max"])
+            else:
+                key_shape.append(dim)
+
+        static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
+        return static_tensor
+
+    keys_values = get_kv_nodes(gm)
+
+    kv_inputs = []
+    for idx, key_value in enumerate(keys_values):
+        k_val = key_value[0].meta["val"]
+        v_val = key_value[1].meta["val"]
+        if fixed_kv:
+            k_val = get_static_tensor(k_val)
+            v_val = get_static_tensor(v_val)
+
+        # Add new inputs using add_graph_input
+        k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+        v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
+        kv_inputs.append((k_input, v_input))
+
+    return kv_inputs
+
+
+def insert_torch_cond_before_sdpa(
+    gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]
+):
    """
    Insert a torch.cond operation before each scaled_dot_product_attention operation.
-    
+
    Args:
        gm: The FX GraphModule to modify
-        
+
    Returns:
        The modified GraphModule
    """
    # Find all nodes with scaled_dot_product_attention
    sdpa_nodes = []
    for node in gm.graph.nodes:
-        if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
+        if (
+            node.op == "call_function"
+            and node.target == torch._C._nn.scaled_dot_product_attention
+        ):
            sdpa_nodes.append(node)
-    
-    # Get the is_causal input node 
-    is_causal_node = next((node for node in gm.graph.nodes if node.op == "placeholder" and node.name == "is_causal"), None)
+
+    # Get the is_causal input node
+    is_causal_node = next(
+        (
+            node
+            for node in gm.graph.nodes
+            if node.op == "placeholder" and node.name == "is_causal"
+        ),
+        None,
+    )

    # For each SDPA node, insert a torch.cond operation before it
    for idx, sdpa_node in enumerate(sdpa_nodes):
- 
+
        with gm.graph.inserting_before(sdpa_node):
            # pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool))
            q_node, k_node, v_node = sdpa_node.args[:3]
            incoming_key, incoming_value = incoming_keys_values[idx]
            # Create nodes for concatenating k with incoming_key and v with incoming_value
            concatenated_k_node = gm.graph.create_node(
                "call_function",
                torch.ops.aten.cat.default,
-                args=([incoming_key, k_node], 2),  # Concatenate along sequence length dimension
-                kwargs={}
+                args=(
+                    [incoming_key, k_node],
+                    2,
+                ),  # Concatenate along sequence length dimension
+                kwargs={},
            )
            concatenated_v_node = gm.graph.create_node(
                "call_function",
                torch.ops.aten.cat.default,
-                args=([incoming_value, v_node], 2),  #  Concatenate along sequence length dimension
-                kwargs={}
-            )
-            
+                args=(
+                    [incoming_value, v_node],
+                    2,
+                ),  #  Concatenate along sequence length dimension
+                kwargs={},
+            )
+
            # Create the torch.cond node
            cond_k_node = gm.graph.create_node(
                "call_function",
                torch.ops.higher_order.cond,
                args=(is_causal_node, concatenated_k_node, k_node),
            )
- 
+
            cond_v_node = gm.graph.create_node(
                "call_function",
                torch.ops.higher_order.cond,
                args=(is_causal_node, concatenated_v_node, v_node),
            )

            sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:]
-    
+
    return gm
-


@_aten_lowering_pass
def insert_dynamic_kv_cache(
    gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
    """Insert FlashInfer MHA + KV cache ops in the graph"""
    """Perform insertion of kv-caches and attention kernel."""

    # Add static key and value as inputs to the graph
-    kv_inputs  = add_kv_and_indices_as_inputs(gm, fixed_kv=True)
+    kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True)

    # Call the function to add KV as outputs
    logits_keys_values = add_kv_as_outputs(gm)

    # Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases
    gm = insert_torch_cond_before_sdpa(gm, kv_inputs)

    gm = clean_up_graph_after_modifications(gm)
-    
+
    new_output_tensors = create_random_output_tensors(logits_keys_values)
    new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
    gm._out_spec = new_out_spec
-    
+
    logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
    return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/test_sdpa.py	2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/test_sdpa.py	2025-05-28 23:50:03.466989+00:00
@@ -9,101 +9,117 @@
from contextlib import nullcontext
import argparse

# llama2_model_name = "meta-llama/Llama-2-7b-hf"
llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct"
-llama_model = AutoModelForCausalLM.from_pretrained(
-                llama3_model_name,
-                use_cache=False,
-                attn_implementation="sdpa",
-                num_hidden_layers=1,
-            ).eval().cuda()
+llama_model = (
+    AutoModelForCausalLM.from_pretrained(
+        llama3_model_name,
+        use_cache=False,
+        attn_implementation="sdpa",
+        num_hidden_layers=1,
+    )
+    .eval()
+    .cuda()
+)
LLAMA_CONFIG = llama_model.config
+

def test_llama_attention(args):
    class LlamaAttentionBlock(nn.Module):
        def __init__(self):
            super().__init__()
            self.config = LLAMA_CONFIG
-            self.attn = LlamaAttention(
-                config=self.config,
-                layer_idx=0
+            self.attn = LlamaAttention(config=self.config, layer_idx=0)
+
+        def forward(self, hidden_states, position_embeddings):
+            attn_output, attn_weights = self.attn(
+                hidden_states, position_embeddings, None
            )
-        def forward(self, hidden_states, position_embeddings):
-            attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None)
            return attn_output
-    
+
    DTYPE = torch.float32
    # model = LlamaAttentionBlock().eval().cuda().to(DTYPE)
    model = llama_model.model.layers[0].self_attn.to(DTYPE)
-    # llama3 
+    # llama3
    # hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
    # position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda())
    hidden_states = torch.load("hidden_states.pt")
    position_embeddings = torch.load("position_embeddings.pt")
    # breakpoint()
    pyt_output = model(hidden_states, position_embeddings, None)
-    
+
    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
-    ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
-    
+    ep = torch.export.export(
+        model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+    )
+
    with torch_tensorrt.logging.debug():
-        trt_model = torch_tensorrt.dynamo.compile(ep, 
-                                                inputs=[hidden_states, position_embeddings, None], 
-                                                enabled_precisions={torch.float32},
-                                                disable_tf32=True,
-                                                debug=True)
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[hidden_states, position_embeddings, None],
+            enabled_precisions={torch.float32},
+            disable_tf32=True,
+            debug=True,
+        )
    trt_output = trt_model(hidden_states, position_embeddings, None)
    if isinstance(pyt_output, tuple):
-        print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}")
+        print(
+            f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+        )
    else:
        print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")
-    
+

def test_llama_decoder(args):
    class LlamaDecoder(nn.Module):
        def __init__(self):
            super().__init__()
            self.config = LLAMA_CONFIG
-            self.decoder_layer = LlamaDecoderLayer(
-                config=self.config,
-                layer_idx=0
+            self.decoder_layer = LlamaDecoderLayer(config=self.config, layer_idx=0)
+
+        def forward(self, hidden_states, position_embeddings):
+            decoder_output = self.decoder_layer(
+                hidden_states, position_embeddings=position_embeddings
            )
-        def forward(self, hidden_states, position_embeddings):
-            decoder_output = self.decoder_layer(hidden_states, position_embeddings=position_embeddings)
            return decoder_output[0]
-    
+
    DTYPE = torch.float32
    model = LlamaDecoder().eval().cuda().to(DTYPE)
-    # llama3 
+    # llama3
    hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
-    position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda())
+    position_embeddings = (
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+    )

    pyt_output = model(hidden_states, position_embeddings)
    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
-    ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes)
-    
-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
-        trt_model = torch_tensorrt.dynamo.compile(ep, 
-                                                inputs=[hidden_states, position_embeddings], 
-                                                enabled_precisions={torch.float32},
-                                                debug=args.debug)
+    ep = torch.export.export(
+        model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+    )
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[hidden_states, position_embeddings],
+            enabled_precisions={torch.float32},
+            debug=args.debug,
+        )
    trt_output = trt_model(hidden_states, position_embeddings)

    print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")


if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(
        description="Run test cases for llama attention and decoder"
    )
    arg_parser.add_argument(
-        "--debug",
-        action="store_true",
-        help="Enable debug (default: False)"
+        "--debug", action="store_true", help="Enable debug (default: False)"
    )
    args = arg_parser.parse_args()
    with torch.inference_mode():
        test_llama_attention(args)
-        # test_llama_decoder(args)
\ No newline at end of file
+        # test_llama_decoder(args)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama3_trt.py	2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama3_trt.py	2025-05-28 23:50:03.506307+00:00
@@ -17,40 +17,48 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from contextlib import nullcontext
-from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache, get_zeroed_kv_cache_inputs
+from utils import (
+    export_llm,
+    generate,
+    recordStats,
+    time_generate,
+    generate_with_kv_cache,
+    get_zeroed_kv_cache_inputs,
+)


DEVICE = torch.device("cuda:0")
+

def get_model(args):
    with torch.no_grad():
        if args.model == "meta-llama/Llama-2-7b-chat-hf":
            model = (
                AutoModelForCausalLM.from_pretrained(
                    args.model,
                    use_cache=False,
                    attn_implementation="sdpa",
-                    num_hidden_layers=1
+                    num_hidden_layers=1,
                )
                .eval()
                .cuda()
            )
        elif args.model == "meta-llama/Llama-3.2-1B-Instruct":
            model = (
                AutoModelForCausalLM.from_pretrained(
                    args.model,
                    use_cache=False,
                    attn_implementation="sdpa",
-                    num_hidden_layers=1
+                    num_hidden_layers=1,
                )
                .eval()
                .cuda()
            )
-            
+
        elif args.model == "meta-llama/Llama-3.2-3B-Instruct":
            model = (
                AutoModelForCausalLM.from_pretrained(
                    args.model,
                    use_cache=False,
@@ -71,13 +79,11 @@
                .cuda()
            )
        elif args.model == "google/gemma-3-1b-it":
            model = (
                AutoModelForCausalLM.from_pretrained(
-                    "google/gemma-3-1b-it", 
-                    use_cache=False, 
-                    attn_implementation="sdpa"
+                    "google/gemma-3-1b-it", use_cache=False, attn_implementation="sdpa"
                )
                .eval()
                .cuda()
            )
    if args.precision == "FP16":
@@ -91,25 +97,25 @@


def compile_torchtrt(model, input_ids, args):
    max_seq_len = input_ids.shape[1] + args.max_tokens
    ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
-    
+
    # Set precision specific flags
-    use_fp32_acc = False 
+    use_fp32_acc = False
    use_explicit_typing = False
    if args.precision == "FP16":
        enabled_precisions = {torch.float32}
-        use_fp32_acc = True 
+        use_fp32_acc = True
        use_explicit_typing = True
    elif args.precision == "BF16":
        enabled_precisions = {torch.bfloat16}
        use_fp32_acc = False
    else:
        enabled_precisions = {torch.float32}

-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
        trt_model = torch_tensorrt.dynamo.compile(
            ep,
            inputs=[input_ids],
            enabled_precisions=enabled_precisions,
            # truncate_double=True,
@@ -132,44 +138,51 @@
        tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
    )
    print("===================================")


-
def measure_perf(trt_model, input_signature, backend_name):
    # Measure average time for 10 iterations
    import timeit
    import numpy as np
-    
+
    total_time = 0
    iterations = 10
-    
+
    print("Running warmup iteration...")
    # Warmup run
    _ = trt_model(*input_signature)
    torch.cuda.synchronize()
-    
+
    print(f"Measuring performance over {iterations} iterations...")
    for i in range(iterations):
        start_time = timeit.default_timer()
        _ = trt_model(*input_signature)
        torch.cuda.synchronize()
        end_time = timeit.default_timer()
        iter_time = end_time - start_time
        total_time += iter_time
        # print(f"Iteration {i+1}: {iter_time:.4f} seconds")
-    
+
    avg_time = total_time / iterations
-    print(f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds")
-    print(f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second")
+    print(
+        f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds"
+    )
+    print(
+        f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second"
+    )
+

if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(
        description="Run inference on a model with random input values"
    )
    arg_parser.add_argument(
-        "--model", type=str, default="meta-llama/Llama-3.2-1B-Instruct", help="Name of LLM model"
+        "--model",
+        type=str,
+        default="meta-llama/Llama-3.2-1B-Instruct",
+        help="Name of LLM model",
    )
    arg_parser.add_argument(
        "--tokenizer_path",
        type=str,
        default="meta-llama/Llama-3.2-1B-Instruct",
@@ -187,34 +200,28 @@
    )
    arg_parser.add_argument(
        "--max_tokens", type=int, default=128, help="no. of max tokens to be generated"
    )
    arg_parser.add_argument(
-        "--enable_pytorch_run", 
-        action="store_true", 
-        help="Enable pytorch run (default: False)"
+        "--enable_pytorch_run",
+        action="store_true",
+        help="Enable pytorch run (default: False)",
    )
    arg_parser.add_argument(
        "--cache",
        type=str,
        default="static",
        help="Type of KV cache to use",
    )
    arg_parser.add_argument(
-        "--cudagraph",
-        action="store_true",
-        help="Enable cudagraphs (default: False)"
-    )
-    arg_parser.add_argument(
-        "--debug",
-        action="store_true",
-        help="Enable debug (default: False)"
-    )
-    arg_parser.add_argument(
-        "--benchmark",
-        action="store_true",
-        help="Enable benchmark (default: False)"
+        "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)"
+    )
+    arg_parser.add_argument(
+        "--debug", action="store_true", help="Enable debug (default: False)"
+    )
+    arg_parser.add_argument(
+        "--benchmark", action="store_true", help="Enable benchmark (default: False)"
    )
    args = arg_parser.parse_args()
    with torch.inference_mode():
        model = get_model(args)

@@ -236,75 +243,118 @@
        pyt_stats = None
        if args.enable_pytorch_run:
            pyt_gen_tokens = generate(
                model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id
            )
-            
+
            if args.benchmark:
                pyt_timings = time_generate(
                    generate,
                    model,
                    input_ids.clone(),
                    MAX_OUTPUT_SEQ_LENGTH,
                    tokenizer.eos_token_id,
                    iterations=args.iterations,
                )
                pyt_stats = recordStats(
-                    "PyTorch", pyt_timings, args.precision, batch_size=1, compile_time_s=None
+                    "PyTorch",
+                    pyt_timings,
+                    args.precision,
+                    batch_size=1,
+                    compile_time_s=None,
                )

        # TRT
        pyt_logits_tok1 = model.cuda()(input_ids)
        next_tokens = torch.argmax(pyt_logits_tok1.logits[:, -1, :], dim=-1)
        input_seq = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        pyt_logits_tok2 = model.cuda()(input_seq)
        from lower_sdpa import *
+
        if args.cache == "static":
            # This import is required to register static KV cache transformations as lowering passes
            from static_cache2 import *
-            trt_model = compile_torchtrt(model, input_ids, args) 
+
+            trt_model = compile_torchtrt(model, input_ids, args)
            kv_cache = get_zeroed_kv_cache_inputs(trt_model)

            # First token generation
-            pyt_keys = torch.load("key.pt"); pyt_values = torch.load("value.pt")
-            trt_logits, key_cache, value_cache, trt_keys_1, trt_values_1 = trt_model(input_ids.clone(), True, *kv_cache, 0, input_ids.shape[1])
-            print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok1.logits - trt_logits))}")
-            print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys - trt_keys_1))}")
-            print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys - key_cache[:, :, :-2, :]))}")
-            print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values - trt_values_1))}")
-            print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values - value_cache[:, :, :-2, :]))}")
+            pyt_keys = torch.load("key.pt")
+            pyt_values = torch.load("value.pt")
+            trt_logits, key_cache, value_cache, trt_keys_1, trt_values_1 = trt_model(
+                input_ids.clone(), True, *kv_cache, 0, input_ids.shape[1]
+            )
+            print(
+                f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok1.logits - trt_logits))}"
+            )
+            print(
+                f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys - trt_keys_1))}"
+            )
+            print(
+                f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys - key_cache[:, :, :-2, :]))}"
+            )
+            print(
+                f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values - trt_values_1))}"
+            )
+            print(
+                f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values - value_cache[:, :, :-2, :]))}"
+            )
            next_tokens = torch.argmax(trt_logits[:, -1, :], dim=-1)

            # Second token generation
-            trt_logits_2, key_cache2, value_cache2, trt_keys_2, trt_values_2 = trt_model(next_tokens[:, None], False, key_cache.clone(), value_cache.clone(), input_ids.shape[1], input_ids.shape[1]+1)
-            pyt_keys2 = torch.load("key2.pt"); pyt_values2 = torch.load("value2.pt")
-            print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok2.logits[:, -1:, :] - trt_logits_2))}")
-            print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys2[:, :, -2:-1, :] - trt_keys_2))}")
-            print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys2 - key_cache2[:, :, :-1, :]))}")
-            print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values2[:, :, -2:-1, :] - trt_values_2))}")
-            print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values2 - value_cache2[:, :, :-1, :]))}")
+            trt_logits_2, key_cache2, value_cache2, trt_keys_2, trt_values_2 = (
+                trt_model(
+                    next_tokens[:, None],
+                    False,
+                    key_cache.clone(),
+                    value_cache.clone(),
+                    input_ids.shape[1],
+                    input_ids.shape[1] + 1,
+                )
+            )
+            pyt_keys2 = torch.load("key2.pt")
+            pyt_values2 = torch.load("value2.pt")
+            print(
+                f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok2.logits[:, -1:, :] - trt_logits_2))}"
+            )
+            print(
+                f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys2[:, :, -2:-1, :] - trt_keys_2))}"
+            )
+            print(
+                f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys2 - key_cache2[:, :, :-1, :]))}"
+            )
+            print(
+                f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values2[:, :, -2:-1, :] - trt_values_2))}"
+            )
+            print(
+                f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values2 - value_cache2[:, :, :-1, :]))}"
+            )
            breakpoint()
        elif args.cache == "dynamic":
            from dynamic_cache import *
-            trt_model = compile_torchtrt(model, input_ids, args) 
+
+            trt_model = compile_torchtrt(model, input_ids, args)
            breakpoint()
            kv_cache = get_zeroed_kv_cache_inputs(trt_model)
        else:
            # pyt_logits = model.cuda()(input_ids.clone())
-            trt_model = compile_torchtrt(model, input_ids, args) 
+            trt_model = compile_torchtrt(model, input_ids, args)
            # trt_logits = trt_model(input_ids.clone(), True)
            # print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}")
            # print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}")
        if args.cache == "static":
            if args.cudagraph:
                # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
                # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
                torch_tensorrt.runtime.set_cudagraphs_mode(True)
-             
+
            trt_gen_tokens = generate_with_kv_cache(
-                trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
-                )
+                trt_model,
+                input_ids.clone(),
+                MAX_OUTPUT_SEQ_LENGTH,
+                tokenizer.eos_token_id,
+            )

            if args.benchmark:
                trt_timings = time_generate(
                    generate_with_kv_cache,
                    trt_model,
@@ -316,14 +366,17 @@
        elif args.cache == "dynamic":
            if args.cudagraph:
                # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
                # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
                torch_tensorrt.runtime.set_cudagraphs_mode(True)
-             
+
            trt_gen_tokens = generate_with_kv_cache(
-                trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
-                )
+                trt_model,
+                input_ids.clone(),
+                MAX_OUTPUT_SEQ_LENGTH,
+                tokenizer.eos_token_id,
+            )

            if args.benchmark:
                trt_timings = time_generate(
                    generate_with_kv_cache,
                    trt_model,
@@ -333,32 +386,39 @@
                    iterations=args.iterations,
                )

        else:
            trt_gen_tokens = generate(
-                trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
+                trt_model,
+                input_ids.clone(),
+                MAX_OUTPUT_SEQ_LENGTH,
+                tokenizer.eos_token_id,
            )
            if args.benchmark:
                trt_timings = time_generate(
                    generate,
                    trt_model,
                    input_ids.clone(),
                    MAX_OUTPUT_SEQ_LENGTH,
                    tokenizer.eos_token_id,
                    iterations=args.iterations,
                )
-        
+
        if args.benchmark:
            trt_stats = recordStats(
-                "TensorRT", trt_timings, args.precision, batch_size=1, compile_time_s=None
-            )
-
-        if args.enable_pytorch_run: 
+                "TensorRT",
+                trt_timings,
+                args.precision,
+                batch_size=1,
+                compile_time_s=None,
+            )
+
+        if args.enable_pytorch_run:
            print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
        print_outputs("TensorRT", trt_gen_tokens, tokenizer)

-        if  args.benchmark:
+        if args.benchmark:
            if args.enable_pytorch_run:
                print("=========PyTorch PERFORMANCE============ \n")
                print(pyt_stats)
            print("===================== \n")
            print("=========TensorRT PERFORMANCE============ \n")
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/static_cache.py	2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/static_cache.py	2025-05-28 23:50:03.516074+00:00
@@ -12,55 +12,57 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
    clean_up_graph_after_modifications,
)
import torch.utils._pytree as pytree
from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes
+
logger = logging.getLogger(__name__)

SDPA_OP = torch._C._nn.scaled_dot_product_attention

+
def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
    """
    Modifies the graph to add query, key, and value tensors as outputs.
-    
+
    This function identifies all scaled dot-product attention (SDPA) operations
    in the graph, creates copies of their query, key, and value inputs, and adds
    these copies to the graph's outputs. This allows for accessing these tensors
    externally, which is useful for operations like key-value caching.
-    
+
    Args:
        graph: The torch.fx.Graph to modify
-        
+
    Returns:
        None. The graph is modified in-place.
    """
-    output_node = next(node for node in gm.graph.nodes if node.op == "output")            
+    output_node = next(node for node in gm.graph.nodes if node.op == "output")

    # Get the current output args (typically a tuple)
    current_outputs = output_node.args[0]
-    
+
    # If the current output is a tuple, extend it with our new outputs
    if isinstance(current_outputs, tuple):
        new_outputs = current_outputs + tuple(kv_cache_for_graph)
    else:
        # If there's only one output or it's not a tuple, create a new tuple
-        new_outputs = (current_outputs,) +  tuple(kv_cache_for_graph)
-            
+        new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
    gm.graph.output(new_outputs)
    gm.graph.erase_node(output_node)

    return new_outputs


def add_kv_cache_inputs(gm, fixed_kv: bool = True):
    """
    Add key-value tensors, index parameters as inputs to the graph.
-    
+
    Args:
        gm: The GraphModule to modify
        fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
-        
+
    Returns:
        A tuple containing:
        - List of (k_input, v_input) node pairs for each SDPA operation
        - start_idx input node for slicing operations
        - end_idx input node for slicing operations
@@ -72,14 +74,14 @@
            if isinstance(dim, torch.SymInt):
                min_max_opt = extract_var_range_info(dim)
                key_shape.append(min_max_opt["max"])
            else:
                key_shape.append(dim)
-        
+
        static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
        return static_tensor
-    
+
    keys_values = get_kv_nodes(gm)

    kv_inputs = []
    for idx, key_value in enumerate(keys_values):
        k_val = key_value[0].meta["val"]
@@ -87,12 +89,12 @@
        if fixed_kv:
            k_val = get_static_tensor(k_val)
            v_val = get_static_tensor(v_val)

        # Add new inputs using add_graph_input
-        k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
-        v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
+        k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+        v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
        kv_inputs.append((k_input, v_input))

    # Add start_idx and end_idx as inputs
    start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0))
    end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1))
@@ -103,10 +105,11 @@
    seq_len = input_ids_meta.shape[1]
    min_max_opt = extract_var_range_info(seq_len)
    max_seq_len = min_max_opt["max"]

    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
    shape_env = ShapeEnv()
    # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
    start_idx_unbacked_symint = shape_env.create_unbacked_symint()
    torch._check(start_idx_unbacked_symint >= 0)
    torch._check(start_idx_unbacked_symint <= max_seq_len)
@@ -119,12 +122,16 @@
    end_idx_input.meta["val"] = end_idx_unbacked_symint

    return kv_inputs, start_idx_input, end_idx_input


-
-def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node):
+def insert_kv_slicing_before_sdpa(
+    gm,
+    incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+    start_idx_input: Node,
+    end_idx_input: Node,
+):
    """
    Insert slicing operations before each scaled_dot_product_attention operation.
    """
    # Find all nodes with scaled_dot_product_attention
    sdpa_nodes = []
@@ -135,106 +142,109 @@
    for idx, sdpa_node in enumerate(sdpa_nodes):
        q_node, k_node, v_node = sdpa_node.args[:3]
        incoming_key, incoming_value = incoming_keys_values[idx]
        kv_cache_for_sdpa_node = []
        new_keys_values = []
-        for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]):
+        for key_or_value, current_key_or_value_node in zip(
+            [incoming_key, incoming_value], [k_node, v_node]
+        ):
            # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
            with gm.graph.inserting_before(sdpa_node):
                slice_1 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(key_or_value,),
-                    kwargs={}
+                    kwargs={},
                )
                slice_2 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(slice_1, 1),
-                    kwargs={}
+                    kwargs={},
                )
                slice_3 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_2, 2, None, start_idx_input),  
-                    kwargs={}
+                    args=(slice_2, 2, None, start_idx_input),
+                    kwargs={},
                )
                slice_4 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_3, 3), 
-                    kwargs={}
-                )
-                # =============================================== # 
+                    args=(slice_3, 3),
+                    kwargs={},
+                )
+                # =============================================== #
                # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
                slice_5 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(key_or_value,),
-                    kwargs={}
+                    kwargs={},
                )
                slice_6 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(slice_5, 1),
-                    kwargs={}
+                    kwargs={},
                )
                slice_7 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_6, 2, end_idx_input),  
-                    kwargs={}
+                    args=(slice_6, 2, end_idx_input),
+                    kwargs={},
                )
                slice_8 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_7, 3), 
-                    kwargs={}
-                )
-                # =============================================== # 
+                    args=(slice_7, 3),
+                    kwargs={},
+                )
+                # =============================================== #
                # Concatenate the sliced tensors to build KV cache
                cat = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.cat.default,
-                    args=([slice_4, current_key_or_value_node, slice_8], 2), 
-                    kwargs={}
+                    args=([slice_4, current_key_or_value_node, slice_8], 2),
+                    kwargs={},
                )
                # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
                cat.meta.update(key_or_value.meta)
                kv_cache_for_sdpa_node.append(cat)
-                # =============================================== # 
+                # =============================================== #
                # Get the current key and value by indexing the KV cache
                slice_9 = gm.graph.create_node(
-                    "call_function",
-                    torch.ops.aten.slice.Tensor,
-                    args=(cat,),
-                    kwargs={}
+                    "call_function", torch.ops.aten.slice.Tensor, args=(cat,), kwargs={}
                )
                slice_10 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(slice_9, 1),
-                    kwargs={}
+                    kwargs={},
                )
                slice_11 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_10, 2, None, end_idx_input),  
-                    kwargs={}
+                    args=(slice_10, 2, None, end_idx_input),
+                    kwargs={},
                )
                slice_12 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_11, 3), 
-                    kwargs={}
+                    args=(slice_11, 3),
+                    kwargs={},
                )
                new_keys_values.append(slice_12)
-        
+
        kv_cache_for_graph.extend(kv_cache_for_sdpa_node)

-        sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + sdpa_node.args[3:]
-    
+        sdpa_node.args = (
+            q_node,
+            new_keys_values[0],
+            new_keys_values[1],
+        ) + sdpa_node.args[3:]
+
    return gm, kv_cache_for_graph


@_aten_lowering_pass
def insert_kv_cache(
@@ -243,13 +253,15 @@
    """Insert KV cache ops in the graph"""
    """Perform insertion of kv-caches and attention kernel."""
    # Add static key and value as inputs to the graph
    kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True)

-    # Build and update the KV cache using computed KV inputs for current token and 
+    # Build and update the KV cache using computed KV inputs for current token and
    # incoming keys and values from previous tokens (which were added as inputs)
-    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input)
+    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+        gm, kv_inputs, start_idx_input, end_idx_input
+    )

    # Call the function to add KV as outputs
    logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)

    gm = clean_up_graph_after_modifications(gm)
@@ -259,7 +271,5 @@
    new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
    gm._out_spec = new_out_spec
    logger.debug("After inserting KV cache into the graph: " + str(gm.graph))

    return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/static_cache2.py	2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/static_cache2.py	2025-05-28 23:50:03.529014+00:00
@@ -12,55 +12,57 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
    clean_up_graph_after_modifications,
)
import torch.utils._pytree as pytree
from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes
+
logger = logging.getLogger(__name__)

SDPA_OP = torch._C._nn.scaled_dot_product_attention

+
def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
    """
    Modifies the graph to add query, key, and value tensors as outputs.
-    
+
    This function identifies all scaled dot-product attention (SDPA) operations
    in the graph, creates copies of their query, key, and value inputs, and adds
    these copies to the graph's outputs. This allows for accessing these tensors
    externally, which is useful for operations like key-value caching.
-    
+
    Args:
        graph: The torch.fx.Graph to modify
-        
+
    Returns:
        None. The graph is modified in-place.
    """
-    output_node = next(node for node in gm.graph.nodes if node.op == "output")            
+    output_node = next(node for node in gm.graph.nodes if node.op == "output")

    # Get the current output args (typically a tuple)
    current_outputs = output_node.args[0]
-    
+
    # If the current output is a tuple, extend it with our new outputs
    if isinstance(current_outputs, tuple):
        new_outputs = current_outputs + tuple(kv_cache_for_graph)
    else:
        # If there's only one output or it's not a tuple, create a new tuple
-        new_outputs = (current_outputs,) +  tuple(kv_cache_for_graph)
-            
+        new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
    gm.graph.output(new_outputs)
    gm.graph.erase_node(output_node)

    return new_outputs


def add_kv_cache_inputs(gm, fixed_kv: bool = True):
    """
    Add key-value tensors, index parameters as inputs to the graph.
-    
+
    Args:
        gm: The GraphModule to modify
        fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
-        
+
    Returns:
        A tuple containing:
        - List of (k_input, v_input) node pairs for each SDPA operation
        - start_idx input node for slicing operations
        - end_idx input node for slicing operations
@@ -72,14 +74,14 @@
            if isinstance(dim, torch.SymInt):
                min_max_opt = extract_var_range_info(dim)
                key_shape.append(min_max_opt["max"])
            else:
                key_shape.append(dim)
-        
+
        static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
        return static_tensor
-    
+
    keys_values = get_kv_nodes(gm)

    kv_inputs = []
    for idx, key_value in enumerate(keys_values):
        k_val = key_value[0].meta["val"]
@@ -87,12 +89,12 @@
        if fixed_kv:
            k_val = get_static_tensor(k_val)
            v_val = get_static_tensor(v_val)

        # Add new inputs using add_graph_input
-        k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
-        v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
+        k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+        v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
        kv_inputs.append((k_input, v_input))

    # Add start_idx and end_idx as inputs
    start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0))
    end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1))
@@ -100,18 +102,19 @@
    # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, ..
    input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
    # Get the third last input which should be the last value cache node and store the max_seq_len
    input_ids_meta = input_nodes[-3].meta["val"]
    seq_len = input_ids_meta.shape[2]
- 
+
    if isinstance(seq_len, torch.SymInt):
        min_max_opt = extract_var_range_info(seq_len)
        max_seq_len = min_max_opt["max"]
    else:
        max_seq_len = seq_len

    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
    shape_env = ShapeEnv()
    # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
    start_idx_unbacked_symint = shape_env.create_unbacked_symint()
    torch._check(start_idx_unbacked_symint >= 0)
    torch._check(start_idx_unbacked_symint <= max_seq_len)
@@ -127,14 +130,17 @@
    is_causal_input = add_graph_input(gm, "is_causal", True)
    is_causal_input.meta["val"] = torch.tensor(True)

    return kv_inputs, start_idx_input, end_idx_input, is_causal_input

-def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input):
+
+def create_kv_cache_update_nodes(
+    gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input
+):
    """
    Create slicing and concatenation nodes for KV cache update.
-    
+
    This function creates the necessary slicing and concatenation nodes to update the KV cache
    during the generation process. It takes the SDPA node, the current KV cache node, and the
    incoming KV cache node as input.
    Returns:
        for a particular SDPA node, a tuple containing:
@@ -147,78 +153,73 @@
    with gm.graph.inserting_before(sdpa_node):
        slice_1 = gm.graph.create_node(
            "call_function",
            torch.ops.aten.slice.Tensor,
            args=(incoming_kv_node,),
-            kwargs={}
+            kwargs={},
        )
        slice_2 = gm.graph.create_node(
-            "call_function",
-            torch.ops.aten.slice.Tensor,
-            args=(slice_1, 1),
-            kwargs={}
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_1, 1), kwargs={}
        )
        slice_3 = gm.graph.create_node(
            "call_function",
            torch.ops.aten.slice.Tensor,
-            args=(slice_2, 2, None, start_idx_input),  
-            kwargs={}
+            args=(slice_2, 2, None, start_idx_input),
+            kwargs={},
        )
        slice_4 = gm.graph.create_node(
-            "call_function",
-            torch.ops.aten.slice.Tensor,
-            args=(slice_3, 3), 
-            kwargs={}
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_3, 3), kwargs={}
        )
        # Concat key_cache[:,:,:start_idx,:] with current key (k)
        concat_keys_or_values = gm.graph.create_node(
            "call_function",
            torch.ops.aten.cat.default,
-            args=([slice_4, current_kv_node], 2), 
-            kwargs={}
-        )
-
-        # =============================================== # 
+            args=([slice_4, current_kv_node], 2),
+            kwargs={},
+        )
+
+        # =============================================== #
        # Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
        slice_5 = gm.graph.create_node(
            "call_function",
            torch.ops.aten.slice.Tensor,
            args=(incoming_kv_node,),
-            kwargs={}
+            kwargs={},
        )
        slice_6 = gm.graph.create_node(
-            "call_function",
-            torch.ops.aten.slice.Tensor,
-            args=(slice_5, 1),
-            kwargs={}
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_5, 1), kwargs={}
        )
        slice_7 = gm.graph.create_node(
            "call_function",
            torch.ops.aten.slice.Tensor,
-            args=(slice_6, 2, end_idx_input),  
-            kwargs={}
+            args=(slice_6, 2, end_idx_input),
+            kwargs={},
        )
        slice_8 = gm.graph.create_node(
-            "call_function",
-            torch.ops.aten.slice.Tensor,
-            args=(slice_7, 3), 
-            kwargs={}
-        )
-        # =============================================== # 
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_7, 3), kwargs={}
+        )
+        # =============================================== #
        # Concatenate the sliced tensors to build KV cache
        new_incoming_keys_or_values = gm.graph.create_node(
            "call_function",
            torch.ops.aten.cat.default,
-            args=([concat_keys_or_values, slice_8], 2), 
-            kwargs={}
+            args=([concat_keys_or_values, slice_8], 2),
+            kwargs={},
        )
        # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
        new_incoming_keys_or_values.meta.update(incoming_kv_node.meta)

    return concat_keys_or_values, new_incoming_keys_or_values

-def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node):
+
+def insert_kv_slicing_before_sdpa(
+    gm,
+    incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+    start_idx_input: Node,
+    end_idx_input: Node,
+    is_causal_input: Node,
+):
    """
    Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic:
    concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
    concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
    new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
@@ -232,21 +233,34 @@
            sdpa_nodes.append(node)
    kv_cache_for_graph = []
    for idx, sdpa_node in enumerate(sdpa_nodes):
        q_node, k_node, v_node = sdpa_node.args[:3]
        incoming_key, incoming_value = incoming_keys_values[idx]
-        # For keys  
-        new_current_key_node, new_incoming_key_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input)
+        # For keys
+        new_current_key_node, new_incoming_key_cache_node = (
+            create_kv_cache_update_nodes(
+                gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input
+            )
+        )
        # For values
-        new_current_value_node, new_incoming_value_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input)
+        new_current_value_node, new_incoming_value_cache_node = (
+            create_kv_cache_update_nodes(
+                gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input
+            )
+        )

        # Store the KV cache nodes for the current SDPA node
-        kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node])
+        kv_cache_for_graph.extend(
+            [new_incoming_key_cache_node, new_incoming_value_cache_node]
+        )

        # Update the SDPA node arguments with current key and value nodes
-        sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (None, is_causal_input) # + sdpa_node.args[3:]
-    
+        sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (
+            None,
+            is_causal_input,
+        )  # + sdpa_node.args[3:]
+
    kv_cache_for_graph.extend([k_node, v_node])
    return gm, kv_cache_for_graph


@_aten_lowering_pass
@@ -254,15 +268,19 @@
    gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
    """Insert KV cache ops in the graph"""
    """Perform insertion of kv-caches and attention kernel."""
    # Add static key and value as inputs to the graph
-    kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True)
-
-    # Build and update the KV cache using computed KV inputs for current token and 
+    kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(
+        gm, fixed_kv=True
+    )
+
+    # Build and update the KV cache using computed KV inputs for current token and
    # incoming keys and values from previous tokens (which were added as inputs)
-    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input)
+    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+        gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input
+    )

    # Call the function to add KV as outputs
    logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)

    gm = clean_up_graph_after_modifications(gm)
@@ -272,7 +290,5 @@
    new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
    gm._out_spec = new_out_spec
    logger.debug("After inserting KV cache into the graph: " + str(gm.graph))

    return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/utils.py	2025-05-28 23:49:40.728795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/utils.py	2025-05-28 23:50:03.690290+00:00
@@ -2,13 +2,14 @@
from transformers import StoppingCriteriaList
from transformers.generation.stopping_criteria import (
    EosTokenCriteria,
    MaxLengthCriteria,
)
-import numpy as np 
-import copy 
+import numpy as np
+import copy
import timeit
+

def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
    """
    Exports the LLM model into an ExportedProgram with dynamic shapes.
    In the case of guard failures due to some PyTorch kernel implements, we also
@@ -36,31 +37,38 @@
                allow_complex_guards_as_runtime_asserts=True,
            )

    return ep

+
def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule):
    """
    Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule.
-    
+
    This function identifies placeholder nodes in the graph that represent KV cache tensors,
    and creates zeroed tensors with the same shape, dtype, and device as the original placeholders.
-    
+
    Args:
        model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders
-        
+
    Returns:
        tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph
    """
    # placeholder nodes are expected to be in the following order:
    # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx
    placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"]
    # The first two inputs are input_ids and is_causal. The last two inputs are start_idx and end_idx. In between are the KV cache tensors.
    kv_cache_inputs = placeholder_nodes[2:-2]
    zeroed_kv_cache_inputs = []
    for input in kv_cache_inputs:
-        zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0")))
+        zeroed_kv_cache_inputs.append(
+            torch.zeros(
+                input.meta["val"].shape,
+                dtype=input.meta["val"].dtype,
+                device=torch.device("cuda:0"),
+            )
+        )

    return tuple(zeroed_kv_cache_inputs)


def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True):
@@ -87,10 +95,11 @@
        if not benchmark and stopping_criteria(input_seq, logits).item():
            break
    # breakpoint()
    return input_seq

+
def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id):
    """
    Greedy decoding of the model with KV cache.
    """
    start_idx = 0
@@ -112,27 +121,26 @@
        next_token_logits = logits[:, -1, :]
        next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        output_seq = torch.cat([output_seq, next_tokens], dim=-1)
        input_seq = next_tokens
        start_idx = end_idx
-        end_idx = start_idx + 1 
+        end_idx = start_idx + 1
    lkv = torch.cat(logits_concat, dim=1)
    # breakpoint()
    return output_seq
+

def time_generate(
    generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10
):
    """
    Measure the time for generating a sentence over certain number of iterations
    """
    timings = []
    for _ in range(iterations):
        start_time = timeit.default_timer()
-        _ = generate_fn(
-            model, inputs, output_seq_length, eos_token_id
-        )
+        _ = generate_fn(model, inputs, output_seq_length, eos_token_id)
        torch.cuda.synchronize()
        end_time = timeit.default_timer()
        timings.append(end_time - start_time)

    return timings
@@ -160,6 +168,6 @@
        "Median-Latency(ms)": time_med * 1000,
        "Mean-Latency(ms)": time_mean * 1000,
        "Latency-StdDev(ms)": time_std * 1000,
        "Compile Time(s)": compile_time_s,
    }
-    return stats
\ No newline at end of file
+    return stats
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/test_static_cache.py	2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/test_static_cache.py	2025-05-28 23:50:03.847453+00:00
@@ -17,61 +17,83 @@
ATOL = 1e-5
RTOL = 1e-5
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

+
class DynamicCacheModel(nn.Module):
    def __init__(self):
        super().__init__()
-        
+
    def forward(self, q, k, v, k1, v1, flag):
-        def true_fn(q, k, v, k1, v1):   
+        def true_fn(q, k, v, k1, v1):
            k_new = torch.cat((k, k1), dim=2)
            v_new = torch.cat((v, v1), dim=2)
            return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new)

        def false_fn(q, k, v, k1, v1):
            return torch._C._nn.scaled_dot_product_attention(q, k, v)

        out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1))

        return 2 * out
-    
+
+
class ModelNoCache(nn.Module):
    def __init__(self):
        super().__init__()
-        
+
    def forward(self, q, k, v):
-        return torch._C._nn.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
+        return torch._C._nn.scaled_dot_product_attention(
+            q, k, v, dropout_p=0.0, is_causal=True
+        )
+

class StaticCacheModel(nn.Module):
    def __init__(self):
        super().__init__()
-        
-    # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): 
+
+    # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
    #     new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2)
    #     new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2)
    #     out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal)
-        
+
    #     return out, new_key_cache, new_value_cache
-    
-    def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): 
-        concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)  # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ]
+
+    def forward(
+        self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+    ):
+        concat_keys = torch.cat(
+            (key_cache[:, :, :start_idx, :], k), dim=2
+        )  # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ]
        concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
        new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
-        new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
-        out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal)
-        
+        new_value_cache = torch.cat(
+            (concat_values, value_cache[:, :, end_idx:, :]), dim=2
+        )
+        out = torch._C._nn.scaled_dot_product_attention(
+            q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
+        )
+
        return out, new_key_cache, new_value_cache


-def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
-        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
+def eager_sdpa(
+    query,
+    key,
+    value,
+    attn_mask=None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+    enable_gqa=False,
+) -> torch.Tensor:
    """
    Eager implementation of SDPA
    """
    import math
+
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    breakpoint()
    if is_causal:
@@ -85,24 +107,28 @@
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
-        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
-        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
+        key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
+        value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

+
def print_diff(tensor1, tensor2, prefix=""):
    """
    Print the diff between two tensors
    """
-    print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}")
+    print(
+        f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+    )
+

def test_static_cache_model(args):
    """
    Test the static cache model
    """
@@ -117,13 +143,15 @@

        # Test Prefill
        start_idx = 0
        end_idx = 2048
        out_no_cache = model_no_cache(q, k, v)
-        out_static_cache, new_key_cache, new_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True)
+        out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+            q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+        )
        assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL)
-        
+
        # Test Generate
        for start_idx in range(2048, 2176):
            end_idx = start_idx + 1
            q_curr = torch.randn(1, 32, 1, 64).cuda()
            k_curr = torch.randn(1, 32, 1, 64).cuda()
@@ -133,17 +161,29 @@
            q_full = torch.cat((q, q_curr), dim=2)
            k_full = torch.cat((k, k_curr), dim=2)
            v_full = torch.cat((v, v_curr), dim=2)

            out_no_cache = model_no_cache(q_full, k_full, v_full)
-            out_static_cache, new_key_cache, new_value_cache = model_static_cache(q_curr, k_curr, v_curr, new_key_cache, new_value_cache, start_idx, end_idx, is_causal=False)
-
-            assert torch.allclose(out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL)
-            q = q_full 
+            out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+                q_curr,
+                k_curr,
+                v_curr,
+                new_key_cache,
+                new_value_cache,
+                start_idx,
+                end_idx,
+                is_causal=False,
+            )
+
+            assert torch.allclose(
+                out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL
+            )
+            q = q_full
            k = k_full
            v = v_full
        print("============== test_static_cache passed ==============")
+

def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args):
    """
    Transform the graph module by adding key and value cache to the graph
    """
@@ -155,52 +195,53 @@
        use_python_runtime=True,
        debug=args.debug,
        min_block_size=1,
    )
    exported_program = pre_export_lowering(exported_program, settings)
-    exported_program = exported_program.run_decompositions(
-        get_decompositions(False)
-    )
+    exported_program = exported_program.run_decompositions(get_decompositions(False))

    gm = exported_program.module()
    gm = post_lowering(gm, settings)

    return gm

+
def test_static_cache_lowering(args):
    """
-    Test static cache lowering pass applied to the model with no cache and run the graph module 
+    Test static cache lowering pass applied to the model with no cache and run the graph module
    and compare the output with the model with no cache
    """
    import static_cache2

    model_no_cache = ModelNoCache().eval().cuda()
    q = torch.randn(1, 32, 2, 64).cuda()
    k = torch.randn(1, 32, 2048, 64).cuda()
    v = torch.randn(1, 32, 2048, 64).cuda()
    key_cache = torch.zeros(1, 32, 2176, 64).cuda()
    value_cache = torch.zeros(1, 32, 2176, 64).cuda()
-    
+
    # Export the model
    q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
    kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
    exported_program = export(
        model_no_cache,
        args=(q, k, v),
-        dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}),
-        strict=False
+        dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+        strict=False,
    )

    gm = transform_gm_with_kv_cache(exported_program, args)

    # Test Prefill
    start_idx = 0
    end_idx = 2048
    is_causal = True
    q = torch.randn(1, 32, 2048, 64).cuda()
    out_no_cache = model_no_cache(q, k, v)
-    out_pyt_cache, key_cache, value_cache = gm(q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx)
+    out_pyt_cache, key_cache, value_cache = gm(
+        q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx
+    )
    assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL)

    # Test Generate
    for start_idx in range(2048, 2176):
        end_idx = start_idx + 1
@@ -209,20 +250,32 @@
        k_curr = torch.randn(1, 32, 1, 64).cuda()
        v_curr = torch.randn(1, 32, 1, 64).cuda()
        # Concatenate the current query, key, and value with the previous ones
        q_full = torch.cat((q, q_curr), dim=2)
        k_full = torch.cat((k, k_curr), dim=2)
-        v_full = torch.cat((v, v_curr), dim=2)   
-        
+        v_full = torch.cat((v, v_curr), dim=2)
+
        out_no_cache = model_no_cache(q_full, k_full, v_full)
-        out_pyt_static_cache, key_cache, value_cache = gm(q_curr, k_curr, v_curr, is_causal, key_cache, value_cache, start_idx, end_idx)
-        assert torch.allclose(out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL)
-        q = q_full 
+        out_pyt_static_cache, key_cache, value_cache = gm(
+            q_curr,
+            k_curr,
+            v_curr,
+            is_causal,
+            key_cache,
+            value_cache,
+            start_idx,
+            end_idx,
+        )
+        assert torch.allclose(
+            out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL
+        )
+        q = q_full
        k = k_full
        v = v_full
-    
+
    print("============== test_static_cache_lowering passed ==============")
+

def test_static_cache_export(args):
    """
    Test the static cache model export
    """
@@ -236,19 +289,28 @@
    start_idx = 0
    end_idx = 2048
    is_causal = True
    # Export the model
    seq_len = torch.export.Dim("seq_len", min=2, max=2048)
-    seq_len_dyn_dim = {2 : seq_len}
+    seq_len_dyn_dim = {2: seq_len}
    exported_program = export(
        model_static_cache,
        args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal),
-        dynamic_shapes=(seq_len_dyn_dim, seq_len_dyn_dim, seq_len_dyn_dim, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC, None),
-        strict=False
-    )
-    
-        
+        dynamic_shapes=(
+            seq_len_dyn_dim,
+            seq_len_dyn_dim,
+            seq_len_dyn_dim,
+            None,
+            None,
+            torch.export.Dim.DYNAMIC,
+            torch.export.Dim.DYNAMIC,
+            None,
+        ),
+        strict=False,
+    )
+
+
def test_static_cache_with_torch_tensorrt(args):
    """
    Test the static cache model with torch_tensorrt
    """
    import static_cache2
@@ -257,83 +319,104 @@
    q = torch.randn(1, 32, 2, 64).cuda()
    k = torch.randn(1, 32, 2048, 64).cuda()
    v = torch.randn(1, 32, 2048, 64).cuda()
    key_cache = torch.zeros(1, 32, 2176, 64).cuda()
    value_cache = torch.zeros(1, 32, 2176, 64).cuda()
-    
+
    # Export the model
    q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
    kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
    exported_program = export(
        model_no_cache,
        args=(q, k, v),
-        dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}),
-        strict=False
-    )
-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
+        dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+        strict=False,
+    )
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
        trt_model = torch_tensorrt.dynamo.compile(
            exported_program,
            inputs=[q, k, v],
            enabled_precisions={torch.float32},
            disable_tf32=True,
            use_python_runtime=True,
            debug=args.debug,
            min_block_size=1,
        )
-    
+
    start_idx = 0
    end_idx = 2048
    is_causal = True
    q = torch.randn(1, 32, 2048, 64).cuda()
    # out_eager = eager_sdpa(q, k, v, is_causal=is_causal)
    out_no_cache = model_no_cache(q, k, v)
-    out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx)
+    out_trt, trt_key_cache, trt_value_cache = trt_model(
+        q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx
+    )
    # breakpoint()
-    assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL), "Prefill TRT logits don't match"
-    assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL), "Prefill TRT key cache don't match"
-    assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL), "Prefill TRT value cache don't match"
-    
+    assert torch.allclose(
+        out_no_cache, out_trt, atol=ATOL, rtol=RTOL
+    ), "Prefill TRT logits don't match"
+    assert torch.allclose(
+        trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL
+    ), "Prefill TRT key cache don't match"
+    assert torch.allclose(
+        trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL
+    ), "Prefill TRT value cache don't match"
+
    # Test Generate
    for start_idx in range(2048, 2176):
        end_idx = start_idx + 1
        q_curr = torch.randn(1, 32, 1, 64).cuda()
        k_curr = torch.randn(1, 32, 1, 64).cuda()
-        v_curr = torch.randn(1, 32, 1, 64).cuda()   
+        v_curr = torch.randn(1, 32, 1, 64).cuda()
        # Concatenate the current query, key, and value with the previous ones
        q_full = torch.cat((q, q_curr), dim=2)
        k_full = torch.cat((k, k_curr), dim=2)
-        v_full = torch.cat((v, v_curr), dim=2)   
+        v_full = torch.cat((v, v_curr), dim=2)
        is_causal = False
        out_no_cache = model_no_cache(q_full, k_full, v_full)
-        out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, is_causal, trt_key_cache, trt_value_cache, start_idx, end_idx)
+        out_trt, trt_key_cache, trt_value_cache = trt_model(
+            q_curr,
+            k_curr,
+            v_curr,
+            is_causal,
+            trt_key_cache,
+            trt_value_cache,
+            start_idx,
+            end_idx,
+        )
        # breakpoint()
        # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}")
        # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}")
        # print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}")
-        assert torch.allclose(out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL), f"Generate TRT logits don't match for idx {start_idx}"
-        assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL), f"Generate TRT key cache don't match for idx {start_idx}"
-        assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL), f"Generate TRT value cache don't match for idx {start_idx}"
-        q = q_full 
+        assert torch.allclose(
+            out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL
+        ), f"Generate TRT logits don't match for idx {start_idx}"
+        assert torch.allclose(
+            trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL
+        ), f"Generate TRT key cache don't match for idx {start_idx}"
+        assert torch.allclose(
+            trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL
+        ), f"Generate TRT value cache don't match for idx {start_idx}"
+        q = q_full
        k = k_full
        v = v_full

    print("============== test_static_cache_with_torch_tensorrt passed ==============")
-    
+

def main():
    arg_parser = argparse.ArgumentParser(
        description="Run test cases for llama attention and decoder"
    )
    arg_parser.add_argument(
-        "--debug",
-        action="store_true",
-        help="Enable debug (default: False)"
+        "--debug", action="store_true", help="Enable debug (default: False)"
    )
    args = arg_parser.parse_args()
    with torch.inference_mode():
        # test_static_cache_model(args)
        # test_static_cache_lowering(args)
        test_static_cache_with_torch_tensorrt(args)
-    
+

if __name__ == "__main__":
-    main()
\ No newline at end of file
+    main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-28 23:49:40.738795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-28 23:50:04.807491+00:00
@@ -692,10 +692,11 @@

    gm = exported_program.module()
    exported_program.module().to("cpu")
    torch.cuda.empty_cache()
    import gc
+
    gc.collect()
    logger.debug("Input graph: " + str(gm.graph))

    # Apply lowering on the graph module
    gm = post_lowering(gm, settings)
@@ -790,31 +791,30 @@
        # TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this.
        logger.warning(
            "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
        )

-    
    # Store the original input spec for later use
-    original_in_spec = getattr(gm, '_in_spec', None)
-    original_out_spec = getattr(gm, '_out_spec', None)
-    
+    original_in_spec = getattr(gm, "_in_spec", None)
+    original_out_spec = getattr(gm, "_out_spec", None)
+
    # Function to preserve and restore module specs
    def preserve_module_specs(in_spec, out_spec, target_module):
        """
        Applies input and output specs to the target module.
-        
+
        Args:
            in_spec: The input spec to apply
            out_spec: The output spec to apply
            target_module: The module to apply specs to
        """
        # Apply specs to target module
        if in_spec is not None:
            target_module._in_spec = in_spec
        if out_spec is not None:
            target_module._out_spec = out_spec
-            
+
        return target_module

    # Partition module into components that can be TRT-accelerated
    fast_partitioner_failed = False
    # If specified, try using the fast partitioner and fall back to the global one on failure
@@ -1197,11 +1197,11 @@
        "enable_weight_streaming": enable_weight_streaming,
        "tiling_optimization_level": tiling_optimization_level,
        "l2_limit_for_tiling": l2_limit_for_tiling,
        "offload_module_to_cpu": offload_module_to_cpu,
    }
-    
+
    settings = CompilationSettings(**compilation_options)
    logger.info("Compilation Settings: %s\n", settings)

    exported_program = pre_export_lowering(exported_program, settings)
    # Decompose the exported program
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py	2025-05-28 23:49:40.742795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py	2025-05-28 23:50:05.774123+00:00
@@ -23,11 +23,15 @@
    """Replace specific versions of scaled_dot_product_attention with an equivalent
    implementation which can be easily converted to TRT
    """
    original_fns, replacement = scaled_dot_product_attention_replacement()
    replaced_nodes = []
-    sdpa_nodes = [node for node in gm.graph.nodes if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default]
+    sdpa_nodes = [
+        node
+        for node in gm.graph.nodes
+        if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default
+    ]
    breakpoint()
    # For each original function, search for it in the graph and replace
    for original in original_fns:
        replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters(
            gm,
@@ -167,6 +171,6 @@
    def replacement(
        query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> torch.Tensor:
        return torch.nn.functional.scaled_dot_product_attention(query, key, value)

-    return (efficient, flash, efficient_scale, flash_scale), replacement
\ No newline at end of file
+    return (efficient, flash, efficient_scale, flash_scale), replacement
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2025-05-28 23:49:40.743795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2025-05-28 23:50:06.419303+00:00
@@ -742,14 +742,15 @@
        """
        # Representation of input shapes to a given model
        # Shapes are concatenated as so:
        # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
        tensor_inputs = [
-            t if isinstance(t, torch.Tensor) else torch.tensor(t)
-            for t in inputs
+            t if isinstance(t, torch.Tensor) else torch.tensor(t) for t in inputs
        ]
-        new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs)
+        new_shape_key = "".join(
+            str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+        )

        # If the new shape key differs from the existing one,
        # invalidate the old shape key and remove the CUDAGraph
        if new_shape_key != self.shape_key:
            logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py	2025-05-28 23:49:40.743795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py	2025-05-28 23:50:06.446920+00:00
@@ -369,27 +369,35 @@
            is_shape_tensor_input = self.engine.is_shape_inference_io(input_name)
            if need_cudagraphs_record:
                # If cudagraphs is enabled, this memory is reserved for future cudagraph runs
                # Clone is required to avoid re-using user-provided GPU memory
                if is_shape_tensor_input:
-                    self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].cpu().clone()
+                    self._input_buffers[inputs_shape_key][i] = (
+                        contiguous_inputs[i].cpu().clone()
+                    )
                else:
-                    self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].clone()
+                    self._input_buffers[inputs_shape_key][i] = contiguous_inputs[
+                        i
+                    ].clone()

            # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
            # as per TensorRT requirements
            if is_shape_tensor_input:
                # Shape tensor inputs are casted to int64 explicitly
                # Currently Torch CPU pointers are not working; numpy pointers are used instead
                # to refer to underlying memory
                inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64)
-                inputs_cpu_numpy = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
+                inputs_cpu_numpy = (
+                    contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
+                )
                # if cudagraphs_enabled:
                #     self._input_buffers[inputs_shape_key][i].copy_(inputs_cpu)
                #     self.context.set_tensor_address(input_name, self._input_buffers[inputs_shape_key][i].numpy().copy().ctypes.data)
                # else:
-                self.context.set_tensor_address(input_name, inputs_cpu_numpy.ctypes.data)
+                self.context.set_tensor_address(
+                    input_name, inputs_cpu_numpy.ctypes.data
+                )
            else:
                self.context.set_input_shape(
                    input_name, tuple(contiguous_inputs[i].shape)
                )
                if cudagraphs_enabled:
@@ -458,11 +466,14 @@
                assert len(contiguous_inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."

                self.setup_input_tensors(
-                    contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record, inputs_shape_key
+                    contiguous_inputs,
+                    self.cudagraphs_enabled,
+                    need_cudagraphs_record,
+                    inputs_shape_key,
                )

                if shape_changed:
                    # Check if input shapes can be inferred.
                    uninferred_input_names = self.context.infer_shapes()
@@ -496,11 +507,12 @@
                    if need_cudagraphs_record:
                        self._output_buffers[inputs_shape_key][o] = outputs[o].clone()

                    if self.cudagraphs_enabled:
                        self.context.set_tensor_address(
-                            output_name, self._output_buffers[inputs_shape_key][o].data_ptr()
+                            output_name,
+                            self._output_buffers[inputs_shape_key][o].data_ptr(),
                        )
                    else:
                        self.context.set_tensor_address(
                            output_name, outputs[o].data_ptr()
                        )
@@ -522,30 +534,35 @@
                self._engine_stream.wait_stream(self._caller_stream)

                with torch.cuda.stream(self._engine_stream):
                    if self.cudagraphs_enabled:
                        if need_cudagraphs_record:
-                            
-                            self.shape_key_to_cudagraph[inputs_shape_key] = torch.cuda.CUDAGraph()
+
+                            self.shape_key_to_cudagraph[inputs_shape_key] = (
+                                torch.cuda.CUDAGraph()
+                            )

                            if self.profiling_enabled:
-                                self.shape_key_to_cudagraph[inputs_shape_key].enable_debug_mode()
+                                self.shape_key_to_cudagraph[
+                                    inputs_shape_key
+                                ].enable_debug_mode()

                            with torch.cuda.graph(
-                                self.shape_key_to_cudagraph[inputs_shape_key], stream=self._engine_stream
+                                self.shape_key_to_cudagraph[inputs_shape_key],
+                                stream=self._engine_stream,
                            ):
                                self.context.execute_async_v3(
                                    self._engine_stream.cuda_stream
                                )

                            if self.profiling_enabled:
                                import tempfile

                                with tempfile.TemporaryDirectory() as tmpdir:
-                                    self.shape_key_to_cudagraph[inputs_shape_key].debug_dump(
-                                        f"{tempdir}/{self.name}_cudagraph.dot"
-                                    )
+                                    self.shape_key_to_cudagraph[
+                                        inputs_shape_key
+                                    ].debug_dump(f"{tempdir}/{self.name}_cudagraph.dot")

                        self.shape_key_to_cudagraph[inputs_shape_key].replay()  # type: ignore

                    else:
                        self.context.execute_async_v3(self._engine_stream.cuda_stream)
@@ -754,18 +771,21 @@
        """
        # Representation of input shapes to a given model
        # Shapes are concatenated as so:
        # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
        tensor_inputs = [
-            t if isinstance(t, torch.Tensor) else torch.tensor(t)
-            for t in inputs
+            t if isinstance(t, torch.Tensor) else torch.tensor(t) for t in inputs
        ]
-        new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs)
+        new_shape_key = "".join(
+            str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+        )

        # If the new shape key differs from the existing one,
        # invalidate the old shape key and remove the CUDAGraph
        if new_shape_key not in self.shape_key_to_cudagraph:
-            logger.debug(f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape.")
+            logger.debug(
+                f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape."
+            )
            # self.shape_key = new_shape_key
            return True, new_shape_key

        return False, new_shape_key
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2025-05-28 23:49:40.739795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2025-05-28 23:50:07.289014+00:00
@@ -1893,10 +1893,11 @@
        SourceIR.ATEN,
        name,
        args[0],
        args[1],
    )
+

@dynamo_tensorrt_converter(operator.sub, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True)
def aten_ops_sub(

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama_benchmark.py	2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama_benchmark.py	2025-05-31 02:01:02.496778+00:00
@@ -10,68 +10,72 @@
def main():
    # Initialize model and tokenizer
    print("Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
-        MODEL_NAME,
-        torch_dtype=torch.float16,
-        use_cache=False,
-        device_map="auto"
+        MODEL_NAME, torch_dtype=torch.float16, use_cache=False, device_map="auto"
    )
    model.generation_config.cache_implementation = "static"
    model.forward = torch.compile(model.forward)
-    
+
    # Prepare input prompt
    word = "What"
    # Tokenize the word
-    word_ids = tokenizer(word, return_tensors="pt").input_ids[0]  # Get the first (and only) sequence
+    word_ids = tokenizer(word, return_tensors="pt").input_ids[
+        0
+    ]  # Get the first (and only) sequence
    # Repeat the token 2048 times
-    input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device)  # Add batch dimension and move to device
+    input_ids = (
+        word_ids.repeat(1024).unsqueeze(0).to(model.device)
+    )  # Add batch dimension and move to device
    print(f"Input tensor shape: {input_ids.shape}")

    # # Warm-up pass
    print("Running warm-up pass...")
    output_ids = model.generate(
        input_ids,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
-        use_cache=USE_CACHE
+        use_cache=USE_CACHE,
    )
-    
+
    # Benchmark loop
    print("Running benchmark...")
    num_iterations = 10
    total_time = 0
    timings = []
-    
+
    for i in range(num_iterations):
        start_time = timeit.default_timer()
        output_ids = model.generate(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
-            use_cache=USE_CACHE
+            use_cache=USE_CACHE,
        )
        end_time = timeit.default_timer()
        generation_time = end_time - start_time
        total_time += generation_time
        timings.append(generation_time)
-        
+
        # Decode and print first iteration output
        # if i == 0:
        #     output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        #     print("\nFirst generation output:")
        #     print(output_text)
-    
+
    # Calculate and print statistics
    average_time = total_time / num_iterations
    print(f"\nPerformance Statistics:")
-    print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds")
+    print(
+        f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds"
+    )
    print(f"Average tokens per second: {100/average_time:.2f}")
    print("\nIndividual timings (ms):")
    for i, t in enumerate(timings):
        print(f"Iteration {i+1}: {t*1000:.2f}")

+
if __name__ == "__main__":
-    main() 
\ No newline at end of file
+    main()
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/cache_utils.py	2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/cache_utils.py	2025-05-31 02:01:02.572391+00:00
@@ -5,81 +5,90 @@
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._pytree import _LEAF_SPEC
from torch._export.utils import _detect_fake_mode_from_gm

+
def get_kv_nodes(gm):
    """
    Get the key and value nodes from the graph.
    """
    kv_nodes = []
    for node in gm.graph.nodes:
-        if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
+        if (
+            node.op == "call_function"
+            and node.target == torch._C._nn.scaled_dot_product_attention
+        ):
            q_node, k_node, v_node = node.args[:3]
            kv_nodes.append((k_node, v_node))
    return kv_nodes

+
def get_random_tensor_from_node(node: Node) -> torch.Tensor:
-        """
-        Creates a random tensor based on the shape information in a node's metadata.
-        For symbolic dimensions, extracts the maximum value from the shape environment.
-        
-        Args:
-            node: A torch.fx.Node object with metadata containing tensor information
-            
-        Returns:
-            A random tensor with shape matching the node's metadata, or None if no valid
-            tensor information is found
-        """
-        if "val" not in node.meta:
-            raise ValueError(f"No tensor information found in node metadata for node: {node}")
-            
-        fake_tensor = node.meta["val"]
-        shape = []
-        
-        # Iterate through each dimension and handle symbolic dimensions
-        for dim in fake_tensor.shape:
-            if isinstance(dim, torch.SymInt):
-                # Extract the maximum value from the shape environment
-                max_val = dim.node.hint
-                shape.append(max_val)
-            else:
-                shape.append(dim)
-        
-        # Create a random tensor with the determined shape
-        dtype = fake_tensor.dtype
-        device = fake_tensor.device
-        random_tensor = torch.rand(shape, dtype=dtype, device=device)
+    """
+    Creates a random tensor based on the shape information in a node's metadata.
+    For symbolic dimensions, extracts the maximum value from the shape environment.

-        return random_tensor
+    Args:
+        node: A torch.fx.Node object with metadata containing tensor information
+
+    Returns:
+        A random tensor with shape matching the node's metadata, or None if no valid
+        tensor information is found
+    """
+    if "val" not in node.meta:
+        raise ValueError(
+            f"No tensor information found in node metadata for node: {node}"
+        )
+
+    fake_tensor = node.meta["val"]
+    shape = []
+
+    # Iterate through each dimension and handle symbolic dimensions
+    for dim in fake_tensor.shape:
+        if isinstance(dim, torch.SymInt):
+            # Extract the maximum value from the shape environment
+            max_val = dim.node.hint
+            shape.append(max_val)
+        else:
+            shape.append(dim)
+
+    # Create a random tensor with the determined shape
+    dtype = fake_tensor.dtype
+    device = fake_tensor.device
+    random_tensor = torch.rand(shape, dtype=dtype, device=device)
+
+    return random_tensor
+

def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]:
    """
    Creates random tensors based on the shape information in node metadata.
    For symbolic dimensions, extracts the maximum value from the shape environment.
-    
+
    Args:
        nodes: List of torch.fx.Node objects with metadata
-        
+
    Returns:
        List of random tensors with shapes matching the nodes' metadata
    """
    random_tensors = []
-    
+
    for node in nodes:
        if isinstance(node, Node):
            node_tensor = get_random_tensor_from_node(node)
        elif isinstance(node, tuple):
            node_tensor_list = []
            for n in node:
                random_tensor = get_random_tensor_from_node(n)
                node_tensor_list.append(random_tensor)
            node_tensor = tuple(node_tensor_list)
-               
+
        random_tensors.append(node_tensor)
-    
+
    return random_tensors
+

def add_graph_input(
    gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
) -> Node:
    """Add a graph input to the given GraphModule and return the newly created node.
@@ -130,10 +139,11 @@
        in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor)

    # return new node...
    return in_node

+
def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool:
    """Check if the node is a call to one of the ops."""
    if node.op != "call_function":
        return False
    # check if it's a single op that's provided
@@ -144,9 +154,10 @@
    if any(node.target == op for op in ops):
        return True

    return False

+
def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]:
    input_nodes: List[Node] = graph.find_nodes(op="placeholder")
    output_nodes: List[Node] = graph.find_nodes(op="output")
-    return (input_nodes, output_nodes)
\ No newline at end of file
+    return (input_nodes, output_nodes)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/dynamic_cache.py	2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/dynamic_cache.py	2025-05-31 02:01:02.649617+00:00
@@ -12,36 +12,45 @@
from torch_tensorrt.dynamo.utils import extract_var_range_info
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
    clean_up_graph_after_modifications,
)

-from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op
+from cache_utils import (
+    add_graph_input,
+    create_random_output_tensors,
+    get_kv_nodes,
+    is_op,
+)
import tensorrt
import torch.utils._pytree as pytree
+
logger = logging.getLogger(__name__)

-@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True)
+
+@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
+    torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True
+)
def cond_converter(
    ctx: torch_tensorrt.dynamo.conversion.ConversionContext,
    target: Target,
    args: Tuple[Any, ...],
    kwargs: Dict[str, Any],
    name: str,
) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]:
    """
    Converter for torch.ops.higher_order.cond operation to TensorRT.
-    
+
    This function handles the conversion of PyTorch's conditional operation to TensorRT.
    The conditional operation selects between two tensors based on a boolean predicate.
-    
+
    Args:
        ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context
        target (Target): The target operation to convert
        args (Tuple[Argument, ...]): The arguments to the operation
        kwargs (Dict[str, Argument]): The keyword arguments to the operation
        name (str): The name to give to the TensorRT layer
-        
+
    Returns:
        Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s)
    """
    if_layer = ctx.net.add_if_conditional()
    condition, true_branch, false_branch = args[0], args[1], args[2]
@@ -49,30 +58,31 @@
    output_layer = if_layer.add_output(true_branch, false_branch)
    output = output_layer.get_output(0)

    return output

+
def add_kv_as_outputs(gm):
    """
    Modifies the graph to add query, key, and value tensors as outputs.
-    
+
    This function identifies all scaled dot-product attention (SDPA) operations
    in the graph, creates copies of their query, key, and value inputs, and adds
    these copies to the graph's outputs. This allows for accessing these tensors
    externally, which is useful for operations like key-value caching.
-    
+
    Args:
        graph: The torch.fx.Graph to modify
-        
+
    Returns:
        None. The graph is modified in-place.
    """
    # list of MHA kernels we would want to detect and replace
    mha_ops = {
        torch._C._nn.scaled_dot_product_attention,
    }
-    
+
    # Find all SDPA nodes in the graph
    mha_nodes = []
    for node in gm.graph.nodes:
        if is_op(node, mha_ops):
            mha_nodes.append(node)
@@ -80,157 +90,170 @@
    # Iterate through each MHA node to extract shape information
    for mha_node in mha_nodes:
        if "val" in mha_node.meta and len(mha_node.args) >= 3:
            # Get the input nodes (query, key, value)
            q_node, k_node, v_node = mha_node.args[:3]
-            
+
            # Add the copy nodes as outputs to the graph
-            output_node = next(node for node in gm.graph.nodes if node.op == "output")            
+            output_node = next(node for node in gm.graph.nodes if node.op == "output")

            # Get the current output args (typically a tuple)
            current_outputs = output_node.args[0]
-            
+
            # If the current output is a tuple, extend it with our new outputs
            if isinstance(current_outputs, tuple):
                new_outputs = current_outputs + ((k_node, v_node),)
            else:
                # If there's only one output or it's not a tuple, create a new tuple
                new_outputs = (current_outputs, (k_node, v_node))
-            
+
            gm.graph.output(new_outputs)
            gm.graph.erase_node(output_node)
-        
+
    return new_outputs


-
-
def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True):
-        """
-        Add key-value tensors and index parameters as inputs to the graph.
-        
-        Args:
-            gm: The GraphModule to modify
-            fixed_kv: Boolean indicating whether to use static tensors for KV cache
-            
-        Returns:
-            A tuple containing:
-            - List of (k_input, v_input) node pairs for each SDPA operation
-            - start_idx input node for slicing operations
-            - end_idx input node for slicing operations
-        """
-
-        def get_static_tensor(tensor: torch.Tensor):
-            key_shape = []
-            for dim in tensor.shape:
-                if isinstance(dim, torch.SymInt):
-                    min_max_opt = extract_var_range_info(dim)
-                    key_shape.append(min_max_opt["max"])
-                else:
-                    key_shape.append(dim)
-            
-            static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
-            return static_tensor
-        
-        keys_values = get_kv_nodes(gm)
-
-        kv_inputs = []
-        for idx, key_value in enumerate(keys_values):
-            k_val = key_value[0].meta["val"]
-            v_val = key_value[1].meta["val"]
-            if fixed_kv:
-                k_val = get_static_tensor(k_val)
-                v_val = get_static_tensor(v_val)
-
-            # Add new inputs using add_graph_input
-            k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
-            v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
-            kv_inputs.append((k_input, v_input))
-
-        return kv_inputs
-
-
-def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]):
+    """
+    Add key-value tensors and index parameters as inputs to the graph.
+
+    Args:
+        gm: The GraphModule to modify
+        fixed_kv: Boolean indicating whether to use static tensors for KV cache
+
+    Returns:
+        A tuple containing:
+        - List of (k_input, v_input) node pairs for each SDPA operation
+        - start_idx input node for slicing operations
+        - end_idx input node for slicing operations
+    """
+
+    def get_static_tensor(tensor: torch.Tensor):
+        key_shape = []
+        for dim in tensor.shape:
+            if isinstance(dim, torch.SymInt):
+                min_max_opt = extract_var_range_info(dim)
+                key_shape.append(min_max_opt["max"])
+            else:
+                key_shape.append(dim)
+
+        static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
+        return static_tensor
+
+    keys_values = get_kv_nodes(gm)
+
+    kv_inputs = []
+    for idx, key_value in enumerate(keys_values):
+        k_val = key_value[0].meta["val"]
+        v_val = key_value[1].meta["val"]
+        if fixed_kv:
+            k_val = get_static_tensor(k_val)
+            v_val = get_static_tensor(v_val)
+
+        # Add new inputs using add_graph_input
+        k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+        v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
+        kv_inputs.append((k_input, v_input))
+
+    return kv_inputs
+
+
+def insert_torch_cond_before_sdpa(
+    gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]
+):
    """
    Insert a torch.cond operation before each scaled_dot_product_attention operation.
-    
+
    Args:
        gm: The FX GraphModule to modify
-        
+
    Returns:
        The modified GraphModule
    """
    # Find all nodes with scaled_dot_product_attention
    sdpa_nodes = []
    for node in gm.graph.nodes:
-        if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
+        if (
+            node.op == "call_function"
+            and node.target == torch._C._nn.scaled_dot_product_attention
+        ):
            sdpa_nodes.append(node)
-    
-    # Get the is_causal input node 
-    is_causal_node = next((node for node in gm.graph.nodes if node.op == "placeholder" and node.name == "is_causal"), None)
+
+    # Get the is_causal input node
+    is_causal_node = next(
+        (
+            node
+            for node in gm.graph.nodes
+            if node.op == "placeholder" and node.name == "is_causal"
+        ),
+        None,
+    )

    # For each SDPA node, insert a torch.cond operation before it
    for idx, sdpa_node in enumerate(sdpa_nodes):
- 
+
        with gm.graph.inserting_before(sdpa_node):
            # pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool))
            q_node, k_node, v_node = sdpa_node.args[:3]
            incoming_key, incoming_value = incoming_keys_values[idx]
            # Create nodes for concatenating k with incoming_key and v with incoming_value
            concatenated_k_node = gm.graph.create_node(
                "call_function",
                torch.ops.aten.cat.default,
-                args=([incoming_key, k_node], 2),  # Concatenate along sequence length dimension
-                kwargs={}
+                args=(
+                    [incoming_key, k_node],
+                    2,
+                ),  # Concatenate along sequence length dimension
+                kwargs={},
            )
            concatenated_v_node = gm.graph.create_node(
                "call_function",
                torch.ops.aten.cat.default,
-                args=([incoming_value, v_node], 2),  #  Concatenate along sequence length dimension
-                kwargs={}
-            )
-            
+                args=(
+                    [incoming_value, v_node],
+                    2,
+                ),  #  Concatenate along sequence length dimension
+                kwargs={},
+            )
+
            # Create the torch.cond node
            cond_k_node = gm.graph.create_node(
                "call_function",
                torch.ops.higher_order.cond,
                args=(is_causal_node, concatenated_k_node, k_node),
            )
- 
+
            cond_v_node = gm.graph.create_node(
                "call_function",
                torch.ops.higher_order.cond,
                args=(is_causal_node, concatenated_v_node, v_node),
            )

            sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:]
-    
+
    return gm
-


@_aten_lowering_pass
def insert_dynamic_kv_cache(
    gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
    """Insert FlashInfer MHA + KV cache ops in the graph"""
    """Perform insertion of kv-caches and attention kernel."""

    # Add static key and value as inputs to the graph
-    kv_inputs  = add_kv_and_indices_as_inputs(gm, fixed_kv=True)
+    kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True)

    # Call the function to add KV as outputs
    logits_keys_values = add_kv_as_outputs(gm)

    # Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases
    gm = insert_torch_cond_before_sdpa(gm, kv_inputs)

    gm = clean_up_graph_after_modifications(gm)
-    
+
    new_output_tensors = create_random_output_tensors(logits_keys_values)
    new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
    gm._out_spec = new_out_spec
-    
+
    logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
    return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/run_llm.py	2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/run_llm.py	2025-05-31 02:01:02.719018+00:00
@@ -17,37 +17,44 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from contextlib import nullcontext
-from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache
+from utils import (
+    export_llm,
+    generate,
+    recordStats,
+    time_generate,
+    generate_with_kv_cache,
+)
import sys
import os

# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from register_sdpa import *

DEVICE = torch.device("cuda:0")
+

def get_model(args):
    with torch.no_grad():
        # Supported list of models:
        # - meta-llama/Llama-3.2-1B-Instruct
        # - meta-llama/Llama-3.2-3B-Instruct
        # - meta-llama/Llama-3.1-8B-Instruct
        # - Qwen/Qwen2.5-1.5B-Instruct
        model = (
-                AutoModelForCausalLM.from_pretrained(
-                    args.model,
-                    use_cache=False,
-                    attn_implementation="sdpa",
-                    # num_hidden_layers=1
-                )
-                .eval()
-                .cuda()
-            )
+            AutoModelForCausalLM.from_pretrained(
+                args.model,
+                use_cache=False,
+                attn_implementation="sdpa",
+                # num_hidden_layers=1
+            )
+            .eval()
+            .cuda()
+        )
    if args.precision == "FP16":
        model = model.to(torch.float16)
    elif args.precision == "BF16":
        model = model.to(torch.bfloat16)
    else:
@@ -59,23 +66,23 @@
def compile_torchtrt(model, input_ids, args):
    max_seq_len = input_ids.shape[1] + args.num_tokens
    ep = export_llm(model, input_ids, max_seq_len=max_seq_len)

    # Set precision specific flags
-    use_fp32_acc = False 
+    use_fp32_acc = False
    use_explicit_typing = False
    if args.precision == "FP16":
        enabled_precisions = {torch.float32}
-        use_fp32_acc = True 
+        use_fp32_acc = True
        use_explicit_typing = True
    elif args.precision == "BF16":
        enabled_precisions = {torch.bfloat16}
        use_fp32_acc = False
    else:
        enabled_precisions = {torch.float32}

-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
        trt_model = torch_tensorrt.dynamo.compile(
            ep,
            inputs=[input_ids],
            enabled_precisions=enabled_precisions,
            # truncate_double=True,
@@ -99,112 +106,125 @@
        tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
    )
    print("===================================")


-
def measure_perf(trt_model, input_signature, backend_name):
    # Measure average time for 10 iterations
    import timeit
    import numpy as np
-    
+
    total_time = 0
    iterations = 10
-    
+
    print("Running warmup iteration...")
    # Warmup run
    _ = trt_model(*input_signature)
    torch.cuda.synchronize()
-    
+
    print(f"Measuring performance over {iterations} iterations...")
    for i in range(iterations):
        start_time = timeit.default_timer()
        _ = trt_model(*input_signature)
        torch.cuda.synchronize()
        end_time = timeit.default_timer()
        iter_time = end_time - start_time
        total_time += iter_time
        # print(f"Iteration {i+1}: {iter_time:.4f} seconds")
-    
+
    avg_time = total_time / iterations
-    print(f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds")
-    print(f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second")
+    print(
+        f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds"
+    )
+    print(
+        f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second"
+    )
+

if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(
        description="Run inference on a model with random input values"
    )
    arg_parser.add_argument(
-        "--model", type=str, default="meta-llama/Llama-3.2-1B-Instruct", help="Name of LLM model"
+        "--model",
+        type=str,
+        default="meta-llama/Llama-3.2-1B-Instruct",
+        help="Name of LLM model",
    )
    arg_parser.add_argument(
        "--tokenizer",
        type=str,
        default="",
        help="Name of LLM model tokenizer",
    )
    arg_parser.add_argument(
        "--prompt", type=str, default="What is parallel programming ?", help="Prompt"
    )
-    arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32")
+    arg_parser.add_argument(
+        "--precision",
+        type=str,
+        default="FP16",
+        help="Precision to use in the model. Options: FP16, BF16, FP32",
+    )
    arg_parser.add_argument(
        "--iterations", type=int, default=5, help="no. of iterations to run"
    )
    arg_parser.add_argument(
        "--min_block_size", type=int, default=1, help="no. of iterations to run"
    )
    arg_parser.add_argument(
-        "--num_tokens", type=int, default=128, help="no. of output tokens to be generated"
+        "--num_tokens",
+        type=int,
+        default=128,
+        help="no. of output tokens to be generated",
    )
    arg_parser.add_argument(
        "--batch_size", type=int, default=1, help="Batch size used for benchmarking"
    )
    arg_parser.add_argument(
-        "--isl", type=int, default=2048, help="Input sequence length used for benchmarking"
-    )
-    arg_parser.add_argument(
-        "--enable_pytorch_run", 
-        action="store_true", 
-        help="Enable pytorch run (default: False)"
+        "--isl",
+        type=int,
+        default=2048,
+        help="Input sequence length used for benchmarking",
+    )
+    arg_parser.add_argument(
+        "--enable_pytorch_run",
+        action="store_true",
+        help="Enable pytorch run (default: False)",
    )
    arg_parser.add_argument(
        "--cache",
        type=str,
        default="",
        help="Type of KV cache to use. Options: static_v1, static_v2, dynamic",
    )
    arg_parser.add_argument(
-        "--cudagraph",
-        action="store_true",
-        help="Enable cudagraphs (default: False)"
-    )
-    arg_parser.add_argument(
-        "--debug",
-        action="store_true",
-        help="Enable debug (default: False)"
-    )
-    arg_parser.add_argument(
-        "--benchmark",
-        action="store_true",
-        help="Enable benchmark (default: False)"
-    )
-    
+        "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)"
+    )
+    arg_parser.add_argument(
+        "--debug", action="store_true", help="Enable debug (default: False)"
+    )
+    arg_parser.add_argument(
+        "--benchmark", action="store_true", help="Enable benchmark (default: False)"
+    )
+
    args = arg_parser.parse_args()
    with torch.inference_mode():
        model = get_model(args)

        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model)

        # Prepare input for benchmarking or evaluation
        if args.benchmark:
-            input_ids = torch.randint(1, 10000, (args.batch_size, args.isl), dtype=torch.int64).to(model.device)
+            input_ids = torch.randint(
+                1, 10000, (args.batch_size, args.isl), dtype=torch.int64
+            ).to(model.device)
            position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
        else:
            model_inputs = tokenizer(args.prompt, return_tensors="pt")
            input_ids = model_inputs["input_ids"].to(DEVICE)
            position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
-        

        MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens
        # Pyt
        pyt_gen_tokens = None
        pyt_timings = None
@@ -221,11 +241,15 @@
                    MAX_OUTPUT_SEQ_LENGTH,
                    tokenizer.eos_token_id,
                    iterations=args.iterations,
                )
                pyt_stats = recordStats(
-                    "PyTorch", pyt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None
+                    "PyTorch",
+                    pyt_timings,
+                    args.precision,
+                    batch_size=args.batch_size,
+                    compile_time_s=None,
                )

        if args.cache == "static_v1":
            # This import is required to register static v1 KV cache transformations as lowering passes
            import static_cache_v1
@@ -233,22 +257,28 @@
            # This import is required to register static v2 KV cache transformations as lowering passes
            import static_cache_v2
        elif args.cache == "dynamic":
            import dynamic_cache

-
-        trt_model = compile_torchtrt(model, input_ids, args) 
-            
-        if args.cache == "static_v1" or args.cache == "static_v2" or args.cache == "dynamic":
+        trt_model = compile_torchtrt(model, input_ids, args)
+
+        if (
+            args.cache == "static_v1"
+            or args.cache == "static_v2"
+            or args.cache == "dynamic"
+        ):
            if args.cudagraph:
                # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
                # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
                torch_tensorrt.runtime.set_cudagraphs_mode(True)
-             
+
            trt_gen_tokens = generate_with_kv_cache(
-                trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
-                )
+                trt_model,
+                input_ids.clone(),
+                MAX_OUTPUT_SEQ_LENGTH,
+                tokenizer.eos_token_id,
+            )

            if args.benchmark:
                trt_timings = time_generate(
                    generate_with_kv_cache,
                    trt_model,
@@ -257,36 +287,44 @@
                    tokenizer.eos_token_id,
                    iterations=args.iterations,
                )
        else:
            trt_gen_tokens = generate(
-                trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
+                trt_model,
+                input_ids.clone(),
+                MAX_OUTPUT_SEQ_LENGTH,
+                tokenizer.eos_token_id,
            )
            if args.benchmark:
                trt_timings = time_generate(
                    generate,
                    trt_model,
                    input_ids.clone(),
                    MAX_OUTPUT_SEQ_LENGTH,
                    tokenizer.eos_token_id,
                    iterations=args.iterations,
                )
-        
+
        if args.benchmark:
            trt_stats = recordStats(
-                "TensorRT", trt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None
-            )
-
-        
+                "TensorRT",
+                trt_timings,
+                args.precision,
+                batch_size=args.batch_size,
+                compile_time_s=None,
+            )
+
        if not args.benchmark:
-            if args.enable_pytorch_run: 
+            if args.enable_pytorch_run:
                print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
-            
+
            print_outputs("TensorRT", trt_gen_tokens, tokenizer)

-            if args.enable_pytorch_run: 
-                print(f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}")
+            if args.enable_pytorch_run:
+                print(
+                    f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}"
+                )

        if args.benchmark:
            if args.enable_pytorch_run:
                print("=========PyTorch PERFORMANCE============ \n")
                print(pyt_stats)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/static_cache_v1.py	2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/static_cache_v1.py	2025-05-31 02:01:02.728059+00:00
@@ -12,55 +12,57 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
    clean_up_graph_after_modifications,
)
import torch.utils._pytree as pytree
from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes
+
logger = logging.getLogger(__name__)

SDPA_OP = torch._C._nn.scaled_dot_product_attention

+
def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
    """
    Modifies the graph to add query, key, and value tensors as outputs.
-    
+
    This function identifies all scaled dot-product attention (SDPA) operations
    in the graph, creates copies of their query, key, and value inputs, and adds
    these copies to the graph's outputs. This allows for accessing these tensors
    externally, which is useful for operations like key-value caching.
-    
+
    Args:
        graph: The torch.fx.Graph to modify
-        
+
    Returns:
        None. The graph is modified in-place.
    """
-    output_node = next(node for node in gm.graph.nodes if node.op == "output")            
+    output_node = next(node for node in gm.graph.nodes if node.op == "output")

    # Get the current output args (typically a tuple)
    current_outputs = output_node.args[0]
-    
+
    # If the current output is a tuple, extend it with our new outputs
    if isinstance(current_outputs, tuple):
        new_outputs = current_outputs + tuple(kv_cache_for_graph)
    else:
        # If there's only one output or it's not a tuple, create a new tuple
-        new_outputs = (current_outputs,) +  tuple(kv_cache_for_graph)
-            
+        new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
    gm.graph.output(new_outputs)
    gm.graph.erase_node(output_node)

    return new_outputs


def add_kv_cache_inputs(gm, fixed_kv: bool = True):
    """
    Add key-value tensors, index parameters as inputs to the graph.
-    
+
    Args:
        gm: The GraphModule to modify
        fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
-        
+
    Returns:
        A tuple containing:
        - List of (k_input, v_input) node pairs for each SDPA operation
        - start_idx input node for slicing operations
        - end_idx input node for slicing operations
@@ -72,14 +74,14 @@
            if isinstance(dim, torch.SymInt):
                min_max_opt = extract_var_range_info(dim)
                key_shape.append(min_max_opt["max"])
            else:
                key_shape.append(dim)
-        
+
        static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
        return static_tensor
-    
+
    keys_values = get_kv_nodes(gm)

    kv_inputs = []
    for idx, key_value in enumerate(keys_values):
        k_val = key_value[0].meta["val"]
@@ -87,12 +89,12 @@
        if fixed_kv:
            k_val = get_static_tensor(k_val)
            v_val = get_static_tensor(v_val)

        # Add new inputs using add_graph_input
-        k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
-        v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
+        k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+        v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
        kv_inputs.append((k_input, v_input))

    # Add start_idx and end_idx as inputs
    start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0))
    end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1))
@@ -103,10 +105,11 @@
    seq_len = input_ids_meta.shape[1]
    min_max_opt = extract_var_range_info(seq_len)
    max_seq_len = min_max_opt["max"]

    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
    shape_env = ShapeEnv()
    # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
    start_idx_unbacked_symint = shape_env.create_unbacked_symint()
    torch._check(start_idx_unbacked_symint >= 0)
    torch._check(start_idx_unbacked_symint <= max_seq_len)
@@ -123,138 +126,152 @@
    is_causal_input.meta["val"] = torch.tensor(True)

    return kv_inputs, start_idx_input, end_idx_input, is_causal_input


-
-def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node):
+def insert_kv_slicing_before_sdpa(
+    gm,
+    incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+    start_idx_input: Node,
+    end_idx_input: Node,
+    is_causal_input: Node,
+):
    """
    Insert slicing operations before each scaled_dot_product_attention operation.
    """
    # Find all nodes with scaled_dot_product_attention
    sdpa_nodes = []
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == SDPA_OP:
            sdpa_nodes.append(node)
    kv_cache_for_graph = []
    for idx, sdpa_node in enumerate(sdpa_nodes):
-        assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
+        assert (
+            len(sdpa_node.args) == 6
+        ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
        q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
        incoming_key, incoming_value = incoming_keys_values[idx]
        kv_cache_for_sdpa_node = []
        new_keys_values = []
-        for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]):
+        for key_or_value, current_key_or_value_node in zip(
+            [incoming_key, incoming_value], [k_node, v_node]
+        ):
            # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
            with gm.graph.inserting_before(sdpa_node):
                slice_1 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(key_or_value,),
-                    kwargs={}
+                    kwargs={},
                )
                slice_2 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(slice_1, 1),
-                    kwargs={}
+                    kwargs={},
                )
                slice_3 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_2, 2, None, start_idx_input),  
-                    kwargs={}
+                    args=(slice_2, 2, None, start_idx_input),
+                    kwargs={},
                )
                slice_4 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_3, 3), 
-                    kwargs={}
-                )
-                # =============================================== # 
+                    args=(slice_3, 3),
+                    kwargs={},
+                )
+                # =============================================== #
                # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
                slice_5 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(key_or_value,),
-                    kwargs={}
+                    kwargs={},
                )
                slice_6 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(slice_5, 1),
-                    kwargs={}
+                    kwargs={},
                )
                slice_7 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_6, 2, end_idx_input),  
-                    kwargs={}
+                    args=(slice_6, 2, end_idx_input),
+                    kwargs={},
                )
                slice_8 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_7, 3), 
-                    kwargs={}
-                )
-                # =============================================== # 
+                    args=(slice_7, 3),
+                    kwargs={},
+                )
+                # =============================================== #
                # Concatenate the sliced tensors to build KV cache
                cat = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.cat.default,
-                    args=([slice_4, current_key_or_value_node, slice_8], 2), 
-                    kwargs={}
+                    args=([slice_4, current_key_or_value_node, slice_8], 2),
+                    kwargs={},
                )
                # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
                cat.meta.update(key_or_value.meta)
                kv_cache_for_sdpa_node.append(cat)
-                # =============================================== # 
+                # =============================================== #
                # Get the current key and value by indexing the KV cache
                slice_9 = gm.graph.create_node(
-                    "call_function",
-                    torch.ops.aten.slice.Tensor,
-                    args=(cat,),
-                    kwargs={}
+                    "call_function", torch.ops.aten.slice.Tensor, args=(cat,), kwargs={}
                )
                slice_10 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
                    args=(slice_9, 1),
-                    kwargs={}
+                    kwargs={},
                )
                slice_11 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_10, 2, None, end_idx_input),  
-                    kwargs={}
+                    args=(slice_10, 2, None, end_idx_input),
+                    kwargs={},
                )
                slice_12 = gm.graph.create_node(
                    "call_function",
                    torch.ops.aten.slice.Tensor,
-                    args=(slice_11, 3), 
-                    kwargs={}
+                    args=(slice_11, 3),
+                    kwargs={},
                )
                new_keys_values.append(slice_12)
-        
+
        kv_cache_for_graph.extend(kv_cache_for_sdpa_node)

-        sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, is_causal_input)
-    
+        sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (
+            attn_mask,
+            dropout_p,
+            is_causal_input,
+        )
+
    return gm, kv_cache_for_graph


@_aten_lowering_pass
def insert_kv_cache(
    gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
    """Insert KV cache ops in the graph"""
    """Perform insertion of kv-caches and attention kernel."""
    # Add static key and value as inputs to the graph
-    kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True)
-
-    # Build and update the KV cache using computed KV inputs for current token and 
+    kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(
+        gm, fixed_kv=True
+    )
+
+    # Build and update the KV cache using computed KV inputs for current token and
    # incoming keys and values from previous tokens (which were added as inputs)
-    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input)
+    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+        gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input
+    )

    # Call the function to add KV as outputs
    logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)

    gm = clean_up_graph_after_modifications(gm)
@@ -264,7 +281,5 @@
    new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
    gm._out_spec = new_out_spec
    logger.debug("After inserting KV cache into the graph: " + str(gm.graph))

    return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/static_cache_v2.py	2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/static_cache_v2.py	2025-05-31 02:01:02.768255+00:00
@@ -12,55 +12,57 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
    clean_up_graph_after_modifications,
)
import torch.utils._pytree as pytree
from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes
+
logger = logging.getLogger(__name__)

SDPA_OP = torch._C._nn.scaled_dot_product_attention

+
def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
    """
    Modifies the graph to add query, key, and value tensors as outputs.
-    
+
    This function identifies all scaled dot-product attention (SDPA) operations
    in the graph, creates copies of their query, key, and value inputs, and adds
    these copies to the graph's outputs. This allows for accessing these tensors
    externally, which is useful for operations like key-value caching.
-    
+
    Args:
        graph: The torch.fx.Graph to modify
-        
+
    Returns:
        None. The graph is modified in-place.
    """
-    output_node = next(node for node in gm.graph.nodes if node.op == "output")            
+    output_node = next(node for node in gm.graph.nodes if node.op == "output")

    # Get the current output args (typically a tuple)
    current_outputs = output_node.args[0]
-    
+
    # If the current output is a tuple, extend it with our new outputs
    if isinstance(current_outputs, tuple):
        new_outputs = current_outputs + tuple(kv_cache_for_graph)
    else:
        # If there's only one output or it's not a tuple, create a new tuple
-        new_outputs = (current_outputs,) +  tuple(kv_cache_for_graph)
-            
+        new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
    gm.graph.output(new_outputs)
    gm.graph.erase_node(output_node)

    return new_outputs


def add_kv_cache_inputs(gm, fixed_kv: bool = True):
    """
    Add key-value tensors, index parameters as inputs to the graph.
-    
+
    Args:
        gm: The GraphModule to modify
        fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
-        
+
    Returns:
        A tuple containing:
        - List of (k_input, v_input) node pairs for each SDPA operation
        - start_idx input node for slicing operations
        - end_idx input node for slicing operations
@@ -72,14 +74,14 @@
            if isinstance(dim, torch.SymInt):
                min_max_opt = extract_var_range_info(dim)
                key_shape.append(min_max_opt["max"])
            else:
                key_shape.append(dim)
-        
+
        static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
        return static_tensor
-    
+
    keys_values = get_kv_nodes(gm)

    kv_inputs = []
    for idx, key_value in enumerate(keys_values):
        k_val = key_value[0].meta["val"]
@@ -87,12 +89,12 @@
        if fixed_kv:
            k_val = get_static_tensor(k_val)
            v_val = get_static_tensor(v_val)

        # Add new inputs using add_graph_input
-        k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
-        v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
+        k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+        v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
        kv_inputs.append((k_input, v_input))

    # Add start_idx and end_idx as inputs
    start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0))
    end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1))
@@ -100,18 +102,19 @@
    # Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx
    input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
    # Get the third last input which should be the last value cache node and store the max_seq_len
    input_ids_meta = input_nodes[-3].meta["val"]
    seq_len = input_ids_meta.shape[2]
- 
+
    if isinstance(seq_len, torch.SymInt):
        min_max_opt = extract_var_range_info(seq_len)
        max_seq_len = min_max_opt["max"]
    else:
        max_seq_len = seq_len

    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
    shape_env = ShapeEnv()
    # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
    start_idx_unbacked_symint = shape_env.create_unbacked_symint()
    torch._check(start_idx_unbacked_symint >= 0)
    torch._check(start_idx_unbacked_symint <= max_seq_len)
@@ -127,14 +130,17 @@
    is_causal_input = add_graph_input(gm, "is_causal", True)
    is_causal_input.meta["val"] = torch.tensor(True)

    return kv_inputs, start_idx_input, end_idx_input, is_causal_input

-def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input):
+
+def create_kv_cache_update_nodes(
+    gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input
+):
    """
    Create slicing and concatenation nodes for KV cache update.
-    
+
    This function creates the necessary slicing and concatenation nodes to update the KV cache
    during the generation process. It takes the SDPA node, the current KV cache node, and the
    incoming KV cache node as input.
    Returns:
        for a particular SDPA node, a tuple containing:
@@ -147,78 +153,73 @@
    with gm.graph.inserting_before(sdpa_node):
        slice_1 = gm.graph.create_node(
            "call_function",
            torch.ops.aten.slice.Tensor,
            args=(incoming_kv_node,),
-            kwargs={}
+            kwargs={},
        )
        slice_2 = gm.graph.create_node(
-            "call_function",
-            torch.ops.aten.slice.Tensor,
-            args=(slice_1, 1),
-            kwargs={}
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_1, 1), kwargs={}
        )
        slice_3 = gm.graph.create_node(
            "call_function",
            torch.ops.aten.slice.Tensor,
-            args=(slice_2, 2, None, start_idx_input),  
-            kwargs={}
+            args=(slice_2, 2, None, start_idx_input),
+            kwargs={},
        )
        slice_4 = gm.graph.create_node(
-            "call_function",
-            torch.ops.aten.slice.Tensor,
-            args=(slice_3, 3), 
-            kwargs={}
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_3, 3), kwargs={}
        )
        # Concat key_cache[:,:,:start_idx,:] with current key (k)
        concat_keys_or_values = gm.graph.create_node(
            "call_function",
            torch.ops.aten.cat.default,
-            args=([slice_4, current_kv_node], 2), 
-            kwargs={}
-        )
-
-        # =============================================== # 
+            args=([slice_4, current_kv_node], 2),
+            kwargs={},
+        )
+
+        # =============================================== #
        # Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
        slice_5 = gm.graph.create_node(
            "call_function",
            torch.ops.aten.slice.Tensor,
            args=(incoming_kv_node,),
-            kwargs={}
+            kwargs={},
        )
        slice_6 = gm.graph.create_node(
-            "call_function",
-            torch.ops.aten.slice.Tensor,
-            args=(slice_5, 1),
-            kwargs={}
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_5, 1), kwargs={}
        )
        slice_7 = gm.graph.create_node(
            "call_function",
            torch.ops.aten.slice.Tensor,
-            args=(slice_6, 2, end_idx_input),  
-            kwargs={}
+            args=(slice_6, 2, end_idx_input),
+            kwargs={},
        )
        slice_8 = gm.graph.create_node(
-            "call_function",
-            torch.ops.aten.slice.Tensor,
-            args=(slice_7, 3), 
-            kwargs={}
-        )
-        # =============================================== # 
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_7, 3), kwargs={}
+        )
+        # =============================================== #
        # Concatenate the sliced tensors to build KV cache
        new_incoming_keys_or_values = gm.graph.create_node(
            "call_function",
            torch.ops.aten.cat.default,
-            args=([concat_keys_or_values, slice_8], 2), 
-            kwargs={}
+            args=([concat_keys_or_values, slice_8], 2),
+            kwargs={},
        )
        # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
        new_incoming_keys_or_values.meta.update(incoming_kv_node.meta)

    return concat_keys_or_values, new_incoming_keys_or_values

-def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node):
+
+def insert_kv_slicing_before_sdpa(
+    gm,
+    incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+    start_idx_input: Node,
+    end_idx_input: Node,
+    is_causal_input: Node,
+):
    """
    Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic:
    concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
    concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
    new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
@@ -230,24 +231,40 @@
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == SDPA_OP:
            sdpa_nodes.append(node)
    kv_cache_for_graph = []
    for idx, sdpa_node in enumerate(sdpa_nodes):
-        assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
+        assert (
+            len(sdpa_node.args) == 6
+        ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
        q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
        incoming_key, incoming_value = incoming_keys_values[idx]
-        # For keys  
-        new_current_key_node, new_incoming_key_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input)
+        # For keys
+        new_current_key_node, new_incoming_key_cache_node = (
+            create_kv_cache_update_nodes(
+                gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input
+            )
+        )
        # For values
-        new_current_value_node, new_incoming_value_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input)
+        new_current_value_node, new_incoming_value_cache_node = (
+            create_kv_cache_update_nodes(
+                gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input
+            )
+        )

        # Store the KV cache nodes for the current SDPA node
-        kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node])
+        kv_cache_for_graph.extend(
+            [new_incoming_key_cache_node, new_incoming_value_cache_node]
+        )

        # Update the SDPA node arguments with current key and value nodes
-        sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (attn_mask, dropout_p, is_causal_input)
-    
+        sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (
+            attn_mask,
+            dropout_p,
+            is_causal_input,
+        )
+
    # kv_cache_for_graph.extend([k_node, v_node])
    return gm, kv_cache_for_graph


@_aten_lowering_pass
@@ -255,15 +272,19 @@
    gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
    """Insert KV cache ops in the graph"""
    """Perform insertion of kv-caches and attention kernel."""
    # Add static key and value as inputs to the graph
-    kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True)
-
-    # Build and update the KV cache using computed KV inputs for current token and 
+    kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(
+        gm, fixed_kv=True
+    )
+
+    # Build and update the KV cache using computed KV inputs for current token and
    # incoming keys and values from previous tokens (which were added as inputs)
-    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input)
+    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+        gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input
+    )

    # Call the function to add KV as outputs
    logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)

    gm = clean_up_graph_after_modifications(gm)
@@ -273,7 +294,5 @@
    new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
    gm._out_spec = new_out_spec
    logger.debug("After inserting KV cache into the graph: " + str(gm.graph))

    return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_qwen2.5_components.py	2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_qwen2.5_components.py	2025-05-31 02:01:02.885818+00:00
@@ -14,42 +14,50 @@
import argparse
import sys
import os

# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from register_sdpa import *

ATOL = 1e-5
RTOL = 1e-5


qwen2_5_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
-qwen2_5_model = AutoModelForCausalLM.from_pretrained(
-                qwen2_5_model_name,
-                use_cache=False,
-                attn_implementation="sdpa",
-                num_hidden_layers=1,
-            ).eval().cuda()
+qwen2_5_model = (
+    AutoModelForCausalLM.from_pretrained(
+        qwen2_5_model_name,
+        use_cache=False,
+        attn_implementation="sdpa",
+        num_hidden_layers=1,
+    )
+    .eval()
+    .cuda()
+)
QWEN_CONFIG = qwen2_5_model.config
+

def print_diff(tensor1, tensor2, prefix=""):
    """
    Print the diff between two tensors
    """
-    print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}")
+    print(
+        f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+    )
+

def test_qwen_apply_rotary_pos_emb(args):
    class QwenApplyRotaryPosEmb(nn.Module):
        def __init__(self):
            super().__init__()
-    
+
        def rotate_half(self, x):
            x1 = x[..., : x.shape[-1] // 2]
            x2 = x[..., x.shape[-1] // 2 :]
            return torch.cat((-x2, x1), dim=-1)
-        
+
        def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
            cos = cos.unsqueeze(unsqueeze_dim)
            sin = sin.unsqueeze(unsqueeze_dim)
            q_embed = (q * cos) + (self.rotate_half(q) * sin)
            k_embed = (k * cos) + (self.rotate_half(k) * sin)
@@ -63,95 +71,104 @@
        DTYPE = torch.float16
    elif args.precision == "BF16":
        DTYPE = torch.bfloat16

    # Set precision specific flags
-    use_fp32_acc = False 
+    use_fp32_acc = False
    use_explicit_typing = False
    if args.precision == "FP16":
        enabled_precisions = {torch.float32}
-        use_fp32_acc = True 
+        use_fp32_acc = True
        use_explicit_typing = True
    elif args.precision == "BF16":
        enabled_precisions = {torch.bfloat16}
        use_fp32_acc = False
    else:
        enabled_precisions = {torch.float32}

    model = QwenApplyRotaryPosEmb().eval().cuda().to(DTYPE)
-    # Shapes for Qwen 2.5 
+    # Shapes for Qwen 2.5
    q = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda()
    k = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda()
    cos = torch.randn((1, 5, 128), dtype=DTYPE).cuda()
    sin = torch.randn((1, 5, 128), dtype=DTYPE).cuda()

    pyt_output = model(q, k, cos, sin)
-    
+
    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
    dynamic_shapes = ({2: seq_len}, {2: seq_len}, {1: seq_len}, {1: seq_len})
    ep = torch.export.export(model, (q, k, cos, sin), dynamic_shapes=dynamic_shapes)
-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
-        trt_model = torch_tensorrt.dynamo.compile(ep, 
-                                                inputs=[q, k, cos, sin], 
-                                                enabled_precisions=enabled_precisions,
-                                                disable_tf32=True,
-                                                use_fp32_acc=use_fp32_acc,
-                                                use_explicit_typing=use_explicit_typing,
-                                                debug=args.debug)
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[q, k, cos, sin],
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            debug=args.debug,
+        )
    trt_output = trt_model(q, k, cos, sin)
-    
+
    if isinstance(pyt_output, tuple):
        print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt")
        # print_diff(pyt_output[1], trt_output[1], "Diff b/w pyt and trt")
        assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
    else:
        print_diff(pyt_output, trt_output, "Diff b/w pyt and trt")
        assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
-    
-    
+
+
def test_qwen_attention(args):
-    
+
    DTYPE = torch.float32
    if args.precision == "FP16":
        DTYPE = torch.float16
    elif args.precision == "BF16":
        DTYPE = torch.bfloat16
-    
+
    # Set precision specific flags
-    use_fp32_acc = False 
+    use_fp32_acc = False
    use_explicit_typing = False
    if args.precision == "FP16":
        enabled_precisions = {torch.float32}
-        use_fp32_acc = True 
+        use_fp32_acc = True
        use_explicit_typing = True
    elif args.precision == "BF16":
        enabled_precisions = {torch.bfloat16}
        use_fp32_acc = False
    else:
        enabled_precisions = {torch.float32}

    model = qwen2_5_model.model.layers[0].self_attn.to(DTYPE)
-    # qwen2.5     
+    # qwen2.5
    hidden_states = torch.randn((1, 5, 1536), dtype=DTYPE).cuda()
-    position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda())
+    position_embeddings = (
+        torch.randn((1, 5, 128), dtype=DTYPE).cuda(),
+        torch.randn((1, 5, 128), dtype=DTYPE).cuda(),
+    )

    pyt_output = model(hidden_states, position_embeddings, None)
-    
+
    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
-    ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
+    ep = torch.export.export(
+        model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+    )

-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
-        trt_model = torch_tensorrt.dynamo.compile(ep, 
-                                                inputs=[hidden_states, position_embeddings, None], 
-                                                enabled_precisions=enabled_precisions,
-                                                disable_tf32=True,
-                                                use_fp32_acc=use_fp32_acc,
-                                                use_explicit_typing=use_explicit_typing,
-                                                debug=args.debug)
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[hidden_states, position_embeddings, None],
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            debug=args.debug,
+        )
    trt_output = trt_model(hidden_states, position_embeddings, None)
-    
+
    if isinstance(pyt_output, tuple):
        print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt")
        assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
    else:
        print_diff(pyt_output, trt_output, "Diff b/w pyt and trt")
@@ -161,14 +178,17 @@
if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(
        description="Run test cases for llama attention and decoder"
    )
    arg_parser.add_argument(
-        "--debug",
-        action="store_true",
-        help="Enable debug (default: False)"
+        "--debug", action="store_true", help="Enable debug (default: False)"
    )
-    arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32")
+    arg_parser.add_argument(
+        "--precision",
+        type=str,
+        default="FP16",
+        help="Precision to use in the model. Options: FP16, BF16, FP32",
+    )
    args = arg_parser.parse_args()
    with torch.inference_mode():
        # test_qwen_apply_rotary_pos_emb(args)
        test_qwen_attention(args)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/utils.py	2025-05-31 02:00:39.720818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/utils.py	2025-05-31 02:01:02.902987+00:00
@@ -2,13 +2,14 @@
from transformers import StoppingCriteriaList
from transformers.generation.stopping_criteria import (
    EosTokenCriteria,
    MaxLengthCriteria,
)
-import numpy as np 
-import copy 
+import numpy as np
+import copy
import timeit
+

def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
    """
    Exports the LLM model into an ExportedProgram with dynamic shapes.
    In the case of guard failures due to some PyTorch kernel implements, we also
@@ -20,49 +21,60 @@
        position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device)
        try:
            print("Trying to export the model using torch.export.export()..")
            # strict=False only enables aotautograd tracing and excludes dynamo.
            ep = torch.export.export(
-                model, args=(inputs,), kwargs={"position_ids":position_ids}, dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False
+                model,
+                args=(inputs,),
+                kwargs={"position_ids": position_ids},
+                dynamic_shapes=({1: seq_len}, {1: seq_len}),
+                strict=False,
            )
        except:
            print(
                "Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
            )
            # This API is used to express the constraint violation guards as asserts in the graph.
            ep = torch.export._trace._export(
                model,
                args=(inputs,),
-                kwargs={"position_ids":position_ids},
+                kwargs={"position_ids": position_ids},
                dynamic_shapes=({1: seq_len}, {1: seq_len}),
                strict=False,
                allow_complex_guards_as_runtime_asserts=True,
            )

    return ep

+
def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule):
    """
    Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule.
-    
+
    This function identifies placeholder nodes in the graph that represent KV cache tensors,
    and creates zeroed tensors with the same shape, dtype, and device as the original placeholders.
-    
+
    Args:
        model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders
-        
+
    Returns:
        tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph
    """
    # placeholder nodes are expected to be in the following order:
    # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx
    placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"]
    # The first two inputs are input_ids, position_ids. The last three inputs are start_idx, end_idx and is_causal. In between are the KV cache tensors.
    kv_cache_inputs = placeholder_nodes[2:-3]
    zeroed_kv_cache_inputs = []
    for input in kv_cache_inputs:
-        zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0")))
+        zeroed_kv_cache_inputs.append(
+            torch.zeros(
+                input.meta["val"].shape,
+                dtype=input.meta["val"].dtype,
+                device=torch.device("cuda:0"),
+            )
+        )

    return tuple(zeroed_kv_cache_inputs)


def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True):
@@ -75,11 +87,11 @@
            EosTokenCriteria(eos_token_id=eos_token_id),
        ]
    )
    isl = input_seq.shape[1]
    osl = max_output_seq_length - isl
-    
+
    num_tokens_generated = 0
    while num_tokens_generated < osl:
        position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda()
        outputs = model(input_seq, position_ids)
        logits = outputs.logits
@@ -91,10 +103,11 @@
        if not benchmark and stopping_criteria(input_seq, logits).item():
            break

    return input_seq

+
def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id):
    """
    Greedy decoding of the model with KV cache.
    """
    start_idx = 0
@@ -105,38 +118,48 @@
    logits_concat = []
    num_tokens_generated = 0
    kv_cache = get_zeroed_kv_cache_inputs(model)
    while end_idx < max_output_seq_length:
        is_causal = True if input_seq.shape[1] > 1 else False
-        position_ids = torch.tensor([[start_idx]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids
-        input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx, is_causal)
+        position_ids = (
+            torch.tensor([[start_idx]], dtype=torch.int64).cuda()
+            if input_seq.shape[1] == 1
+            else position_ids
+        )
+        input_signature = (
+            input_seq,
+            position_ids,
+            *kv_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
        logits_keys_values = model(*input_signature)
        num_tokens_generated += 1
        logits = logits_keys_values[0]
        logits_concat.append(logits)
        kv_cache = logits_keys_values[1:]
        next_token_logits = logits[:, -1, :]
        next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        output_seq = torch.cat([output_seq, next_tokens], dim=-1)
        input_seq = next_tokens
        start_idx = end_idx
-        end_idx = start_idx + 1 
+        end_idx = start_idx + 1
    lkv = torch.cat(logits_concat, dim=1)
    return output_seq
+

def time_generate(
    generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10
):
    """
    Measure the time for generating a sentence over certain number of iterations
    """
    timings = []
    for _ in range(iterations):
        start_time = timeit.default_timer()
-        _ = generate_fn(
-            model, inputs, output_seq_length, eos_token_id
-        )
+        _ = generate_fn(model, inputs, output_seq_length, eos_token_id)
        torch.cuda.synchronize()
        end_time = timeit.default_timer()
        timings.append(end_time - start_time)

    return timings
@@ -164,6 +187,6 @@
        "Median-Latency(ms)": time_med * 1000,
        "Mean-Latency(ms)": time_mean * 1000,
        "Latency-StdDev(ms)": time_std * 1000,
        "Compile Time(s)": compile_time_s,
    }
-    return stats
\ No newline at end of file
+    return stats
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_llama_components.py	2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_llama_components.py	2025-05-31 02:01:03.012288+00:00
@@ -14,253 +14,359 @@
import argparse
import sys
import os

# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from register_sdpa import *
+
ATOL = 1e-5
RTOL = 1e-5


# llama2_model_name = "meta-llama/Llama-2-7b-hf"
llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct"
-llama_model = AutoModelForCausalLM.from_pretrained(
-                llama3_model_name,
-                use_cache=False,
-                attn_implementation="sdpa",
-                num_hidden_layers=1,
-            ).eval().cuda()
+llama_model = (
+    AutoModelForCausalLM.from_pretrained(
+        llama3_model_name,
+        use_cache=False,
+        attn_implementation="sdpa",
+        num_hidden_layers=1,
+    )
+    .eval()
+    .cuda()
+)
LLAMA_CONFIG = llama_model.config
+

def test_llama_attention(args):
    class LlamaAttentionBlock(nn.Module):
        def __init__(self):
            super().__init__()
            self.config = LLAMA_CONFIG
-            self.attn = LlamaAttention(
-                config=self.config,
-                layer_idx=0
+            self.attn = LlamaAttention(config=self.config, layer_idx=0)
+
+        def forward(self, hidden_states, position_embeddings):
+            attn_output, attn_weights = self.attn(
+                hidden_states, position_embeddings, None
            )
-        def forward(self, hidden_states, position_embeddings):
-            attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None)
            return attn_output
-    
+
    DTYPE = torch.float32
    if args.precision == "FP16":
        DTYPE = torch.float16
    elif args.precision == "BF16":
        DTYPE = torch.bfloat16
-    
+
    # Set precision specific flags
-    use_fp32_acc = False 
+    use_fp32_acc = False
    use_explicit_typing = False
    if args.precision == "FP16":
        enabled_precisions = {torch.float32}
-        use_fp32_acc = True 
+        use_fp32_acc = True
        use_explicit_typing = True
    elif args.precision == "BF16":
        enabled_precisions = {torch.bfloat16}
        use_fp32_acc = False
    else:
        enabled_precisions = {torch.float32}

    # model = LlamaAttentionBlock().eval().cuda().to(DTYPE)
    model = llama_model.model.layers[0].self_attn.to(DTYPE)
-    # llama3 
+    # llama3
    hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
-    position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda())
+    position_embeddings = (
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+    )

    pyt_output = model(hidden_states, position_embeddings, None)
-    
+
    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
-    ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
-    
-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
-        trt_model = torch_tensorrt.dynamo.compile(ep, 
-                                                inputs=[hidden_states, position_embeddings, None], 
-                                                enabled_precisions=enabled_precisions,
-                                                disable_tf32=True,
-                                                use_fp32_acc=use_fp32_acc,
-                                                use_explicit_typing=use_explicit_typing,
-                                                debug=args.debug)
+    ep = torch.export.export(
+        model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+    )
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[hidden_states, position_embeddings, None],
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            debug=args.debug,
+        )
    trt_output = trt_model(hidden_states, position_embeddings, None)
    if isinstance(pyt_output, tuple):
-        print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}")
+        print(
+            f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+        )
        breakpoint()
        assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
    else:
        print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")
        assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)

+
def print_diff(tensor1, tensor2, prefix=""):
    """
    Print the diff between two tensors
    """
-    print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}")
+    print(
+        f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+    )
+

def test_llama_attention_with_static_cache(args):
    class LlamaAttentionBlock(nn.Module):
        def __init__(self):
            super().__init__()
            self.config = LLAMA_CONFIG
-            self.attn = LlamaAttention(
-                config=self.config,
-                layer_idx=0
+            self.attn = LlamaAttention(config=self.config, layer_idx=0)
+
+        def forward(self, hidden_states, position_embeddings):
+            attn_output, attn_weights = self.attn(
+                hidden_states, position_embeddings, None
            )
-        def forward(self, hidden_states, position_embeddings):
-            attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None)
            return attn_output
-    
+
    DTYPE = torch.float32
    model = llama_model.model.layers[0].self_attn.to(DTYPE)

-    # Inputs 
+    # Inputs
    ISL = 2048
    NUM_TOKENS = 128
    OSL = ISL + NUM_TOKENS
    hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda()
-    position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda())
+    position_embeddings = (
+        torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+        torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+    )
    key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
    value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
    start_idx = 0
    end_idx = ISL
    is_causal = True

    pyt_output = model(hidden_states, position_embeddings, None)
    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
-    ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
+    ep = torch.export.export(
+        model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+    )
    import register_sdpa
    import static_cache2
-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
-        trt_model = torch_tensorrt.dynamo.compile(ep, 
-                                                inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal], 
-                                                enabled_precisions={torch.float32},
-                                                disable_tf32=True,
-                                                debug=args.debug, 
-                                                # offload_module_to_cpu=True, 
-                                                use_python_runtime=True)
-    
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[
+                hidden_states,
+                position_embeddings,
+                None,
+                key_cache,
+                value_cache,
+                start_idx,
+                end_idx,
+                is_causal,
+            ],
+            enabled_precisions={torch.float32},
+            disable_tf32=True,
+            debug=args.debug,
+            # offload_module_to_cpu=True,
+            use_python_runtime=True,
+        )
+
    # Test Prefill
-    trt_output, _, key_cache, value_cache = trt_model(hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal)
+    trt_output, _, key_cache, value_cache = trt_model(
+        hidden_states,
+        position_embeddings,
+        None,
+        key_cache,
+        value_cache,
+        start_idx,
+        end_idx,
+        is_causal,
+    )
    print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]")

    # Test Generate
    for start_idx in range(2048, 2176):
        end_idx = start_idx + 1
        hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda()
-        position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda())
+        position_embeddings_curr = (
+            torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+            torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+        )
        # Concatenate the current  hidden_states with the previous ones
        hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
-        position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1))
-        
+        position_embeddings_full = (
+            torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1),
+            torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1),
+        )
+
        is_causal = False
        out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None)
-        out_trt, _, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, None, key_cache, value_cache, start_idx, end_idx, is_causal)
+        out_trt, _, key_cache, value_cache = trt_model(
+            hidden_states_curr,
+            position_embeddings_curr,
+            None,
+            key_cache,
+            value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
        out_pyt = out_no_cache[:, -1:, :]
        print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")

        hidden_states = hidden_states_full
        position_embeddings = position_embeddings_full


def test_llama_decoder(args):
-    
+
    DTYPE = torch.float32
    model = llama_model.model.layers[0].to(DTYPE)
-    # llama3 
+    # llama3
    hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
-    position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda())
+    position_embeddings = (
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+    )

    pyt_output = model(hidden_states, position_embeddings)
    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
-    ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes)
-    
-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
-        trt_model = torch_tensorrt.dynamo.compile(ep, 
-                                                inputs=[hidden_states, position_embeddings], 
-                                                enabled_precisions={torch.float32},
-                                                debug=args.debug)
+    ep = torch.export.export(
+        model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+    )
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[hidden_states, position_embeddings],
+            enabled_precisions={torch.float32},
+            debug=args.debug,
+        )
    trt_output = trt_model(hidden_states, position_embeddings)

    print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")
    assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+

def test_llama_decoder_with_static_cache(args):

    class LlamaDecoderLayerBlock(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.config = LLAMA_CONFIG
-            self.decoder = LlamaDecoderLayer(
-                config=self.config,
-                layer_idx=0)
+            self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0)
            self.model = model
+
        def forward(self, hidden_states, position_embeddings):
            return self.model(hidden_states, position_embeddings=position_embeddings)

    DTYPE = torch.float32
    model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE))
-    
-    # Inputs 
+
+    # Inputs
    ISL = 2048
    NUM_TOKENS = 128
    OSL = ISL + NUM_TOKENS
    hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda()
-    position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda())
+    position_embeddings = (
+        torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+        torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+    )
    key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
    value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
    start_idx = 0
    end_idx = ISL
    is_causal = True

    pyt_output = model(hidden_states, position_embeddings)
    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
-    ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes)
+    ep = torch.export.export(
+        model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+    )
    import register_sdpa
    import static_cache2
-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
-        trt_model = torch_tensorrt.dynamo.compile(ep, 
-                                                arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal], 
-                                                enabled_precisions={torch.float32},
-                                                disable_tf32=True,
-                                                debug=args.debug, 
-                                                # offload_module_to_cpu=True, 
-                                                use_python_runtime=True)
-    
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            arg_inputs=[
+                hidden_states,
+                position_embeddings,
+                key_cache,
+                value_cache,
+                start_idx,
+                end_idx,
+                is_causal,
+            ],
+            enabled_precisions={torch.float32},
+            disable_tf32=True,
+            debug=args.debug,
+            # offload_module_to_cpu=True,
+            use_python_runtime=True,
+        )
+
    # Test Prefill
-    trt_output, key_cache, value_cache = trt_model(hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal)
+    trt_output, key_cache, value_cache = trt_model(
+        hidden_states,
+        position_embeddings,
+        key_cache,
+        value_cache,
+        start_idx,
+        end_idx,
+        is_causal,
+    )
    print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]")

    # Test Generate
    for start_idx in range(2048, 2176):
        end_idx = start_idx + 1
        hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda()
-        position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda())
+        position_embeddings_curr = (
+            torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+            torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+        )
        # Concatenate the current  hidden_states with the previous ones
        hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
-        position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1))
-        
+        position_embeddings_full = (
+            torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1),
+            torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1),
+        )
+
        is_causal = False
        out_no_cache = model(hidden_states_full, position_embeddings_full)

-        out_trt, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, key_cache, value_cache, start_idx, end_idx, is_causal)
+        out_trt, key_cache, value_cache = trt_model(
+            hidden_states_curr,
+            position_embeddings_curr,
+            key_cache,
+            value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
        out_pyt = out_no_cache[0][:, -1:, :]
        print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
        hidden_states = hidden_states_full
        position_embeddings = position_embeddings_full

+
def test_llama_model_with_static_cache(args):

    DTYPE = torch.float32
    model = llama_model.model.to(DTYPE)

-    # Inputs 
+    # Inputs
    ISL = 2048
    NUM_TOKENS = 128
    OSL = ISL + NUM_TOKENS
    input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda()
    position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda()
@@ -271,66 +377,77 @@
    is_causal = True

    pyt_output = model(input_ids)
    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
    dynamic_shapes = ({1: seq_len}, {1: seq_len})
-    kwarg_inputs = {"input_ids":input_ids, "position_ids":position_ids}
-    ep = torch.export.export(model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes)
+    kwarg_inputs = {"input_ids": input_ids, "position_ids": position_ids}
+    ep = torch.export.export(
+        model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes
+    )

    import register_sdpa
    import static_cache2
-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
-        trt_model = torch_tensorrt.dynamo.compile(ep, 
-                                                arg_inputs=[], 
-                                                kwarg_inputs=kwarg_inputs,
-                                                enabled_precisions={torch.float32},
-                                                disable_tf32=True,
-                                                debug=args.debug, 
-                                                # offload_module_to_cpu=True, 
-                                                use_python_runtime=True)
-    
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            arg_inputs=[],
+            kwarg_inputs=kwarg_inputs,
+            enabled_precisions={torch.float32},
+            disable_tf32=True,
+            debug=args.debug,
+            # offload_module_to_cpu=True,
+            use_python_runtime=True,
+        )
+
    # Test Prefill
-    trt_output, key_cache, value_cache = trt_model(input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal)
+    trt_output, key_cache, value_cache = trt_model(
+        input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal
+    )
    pyt_output = pyt_output.last_hidden_state
    print_diff(pyt_output, trt_output, "pyt_output vs trt_output [Prefill]")

    # Test Generate
    for start_idx in range(2048, 2176):
        end_idx = start_idx + 1
        input_ids_curr = torch.randint(1, 20, (1, 1), dtype=torch.int64).cuda()
        position_ids_curr = torch.tensor([[start_idx]], dtype=torch.int64).cuda()
-        
+
        # Concatenate the current  hidden_states with the previous ones
        input_ids_full = torch.cat((input_ids, input_ids_curr), dim=1)
        position_ids_full = torch.cat((position_ids, position_ids_curr), dim=1)
        is_causal = False
-        kwarg_inputs = {"input_ids":input_ids_full, "position_ids":position_ids_full}
+        kwarg_inputs = {"input_ids": input_ids_full, "position_ids": position_ids_full}
        out_no_cache = model(**kwarg_inputs)

-        out_trt, key_cache, value_cache = trt_model(input_ids_curr, position_ids_curr, key_cache, value_cache, start_idx, end_idx, is_causal)
+        out_trt, key_cache, value_cache = trt_model(
+            input_ids_curr,
+            position_ids_curr,
+            key_cache,
+            value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
        out_pyt = out_no_cache.last_hidden_state[:, -1:, :]
        print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
        input_ids = input_ids_full
        position_ids = position_ids_full

+
if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(
        description="Run test cases for llama attention and decoder"
    )
    arg_parser.add_argument(
-        "--debug",
-        action="store_true",
-        help="Enable debug (default: False)"
+        "--debug", action="store_true", help="Enable debug (default: False)"
    )
    arg_parser.add_argument(
-        "--precision",
-        type=str,
-        default="FP16",
-        help="Precision (default: FP16)"
+        "--precision", type=str, default="FP16", help="Precision (default: FP16)"
    )
    args = arg_parser.parse_args()
    with torch.inference_mode():
        test_llama_attention(args)
        # test_llama_decoder(args)
        # test_llama_attention_with_static_cache(args)
        # test_llama_decoder_with_static_cache(args)
-        # test_llama_model_with_static_cache(args)
\ No newline at end of file
+        # test_llama_model_with_static_cache(args)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/register_sdpa.py	2025-05-31 02:00:39.720818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/register_sdpa.py	2025-05-31 02:01:03.019858+00:00
@@ -72,13 +72,11 @@
                if len(node.args) == 6:
                    query, key, value, dropout_p, is_causal, return_debug_mask = (
                        node.args
                    )
                if len(node.args) == 5:
-                    query, key, value, dropout_p, is_causal = (
-                        node.args
-                    )
+                    query, key, value, dropout_p, is_causal = node.args
                elif len(node.args) == 3:
                    query, key, value = node.args
                    dropout_p = 0.0
                    is_causal = True
                else:
@@ -95,11 +93,14 @@
            # The input args is (query, key, value, is_causal). kwargs has scale
            with gm.graph.inserting_after(node):
                new_node = gm.graph.call_function(
                    torch.nn.functional.scaled_dot_product_attention,
                    args=modified_input_args,
-                    kwargs={"scale": node.kwargs.get("scale", None), "use_fp32_acc": settings.use_fp32_acc},
+                    kwargs={
+                        "scale": node.kwargs.get("scale", None),
+                        "use_fp32_acc": settings.use_fp32_acc,
+                    },
                )

                # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
                new_node.meta = copy.copy(node.meta)
                # Check if there's a getitem node following this attention node
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/sdpa_converter.py	2025-05-31 02:00:39.720818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/sdpa_converter.py	2025-05-31 02:01:03.072635+00:00
@@ -74,12 +74,12 @@
    #     query = cast_trt_tensor(
    #             ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir
    #         )
    #     key = cast_trt_tensor(
    #             ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir
-            # )
-    
+    # )
+
    if scale is None:
        scale = query.shape[-1]
        if scale < 0:
            # dynamic shape
            scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
@@ -112,11 +112,11 @@
        name + "_mm",
        query,
        key,
        other_matrix_op=trt.MatrixOperation.TRANSPOSE,
    )
-    
+
    # if use_fp32_acc:
    #     mm = cast_trt_tensor(
    #             ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir
    #         )

@@ -162,11 +162,11 @@
        scaled_add_attn_bias = impl.elementwise.add(
            ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
        )
    else:
        scaled_add_attn_bias = mm
-    
+
    # Create a if condition to check if is_causal is True
    if isinstance(is_causal, TRTTensor):
        if_layer = ctx.net.add_if_conditional()
        condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, mm
        if_layer.set_condition(condition)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_static_cache.py	2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_static_cache.py	2025-05-31 02:01:03.107457+00:00
@@ -17,61 +17,83 @@
ATOL = 1e-5
RTOL = 1e-5
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

+
class DynamicCacheModel(nn.Module):
    def __init__(self):
        super().__init__()
-        
+
    def forward(self, q, k, v, k1, v1, flag):
-        def true_fn(q, k, v, k1, v1):   
+        def true_fn(q, k, v, k1, v1):
            k_new = torch.cat((k, k1), dim=2)
            v_new = torch.cat((v, v1), dim=2)
            return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new)

        def false_fn(q, k, v, k1, v1):
            return torch._C._nn.scaled_dot_product_attention(q, k, v)

        out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1))

        return 2 * out
-    
+
+
class ModelNoCache(nn.Module):
    def __init__(self):
        super().__init__()
-        
+
    def forward(self, q, k, v):
-        return torch._C._nn.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
+        return torch._C._nn.scaled_dot_product_attention(
+            q, k, v, dropout_p=0.0, is_causal=True
+        )
+

class StaticCacheModel(nn.Module):
    def __init__(self):
        super().__init__()
-        
-    # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): 
+
+    # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
    #     new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2)
    #     new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2)
    #     out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal)
-        
+
    #     return out, new_key_cache, new_value_cache
-    
-    def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): 
-        concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)  # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ]
+
+    def forward(
+        self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+    ):
+        concat_keys = torch.cat(
+            (key_cache[:, :, :start_idx, :], k), dim=2
+        )  # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ]
        concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
        new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
-        new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
-        out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal)
-        
+        new_value_cache = torch.cat(
+            (concat_values, value_cache[:, :, end_idx:, :]), dim=2
+        )
+        out = torch._C._nn.scaled_dot_product_attention(
+            q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
+        )
+
        return out, new_key_cache, new_value_cache


-def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
-        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
+def eager_sdpa(
+    query,
+    key,
+    value,
+    attn_mask=None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+    enable_gqa=False,
+) -> torch.Tensor:
    """
    Eager implementation of SDPA
    """
    import math
+
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    breakpoint()
    if is_causal:
@@ -85,24 +107,28 @@
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
-        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
-        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
+        key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
+        value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

+
def print_diff(tensor1, tensor2, prefix=""):
    """
    Print the diff between two tensors
    """
-    print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}")
+    print(
+        f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+    )
+

def test_static_cache_model(args):
    """
    Test the static cache model
    """
@@ -117,13 +143,15 @@

        # Test Prefill
        start_idx = 0
        end_idx = 2048
        out_no_cache = model_no_cache(q, k, v)
-        out_static_cache, new_key_cache, new_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True)
+        out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+            q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+        )
        assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL)
-        
+
        # Test Generate
        for start_idx in range(2048, 2176):
            end_idx = start_idx + 1
            q_curr = torch.randn(1, 32, 1, 64).cuda()
            k_curr = torch.randn(1, 32, 1, 64).cuda()
@@ -133,17 +161,29 @@
            q_full = torch.cat((q, q_curr), dim=2)
            k_full = torch.cat((k, k_curr), dim=2)
            v_full = torch.cat((v, v_curr), dim=2)

            out_no_cache = model_no_cache(q_full, k_full, v_full)
-            out_static_cache, new_key_cache, new_value_cache = model_static_cache(q_curr, k_curr, v_curr, new_key_cache, new_value_cache, start_idx, end_idx, is_causal=False)
-
-            assert torch.allclose(out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL)
-            q = q_full 
+            out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+                q_curr,
+                k_curr,
+                v_curr,
+                new_key_cache,
+                new_value_cache,
+                start_idx,
+                end_idx,
+                is_causal=False,
+            )
+
+            assert torch.allclose(
+                out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL
+            )
+            q = q_full
            k = k_full
            v = v_full
        print("============== test_static_cache passed ==============")
+

def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args):
    """
    Transform the graph module by adding key and value cache to the graph
    """
@@ -155,52 +195,53 @@
        use_python_runtime=True,
        debug=args.debug,
        min_block_size=1,
    )
    exported_program = pre_export_lowering(exported_program, settings)
-    exported_program = exported_program.run_decompositions(
-        get_decompositions(False)
-    )
+    exported_program = exported_program.run_decompositions(get_decompositions(False))

    gm = exported_program.module()
    gm = post_lowering(gm, settings)

    return gm

+
def test_static_cache_lowering(args):
    """
-    Test static cache lowering pass applied to the model with no cache and run the graph module 
+    Test static cache lowering pass applied to the model with no cache and run the graph module
    and compare the output with the model with no cache
    """
    import static_cache2

    model_no_cache = ModelNoCache().eval().cuda()
    q = torch.randn(1, 32, 2, 64).cuda()
    k = torch.randn(1, 32, 2048, 64).cuda()
    v = torch.randn(1, 32, 2048, 64).cuda()
    key_cache = torch.zeros(1, 32, 2176, 64).cuda()
    value_cache = torch.zeros(1, 32, 2176, 64).cuda()
-    
+
    # Export the model
    q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
    kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
    exported_program = export(
        model_no_cache,
        args=(q, k, v),
-        dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}),
-        strict=False
+        dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+        strict=False,
    )

    gm = transform_gm_with_kv_cache(exported_program, args)

    # Test Prefill
    start_idx = 0
    end_idx = 2048
    is_causal = True
    q = torch.randn(1, 32, 2048, 64).cuda()
    out_no_cache = model_no_cache(q, k, v)
-    out_pyt_cache, key_cache, value_cache = gm(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal)
+    out_pyt_cache, key_cache, value_cache = gm(
+        q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal
+    )
    assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL)

    # Test Generate
    for start_idx in range(2048, 2176):
        end_idx = start_idx + 1
@@ -209,20 +250,32 @@
        k_curr = torch.randn(1, 32, 1, 64).cuda()
        v_curr = torch.randn(1, 32, 1, 64).cuda()
        # Concatenate the current query, key, and value with the previous ones
        q_full = torch.cat((q, q_curr), dim=2)
        k_full = torch.cat((k, k_curr), dim=2)
-        v_full = torch.cat((v, v_curr), dim=2)   
-        
+        v_full = torch.cat((v, v_curr), dim=2)
+
        out_no_cache = model_no_cache(q_full, k_full, v_full)
-        out_pyt_static_cache, key_cache, value_cache = gm(q_curr, k_curr, v_curr, key_cache, value_cache, start_idx, end_idx, is_causal)
-        assert torch.allclose(out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL)
-        q = q_full 
+        out_pyt_static_cache, key_cache, value_cache = gm(
+            q_curr,
+            k_curr,
+            v_curr,
+            key_cache,
+            value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
+        assert torch.allclose(
+            out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL
+        )
+        q = q_full
        k = k_full
        v = v_full
-    
+
    print("============== test_static_cache_lowering passed ==============")
+

def test_static_cache_export(args):
    """
    Test the static cache model export
    """
@@ -236,19 +289,28 @@
    start_idx = 0
    end_idx = 2048
    is_causal = True
    # Export the model
    seq_len = torch.export.Dim("seq_len", min=2, max=2048)
-    seq_len_dyn_dim = {2 : seq_len}
+    seq_len_dyn_dim = {2: seq_len}
    exported_program = export(
        model_static_cache,
        args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal),
-        dynamic_shapes=(seq_len_dyn_dim, seq_len_dyn_dim, seq_len_dyn_dim, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC, None),
-        strict=False
-    )
-    
-        
+        dynamic_shapes=(
+            seq_len_dyn_dim,
+            seq_len_dyn_dim,
+            seq_len_dyn_dim,
+            None,
+            None,
+            torch.export.Dim.DYNAMIC,
+            torch.export.Dim.DYNAMIC,
+            None,
+        ),
+        strict=False,
+    )
+
+
def test_static_cache_with_torch_tensorrt(args):
    """
    Test the static cache model with torch_tensorrt
    """
    import static_cache2
@@ -257,83 +319,104 @@
    q = torch.randn(1, 32, 2, 64).cuda()
    k = torch.randn(1, 32, 2048, 64).cuda()
    v = torch.randn(1, 32, 2048, 64).cuda()
    key_cache = torch.zeros(1, 32, 2176, 64).cuda()
    value_cache = torch.zeros(1, 32, 2176, 64).cuda()
-    
+
    # Export the model
    q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
    kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
    exported_program = export(
        model_no_cache,
        args=(q, k, v),
-        dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}),
-        strict=False
-    )
-    with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
+        dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+        strict=False,
+    )
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
        trt_model = torch_tensorrt.dynamo.compile(
            exported_program,
            inputs=[q, k, v],
            enabled_precisions={torch.float32},
            disable_tf32=True,
            use_python_runtime=True,
            debug=args.debug,
            min_block_size=1,
        )
-    
+
    start_idx = 0
    end_idx = 2048
    is_causal = True
    q = torch.randn(1, 32, 2048, 64).cuda()
    # out_eager = eager_sdpa(q, k, v, is_causal=is_causal)
    out_no_cache = model_no_cache(q, k, v)
-    out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal)
-
-    assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL), "Prefill TRT logits don't match"
-    assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL), "Prefill TRT key cache don't match"
-    assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL), "Prefill TRT value cache don't match"
-    
+    out_trt, trt_key_cache, trt_value_cache = trt_model(
+        q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal
+    )
+
+    assert torch.allclose(
+        out_no_cache, out_trt, atol=ATOL, rtol=RTOL
+    ), "Prefill TRT logits don't match"
+    assert torch.allclose(
+        trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL
+    ), "Prefill TRT key cache don't match"
+    assert torch.allclose(
+        trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL
+    ), "Prefill TRT value cache don't match"
+
    # Test Generate
    for start_idx in range(2048, 2176):
        end_idx = start_idx + 1
        q_curr = torch.randn(1, 32, 1, 64).cuda()
        k_curr = torch.randn(1, 32, 1, 64).cuda()
-        v_curr = torch.randn(1, 32, 1, 64).cuda()   
+        v_curr = torch.randn(1, 32, 1, 64).cuda()
        # Concatenate the current query, key, and value with the previous ones
        q_full = torch.cat((q, q_curr), dim=2)
        k_full = torch.cat((k, k_curr), dim=2)
-        v_full = torch.cat((v, v_curr), dim=2)   
+        v_full = torch.cat((v, v_curr), dim=2)
        is_causal = False
        out_no_cache = model_no_cache(q_full, k_full, v_full)
-        out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal)
+        out_trt, trt_key_cache, trt_value_cache = trt_model(
+            q_curr,
+            k_curr,
+            v_curr,
+            trt_key_cache,
+            trt_value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
        # breakpoint()
        # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}")
        # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}")
        # print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}")
-        assert torch.allclose(out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL), f"Generate TRT logits don't match for idx {start_idx}"
-        assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL), f"Generate TRT key cache don't match for idx {start_idx}"
-        assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL), f"Generate TRT value cache don't match for idx {start_idx}"
-        q = q_full 
+        assert torch.allclose(
+            out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL
+        ), f"Generate TRT logits don't match for idx {start_idx}"
+        assert torch.allclose(
+            trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL
+        ), f"Generate TRT key cache don't match for idx {start_idx}"
+        assert torch.allclose(
+            trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL
+        ), f"Generate TRT value cache don't match for idx {start_idx}"
+        q = q_full
        k = k_full
        v = v_full

    print("============== test_static_cache_with_torch_tensorrt passed ==============")
-    
+

def main():
    arg_parser = argparse.ArgumentParser(
        description="Run test cases for llama attention and decoder"
    )
    arg_parser.add_argument(
-        "--debug",
-        action="store_true",
-        help="Enable debug (default: False)"
+        "--debug", action="store_true", help="Enable debug (default: False)"
    )
    args = arg_parser.parse_args()
    with torch.inference_mode():
        test_static_cache_model(args)
        test_static_cache_lowering(args)
        test_static_cache_with_torch_tensorrt(args)
-    
+

if __name__ == "__main__":
-    main()
\ No newline at end of file
+    main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-31 02:00:39.730818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-31 02:01:04.362893+00:00
@@ -692,10 +692,11 @@

    gm = exported_program.module()
    exported_program.module().to("cpu")
    torch.cuda.empty_cache()
    import gc
+
    gc.collect()
    logger.debug("Input graph: " + str(gm.graph))

    # Apply lowering on the graph module
    gm = post_lowering(gm, settings)
@@ -790,31 +791,30 @@
        # TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this.
        logger.warning(
            "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
        )

-    
    # Store the original input spec for later use
-    original_in_spec = getattr(gm, '_in_spec', None)
-    original_out_spec = getattr(gm, '_out_spec', None)
-    
+    original_in_spec = getattr(gm, "_in_spec", None)
+    original_out_spec = getattr(gm, "_out_spec", None)
+
    # Function to preserve and restore module specs
    def preserve_module_specs(in_spec, out_spec, target_module):
        """
        Applies input and output specs to the target module.
-        
+
        Args:
            in_spec: The input spec to apply
            out_spec: The output spec to apply
            target_module: The module to apply specs to
        """
        # Apply specs to target module
        if in_spec is not None:
            target_module._in_spec = in_spec
        if out_spec is not None:
            target_module._out_spec = out_spec
-            
+
        return target_module

    # Partition module into components that can be TRT-accelerated
    fast_partitioner_failed = False
    # If specified, try using the fast partitioner and fall back to the global one on failure
@@ -1197,11 +1197,11 @@
        "enable_weight_streaming": enable_weight_streaming,
        "tiling_optimization_level": tiling_optimization_level,
        "l2_limit_for_tiling": l2_limit_for_tiling,
        "offload_module_to_cpu": offload_module_to_cpu,
    }
-    
+
    settings = CompilationSettings(**compilation_options)
    logger.info("Compilation Settings: %s\n", settings)

    exported_program = pre_export_lowering(exported_program, settings)
    # Decompose the exported program
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py	2025-05-31 02:00:39.735818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py	2025-05-31 02:01:05.149640+00:00
@@ -23,11 +23,15 @@
    """Replace specific versions of scaled_dot_product_attention with an equivalent
    implementation which can be easily converted to TRT
    """
    original_fns, replacement = scaled_dot_product_attention_replacement()
    replaced_nodes = []
-    sdpa_nodes = [node for node in gm.graph.nodes if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default]
+    sdpa_nodes = [
+        node
+        for node in gm.graph.nodes
+        if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default
+    ]
    breakpoint()
    # For each original function, search for it in the graph and replace
    for original in original_fns:
        replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters(
            gm,
@@ -167,6 +171,6 @@
    def replacement(
        query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> torch.Tensor:
        return torch.nn.functional.scaled_dot_product_attention(query, key, value)

-    return (efficient, flash, efficient_scale, flash_scale), replacement
\ No newline at end of file
+    return (efficient, flash, efficient_scale, flash_scale), replacement
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py	2025-05-31 02:00:39.736818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py	2025-05-31 02:01:05.811969+00:00
@@ -369,27 +369,35 @@
            is_shape_tensor_input = self.engine.is_shape_inference_io(input_name)
            if need_cudagraphs_record:
                # If cudagraphs is enabled, this memory is reserved for future cudagraph runs
                # Clone is required to avoid re-using user-provided GPU memory
                if is_shape_tensor_input:
-                    self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].cpu().clone()
+                    self._input_buffers[inputs_shape_key][i] = (
+                        contiguous_inputs[i].cpu().clone()
+                    )
                else:
-                    self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].clone()
+                    self._input_buffers[inputs_shape_key][i] = contiguous_inputs[
+                        i
+                    ].clone()

            # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
            # as per TensorRT requirements
            if is_shape_tensor_input:
                # Shape tensor inputs are casted to int64 explicitly
                # Currently Torch CPU pointers are not working; numpy pointers are used instead
                # to refer to underlying memory
                inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64)
-                inputs_cpu_numpy = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
+                inputs_cpu_numpy = (
+                    contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
+                )
                # if cudagraphs_enabled:
                #     self._input_buffers[inputs_shape_key][i].copy_(inputs_cpu)
                #     self.context.set_tensor_address(input_name, self._input_buffers[inputs_shape_key][i].numpy().copy().ctypes.data)
                # else:
-                self.context.set_tensor_address(input_name, inputs_cpu_numpy.ctypes.data)
+                self.context.set_tensor_address(
+                    input_name, inputs_cpu_numpy.ctypes.data
+                )
            else:
                self.context.set_input_shape(
                    input_name, tuple(contiguous_inputs[i].shape)
                )
                if cudagraphs_enabled:
@@ -458,11 +466,14 @@
                assert len(contiguous_inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."

                self.setup_input_tensors(
-                    contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record, inputs_shape_key
+                    contiguous_inputs,
+                    self.cudagraphs_enabled,
+                    need_cudagraphs_record,
+                    inputs_shape_key,
                )

                if shape_changed:
                    # Check if input shapes can be inferred.
                    uninferred_input_names = self.context.infer_shapes()
@@ -496,11 +507,12 @@
                    if need_cudagraphs_record:
                        self._output_buffers[inputs_shape_key][o] = outputs[o].clone()

                    if self.cudagraphs_enabled:
                        self.context.set_tensor_address(
-                            output_name, self._output_buffers[inputs_shape_key][o].data_ptr()
+                            output_name,
+                            self._output_buffers[inputs_shape_key][o].data_ptr(),
                        )
                    else:
                        self.context.set_tensor_address(
                            output_name, outputs[o].data_ptr()
                        )
@@ -522,30 +534,35 @@
                self._engine_stream.wait_stream(self._caller_stream)

                with torch.cuda.stream(self._engine_stream):
                    if self.cudagraphs_enabled:
                        if need_cudagraphs_record:
-                            
-                            self.shape_key_to_cudagraph[inputs_shape_key] = torch.cuda.CUDAGraph()
+
+                            self.shape_key_to_cudagraph[inputs_shape_key] = (
+                                torch.cuda.CUDAGraph()
+                            )

                            if self.profiling_enabled:
-                                self.shape_key_to_cudagraph[inputs_shape_key].enable_debug_mode()
+                                self.shape_key_to_cudagraph[
+                                    inputs_shape_key
+                                ].enable_debug_mode()

                            with torch.cuda.graph(
-                                self.shape_key_to_cudagraph[inputs_shape_key], stream=self._engine_stream
+                                self.shape_key_to_cudagraph[inputs_shape_key],
+                                stream=self._engine_stream,
                            ):
                                self.context.execute_async_v3(
                                    self._engine_stream.cuda_stream
                                )

                            if self.profiling_enabled:
                                import tempfile

                                with tempfile.TemporaryDirectory() as tmpdir:
-                                    self.shape_key_to_cudagraph[inputs_shape_key].debug_dump(
-                                        f"{tempdir}/{self.name}_cudagraph.dot"
-                                    )
+                                    self.shape_key_to_cudagraph[
+                                        inputs_shape_key
+                                    ].debug_dump(f"{tempdir}/{self.name}_cudagraph.dot")

                        self.shape_key_to_cudagraph[inputs_shape_key].replay()  # type: ignore

                    else:
                        self.context.execute_async_v3(self._engine_stream.cuda_stream)
@@ -754,18 +771,21 @@
        """
        # Representation of input shapes to a given model
        # Shapes are concatenated as so:
        # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
        tensor_inputs = [
-            t if isinstance(t, torch.Tensor) else torch.tensor(t)
-            for t in inputs
+            t if isinstance(t, torch.Tensor) else torch.tensor(t) for t in inputs
        ]
-        new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs)
+        new_shape_key = "".join(
+            str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+        )

        # If the new shape key differs from the existing one,
        # invalidate the old shape key and remove the CUDAGraph
        if new_shape_key not in self.shape_key_to_cudagraph:
-            logger.debug(f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape.")
+            logger.debug(
+                f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape."
+            )
            # self.shape_key = new_shape_key
            return True, new_shape_key

        return False, new_shape_key
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2025-05-31 02:00:39.736818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2025-05-31 02:01:05.838721+00:00
@@ -742,14 +742,15 @@
        """
        # Representation of input shapes to a given model
        # Shapes are concatenated as so:
        # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
        tensor_inputs = [
-            t if isinstance(t, torch.Tensor) else torch.tensor(t)
-            for t in inputs
+            t if isinstance(t, torch.Tensor) else torch.tensor(t) for t in inputs
        ]
-        new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs)
+        new_shape_key = "".join(
+            str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+        )

        # If the new shape key differs from the existing one,
        # invalidate the old shape key and remove the CUDAGraph
        if new_shape_key != self.shape_key:
            logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2025-05-31 02:00:39.732818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2025-05-31 02:01:06.640331+00:00
@@ -1893,10 +1893,11 @@
        SourceIR.ATEN,
        name,
        args[0],
        args[1],
    )
+

@dynamo_tensorrt_converter(operator.sub, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True)
def aten_ops_sub(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: runtime WIP Work is in progress, pull request should not be merged yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants