diff --git a/docsrc/index.rst b/docsrc/index.rst
index 67fbdc56f5..4d28d77640 100644
--- a/docsrc/index.rst
+++ b/docsrc/index.rst
@@ -140,11 +140,10 @@ Model Zoo
* :ref:`torch_compile_resnet`
* :ref:`torch_compile_transformer`
* :ref:`torch_compile_stable_diffusion`
+* :ref:`compile_hf_models`
* :ref:`torch_compile_gpt2`
* :ref:`torch_export_gpt2`
-* :ref:`torch_export_llama2`
* :ref:`torch_export_sam2`
-* :ref:`torch_export_flux_dev`
* :ref:`notebooks`
.. toctree::
@@ -155,11 +154,10 @@ Model Zoo
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
+ tutorials/compile_hf_models
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
tutorials/_rendered_examples/dynamo/torch_compile_gpt2
- tutorials/_rendered_examples/dynamo/torch_export_gpt2
- tutorials/_rendered_examples/dynamo/torch_export_llama2
tutorials/_rendered_examples/dynamo/torch_export_sam2
tutorials/_rendered_examples/dynamo/torch_export_flux_dev
tutorials/notebooks
diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst
new file mode 100644
index 0000000000..f6da87b145
--- /dev/null
+++ b/docsrc/tutorials/compile_hf_models.rst
@@ -0,0 +1,218 @@
+.. _compile_hf_models:
+
+Compiling LLM models from Huggingface
+======================================
+
+This tutorial walks you through how to compile LLM models from Huggingface using Torch-TensorRT. We also introduce KV caching in Torch-TensorRT which can greatly improve the performance of LLM inference.
+The code is available in the `tools/llm `_ directory. We use the ``run_llm.py`` script to compile the model, generate outputs, and measure the performance.
+
+.. note::
+ This is an **experimental release** and APIs may change in future versions.
+
+.. note::
+ The compilation scripts and tutorials for Llama-2-7b-chat-hf and gpt2 models have been consolidated into the unified ``run_llm.py`` script located in the `tools/llm `_ directory.
+
+Overview of tools/llm Directory
+-------------------------------
+
+The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface:
+
+* **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking
+* **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization
+* **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass.
+* **Testing Components**: Model-specific test files for validation
+* **Utility Functions**: ``utils.py`` and ``cache_utils.py`` for common operations
+
+Supported Models
+----------------
+We have officially verified support for the following LLM families:
+
+.. list-table::
+ :widths: 20 40 20 20
+ :header-rows: 1
+
+ * - Model Series
+ - HuggingFace Model Card
+ - Precision
+ - KV Cache Support ?
+ * - GPT-2
+ - gpt2
+ - FP16, FP32
+ - Yes
+ * - LLaMA 2
+ - meta-llama/Llama-2-7b-chat-hf
+ - FP16, FP32
+ - Yes
+ * - LLaMA 3.1
+ - meta-llama/Llama-3.1-8B-Instruct
+ - FP16, FP32
+ - Yes
+ * - LLaMA 3.2
+ - | meta-llama/Llama-3.2-1B-Instruct
+ | meta-llama/Llama-3.2-3B-Instruct
+ - FP16, FP32
+ - Yes
+ * - Qwen 2.5
+ - | Qwen/Qwen2.5-0.5B-Instruct
+ | Qwen/Qwen2.5-1.5B-Instruct
+ | Qwen/Qwen2.5-3B-Instruct
+ | Qwen/Qwen2.5-7B-Instruct
+ - FP16, FP32
+ - Yes
+
+Getting Started with run_llm.py
+-------------------------------
+
+The main entry point is ``run_llm.py``, which provides a complete workflow for model compilation and benchmarking.
+
+Basic Usage
+^^^^^^^^^^^
+
+.. code-block:: bash
+
+ python tools/llm/run_llm.py \
+ --model meta-llama/Llama-3.2-1B-Instruct \
+ --prompt "What is parallel programming?" \
+ --precision FP16 \
+ --num_tokens 128 \
+ --cache static_v2 \
+ --benchmark
+
+Key Arguments
+^^^^^^^^^^^^^
+
+* ``--model``: Name or path of the HuggingFace LLM
+* ``--tokenizer``: (Optional) Tokenizer name; defaults to model name
+* ``--prompt``: Input prompt for text generation
+* ``--precision``: Precision mode (``FP16``, ``FP32``)
+* ``--num_tokens``: Number of output tokens to generate
+* ``--cache``: KV cache type (``static_v1``, ``static_v2``, or empty for no KV caching)
+* ``--benchmark``: Enable benchmarking mode for performance comparison
+* ``--enable_pytorch_run``: Also run and compare PyTorch baseline
+
+
+Other Usage Examples
+^^^^^^^^^^^^^^^^^^^^
+.. code-block:: bash
+
+ # Compare different models performance
+ python tools/llm/run_llm.py --model gpt2 --benchmark --enable_pytorch_run
+ python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --benchmark --enable_pytorch_run
+
+ # Generate the outputs (disable benchmarking) by specifying the number of tokens to generate. Default = 128
+ python tools/llm/run_llm.py --model gpt2 --prompt "What is parallel programming?" --num_tokens 128
+ python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --num_tokens 128
+
+ # Test different caching approaches
+ python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v1
+ python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v2
+
+ # Compare FP16 vs FP32 performance
+ python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP16 --benchmark
+ python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark
+
+
+KV Caching in Torch-TensorRT
+---------------------------------
+
+We provide two versions of static KV caching: `static_cache_v1 `_ and `static_cache_v2 `_.
+In both implementations, we add static KV cache tensors as model inputs/outputs without storing them as external memory.
+The length of KV cache = input sequence length + output sequence length (specified by ``--num_tokens``). The number of heads and head dimension are determined by the model config.
+
+Static Cache v1
+^^^^^^^^^^^^^^^^
+
+The ``static_cache_v1.py`` implements KV cache in the model graph as follows:
+
+.. code-block:: python
+
+ class StaticCacheV1Model(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
+ # Concatenate new key/value pairs with existing cache
+ new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2)
+ new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2)
+
+ # Compute attention using the updated cache
+ attn_output = torch._C._nn.scaled_dot_product_attention(
+ q,
+ new_key_cache[:, :, :end_idx, :],
+ new_value_cache[:, :, :end_idx, :],
+ dropout_p=0.0,
+ is_causal=is_causal
+ )
+
+ return attn_output, new_key_cache, new_value_cache
+
+In the above code, we concatenate the new key/value pairs with the existing cache and update it. To compute the attention, we use the updated cache and gather the corresponding keys/values from the cache up until and including the current token index.
+The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module.
+
+.. note::
+ The ``start_idx`` and ``end_idx`` are the start and end indices of the current token in the cache. For prefill phase, ``start_idx`` is 0 and ``end_idx`` is the input sequence length.
+ For decode phase, ``start_idx`` begins at the input sequence length and ``end_idx`` equals ``start_idx + 1``. The ``start_idx`` is incremented by 1 until the end of the sequence or we reach the maximum number of tokens to generate.
+
+
+Static Cache v2
+^^^^^^^^^^^^^^^^
+
+The ``static_cache_v2.py`` is similar to ``static_cache_v1.py`` but it uses less number of slice operations. It implements KV cache in the model graph as follows:
+
+.. code-block:: python
+
+ class StaticCacheV2Model(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
+ concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
+ concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
+ new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
+ new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
+ attn_output = torch._C._nn.scaled_dot_product_attention(
+ q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
+ )
+
+ return attn_output, new_key_cache, new_value_cache
+
+In the above code, we concatenate the existing key/value cache with current key/value of the token. We use this to directly compute the attention and update the key/value cache inserting the current key/value.
+The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module.
+The definitons of ``start_idx`` and ``end_idx`` are the same as ``static_cache_v1.py``.
+
+After the model is compiled with static KV cache, the input signature of the model is changed. The new input signature is ``(input_ids, position_ids, key_cache_0, value_cache_0, ..., start_idx, end_idx)``.
+The number of key/value cache tensors is equal to the number of attention heads in the model. We can use the ``generate_with_static_cache`` function to generate the outputs.
+
+Generating Outputs
+-------------------
+We use custom `generate `_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching.
+There is also a `generate_with_static_cache `_ function that performs autoregressive decoding with KV caching.
+
+The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache.
+The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``.
+We initialize the key/value cache tensors with zeros and for every token generated, the new key/value cache tensors are the outputs of the model.
+
+SDPA Converter (sdpa_converter.py)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+* Converts scaled dot-product attention operation using TRT Python API.
+* Supports causal and standard self-attention.
+
+SDPA Registration (register_sdpa.py)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+* This is a Torch-TensorRT lowering pass that replaces variants of SDPA with ``torch.nn.functional.scaled_dot_product_attention``.
+* Registers the SDPA converter which is used for converting ``torch.nn.functional.scaled_dot_product_attention`` operation.
+
+
+Limitations and Known Issues
+----------------------------
+
+* Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported
+* Some model architectures (e.g. Phi-4) have issues with exporting the torch model.
+
+Requirements
+^^^^^^^^^^^^
+
+* Torch-TensorRT 2.8.0 or later
+* Transformers v4.52.3
\ No newline at end of file
diff --git a/examples/dynamo/torch_export_gpt2.py b/examples/dynamo/torch_export_gpt2.py
deleted file mode 100644
index 4d34c58de4..0000000000
--- a/examples/dynamo/torch_export_gpt2.py
+++ /dev/null
@@ -1,98 +0,0 @@
-"""
-.. _torch_export_gpt2:
-
-Compiling GPT2 using the dynamo backend
-==========================================================
-
-This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model.
-"""
-
-# %%
-# Imports and Model Definition
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-import torch
-import torch_tensorrt
-from transformers import AutoModelForCausalLM, AutoTokenizer
-from utils import export_llm, generate
-
-# %%
-
-# Define the parameters and initialize the model
-MAX_TOKENS = 32
-DEVICE = torch.device("cuda:0")
-
-# Define the GPT2 model from hugging face
-# kv_cache is not supported in Torch-TRT currently.
-# CPU is used here so that GPU memory is reserved for TRT compilation.
-with torch.no_grad():
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
- model = (
- AutoModelForCausalLM.from_pretrained(
- "gpt2",
- pad_token_id=tokenizer.eos_token_id,
- use_cache=False,
- attn_implementation="eager",
- )
- .eval()
- .half()
- )
-
-# %%
-# Tokenize a sample input prompt and get pytorch model outputs
-prompt = "I enjoy walking with my cute dog"
-model_inputs = tokenizer(prompt, return_tensors="pt")
-input_ids = model_inputs["input_ids"]
-
-# Auto-regressive generation loop for greedy decoding using PyTorch model
-# We use a custom generate function which is very similar to the huggingface one.
-pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
-
-
-# %%
-# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-# Export the GPT2 model into an ExportedProgram which is input of TRT compilation
-# To compile the model in FP16, we do the following
-# 1) Cast the model to FP16 via model.half()
-# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation
-# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch)
-gpt2_ep = export_llm(model, input_ids, max_seq_len=1024)
-trt_model = torch_tensorrt.dynamo.compile(
- gpt2_ep,
- inputs=[input_ids],
- enabled_precisions={torch.float32},
- truncate_double=True,
- device=DEVICE,
- disable_tf32=True,
- use_explicit_typing=True,
- use_fp32_acc=True,
-)
-
-# Auto-regressive generation loop for greedy decoding using TensorRT model
-# We use a custom generate function which is very similar to the huggingface one.
-# Move inputs to GPU
-input_ids = input_ids.to(DEVICE)
-trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
-
-# %%
-# Decode the output sentences of PyTorch and TensorRT
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-print("=============================")
-print(
- "Pytorch model generated text: ",
- tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
-)
-print("=============================")
-print(
- "TensorRT model generated text: ",
- tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
-)
-
-# Prompt : What is parallel programming ?
-
-# =============================
-# Pytorch model generated text: The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that
-
-# =============================
-# TensorRT model generated text: The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that
diff --git a/examples/dynamo/torch_export_llama2.py b/examples/dynamo/torch_export_llama2.py
deleted file mode 100644
index 2f3e3cba43..0000000000
--- a/examples/dynamo/torch_export_llama2.py
+++ /dev/null
@@ -1,102 +0,0 @@
-"""
-.. _torch_export_llama2:
-
-Compiling Llama2 using the dynamo backend
-==========================================================
-
-This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model.
-"""
-
-# %%
-# Imports and Model Definition
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-import torch
-import torch_tensorrt
-from transformers import AutoModelForCausalLM, AutoTokenizer
-from utils import export_llm, generate
-
-# %%
-# Define the parameters and initialize the model
-MAX_TOKENS = 32
-DEVICE = torch.device("cuda:0")
-
-# Define the Llama2 model from hugging face
-# kv_cache is not supported in Torch-TRT currently.
-# CPU is used here so that GPU memory is reserved for TRT compilation.
-llama_path = "meta-llama/Llama-2-7b-chat-hf"
-with torch.no_grad():
- model = (
- AutoModelForCausalLM.from_pretrained(
- llama_path, use_cache=False, attn_implementation="eager"
- )
- .eval()
- .half()
- )
-
-tokenizer = AutoTokenizer.from_pretrained(llama_path)
-
-# %%
-# Tokenize a sample input prompt and get pytorch model outputs
-prompt = "What is dynamic programming?"
-model_inputs = tokenizer(prompt, return_tensors="pt")
-input_ids = model_inputs.input_ids
-
-# Auto-regressive generation loop for greedy decoding using PyTorch model
-# We use a custom generate function which is very similar to the huggingface one.
-pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
-
-# %%
-# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-# Export the llama2 model into an ExportedProgram which is input of TRT compilation
-# To compile the model in FP16, we do the following
-# 1) Cast the model to FP16 via model.half()
-# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation
-# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch)
-llama2_ep = export_llm(model, input_ids, max_seq_len=64)
-trt_model = torch_tensorrt.dynamo.compile(
- llama2_ep,
- inputs=[input_ids],
- enabled_precisions={torch.float32},
- truncate_double=True,
- device=DEVICE,
- disable_tf32=True,
- use_explicit_typing=True,
- use_fp32_acc=True,
-)
-
-# Auto-regressive generation loop for greedy decoding using TensorRT model
-# We use a custom generate function which is very similar to the huggingface one.
-# Move inputs to GPU
-input_ids = input_ids.to(DEVICE)
-trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
-
-# %%
-# Decode the output sentences of PyTorch and TensorRT
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-print("=============================")
-print(
- "Pytorch model generated text: ",
- tokenizer.batch_decode(
- pyt_gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False
- )[0],
-)
-print("=============================")
-print(
- "TensorRT model generated text: ",
- tokenizer.batch_decode(
- trt_gen_tokens,
- skip_special_tokens=True,
- clean_up_tokenization_spaces=False,
- )[0],
-)
-
-
-# Prompt : What is dynamic programming?
-
-# =============================
-# Pytorch model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and
-
-# =============================
-# TensorRT model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and
diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py
deleted file mode 100644
index 25ad99c12d..0000000000
--- a/examples/dynamo/utils.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import torch
-from transformers import StoppingCriteriaList
-from transformers.generation.stopping_criteria import (
- EosTokenCriteria,
- MaxLengthCriteria,
-)
-
-
-def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
- """
- Exports the LLM model into an ExportedProgram with dynamic shapes.
- In the case of guard failures due to some PyTorch kernel implements, we also
- try to re-export the graph by expressing them as runtime assert nodes
- """
- with torch.no_grad():
- # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
- seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
- try:
- print("Trying to export the model using torch.export.export()..")
- # strict=False only enables aotautograd tracing and excludes dynamo.
- ep = torch.export.export(
- model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False
- )
- except:
- print(
- "Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
- )
- # This API is used to express the constraint violation guards as asserts in the graph.
- ep = torch.export._trace._export(
- model,
- (inputs,),
- dynamic_shapes=({1: seq_len},),
- strict=False,
- allow_complex_guards_as_runtime_asserts=True,
- )
-
- return ep
-
-
-def generate(model, input_seq, max_tokens, eos_token_id):
- """
- Greedy decoding of the model. This generates up to max_tokens.
- """
- # Max length of output seq = current input_seq length + max_tokens allowed to generate
- max_output_seq_length = input_seq.shape[1] + max_tokens
- stopping_criteria = StoppingCriteriaList(
- [
- MaxLengthCriteria(max_length=max_output_seq_length),
- EosTokenCriteria(eos_token_id=eos_token_id),
- ]
- )
-
- while True:
- outputs = model(input_seq)
- logits = outputs.logits
- next_token_logits = logits[:, -1, :]
- next_tokens = torch.argmax(next_token_logits, dim=-1)
- input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1)
- # TODO: Handle batch in this check
- if stopping_criteria(input_seq, logits).item():
- break
-
- return input_seq
diff --git a/examples/dynamo/weight_streaming_example.py b/examples/dynamo/weight_streaming_example.py
index e1076a9e75..601292ba95 100644
--- a/examples/dynamo/weight_streaming_example.py
+++ b/examples/dynamo/weight_streaming_example.py
@@ -32,7 +32,43 @@
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM
-from utils import export_llm
+
+
+def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
+ """
+ Exports the LLM model into an ExportedProgram with dynamic shapes.
+ In the case of guard failures due to some PyTorch kernel implements, we also
+ try to re-export the graph by expressing them as runtime assert nodes
+ """
+ with torch.no_grad():
+ # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
+ seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
+ position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device)
+ try:
+ print("Trying to export the model using torch.export.export()..")
+ # strict=False only enables aotautograd tracing and excludes dynamo.
+ ep = torch.export.export(
+ model,
+ args=(inputs,),
+ kwargs={"position_ids": position_ids},
+ dynamic_shapes=({1: seq_len}, {1: seq_len}),
+ strict=False,
+ )
+ except:
+ print(
+ "Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
+ )
+ # This API is used to express the constraint violation guards as asserts in the graph.
+ ep = torch.export._trace._export(
+ model,
+ args=(inputs,),
+ kwargs={"position_ids": position_ids},
+ dynamic_shapes=({1: seq_len}, {1: seq_len}),
+ strict=False,
+ allow_complex_guards_as_runtime_asserts=True,
+ )
+
+ return ep
def time_generate(model, inputs, output_seq_length, iterations=10):
diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index d7092f1e0f..116eadfd41 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -799,6 +799,28 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
)
+ # Store the original input spec for later use
+ original_in_spec = getattr(gm, "_in_spec", None)
+ original_out_spec = getattr(gm, "_out_spec", None)
+
+ # Function to preserve and restore module specs
+ def preserve_module_specs(
+ in_spec: Any, out_spec: Any, target_module: torch.fx.GraphModule
+ ) -> None:
+ """
+ Applies input and output specs to the target module.
+
+ Args:
+ in_spec: The input spec to apply
+ out_spec: The output spec to apply
+ target_module: The module to apply specs to
+ """
+ # Apply specs to target module
+ if in_spec is not None:
+ target_module._in_spec = in_spec
+ if out_spec is not None:
+ target_module._out_spec = out_spec
+
# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False
# If specified, try using the fast partitioner and fall back to the global one on failure
@@ -844,6 +866,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
continue
submodule_node_dict[node.name] = node
+ preserve_module_specs(original_in_spec, original_out_spec, partitioned_module)
# Store TRT replicas of Torch subgraphs
trt_modules = {}
# Iterate over all components that can be accelerated
diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
index b134b3d5f5..8d7a914836 100644
--- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
+++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
@@ -890,10 +890,9 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
else:
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
- def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
+ def get_attr(self, target: str, args: Any, kwargs: Any) -> torch.Tensor:
with _disable_current_modes(), unset_fake_temporarily():
frozen_attr = self.fetch_attr(target)
-
if isinstance(frozen_attr, torch.nn.Parameter):
constant_tensor = frozen_attr.data
else:
diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index e542f1d417..f243d091a4 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -1935,6 +1935,7 @@ def aten_ops_minimum(
)
+@dynamo_tensorrt_converter(operator.sub, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True)
def aten_ops_sub(
diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
index fc76b20141..1d619b6ce3 100644
--- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
+++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
@@ -752,7 +752,14 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
# Representation of input shapes to a given model
# Shapes are concatenated as so:
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
- new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs)
+ tensor_inputs = []
+ for t in inputs:
+ if not isinstance(t, torch.Tensor):
+ return True
+ tensor_inputs.append(t)
+ new_shape_key = "".join(
+ str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+ )
# If the new shape key differs from the existing one,
# invalidate the old shape key and remove the CUDAGraph
diff --git a/tools/llm/README.md b/tools/llm/README.md
new file mode 100644
index 0000000000..3fd55bc060
--- /dev/null
+++ b/tools/llm/README.md
@@ -0,0 +1,66 @@
+# Optimizing LLMs in Torch-TensorRT
+
+This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry point is `run_llm.py`, which demonstrates how to export, compile, and run LLMs with various caching strategies and precision modes. Note that this is an **experimental release** and APIs may change in future versions.
+
+### Key Features
+
+- **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc.
+- **Precision Modes:** Supports FP16, BF16, and FP32.
+- **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding.
+- **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends.
+- **Custom Attention:** Registers and converts custom scaled dot-product attention (SDPA) for compatibility with TensorRT.
+
+
+### Supported Models
+
+We have officially verified support for the following models:
+
+| Model Series | HF Model Card | Precision | KV Cache Supported ? |
+|--------------|---------------|-----------|-------------------|
+| GPT-2 | gpt2
gpt2-medium | FP16, FP32 | Yes |
+| LLaMA 2 | meta-llama/Llama-2-7b-chat-hf | FP16, FP32 | Yes |
+| LLaMA 3.1 | meta-llama/Llama-3.1-8B-Instruct | FP16, FP32 | Yes |
+| LLaMA 3.2 | meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.2-3B-Instruct | FP16, FP32 | Yes |
+| Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-4B-Instruct
Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes |
+
+
+### Usage
+
+The main entry point is : `run_llm.py`
+
+```bash
+python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark
+```
+
+#### Key Arguments
+
+- `--model`: Name or path of the HuggingFace LLM.
+- `--tokenizer`: (Optional) Tokenizer name; defaults to model.
+- `--prompt`: Input prompt for generation.
+- `--precision`: Precision mode (`FP16`, `FP32`).
+- `--num_tokens`: Number of output tokens to generate.
+- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching).
+- `--benchmark`: Enable benchmarking mode.
+- `--enable_pytorch_run`: Also run and compare PyTorch baseline.
+
+### Caching Strategies
+
+- **Static Cache v1/v2:** Adds static KV cache tensors as model inputs/outputs for efficient reuse.
+- **No Cache:** Standard autoregressive decoding.
+
+Please read our tutorial on how static cache is implemented.
+
+## Extension
+
+This codebase can be extended to
+- Add new models by specifying their HuggingFace name.
+- Implement new cache strategies by adding FX graph passes.
+- Customize SDPA conversion for new attention mechanisms.
+
+## Limitations
+- We do not currently support sliding window attention (used in Gemma3 and Qwen 3 models) yet.
+
+## Requirements
+
+- Torch-TensorRT 2.8.0
+- Transformers v4.52.3
\ No newline at end of file
diff --git a/tools/llm/cache_utils.py b/tools/llm/cache_utils.py
new file mode 100644
index 0000000000..d25e5bb40e
--- /dev/null
+++ b/tools/llm/cache_utils.py
@@ -0,0 +1,177 @@
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import tensorrt
+import torch
+import torch_tensorrt
+from torch._export.utils import _detect_fake_mode_from_gm
+from torch._ops import OpOverloadPacket
+from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
+from torch.fx import Graph, GraphModule, Node
+from torch.fx.node import Target
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
+from torch.utils._pytree import _LEAF_SPEC
+
+
+def get_kv_nodes(gm):
+ """
+ Extract key and value nodes from scaled dot-product attention operations in the graph.
+
+ This function searches through the graph for scaled_dot_product_attention operations
+ and extracts the key and value tensor nodes from each operation's arguments.
+
+ Args:
+ gm: A torch.fx.GraphModule containing the computational graph
+
+ Returns:
+ List[Tuple[Node, Node]]: A list of tuples, where each tuple contains
+ (key_node, value_node) from a scaled dot-product attention operation
+ """
+ kv_nodes = []
+ for node in gm.graph.nodes:
+ if (
+ node.op == "call_function"
+ and node.target == torch._C._nn.scaled_dot_product_attention
+ ):
+ q_node, k_node, v_node = node.args[:3]
+ kv_nodes.append((k_node, v_node))
+ return kv_nodes
+
+
+def get_random_tensor_from_node(node: Node) -> torch.Tensor:
+ """
+ Creates a random tensor based on the shape information in a node's metadata.
+ For symbolic dimensions, extracts the maximum value from the shape environment.
+
+ Args:
+ node: A torch.fx.Node object with metadata containing tensor information
+
+ Returns:
+ A random tensor with shape matching the node's metadata, or None if no valid
+ tensor information is found
+ """
+ if "val" not in node.meta:
+ raise ValueError(
+ f"No tensor information found in node metadata for node: {node}"
+ )
+
+ fake_tensor = node.meta["val"]
+ shape = []
+
+ # Iterate through each dimension and handle symbolic dimensions
+ for dim in fake_tensor.shape:
+ if isinstance(dim, torch.SymInt):
+ # Extract the maximum value from the shape environment
+ max_val = dim.node.hint
+ shape.append(max_val)
+ else:
+ shape.append(dim)
+
+ # Create a random tensor with the determined shape
+ dtype = fake_tensor.dtype
+ device = fake_tensor.device
+ random_tensor = torch.rand(shape, dtype=dtype, device=device)
+
+ return random_tensor
+
+
+def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]:
+ """
+ Creates random tensors based on the shape information in node metadata.
+ For symbolic dimensions, extracts the maximum value from the shape environment.
+
+ Args:
+ nodes: List of torch.fx.Node objects with metadata
+
+ Returns:
+ List of random tensors with shapes matching the nodes' metadata
+ """
+ random_tensors = []
+
+ for node in nodes:
+ if isinstance(node, Node):
+ node_tensor = get_random_tensor_from_node(node)
+ elif isinstance(node, tuple):
+ node_tensor_list = []
+ for n in node:
+ random_tensor = get_random_tensor_from_node(n)
+ node_tensor_list.append(random_tensor)
+ node_tensor = tuple(node_tensor_list)
+
+ random_tensors.append(node_tensor)
+
+ return random_tensors
+
+
+def _add_graph_input(
+ gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
+) -> Node:
+ """Add a graph input to the given GraphModule and return the newly created node.
+
+ NOTE: function does NOT do any graph canonicalization. This is left to the user!
+
+ Args:
+ gm (GraphModule): The GraphModule to add the input to.
+ name (str): The name of the input.
+ val (torch.Tensor): An example tensor to use for the input.
+ dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET]
+ """
+ # check that no dynamic shape is provided...
+ if dynamic_shape:
+ raise NotImplementedError("Dynamic shape not supported for adding graph inputs")
+
+ # extract graph and input spec
+ graph: Graph = gm.graph
+
+ in_spec = graph._codegen.pytree_info.in_spec
+ in_spec_for_args = in_spec.children_specs[0]
+ orig_args = graph._codegen.pytree_info.orig_args
+ assert in_spec_for_args.type is tuple
+
+ # insert input node after currently last input node
+ node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1]
+ with graph.inserting_after(node_last_input):
+ in_node = graph.placeholder(name)
+ in_spec_for_args.children_specs.append(_LEAF_SPEC)
+ orig_args.append(f"arg_{name}")
+
+ # update pytree info recursively with __post_init__ starting at leaves
+ def call_post_init(spec):
+ for child_spec in spec.children_specs:
+ call_post_init(child_spec)
+ spec.__post_init__()
+
+ call_post_init(in_spec)
+
+ # set fake tensor information if all required information is available
+ fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
+ if fake_mode and val is not None and isinstance(val, torch.Tensor):
+ if isinstance(val, FakeTensor):
+ fake_tensor = val
+ else:
+ fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True)
+ in_node.meta["val"] = fake_tensor
+ in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor)
+
+ # return new node...
+ return in_node
+
+
+def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool:
+ """Check if the node is a call to one of the ops."""
+ if node.op != "call_function":
+ return False
+ # check if it's a single op that's provided
+ if isinstance(ops, OpOverloadPacket):
+ ops = [ops]
+
+ # check if it's the op itself instead of an overload
+ if any(node.target == op for op in ops):
+ return True
+
+ return False
+
+
+def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]:
+ input_nodes: List[Node] = graph.find_nodes(op="placeholder")
+ output_nodes: List[Node] = graph.find_nodes(op="output")
+ return (input_nodes, output_nodes)
diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py
new file mode 100644
index 0000000000..1a4030ea7d
--- /dev/null
+++ b/tools/llm/run_llm.py
@@ -0,0 +1,356 @@
+"""
+.. _run_llm:
+
+Running LLM inference with Torch-TensorRT
+==========================================================
+
+This script illustrates Torch-TensorRT workflow with dynamo backend on popular LLM models.
+"""
+
+import argparse
+import copy
+import os
+import timeit
+from contextlib import nullcontext
+
+# %%
+# Imports and Model Definition
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+import torch
+import torch_tensorrt
+from torchtrt_ext import register_sdpa
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from utils import (
+ export_llm,
+ generate,
+ generate_with_static_cache,
+ record_stats,
+ time_generate,
+)
+
+DEVICE = torch.device("cuda:0")
+
+
+def get_model(args):
+ """
+ Load and configure the language model for inference.
+
+ This function loads a pre-trained causal language model using the specified
+ model name and configures it with the appropriate precision and settings
+ for inference.
+
+ Args:
+ args: Parsed command line arguments containing:
+ - model (str): Name or path of the model to load
+ - precision (str): Precision to use ("FP16", "BF16", or "FP32")
+
+ Returns:
+ torch.nn.Module: The loaded and configured model ready for inference,
+ moved to CUDA device with the specified precision
+ """
+ with torch.no_grad():
+ model = (
+ AutoModelForCausalLM.from_pretrained(
+ args.model,
+ use_cache=False,
+ attn_implementation="sdpa",
+ )
+ .eval()
+ .cuda()
+ )
+ if args.precision == "FP16":
+ model = model.to(torch.float16)
+ elif args.precision == "BF16":
+ model = model.to(torch.bfloat16)
+ else:
+ model = model.to(torch.float32)
+
+ return model
+
+
+def compile_torchtrt(model, input_ids, args):
+ """
+ Compile a PyTorch model to TensorRT using torch_tensorrt.dynamo.compile.
+
+ This function exports the given model to a TorchScript representation and then
+ compiles it to TensorRT for optimized inference. The compilation process includes
+ precision-specific optimizations and various performance tuning parameters.
+
+ Args:
+ model (torch.nn.Module): The PyTorch model to compile
+ input_ids (torch.Tensor): Input token IDs tensor used for model export
+ args: Parsed command line arguments containing:
+ - num_tokens (int): Number of tokens to generate (used for max sequence length)
+ - precision (str): Precision to use ("FP16", "BF16", or "FP32")
+ - debug (bool): Whether to enable debug logging
+ - min_block_size (int): Minimum block size for TensorRT compilation
+
+ Returns:
+ torch_tensorrt.dynamo.TorchTensorRTModule: The compiled TensorRT model ready
+ for optimized inference
+ """
+ max_seq_len = input_ids.shape[1] + args.num_tokens
+ ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
+ position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
+ # Set precision specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ use_fp32_acc = False
+ else:
+ enabled_precisions = {torch.float32}
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[input_ids, position_ids],
+ enabled_precisions=enabled_precisions,
+ # truncate_double=True,
+ use_explicit_typing=use_explicit_typing,
+ use_fp32_acc=use_fp32_acc,
+ device=DEVICE,
+ disable_tf32=True,
+ use_python_runtime=True,
+ debug=args.debug,
+ offload_module_to_cpu=True,
+ min_block_size=args.min_block_size,
+ )
+
+ return trt_model
+
+
+def print_outputs(backend_name, gen_tokens, tokenizer):
+ """
+ Print the generated tokens from the model.
+ """
+ print(f"========= {backend_name} =========")
+ print(
+ f"{backend_name} model generated text: ",
+ tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
+ )
+ print("===================================")
+
+
+def measure_perf(trt_model, input_signature, backend_name):
+ """
+ Measure the performance of a TensorRT model by running it multiple times and
+ calculating the average time per iteration.
+ """
+ total_time = 0
+ iterations = 10
+
+ print("Running warmup iteration...")
+ # Warmup run
+ _ = trt_model(*input_signature)
+ torch.cuda.synchronize()
+
+ print(f"Measuring performance over {iterations} iterations...")
+ for i in range(iterations):
+ start_time = timeit.default_timer()
+ _ = trt_model(*input_signature)
+ torch.cuda.synchronize()
+ end_time = timeit.default_timer()
+ iter_time = end_time - start_time
+ total_time += iter_time
+
+ avg_time = total_time / iterations
+ print(
+ f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds"
+ )
+ print(
+ f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second"
+ )
+
+
+if __name__ == "__main__":
+ arg_parser = argparse.ArgumentParser(
+ description="Run inference on a model with random input values"
+ )
+ arg_parser.add_argument(
+ "--model",
+ type=str,
+ default="meta-llama/Llama-3.2-1B-Instruct",
+ help="Name of LLM model",
+ )
+ arg_parser.add_argument(
+ "--tokenizer",
+ type=str,
+ default="",
+ help="Name of LLM model tokenizer",
+ )
+ arg_parser.add_argument(
+ "--prompt", type=str, default="What is parallel programming ?", help="Prompt"
+ )
+ arg_parser.add_argument(
+ "--precision",
+ type=str,
+ default="FP16",
+ help="Precision to use in the model. Options: FP16, BF16, FP32",
+ )
+ arg_parser.add_argument(
+ "--iterations", type=int, default=5, help="no. of iterations to run"
+ )
+ arg_parser.add_argument(
+ "--min_block_size", type=int, default=1, help="no. of iterations to run"
+ )
+ arg_parser.add_argument(
+ "--num_tokens",
+ type=int,
+ default=128,
+ help="no. of output tokens to be generated",
+ )
+ arg_parser.add_argument(
+ "--batch_size", type=int, default=1, help="Batch size used for benchmarking"
+ )
+ arg_parser.add_argument(
+ "--isl",
+ type=int,
+ default=2048,
+ help="Input sequence length used for benchmarking",
+ )
+ arg_parser.add_argument(
+ "--enable_pytorch_run",
+ action="store_true",
+ help="Enable pytorch run (default: False)",
+ )
+ arg_parser.add_argument(
+ "--cache",
+ type=str,
+ default="",
+ help="Type of KV cache to use. Options: static_v1, static_v2",
+ )
+ arg_parser.add_argument(
+ "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)"
+ )
+ arg_parser.add_argument(
+ "--debug", action="store_true", help="Enable debug (default: False)"
+ )
+ arg_parser.add_argument(
+ "--benchmark", action="store_true", help="Enable benchmark (default: False)"
+ )
+
+ args = arg_parser.parse_args()
+ with torch.inference_mode():
+ model = get_model(args)
+
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model)
+
+ # Prepare input for benchmarking or evaluation
+ if args.benchmark:
+ input_ids = torch.randint(
+ 1, 10000, (args.batch_size, args.isl), dtype=torch.int64
+ ).to(model.device)
+ position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
+ else:
+ model_inputs = tokenizer(args.prompt, return_tensors="pt")
+ input_ids = model_inputs["input_ids"].to(DEVICE)
+ position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
+
+ MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens
+ # Pyt
+ pyt_gen_tokens = None
+ pyt_timings = None
+ pyt_stats = None
+
+ if args.enable_pytorch_run:
+ pyt_gen_tokens = generate(
+ model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id
+ )
+ if args.benchmark:
+ pyt_timings = time_generate(
+ generate,
+ model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
+ iterations=args.iterations,
+ )
+ pyt_stats = record_stats(
+ "PyTorch",
+ pyt_timings,
+ args.precision,
+ batch_size=args.batch_size,
+ compile_time_s=None,
+ )
+
+ if args.cache == "static_v1":
+ # This import is required to register static v1 KV cache transformations as lowering passes
+ from torchtrt_ext import static_cache_v1
+ if args.cache == "static_v2":
+ # This import is required to register static v2 KV cache transformations as lowering passes
+ from torchtrt_ext import static_cache_v2
+
+ # Compile the model with Torch-TensorRT
+ trt_model = compile_torchtrt(model, input_ids, args)
+
+ if args.cache == "static_v1" or args.cache == "static_v2":
+ if args.cudagraph:
+ # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
+ # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
+ torch_tensorrt.runtime.set_cudagraphs_mode(True)
+
+ trt_gen_tokens = generate_with_static_cache(
+ trt_model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
+ )
+
+ if args.benchmark:
+ trt_timings = time_generate(
+ generate_with_static_cache,
+ trt_model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
+ iterations=args.iterations,
+ )
+ else:
+ trt_gen_tokens = generate(
+ trt_model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
+ )
+ if args.benchmark:
+ trt_timings = time_generate(
+ generate,
+ trt_model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
+ iterations=args.iterations,
+ )
+
+ if args.benchmark:
+ trt_stats = record_stats(
+ "TensorRT",
+ trt_timings,
+ args.precision,
+ batch_size=args.batch_size,
+ compile_time_s=None,
+ )
+
+ if not args.benchmark:
+ if args.enable_pytorch_run:
+ print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
+
+ print_outputs("TensorRT", trt_gen_tokens, tokenizer)
+
+ if args.enable_pytorch_run:
+ print(
+ f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}"
+ )
+
+ if args.benchmark:
+ if args.enable_pytorch_run:
+ print("=========PyTorch PERFORMANCE============ \n")
+ print(pyt_stats)
+ print("===================== \n")
+ print("=========TensorRT PERFORMANCE============ \n")
+ print(trt_stats)
diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py
new file mode 100644
index 0000000000..b60396c08b
--- /dev/null
+++ b/tools/llm/static_cache_v1.py
@@ -0,0 +1,277 @@
+import logging
+from typing import List, Tuple
+
+import torch
+import torch.utils._pytree as pytree
+from cache_utils import _add_graph_input, create_random_output_tensors, get_kv_nodes
+from torch.fx import Node
+from torch_tensorrt.dynamo._settings import CompilationSettings
+from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
+ _aten_lowering_pass,
+)
+from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
+ clean_up_graph_after_modifications,
+)
+from torch_tensorrt.dynamo.utils import extract_var_range_info
+
+logger = logging.getLogger(__name__)
+
+SDPA_OP = torch._C._nn.scaled_dot_product_attention
+
+
+def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
+ """
+ Modifies the graph to add query, key, and value tensors as outputs.
+
+ This function identifies all scaled dot-product attention (SDPA) operations
+ in the graph, creates copies of their query, key, and value inputs, and adds
+ these copies to the graph's outputs. This allows for accessing these tensors
+ externally, which is useful for operations like key-value caching.
+
+ Args:
+ graph: The torch.fx.Graph to modify
+
+ Returns:
+ None. The graph is modified in-place.
+ """
+ output_node = next(node for node in gm.graph.nodes if node.op == "output")
+
+ # Get the current output args (typically a tuple)
+ current_outputs = output_node.args[0]
+
+ # If the current output is a tuple, extend it with our new outputs
+ if isinstance(current_outputs, tuple):
+ new_outputs = current_outputs + tuple(kv_cache_for_graph)
+ else:
+ # If there's only one output or it's not a tuple, create a new tuple
+ new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
+ gm.graph.output(new_outputs)
+ gm.graph.erase_node(output_node)
+
+ return new_outputs
+
+
+def add_kv_cache_inputs(gm, fixed_kv: bool = True):
+ """
+ Add key-value tensors, index parameters as inputs to the graph.
+
+ Args:
+ gm: The GraphModule to modify
+ fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
+
+ Returns:
+ A tuple containing:
+ - List of (k_input, v_input) node pairs for each SDPA operation
+ - start_idx input node for slicing operations
+ - end_idx input node for slicing operations
+ """
+
+ def get_static_tensor(tensor: torch.Tensor):
+ key_shape = []
+ for dim in tensor.shape:
+ if isinstance(dim, torch.SymInt):
+ min_max_opt = extract_var_range_info(dim)
+ key_shape.append(min_max_opt["max"])
+ else:
+ key_shape.append(dim)
+
+ static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
+ return static_tensor
+
+ keys_values = get_kv_nodes(gm)
+
+ kv_inputs = []
+ for idx, key_value in enumerate(keys_values):
+ k_val = key_value[0].meta["val"]
+ v_val = key_value[1].meta["val"]
+ if fixed_kv:
+ k_val = get_static_tensor(k_val)
+ v_val = get_static_tensor(v_val)
+
+ # Add new inputs using _add_graph_input
+ k_input = _add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+ v_input = _add_graph_input(gm, key_value[1].name + "_v_input", v_val)
+ kv_inputs.append((k_input, v_input))
+
+ # Add start_idx and end_idx as inputs
+ start_idx_input = _add_graph_input(gm, "start_idx", torch.tensor(0))
+ end_idx_input = _add_graph_input(gm, "end_idx", torch.tensor(1))
+
+ # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, ..
+ input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
+ input_ids_meta = input_nodes[0].meta["val"]
+ seq_len = input_ids_meta.shape[1]
+ min_max_opt = extract_var_range_info(seq_len)
+ max_seq_len = min_max_opt["max"]
+
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+ shape_env = ShapeEnv()
+ # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
+ start_idx_unbacked_symint = shape_env.create_unbacked_symint()
+ torch._check(start_idx_unbacked_symint >= 0)
+ torch._check(start_idx_unbacked_symint <= max_seq_len)
+
+ end_idx_unbacked_symint = shape_env.create_unbacked_symint()
+ torch._check(end_idx_unbacked_symint >= 0)
+ torch._check(end_idx_unbacked_symint <= max_seq_len)
+ # Set the symbolic ints as the metadata for start_idx and end_idx inputs
+ start_idx_input.meta["val"] = start_idx_unbacked_symint
+ end_idx_input.meta["val"] = end_idx_unbacked_symint
+
+ return kv_inputs, start_idx_input, end_idx_input
+
+
+def insert_kv_slicing_before_sdpa(
+ gm,
+ incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+ start_idx_input: Node,
+ end_idx_input: Node,
+):
+ """
+ Insert slicing operations before each scaled_dot_product_attention operation.
+ """
+ # Find all nodes with scaled_dot_product_attention
+ sdpa_nodes = []
+ for node in gm.graph.nodes:
+ if node.op == "call_function" and node.target == SDPA_OP:
+ sdpa_nodes.append(node)
+ kv_cache_for_graph = []
+ for idx, sdpa_node in enumerate(sdpa_nodes):
+ assert (
+ len(sdpa_node.args) == 6
+ ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
+ q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
+ incoming_key, incoming_value = incoming_keys_values[idx]
+ kv_cache_for_sdpa_node = []
+ new_keys_values = []
+ for key_or_value, current_key_or_value_node in zip(
+ [incoming_key, incoming_value], [k_node, v_node]
+ ):
+ # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
+ with gm.graph.inserting_before(sdpa_node):
+ slice_1 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(key_or_value,),
+ kwargs={},
+ )
+ slice_2 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_1, 1),
+ kwargs={},
+ )
+ slice_3 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_2, 2, None, start_idx_input),
+ kwargs={},
+ )
+ slice_4 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_3, 3),
+ kwargs={},
+ )
+ # =============================================== #
+ # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
+ slice_5 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(key_or_value,),
+ kwargs={},
+ )
+ slice_6 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_5, 1),
+ kwargs={},
+ )
+ slice_7 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_6, 2, end_idx_input),
+ kwargs={},
+ )
+ slice_8 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_7, 3),
+ kwargs={},
+ )
+ # =============================================== #
+ # Concatenate the sliced tensors to build KV cache
+ cat = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.cat.default,
+ args=([slice_4, current_key_or_value_node, slice_8], 2),
+ kwargs={},
+ )
+ # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
+ cat.meta.update(key_or_value.meta)
+ kv_cache_for_sdpa_node.append(cat)
+ # =============================================== #
+ # Get the current key and value by indexing the KV cache
+ slice_9 = gm.graph.create_node(
+ "call_function", torch.ops.aten.slice.Tensor, args=(cat,), kwargs={}
+ )
+ slice_10 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_9, 1),
+ kwargs={},
+ )
+ slice_11 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_10, 2, None, end_idx_input),
+ kwargs={},
+ )
+ slice_12 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_11, 3),
+ kwargs={},
+ )
+ new_keys_values.append(slice_12)
+
+ kv_cache_for_graph.extend(kv_cache_for_sdpa_node)
+
+ sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (
+ attn_mask,
+ dropout_p,
+ True,
+ )
+
+ return gm, kv_cache_for_graph
+
+
+@_aten_lowering_pass
+def insert_static_cache_v1(
+ gm: torch.fx.GraphModule, settings: CompilationSettings
+) -> torch.fx.GraphModule:
+ """Insert KV cache ops in the graph"""
+ """Perform insertion of kv-caches and attention kernel."""
+ # Add static key and value as inputs to the graph
+ kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True)
+
+ # Build and update the KV cache using computed KV inputs for current token and
+ # incoming keys and values from previous tokens (which were added as inputs)
+ gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+ gm, kv_inputs, start_idx_input, end_idx_input
+ )
+
+ # Call the function to add KV as outputs
+ logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)
+
+ gm = clean_up_graph_after_modifications(gm)
+
+ new_output_tensors = create_random_output_tensors(logits_keys_values)
+
+ new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
+ gm._out_spec = new_out_spec
+ logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
+
+ return gm
diff --git a/tools/llm/static_cache_v2.py b/tools/llm/static_cache_v2.py
new file mode 100644
index 0000000000..4634b79a52
--- /dev/null
+++ b/tools/llm/static_cache_v2.py
@@ -0,0 +1,290 @@
+import logging
+from typing import List, Tuple
+
+import torch
+import torch.utils._pytree as pytree
+from cache_utils import _add_graph_input, create_random_output_tensors, get_kv_nodes
+from torch.fx import Node
+from torch_tensorrt.dynamo._settings import CompilationSettings
+from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
+ _aten_lowering_pass,
+)
+from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
+ clean_up_graph_after_modifications,
+)
+from torch_tensorrt.dynamo.utils import extract_var_range_info
+
+logger = logging.getLogger(__name__)
+
+SDPA_OP = torch._C._nn.scaled_dot_product_attention
+
+
+def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
+ """
+ Modifies the graph to add query, key, and value tensors as outputs.
+
+ This function identifies all scaled dot-product attention (SDPA) operations
+ in the graph, creates copies of their query, key, and value inputs, and adds
+ these copies to the graph's outputs. This allows for accessing these tensors
+ externally, which is useful for operations like key-value caching.
+
+ Args:
+ graph: The torch.fx.Graph to modify
+
+ Returns:
+ None. The graph is modified in-place.
+ """
+ output_node = next(node for node in gm.graph.nodes if node.op == "output")
+
+ # Get the current output args (typically a tuple)
+ current_outputs = output_node.args[0]
+
+ # If the current output is a tuple, extend it with our new outputs
+ if isinstance(current_outputs, tuple):
+ new_outputs = current_outputs + tuple(kv_cache_for_graph)
+ else:
+ # If there's only one output or it's not a tuple, create a new tuple
+ new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
+ gm.graph.output(new_outputs)
+ gm.graph.erase_node(output_node)
+
+ return new_outputs
+
+
+def add_kv_cache_inputs(gm, fixed_kv: bool = True):
+ """
+ Add key-value tensors, index parameters as inputs to the graph.
+
+ Args:
+ gm: The GraphModule to modify
+ fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
+
+ Returns:
+ A tuple containing:
+ - List of (k_input, v_input) node pairs for each SDPA operation
+ - start_idx input node for slicing operations
+ - end_idx input node for slicing operations
+ """
+
+ def get_static_tensor(tensor: torch.Tensor):
+ key_shape = []
+ for dim in tensor.shape:
+ if isinstance(dim, torch.SymInt):
+ min_max_opt = extract_var_range_info(dim)
+ key_shape.append(min_max_opt["max"])
+ else:
+ key_shape.append(dim)
+
+ static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
+ return static_tensor
+
+ keys_values = get_kv_nodes(gm)
+
+ kv_inputs = []
+ for idx, key_value in enumerate(keys_values):
+ k_val = key_value[0].meta["val"]
+ v_val = key_value[1].meta["val"]
+ if fixed_kv:
+ k_val = get_static_tensor(k_val)
+ v_val = get_static_tensor(v_val)
+
+ # Add new inputs using _add_graph_input
+ k_input = _add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+ v_input = _add_graph_input(gm, key_value[1].name + "_v_input", v_val)
+ kv_inputs.append((k_input, v_input))
+
+ # Add start_idx and end_idx as inputs
+ start_idx_input = _add_graph_input(gm, "start_idx", torch.tensor(0))
+ end_idx_input = _add_graph_input(gm, "end_idx", torch.tensor(1))
+
+ # Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx
+ input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
+ # Get the third last input which should be the last value cache node and store the max_seq_len
+ input_ids_meta = input_nodes[-3].meta["val"]
+ seq_len = input_ids_meta.shape[2]
+
+ if isinstance(seq_len, torch.SymInt):
+ min_max_opt = extract_var_range_info(seq_len)
+ max_seq_len = min_max_opt["max"]
+ else:
+ max_seq_len = seq_len
+
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+ shape_env = ShapeEnv()
+ # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
+ start_idx_unbacked_symint = shape_env.create_unbacked_symint()
+ torch._check(start_idx_unbacked_symint >= 0)
+ torch._check(start_idx_unbacked_symint <= max_seq_len)
+
+ end_idx_unbacked_symint = shape_env.create_unbacked_symint()
+ torch._check(end_idx_unbacked_symint >= 0)
+ torch._check(end_idx_unbacked_symint <= max_seq_len)
+ # Set the symbolic ints as the metadata for start_idx and end_idx inputs
+ start_idx_input.meta["val"] = start_idx_unbacked_symint
+ end_idx_input.meta["val"] = end_idx_unbacked_symint
+
+ return kv_inputs, start_idx_input, end_idx_input
+
+
+def create_kv_cache_update_nodes(
+ gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input
+):
+ """
+ Create slicing and concatenation nodes for KV cache update.
+
+ This function creates the necessary slicing and concatenation nodes to update the KV cache
+ during the generation process. It takes the SDPA node, the current KV cache node, and the
+ incoming KV cache node as input.
+ Returns:
+ for a particular SDPA node, a tuple containing:
+ - List of new current KV nodes
+ - List of updated incoming KV cache nodes
+
+ """
+
+ # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
+ with gm.graph.inserting_before(sdpa_node):
+ slice_1 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(incoming_kv_node,),
+ kwargs={},
+ )
+ slice_2 = gm.graph.create_node(
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_1, 1), kwargs={}
+ )
+ slice_3 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_2, 2, None, start_idx_input),
+ kwargs={},
+ )
+ slice_4 = gm.graph.create_node(
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_3, 3), kwargs={}
+ )
+ # Concat key_cache[:,:,:start_idx,:] with current key (k)
+ concat_keys_or_values = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.cat.default,
+ args=([slice_4, current_kv_node], 2),
+ kwargs={},
+ )
+
+ # =============================================== #
+ # Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
+ slice_5 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(incoming_kv_node,),
+ kwargs={},
+ )
+ slice_6 = gm.graph.create_node(
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_5, 1), kwargs={}
+ )
+ slice_7 = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.slice.Tensor,
+ args=(slice_6, 2, end_idx_input),
+ kwargs={},
+ )
+ slice_8 = gm.graph.create_node(
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_7, 3), kwargs={}
+ )
+ # =============================================== #
+ # Concatenate the sliced tensors to build KV cache
+ new_incoming_keys_or_values = gm.graph.create_node(
+ "call_function",
+ torch.ops.aten.cat.default,
+ args=([concat_keys_or_values, slice_8], 2),
+ kwargs={},
+ )
+ # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
+ new_incoming_keys_or_values.meta.update(incoming_kv_node.meta)
+
+ return concat_keys_or_values, new_incoming_keys_or_values
+
+
+def insert_kv_slicing_before_sdpa(
+ gm,
+ incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+ start_idx_input: Node,
+ end_idx_input: Node,
+):
+ """
+ Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic:
+ concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
+ concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
+ new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
+ new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
+ out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal)
+ """
+ # Find all nodes with scaled_dot_product_attention
+ sdpa_nodes = []
+ for node in gm.graph.nodes:
+ if node.op == "call_function" and node.target == SDPA_OP:
+ sdpa_nodes.append(node)
+ kv_cache_for_graph = []
+ for idx, sdpa_node in enumerate(sdpa_nodes):
+ assert (
+ len(sdpa_node.args) == 6
+ ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
+ q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
+ incoming_key, incoming_value = incoming_keys_values[idx]
+ # For keys
+ new_current_key_node, new_incoming_key_cache_node = (
+ create_kv_cache_update_nodes(
+ gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input
+ )
+ )
+ # For values
+ new_current_value_node, new_incoming_value_cache_node = (
+ create_kv_cache_update_nodes(
+ gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input
+ )
+ )
+
+ # Store the KV cache nodes for the current SDPA node
+ kv_cache_for_graph.extend(
+ [new_incoming_key_cache_node, new_incoming_value_cache_node]
+ )
+
+ # Update the SDPA node arguments with current key and value nodes
+ sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (
+ attn_mask,
+ dropout_p,
+ True,
+ )
+
+ # kv_cache_for_graph.extend([k_node, v_node])
+ return gm, kv_cache_for_graph
+
+
+@_aten_lowering_pass
+def insert_static_cache_v2(
+ gm: torch.fx.GraphModule, settings: CompilationSettings
+) -> torch.fx.GraphModule:
+ """Insert KV cache ops in the graph"""
+ """Perform insertion of kv-caches and attention kernel."""
+ # Add static key and value as inputs to the graph
+ kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True)
+
+ # Build and update the KV cache using computed KV inputs for current token and
+ # incoming keys and values from previous tokens (which were added as inputs)
+ gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+ gm, kv_inputs, start_idx_input, end_idx_input
+ )
+
+ # Call the function to add KV as outputs
+ logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)
+
+ gm = clean_up_graph_after_modifications(gm)
+
+ new_output_tensors = create_random_output_tensors(logits_keys_values)
+
+ new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
+ gm._out_spec = new_out_spec
+
+ logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
+ return gm
diff --git a/tools/llm/test_llama_components.py b/tools/llm/test_llama_components.py
new file mode 100644
index 0000000000..ef7e59cd72
--- /dev/null
+++ b/tools/llm/test_llama_components.py
@@ -0,0 +1,603 @@
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = False
+torch.backends.cudnn.allow_tf32 = False
+
+import argparse
+import os
+import sys
+from contextlib import nullcontext
+
+import torch.nn as nn
+import torch_tensorrt
+from torch.testing._internal.common_utils import TestCase, run_tests
+from transformers import AutoModelForCausalLM
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
+
+# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
+from register_sdpa import *
+
+ATOL = 1e-5
+RTOL = 1e-5
+
+
+# llama2_model_name = "meta-llama/Llama-2-7b-hf"
+llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct"
+llama_model = (
+ AutoModelForCausalLM.from_pretrained(
+ llama3_model_name,
+ use_cache=False,
+ attn_implementation="sdpa",
+ num_hidden_layers=1,
+ )
+ .eval()
+ .cuda()
+)
+LLAMA_CONFIG = llama_model.config
+
+
+def test_llama_attention(args):
+
+ DTYPE = torch.float32
+ if args.precision == "FP16":
+ DTYPE = torch.float16
+ elif args.precision == "BF16":
+ DTYPE = torch.bfloat16
+
+ # Set precision specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ use_fp32_acc = False
+ else:
+ enabled_precisions = {torch.float32}
+
+ # model = LlamaAttentionBlock().eval().cuda().to(DTYPE)
+ model = llama_model.model.layers[0].self_attn.to(DTYPE)
+ # llama3
+ hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
+ position_embeddings = (
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ )
+
+ pyt_output = model(hidden_states, position_embeddings, None)
+ seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+ dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
+ from torch.export._trace import _export
+
+ # ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes, strict=False)
+ ep = _export(
+ model,
+ args=(hidden_states, position_embeddings, None),
+ dynamic_shapes=dynamic_shapes,
+ strict=False,
+ allow_complex_guards_as_runtime_asserts=True,
+ )
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[hidden_states, position_embeddings, None],
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ debug=args.debug,
+ )
+ trt_output = trt_model(hidden_states, position_embeddings, None)
+ if isinstance(pyt_output, tuple):
+ print(
+ f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+ )
+ assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
+ else:
+ print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")
+ assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
+
+def print_diff(tensor1, tensor2, prefix=""):
+ """
+ Print the diff between two tensors
+ """
+ print(
+ f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+ )
+
+
+def test_llama_attention_with_static_cache(args):
+ class LlamaAttentionBlock(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.config = LLAMA_CONFIG
+ self.attn = LlamaAttention(config=self.config, layer_idx=0)
+
+ def forward(self, hidden_states, position_embeddings):
+ attn_output, attn_weights = self.attn(
+ hidden_states, position_embeddings, None
+ )
+ return attn_output
+
+ DTYPE = torch.float32
+ if args.precision == "FP16":
+ DTYPE = torch.float16
+ elif args.precision == "BF16":
+ DTYPE = torch.bfloat16
+
+ # Set precision specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ use_fp32_acc = False
+ else:
+ enabled_precisions = {torch.float32}
+ model = llama_model.model.layers[0].self_attn.to(DTYPE)
+
+ # Inputs
+ ISL = 2048
+ NUM_TOKENS = 128
+ OSL = ISL + NUM_TOKENS
+ hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda()
+ position_embeddings = (
+ torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+ )
+ key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+ value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+ start_idx = 0
+ end_idx = ISL
+ is_causal = True
+
+ pyt_output = model(hidden_states, position_embeddings, None)
+ seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+ dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
+ ep = torch.export.export(
+ model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+ )
+ import static_cache_v2
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[
+ hidden_states,
+ position_embeddings,
+ None,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ ],
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ debug=args.debug,
+ # offload_module_to_cpu=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ use_python_runtime=True,
+ )
+
+ # Test Prefill
+ trt_output, _, key_cache, value_cache = trt_model(
+ hidden_states,
+ position_embeddings,
+ None,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
+ print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]")
+
+ # Test Generate
+ for start_idx in range(2048, 2176):
+ end_idx = start_idx + 1
+ hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda()
+ position_embeddings_curr = (
+ torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+ )
+ # Concatenate the current hidden_states with the previous ones
+ hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
+ position_embeddings_full = (
+ torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1),
+ torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1),
+ )
+
+ is_causal = False
+ out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None)
+ out_trt, _, key_cache, value_cache = trt_model(
+ hidden_states_curr,
+ position_embeddings_curr,
+ None,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
+ out_pyt = out_no_cache[:, -1:, :]
+ print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
+
+ hidden_states = hidden_states_full
+ position_embeddings = position_embeddings_full
+
+
+def test_llama_decoder(args):
+
+ class LlamaDecoderLayerBlock(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.config = LLAMA_CONFIG
+ self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0)
+ self.model = model
+
+ def forward(self, hidden_states, position_embeddings):
+ return self.model(hidden_states, position_embeddings=position_embeddings)
+
+ DTYPE = torch.float32
+ if args.precision == "FP16":
+ DTYPE = torch.float16
+ elif args.precision == "BF16":
+ DTYPE = torch.bfloat16
+
+ # Set precision specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ use_fp32_acc = False
+ else:
+ enabled_precisions = {torch.float32}
+
+ model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE))
+ # llama3
+ hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
+ position_embeddings = (
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ )
+
+ pyt_output = model(hidden_states, position_embeddings)
+ seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+ dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
+ ep = torch.export.export(
+ model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+ )
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[hidden_states, position_embeddings],
+ enabled_precisions=enabled_precisions,
+ debug=args.debug,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ )
+ trt_output = trt_model(hidden_states, position_embeddings)
+
+ print(
+ f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+ )
+ assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
+
+
+def test_llama_decoder_with_static_cache(args):
+
+ class LlamaDecoderLayerBlock(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.config = LLAMA_CONFIG
+ self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0)
+ self.model = model
+
+ def forward(self, hidden_states, position_embeddings):
+ return self.model(hidden_states, position_embeddings=position_embeddings)
+
+ DTYPE = torch.float32
+ if args.precision == "FP16":
+ DTYPE = torch.float16
+ elif args.precision == "BF16":
+ DTYPE = torch.bfloat16
+
+ # Set precision specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ use_fp32_acc = False
+ else:
+ enabled_precisions = {torch.float32}
+ model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE))
+
+ # Inputs
+ ISL = 2048
+ NUM_TOKENS = 128
+ OSL = ISL + NUM_TOKENS
+ hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda()
+ position_embeddings = (
+ torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+ )
+ key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+ value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+ start_idx = 0
+ end_idx = ISL
+ is_causal = True
+
+ pyt_output = model(hidden_states, position_embeddings)
+ seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+ dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
+ ep = torch.export.export(
+ model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+ )
+ import static_cache_v2
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ arg_inputs=[
+ hidden_states,
+ position_embeddings,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ ],
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ debug=args.debug,
+ # offload_module_to_cpu=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ use_python_runtime=True,
+ )
+
+ # Test Prefill
+ trt_output, key_cache, value_cache = trt_model(
+ hidden_states,
+ position_embeddings,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
+ print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]")
+
+ # Test Generate
+ for start_idx in range(2048, 2176):
+ end_idx = start_idx + 1
+ hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda()
+ position_embeddings_curr = (
+ torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+ )
+ # Concatenate the current hidden_states with the previous ones
+ hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
+ position_embeddings_full = (
+ torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1),
+ torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1),
+ )
+
+ is_causal = False
+ out_no_cache = model(hidden_states_full, position_embeddings_full)
+
+ out_trt, key_cache, value_cache = trt_model(
+ hidden_states_curr,
+ position_embeddings_curr,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
+ out_pyt = out_no_cache[0][:, -1:, :]
+ print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
+ hidden_states = hidden_states_full
+ position_embeddings = position_embeddings_full
+
+
+def test_llama_model(args):
+
+ DTYPE = torch.float32
+ if args.precision == "FP16":
+ DTYPE = torch.float16
+ elif args.precision == "BF16":
+ DTYPE = torch.bfloat16
+
+ # Set precision specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ use_fp32_acc = False
+ else:
+ enabled_precisions = {torch.float32}
+
+ model = llama_model.model.to(DTYPE)
+
+ # Inputs
+ ISL = 2048
+ NUM_TOKENS = 128
+ OSL = ISL + NUM_TOKENS
+ input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda()
+ position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda()
+
+ pyt_output = model(input_ids, position_ids)
+ seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+ dynamic_shapes = ({1: seq_len}, {1: seq_len})
+ kwarg_inputs = {"position_ids": position_ids}
+ from torch.export._trace import _export
+
+ ep = _export(
+ model,
+ args=(input_ids,),
+ kwargs=kwarg_inputs,
+ dynamic_shapes=dynamic_shapes,
+ strict=False,
+ allow_complex_guards_as_runtime_asserts=True,
+ )
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ arg_inputs=[],
+ kwarg_inputs=kwarg_inputs,
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ debug=args.debug,
+ offload_module_to_cpu=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ use_python_runtime=True,
+ )
+
+ trt_output = trt_model(input_ids, position_ids)
+
+ print(
+ f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+ )
+ # print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}")
+ breakpoint()
+ assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
+
+def test_llama_model_with_static_cache(args):
+
+ DTYPE = torch.float32
+ if args.precision == "FP16":
+ DTYPE = torch.float16
+ elif args.precision == "BF16":
+ DTYPE = torch.bfloat16
+
+ # Set precision specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ use_fp32_acc = False
+ else:
+ enabled_precisions = {torch.float32}
+ model = llama_model.model.to(DTYPE)
+
+ # Inputs
+ ISL = 2048
+ NUM_TOKENS = 128
+ OSL = ISL + NUM_TOKENS
+ input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda()
+ position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda()
+ key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+ value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+ start_idx = 0
+ end_idx = ISL
+ is_causal = True
+
+ pyt_output = model(input_ids)
+ seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+ dynamic_shapes = ({1: seq_len}, {1: seq_len})
+ kwarg_inputs = {"input_ids": input_ids, "position_ids": position_ids}
+ ep = torch.export.export(
+ model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes
+ )
+
+ import static_cache_v2
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ arg_inputs=[],
+ kwarg_inputs=kwarg_inputs,
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ debug=args.debug,
+ # offload_module_to_cpu=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ use_python_runtime=True,
+ )
+
+ # Test Prefill
+ trt_output, key_cache, value_cache = trt_model(
+ input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal
+ )
+ pyt_output = pyt_output.last_hidden_state
+ print_diff(pyt_output, trt_output, "pyt_output vs trt_output [Prefill]")
+
+ # Test Generate
+ for start_idx in range(2048, 2176):
+ end_idx = start_idx + 1
+ input_ids_curr = torch.randint(1, 20, (1, 1), dtype=torch.int64).cuda()
+ position_ids_curr = torch.tensor([[start_idx]], dtype=torch.int64).cuda()
+
+ # Concatenate the current hidden_states with the previous ones
+ input_ids_full = torch.cat((input_ids, input_ids_curr), dim=1)
+ position_ids_full = torch.cat((position_ids, position_ids_curr), dim=1)
+ is_causal = False
+ kwarg_inputs = {"input_ids": input_ids_full, "position_ids": position_ids_full}
+ out_no_cache = model(**kwarg_inputs)
+
+ out_trt, key_cache, value_cache = trt_model(
+ input_ids_curr,
+ position_ids_curr,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
+ out_pyt = out_no_cache.last_hidden_state[:, -1:, :]
+ print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
+ input_ids = input_ids_full
+ position_ids = position_ids_full
+
+
+if __name__ == "__main__":
+ arg_parser = argparse.ArgumentParser(
+ description="Run test cases for llama attention and decoder"
+ )
+ arg_parser.add_argument(
+ "--debug", action="store_true", help="Enable debug (default: False)"
+ )
+ arg_parser.add_argument(
+ "--precision", type=str, default="FP16", help="Precision (default: FP16)"
+ )
+ args = arg_parser.parse_args()
+ with torch.inference_mode():
+ # test_llama_attention(args)
+ # test_llama_decoder(args)
+ test_llama_model(args)
+ # test_llama_attention_with_static_cache(args)
+ # test_llama_decoder_with_static_cache(args)
+ # test_llama_model_with_static_cache(args)
diff --git a/tools/llm/test_qwen2.5_components.py b/tools/llm/test_qwen2.5_components.py
new file mode 100644
index 0000000000..60482bf22d
--- /dev/null
+++ b/tools/llm/test_qwen2.5_components.py
@@ -0,0 +1,193 @@
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = False
+torch.backends.cudnn.allow_tf32 = False
+
+import argparse
+import os
+import sys
+from contextlib import nullcontext
+
+import torch.nn as nn
+import torch_tensorrt
+from torch.testing._internal.common_utils import TestCase, run_tests
+from transformers import AutoModelForCausalLM
+from transformers.models.llama.configuration_llama import LlamaConfig
+
+# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
+from register_sdpa import *
+
+ATOL = 1e-5
+RTOL = 1e-5
+
+
+qwen2_5_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
+qwen2_5_model = (
+ AutoModelForCausalLM.from_pretrained(
+ qwen2_5_model_name,
+ use_cache=False,
+ attn_implementation="sdpa",
+ num_hidden_layers=1,
+ )
+ .eval()
+ .cuda()
+)
+QWEN_CONFIG = qwen2_5_model.config
+
+
+def print_diff(tensor1, tensor2, prefix=""):
+ """
+ Print the diff between two tensors
+ """
+ print(
+ f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+ )
+
+
+def test_qwen_apply_rotary_pos_emb(args):
+ class QwenApplyRotaryPosEmb(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def rotate_half(self, x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (self.rotate_half(q) * sin)
+ k_embed = (k * cos) + (self.rotate_half(k) * sin)
+ return q_embed, k_embed
+
+ def forward(self, q, k, cos, sin, unsqueeze_dim=1):
+ return self.apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim)
+
+ DTYPE = torch.float32
+ if args.precision == "FP16":
+ DTYPE = torch.float16
+ elif args.precision == "BF16":
+ DTYPE = torch.bfloat16
+
+ # Set precision specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ use_fp32_acc = False
+ else:
+ enabled_precisions = {torch.float32}
+
+ model = QwenApplyRotaryPosEmb().eval().cuda().to(DTYPE)
+ # Shapes for Qwen 2.5
+ q = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda()
+ k = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda()
+ cos = torch.randn((1, 5, 128), dtype=DTYPE).cuda()
+ sin = torch.randn((1, 5, 128), dtype=DTYPE).cuda()
+
+ pyt_output = model(q, k, cos, sin)
+
+ seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+ dynamic_shapes = ({2: seq_len}, {2: seq_len}, {1: seq_len}, {1: seq_len})
+ ep = torch.export.export(model, (q, k, cos, sin), dynamic_shapes=dynamic_shapes)
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[q, k, cos, sin],
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ debug=args.debug,
+ )
+ trt_output = trt_model(q, k, cos, sin)
+
+ if isinstance(pyt_output, tuple):
+ print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt")
+ # print_diff(pyt_output[1], trt_output[1], "Diff b/w pyt and trt")
+ assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
+ else:
+ print_diff(pyt_output, trt_output, "Diff b/w pyt and trt")
+ assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
+
+def test_qwen_attention(args):
+
+ DTYPE = torch.float32
+ if args.precision == "FP16":
+ DTYPE = torch.float16
+ elif args.precision == "BF16":
+ DTYPE = torch.bfloat16
+
+ # Set precision specific flags
+ use_fp32_acc = False
+ use_explicit_typing = False
+ if args.precision == "FP16":
+ enabled_precisions = {torch.float32}
+ use_fp32_acc = True
+ use_explicit_typing = True
+ elif args.precision == "BF16":
+ enabled_precisions = {torch.bfloat16}
+ use_fp32_acc = False
+ else:
+ enabled_precisions = {torch.float32}
+
+ model = qwen2_5_model.model.layers[0].self_attn.to(DTYPE)
+ # qwen2.5
+ hidden_states = torch.randn((1, 5, 1536), dtype=DTYPE).cuda()
+ position_embeddings = (
+ torch.randn((1, 5, 128), dtype=DTYPE).cuda(),
+ torch.randn((1, 5, 128), dtype=DTYPE).cuda(),
+ )
+
+ pyt_output = model(hidden_states, position_embeddings, None)
+
+ seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+ dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
+ ep = torch.export.export(
+ model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+ )
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[hidden_states, position_embeddings, None],
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ debug=args.debug,
+ )
+ trt_output = trt_model(hidden_states, position_embeddings, None)
+
+ if isinstance(pyt_output, tuple):
+ print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt")
+ assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
+ else:
+ print_diff(pyt_output, trt_output, "Diff b/w pyt and trt")
+ assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
+
+if __name__ == "__main__":
+ arg_parser = argparse.ArgumentParser(
+ description="Run test cases for llama attention and decoder"
+ )
+ arg_parser.add_argument(
+ "--debug", action="store_true", help="Enable debug (default: False)"
+ )
+ arg_parser.add_argument(
+ "--precision",
+ type=str,
+ default="FP16",
+ help="Precision to use in the model. Options: FP16, BF16, FP32",
+ )
+ args = arg_parser.parse_args()
+ with torch.inference_mode():
+ # test_qwen_apply_rotary_pos_emb(args)
+ test_qwen_attention(args)
diff --git a/tools/llm/test_static_cache.py b/tools/llm/test_static_cache.py
new file mode 100644
index 0000000000..603f84d3a6
--- /dev/null
+++ b/tools/llm/test_static_cache.py
@@ -0,0 +1,478 @@
+import argparse
+import os
+import sys
+from contextlib import nullcontext
+
+import torch
+import torch.nn as nn
+import torch_tensorrt
+from torch.export import export
+from torch_tensorrt.dynamo.lowering import (
+ get_decompositions,
+ post_lowering,
+ pre_export_lowering,
+)
+from transformers import AutoModelForCausalLM
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
+
+# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
+import register_sdpa
+
+ATOL = 1e-5
+RTOL = 1e-5
+torch.backends.cuda.matmul.allow_tf32 = False
+torch.backends.cudnn.allow_tf32 = False
+
+
+class DynamicCacheModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, q, k, v, k1, v1, flag):
+ def true_fn(q, k, v, k1, v1):
+ k_new = torch.cat((k, k1), dim=2)
+ v_new = torch.cat((v, v1), dim=2)
+ return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new)
+
+ def false_fn(q, k, v, k1, v1):
+ return torch._C._nn.scaled_dot_product_attention(q, k, v)
+
+ out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1))
+
+ return 2 * out
+
+
+class ModelNoCache(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, q, k, v):
+ return torch._C._nn.scaled_dot_product_attention(
+ q, k, v, dropout_p=0.0, is_causal=True
+ )
+
+
+class StaticCacheModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(
+ self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+ ):
+ new_key_cache = torch.cat(
+ (key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2
+ )
+ new_value_cache = torch.cat(
+ (value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2
+ )
+ attn_output = torch._C._nn.scaled_dot_product_attention(
+ q,
+ new_key_cache[:, :, :end_idx, :],
+ new_value_cache[:, :, :end_idx, :],
+ dropout_p=0.0,
+ is_causal=is_causal,
+ )
+
+ return attn_output, new_key_cache, new_value_cache
+
+ def forward(
+ self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+ ):
+ concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
+ concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
+ new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
+ new_value_cache = torch.cat(
+ (concat_values, value_cache[:, :, end_idx:, :]), dim=2
+ )
+ attn_output = torch._C._nn.scaled_dot_product_attention(
+ q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
+ )
+
+ return attn_output, new_key_cache, new_value_cache
+
+
+def eager_sdpa(
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ scale=None,
+ enable_gqa=False,
+) -> torch.Tensor:
+ """
+ Eager implementation of SDPA
+ """
+ import math
+
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
+
+ if is_causal:
+ assert attn_mask is None
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda()
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
+ attn_bias.to(query.dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias = attn_mask + attn_bias
+
+ if enable_gqa:
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
+
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+ return attn_weight @ value
+
+
+def print_diff(tensor1, tensor2, prefix=""):
+ """
+ Print the diff between two tensors
+ """
+ print(
+ f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+ )
+
+
+def test_no_cache_model_with_torch_tensorrt(args):
+ """
+ Test the no cache model
+ """
+ with torch.inference_mode():
+ model_no_cache = ModelNoCache().eval().cuda()
+ # q = torch.randn(1, 32, 6, 64).cuda()
+ # k = torch.randn(1, 32, 6, 64).cuda()
+ # v = torch.randn(1, 32, 6, 64).cuda()
+ q = torch.load("query.pt")
+ k = torch.load("key.pt")
+ v = torch.load("value.pt")
+ out_no_cache = model_no_cache(q, k, v)
+ out_eager = eager_sdpa(q, k, v, is_causal=True)
+ q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
+ # Export the model
+ exported_program = torch.export.export(
+ model_no_cache,
+ args=(q, k, v),
+ dynamic_shapes=({2: q_seq_len}, {2: q_seq_len}, {2: q_seq_len}),
+ strict=False,
+ )
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ exported_program,
+ inputs=[q, k, v],
+ enabled_precisions={torch.float32},
+ disable_tf32=True,
+ debug=args.debug,
+ min_block_size=1,
+ )
+ out_trt = trt_model(q, k, v)
+
+ print_diff(out_no_cache, out_eager, "out_no_cache vs out_eager")
+ print_diff(out_no_cache, out_trt, "out_no_cache vs out_trt")
+ print_diff(out_eager, out_trt, "out_eager vs out_trt")
+ breakpoint()
+
+
+def test_static_cache_model(args):
+ """
+ Test the static cache model
+ """
+ with torch.inference_mode():
+ model_no_cache = ModelNoCache().eval().cuda()
+ model_static_cache = StaticCacheModel().eval().cuda()
+ q = torch.randn(1, 32, 2048, 64).cuda()
+ k = torch.randn(1, 32, 2048, 64).cuda()
+ v = torch.randn(1, 32, 2048, 64).cuda()
+ key_cache = torch.zeros(1, 32, 2176, 64).cuda()
+ value_cache = torch.zeros(1, 32, 2176, 64).cuda()
+
+ # Test Prefill
+ start_idx = 0
+ end_idx = 2048
+ out_no_cache = model_no_cache(q, k, v)
+ out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+ q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+ )
+ assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL)
+
+ # Test Generate
+ for start_idx in range(2048, 2176):
+ end_idx = start_idx + 1
+ q_curr = torch.randn(1, 32, 1, 64).cuda()
+ k_curr = torch.randn(1, 32, 1, 64).cuda()
+ v_curr = torch.randn(1, 32, 1, 64).cuda()
+
+ # Concatenate the current query, key, and value with the previous ones
+ q_full = torch.cat((q, q_curr), dim=2)
+ k_full = torch.cat((k, k_curr), dim=2)
+ v_full = torch.cat((v, v_curr), dim=2)
+
+ out_no_cache = model_no_cache(q_full, k_full, v_full)
+ out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+ q_curr,
+ k_curr,
+ v_curr,
+ new_key_cache,
+ new_value_cache,
+ start_idx,
+ end_idx,
+ is_causal=False,
+ )
+
+ assert torch.allclose(
+ out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL
+ )
+ q = q_full
+ k = k_full
+ v = v_full
+ print("============== test_static_cache passed ==============")
+
+
+def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args):
+ """
+ Transform the graph module by adding key and value cache to the graph
+ """
+ gm = exported_program.module()
+ # Post lower the model
+ settings = torch_tensorrt.dynamo.conversion.CompilationSettings(
+ enabled_precisions={torch.float32},
+ disable_tf32=True,
+ use_python_runtime=True,
+ debug=args.debug,
+ min_block_size=1,
+ )
+ exported_program = pre_export_lowering(exported_program, settings)
+ exported_program = exported_program.run_decompositions(get_decompositions(False))
+
+ gm = exported_program.module()
+ gm = post_lowering(gm, settings)
+
+ return gm
+
+
+def test_static_cache_lowering(args):
+ """
+ Test static cache lowering pass applied to the model with no cache and run the graph module
+ and compare the output with the model with no cache
+ """
+ import static_cache2
+
+ model_no_cache = ModelNoCache().eval().cuda()
+ q = torch.randn(1, 32, 2, 64).cuda()
+ k = torch.randn(1, 32, 2048, 64).cuda()
+ v = torch.randn(1, 32, 2048, 64).cuda()
+ key_cache = torch.zeros(1, 32, 2176, 64).cuda()
+ value_cache = torch.zeros(1, 32, 2176, 64).cuda()
+
+ # Export the model
+ q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
+ kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
+ exported_program = export(
+ model_no_cache,
+ args=(q, k, v),
+ dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+ strict=False,
+ )
+
+ gm = transform_gm_with_kv_cache(exported_program, args)
+
+ # Test Prefill
+ start_idx = 0
+ end_idx = 2048
+ is_causal = True
+ q = torch.randn(1, 32, 2048, 64).cuda()
+ out_no_cache = model_no_cache(q, k, v)
+ out_pyt_cache, key_cache, value_cache = gm(
+ q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal
+ )
+ assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL)
+
+ # Test Generate
+ for start_idx in range(2048, 2176):
+ end_idx = start_idx + 1
+ is_causal = False
+ q_curr = torch.randn(1, 32, 1, 64).cuda()
+ k_curr = torch.randn(1, 32, 1, 64).cuda()
+ v_curr = torch.randn(1, 32, 1, 64).cuda()
+ # Concatenate the current query, key, and value with the previous ones
+ q_full = torch.cat((q, q_curr), dim=2)
+ k_full = torch.cat((k, k_curr), dim=2)
+ v_full = torch.cat((v, v_curr), dim=2)
+
+ out_no_cache = model_no_cache(q_full, k_full, v_full)
+ out_pyt_static_cache, key_cache, value_cache = gm(
+ q_curr,
+ k_curr,
+ v_curr,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
+ assert torch.allclose(
+ out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL
+ )
+ q = q_full
+ k = k_full
+ v = v_full
+
+ print("============== test_static_cache_lowering passed ==============")
+
+
+def test_static_cache_export(args):
+ """
+ Test the static cache model export
+ """
+ model_static_cache = StaticCacheModel().eval().cuda()
+ q = torch.randn(1, 32, 2048, 64).cuda()
+ k = torch.randn(1, 32, 2048, 64).cuda()
+ v = torch.randn(1, 32, 2048, 64).cuda()
+ key_cache = torch.zeros(1, 32, 2176, 64).cuda()
+ value_cache = torch.zeros(1, 32, 2176, 64).cuda()
+ # Test Prefill
+ start_idx = 0
+ end_idx = 2048
+ is_causal = True
+ # Export the model
+ seq_len = torch.export.Dim("seq_len", min=2, max=2048)
+ seq_len_dyn_dim = {2: seq_len}
+ exported_program = export(
+ model_static_cache,
+ args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal),
+ dynamic_shapes=(
+ seq_len_dyn_dim,
+ seq_len_dyn_dim,
+ seq_len_dyn_dim,
+ None,
+ None,
+ torch.export.Dim.DYNAMIC,
+ torch.export.Dim.DYNAMIC,
+ None,
+ ),
+ strict=False,
+ )
+
+
+def test_static_cache_with_torch_tensorrt(args):
+ """
+ Test the static cache model with torch_tensorrt
+ """
+ import static_cache_v2
+
+ model_no_cache = ModelNoCache().eval().cuda()
+ q = torch.randn(1, 32, 2, 64).cuda()
+ k = torch.randn(1, 32, 2048, 64).cuda()
+ v = torch.randn(1, 32, 2048, 64).cuda()
+ key_cache = torch.zeros(1, 32, 2176, 64).cuda()
+ value_cache = torch.zeros(1, 32, 2176, 64).cuda()
+
+ # Export the model
+ q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
+ kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
+ exported_program = export(
+ model_no_cache,
+ args=(q, k, v),
+ dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+ strict=False,
+ )
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ exported_program,
+ inputs=[q, k, v],
+ enabled_precisions={torch.float32},
+ disable_tf32=True,
+ use_python_runtime=True,
+ debug=args.debug,
+ min_block_size=1,
+ )
+
+ start_idx = 0
+ end_idx = 2048
+ is_causal = True
+ q = torch.randn(1, 32, 2048, 64).cuda()
+ # out_eager = eager_sdpa(q, k, v, is_causal=is_causal)
+ out_no_cache = model_no_cache(q, k, v)
+ out_trt, trt_key_cache, trt_value_cache = trt_model(
+ q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal
+ )
+
+ assert torch.allclose(
+ out_no_cache, out_trt, atol=ATOL, rtol=RTOL
+ ), "Prefill TRT logits don't match"
+ assert torch.allclose(
+ trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL
+ ), "Prefill TRT key cache don't match"
+ assert torch.allclose(
+ trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL
+ ), "Prefill TRT value cache don't match"
+
+ # Test Generate
+ for start_idx in range(2048, 2176):
+ end_idx = start_idx + 1
+ q_curr = torch.randn(1, 32, 1, 64).cuda()
+ k_curr = torch.randn(1, 32, 1, 64).cuda()
+ v_curr = torch.randn(1, 32, 1, 64).cuda()
+ # Concatenate the current query, key, and value with the previous ones
+ q_full = torch.cat((q, q_curr), dim=2)
+ k_full = torch.cat((k, k_curr), dim=2)
+ v_full = torch.cat((v, v_curr), dim=2)
+ is_causal = True
+ out_no_cache = model_no_cache(q_full, k_full, v_full)
+ out_trt, trt_key_cache, trt_value_cache = trt_model(
+ q_curr,
+ k_curr,
+ v_curr,
+ trt_key_cache,
+ trt_value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
+ # breakpoint()
+ # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}")
+ # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}")
+ # print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}")
+ assert torch.allclose(
+ out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL
+ ), f"Generate TRT logits don't match for idx {start_idx}"
+ assert torch.allclose(
+ trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL
+ ), f"Generate TRT key cache don't match for idx {start_idx}"
+ assert torch.allclose(
+ trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL
+ ), f"Generate TRT value cache don't match for idx {start_idx}"
+ q = q_full
+ k = k_full
+ v = v_full
+
+ print("============== test_static_cache_with_torch_tensorrt passed ==============")
+
+
+def main():
+ arg_parser = argparse.ArgumentParser(
+ description="Run test cases for llama attention and decoder"
+ )
+ arg_parser.add_argument(
+ "--debug", action="store_true", help="Enable debug (default: False)"
+ )
+ args = arg_parser.parse_args()
+ with torch.inference_mode():
+ # test_no_cache_model_with_torch_tensorrt(args)
+ # test_static_cache_model(args)
+ # test_static_cache_lowering(args)
+ test_static_cache_with_torch_tensorrt(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/dynamo/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py
similarity index 86%
rename from examples/dynamo/register_sdpa.py
rename to tools/llm/torchtrt_ext/register_sdpa.py
index 7436f31939..90a00a5798 100644
--- a/examples/dynamo/register_sdpa.py
+++ b/tools/llm/torchtrt_ext/register_sdpa.py
@@ -4,7 +4,6 @@
from typing import Callable, Sequence, Tuple
import torch
-from sdpa_converter import *
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS
@@ -15,15 +14,19 @@
clean_up_graph_after_modifications,
)
+from .sdpa_converter import *
+
logger = logging.getLogger(__name__)
# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention
# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it.
-TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default)
+TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None)
+TORCH_TRT_DECOMPOSITIONS.pop(
+ torch.ops.aten._scaled_dot_product_efficient_attention.default, None
+)
TORCH_TRT_DECOMPOSITIONS.pop(
- torch.ops.aten._scaled_dot_product_efficient_attention.default
+ torch.ops.aten._scaled_dot_product_flash_attention.default, None
)
-TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default)
REPLACEABLE_ATEN_OPS = {
torch.ops.aten._scaled_dot_product_efficient_attention.default,
@@ -59,6 +62,7 @@ def replace_variants_of_sdpa(
elif len(node.args) == 5:
query, key, value, attn_mask, is_causal = node.args
dropout_p = 0.0
+
else:
raise ValueError(
f"Unexpected number of arguments for {node.target} in the graph"
@@ -71,6 +75,8 @@ def replace_variants_of_sdpa(
query, key, value, dropout_p, is_causal, return_debug_mask = (
node.args
)
+ if len(node.args) == 5:
+ query, key, value, dropout_p, is_causal = node.args
elif len(node.args) == 3:
query, key, value = node.args
dropout_p = 0.0
@@ -79,20 +85,21 @@ def replace_variants_of_sdpa(
raise ValueError(
f"Unexpected number of arguments for {node.target} in the graph"
)
- if attn_mask is not None:
- logger.warning(
- f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration."
- )
-
- modified_input_args = (query, key, value, None, dropout_p, is_causal)
+ logger.warning(
+ f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
+ )
+ modified_input_args = (query, key, value, None, dropout_p, True)
# Create a new node with torch.nn.functional.scaled_dot_product_attention
# The input args is (query, key, value, is_causal). kwargs has scale
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
torch.nn.functional.scaled_dot_product_attention,
args=modified_input_args,
- kwargs={"scale": node.kwargs.get("scale", None)},
+ kwargs={
+ "scale": node.kwargs.get("scale", None),
+ "use_fp32_acc": settings.use_fp32_acc,
+ },
)
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
@@ -113,7 +120,7 @@ def replace_variants_of_sdpa(
# Clean up the graph
clean_up_graph_after_modifications(gm)
- logger.info(
+ logger.debug(
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
)
return gm
diff --git a/examples/dynamo/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py
similarity index 51%
rename from examples/dynamo/sdpa_converter.py
rename to tools/llm/torchtrt_ext/sdpa_converter.py
index 903324dff5..47083c7b48 100644
--- a/examples/dynamo/sdpa_converter.py
+++ b/tools/llm/torchtrt_ext/sdpa_converter.py
@@ -62,25 +62,15 @@ def scaled_dot_product_attention(
) -> TRTTensor:
# TODO: Handle attn_mask and is_causal arguments in the future
query, key, value, attn_mask, dropout_p, is_causal = args
- logger.info(
- "Ignoring attn_mask and is_causal arguments provided by the original graph. "
- "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True "
- "and for generate phase, is_causal=False since we pass only 1 input token at a time"
- )
# TODO: remove this once we have a better way to handle the causal mask
scale = kwargs.get("scale", None)
source_ir = SourceIR.ATEN
+ is_causal = True
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
- mm = impl.matmul.matrix_multiply(
- ctx,
- target,
- source_ir,
- name + "_mm",
- query,
- key,
- other_matrix_op=trt.MatrixOperation.TRANSPOSE,
- )
+ use_fp32_acc = kwargs.get("use_fp32_acc", False)
+ query_dtype = query.dtype
+
if scale is None:
scale = query.shape[-1]
if scale < 0:
@@ -90,80 +80,106 @@ def scaled_dot_product_attention(
else:
# static shape
sqrt_scaled = math.sqrt(scale)
- scaled = impl.elementwise.div(
+ key = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
- mm,
+ key,
sqrt_scaled,
)
else:
- scaled = impl.elementwise.mul(
+ key = impl.elementwise.mul(
ctx,
target,
source_ir,
name + "_scale",
- mm,
+ key,
scale,
)
- # If is_causal is True, we need to generate a causal mask
- if is_causal:
- L, S = query.shape[-2], key.shape[-2]
- if L >= 0 and S >= 0:
- # static shape
- attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
- temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
- attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
- attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
- else:
- # if any of the L or S is dynamic shape
- if L < 0:
- L = impl.shape.shape(
- ctx, target, source_ir, name + "_shape_0", query, 2
- )
- if S < 0:
- S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
-
- # generate the mask tensor
- tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
-
- temp_mask = impl.unary.logical_not(
- ctx, target, source_ir, name + "_logical_not", tril_tensor
- )
- temp_mask_casted = cast_trt_tensor(
- ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir
- )
- one_minus_temp_mask = impl.elementwise.sub(
- ctx,
- target,
- source_ir,
- name + "_one_minus_temp_mask",
- 1.0,
- temp_mask_casted,
- )
- attn_bias = impl.unary.log(
- ctx, target, source_ir, name + "_log", one_minus_temp_mask
- )
-
- scaled_add_attn_bias = impl.elementwise.add(
- ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
+ if use_fp32_acc and query_dtype == trt.float16:
+ query = cast_trt_tensor(
+ ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir
+ )
+ key = cast_trt_tensor(
+ ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir
)
+
+ mm = impl.matmul.matrix_multiply(
+ ctx,
+ target,
+ source_ir,
+ name + "_mm",
+ query,
+ key,
+ other_matrix_op=trt.MatrixOperation.TRANSPOSE,
+ )
+
+ if use_fp32_acc:
+ mm = cast_trt_tensor(
+ ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir
+ )
+
+ L, S = query.shape[-2], key.shape[-2]
+ if L >= 0 and S >= 0:
+ # static shape
+ attn_bias = np.zeros((L, S), dtype=dtype._from(query_dtype).to(np.dtype))
+ temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
+ attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
+ attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
else:
- scaled_add_attn_bias = scaled
+ # if any of the L or S is dynamic shape
+ if L < 0:
+ L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2)
+ if S < 0:
+ S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
- # Create a if condition to check if is_causal is True
- if isinstance(is_causal, TRTTensor):
- if_layer = ctx.net.add_if_conditional()
- condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled
- if_layer.set_condition(condition)
- output_layer = if_layer.add_output(true_branch, false_branch)
- scaled_add_attn_bias = output_layer.get_output(0)
+ # generate the mask tensor
+ tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
+
+ temp_mask = impl.unary.logical_not(
+ ctx, target, source_ir, name + "_logical_not", tril_tensor
+ )
+
+ # This need_mask determines if we want to use the causal mask or not
+ # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
+ # So need_mask will be all False values in this case.
+ # TODO: Implement more general case where L != 1 and S != L
+ need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
+ temp_mask = impl.elementwise.logical_and(
+ ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
+ )
+ temp_mask_casted = cast_trt_tensor(
+ ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
+ )
+
+ one_minus_temp_mask = impl.elementwise.sub(
+ ctx,
+ target,
+ source_ir,
+ name + "_one_minus_temp_mask",
+ 1.0,
+ temp_mask_casted,
+ )
+ attn_bias = impl.unary.log(
+ ctx, target, source_ir, name + "_log", one_minus_temp_mask
+ )
+
+ scaled_add_attn_bias = impl.elementwise.add(
+ ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
+ )
softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False
)
+ if use_fp32_acc:
+ softmax = cast_trt_tensor(
+ ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir
+ )
+ value = cast_trt_tensor(
+ ctx, value, trt.float32, name + "_value_cast_to_fp32", target, source_ir
+ )
out = impl.matmul.matrix_multiply(
ctx,
target,
@@ -172,5 +188,9 @@ def scaled_dot_product_attention(
softmax,
value,
)
+ if use_fp32_acc:
+ out = cast_trt_tensor(
+ ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir
+ )
return out
diff --git a/tools/llm/utils.py b/tools/llm/utils.py
new file mode 100644
index 0000000000..2c3434b0ed
--- /dev/null
+++ b/tools/llm/utils.py
@@ -0,0 +1,244 @@
+import copy
+import timeit
+
+import numpy as np
+import torch
+from transformers import StoppingCriteriaList
+from transformers.generation.stopping_criteria import (
+ EosTokenCriteria,
+ MaxLengthCriteria,
+)
+
+
+def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
+ """
+ Exports the LLM model into an ExportedProgram with dynamic shapes.
+ In the case of guard failures due to some PyTorch kernel implements, we also
+ try to re-export the graph by expressing them as runtime assert nodes
+ """
+ with torch.no_grad():
+ # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
+ seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
+ position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device)
+ try:
+ print("Trying to export the model using torch.export.export()..")
+ # strict=False only enables aotautograd tracing and excludes dynamo.
+ ep = torch.export.export(
+ model,
+ args=(inputs,),
+ kwargs={"position_ids": position_ids},
+ dynamic_shapes=({1: seq_len}, {1: seq_len}),
+ strict=False,
+ )
+ except:
+ print(
+ "Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
+ )
+ # This API is used to express the constraint violation guards as asserts in the graph.
+ ep = torch.export._trace._export(
+ model,
+ args=(inputs,),
+ kwargs={"position_ids": position_ids},
+ dynamic_shapes=({1: seq_len}, {1: seq_len}),
+ strict=False,
+ allow_complex_guards_as_runtime_asserts=True,
+ )
+
+ return ep
+
+
+def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule):
+ """
+ Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2.
+
+ This function identifies placeholder nodes in the graph that represent KV cache tensors,
+ and creates zeroed tensors with the same shape, dtype, and device as the original placeholders.
+
+ Args:
+ model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders
+
+ Returns:
+ tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph
+ """
+ # placeholder nodes are expected to be in the following order:
+ # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx
+ placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"]
+ # The first two inputs are input_ids, position_ids. The last two inputs are start_idx, end_idx. In between are the KV cache tensors.
+ kv_cache_inputs = placeholder_nodes[2:-2]
+ zeroed_kv_cache_inputs = []
+ for input in kv_cache_inputs:
+ zeroed_kv_cache_inputs.append(
+ torch.zeros(
+ input.meta["val"].shape,
+ dtype=input.meta["val"].dtype,
+ device=torch.device("cuda:0"),
+ )
+ )
+
+ return tuple(zeroed_kv_cache_inputs)
+
+
+def get_zeroed_dynamic_cache_inputs(model: torch.fx.GraphModule):
+ """
+ Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. This should only be used for dynamic cache.
+
+ This function identifies placeholder nodes in the graph that represent KV cache tensors,
+ and creates zeroed tensors with the same shape, dtype, and device as the original placeholders.
+
+ Args:
+ model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders
+
+ Returns:
+ tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph
+ """
+ # placeholder nodes are expected to be in the following order:
+ # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx
+ placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"]
+ # The first two inputs are input_ids, position_ids. The last input is is_generate. In between are the KV cache tensors.
+ kv_cache_inputs = placeholder_nodes[2:-1]
+ zeroed_kv_cache_inputs = []
+ for input in kv_cache_inputs:
+ zeroed_kv_cache_inputs.append(
+ torch.zeros(
+ input.meta["val"].shape,
+ dtype=input.meta["val"].dtype,
+ device=torch.device("cuda:0"),
+ )
+ )
+
+ return tuple(zeroed_kv_cache_inputs)
+
+
+def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True):
+ """
+ Greedy decoding of the model. This generates up to max_tokens.
+ """
+ stopping_criteria = StoppingCriteriaList(
+ [
+ MaxLengthCriteria(max_length=max_output_seq_length),
+ EosTokenCriteria(eos_token_id=eos_token_id),
+ ]
+ )
+ isl = input_seq.shape[1]
+ osl = max_output_seq_length - isl
+
+ num_tokens_generated = 0
+ while num_tokens_generated < osl:
+ position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda()
+ outputs = model(input_seq, position_ids=position_ids)
+ logits = outputs.logits
+ next_token_logits = logits[:, -1, :]
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
+ input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1)
+ num_tokens_generated += 1
+ # TODO: Handle batch in this check
+ if not benchmark and stopping_criteria(input_seq, logits).item():
+ break
+
+ return input_seq
+
+
+def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_token_id):
+ """
+ Greedy decoding of the model with static KV cache.
+ """
+ start_idx = 0
+ end_idx = input_seq.shape[1]
+ position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda()
+ output_seq = input_seq.clone()
+ # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL
+ num_tokens_generated = 0
+ kv_cache = get_zeroed_static_cache_inputs(model)
+ while end_idx < max_output_seq_length:
+ position_ids = (
+ torch.tensor([[start_idx]], dtype=torch.int64).cuda()
+ if input_seq.shape[1] == 1
+ else position_ids
+ )
+ input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx)
+ logits_keys_values = model(*input_signature)
+ num_tokens_generated += 1
+ logits = logits_keys_values[0]
+ kv_cache = logits_keys_values[1:]
+ next_token_logits = logits[:, -1, :]
+ next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
+ output_seq = torch.cat([output_seq, next_tokens], dim=-1)
+ input_seq = next_tokens
+ start_idx = end_idx
+ end_idx = start_idx + 1
+ return output_seq
+
+
+def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_token_id):
+ """
+ Greedy decoding of the model with dynamic KV cache.
+ """
+ position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda()
+ output_seq = input_seq.clone()
+ num_output_tokens = max_output_seq_length - input_seq.shape[1]
+ num_tokens_generated = 0
+ kv_cache = get_zeroed_dynamic_cache_inputs(model)
+ last_position_id = position_ids[-1, -1].item()
+ breakpoint()
+ while num_tokens_generated < num_output_tokens:
+ is_generate = False if input_seq.shape[1] > 1 else True
+ position_ids = (
+ torch.tensor([[last_position_id + 1]], dtype=torch.int64).cuda()
+ if input_seq.shape[1] == 1
+ else position_ids
+ )
+ input_signature = (input_seq, position_ids, *kv_cache, is_generate)
+ logits_keys_values = model(*input_signature)
+ num_tokens_generated += 1
+ logits = logits_keys_values[0]
+ kv_cache = logits_keys_values[1:]
+ next_token_logits = logits[:, -1, :]
+ next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
+ output_seq = torch.cat([output_seq, next_tokens], dim=-1)
+ input_seq = next_tokens
+ last_position_id += 1
+ return output_seq
+
+
+def time_generate(
+ generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10
+):
+ """
+ Measure the time for generating a sentence over certain number of iterations
+ """
+ timings = []
+ for _ in range(iterations):
+ start_time = timeit.default_timer()
+ _ = generate_fn(model, inputs, output_seq_length, eos_token_id)
+ torch.cuda.synchronize()
+ end_time = timeit.default_timer()
+ timings.append(end_time - start_time)
+
+ return timings
+
+
+def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None):
+ """
+ Records different timing stats and adds it to the result
+ """
+ times = np.array(timings)
+ speeds = batch_size / times
+ time_mean = np.mean(times).item()
+ time_med = np.median(times).item()
+ time_99th = np.percentile(times, 99).item()
+ time_std = np.std(times, ddof=0).item()
+ speed_mean = np.mean(speeds).item()
+ speed_med = np.median(speeds).item()
+
+ stats = {
+ "Backend": backend,
+ "Precision": precision,
+ "Batch size": batch_size,
+ "Median(FPS)": speed_med,
+ "Mean(FPS)": speed_mean,
+ "Median-Latency(ms)": time_med * 1000,
+ "Mean-Latency(ms)": time_mean * 1000,
+ "Latency-StdDev(ms)": time_std * 1000,
+ "Compile Time(s)": compile_time_s,
+ }
+ return stats