-
Notifications
You must be signed in to change notification settings - Fork 364
feat: caching attempts #3527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
peri044
wants to merge
13
commits into
main
Choose a base branch
from
kv_cache
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
feat: caching attempts #3527
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/cache_utils.py 2025-05-28 23:49:40.726795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/cache_utils.py 2025-05-28 23:50:03.207085+00:00
@@ -5,81 +5,90 @@
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._pytree import _LEAF_SPEC
from torch._export.utils import _detect_fake_mode_from_gm
+
def get_kv_nodes(gm):
"""
Get the key and value nodes from the graph.
"""
kv_nodes = []
for node in gm.graph.nodes:
- if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
+ if (
+ node.op == "call_function"
+ and node.target == torch._C._nn.scaled_dot_product_attention
+ ):
q_node, k_node, v_node = node.args[:3]
kv_nodes.append((k_node, v_node))
return kv_nodes
+
def get_random_tensor_from_node(node: Node) -> torch.Tensor:
- """
- Creates a random tensor based on the shape information in a node's metadata.
- For symbolic dimensions, extracts the maximum value from the shape environment.
-
- Args:
- node: A torch.fx.Node object with metadata containing tensor information
-
- Returns:
- A random tensor with shape matching the node's metadata, or None if no valid
- tensor information is found
- """
- if "val" not in node.meta:
- raise ValueError(f"No tensor information found in node metadata for node: {node}")
-
- fake_tensor = node.meta["val"]
- shape = []
-
- # Iterate through each dimension and handle symbolic dimensions
- for dim in fake_tensor.shape:
- if isinstance(dim, torch.SymInt):
- # Extract the maximum value from the shape environment
- max_val = dim.node.hint
- shape.append(max_val)
- else:
- shape.append(dim)
-
- # Create a random tensor with the determined shape
- dtype = fake_tensor.dtype
- device = fake_tensor.device
- random_tensor = torch.rand(shape, dtype=dtype, device=device)
+ """
+ Creates a random tensor based on the shape information in a node's metadata.
+ For symbolic dimensions, extracts the maximum value from the shape environment.
- return random_tensor
+ Args:
+ node: A torch.fx.Node object with metadata containing tensor information
+
+ Returns:
+ A random tensor with shape matching the node's metadata, or None if no valid
+ tensor information is found
+ """
+ if "val" not in node.meta:
+ raise ValueError(
+ f"No tensor information found in node metadata for node: {node}"
+ )
+
+ fake_tensor = node.meta["val"]
+ shape = []
+
+ # Iterate through each dimension and handle symbolic dimensions
+ for dim in fake_tensor.shape:
+ if isinstance(dim, torch.SymInt):
+ # Extract the maximum value from the shape environment
+ max_val = dim.node.hint
+ shape.append(max_val)
+ else:
+ shape.append(dim)
+
+ # Create a random tensor with the determined shape
+ dtype = fake_tensor.dtype
+ device = fake_tensor.device
+ random_tensor = torch.rand(shape, dtype=dtype, device=device)
+
+ return random_tensor
+
def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]:
"""
Creates random tensors based on the shape information in node metadata.
For symbolic dimensions, extracts the maximum value from the shape environment.
-
+
Args:
nodes: List of torch.fx.Node objects with metadata
-
+
Returns:
List of random tensors with shapes matching the nodes' metadata
"""
random_tensors = []
-
+
for node in nodes:
if isinstance(node, Node):
node_tensor = get_random_tensor_from_node(node)
elif isinstance(node, tuple):
node_tensor_list = []
for n in node:
random_tensor = get_random_tensor_from_node(n)
node_tensor_list.append(random_tensor)
node_tensor = tuple(node_tensor_list)
-
+
random_tensors.append(node_tensor)
-
+
return random_tensors
+
def add_graph_input(
gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
) -> Node:
"""Add a graph input to the given GraphModule and return the newly created node.
@@ -130,10 +139,11 @@
in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor)
# return new node...
return in_node
+
def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool:
"""Check if the node is a call to one of the ops."""
if node.op != "call_function":
return False
# check if it's a single op that's provided
@@ -144,9 +154,10 @@
if any(node.target == op for op in ops):
return True
return False
+
def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]:
input_nodes: List[Node] = graph.find_nodes(op="placeholder")
output_nodes: List[Node] = graph.find_nodes(op="output")
- return (input_nodes, output_nodes)
\ No newline at end of file
+ return (input_nodes, output_nodes)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama_benchmark.py 2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama_benchmark.py 2025-05-28 23:50:03.276791+00:00
@@ -10,68 +10,72 @@
def main():
# Initialize model and tokenizer
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
- MODEL_NAME,
- torch_dtype=torch.float16,
- use_cache=False,
- device_map="auto"
+ MODEL_NAME, torch_dtype=torch.float16, use_cache=False, device_map="auto"
)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward)
-
+
# Prepare input prompt
word = "What"
# Tokenize the word
- word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence
+ word_ids = tokenizer(word, return_tensors="pt").input_ids[
+ 0
+ ] # Get the first (and only) sequence
# Repeat the token 2048 times
- input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device
+ input_ids = (
+ word_ids.repeat(1024).unsqueeze(0).to(model.device)
+ ) # Add batch dimension and move to device
print(f"Input tensor shape: {input_ids.shape}")
# # Warm-up pass
print("Running warm-up pass...")
output_ids = model.generate(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
- use_cache=USE_CACHE
+ use_cache=USE_CACHE,
)
-
+
# Benchmark loop
print("Running benchmark...")
num_iterations = 10
total_time = 0
timings = []
-
+
for i in range(num_iterations):
start_time = timeit.default_timer()
output_ids = model.generate(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
- use_cache=USE_CACHE
+ use_cache=USE_CACHE,
)
end_time = timeit.default_timer()
generation_time = end_time - start_time
total_time += generation_time
timings.append(generation_time)
-
+
# Decode and print first iteration output
# if i == 0:
# output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# print("\nFirst generation output:")
# print(output_text)
-
+
# Calculate and print statistics
average_time = total_time / num_iterations
print(f"\nPerformance Statistics:")
- print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds")
+ print(
+ f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds"
+ )
print(f"Average tokens per second: {100/average_time:.2f}")
print("\nIndividual timings (ms):")
for i, t in enumerate(timings):
print(f"Iteration {i+1}: {t*1000:.2f}")
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/dynamic_cache.py 2025-05-28 23:49:40.726795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/dynamic_cache.py 2025-05-28 23:50:03.285151+00:00
@@ -12,36 +12,45 @@
from torch_tensorrt.dynamo.utils import extract_var_range_info
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
-from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op
+from cache_utils import (
+ add_graph_input,
+ create_random_output_tensors,
+ get_kv_nodes,
+ is_op,
+)
import tensorrt
import torch.utils._pytree as pytree
+
logger = logging.getLogger(__name__)
-@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True)
+
+@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
+ torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True
+)
def cond_converter(
ctx: torch_tensorrt.dynamo.conversion.ConversionContext,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: str,
) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]:
"""
Converter for torch.ops.higher_order.cond operation to TensorRT.
-
+
This function handles the conversion of PyTorch's conditional operation to TensorRT.
The conditional operation selects between two tensors based on a boolean predicate.
-
+
Args:
ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context
target (Target): The target operation to convert
args (Tuple[Argument, ...]): The arguments to the operation
kwargs (Dict[str, Argument]): The keyword arguments to the operation
name (str): The name to give to the TensorRT layer
-
+
Returns:
Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s)
"""
if_layer = ctx.net.add_if_conditional()
condition, true_branch, false_branch = args[0], args[1], args[2]
@@ -49,30 +58,31 @@
output_layer = if_layer.add_output(true_branch, false_branch)
output = output_layer.get_output(0)
return output
+
def add_kv_as_outputs(gm):
"""
Modifies the graph to add query, key, and value tensors as outputs.
-
+
This function identifies all scaled dot-product attention (SDPA) operations
in the graph, creates copies of their query, key, and value inputs, and adds
these copies to the graph's outputs. This allows for accessing these tensors
externally, which is useful for operations like key-value caching.
-
+
Args:
graph: The torch.fx.Graph to modify
-
+
Returns:
None. The graph is modified in-place.
"""
# list of MHA kernels we would want to detect and replace
mha_ops = {
torch._C._nn.scaled_dot_product_attention,
}
-
+
# Find all SDPA nodes in the graph
mha_nodes = []
for node in gm.graph.nodes:
if is_op(node, mha_ops):
mha_nodes.append(node)
@@ -80,157 +90,170 @@
# Iterate through each MHA node to extract shape information
for mha_node in mha_nodes:
if "val" in mha_node.meta and len(mha_node.args) >= 3:
# Get the input nodes (query, key, value)
q_node, k_node, v_node = mha_node.args[:3]
-
+
# Add the copy nodes as outputs to the graph
- output_node = next(node for node in gm.graph.nodes if node.op == "output")
+ output_node = next(node for node in gm.graph.nodes if node.op == "output")
# Get the current output args (typically a tuple)
current_outputs = output_node.args[0]
-
+
# If the current output is a tuple, extend it with our new outputs
if isinstance(current_outputs, tuple):
new_outputs = current_outputs + ((k_node, v_node),)
else:
# If there's only one output or it's not a tuple, create a new tuple
new_outputs = (current_outputs, (k_node, v_node))
-
+
gm.graph.output(new_outputs)
gm.graph.erase_node(output_node)
-
+
return new_outputs
-
-
def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True):
- """
- Add key-value tensors and index parameters as inputs to the graph.
-
- Args:
- gm: The GraphModule to modify
- fixed_kv: Boolean indicating whether to use static tensors for KV cache
-
- Returns:
- A tuple containing:
- - List of (k_input, v_input) node pairs for each SDPA operation
- - start_idx input node for slicing operations
- - end_idx input node for slicing operations
- """
-
- def get_static_tensor(tensor: torch.Tensor):
- key_shape = []
- for dim in tensor.shape:
- if isinstance(dim, torch.SymInt):
- min_max_opt = extract_var_range_info(dim)
- key_shape.append(min_max_opt["max"])
- else:
- key_shape.append(dim)
-
- static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
- return static_tensor
-
- keys_values = get_kv_nodes(gm)
-
- kv_inputs = []
- for idx, key_value in enumerate(keys_values):
- k_val = key_value[0].meta["val"]
- v_val = key_value[1].meta["val"]
- if fixed_kv:
- k_val = get_static_tensor(k_val)
- v_val = get_static_tensor(v_val)
-
- # Add new inputs using add_graph_input
- k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
- v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
- kv_inputs.append((k_input, v_input))
-
- return kv_inputs
-
-
-def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]):
+ """
+ Add key-value tensors and index parameters as inputs to the graph.
+
+ Args:
+ gm: The GraphModule to modify
+ fixed_kv: Boolean indicating whether to use static tensors for KV cache
+
+ Returns:
+ A tuple containing:
+ - List of (k_input, v_input) node pairs for each SDPA operation
+ - start_idx input node for slicing operations
+ - end_idx input node for slicing operations
+ """
+
+ def get_static_tensor(tensor: torch.Tensor):
+ key_shape = []
+ for dim in tensor.shape:
+ if isinstance(dim, torch.SymInt):
+ min_max_opt = extract_var_range_info(dim)
+ key_shape.append(min_max_opt["max"])
+ else:
+ key_shape.append(dim)
+
+ static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
+ return static_tensor
+
+ keys_values = get_kv_nodes(gm)
+
+ kv_inputs = []
+ for idx, key_value in enumerate(keys_values):
+ k_val = key_value[0].meta["val"]
+ v_val = key_value[1].meta["val"]
+ if fixed_kv:
+ k_val = get_static_tensor(k_val)
+ v_val = get_static_tensor(v_val)
+
+ # Add new inputs using add_graph_input
+ k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+ v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
+ kv_inputs.append((k_input, v_input))
+
+ return kv_inputs
+
+
+def insert_torch_cond_before_sdpa(
+ gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]
+):
"""
Insert a torch.cond operation before each scaled_dot_product_attention operation.
-
+
Args:
gm: The FX GraphModule to modify
-
+
Returns:
The modified GraphModule
"""
# Find all nodes with scaled_dot_product_attention
sdpa_nodes = []
for node in gm.graph.nodes:
- if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
+ if (
+ node.op == "call_function"
+ and node.target == torch._C._nn.scaled_dot_product_attention
+ ):
sdpa_nodes.append(node)
-
- # Get the is_causal input node
- is_causal_node = next((node for node in gm.graph.nodes if node.op == "placeholder" and node.name == "is_causal"), None)
+
+ # Get the is_causal input node
+ is_causal_node = next(
+ (
+ node
+ for node in gm.graph.nodes
+ if node.op == "placeholder" and node.name == "is_causal"
+ ),
+ None,
+ )
# For each SDPA node, insert a torch.cond operation before it
for idx, sdpa_node in enumerate(sdpa_nodes):
-
+
with gm.graph.inserting_before(sdpa_node):
# pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool))
q_node, k_node, v_node = sdpa_node.args[:3]
incoming_key, incoming_value = incoming_keys_values[idx]
# Create nodes for concatenating k with incoming_key and v with incoming_value
concatenated_k_node = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([incoming_key, k_node], 2), # Concatenate along sequence length dimension
- kwargs={}
+ args=(
+ [incoming_key, k_node],
+ 2,
+ ), # Concatenate along sequence length dimension
+ kwargs={},
)
concatenated_v_node = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([incoming_value, v_node], 2), # Concatenate along sequence length dimension
- kwargs={}
- )
-
+ args=(
+ [incoming_value, v_node],
+ 2,
+ ), # Concatenate along sequence length dimension
+ kwargs={},
+ )
+
# Create the torch.cond node
cond_k_node = gm.graph.create_node(
"call_function",
torch.ops.higher_order.cond,
args=(is_causal_node, concatenated_k_node, k_node),
)
-
+
cond_v_node = gm.graph.create_node(
"call_function",
torch.ops.higher_order.cond,
args=(is_causal_node, concatenated_v_node, v_node),
)
sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:]
-
+
return gm
-
@_aten_lowering_pass
def insert_dynamic_kv_cache(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Insert FlashInfer MHA + KV cache ops in the graph"""
"""Perform insertion of kv-caches and attention kernel."""
# Add static key and value as inputs to the graph
- kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True)
+ kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True)
# Call the function to add KV as outputs
logits_keys_values = add_kv_as_outputs(gm)
# Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases
gm = insert_torch_cond_before_sdpa(gm, kv_inputs)
gm = clean_up_graph_after_modifications(gm)
-
+
new_output_tensors = create_random_output_tensors(logits_keys_values)
new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
gm._out_spec = new_out_spec
-
+
logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/test_sdpa.py 2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/test_sdpa.py 2025-05-28 23:50:03.466989+00:00
@@ -9,101 +9,117 @@
from contextlib import nullcontext
import argparse
# llama2_model_name = "meta-llama/Llama-2-7b-hf"
llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct"
-llama_model = AutoModelForCausalLM.from_pretrained(
- llama3_model_name,
- use_cache=False,
- attn_implementation="sdpa",
- num_hidden_layers=1,
- ).eval().cuda()
+llama_model = (
+ AutoModelForCausalLM.from_pretrained(
+ llama3_model_name,
+ use_cache=False,
+ attn_implementation="sdpa",
+ num_hidden_layers=1,
+ )
+ .eval()
+ .cuda()
+)
LLAMA_CONFIG = llama_model.config
+
def test_llama_attention(args):
class LlamaAttentionBlock(nn.Module):
def __init__(self):
super().__init__()
self.config = LLAMA_CONFIG
- self.attn = LlamaAttention(
- config=self.config,
- layer_idx=0
+ self.attn = LlamaAttention(config=self.config, layer_idx=0)
+
+ def forward(self, hidden_states, position_embeddings):
+ attn_output, attn_weights = self.attn(
+ hidden_states, position_embeddings, None
)
- def forward(self, hidden_states, position_embeddings):
- attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None)
return attn_output
-
+
DTYPE = torch.float32
# model = LlamaAttentionBlock().eval().cuda().to(DTYPE)
model = llama_model.model.layers[0].self_attn.to(DTYPE)
- # llama3
+ # llama3
# hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
# position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda())
hidden_states = torch.load("hidden_states.pt")
position_embeddings = torch.load("position_embeddings.pt")
# breakpoint()
pyt_output = model(hidden_states, position_embeddings, None)
-
+
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
- ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
-
+ ep = torch.export.export(
+ model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+ )
+
with torch_tensorrt.logging.debug():
- trt_model = torch_tensorrt.dynamo.compile(ep,
- inputs=[hidden_states, position_embeddings, None],
- enabled_precisions={torch.float32},
- disable_tf32=True,
- debug=True)
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[hidden_states, position_embeddings, None],
+ enabled_precisions={torch.float32},
+ disable_tf32=True,
+ debug=True,
+ )
trt_output = trt_model(hidden_states, position_embeddings, None)
if isinstance(pyt_output, tuple):
- print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}")
+ print(
+ f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+ )
else:
print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")
-
+
def test_llama_decoder(args):
class LlamaDecoder(nn.Module):
def __init__(self):
super().__init__()
self.config = LLAMA_CONFIG
- self.decoder_layer = LlamaDecoderLayer(
- config=self.config,
- layer_idx=0
+ self.decoder_layer = LlamaDecoderLayer(config=self.config, layer_idx=0)
+
+ def forward(self, hidden_states, position_embeddings):
+ decoder_output = self.decoder_layer(
+ hidden_states, position_embeddings=position_embeddings
)
- def forward(self, hidden_states, position_embeddings):
- decoder_output = self.decoder_layer(hidden_states, position_embeddings=position_embeddings)
return decoder_output[0]
-
+
DTYPE = torch.float32
model = LlamaDecoder().eval().cuda().to(DTYPE)
- # llama3
+ # llama3
hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
- position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda())
+ position_embeddings = (
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ )
pyt_output = model(hidden_states, position_embeddings)
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
- ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes)
-
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
- trt_model = torch_tensorrt.dynamo.compile(ep,
- inputs=[hidden_states, position_embeddings],
- enabled_precisions={torch.float32},
- debug=args.debug)
+ ep = torch.export.export(
+ model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+ )
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[hidden_states, position_embeddings],
+ enabled_precisions={torch.float32},
+ debug=args.debug,
+ )
trt_output = trt_model(hidden_states, position_embeddings)
print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(
description="Run test cases for llama attention and decoder"
)
arg_parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable debug (default: False)"
+ "--debug", action="store_true", help="Enable debug (default: False)"
)
args = arg_parser.parse_args()
with torch.inference_mode():
test_llama_attention(args)
- # test_llama_decoder(args)
\ No newline at end of file
+ # test_llama_decoder(args)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama3_trt.py 2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama3_trt.py 2025-05-28 23:50:03.506307+00:00
@@ -17,40 +17,48 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from contextlib import nullcontext
-from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache, get_zeroed_kv_cache_inputs
+from utils import (
+ export_llm,
+ generate,
+ recordStats,
+ time_generate,
+ generate_with_kv_cache,
+ get_zeroed_kv_cache_inputs,
+)
DEVICE = torch.device("cuda:0")
+
def get_model(args):
with torch.no_grad():
if args.model == "meta-llama/Llama-2-7b-chat-hf":
model = (
AutoModelForCausalLM.from_pretrained(
args.model,
use_cache=False,
attn_implementation="sdpa",
- num_hidden_layers=1
+ num_hidden_layers=1,
)
.eval()
.cuda()
)
elif args.model == "meta-llama/Llama-3.2-1B-Instruct":
model = (
AutoModelForCausalLM.from_pretrained(
args.model,
use_cache=False,
attn_implementation="sdpa",
- num_hidden_layers=1
+ num_hidden_layers=1,
)
.eval()
.cuda()
)
-
+
elif args.model == "meta-llama/Llama-3.2-3B-Instruct":
model = (
AutoModelForCausalLM.from_pretrained(
args.model,
use_cache=False,
@@ -71,13 +79,11 @@
.cuda()
)
elif args.model == "google/gemma-3-1b-it":
model = (
AutoModelForCausalLM.from_pretrained(
- "google/gemma-3-1b-it",
- use_cache=False,
- attn_implementation="sdpa"
+ "google/gemma-3-1b-it", use_cache=False, attn_implementation="sdpa"
)
.eval()
.cuda()
)
if args.precision == "FP16":
@@ -91,25 +97,25 @@
def compile_torchtrt(model, input_ids, args):
max_seq_len = input_ids.shape[1] + args.max_tokens
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
-
+
# Set precision specific flags
- use_fp32_acc = False
+ use_fp32_acc = False
use_explicit_typing = False
if args.precision == "FP16":
enabled_precisions = {torch.float32}
- use_fp32_acc = True
+ use_fp32_acc = True
use_explicit_typing = True
elif args.precision == "BF16":
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
else:
enabled_precisions = {torch.float32}
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[input_ids],
enabled_precisions=enabled_precisions,
# truncate_double=True,
@@ -132,44 +138,51 @@
tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
)
print("===================================")
-
def measure_perf(trt_model, input_signature, backend_name):
# Measure average time for 10 iterations
import timeit
import numpy as np
-
+
total_time = 0
iterations = 10
-
+
print("Running warmup iteration...")
# Warmup run
_ = trt_model(*input_signature)
torch.cuda.synchronize()
-
+
print(f"Measuring performance over {iterations} iterations...")
for i in range(iterations):
start_time = timeit.default_timer()
_ = trt_model(*input_signature)
torch.cuda.synchronize()
end_time = timeit.default_timer()
iter_time = end_time - start_time
total_time += iter_time
# print(f"Iteration {i+1}: {iter_time:.4f} seconds")
-
+
avg_time = total_time / iterations
- print(f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds")
- print(f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second")
+ print(
+ f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds"
+ )
+ print(
+ f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second"
+ )
+
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(
description="Run inference on a model with random input values"
)
arg_parser.add_argument(
- "--model", type=str, default="meta-llama/Llama-3.2-1B-Instruct", help="Name of LLM model"
+ "--model",
+ type=str,
+ default="meta-llama/Llama-3.2-1B-Instruct",
+ help="Name of LLM model",
)
arg_parser.add_argument(
"--tokenizer_path",
type=str,
default="meta-llama/Llama-3.2-1B-Instruct",
@@ -187,34 +200,28 @@
)
arg_parser.add_argument(
"--max_tokens", type=int, default=128, help="no. of max tokens to be generated"
)
arg_parser.add_argument(
- "--enable_pytorch_run",
- action="store_true",
- help="Enable pytorch run (default: False)"
+ "--enable_pytorch_run",
+ action="store_true",
+ help="Enable pytorch run (default: False)",
)
arg_parser.add_argument(
"--cache",
type=str,
default="static",
help="Type of KV cache to use",
)
arg_parser.add_argument(
- "--cudagraph",
- action="store_true",
- help="Enable cudagraphs (default: False)"
- )
- arg_parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable debug (default: False)"
- )
- arg_parser.add_argument(
- "--benchmark",
- action="store_true",
- help="Enable benchmark (default: False)"
+ "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)"
+ )
+ arg_parser.add_argument(
+ "--debug", action="store_true", help="Enable debug (default: False)"
+ )
+ arg_parser.add_argument(
+ "--benchmark", action="store_true", help="Enable benchmark (default: False)"
)
args = arg_parser.parse_args()
with torch.inference_mode():
model = get_model(args)
@@ -236,75 +243,118 @@
pyt_stats = None
if args.enable_pytorch_run:
pyt_gen_tokens = generate(
model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id
)
-
+
if args.benchmark:
pyt_timings = time_generate(
generate,
model,
input_ids.clone(),
MAX_OUTPUT_SEQ_LENGTH,
tokenizer.eos_token_id,
iterations=args.iterations,
)
pyt_stats = recordStats(
- "PyTorch", pyt_timings, args.precision, batch_size=1, compile_time_s=None
+ "PyTorch",
+ pyt_timings,
+ args.precision,
+ batch_size=1,
+ compile_time_s=None,
)
# TRT
pyt_logits_tok1 = model.cuda()(input_ids)
next_tokens = torch.argmax(pyt_logits_tok1.logits[:, -1, :], dim=-1)
input_seq = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
pyt_logits_tok2 = model.cuda()(input_seq)
from lower_sdpa import *
+
if args.cache == "static":
# This import is required to register static KV cache transformations as lowering passes
from static_cache2 import *
- trt_model = compile_torchtrt(model, input_ids, args)
+
+ trt_model = compile_torchtrt(model, input_ids, args)
kv_cache = get_zeroed_kv_cache_inputs(trt_model)
# First token generation
- pyt_keys = torch.load("key.pt"); pyt_values = torch.load("value.pt")
- trt_logits, key_cache, value_cache, trt_keys_1, trt_values_1 = trt_model(input_ids.clone(), True, *kv_cache, 0, input_ids.shape[1])
- print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok1.logits - trt_logits))}")
- print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys - trt_keys_1))}")
- print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys - key_cache[:, :, :-2, :]))}")
- print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values - trt_values_1))}")
- print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values - value_cache[:, :, :-2, :]))}")
+ pyt_keys = torch.load("key.pt")
+ pyt_values = torch.load("value.pt")
+ trt_logits, key_cache, value_cache, trt_keys_1, trt_values_1 = trt_model(
+ input_ids.clone(), True, *kv_cache, 0, input_ids.shape[1]
+ )
+ print(
+ f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok1.logits - trt_logits))}"
+ )
+ print(
+ f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys - trt_keys_1))}"
+ )
+ print(
+ f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys - key_cache[:, :, :-2, :]))}"
+ )
+ print(
+ f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values - trt_values_1))}"
+ )
+ print(
+ f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values - value_cache[:, :, :-2, :]))}"
+ )
next_tokens = torch.argmax(trt_logits[:, -1, :], dim=-1)
# Second token generation
- trt_logits_2, key_cache2, value_cache2, trt_keys_2, trt_values_2 = trt_model(next_tokens[:, None], False, key_cache.clone(), value_cache.clone(), input_ids.shape[1], input_ids.shape[1]+1)
- pyt_keys2 = torch.load("key2.pt"); pyt_values2 = torch.load("value2.pt")
- print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok2.logits[:, -1:, :] - trt_logits_2))}")
- print(f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys2[:, :, -2:-1, :] - trt_keys_2))}")
- print(f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys2 - key_cache2[:, :, :-1, :]))}")
- print(f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values2[:, :, -2:-1, :] - trt_values_2))}")
- print(f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values2 - value_cache2[:, :, :-1, :]))}")
+ trt_logits_2, key_cache2, value_cache2, trt_keys_2, trt_values_2 = (
+ trt_model(
+ next_tokens[:, None],
+ False,
+ key_cache.clone(),
+ value_cache.clone(),
+ input_ids.shape[1],
+ input_ids.shape[1] + 1,
+ )
+ )
+ pyt_keys2 = torch.load("key2.pt")
+ pyt_values2 = torch.load("value2.pt")
+ print(
+ f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits_tok2.logits[:, -1:, :] - trt_logits_2))}"
+ )
+ print(
+ f"Diff between pyt and trt keys: {torch.mean(torch.abs(pyt_keys2[:, :, -2:-1, :] - trt_keys_2))}"
+ )
+ print(
+ f"Diff between pyt and trt keys in cache: {torch.mean(torch.abs(pyt_keys2 - key_cache2[:, :, :-1, :]))}"
+ )
+ print(
+ f"Diff between pyt and trt values: {torch.mean(torch.abs(pyt_values2[:, :, -2:-1, :] - trt_values_2))}"
+ )
+ print(
+ f"Diff between pyt and trt values in cache: {torch.mean(torch.abs(pyt_values2 - value_cache2[:, :, :-1, :]))}"
+ )
breakpoint()
elif args.cache == "dynamic":
from dynamic_cache import *
- trt_model = compile_torchtrt(model, input_ids, args)
+
+ trt_model = compile_torchtrt(model, input_ids, args)
breakpoint()
kv_cache = get_zeroed_kv_cache_inputs(trt_model)
else:
# pyt_logits = model.cuda()(input_ids.clone())
- trt_model = compile_torchtrt(model, input_ids, args)
+ trt_model = compile_torchtrt(model, input_ids, args)
# trt_logits = trt_model(input_ids.clone(), True)
# print(f"Diff between pyt and trt: {torch.mean(torch.abs(pyt_logits - trt_logits))}")
# print(f"Diff between pyt and trt logits: {torch.mean(torch.abs(pyt_logits.logits - trt_logits.logits))}")
if args.cache == "static":
if args.cudagraph:
# Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
# trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
torch_tensorrt.runtime.set_cudagraphs_mode(True)
-
+
trt_gen_tokens = generate_with_kv_cache(
- trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
- )
+ trt_model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
+ )
if args.benchmark:
trt_timings = time_generate(
generate_with_kv_cache,
trt_model,
@@ -316,14 +366,17 @@
elif args.cache == "dynamic":
if args.cudagraph:
# Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
# trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
torch_tensorrt.runtime.set_cudagraphs_mode(True)
-
+
trt_gen_tokens = generate_with_kv_cache(
- trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
- )
+ trt_model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
+ )
if args.benchmark:
trt_timings = time_generate(
generate_with_kv_cache,
trt_model,
@@ -333,32 +386,39 @@
iterations=args.iterations,
)
else:
trt_gen_tokens = generate(
- trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
+ trt_model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
)
if args.benchmark:
trt_timings = time_generate(
generate,
trt_model,
input_ids.clone(),
MAX_OUTPUT_SEQ_LENGTH,
tokenizer.eos_token_id,
iterations=args.iterations,
)
-
+
if args.benchmark:
trt_stats = recordStats(
- "TensorRT", trt_timings, args.precision, batch_size=1, compile_time_s=None
- )
-
- if args.enable_pytorch_run:
+ "TensorRT",
+ trt_timings,
+ args.precision,
+ batch_size=1,
+ compile_time_s=None,
+ )
+
+ if args.enable_pytorch_run:
print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
print_outputs("TensorRT", trt_gen_tokens, tokenizer)
- if args.benchmark:
+ if args.benchmark:
if args.enable_pytorch_run:
print("=========PyTorch PERFORMANCE============ \n")
print(pyt_stats)
print("===================== \n")
print("=========TensorRT PERFORMANCE============ \n")
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/static_cache.py 2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/static_cache.py 2025-05-28 23:50:03.516074+00:00
@@ -12,55 +12,57 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
import torch.utils._pytree as pytree
from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes
+
logger = logging.getLogger(__name__)
SDPA_OP = torch._C._nn.scaled_dot_product_attention
+
def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
"""
Modifies the graph to add query, key, and value tensors as outputs.
-
+
This function identifies all scaled dot-product attention (SDPA) operations
in the graph, creates copies of their query, key, and value inputs, and adds
these copies to the graph's outputs. This allows for accessing these tensors
externally, which is useful for operations like key-value caching.
-
+
Args:
graph: The torch.fx.Graph to modify
-
+
Returns:
None. The graph is modified in-place.
"""
- output_node = next(node for node in gm.graph.nodes if node.op == "output")
+ output_node = next(node for node in gm.graph.nodes if node.op == "output")
# Get the current output args (typically a tuple)
current_outputs = output_node.args[0]
-
+
# If the current output is a tuple, extend it with our new outputs
if isinstance(current_outputs, tuple):
new_outputs = current_outputs + tuple(kv_cache_for_graph)
else:
# If there's only one output or it's not a tuple, create a new tuple
- new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
-
+ new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
gm.graph.output(new_outputs)
gm.graph.erase_node(output_node)
return new_outputs
def add_kv_cache_inputs(gm, fixed_kv: bool = True):
"""
Add key-value tensors, index parameters as inputs to the graph.
-
+
Args:
gm: The GraphModule to modify
fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
-
+
Returns:
A tuple containing:
- List of (k_input, v_input) node pairs for each SDPA operation
- start_idx input node for slicing operations
- end_idx input node for slicing operations
@@ -72,14 +74,14 @@
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
key_shape.append(min_max_opt["max"])
else:
key_shape.append(dim)
-
+
static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
return static_tensor
-
+
keys_values = get_kv_nodes(gm)
kv_inputs = []
for idx, key_value in enumerate(keys_values):
k_val = key_value[0].meta["val"]
@@ -87,12 +89,12 @@
if fixed_kv:
k_val = get_static_tensor(k_val)
v_val = get_static_tensor(v_val)
# Add new inputs using add_graph_input
- k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
- v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
+ k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+ v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
kv_inputs.append((k_input, v_input))
# Add start_idx and end_idx as inputs
start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0))
end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1))
@@ -103,10 +105,11 @@
seq_len = input_ids_meta.shape[1]
min_max_opt = extract_var_range_info(seq_len)
max_seq_len = min_max_opt["max"]
from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
shape_env = ShapeEnv()
# Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
start_idx_unbacked_symint = shape_env.create_unbacked_symint()
torch._check(start_idx_unbacked_symint >= 0)
torch._check(start_idx_unbacked_symint <= max_seq_len)
@@ -119,12 +122,16 @@
end_idx_input.meta["val"] = end_idx_unbacked_symint
return kv_inputs, start_idx_input, end_idx_input
-
-def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node):
+def insert_kv_slicing_before_sdpa(
+ gm,
+ incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+ start_idx_input: Node,
+ end_idx_input: Node,
+):
"""
Insert slicing operations before each scaled_dot_product_attention operation.
"""
# Find all nodes with scaled_dot_product_attention
sdpa_nodes = []
@@ -135,106 +142,109 @@
for idx, sdpa_node in enumerate(sdpa_nodes):
q_node, k_node, v_node = sdpa_node.args[:3]
incoming_key, incoming_value = incoming_keys_values[idx]
kv_cache_for_sdpa_node = []
new_keys_values = []
- for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]):
+ for key_or_value, current_key_or_value_node in zip(
+ [incoming_key, incoming_value], [k_node, v_node]
+ ):
# Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
with gm.graph.inserting_before(sdpa_node):
slice_1 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(key_or_value,),
- kwargs={}
+ kwargs={},
)
slice_2 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(slice_1, 1),
- kwargs={}
+ kwargs={},
)
slice_3 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_2, 2, None, start_idx_input),
- kwargs={}
+ args=(slice_2, 2, None, start_idx_input),
+ kwargs={},
)
slice_4 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_3, 3),
- kwargs={}
- )
- # =============================================== #
+ args=(slice_3, 3),
+ kwargs={},
+ )
+ # =============================================== #
# Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
slice_5 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(key_or_value,),
- kwargs={}
+ kwargs={},
)
slice_6 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(slice_5, 1),
- kwargs={}
+ kwargs={},
)
slice_7 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_6, 2, end_idx_input),
- kwargs={}
+ args=(slice_6, 2, end_idx_input),
+ kwargs={},
)
slice_8 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_7, 3),
- kwargs={}
- )
- # =============================================== #
+ args=(slice_7, 3),
+ kwargs={},
+ )
+ # =============================================== #
# Concatenate the sliced tensors to build KV cache
cat = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([slice_4, current_key_or_value_node, slice_8], 2),
- kwargs={}
+ args=([slice_4, current_key_or_value_node, slice_8], 2),
+ kwargs={},
)
# Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
cat.meta.update(key_or_value.meta)
kv_cache_for_sdpa_node.append(cat)
- # =============================================== #
+ # =============================================== #
# Get the current key and value by indexing the KV cache
slice_9 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(cat,),
- kwargs={}
+ "call_function", torch.ops.aten.slice.Tensor, args=(cat,), kwargs={}
)
slice_10 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(slice_9, 1),
- kwargs={}
+ kwargs={},
)
slice_11 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_10, 2, None, end_idx_input),
- kwargs={}
+ args=(slice_10, 2, None, end_idx_input),
+ kwargs={},
)
slice_12 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_11, 3),
- kwargs={}
+ args=(slice_11, 3),
+ kwargs={},
)
new_keys_values.append(slice_12)
-
+
kv_cache_for_graph.extend(kv_cache_for_sdpa_node)
- sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + sdpa_node.args[3:]
-
+ sdpa_node.args = (
+ q_node,
+ new_keys_values[0],
+ new_keys_values[1],
+ ) + sdpa_node.args[3:]
+
return gm, kv_cache_for_graph
@_aten_lowering_pass
def insert_kv_cache(
@@ -243,13 +253,15 @@
"""Insert KV cache ops in the graph"""
"""Perform insertion of kv-caches and attention kernel."""
# Add static key and value as inputs to the graph
kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True)
- # Build and update the KV cache using computed KV inputs for current token and
+ # Build and update the KV cache using computed KV inputs for current token and
# incoming keys and values from previous tokens (which were added as inputs)
- gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input)
+ gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+ gm, kv_inputs, start_idx_input, end_idx_input
+ )
# Call the function to add KV as outputs
logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)
gm = clean_up_graph_after_modifications(gm)
@@ -259,7 +271,5 @@
new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
gm._out_spec = new_out_spec
logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/static_cache2.py 2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/static_cache2.py 2025-05-28 23:50:03.529014+00:00
@@ -12,55 +12,57 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
import torch.utils._pytree as pytree
from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes
+
logger = logging.getLogger(__name__)
SDPA_OP = torch._C._nn.scaled_dot_product_attention
+
def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
"""
Modifies the graph to add query, key, and value tensors as outputs.
-
+
This function identifies all scaled dot-product attention (SDPA) operations
in the graph, creates copies of their query, key, and value inputs, and adds
these copies to the graph's outputs. This allows for accessing these tensors
externally, which is useful for operations like key-value caching.
-
+
Args:
graph: The torch.fx.Graph to modify
-
+
Returns:
None. The graph is modified in-place.
"""
- output_node = next(node for node in gm.graph.nodes if node.op == "output")
+ output_node = next(node for node in gm.graph.nodes if node.op == "output")
# Get the current output args (typically a tuple)
current_outputs = output_node.args[0]
-
+
# If the current output is a tuple, extend it with our new outputs
if isinstance(current_outputs, tuple):
new_outputs = current_outputs + tuple(kv_cache_for_graph)
else:
# If there's only one output or it's not a tuple, create a new tuple
- new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
-
+ new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
gm.graph.output(new_outputs)
gm.graph.erase_node(output_node)
return new_outputs
def add_kv_cache_inputs(gm, fixed_kv: bool = True):
"""
Add key-value tensors, index parameters as inputs to the graph.
-
+
Args:
gm: The GraphModule to modify
fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
-
+
Returns:
A tuple containing:
- List of (k_input, v_input) node pairs for each SDPA operation
- start_idx input node for slicing operations
- end_idx input node for slicing operations
@@ -72,14 +74,14 @@
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
key_shape.append(min_max_opt["max"])
else:
key_shape.append(dim)
-
+
static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
return static_tensor
-
+
keys_values = get_kv_nodes(gm)
kv_inputs = []
for idx, key_value in enumerate(keys_values):
k_val = key_value[0].meta["val"]
@@ -87,12 +89,12 @@
if fixed_kv:
k_val = get_static_tensor(k_val)
v_val = get_static_tensor(v_val)
# Add new inputs using add_graph_input
- k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
- v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
+ k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+ v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
kv_inputs.append((k_input, v_input))
# Add start_idx and end_idx as inputs
start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0))
end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1))
@@ -100,18 +102,19 @@
# Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, ..
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
# Get the third last input which should be the last value cache node and store the max_seq_len
input_ids_meta = input_nodes[-3].meta["val"]
seq_len = input_ids_meta.shape[2]
-
+
if isinstance(seq_len, torch.SymInt):
min_max_opt = extract_var_range_info(seq_len)
max_seq_len = min_max_opt["max"]
else:
max_seq_len = seq_len
from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
shape_env = ShapeEnv()
# Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
start_idx_unbacked_symint = shape_env.create_unbacked_symint()
torch._check(start_idx_unbacked_symint >= 0)
torch._check(start_idx_unbacked_symint <= max_seq_len)
@@ -127,14 +130,17 @@
is_causal_input = add_graph_input(gm, "is_causal", True)
is_causal_input.meta["val"] = torch.tensor(True)
return kv_inputs, start_idx_input, end_idx_input, is_causal_input
-def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input):
+
+def create_kv_cache_update_nodes(
+ gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input
+):
"""
Create slicing and concatenation nodes for KV cache update.
-
+
This function creates the necessary slicing and concatenation nodes to update the KV cache
during the generation process. It takes the SDPA node, the current KV cache node, and the
incoming KV cache node as input.
Returns:
for a particular SDPA node, a tuple containing:
@@ -147,78 +153,73 @@
with gm.graph.inserting_before(sdpa_node):
slice_1 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(incoming_kv_node,),
- kwargs={}
+ kwargs={},
)
slice_2 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(slice_1, 1),
- kwargs={}
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_1, 1), kwargs={}
)
slice_3 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_2, 2, None, start_idx_input),
- kwargs={}
+ args=(slice_2, 2, None, start_idx_input),
+ kwargs={},
)
slice_4 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(slice_3, 3),
- kwargs={}
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_3, 3), kwargs={}
)
# Concat key_cache[:,:,:start_idx,:] with current key (k)
concat_keys_or_values = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([slice_4, current_kv_node], 2),
- kwargs={}
- )
-
- # =============================================== #
+ args=([slice_4, current_kv_node], 2),
+ kwargs={},
+ )
+
+ # =============================================== #
# Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
slice_5 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(incoming_kv_node,),
- kwargs={}
+ kwargs={},
)
slice_6 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(slice_5, 1),
- kwargs={}
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_5, 1), kwargs={}
)
slice_7 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_6, 2, end_idx_input),
- kwargs={}
+ args=(slice_6, 2, end_idx_input),
+ kwargs={},
)
slice_8 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(slice_7, 3),
- kwargs={}
- )
- # =============================================== #
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_7, 3), kwargs={}
+ )
+ # =============================================== #
# Concatenate the sliced tensors to build KV cache
new_incoming_keys_or_values = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([concat_keys_or_values, slice_8], 2),
- kwargs={}
+ args=([concat_keys_or_values, slice_8], 2),
+ kwargs={},
)
# Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
new_incoming_keys_or_values.meta.update(incoming_kv_node.meta)
return concat_keys_or_values, new_incoming_keys_or_values
-def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node):
+
+def insert_kv_slicing_before_sdpa(
+ gm,
+ incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+ start_idx_input: Node,
+ end_idx_input: Node,
+ is_causal_input: Node,
+):
"""
Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic:
concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
@@ -232,21 +233,34 @@
sdpa_nodes.append(node)
kv_cache_for_graph = []
for idx, sdpa_node in enumerate(sdpa_nodes):
q_node, k_node, v_node = sdpa_node.args[:3]
incoming_key, incoming_value = incoming_keys_values[idx]
- # For keys
- new_current_key_node, new_incoming_key_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input)
+ # For keys
+ new_current_key_node, new_incoming_key_cache_node = (
+ create_kv_cache_update_nodes(
+ gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input
+ )
+ )
# For values
- new_current_value_node, new_incoming_value_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input)
+ new_current_value_node, new_incoming_value_cache_node = (
+ create_kv_cache_update_nodes(
+ gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input
+ )
+ )
# Store the KV cache nodes for the current SDPA node
- kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node])
+ kv_cache_for_graph.extend(
+ [new_incoming_key_cache_node, new_incoming_value_cache_node]
+ )
# Update the SDPA node arguments with current key and value nodes
- sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (None, is_causal_input) # + sdpa_node.args[3:]
-
+ sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (
+ None,
+ is_causal_input,
+ ) # + sdpa_node.args[3:]
+
kv_cache_for_graph.extend([k_node, v_node])
return gm, kv_cache_for_graph
@_aten_lowering_pass
@@ -254,15 +268,19 @@
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Insert KV cache ops in the graph"""
"""Perform insertion of kv-caches and attention kernel."""
# Add static key and value as inputs to the graph
- kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True)
-
- # Build and update the KV cache using computed KV inputs for current token and
+ kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(
+ gm, fixed_kv=True
+ )
+
+ # Build and update the KV cache using computed KV inputs for current token and
# incoming keys and values from previous tokens (which were added as inputs)
- gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input)
+ gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+ gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input
+ )
# Call the function to add KV as outputs
logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)
gm = clean_up_graph_after_modifications(gm)
@@ -272,7 +290,5 @@
new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
gm._out_spec = new_out_spec
logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/utils.py 2025-05-28 23:49:40.728795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/utils.py 2025-05-28 23:50:03.690290+00:00
@@ -2,13 +2,14 @@
from transformers import StoppingCriteriaList
from transformers.generation.stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
)
-import numpy as np
-import copy
+import numpy as np
+import copy
import timeit
+
def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
"""
Exports the LLM model into an ExportedProgram with dynamic shapes.
In the case of guard failures due to some PyTorch kernel implements, we also
@@ -36,31 +37,38 @@
allow_complex_guards_as_runtime_asserts=True,
)
return ep
+
def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule):
"""
Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule.
-
+
This function identifies placeholder nodes in the graph that represent KV cache tensors,
and creates zeroed tensors with the same shape, dtype, and device as the original placeholders.
-
+
Args:
model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders
-
+
Returns:
tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph
"""
# placeholder nodes are expected to be in the following order:
# input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx
placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"]
# The first two inputs are input_ids and is_causal. The last two inputs are start_idx and end_idx. In between are the KV cache tensors.
kv_cache_inputs = placeholder_nodes[2:-2]
zeroed_kv_cache_inputs = []
for input in kv_cache_inputs:
- zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0")))
+ zeroed_kv_cache_inputs.append(
+ torch.zeros(
+ input.meta["val"].shape,
+ dtype=input.meta["val"].dtype,
+ device=torch.device("cuda:0"),
+ )
+ )
return tuple(zeroed_kv_cache_inputs)
def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True):
@@ -87,10 +95,11 @@
if not benchmark and stopping_criteria(input_seq, logits).item():
break
# breakpoint()
return input_seq
+
def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id):
"""
Greedy decoding of the model with KV cache.
"""
start_idx = 0
@@ -112,27 +121,26 @@
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
output_seq = torch.cat([output_seq, next_tokens], dim=-1)
input_seq = next_tokens
start_idx = end_idx
- end_idx = start_idx + 1
+ end_idx = start_idx + 1
lkv = torch.cat(logits_concat, dim=1)
# breakpoint()
return output_seq
+
def time_generate(
generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10
):
"""
Measure the time for generating a sentence over certain number of iterations
"""
timings = []
for _ in range(iterations):
start_time = timeit.default_timer()
- _ = generate_fn(
- model, inputs, output_seq_length, eos_token_id
- )
+ _ = generate_fn(model, inputs, output_seq_length, eos_token_id)
torch.cuda.synchronize()
end_time = timeit.default_timer()
timings.append(end_time - start_time)
return timings
@@ -160,6 +168,6 @@
"Median-Latency(ms)": time_med * 1000,
"Mean-Latency(ms)": time_mean * 1000,
"Latency-StdDev(ms)": time_std * 1000,
"Compile Time(s)": compile_time_s,
}
- return stats
\ No newline at end of file
+ return stats
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/test_static_cache.py 2025-05-28 23:49:40.727795+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/test_static_cache.py 2025-05-28 23:50:03.847453+00:00
@@ -17,61 +17,83 @@
ATOL = 1e-5
RTOL = 1e-5
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
+
class DynamicCacheModel(nn.Module):
def __init__(self):
super().__init__()
-
+
def forward(self, q, k, v, k1, v1, flag):
- def true_fn(q, k, v, k1, v1):
+ def true_fn(q, k, v, k1, v1):
k_new = torch.cat((k, k1), dim=2)
v_new = torch.cat((v, v1), dim=2)
return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new)
def false_fn(q, k, v, k1, v1):
return torch._C._nn.scaled_dot_product_attention(q, k, v)
out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1))
return 2 * out
-
+
+
class ModelNoCache(nn.Module):
def __init__(self):
super().__init__()
-
+
def forward(self, q, k, v):
- return torch._C._nn.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
+ return torch._C._nn.scaled_dot_product_attention(
+ q, k, v, dropout_p=0.0, is_causal=True
+ )
+
class StaticCacheModel(nn.Module):
def __init__(self):
super().__init__()
-
- # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
+
+ # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
# new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2)
# new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2)
# out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal)
-
+
# return out, new_key_cache, new_value_cache
-
- def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
- concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ]
+
+ def forward(
+ self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+ ):
+ concat_keys = torch.cat(
+ (key_cache[:, :, :start_idx, :], k), dim=2
+ ) # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ]
concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
- new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
- out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal)
-
+ new_value_cache = torch.cat(
+ (concat_values, value_cache[:, :, end_idx:, :]), dim=2
+ )
+ out = torch._C._nn.scaled_dot_product_attention(
+ q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
+ )
+
return out, new_key_cache, new_value_cache
-def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
- is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
+def eager_sdpa(
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ scale=None,
+ enable_gqa=False,
+) -> torch.Tensor:
"""
Eager implementation of SDPA
"""
import math
+
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
breakpoint()
if is_causal:
@@ -85,24 +107,28 @@
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
if enable_gqa:
- key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
- value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
+
def print_diff(tensor1, tensor2, prefix=""):
"""
Print the diff between two tensors
"""
- print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}")
+ print(
+ f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+ )
+
def test_static_cache_model(args):
"""
Test the static cache model
"""
@@ -117,13 +143,15 @@
# Test Prefill
start_idx = 0
end_idx = 2048
out_no_cache = model_no_cache(q, k, v)
- out_static_cache, new_key_cache, new_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True)
+ out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+ q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+ )
assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL)
-
+
# Test Generate
for start_idx in range(2048, 2176):
end_idx = start_idx + 1
q_curr = torch.randn(1, 32, 1, 64).cuda()
k_curr = torch.randn(1, 32, 1, 64).cuda()
@@ -133,17 +161,29 @@
q_full = torch.cat((q, q_curr), dim=2)
k_full = torch.cat((k, k_curr), dim=2)
v_full = torch.cat((v, v_curr), dim=2)
out_no_cache = model_no_cache(q_full, k_full, v_full)
- out_static_cache, new_key_cache, new_value_cache = model_static_cache(q_curr, k_curr, v_curr, new_key_cache, new_value_cache, start_idx, end_idx, is_causal=False)
-
- assert torch.allclose(out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL)
- q = q_full
+ out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+ q_curr,
+ k_curr,
+ v_curr,
+ new_key_cache,
+ new_value_cache,
+ start_idx,
+ end_idx,
+ is_causal=False,
+ )
+
+ assert torch.allclose(
+ out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL
+ )
+ q = q_full
k = k_full
v = v_full
print("============== test_static_cache passed ==============")
+
def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args):
"""
Transform the graph module by adding key and value cache to the graph
"""
@@ -155,52 +195,53 @@
use_python_runtime=True,
debug=args.debug,
min_block_size=1,
)
exported_program = pre_export_lowering(exported_program, settings)
- exported_program = exported_program.run_decompositions(
- get_decompositions(False)
- )
+ exported_program = exported_program.run_decompositions(get_decompositions(False))
gm = exported_program.module()
gm = post_lowering(gm, settings)
return gm
+
def test_static_cache_lowering(args):
"""
- Test static cache lowering pass applied to the model with no cache and run the graph module
+ Test static cache lowering pass applied to the model with no cache and run the graph module
and compare the output with the model with no cache
"""
import static_cache2
model_no_cache = ModelNoCache().eval().cuda()
q = torch.randn(1, 32, 2, 64).cuda()
k = torch.randn(1, 32, 2048, 64).cuda()
v = torch.randn(1, 32, 2048, 64).cuda()
key_cache = torch.zeros(1, 32, 2176, 64).cuda()
value_cache = torch.zeros(1, 32, 2176, 64).cuda()
-
+
# Export the model
q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
exported_program = export(
model_no_cache,
args=(q, k, v),
- dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}),
- strict=False
+ dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+ strict=False,
)
gm = transform_gm_with_kv_cache(exported_program, args)
# Test Prefill
start_idx = 0
end_idx = 2048
is_causal = True
q = torch.randn(1, 32, 2048, 64).cuda()
out_no_cache = model_no_cache(q, k, v)
- out_pyt_cache, key_cache, value_cache = gm(q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx)
+ out_pyt_cache, key_cache, value_cache = gm(
+ q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx
+ )
assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL)
# Test Generate
for start_idx in range(2048, 2176):
end_idx = start_idx + 1
@@ -209,20 +250,32 @@
k_curr = torch.randn(1, 32, 1, 64).cuda()
v_curr = torch.randn(1, 32, 1, 64).cuda()
# Concatenate the current query, key, and value with the previous ones
q_full = torch.cat((q, q_curr), dim=2)
k_full = torch.cat((k, k_curr), dim=2)
- v_full = torch.cat((v, v_curr), dim=2)
-
+ v_full = torch.cat((v, v_curr), dim=2)
+
out_no_cache = model_no_cache(q_full, k_full, v_full)
- out_pyt_static_cache, key_cache, value_cache = gm(q_curr, k_curr, v_curr, is_causal, key_cache, value_cache, start_idx, end_idx)
- assert torch.allclose(out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL)
- q = q_full
+ out_pyt_static_cache, key_cache, value_cache = gm(
+ q_curr,
+ k_curr,
+ v_curr,
+ is_causal,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ )
+ assert torch.allclose(
+ out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL
+ )
+ q = q_full
k = k_full
v = v_full
-
+
print("============== test_static_cache_lowering passed ==============")
+
def test_static_cache_export(args):
"""
Test the static cache model export
"""
@@ -236,19 +289,28 @@
start_idx = 0
end_idx = 2048
is_causal = True
# Export the model
seq_len = torch.export.Dim("seq_len", min=2, max=2048)
- seq_len_dyn_dim = {2 : seq_len}
+ seq_len_dyn_dim = {2: seq_len}
exported_program = export(
model_static_cache,
args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal),
- dynamic_shapes=(seq_len_dyn_dim, seq_len_dyn_dim, seq_len_dyn_dim, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC, None),
- strict=False
- )
-
-
+ dynamic_shapes=(
+ seq_len_dyn_dim,
+ seq_len_dyn_dim,
+ seq_len_dyn_dim,
+ None,
+ None,
+ torch.export.Dim.DYNAMIC,
+ torch.export.Dim.DYNAMIC,
+ None,
+ ),
+ strict=False,
+ )
+
+
def test_static_cache_with_torch_tensorrt(args):
"""
Test the static cache model with torch_tensorrt
"""
import static_cache2
@@ -257,83 +319,104 @@
q = torch.randn(1, 32, 2, 64).cuda()
k = torch.randn(1, 32, 2048, 64).cuda()
v = torch.randn(1, 32, 2048, 64).cuda()
key_cache = torch.zeros(1, 32, 2176, 64).cuda()
value_cache = torch.zeros(1, 32, 2176, 64).cuda()
-
+
# Export the model
q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
exported_program = export(
model_no_cache,
args=(q, k, v),
- dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}),
- strict=False
- )
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
+ dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+ strict=False,
+ )
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
exported_program,
inputs=[q, k, v],
enabled_precisions={torch.float32},
disable_tf32=True,
use_python_runtime=True,
debug=args.debug,
min_block_size=1,
)
-
+
start_idx = 0
end_idx = 2048
is_causal = True
q = torch.randn(1, 32, 2048, 64).cuda()
# out_eager = eager_sdpa(q, k, v, is_causal=is_causal)
out_no_cache = model_no_cache(q, k, v)
- out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx)
+ out_trt, trt_key_cache, trt_value_cache = trt_model(
+ q, k, v, is_causal, key_cache, value_cache, start_idx, end_idx
+ )
# breakpoint()
- assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL), "Prefill TRT logits don't match"
- assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL), "Prefill TRT key cache don't match"
- assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL), "Prefill TRT value cache don't match"
-
+ assert torch.allclose(
+ out_no_cache, out_trt, atol=ATOL, rtol=RTOL
+ ), "Prefill TRT logits don't match"
+ assert torch.allclose(
+ trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL
+ ), "Prefill TRT key cache don't match"
+ assert torch.allclose(
+ trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL
+ ), "Prefill TRT value cache don't match"
+
# Test Generate
for start_idx in range(2048, 2176):
end_idx = start_idx + 1
q_curr = torch.randn(1, 32, 1, 64).cuda()
k_curr = torch.randn(1, 32, 1, 64).cuda()
- v_curr = torch.randn(1, 32, 1, 64).cuda()
+ v_curr = torch.randn(1, 32, 1, 64).cuda()
# Concatenate the current query, key, and value with the previous ones
q_full = torch.cat((q, q_curr), dim=2)
k_full = torch.cat((k, k_curr), dim=2)
- v_full = torch.cat((v, v_curr), dim=2)
+ v_full = torch.cat((v, v_curr), dim=2)
is_causal = False
out_no_cache = model_no_cache(q_full, k_full, v_full)
- out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, is_causal, trt_key_cache, trt_value_cache, start_idx, end_idx)
+ out_trt, trt_key_cache, trt_value_cache = trt_model(
+ q_curr,
+ k_curr,
+ v_curr,
+ is_causal,
+ trt_key_cache,
+ trt_value_cache,
+ start_idx,
+ end_idx,
+ )
# breakpoint()
# print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}")
# print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}")
# print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}")
- assert torch.allclose(out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL), f"Generate TRT logits don't match for idx {start_idx}"
- assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL), f"Generate TRT key cache don't match for idx {start_idx}"
- assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL), f"Generate TRT value cache don't match for idx {start_idx}"
- q = q_full
+ assert torch.allclose(
+ out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL
+ ), f"Generate TRT logits don't match for idx {start_idx}"
+ assert torch.allclose(
+ trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL
+ ), f"Generate TRT key cache don't match for idx {start_idx}"
+ assert torch.allclose(
+ trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL
+ ), f"Generate TRT value cache don't match for idx {start_idx}"
+ q = q_full
k = k_full
v = v_full
print("============== test_static_cache_with_torch_tensorrt passed ==============")
-
+
def main():
arg_parser = argparse.ArgumentParser(
description="Run test cases for llama attention and decoder"
)
arg_parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable debug (default: False)"
+ "--debug", action="store_true", help="Enable debug (default: False)"
)
args = arg_parser.parse_args()
with torch.inference_mode():
# test_static_cache_model(args)
# test_static_cache_lowering(args)
test_static_cache_with_torch_tensorrt(args)
-
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2025-05-28 23:49:40.738795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2025-05-28 23:50:04.807491+00:00
@@ -692,10 +692,11 @@
gm = exported_program.module()
exported_program.module().to("cpu")
torch.cuda.empty_cache()
import gc
+
gc.collect()
logger.debug("Input graph: " + str(gm.graph))
# Apply lowering on the graph module
gm = post_lowering(gm, settings)
@@ -790,31 +791,30 @@
# TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this.
logger.warning(
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
)
-
# Store the original input spec for later use
- original_in_spec = getattr(gm, '_in_spec', None)
- original_out_spec = getattr(gm, '_out_spec', None)
-
+ original_in_spec = getattr(gm, "_in_spec", None)
+ original_out_spec = getattr(gm, "_out_spec", None)
+
# Function to preserve and restore module specs
def preserve_module_specs(in_spec, out_spec, target_module):
"""
Applies input and output specs to the target module.
-
+
Args:
in_spec: The input spec to apply
out_spec: The output spec to apply
target_module: The module to apply specs to
"""
# Apply specs to target module
if in_spec is not None:
target_module._in_spec = in_spec
if out_spec is not None:
target_module._out_spec = out_spec
-
+
return target_module
# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False
# If specified, try using the fast partitioner and fall back to the global one on failure
@@ -1197,11 +1197,11 @@
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
}
-
+
settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
exported_program = pre_export_lowering(exported_program, settings)
# Decompose the exported program
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2025-05-28 23:49:40.742795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2025-05-28 23:50:05.774123+00:00
@@ -23,11 +23,15 @@
"""Replace specific versions of scaled_dot_product_attention with an equivalent
implementation which can be easily converted to TRT
"""
original_fns, replacement = scaled_dot_product_attention_replacement()
replaced_nodes = []
- sdpa_nodes = [node for node in gm.graph.nodes if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default]
+ sdpa_nodes = [
+ node
+ for node in gm.graph.nodes
+ if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default
+ ]
breakpoint()
# For each original function, search for it in the graph and replace
for original in original_fns:
replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters(
gm,
@@ -167,6 +171,6 @@
def replacement(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(query, key, value)
- return (efficient, flash, efficient_scale, flash_scale), replacement
\ No newline at end of file
+ return (efficient, flash, efficient_scale, flash_scale), replacement
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2025-05-28 23:49:40.743795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2025-05-28 23:50:06.419303+00:00
@@ -742,14 +742,15 @@
"""
# Representation of input shapes to a given model
# Shapes are concatenated as so:
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
tensor_inputs = [
- t if isinstance(t, torch.Tensor) else torch.tensor(t)
- for t in inputs
+ t if isinstance(t, torch.Tensor) else torch.tensor(t) for t in inputs
]
- new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs)
+ new_shape_key = "".join(
+ str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+ )
# If the new shape key differs from the existing one,
# invalidate the old shape key and remove the CUDAGraph
if new_shape_key != self.shape_key:
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py 2025-05-28 23:49:40.743795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py 2025-05-28 23:50:06.446920+00:00
@@ -369,27 +369,35 @@
is_shape_tensor_input = self.engine.is_shape_inference_io(input_name)
if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
# Clone is required to avoid re-using user-provided GPU memory
if is_shape_tensor_input:
- self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].cpu().clone()
+ self._input_buffers[inputs_shape_key][i] = (
+ contiguous_inputs[i].cpu().clone()
+ )
else:
- self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].clone()
+ self._input_buffers[inputs_shape_key][i] = contiguous_inputs[
+ i
+ ].clone()
# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# as per TensorRT requirements
if is_shape_tensor_input:
# Shape tensor inputs are casted to int64 explicitly
# Currently Torch CPU pointers are not working; numpy pointers are used instead
# to refer to underlying memory
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64)
- inputs_cpu_numpy = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
+ inputs_cpu_numpy = (
+ contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
+ )
# if cudagraphs_enabled:
# self._input_buffers[inputs_shape_key][i].copy_(inputs_cpu)
# self.context.set_tensor_address(input_name, self._input_buffers[inputs_shape_key][i].numpy().copy().ctypes.data)
# else:
- self.context.set_tensor_address(input_name, inputs_cpu_numpy.ctypes.data)
+ self.context.set_tensor_address(
+ input_name, inputs_cpu_numpy.ctypes.data
+ )
else:
self.context.set_input_shape(
input_name, tuple(contiguous_inputs[i].shape)
)
if cudagraphs_enabled:
@@ -458,11 +466,14 @@
assert len(contiguous_inputs) == len(
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
self.setup_input_tensors(
- contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record, inputs_shape_key
+ contiguous_inputs,
+ self.cudagraphs_enabled,
+ need_cudagraphs_record,
+ inputs_shape_key,
)
if shape_changed:
# Check if input shapes can be inferred.
uninferred_input_names = self.context.infer_shapes()
@@ -496,11 +507,12 @@
if need_cudagraphs_record:
self._output_buffers[inputs_shape_key][o] = outputs[o].clone()
if self.cudagraphs_enabled:
self.context.set_tensor_address(
- output_name, self._output_buffers[inputs_shape_key][o].data_ptr()
+ output_name,
+ self._output_buffers[inputs_shape_key][o].data_ptr(),
)
else:
self.context.set_tensor_address(
output_name, outputs[o].data_ptr()
)
@@ -522,30 +534,35 @@
self._engine_stream.wait_stream(self._caller_stream)
with torch.cuda.stream(self._engine_stream):
if self.cudagraphs_enabled:
if need_cudagraphs_record:
-
- self.shape_key_to_cudagraph[inputs_shape_key] = torch.cuda.CUDAGraph()
+
+ self.shape_key_to_cudagraph[inputs_shape_key] = (
+ torch.cuda.CUDAGraph()
+ )
if self.profiling_enabled:
- self.shape_key_to_cudagraph[inputs_shape_key].enable_debug_mode()
+ self.shape_key_to_cudagraph[
+ inputs_shape_key
+ ].enable_debug_mode()
with torch.cuda.graph(
- self.shape_key_to_cudagraph[inputs_shape_key], stream=self._engine_stream
+ self.shape_key_to_cudagraph[inputs_shape_key],
+ stream=self._engine_stream,
):
self.context.execute_async_v3(
self._engine_stream.cuda_stream
)
if self.profiling_enabled:
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
- self.shape_key_to_cudagraph[inputs_shape_key].debug_dump(
- f"{tempdir}/{self.name}_cudagraph.dot"
- )
+ self.shape_key_to_cudagraph[
+ inputs_shape_key
+ ].debug_dump(f"{tempdir}/{self.name}_cudagraph.dot")
self.shape_key_to_cudagraph[inputs_shape_key].replay() # type: ignore
else:
self.context.execute_async_v3(self._engine_stream.cuda_stream)
@@ -754,18 +771,21 @@
"""
# Representation of input shapes to a given model
# Shapes are concatenated as so:
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
tensor_inputs = [
- t if isinstance(t, torch.Tensor) else torch.tensor(t)
- for t in inputs
+ t if isinstance(t, torch.Tensor) else torch.tensor(t) for t in inputs
]
- new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs)
+ new_shape_key = "".join(
+ str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+ )
# If the new shape key differs from the existing one,
# invalidate the old shape key and remove the CUDAGraph
if new_shape_key not in self.shape_key_to_cudagraph:
- logger.debug(f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape.")
+ logger.debug(
+ f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape."
+ )
# self.shape_key = new_shape_key
return True, new_shape_key
return False, new_shape_key
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2025-05-28 23:49:40.739795+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2025-05-28 23:50:07.289014+00:00
@@ -1893,10 +1893,11 @@
SourceIR.ATEN,
name,
args[0],
args[1],
)
+
@dynamo_tensorrt_converter(operator.sub, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True)
def aten_ops_sub(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama_benchmark.py 2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llama_benchmark.py 2025-05-31 02:01:02.496778+00:00
@@ -10,68 +10,72 @@
def main():
# Initialize model and tokenizer
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
- MODEL_NAME,
- torch_dtype=torch.float16,
- use_cache=False,
- device_map="auto"
+ MODEL_NAME, torch_dtype=torch.float16, use_cache=False, device_map="auto"
)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward)
-
+
# Prepare input prompt
word = "What"
# Tokenize the word
- word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence
+ word_ids = tokenizer(word, return_tensors="pt").input_ids[
+ 0
+ ] # Get the first (and only) sequence
# Repeat the token 2048 times
- input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device
+ input_ids = (
+ word_ids.repeat(1024).unsqueeze(0).to(model.device)
+ ) # Add batch dimension and move to device
print(f"Input tensor shape: {input_ids.shape}")
# # Warm-up pass
print("Running warm-up pass...")
output_ids = model.generate(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
- use_cache=USE_CACHE
+ use_cache=USE_CACHE,
)
-
+
# Benchmark loop
print("Running benchmark...")
num_iterations = 10
total_time = 0
timings = []
-
+
for i in range(num_iterations):
start_time = timeit.default_timer()
output_ids = model.generate(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
- use_cache=USE_CACHE
+ use_cache=USE_CACHE,
)
end_time = timeit.default_timer()
generation_time = end_time - start_time
total_time += generation_time
timings.append(generation_time)
-
+
# Decode and print first iteration output
# if i == 0:
# output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# print("\nFirst generation output:")
# print(output_text)
-
+
# Calculate and print statistics
average_time = total_time / num_iterations
print(f"\nPerformance Statistics:")
- print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds")
+ print(
+ f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds"
+ )
print(f"Average tokens per second: {100/average_time:.2f}")
print("\nIndividual timings (ms):")
for i, t in enumerate(timings):
print(f"Iteration {i+1}: {t*1000:.2f}")
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/cache_utils.py 2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/cache_utils.py 2025-05-31 02:01:02.572391+00:00
@@ -5,81 +5,90 @@
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._pytree import _LEAF_SPEC
from torch._export.utils import _detect_fake_mode_from_gm
+
def get_kv_nodes(gm):
"""
Get the key and value nodes from the graph.
"""
kv_nodes = []
for node in gm.graph.nodes:
- if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
+ if (
+ node.op == "call_function"
+ and node.target == torch._C._nn.scaled_dot_product_attention
+ ):
q_node, k_node, v_node = node.args[:3]
kv_nodes.append((k_node, v_node))
return kv_nodes
+
def get_random_tensor_from_node(node: Node) -> torch.Tensor:
- """
- Creates a random tensor based on the shape information in a node's metadata.
- For symbolic dimensions, extracts the maximum value from the shape environment.
-
- Args:
- node: A torch.fx.Node object with metadata containing tensor information
-
- Returns:
- A random tensor with shape matching the node's metadata, or None if no valid
- tensor information is found
- """
- if "val" not in node.meta:
- raise ValueError(f"No tensor information found in node metadata for node: {node}")
-
- fake_tensor = node.meta["val"]
- shape = []
-
- # Iterate through each dimension and handle symbolic dimensions
- for dim in fake_tensor.shape:
- if isinstance(dim, torch.SymInt):
- # Extract the maximum value from the shape environment
- max_val = dim.node.hint
- shape.append(max_val)
- else:
- shape.append(dim)
-
- # Create a random tensor with the determined shape
- dtype = fake_tensor.dtype
- device = fake_tensor.device
- random_tensor = torch.rand(shape, dtype=dtype, device=device)
+ """
+ Creates a random tensor based on the shape information in a node's metadata.
+ For symbolic dimensions, extracts the maximum value from the shape environment.
- return random_tensor
+ Args:
+ node: A torch.fx.Node object with metadata containing tensor information
+
+ Returns:
+ A random tensor with shape matching the node's metadata, or None if no valid
+ tensor information is found
+ """
+ if "val" not in node.meta:
+ raise ValueError(
+ f"No tensor information found in node metadata for node: {node}"
+ )
+
+ fake_tensor = node.meta["val"]
+ shape = []
+
+ # Iterate through each dimension and handle symbolic dimensions
+ for dim in fake_tensor.shape:
+ if isinstance(dim, torch.SymInt):
+ # Extract the maximum value from the shape environment
+ max_val = dim.node.hint
+ shape.append(max_val)
+ else:
+ shape.append(dim)
+
+ # Create a random tensor with the determined shape
+ dtype = fake_tensor.dtype
+ device = fake_tensor.device
+ random_tensor = torch.rand(shape, dtype=dtype, device=device)
+
+ return random_tensor
+
def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]:
"""
Creates random tensors based on the shape information in node metadata.
For symbolic dimensions, extracts the maximum value from the shape environment.
-
+
Args:
nodes: List of torch.fx.Node objects with metadata
-
+
Returns:
List of random tensors with shapes matching the nodes' metadata
"""
random_tensors = []
-
+
for node in nodes:
if isinstance(node, Node):
node_tensor = get_random_tensor_from_node(node)
elif isinstance(node, tuple):
node_tensor_list = []
for n in node:
random_tensor = get_random_tensor_from_node(n)
node_tensor_list.append(random_tensor)
node_tensor = tuple(node_tensor_list)
-
+
random_tensors.append(node_tensor)
-
+
return random_tensors
+
def add_graph_input(
gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
) -> Node:
"""Add a graph input to the given GraphModule and return the newly created node.
@@ -130,10 +139,11 @@
in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor)
# return new node...
return in_node
+
def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool:
"""Check if the node is a call to one of the ops."""
if node.op != "call_function":
return False
# check if it's a single op that's provided
@@ -144,9 +154,10 @@
if any(node.target == op for op in ops):
return True
return False
+
def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]:
input_nodes: List[Node] = graph.find_nodes(op="placeholder")
output_nodes: List[Node] = graph.find_nodes(op="output")
- return (input_nodes, output_nodes)
\ No newline at end of file
+ return (input_nodes, output_nodes)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/dynamic_cache.py 2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/dynamic_cache.py 2025-05-31 02:01:02.649617+00:00
@@ -12,36 +12,45 @@
from torch_tensorrt.dynamo.utils import extract_var_range_info
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
-from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op
+from cache_utils import (
+ add_graph_input,
+ create_random_output_tensors,
+ get_kv_nodes,
+ is_op,
+)
import tensorrt
import torch.utils._pytree as pytree
+
logger = logging.getLogger(__name__)
-@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True)
+
+@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
+ torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True
+)
def cond_converter(
ctx: torch_tensorrt.dynamo.conversion.ConversionContext,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: str,
) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]:
"""
Converter for torch.ops.higher_order.cond operation to TensorRT.
-
+
This function handles the conversion of PyTorch's conditional operation to TensorRT.
The conditional operation selects between two tensors based on a boolean predicate.
-
+
Args:
ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context
target (Target): The target operation to convert
args (Tuple[Argument, ...]): The arguments to the operation
kwargs (Dict[str, Argument]): The keyword arguments to the operation
name (str): The name to give to the TensorRT layer
-
+
Returns:
Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s)
"""
if_layer = ctx.net.add_if_conditional()
condition, true_branch, false_branch = args[0], args[1], args[2]
@@ -49,30 +58,31 @@
output_layer = if_layer.add_output(true_branch, false_branch)
output = output_layer.get_output(0)
return output
+
def add_kv_as_outputs(gm):
"""
Modifies the graph to add query, key, and value tensors as outputs.
-
+
This function identifies all scaled dot-product attention (SDPA) operations
in the graph, creates copies of their query, key, and value inputs, and adds
these copies to the graph's outputs. This allows for accessing these tensors
externally, which is useful for operations like key-value caching.
-
+
Args:
graph: The torch.fx.Graph to modify
-
+
Returns:
None. The graph is modified in-place.
"""
# list of MHA kernels we would want to detect and replace
mha_ops = {
torch._C._nn.scaled_dot_product_attention,
}
-
+
# Find all SDPA nodes in the graph
mha_nodes = []
for node in gm.graph.nodes:
if is_op(node, mha_ops):
mha_nodes.append(node)
@@ -80,157 +90,170 @@
# Iterate through each MHA node to extract shape information
for mha_node in mha_nodes:
if "val" in mha_node.meta and len(mha_node.args) >= 3:
# Get the input nodes (query, key, value)
q_node, k_node, v_node = mha_node.args[:3]
-
+
# Add the copy nodes as outputs to the graph
- output_node = next(node for node in gm.graph.nodes if node.op == "output")
+ output_node = next(node for node in gm.graph.nodes if node.op == "output")
# Get the current output args (typically a tuple)
current_outputs = output_node.args[0]
-
+
# If the current output is a tuple, extend it with our new outputs
if isinstance(current_outputs, tuple):
new_outputs = current_outputs + ((k_node, v_node),)
else:
# If there's only one output or it's not a tuple, create a new tuple
new_outputs = (current_outputs, (k_node, v_node))
-
+
gm.graph.output(new_outputs)
gm.graph.erase_node(output_node)
-
+
return new_outputs
-
-
def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True):
- """
- Add key-value tensors and index parameters as inputs to the graph.
-
- Args:
- gm: The GraphModule to modify
- fixed_kv: Boolean indicating whether to use static tensors for KV cache
-
- Returns:
- A tuple containing:
- - List of (k_input, v_input) node pairs for each SDPA operation
- - start_idx input node for slicing operations
- - end_idx input node for slicing operations
- """
-
- def get_static_tensor(tensor: torch.Tensor):
- key_shape = []
- for dim in tensor.shape:
- if isinstance(dim, torch.SymInt):
- min_max_opt = extract_var_range_info(dim)
- key_shape.append(min_max_opt["max"])
- else:
- key_shape.append(dim)
-
- static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
- return static_tensor
-
- keys_values = get_kv_nodes(gm)
-
- kv_inputs = []
- for idx, key_value in enumerate(keys_values):
- k_val = key_value[0].meta["val"]
- v_val = key_value[1].meta["val"]
- if fixed_kv:
- k_val = get_static_tensor(k_val)
- v_val = get_static_tensor(v_val)
-
- # Add new inputs using add_graph_input
- k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
- v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
- kv_inputs.append((k_input, v_input))
-
- return kv_inputs
-
-
-def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]):
+ """
+ Add key-value tensors and index parameters as inputs to the graph.
+
+ Args:
+ gm: The GraphModule to modify
+ fixed_kv: Boolean indicating whether to use static tensors for KV cache
+
+ Returns:
+ A tuple containing:
+ - List of (k_input, v_input) node pairs for each SDPA operation
+ - start_idx input node for slicing operations
+ - end_idx input node for slicing operations
+ """
+
+ def get_static_tensor(tensor: torch.Tensor):
+ key_shape = []
+ for dim in tensor.shape:
+ if isinstance(dim, torch.SymInt):
+ min_max_opt = extract_var_range_info(dim)
+ key_shape.append(min_max_opt["max"])
+ else:
+ key_shape.append(dim)
+
+ static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
+ return static_tensor
+
+ keys_values = get_kv_nodes(gm)
+
+ kv_inputs = []
+ for idx, key_value in enumerate(keys_values):
+ k_val = key_value[0].meta["val"]
+ v_val = key_value[1].meta["val"]
+ if fixed_kv:
+ k_val = get_static_tensor(k_val)
+ v_val = get_static_tensor(v_val)
+
+ # Add new inputs using add_graph_input
+ k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+ v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
+ kv_inputs.append((k_input, v_input))
+
+ return kv_inputs
+
+
+def insert_torch_cond_before_sdpa(
+ gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]
+):
"""
Insert a torch.cond operation before each scaled_dot_product_attention operation.
-
+
Args:
gm: The FX GraphModule to modify
-
+
Returns:
The modified GraphModule
"""
# Find all nodes with scaled_dot_product_attention
sdpa_nodes = []
for node in gm.graph.nodes:
- if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
+ if (
+ node.op == "call_function"
+ and node.target == torch._C._nn.scaled_dot_product_attention
+ ):
sdpa_nodes.append(node)
-
- # Get the is_causal input node
- is_causal_node = next((node for node in gm.graph.nodes if node.op == "placeholder" and node.name == "is_causal"), None)
+
+ # Get the is_causal input node
+ is_causal_node = next(
+ (
+ node
+ for node in gm.graph.nodes
+ if node.op == "placeholder" and node.name == "is_causal"
+ ),
+ None,
+ )
# For each SDPA node, insert a torch.cond operation before it
for idx, sdpa_node in enumerate(sdpa_nodes):
-
+
with gm.graph.inserting_before(sdpa_node):
# pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool))
q_node, k_node, v_node = sdpa_node.args[:3]
incoming_key, incoming_value = incoming_keys_values[idx]
# Create nodes for concatenating k with incoming_key and v with incoming_value
concatenated_k_node = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([incoming_key, k_node], 2), # Concatenate along sequence length dimension
- kwargs={}
+ args=(
+ [incoming_key, k_node],
+ 2,
+ ), # Concatenate along sequence length dimension
+ kwargs={},
)
concatenated_v_node = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([incoming_value, v_node], 2), # Concatenate along sequence length dimension
- kwargs={}
- )
-
+ args=(
+ [incoming_value, v_node],
+ 2,
+ ), # Concatenate along sequence length dimension
+ kwargs={},
+ )
+
# Create the torch.cond node
cond_k_node = gm.graph.create_node(
"call_function",
torch.ops.higher_order.cond,
args=(is_causal_node, concatenated_k_node, k_node),
)
-
+
cond_v_node = gm.graph.create_node(
"call_function",
torch.ops.higher_order.cond,
args=(is_causal_node, concatenated_v_node, v_node),
)
sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:]
-
+
return gm
-
@_aten_lowering_pass
def insert_dynamic_kv_cache(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Insert FlashInfer MHA + KV cache ops in the graph"""
"""Perform insertion of kv-caches and attention kernel."""
# Add static key and value as inputs to the graph
- kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True)
+ kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True)
# Call the function to add KV as outputs
logits_keys_values = add_kv_as_outputs(gm)
# Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases
gm = insert_torch_cond_before_sdpa(gm, kv_inputs)
gm = clean_up_graph_after_modifications(gm)
-
+
new_output_tensors = create_random_output_tensors(logits_keys_values)
new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
gm._out_spec = new_out_spec
-
+
logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/run_llm.py 2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/run_llm.py 2025-05-31 02:01:02.719018+00:00
@@ -17,37 +17,44 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from contextlib import nullcontext
-from utils import export_llm, generate, recordStats, time_generate, generate_with_kv_cache
+from utils import (
+ export_llm,
+ generate,
+ recordStats,
+ time_generate,
+ generate_with_kv_cache,
+)
import sys
import os
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from register_sdpa import *
DEVICE = torch.device("cuda:0")
+
def get_model(args):
with torch.no_grad():
# Supported list of models:
# - meta-llama/Llama-3.2-1B-Instruct
# - meta-llama/Llama-3.2-3B-Instruct
# - meta-llama/Llama-3.1-8B-Instruct
# - Qwen/Qwen2.5-1.5B-Instruct
model = (
- AutoModelForCausalLM.from_pretrained(
- args.model,
- use_cache=False,
- attn_implementation="sdpa",
- # num_hidden_layers=1
- )
- .eval()
- .cuda()
- )
+ AutoModelForCausalLM.from_pretrained(
+ args.model,
+ use_cache=False,
+ attn_implementation="sdpa",
+ # num_hidden_layers=1
+ )
+ .eval()
+ .cuda()
+ )
if args.precision == "FP16":
model = model.to(torch.float16)
elif args.precision == "BF16":
model = model.to(torch.bfloat16)
else:
@@ -59,23 +66,23 @@
def compile_torchtrt(model, input_ids, args):
max_seq_len = input_ids.shape[1] + args.num_tokens
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
# Set precision specific flags
- use_fp32_acc = False
+ use_fp32_acc = False
use_explicit_typing = False
if args.precision == "FP16":
enabled_precisions = {torch.float32}
- use_fp32_acc = True
+ use_fp32_acc = True
use_explicit_typing = True
elif args.precision == "BF16":
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
else:
enabled_precisions = {torch.float32}
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[input_ids],
enabled_precisions=enabled_precisions,
# truncate_double=True,
@@ -99,112 +106,125 @@
tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
)
print("===================================")
-
def measure_perf(trt_model, input_signature, backend_name):
# Measure average time for 10 iterations
import timeit
import numpy as np
-
+
total_time = 0
iterations = 10
-
+
print("Running warmup iteration...")
# Warmup run
_ = trt_model(*input_signature)
torch.cuda.synchronize()
-
+
print(f"Measuring performance over {iterations} iterations...")
for i in range(iterations):
start_time = timeit.default_timer()
_ = trt_model(*input_signature)
torch.cuda.synchronize()
end_time = timeit.default_timer()
iter_time = end_time - start_time
total_time += iter_time
# print(f"Iteration {i+1}: {iter_time:.4f} seconds")
-
+
avg_time = total_time / iterations
- print(f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds")
- print(f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second")
+ print(
+ f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds"
+ )
+ print(
+ f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second"
+ )
+
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(
description="Run inference on a model with random input values"
)
arg_parser.add_argument(
- "--model", type=str, default="meta-llama/Llama-3.2-1B-Instruct", help="Name of LLM model"
+ "--model",
+ type=str,
+ default="meta-llama/Llama-3.2-1B-Instruct",
+ help="Name of LLM model",
)
arg_parser.add_argument(
"--tokenizer",
type=str,
default="",
help="Name of LLM model tokenizer",
)
arg_parser.add_argument(
"--prompt", type=str, default="What is parallel programming ?", help="Prompt"
)
- arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32")
+ arg_parser.add_argument(
+ "--precision",
+ type=str,
+ default="FP16",
+ help="Precision to use in the model. Options: FP16, BF16, FP32",
+ )
arg_parser.add_argument(
"--iterations", type=int, default=5, help="no. of iterations to run"
)
arg_parser.add_argument(
"--min_block_size", type=int, default=1, help="no. of iterations to run"
)
arg_parser.add_argument(
- "--num_tokens", type=int, default=128, help="no. of output tokens to be generated"
+ "--num_tokens",
+ type=int,
+ default=128,
+ help="no. of output tokens to be generated",
)
arg_parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size used for benchmarking"
)
arg_parser.add_argument(
- "--isl", type=int, default=2048, help="Input sequence length used for benchmarking"
- )
- arg_parser.add_argument(
- "--enable_pytorch_run",
- action="store_true",
- help="Enable pytorch run (default: False)"
+ "--isl",
+ type=int,
+ default=2048,
+ help="Input sequence length used for benchmarking",
+ )
+ arg_parser.add_argument(
+ "--enable_pytorch_run",
+ action="store_true",
+ help="Enable pytorch run (default: False)",
)
arg_parser.add_argument(
"--cache",
type=str,
default="",
help="Type of KV cache to use. Options: static_v1, static_v2, dynamic",
)
arg_parser.add_argument(
- "--cudagraph",
- action="store_true",
- help="Enable cudagraphs (default: False)"
- )
- arg_parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable debug (default: False)"
- )
- arg_parser.add_argument(
- "--benchmark",
- action="store_true",
- help="Enable benchmark (default: False)"
- )
-
+ "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)"
+ )
+ arg_parser.add_argument(
+ "--debug", action="store_true", help="Enable debug (default: False)"
+ )
+ arg_parser.add_argument(
+ "--benchmark", action="store_true", help="Enable benchmark (default: False)"
+ )
+
args = arg_parser.parse_args()
with torch.inference_mode():
model = get_model(args)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model)
# Prepare input for benchmarking or evaluation
if args.benchmark:
- input_ids = torch.randint(1, 10000, (args.batch_size, args.isl), dtype=torch.int64).to(model.device)
+ input_ids = torch.randint(
+ 1, 10000, (args.batch_size, args.isl), dtype=torch.int64
+ ).to(model.device)
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
else:
model_inputs = tokenizer(args.prompt, return_tensors="pt")
input_ids = model_inputs["input_ids"].to(DEVICE)
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
-
MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens
# Pyt
pyt_gen_tokens = None
pyt_timings = None
@@ -221,11 +241,15 @@
MAX_OUTPUT_SEQ_LENGTH,
tokenizer.eos_token_id,
iterations=args.iterations,
)
pyt_stats = recordStats(
- "PyTorch", pyt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None
+ "PyTorch",
+ pyt_timings,
+ args.precision,
+ batch_size=args.batch_size,
+ compile_time_s=None,
)
if args.cache == "static_v1":
# This import is required to register static v1 KV cache transformations as lowering passes
import static_cache_v1
@@ -233,22 +257,28 @@
# This import is required to register static v2 KV cache transformations as lowering passes
import static_cache_v2
elif args.cache == "dynamic":
import dynamic_cache
-
- trt_model = compile_torchtrt(model, input_ids, args)
-
- if args.cache == "static_v1" or args.cache == "static_v2" or args.cache == "dynamic":
+ trt_model = compile_torchtrt(model, input_ids, args)
+
+ if (
+ args.cache == "static_v1"
+ or args.cache == "static_v2"
+ or args.cache == "dynamic"
+ ):
if args.cudagraph:
# Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
# trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
torch_tensorrt.runtime.set_cudagraphs_mode(True)
-
+
trt_gen_tokens = generate_with_kv_cache(
- trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
- )
+ trt_model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
+ )
if args.benchmark:
trt_timings = time_generate(
generate_with_kv_cache,
trt_model,
@@ -257,36 +287,44 @@
tokenizer.eos_token_id,
iterations=args.iterations,
)
else:
trt_gen_tokens = generate(
- trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id,
+ trt_model,
+ input_ids.clone(),
+ MAX_OUTPUT_SEQ_LENGTH,
+ tokenizer.eos_token_id,
)
if args.benchmark:
trt_timings = time_generate(
generate,
trt_model,
input_ids.clone(),
MAX_OUTPUT_SEQ_LENGTH,
tokenizer.eos_token_id,
iterations=args.iterations,
)
-
+
if args.benchmark:
trt_stats = recordStats(
- "TensorRT", trt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None
- )
-
-
+ "TensorRT",
+ trt_timings,
+ args.precision,
+ batch_size=args.batch_size,
+ compile_time_s=None,
+ )
+
if not args.benchmark:
- if args.enable_pytorch_run:
+ if args.enable_pytorch_run:
print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
-
+
print_outputs("TensorRT", trt_gen_tokens, tokenizer)
- if args.enable_pytorch_run:
- print(f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}")
+ if args.enable_pytorch_run:
+ print(
+ f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}"
+ )
if args.benchmark:
if args.enable_pytorch_run:
print("=========PyTorch PERFORMANCE============ \n")
print(pyt_stats)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/static_cache_v1.py 2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/static_cache_v1.py 2025-05-31 02:01:02.728059+00:00
@@ -12,55 +12,57 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
import torch.utils._pytree as pytree
from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes
+
logger = logging.getLogger(__name__)
SDPA_OP = torch._C._nn.scaled_dot_product_attention
+
def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
"""
Modifies the graph to add query, key, and value tensors as outputs.
-
+
This function identifies all scaled dot-product attention (SDPA) operations
in the graph, creates copies of their query, key, and value inputs, and adds
these copies to the graph's outputs. This allows for accessing these tensors
externally, which is useful for operations like key-value caching.
-
+
Args:
graph: The torch.fx.Graph to modify
-
+
Returns:
None. The graph is modified in-place.
"""
- output_node = next(node for node in gm.graph.nodes if node.op == "output")
+ output_node = next(node for node in gm.graph.nodes if node.op == "output")
# Get the current output args (typically a tuple)
current_outputs = output_node.args[0]
-
+
# If the current output is a tuple, extend it with our new outputs
if isinstance(current_outputs, tuple):
new_outputs = current_outputs + tuple(kv_cache_for_graph)
else:
# If there's only one output or it's not a tuple, create a new tuple
- new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
-
+ new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
gm.graph.output(new_outputs)
gm.graph.erase_node(output_node)
return new_outputs
def add_kv_cache_inputs(gm, fixed_kv: bool = True):
"""
Add key-value tensors, index parameters as inputs to the graph.
-
+
Args:
gm: The GraphModule to modify
fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
-
+
Returns:
A tuple containing:
- List of (k_input, v_input) node pairs for each SDPA operation
- start_idx input node for slicing operations
- end_idx input node for slicing operations
@@ -72,14 +74,14 @@
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
key_shape.append(min_max_opt["max"])
else:
key_shape.append(dim)
-
+
static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
return static_tensor
-
+
keys_values = get_kv_nodes(gm)
kv_inputs = []
for idx, key_value in enumerate(keys_values):
k_val = key_value[0].meta["val"]
@@ -87,12 +89,12 @@
if fixed_kv:
k_val = get_static_tensor(k_val)
v_val = get_static_tensor(v_val)
# Add new inputs using add_graph_input
- k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
- v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
+ k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+ v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
kv_inputs.append((k_input, v_input))
# Add start_idx and end_idx as inputs
start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0))
end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1))
@@ -103,10 +105,11 @@
seq_len = input_ids_meta.shape[1]
min_max_opt = extract_var_range_info(seq_len)
max_seq_len = min_max_opt["max"]
from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
shape_env = ShapeEnv()
# Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
start_idx_unbacked_symint = shape_env.create_unbacked_symint()
torch._check(start_idx_unbacked_symint >= 0)
torch._check(start_idx_unbacked_symint <= max_seq_len)
@@ -123,138 +126,152 @@
is_causal_input.meta["val"] = torch.tensor(True)
return kv_inputs, start_idx_input, end_idx_input, is_causal_input
-
-def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node):
+def insert_kv_slicing_before_sdpa(
+ gm,
+ incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+ start_idx_input: Node,
+ end_idx_input: Node,
+ is_causal_input: Node,
+):
"""
Insert slicing operations before each scaled_dot_product_attention operation.
"""
# Find all nodes with scaled_dot_product_attention
sdpa_nodes = []
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == SDPA_OP:
sdpa_nodes.append(node)
kv_cache_for_graph = []
for idx, sdpa_node in enumerate(sdpa_nodes):
- assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
+ assert (
+ len(sdpa_node.args) == 6
+ ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
incoming_key, incoming_value = incoming_keys_values[idx]
kv_cache_for_sdpa_node = []
new_keys_values = []
- for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]):
+ for key_or_value, current_key_or_value_node in zip(
+ [incoming_key, incoming_value], [k_node, v_node]
+ ):
# Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
with gm.graph.inserting_before(sdpa_node):
slice_1 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(key_or_value,),
- kwargs={}
+ kwargs={},
)
slice_2 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(slice_1, 1),
- kwargs={}
+ kwargs={},
)
slice_3 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_2, 2, None, start_idx_input),
- kwargs={}
+ args=(slice_2, 2, None, start_idx_input),
+ kwargs={},
)
slice_4 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_3, 3),
- kwargs={}
- )
- # =============================================== #
+ args=(slice_3, 3),
+ kwargs={},
+ )
+ # =============================================== #
# Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
slice_5 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(key_or_value,),
- kwargs={}
+ kwargs={},
)
slice_6 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(slice_5, 1),
- kwargs={}
+ kwargs={},
)
slice_7 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_6, 2, end_idx_input),
- kwargs={}
+ args=(slice_6, 2, end_idx_input),
+ kwargs={},
)
slice_8 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_7, 3),
- kwargs={}
- )
- # =============================================== #
+ args=(slice_7, 3),
+ kwargs={},
+ )
+ # =============================================== #
# Concatenate the sliced tensors to build KV cache
cat = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([slice_4, current_key_or_value_node, slice_8], 2),
- kwargs={}
+ args=([slice_4, current_key_or_value_node, slice_8], 2),
+ kwargs={},
)
# Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
cat.meta.update(key_or_value.meta)
kv_cache_for_sdpa_node.append(cat)
- # =============================================== #
+ # =============================================== #
# Get the current key and value by indexing the KV cache
slice_9 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(cat,),
- kwargs={}
+ "call_function", torch.ops.aten.slice.Tensor, args=(cat,), kwargs={}
)
slice_10 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(slice_9, 1),
- kwargs={}
+ kwargs={},
)
slice_11 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_10, 2, None, end_idx_input),
- kwargs={}
+ args=(slice_10, 2, None, end_idx_input),
+ kwargs={},
)
slice_12 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_11, 3),
- kwargs={}
+ args=(slice_11, 3),
+ kwargs={},
)
new_keys_values.append(slice_12)
-
+
kv_cache_for_graph.extend(kv_cache_for_sdpa_node)
- sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, is_causal_input)
-
+ sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (
+ attn_mask,
+ dropout_p,
+ is_causal_input,
+ )
+
return gm, kv_cache_for_graph
@_aten_lowering_pass
def insert_kv_cache(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Insert KV cache ops in the graph"""
"""Perform insertion of kv-caches and attention kernel."""
# Add static key and value as inputs to the graph
- kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True)
-
- # Build and update the KV cache using computed KV inputs for current token and
+ kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(
+ gm, fixed_kv=True
+ )
+
+ # Build and update the KV cache using computed KV inputs for current token and
# incoming keys and values from previous tokens (which were added as inputs)
- gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input)
+ gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+ gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input
+ )
# Call the function to add KV as outputs
logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)
gm = clean_up_graph_after_modifications(gm)
@@ -264,7 +281,5 @@
new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
gm._out_spec = new_out_spec
logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/static_cache_v2.py 2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/static_cache_v2.py 2025-05-31 02:01:02.768255+00:00
@@ -12,55 +12,57 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
import torch.utils._pytree as pytree
from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes
+
logger = logging.getLogger(__name__)
SDPA_OP = torch._C._nn.scaled_dot_product_attention
+
def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
"""
Modifies the graph to add query, key, and value tensors as outputs.
-
+
This function identifies all scaled dot-product attention (SDPA) operations
in the graph, creates copies of their query, key, and value inputs, and adds
these copies to the graph's outputs. This allows for accessing these tensors
externally, which is useful for operations like key-value caching.
-
+
Args:
graph: The torch.fx.Graph to modify
-
+
Returns:
None. The graph is modified in-place.
"""
- output_node = next(node for node in gm.graph.nodes if node.op == "output")
+ output_node = next(node for node in gm.graph.nodes if node.op == "output")
# Get the current output args (typically a tuple)
current_outputs = output_node.args[0]
-
+
# If the current output is a tuple, extend it with our new outputs
if isinstance(current_outputs, tuple):
new_outputs = current_outputs + tuple(kv_cache_for_graph)
else:
# If there's only one output or it's not a tuple, create a new tuple
- new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
-
+ new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
gm.graph.output(new_outputs)
gm.graph.erase_node(output_node)
return new_outputs
def add_kv_cache_inputs(gm, fixed_kv: bool = True):
"""
Add key-value tensors, index parameters as inputs to the graph.
-
+
Args:
gm: The GraphModule to modify
fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
-
+
Returns:
A tuple containing:
- List of (k_input, v_input) node pairs for each SDPA operation
- start_idx input node for slicing operations
- end_idx input node for slicing operations
@@ -72,14 +74,14 @@
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
key_shape.append(min_max_opt["max"])
else:
key_shape.append(dim)
-
+
static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
return static_tensor
-
+
keys_values = get_kv_nodes(gm)
kv_inputs = []
for idx, key_value in enumerate(keys_values):
k_val = key_value[0].meta["val"]
@@ -87,12 +89,12 @@
if fixed_kv:
k_val = get_static_tensor(k_val)
v_val = get_static_tensor(v_val)
# Add new inputs using add_graph_input
- k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
- v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
+ k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+ v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val)
kv_inputs.append((k_input, v_input))
# Add start_idx and end_idx as inputs
start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0))
end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1))
@@ -100,18 +102,19 @@
# Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
# Get the third last input which should be the last value cache node and store the max_seq_len
input_ids_meta = input_nodes[-3].meta["val"]
seq_len = input_ids_meta.shape[2]
-
+
if isinstance(seq_len, torch.SymInt):
min_max_opt = extract_var_range_info(seq_len)
max_seq_len = min_max_opt["max"]
else:
max_seq_len = seq_len
from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
shape_env = ShapeEnv()
# Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
start_idx_unbacked_symint = shape_env.create_unbacked_symint()
torch._check(start_idx_unbacked_symint >= 0)
torch._check(start_idx_unbacked_symint <= max_seq_len)
@@ -127,14 +130,17 @@
is_causal_input = add_graph_input(gm, "is_causal", True)
is_causal_input.meta["val"] = torch.tensor(True)
return kv_inputs, start_idx_input, end_idx_input, is_causal_input
-def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input):
+
+def create_kv_cache_update_nodes(
+ gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input
+):
"""
Create slicing and concatenation nodes for KV cache update.
-
+
This function creates the necessary slicing and concatenation nodes to update the KV cache
during the generation process. It takes the SDPA node, the current KV cache node, and the
incoming KV cache node as input.
Returns:
for a particular SDPA node, a tuple containing:
@@ -147,78 +153,73 @@
with gm.graph.inserting_before(sdpa_node):
slice_1 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(incoming_kv_node,),
- kwargs={}
+ kwargs={},
)
slice_2 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(slice_1, 1),
- kwargs={}
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_1, 1), kwargs={}
)
slice_3 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_2, 2, None, start_idx_input),
- kwargs={}
+ args=(slice_2, 2, None, start_idx_input),
+ kwargs={},
)
slice_4 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(slice_3, 3),
- kwargs={}
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_3, 3), kwargs={}
)
# Concat key_cache[:,:,:start_idx,:] with current key (k)
concat_keys_or_values = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([slice_4, current_kv_node], 2),
- kwargs={}
- )
-
- # =============================================== #
+ args=([slice_4, current_kv_node], 2),
+ kwargs={},
+ )
+
+ # =============================================== #
# Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
slice_5 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
args=(incoming_kv_node,),
- kwargs={}
+ kwargs={},
)
slice_6 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(slice_5, 1),
- kwargs={}
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_5, 1), kwargs={}
)
slice_7 = gm.graph.create_node(
"call_function",
torch.ops.aten.slice.Tensor,
- args=(slice_6, 2, end_idx_input),
- kwargs={}
+ args=(slice_6, 2, end_idx_input),
+ kwargs={},
)
slice_8 = gm.graph.create_node(
- "call_function",
- torch.ops.aten.slice.Tensor,
- args=(slice_7, 3),
- kwargs={}
- )
- # =============================================== #
+ "call_function", torch.ops.aten.slice.Tensor, args=(slice_7, 3), kwargs={}
+ )
+ # =============================================== #
# Concatenate the sliced tensors to build KV cache
new_incoming_keys_or_values = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
- args=([concat_keys_or_values, slice_8], 2),
- kwargs={}
+ args=([concat_keys_or_values, slice_8], 2),
+ kwargs={},
)
# Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
new_incoming_keys_or_values.meta.update(incoming_kv_node.meta)
return concat_keys_or_values, new_incoming_keys_or_values
-def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node):
+
+def insert_kv_slicing_before_sdpa(
+ gm,
+ incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+ start_idx_input: Node,
+ end_idx_input: Node,
+ is_causal_input: Node,
+):
"""
Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic:
concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
@@ -230,24 +231,40 @@
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == SDPA_OP:
sdpa_nodes.append(node)
kv_cache_for_graph = []
for idx, sdpa_node in enumerate(sdpa_nodes):
- assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
+ assert (
+ len(sdpa_node.args) == 6
+ ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
incoming_key, incoming_value = incoming_keys_values[idx]
- # For keys
- new_current_key_node, new_incoming_key_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input)
+ # For keys
+ new_current_key_node, new_incoming_key_cache_node = (
+ create_kv_cache_update_nodes(
+ gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input
+ )
+ )
# For values
- new_current_value_node, new_incoming_value_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input)
+ new_current_value_node, new_incoming_value_cache_node = (
+ create_kv_cache_update_nodes(
+ gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input
+ )
+ )
# Store the KV cache nodes for the current SDPA node
- kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node])
+ kv_cache_for_graph.extend(
+ [new_incoming_key_cache_node, new_incoming_value_cache_node]
+ )
# Update the SDPA node arguments with current key and value nodes
- sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (attn_mask, dropout_p, is_causal_input)
-
+ sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (
+ attn_mask,
+ dropout_p,
+ is_causal_input,
+ )
+
# kv_cache_for_graph.extend([k_node, v_node])
return gm, kv_cache_for_graph
@_aten_lowering_pass
@@ -255,15 +272,19 @@
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Insert KV cache ops in the graph"""
"""Perform insertion of kv-caches and attention kernel."""
# Add static key and value as inputs to the graph
- kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True)
-
- # Build and update the KV cache using computed KV inputs for current token and
+ kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(
+ gm, fixed_kv=True
+ )
+
+ # Build and update the KV cache using computed KV inputs for current token and
# incoming keys and values from previous tokens (which were added as inputs)
- gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input)
+ gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+ gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input
+ )
# Call the function to add KV as outputs
logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)
gm = clean_up_graph_after_modifications(gm)
@@ -273,7 +294,5 @@
new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
gm._out_spec = new_out_spec
logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
return gm
-
-
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_qwen2.5_components.py 2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_qwen2.5_components.py 2025-05-31 02:01:02.885818+00:00
@@ -14,42 +14,50 @@
import argparse
import sys
import os
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from register_sdpa import *
ATOL = 1e-5
RTOL = 1e-5
qwen2_5_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
-qwen2_5_model = AutoModelForCausalLM.from_pretrained(
- qwen2_5_model_name,
- use_cache=False,
- attn_implementation="sdpa",
- num_hidden_layers=1,
- ).eval().cuda()
+qwen2_5_model = (
+ AutoModelForCausalLM.from_pretrained(
+ qwen2_5_model_name,
+ use_cache=False,
+ attn_implementation="sdpa",
+ num_hidden_layers=1,
+ )
+ .eval()
+ .cuda()
+)
QWEN_CONFIG = qwen2_5_model.config
+
def print_diff(tensor1, tensor2, prefix=""):
"""
Print the diff between two tensors
"""
- print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}")
+ print(
+ f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+ )
+
def test_qwen_apply_rotary_pos_emb(args):
class QwenApplyRotaryPosEmb(nn.Module):
def __init__(self):
super().__init__()
-
+
def rotate_half(self, x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
-
+
def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
@@ -63,95 +71,104 @@
DTYPE = torch.float16
elif args.precision == "BF16":
DTYPE = torch.bfloat16
# Set precision specific flags
- use_fp32_acc = False
+ use_fp32_acc = False
use_explicit_typing = False
if args.precision == "FP16":
enabled_precisions = {torch.float32}
- use_fp32_acc = True
+ use_fp32_acc = True
use_explicit_typing = True
elif args.precision == "BF16":
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
else:
enabled_precisions = {torch.float32}
model = QwenApplyRotaryPosEmb().eval().cuda().to(DTYPE)
- # Shapes for Qwen 2.5
+ # Shapes for Qwen 2.5
q = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda()
k = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda()
cos = torch.randn((1, 5, 128), dtype=DTYPE).cuda()
sin = torch.randn((1, 5, 128), dtype=DTYPE).cuda()
pyt_output = model(q, k, cos, sin)
-
+
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
dynamic_shapes = ({2: seq_len}, {2: seq_len}, {1: seq_len}, {1: seq_len})
ep = torch.export.export(model, (q, k, cos, sin), dynamic_shapes=dynamic_shapes)
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
- trt_model = torch_tensorrt.dynamo.compile(ep,
- inputs=[q, k, cos, sin],
- enabled_precisions=enabled_precisions,
- disable_tf32=True,
- use_fp32_acc=use_fp32_acc,
- use_explicit_typing=use_explicit_typing,
- debug=args.debug)
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[q, k, cos, sin],
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ debug=args.debug,
+ )
trt_output = trt_model(q, k, cos, sin)
-
+
if isinstance(pyt_output, tuple):
print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt")
# print_diff(pyt_output[1], trt_output[1], "Diff b/w pyt and trt")
assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
else:
print_diff(pyt_output, trt_output, "Diff b/w pyt and trt")
assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
-
-
+
+
def test_qwen_attention(args):
-
+
DTYPE = torch.float32
if args.precision == "FP16":
DTYPE = torch.float16
elif args.precision == "BF16":
DTYPE = torch.bfloat16
-
+
# Set precision specific flags
- use_fp32_acc = False
+ use_fp32_acc = False
use_explicit_typing = False
if args.precision == "FP16":
enabled_precisions = {torch.float32}
- use_fp32_acc = True
+ use_fp32_acc = True
use_explicit_typing = True
elif args.precision == "BF16":
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
else:
enabled_precisions = {torch.float32}
model = qwen2_5_model.model.layers[0].self_attn.to(DTYPE)
- # qwen2.5
+ # qwen2.5
hidden_states = torch.randn((1, 5, 1536), dtype=DTYPE).cuda()
- position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda())
+ position_embeddings = (
+ torch.randn((1, 5, 128), dtype=DTYPE).cuda(),
+ torch.randn((1, 5, 128), dtype=DTYPE).cuda(),
+ )
pyt_output = model(hidden_states, position_embeddings, None)
-
+
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
- ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
+ ep = torch.export.export(
+ model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+ )
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
- trt_model = torch_tensorrt.dynamo.compile(ep,
- inputs=[hidden_states, position_embeddings, None],
- enabled_precisions=enabled_precisions,
- disable_tf32=True,
- use_fp32_acc=use_fp32_acc,
- use_explicit_typing=use_explicit_typing,
- debug=args.debug)
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[hidden_states, position_embeddings, None],
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ debug=args.debug,
+ )
trt_output = trt_model(hidden_states, position_embeddings, None)
-
+
if isinstance(pyt_output, tuple):
print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt")
assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
else:
print_diff(pyt_output, trt_output, "Diff b/w pyt and trt")
@@ -161,14 +178,17 @@
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(
description="Run test cases for llama attention and decoder"
)
arg_parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable debug (default: False)"
+ "--debug", action="store_true", help="Enable debug (default: False)"
)
- arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32")
+ arg_parser.add_argument(
+ "--precision",
+ type=str,
+ default="FP16",
+ help="Precision to use in the model. Options: FP16, BF16, FP32",
+ )
args = arg_parser.parse_args()
with torch.inference_mode():
# test_qwen_apply_rotary_pos_emb(args)
test_qwen_attention(args)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/utils.py 2025-05-31 02:00:39.720818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/utils.py 2025-05-31 02:01:02.902987+00:00
@@ -2,13 +2,14 @@
from transformers import StoppingCriteriaList
from transformers.generation.stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
)
-import numpy as np
-import copy
+import numpy as np
+import copy
import timeit
+
def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
"""
Exports the LLM model into an ExportedProgram with dynamic shapes.
In the case of guard failures due to some PyTorch kernel implements, we also
@@ -20,49 +21,60 @@
position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device)
try:
print("Trying to export the model using torch.export.export()..")
# strict=False only enables aotautograd tracing and excludes dynamo.
ep = torch.export.export(
- model, args=(inputs,), kwargs={"position_ids":position_ids}, dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False
+ model,
+ args=(inputs,),
+ kwargs={"position_ids": position_ids},
+ dynamic_shapes=({1: seq_len}, {1: seq_len}),
+ strict=False,
)
except:
print(
"Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
)
# This API is used to express the constraint violation guards as asserts in the graph.
ep = torch.export._trace._export(
model,
args=(inputs,),
- kwargs={"position_ids":position_ids},
+ kwargs={"position_ids": position_ids},
dynamic_shapes=({1: seq_len}, {1: seq_len}),
strict=False,
allow_complex_guards_as_runtime_asserts=True,
)
return ep
+
def get_zeroed_kv_cache_inputs(model: torch.fx.GraphModule):
"""
Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule.
-
+
This function identifies placeholder nodes in the graph that represent KV cache tensors,
and creates zeroed tensors with the same shape, dtype, and device as the original placeholders.
-
+
Args:
model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders
-
+
Returns:
tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph
"""
# placeholder nodes are expected to be in the following order:
# input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx
placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"]
# The first two inputs are input_ids, position_ids. The last three inputs are start_idx, end_idx and is_causal. In between are the KV cache tensors.
kv_cache_inputs = placeholder_nodes[2:-3]
zeroed_kv_cache_inputs = []
for input in kv_cache_inputs:
- zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0")))
+ zeroed_kv_cache_inputs.append(
+ torch.zeros(
+ input.meta["val"].shape,
+ dtype=input.meta["val"].dtype,
+ device=torch.device("cuda:0"),
+ )
+ )
return tuple(zeroed_kv_cache_inputs)
def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True):
@@ -75,11 +87,11 @@
EosTokenCriteria(eos_token_id=eos_token_id),
]
)
isl = input_seq.shape[1]
osl = max_output_seq_length - isl
-
+
num_tokens_generated = 0
while num_tokens_generated < osl:
position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda()
outputs = model(input_seq, position_ids)
logits = outputs.logits
@@ -91,10 +103,11 @@
if not benchmark and stopping_criteria(input_seq, logits).item():
break
return input_seq
+
def generate_with_kv_cache(model, input_seq, max_output_seq_length, eos_token_id):
"""
Greedy decoding of the model with KV cache.
"""
start_idx = 0
@@ -105,38 +118,48 @@
logits_concat = []
num_tokens_generated = 0
kv_cache = get_zeroed_kv_cache_inputs(model)
while end_idx < max_output_seq_length:
is_causal = True if input_seq.shape[1] > 1 else False
- position_ids = torch.tensor([[start_idx]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids
- input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx, is_causal)
+ position_ids = (
+ torch.tensor([[start_idx]], dtype=torch.int64).cuda()
+ if input_seq.shape[1] == 1
+ else position_ids
+ )
+ input_signature = (
+ input_seq,
+ position_ids,
+ *kv_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
logits_keys_values = model(*input_signature)
num_tokens_generated += 1
logits = logits_keys_values[0]
logits_concat.append(logits)
kv_cache = logits_keys_values[1:]
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
output_seq = torch.cat([output_seq, next_tokens], dim=-1)
input_seq = next_tokens
start_idx = end_idx
- end_idx = start_idx + 1
+ end_idx = start_idx + 1
lkv = torch.cat(logits_concat, dim=1)
return output_seq
+
def time_generate(
generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10
):
"""
Measure the time for generating a sentence over certain number of iterations
"""
timings = []
for _ in range(iterations):
start_time = timeit.default_timer()
- _ = generate_fn(
- model, inputs, output_seq_length, eos_token_id
- )
+ _ = generate_fn(model, inputs, output_seq_length, eos_token_id)
torch.cuda.synchronize()
end_time = timeit.default_timer()
timings.append(end_time - start_time)
return timings
@@ -164,6 +187,6 @@
"Median-Latency(ms)": time_med * 1000,
"Mean-Latency(ms)": time_mean * 1000,
"Latency-StdDev(ms)": time_std * 1000,
"Compile Time(s)": compile_time_s,
}
- return stats
\ No newline at end of file
+ return stats
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_llama_components.py 2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_llama_components.py 2025-05-31 02:01:03.012288+00:00
@@ -14,253 +14,359 @@
import argparse
import sys
import os
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from register_sdpa import *
+
ATOL = 1e-5
RTOL = 1e-5
# llama2_model_name = "meta-llama/Llama-2-7b-hf"
llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct"
-llama_model = AutoModelForCausalLM.from_pretrained(
- llama3_model_name,
- use_cache=False,
- attn_implementation="sdpa",
- num_hidden_layers=1,
- ).eval().cuda()
+llama_model = (
+ AutoModelForCausalLM.from_pretrained(
+ llama3_model_name,
+ use_cache=False,
+ attn_implementation="sdpa",
+ num_hidden_layers=1,
+ )
+ .eval()
+ .cuda()
+)
LLAMA_CONFIG = llama_model.config
+
def test_llama_attention(args):
class LlamaAttentionBlock(nn.Module):
def __init__(self):
super().__init__()
self.config = LLAMA_CONFIG
- self.attn = LlamaAttention(
- config=self.config,
- layer_idx=0
+ self.attn = LlamaAttention(config=self.config, layer_idx=0)
+
+ def forward(self, hidden_states, position_embeddings):
+ attn_output, attn_weights = self.attn(
+ hidden_states, position_embeddings, None
)
- def forward(self, hidden_states, position_embeddings):
- attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None)
return attn_output
-
+
DTYPE = torch.float32
if args.precision == "FP16":
DTYPE = torch.float16
elif args.precision == "BF16":
DTYPE = torch.bfloat16
-
+
# Set precision specific flags
- use_fp32_acc = False
+ use_fp32_acc = False
use_explicit_typing = False
if args.precision == "FP16":
enabled_precisions = {torch.float32}
- use_fp32_acc = True
+ use_fp32_acc = True
use_explicit_typing = True
elif args.precision == "BF16":
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
else:
enabled_precisions = {torch.float32}
# model = LlamaAttentionBlock().eval().cuda().to(DTYPE)
model = llama_model.model.layers[0].self_attn.to(DTYPE)
- # llama3
+ # llama3
hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
- position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda())
+ position_embeddings = (
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ )
pyt_output = model(hidden_states, position_embeddings, None)
-
+
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
- ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
-
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
- trt_model = torch_tensorrt.dynamo.compile(ep,
- inputs=[hidden_states, position_embeddings, None],
- enabled_precisions=enabled_precisions,
- disable_tf32=True,
- use_fp32_acc=use_fp32_acc,
- use_explicit_typing=use_explicit_typing,
- debug=args.debug)
+ ep = torch.export.export(
+ model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+ )
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[hidden_states, position_embeddings, None],
+ enabled_precisions=enabled_precisions,
+ disable_tf32=True,
+ use_fp32_acc=use_fp32_acc,
+ use_explicit_typing=use_explicit_typing,
+ debug=args.debug,
+ )
trt_output = trt_model(hidden_states, position_embeddings, None)
if isinstance(pyt_output, tuple):
- print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}")
+ print(
+ f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+ )
breakpoint()
assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
else:
print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")
assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
def print_diff(tensor1, tensor2, prefix=""):
"""
Print the diff between two tensors
"""
- print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}")
+ print(
+ f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+ )
+
def test_llama_attention_with_static_cache(args):
class LlamaAttentionBlock(nn.Module):
def __init__(self):
super().__init__()
self.config = LLAMA_CONFIG
- self.attn = LlamaAttention(
- config=self.config,
- layer_idx=0
+ self.attn = LlamaAttention(config=self.config, layer_idx=0)
+
+ def forward(self, hidden_states, position_embeddings):
+ attn_output, attn_weights = self.attn(
+ hidden_states, position_embeddings, None
)
- def forward(self, hidden_states, position_embeddings):
- attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None)
return attn_output
-
+
DTYPE = torch.float32
model = llama_model.model.layers[0].self_attn.to(DTYPE)
- # Inputs
+ # Inputs
ISL = 2048
NUM_TOKENS = 128
OSL = ISL + NUM_TOKENS
hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda()
- position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda())
+ position_embeddings = (
+ torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+ )
key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
start_idx = 0
end_idx = ISL
is_causal = True
pyt_output = model(hidden_states, position_embeddings, None)
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
- ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
+ ep = torch.export.export(
+ model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+ )
import register_sdpa
import static_cache2
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
- trt_model = torch_tensorrt.dynamo.compile(ep,
- inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal],
- enabled_precisions={torch.float32},
- disable_tf32=True,
- debug=args.debug,
- # offload_module_to_cpu=True,
- use_python_runtime=True)
-
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[
+ hidden_states,
+ position_embeddings,
+ None,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ ],
+ enabled_precisions={torch.float32},
+ disable_tf32=True,
+ debug=args.debug,
+ # offload_module_to_cpu=True,
+ use_python_runtime=True,
+ )
+
# Test Prefill
- trt_output, _, key_cache, value_cache = trt_model(hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal)
+ trt_output, _, key_cache, value_cache = trt_model(
+ hidden_states,
+ position_embeddings,
+ None,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]")
# Test Generate
for start_idx in range(2048, 2176):
end_idx = start_idx + 1
hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda()
- position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda())
+ position_embeddings_curr = (
+ torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+ )
# Concatenate the current hidden_states with the previous ones
hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
- position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1))
-
+ position_embeddings_full = (
+ torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1),
+ torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1),
+ )
+
is_causal = False
out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None)
- out_trt, _, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, None, key_cache, value_cache, start_idx, end_idx, is_causal)
+ out_trt, _, key_cache, value_cache = trt_model(
+ hidden_states_curr,
+ position_embeddings_curr,
+ None,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
out_pyt = out_no_cache[:, -1:, :]
print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
hidden_states = hidden_states_full
position_embeddings = position_embeddings_full
def test_llama_decoder(args):
-
+
DTYPE = torch.float32
model = llama_model.model.layers[0].to(DTYPE)
- # llama3
+ # llama3
hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
- position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda())
+ position_embeddings = (
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+ )
pyt_output = model(hidden_states, position_embeddings)
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
- ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes)
-
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
- trt_model = torch_tensorrt.dynamo.compile(ep,
- inputs=[hidden_states, position_embeddings],
- enabled_precisions={torch.float32},
- debug=args.debug)
+ ep = torch.export.export(
+ model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+ )
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ inputs=[hidden_states, position_embeddings],
+ enabled_precisions={torch.float32},
+ debug=args.debug,
+ )
trt_output = trt_model(hidden_states, position_embeddings)
print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")
assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
def test_llama_decoder_with_static_cache(args):
class LlamaDecoderLayerBlock(nn.Module):
def __init__(self, model):
super().__init__()
self.config = LLAMA_CONFIG
- self.decoder = LlamaDecoderLayer(
- config=self.config,
- layer_idx=0)
+ self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0)
self.model = model
+
def forward(self, hidden_states, position_embeddings):
return self.model(hidden_states, position_embeddings=position_embeddings)
DTYPE = torch.float32
model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE))
-
- # Inputs
+
+ # Inputs
ISL = 2048
NUM_TOKENS = 128
OSL = ISL + NUM_TOKENS
hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda()
- position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda())
+ position_embeddings = (
+ torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+ )
key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
start_idx = 0
end_idx = ISL
is_causal = True
pyt_output = model(hidden_states, position_embeddings)
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
- ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes)
+ ep = torch.export.export(
+ model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+ )
import register_sdpa
import static_cache2
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
- trt_model = torch_tensorrt.dynamo.compile(ep,
- arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal],
- enabled_precisions={torch.float32},
- disable_tf32=True,
- debug=args.debug,
- # offload_module_to_cpu=True,
- use_python_runtime=True)
-
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ arg_inputs=[
+ hidden_states,
+ position_embeddings,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ ],
+ enabled_precisions={torch.float32},
+ disable_tf32=True,
+ debug=args.debug,
+ # offload_module_to_cpu=True,
+ use_python_runtime=True,
+ )
+
# Test Prefill
- trt_output, key_cache, value_cache = trt_model(hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal)
+ trt_output, key_cache, value_cache = trt_model(
+ hidden_states,
+ position_embeddings,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]")
# Test Generate
for start_idx in range(2048, 2176):
end_idx = start_idx + 1
hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda()
- position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda())
+ position_embeddings_curr = (
+ torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+ torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+ )
# Concatenate the current hidden_states with the previous ones
hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
- position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1))
-
+ position_embeddings_full = (
+ torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1),
+ torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1),
+ )
+
is_causal = False
out_no_cache = model(hidden_states_full, position_embeddings_full)
- out_trt, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, key_cache, value_cache, start_idx, end_idx, is_causal)
+ out_trt, key_cache, value_cache = trt_model(
+ hidden_states_curr,
+ position_embeddings_curr,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
out_pyt = out_no_cache[0][:, -1:, :]
print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
hidden_states = hidden_states_full
position_embeddings = position_embeddings_full
+
def test_llama_model_with_static_cache(args):
DTYPE = torch.float32
model = llama_model.model.to(DTYPE)
- # Inputs
+ # Inputs
ISL = 2048
NUM_TOKENS = 128
OSL = ISL + NUM_TOKENS
input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda()
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda()
@@ -271,66 +377,77 @@
is_causal = True
pyt_output = model(input_ids)
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
dynamic_shapes = ({1: seq_len}, {1: seq_len})
- kwarg_inputs = {"input_ids":input_ids, "position_ids":position_ids}
- ep = torch.export.export(model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes)
+ kwarg_inputs = {"input_ids": input_ids, "position_ids": position_ids}
+ ep = torch.export.export(
+ model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes
+ )
import register_sdpa
import static_cache2
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
- trt_model = torch_tensorrt.dynamo.compile(ep,
- arg_inputs=[],
- kwarg_inputs=kwarg_inputs,
- enabled_precisions={torch.float32},
- disable_tf32=True,
- debug=args.debug,
- # offload_module_to_cpu=True,
- use_python_runtime=True)
-
+
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+ trt_model = torch_tensorrt.dynamo.compile(
+ ep,
+ arg_inputs=[],
+ kwarg_inputs=kwarg_inputs,
+ enabled_precisions={torch.float32},
+ disable_tf32=True,
+ debug=args.debug,
+ # offload_module_to_cpu=True,
+ use_python_runtime=True,
+ )
+
# Test Prefill
- trt_output, key_cache, value_cache = trt_model(input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal)
+ trt_output, key_cache, value_cache = trt_model(
+ input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal
+ )
pyt_output = pyt_output.last_hidden_state
print_diff(pyt_output, trt_output, "pyt_output vs trt_output [Prefill]")
# Test Generate
for start_idx in range(2048, 2176):
end_idx = start_idx + 1
input_ids_curr = torch.randint(1, 20, (1, 1), dtype=torch.int64).cuda()
position_ids_curr = torch.tensor([[start_idx]], dtype=torch.int64).cuda()
-
+
# Concatenate the current hidden_states with the previous ones
input_ids_full = torch.cat((input_ids, input_ids_curr), dim=1)
position_ids_full = torch.cat((position_ids, position_ids_curr), dim=1)
is_causal = False
- kwarg_inputs = {"input_ids":input_ids_full, "position_ids":position_ids_full}
+ kwarg_inputs = {"input_ids": input_ids_full, "position_ids": position_ids_full}
out_no_cache = model(**kwarg_inputs)
- out_trt, key_cache, value_cache = trt_model(input_ids_curr, position_ids_curr, key_cache, value_cache, start_idx, end_idx, is_causal)
+ out_trt, key_cache, value_cache = trt_model(
+ input_ids_curr,
+ position_ids_curr,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
out_pyt = out_no_cache.last_hidden_state[:, -1:, :]
print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
input_ids = input_ids_full
position_ids = position_ids_full
+
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(
description="Run test cases for llama attention and decoder"
)
arg_parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable debug (default: False)"
+ "--debug", action="store_true", help="Enable debug (default: False)"
)
arg_parser.add_argument(
- "--precision",
- type=str,
- default="FP16",
- help="Precision (default: FP16)"
+ "--precision", type=str, default="FP16", help="Precision (default: FP16)"
)
args = arg_parser.parse_args()
with torch.inference_mode():
test_llama_attention(args)
# test_llama_decoder(args)
# test_llama_attention_with_static_cache(args)
# test_llama_decoder_with_static_cache(args)
- # test_llama_model_with_static_cache(args)
\ No newline at end of file
+ # test_llama_model_with_static_cache(args)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/register_sdpa.py 2025-05-31 02:00:39.720818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/register_sdpa.py 2025-05-31 02:01:03.019858+00:00
@@ -72,13 +72,11 @@
if len(node.args) == 6:
query, key, value, dropout_p, is_causal, return_debug_mask = (
node.args
)
if len(node.args) == 5:
- query, key, value, dropout_p, is_causal = (
- node.args
- )
+ query, key, value, dropout_p, is_causal = node.args
elif len(node.args) == 3:
query, key, value = node.args
dropout_p = 0.0
is_causal = True
else:
@@ -95,11 +93,14 @@
# The input args is (query, key, value, is_causal). kwargs has scale
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
torch.nn.functional.scaled_dot_product_attention,
args=modified_input_args,
- kwargs={"scale": node.kwargs.get("scale", None), "use_fp32_acc": settings.use_fp32_acc},
+ kwargs={
+ "scale": node.kwargs.get("scale", None),
+ "use_fp32_acc": settings.use_fp32_acc,
+ },
)
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
new_node.meta = copy.copy(node.meta)
# Check if there's a getitem node following this attention node
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/sdpa_converter.py 2025-05-31 02:00:39.720818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/sdpa_converter.py 2025-05-31 02:01:03.072635+00:00
@@ -74,12 +74,12 @@
# query = cast_trt_tensor(
# ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir
# )
# key = cast_trt_tensor(
# ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir
- # )
-
+ # )
+
if scale is None:
scale = query.shape[-1]
if scale < 0:
# dynamic shape
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
@@ -112,11 +112,11 @@
name + "_mm",
query,
key,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)
-
+
# if use_fp32_acc:
# mm = cast_trt_tensor(
# ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir
# )
@@ -162,11 +162,11 @@
scaled_add_attn_bias = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
)
else:
scaled_add_attn_bias = mm
-
+
# Create a if condition to check if is_causal is True
if isinstance(is_causal, TRTTensor):
if_layer = ctx.net.add_if_conditional()
condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, mm
if_layer.set_condition(condition)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_static_cache.py 2025-05-31 02:00:39.719818+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/llm/test_static_cache.py 2025-05-31 02:01:03.107457+00:00
@@ -17,61 +17,83 @@
ATOL = 1e-5
RTOL = 1e-5
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
+
class DynamicCacheModel(nn.Module):
def __init__(self):
super().__init__()
-
+
def forward(self, q, k, v, k1, v1, flag):
- def true_fn(q, k, v, k1, v1):
+ def true_fn(q, k, v, k1, v1):
k_new = torch.cat((k, k1), dim=2)
v_new = torch.cat((v, v1), dim=2)
return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new)
def false_fn(q, k, v, k1, v1):
return torch._C._nn.scaled_dot_product_attention(q, k, v)
out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1))
return 2 * out
-
+
+
class ModelNoCache(nn.Module):
def __init__(self):
super().__init__()
-
+
def forward(self, q, k, v):
- return torch._C._nn.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
+ return torch._C._nn.scaled_dot_product_attention(
+ q, k, v, dropout_p=0.0, is_causal=True
+ )
+
class StaticCacheModel(nn.Module):
def __init__(self):
super().__init__()
-
- # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
+
+ # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
# new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2)
# new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2)
# out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal)
-
+
# return out, new_key_cache, new_value_cache
-
- def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
- concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ]
+
+ def forward(
+ self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+ ):
+ concat_keys = torch.cat(
+ (key_cache[:, :, :start_idx, :], k), dim=2
+ ) # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ]
concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
- new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
- out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal)
-
+ new_value_cache = torch.cat(
+ (concat_values, value_cache[:, :, end_idx:, :]), dim=2
+ )
+ out = torch._C._nn.scaled_dot_product_attention(
+ q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
+ )
+
return out, new_key_cache, new_value_cache
-def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
- is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
+def eager_sdpa(
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ scale=None,
+ enable_gqa=False,
+) -> torch.Tensor:
"""
Eager implementation of SDPA
"""
import math
+
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
breakpoint()
if is_causal:
@@ -85,24 +107,28 @@
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
if enable_gqa:
- key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
- value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
+
def print_diff(tensor1, tensor2, prefix=""):
"""
Print the diff between two tensors
"""
- print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}")
+ print(
+ f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+ )
+
def test_static_cache_model(args):
"""
Test the static cache model
"""
@@ -117,13 +143,15 @@
# Test Prefill
start_idx = 0
end_idx = 2048
out_no_cache = model_no_cache(q, k, v)
- out_static_cache, new_key_cache, new_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True)
+ out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+ q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+ )
assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL)
-
+
# Test Generate
for start_idx in range(2048, 2176):
end_idx = start_idx + 1
q_curr = torch.randn(1, 32, 1, 64).cuda()
k_curr = torch.randn(1, 32, 1, 64).cuda()
@@ -133,17 +161,29 @@
q_full = torch.cat((q, q_curr), dim=2)
k_full = torch.cat((k, k_curr), dim=2)
v_full = torch.cat((v, v_curr), dim=2)
out_no_cache = model_no_cache(q_full, k_full, v_full)
- out_static_cache, new_key_cache, new_value_cache = model_static_cache(q_curr, k_curr, v_curr, new_key_cache, new_value_cache, start_idx, end_idx, is_causal=False)
-
- assert torch.allclose(out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL)
- q = q_full
+ out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+ q_curr,
+ k_curr,
+ v_curr,
+ new_key_cache,
+ new_value_cache,
+ start_idx,
+ end_idx,
+ is_causal=False,
+ )
+
+ assert torch.allclose(
+ out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL
+ )
+ q = q_full
k = k_full
v = v_full
print("============== test_static_cache passed ==============")
+
def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args):
"""
Transform the graph module by adding key and value cache to the graph
"""
@@ -155,52 +195,53 @@
use_python_runtime=True,
debug=args.debug,
min_block_size=1,
)
exported_program = pre_export_lowering(exported_program, settings)
- exported_program = exported_program.run_decompositions(
- get_decompositions(False)
- )
+ exported_program = exported_program.run_decompositions(get_decompositions(False))
gm = exported_program.module()
gm = post_lowering(gm, settings)
return gm
+
def test_static_cache_lowering(args):
"""
- Test static cache lowering pass applied to the model with no cache and run the graph module
+ Test static cache lowering pass applied to the model with no cache and run the graph module
and compare the output with the model with no cache
"""
import static_cache2
model_no_cache = ModelNoCache().eval().cuda()
q = torch.randn(1, 32, 2, 64).cuda()
k = torch.randn(1, 32, 2048, 64).cuda()
v = torch.randn(1, 32, 2048, 64).cuda()
key_cache = torch.zeros(1, 32, 2176, 64).cuda()
value_cache = torch.zeros(1, 32, 2176, 64).cuda()
-
+
# Export the model
q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
exported_program = export(
model_no_cache,
args=(q, k, v),
- dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}),
- strict=False
+ dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+ strict=False,
)
gm = transform_gm_with_kv_cache(exported_program, args)
# Test Prefill
start_idx = 0
end_idx = 2048
is_causal = True
q = torch.randn(1, 32, 2048, 64).cuda()
out_no_cache = model_no_cache(q, k, v)
- out_pyt_cache, key_cache, value_cache = gm(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal)
+ out_pyt_cache, key_cache, value_cache = gm(
+ q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal
+ )
assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL)
# Test Generate
for start_idx in range(2048, 2176):
end_idx = start_idx + 1
@@ -209,20 +250,32 @@
k_curr = torch.randn(1, 32, 1, 64).cuda()
v_curr = torch.randn(1, 32, 1, 64).cuda()
# Concatenate the current query, key, and value with the previous ones
q_full = torch.cat((q, q_curr), dim=2)
k_full = torch.cat((k, k_curr), dim=2)
- v_full = torch.cat((v, v_curr), dim=2)
-
+ v_full = torch.cat((v, v_curr), dim=2)
+
out_no_cache = model_no_cache(q_full, k_full, v_full)
- out_pyt_static_cache, key_cache, value_cache = gm(q_curr, k_curr, v_curr, key_cache, value_cache, start_idx, end_idx, is_causal)
- assert torch.allclose(out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL)
- q = q_full
+ out_pyt_static_cache, key_cache, value_cache = gm(
+ q_curr,
+ k_curr,
+ v_curr,
+ key_cache,
+ value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
+ assert torch.allclose(
+ out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL
+ )
+ q = q_full
k = k_full
v = v_full
-
+
print("============== test_static_cache_lowering passed ==============")
+
def test_static_cache_export(args):
"""
Test the static cache model export
"""
@@ -236,19 +289,28 @@
start_idx = 0
end_idx = 2048
is_causal = True
# Export the model
seq_len = torch.export.Dim("seq_len", min=2, max=2048)
- seq_len_dyn_dim = {2 : seq_len}
+ seq_len_dyn_dim = {2: seq_len}
exported_program = export(
model_static_cache,
args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal),
- dynamic_shapes=(seq_len_dyn_dim, seq_len_dyn_dim, seq_len_dyn_dim, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC, None),
- strict=False
- )
-
-
+ dynamic_shapes=(
+ seq_len_dyn_dim,
+ seq_len_dyn_dim,
+ seq_len_dyn_dim,
+ None,
+ None,
+ torch.export.Dim.DYNAMIC,
+ torch.export.Dim.DYNAMIC,
+ None,
+ ),
+ strict=False,
+ )
+
+
def test_static_cache_with_torch_tensorrt(args):
"""
Test the static cache model with torch_tensorrt
"""
import static_cache2
@@ -257,83 +319,104 @@
q = torch.randn(1, 32, 2, 64).cuda()
k = torch.randn(1, 32, 2048, 64).cuda()
v = torch.randn(1, 32, 2048, 64).cuda()
key_cache = torch.zeros(1, 32, 2176, 64).cuda()
value_cache = torch.zeros(1, 32, 2176, 64).cuda()
-
+
# Export the model
q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
exported_program = export(
model_no_cache,
args=(q, k, v),
- dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}),
- strict=False
- )
- with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
+ dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+ strict=False,
+ )
+ with torch_tensorrt.logging.debug() if args.debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
exported_program,
inputs=[q, k, v],
enabled_precisions={torch.float32},
disable_tf32=True,
use_python_runtime=True,
debug=args.debug,
min_block_size=1,
)
-
+
start_idx = 0
end_idx = 2048
is_causal = True
q = torch.randn(1, 32, 2048, 64).cuda()
# out_eager = eager_sdpa(q, k, v, is_causal=is_causal)
out_no_cache = model_no_cache(q, k, v)
- out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal)
-
- assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL), "Prefill TRT logits don't match"
- assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL), "Prefill TRT key cache don't match"
- assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL), "Prefill TRT value cache don't match"
-
+ out_trt, trt_key_cache, trt_value_cache = trt_model(
+ q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal
+ )
+
+ assert torch.allclose(
+ out_no_cache, out_trt, atol=ATOL, rtol=RTOL
+ ), "Prefill TRT logits don't match"
+ assert torch.allclose(
+ trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL
+ ), "Prefill TRT key cache don't match"
+ assert torch.allclose(
+ trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL
+ ), "Prefill TRT value cache don't match"
+
# Test Generate
for start_idx in range(2048, 2176):
end_idx = start_idx + 1
q_curr = torch.randn(1, 32, 1, 64).cuda()
k_curr = torch.randn(1, 32, 1, 64).cuda()
- v_curr = torch.randn(1, 32, 1, 64).cuda()
+ v_curr = torch.randn(1, 32, 1, 64).cuda()
# Concatenate the current query, key, and value with the previous ones
q_full = torch.cat((q, q_curr), dim=2)
k_full = torch.cat((k, k_curr), dim=2)
- v_full = torch.cat((v, v_curr), dim=2)
+ v_full = torch.cat((v, v_curr), dim=2)
is_causal = False
out_no_cache = model_no_cache(q_full, k_full, v_full)
- out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal)
+ out_trt, trt_key_cache, trt_value_cache = trt_model(
+ q_curr,
+ k_curr,
+ v_curr,
+ trt_key_cache,
+ trt_value_cache,
+ start_idx,
+ end_idx,
+ is_causal,
+ )
# breakpoint()
# print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}")
# print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}")
# print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}")
- assert torch.allclose(out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL), f"Generate TRT logits don't match for idx {start_idx}"
- assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL), f"Generate TRT key cache don't match for idx {start_idx}"
- assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL), f"Generate TRT value cache don't match for idx {start_idx}"
- q = q_full
+ assert torch.allclose(
+ out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL
+ ), f"Generate TRT logits don't match for idx {start_idx}"
+ assert torch.allclose(
+ trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL
+ ), f"Generate TRT key cache don't match for idx {start_idx}"
+ assert torch.allclose(
+ trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL
+ ), f"Generate TRT value cache don't match for idx {start_idx}"
+ q = q_full
k = k_full
v = v_full
print("============== test_static_cache_with_torch_tensorrt passed ==============")
-
+
def main():
arg_parser = argparse.ArgumentParser(
description="Run test cases for llama attention and decoder"
)
arg_parser.add_argument(
- "--debug",
- action="store_true",
- help="Enable debug (default: False)"
+ "--debug", action="store_true", help="Enable debug (default: False)"
)
args = arg_parser.parse_args()
with torch.inference_mode():
test_static_cache_model(args)
test_static_cache_lowering(args)
test_static_cache_with_torch_tensorrt(args)
-
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2025-05-31 02:00:39.730818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2025-05-31 02:01:04.362893+00:00
@@ -692,10 +692,11 @@
gm = exported_program.module()
exported_program.module().to("cpu")
torch.cuda.empty_cache()
import gc
+
gc.collect()
logger.debug("Input graph: " + str(gm.graph))
# Apply lowering on the graph module
gm = post_lowering(gm, settings)
@@ -790,31 +791,30 @@
# TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this.
logger.warning(
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
)
-
# Store the original input spec for later use
- original_in_spec = getattr(gm, '_in_spec', None)
- original_out_spec = getattr(gm, '_out_spec', None)
-
+ original_in_spec = getattr(gm, "_in_spec", None)
+ original_out_spec = getattr(gm, "_out_spec", None)
+
# Function to preserve and restore module specs
def preserve_module_specs(in_spec, out_spec, target_module):
"""
Applies input and output specs to the target module.
-
+
Args:
in_spec: The input spec to apply
out_spec: The output spec to apply
target_module: The module to apply specs to
"""
# Apply specs to target module
if in_spec is not None:
target_module._in_spec = in_spec
if out_spec is not None:
target_module._out_spec = out_spec
-
+
return target_module
# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False
# If specified, try using the fast partitioner and fall back to the global one on failure
@@ -1197,11 +1197,11 @@
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
}
-
+
settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
exported_program = pre_export_lowering(exported_program, settings)
# Decompose the exported program
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2025-05-31 02:00:39.735818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2025-05-31 02:01:05.149640+00:00
@@ -23,11 +23,15 @@
"""Replace specific versions of scaled_dot_product_attention with an equivalent
implementation which can be easily converted to TRT
"""
original_fns, replacement = scaled_dot_product_attention_replacement()
replaced_nodes = []
- sdpa_nodes = [node for node in gm.graph.nodes if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default]
+ sdpa_nodes = [
+ node
+ for node in gm.graph.nodes
+ if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default
+ ]
breakpoint()
# For each original function, search for it in the graph and replace
for original in original_fns:
replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters(
gm,
@@ -167,6 +171,6 @@
def replacement(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(query, key, value)
- return (efficient, flash, efficient_scale, flash_scale), replacement
\ No newline at end of file
+ return (efficient, flash, efficient_scale, flash_scale), replacement
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py 2025-05-31 02:00:39.736818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py 2025-05-31 02:01:05.811969+00:00
@@ -369,27 +369,35 @@
is_shape_tensor_input = self.engine.is_shape_inference_io(input_name)
if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
# Clone is required to avoid re-using user-provided GPU memory
if is_shape_tensor_input:
- self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].cpu().clone()
+ self._input_buffers[inputs_shape_key][i] = (
+ contiguous_inputs[i].cpu().clone()
+ )
else:
- self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].clone()
+ self._input_buffers[inputs_shape_key][i] = contiguous_inputs[
+ i
+ ].clone()
# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# as per TensorRT requirements
if is_shape_tensor_input:
# Shape tensor inputs are casted to int64 explicitly
# Currently Torch CPU pointers are not working; numpy pointers are used instead
# to refer to underlying memory
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64)
- inputs_cpu_numpy = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
+ inputs_cpu_numpy = (
+ contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
+ )
# if cudagraphs_enabled:
# self._input_buffers[inputs_shape_key][i].copy_(inputs_cpu)
# self.context.set_tensor_address(input_name, self._input_buffers[inputs_shape_key][i].numpy().copy().ctypes.data)
# else:
- self.context.set_tensor_address(input_name, inputs_cpu_numpy.ctypes.data)
+ self.context.set_tensor_address(
+ input_name, inputs_cpu_numpy.ctypes.data
+ )
else:
self.context.set_input_shape(
input_name, tuple(contiguous_inputs[i].shape)
)
if cudagraphs_enabled:
@@ -458,11 +466,14 @@
assert len(contiguous_inputs) == len(
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
self.setup_input_tensors(
- contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record, inputs_shape_key
+ contiguous_inputs,
+ self.cudagraphs_enabled,
+ need_cudagraphs_record,
+ inputs_shape_key,
)
if shape_changed:
# Check if input shapes can be inferred.
uninferred_input_names = self.context.infer_shapes()
@@ -496,11 +507,12 @@
if need_cudagraphs_record:
self._output_buffers[inputs_shape_key][o] = outputs[o].clone()
if self.cudagraphs_enabled:
self.context.set_tensor_address(
- output_name, self._output_buffers[inputs_shape_key][o].data_ptr()
+ output_name,
+ self._output_buffers[inputs_shape_key][o].data_ptr(),
)
else:
self.context.set_tensor_address(
output_name, outputs[o].data_ptr()
)
@@ -522,30 +534,35 @@
self._engine_stream.wait_stream(self._caller_stream)
with torch.cuda.stream(self._engine_stream):
if self.cudagraphs_enabled:
if need_cudagraphs_record:
-
- self.shape_key_to_cudagraph[inputs_shape_key] = torch.cuda.CUDAGraph()
+
+ self.shape_key_to_cudagraph[inputs_shape_key] = (
+ torch.cuda.CUDAGraph()
+ )
if self.profiling_enabled:
- self.shape_key_to_cudagraph[inputs_shape_key].enable_debug_mode()
+ self.shape_key_to_cudagraph[
+ inputs_shape_key
+ ].enable_debug_mode()
with torch.cuda.graph(
- self.shape_key_to_cudagraph[inputs_shape_key], stream=self._engine_stream
+ self.shape_key_to_cudagraph[inputs_shape_key],
+ stream=self._engine_stream,
):
self.context.execute_async_v3(
self._engine_stream.cuda_stream
)
if self.profiling_enabled:
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
- self.shape_key_to_cudagraph[inputs_shape_key].debug_dump(
- f"{tempdir}/{self.name}_cudagraph.dot"
- )
+ self.shape_key_to_cudagraph[
+ inputs_shape_key
+ ].debug_dump(f"{tempdir}/{self.name}_cudagraph.dot")
self.shape_key_to_cudagraph[inputs_shape_key].replay() # type: ignore
else:
self.context.execute_async_v3(self._engine_stream.cuda_stream)
@@ -754,18 +771,21 @@
"""
# Representation of input shapes to a given model
# Shapes are concatenated as so:
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
tensor_inputs = [
- t if isinstance(t, torch.Tensor) else torch.tensor(t)
- for t in inputs
+ t if isinstance(t, torch.Tensor) else torch.tensor(t) for t in inputs
]
- new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs)
+ new_shape_key = "".join(
+ str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+ )
# If the new shape key differs from the existing one,
# invalidate the old shape key and remove the CUDAGraph
if new_shape_key not in self.shape_key_to_cudagraph:
- logger.debug(f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape.")
+ logger.debug(
+ f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape."
+ )
# self.shape_key = new_shape_key
return True, new_shape_key
return False, new_shape_key
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2025-05-31 02:00:39.736818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2025-05-31 02:01:05.838721+00:00
@@ -742,14 +742,15 @@
"""
# Representation of input shapes to a given model
# Shapes are concatenated as so:
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
tensor_inputs = [
- t if isinstance(t, torch.Tensor) else torch.tensor(t)
- for t in inputs
+ t if isinstance(t, torch.Tensor) else torch.tensor(t) for t in inputs
]
- new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs)
+ new_shape_key = "".join(
+ str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+ )
# If the new shape key differs from the existing one,
# invalidate the old shape key and remove the CUDAGraph
if new_shape_key != self.shape_key:
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2025-05-31 02:00:39.732818+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2025-05-31 02:01:06.640331+00:00
@@ -1893,10 +1893,11 @@
SourceIR.ATEN,
name,
args[0],
args[1],
)
+
@dynamo_tensorrt_converter(operator.sub, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True)
def aten_ops_sub(
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
cla signed
component: api [Python]
Issues re: Python API
component: conversion
Issues re: Conversion stage
component: dynamo
Issues relating to the `torch.compile` or `torch._dynamo.export` paths
component: lowering
Issues re: The lowering / preprocessing passes
component: runtime
WIP
Work is in progress, pull request should not be merged yet
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: