-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding chunk converter back #3314
Conversation
Since PyTorch implicitly decomposes I think the |
Thanks @HolyWu! Yes, |
Hmm, but isn't it the case that earlier chunk was getting implicitly changed to split during torch export but now it is not in the recent nightlies. This was the reason why the converter was removed before and now it has been added back. |
@apbose
You can repro by this code: import torch
import torch_tensorrt as torch_trt
class MyModel(torch.nn.Module):
def forward(self, a):
return torch.randn(4, 3).chunk(2, 1)
inputs = torch.randn((4, 3)).to("cuda")
model = MyModel().eval().cuda()
compiled_model = torch_trt.compile(model, inputs=[inputs])
print(compiled_model.graph)
out = compiled_model(inputs)
print(out) |
@zewenli98 thanks for the clarification.
In the recent nightlies it leads to
|
Sorry, I don't know where the list is. I guess the decomp is in the C++ side rather than in the Python side, so the
I don't think so. Using the codes below shows that the registered converter for import os
from typing import Dict, Sequence, Tuple, Union
import torch
import torch_tensorrt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.conversion import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority, dynamo_tensorrt_converter
from torch_tensorrt.dynamo.types import TRTTensor
os.environ["CI_BUILD"] = "1"
print(f"\n{torch.__version__=}")
print(f"{torch_tensorrt.__version__=}\n")
@dynamo_tensorrt_converter(torch.ops.aten.chunk.default, priority=ConverterPriority.HIGH, supports_dynamic_shapes=True)
def my_chunk(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
raise NotImplementedError("chunk")
class MyModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.aten.chunk.default(x, 3, 0)
with torch.inference_mode():
model = MyModule().eval().cuda()
inputs = [torch.randn(1, device="cuda")]
trt_model = torch_tensorrt.compile(
model,
"dynamo",
inputs,
debug=True,
min_block_size=1,
)
print(f"\n{trt_model(*inputs)}") torch.__version__='2.6.0.dev20241209+cu124'
torch_tensorrt.__version__='2.6.0.dev20241210+cu124'
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%chunk : [num_users=1] = call_function[target=torch.ops.aten.chunk.default](args = (%x, 3), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%chunk, 0), kwargs = {})
return (getitem,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%split : [num_users=1] = call_function[target=torch.ops.aten.split.Tensor](args = (%x, 1), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%x : [num_users=1] = placeholder[target=x]
%split : [num_users=1] = call_function[target=torch.ops.aten.split.Tensor](args = (%x, 1), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%split : [num_users=1] = call_function[target=torch.ops.aten.split.Tensor](args = (%x, 1), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%split : [num_users=1] = call_function[target=torch.ops.aten.split.Tensor](args = (%x, 1), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
return (getitem,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.split.Tensor + Operator Count: 1
- _operator.getitem + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 2 operators out of 2 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.split.Tensor + Operator Count: 1
- _operator.getitem + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(1,)]
graph():
%x : [num_users=1] = placeholder[target=x]
%split : [num_users=1] = call_function[target=torch.ops.aten.split.Tensor](args = (%x, 1), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
return getitem
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1,)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /split (kind: aten.split.Tensor, args: ('x <Node>', '1 <int>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /split [aten.split.Tensor] (Inputs: (x: (1,)@torch.float32, 1) | Outputs: (split: ((1,)@torch.float32)))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /getitem (kind: <built-in function getitem>, args: ('split <Node>', '0 <int>'))
DEBUG:torch_tensorrt.dynamo.conversion.ops_evaluators:Evaluating _operator.getitem on object with name: /getitem
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /getitem [<built-in function getitem>] (Inputs: (split: ((1,)@torch.float32), 0) | Outputs: (getitem: (1,)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('getitem <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1,), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (getitem: (1,)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.003443
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.024018
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 2516 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 127 microseconds.
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 0
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 0
DEBUG: [Torch-TensorRT] - - Runner scratch: 0 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: x has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
Name: _run_on_acc_0_engine
Inputs: [
id: 0
name: x
shape: [1]
dtype: Float
]
Outputs: [
id: 0
name: output0
shape: [1]
dtype: Float
]
Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
Hardware Compatibility: Disabled
Target Platform: windows_x86_64
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++
The graph consists of 2 Total Operators, of which 2 operators are supported, 100.0% coverage
Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False, enable_cross_compile_for_windows=False)
Graph Structure:
Inputs: List[Tensor: (1)@float32]
...
TRT Engine #1 - Submodule name: _run_on_acc_0
Engine Inputs: List[Tensor: (1)@float32]
Number of Operators in Engine: 2
Engine Outputs: List[Tensor: (1)@float32]
...
Outputs: List[Tensor: (1)@float32]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 2.0
Most Operators in a TRT Engine: 2
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=2 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=2 which generates 1 TRT engine(s)
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0
DEBUG: [Torch-TensorRT] - Input Name: x Shape: [1]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [1]
[tensor([-1.2514], device='cuda:0')] |
@apbose I think if you run the chunk converter tests, unlike compiling a model, it will not do any transformations like lowering to the graph, so the DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%chunk : [num_users=1] = call_function[target=torch.ops.aten.chunk.default](args = (%x, 3), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%chunk, 0), kwargs = {})
return (getitem,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%split : [num_users=1] = call_function[target=torch.ops.aten.split.Tensor](args = (%x, 1), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
return (getitem,) The |
I think this PR could be closed in favor of #3167 where chunk converter tests were entirely removed. |
Closed in favor of #3167 |
PR https://github.com/pytorch/TensorRT/pull/3120/files did away with the chunk converter since it got lowered to split in dynamo tracing. In recent nightlies it complains with chunk converter not found.
Todo: handle dynamic cases by adding a validator