diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 07ade4b17f..4a86042129 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -49,8 +49,13 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) } builder = make_trt(nvinfer1::createInferBuilder(logger)); + // kEXPLICIT_BATCH was removed in TRT 11 (explicit batch is always on); pass 0 on TRT 11+. +#if NV_TENSORRT_MAJOR >= 11 + net = make_trt(builder->createNetworkV2(0)); +#else net = make_trt( builder->createNetworkV2(1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH))); +#endif LOG_INFO(settings); cfg = make_trt(builder->createBuilderConfig()); @@ -59,7 +64,8 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) switch (*p) { case nvinfer1::DataType::kHALF: // tensorrt_rtx is strong typed, cannot set fp16 by builder config, only do this for tensorrt build -#ifndef TRT_MAJOR_RTX +// TRT 11.0 removed platformHasFastFp16 and kFP16 (always strongly typed). +#if !defined(TRT_MAJOR_RTX) && NV_TENSORRT_MAJOR < 11 TORCHTRT_CHECK( builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16"); cfg->setFlag(nvinfer1::BuilderFlag::kFP16); diff --git a/core/conversion/converters/impl/batch_norm.cpp b/core/conversion/converters/impl/batch_norm.cpp index 03b844fdd2..c787c0cce8 100644 --- a/core/conversion/converters/impl/batch_norm.cpp +++ b/core/conversion/converters/impl/batch_norm.cpp @@ -189,7 +189,13 @@ bool InstanceNormConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& fc.nbFields = f.size(); fc.fields = f.data(); + // TRT 11.0 renamed getPluginCreator → getCreator with IPluginCreatorInterface return type +#if NV_TENSORRT_MAJOR >= 11 + auto* creator_iface = getPluginRegistry()->getCreator("InstanceNormalization_TRT", "1", ""); + auto* creator = dynamic_cast(creator_iface); +#else auto creator = getPluginRegistry()->getPluginCreator("InstanceNormalization_TRT", "1", ""); +#endif auto instance_norm_plugin = creator->createPlugin("instance_norm", &fc); TORCHTRT_CHECK(instance_norm_plugin, "Unable to create instance_norm plugin from TensorRT plugin registry" << *n); diff --git a/core/conversion/converters/impl/interpolate.cpp b/core/conversion/converters/impl/interpolate.cpp index f3b0180188..27a35cdac1 100644 --- a/core/conversion/converters/impl/interpolate.cpp +++ b/core/conversion/converters/impl/interpolate.cpp @@ -54,7 +54,13 @@ void create_plugin( fc.nbFields = f.size(); fc.fields = f.data(); + // TRT 11.0 renamed getPluginCreator → getCreator with IPluginCreatorInterface return type +#if NV_TENSORRT_MAJOR >= 11 + auto* creator_iface = getPluginRegistry()->getCreator("Interpolate", "1", "torch_tensorrt"); + auto* creator = dynamic_cast(creator_iface); +#else auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "torch_tensorrt"); +#endif auto interpolate_plugin = creator->createPlugin(name, &fc); auto resize_layer = ctx->net->addPluginV2(reinterpret_cast(&in), 1, *interpolate_plugin); diff --git a/core/conversion/converters/impl/layer_norm.cpp b/core/conversion/converters/impl/layer_norm.cpp index 4bb1c1211b..1db827945b 100644 --- a/core/conversion/converters/impl/layer_norm.cpp +++ b/core/conversion/converters/impl/layer_norm.cpp @@ -51,7 +51,9 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() TORCHTRT_CHECK(normalize_layer, "Unable to create layer_norm from node: " << *n); normalize_layer->setName(util::node_info(n).c_str()); normalize_layer->setEpsilon(eps); +#if NV_TENSORRT_MAJOR < 11 normalize_layer->setComputePrecision(input->getType()); +#endif auto normalized = normalize_layer->getOutput(0); ctx->AssociateValueAndTensor(n->outputs()[0], normalized); diff --git a/core/conversion/converters/impl/normalize.cpp b/core/conversion/converters/impl/normalize.cpp index 9e50a0c418..7e2dd8fb1c 100644 --- a/core/conversion/converters/impl/normalize.cpp +++ b/core/conversion/converters/impl/normalize.cpp @@ -42,7 +42,13 @@ void create_plugin( } } + // TRT 11.0 renamed getPluginCreator → getCreator with IPluginCreatorInterface return type +#if NV_TENSORRT_MAJOR >= 11 + auto* creator_iface = getPluginRegistry()->getCreator("NormalizePlugin", "1", "torch_tensorrt"); + auto* creator = dynamic_cast(creator_iface); +#else auto creator = getPluginRegistry()->getPluginCreator("NormalizePlugin", "1", "torch_tensorrt"); +#endif auto plugin = creator->createPlugin(name, &fc); auto normalize_layer = ctx->net->addPluginV2(reinterpret_cast(&in), 1, *plugin); TORCHTRT_CHECK(normalize_layer, "Unable to create normalization plugin from node" << *n); diff --git a/core/conversion/converters/impl/pooling.cpp b/core/conversion/converters/impl/pooling.cpp index 0e7f4e0dbc..5555fc11e3 100644 --- a/core/conversion/converters/impl/pooling.cpp +++ b/core/conversion/converters/impl/pooling.cpp @@ -97,7 +97,13 @@ bool AdaptivePoolingConverter( LOG_WARNING( "Adaptive pooling layer will be using Aten library kernels in pytorch for execution. TensorRT does not support adaptive pooling natively. Consider switching to non-adaptive pooling if this is an issue"); + // TRT 11.0 renamed getPluginCreator → getCreator with IPluginCreatorInterface return type +#if NV_TENSORRT_MAJOR >= 11 + auto* creator_iface = getPluginRegistry()->getCreator("Interpolate", "1", "torch_tensorrt"); + auto* creator = dynamic_cast(creator_iface); +#else auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "torch_tensorrt"); +#endif auto interpolate_plugin = creator->createPlugin(mode.c_str(), &fc); new_layer = ctx->net->addPluginV2(reinterpret_cast(&in), 1, *interpolate_plugin); diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index 3e01869d68..e407906c16 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -89,7 +89,9 @@ auto sqrt_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patter auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kSQRT); TORCHTRT_CHECK(unary_layer, "Unable to create sqrt layer from node: " << *n); unary_layer->setName(util::node_info(n).c_str()); +#if NV_TENSORRT_MAJOR < 11 unary_layer->setOutputType(0, in->getType()); +#endif auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; diff --git a/core/plugins/register_plugins.cpp b/core/plugins/register_plugins.cpp index 8590902a15..38b871019d 100644 --- a/core/plugins/register_plugins.cpp +++ b/core/plugins/register_plugins.cpp @@ -30,6 +30,24 @@ class TorchTRTPluginRegistry { plugin_logger.set_reportable_log_level(util::logging::get_logger().get_reportable_log_level()); int numCreators = 0; + // TRT 11.0 renamed getPluginCreatorList → getAllCreators with IPluginCreatorInterface return type +#if NV_TENSORRT_MAJOR >= 11 + auto pluginsList = getPluginRegistry()->getAllCreators(&numCreators); + for (int k = 0; k < numCreators; ++k) { + if (!pluginsList[k]) { + plugin_logger.log(util::logging::LogLevel::kDEBUG, "Plugin creator for plugin " + str(k) + " is a nullptr"); + continue; + } + auto* creator_v1 = dynamic_cast(pluginsList[k]); + if (creator_v1) { + std::string pluginNamespace = creator_v1->getPluginNamespace(); + plugin_logger.log( + util::logging::LogLevel::kDEBUG, + "Registered plugin creator - " + std::string(creator_v1->getPluginName()) + + ", Namespace: " + pluginNamespace); + } + } +#else auto pluginsList = getPluginRegistry()->getPluginCreatorList(&numCreators); for (int k = 0; k < numCreators; ++k) { if (!pluginsList[k]) { @@ -42,6 +60,7 @@ class TorchTRTPluginRegistry { "Registered plugin creator - " + std::string(pluginsList[k]->getPluginName()) + ", Namespace: " + pluginNamespace); } +#endif plugin_logger.log(util::logging::LogLevel::kDEBUG, "Total number of plugins registered: " + str(numCreators)); } diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index eb7df76054..26c558494b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -935,7 +935,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: output_dtype = self.output_dtypes[i] self.ctx.net.mark_output(output) - if output_dtype is not dtype.unknown: + if output_dtype is not dtype.unknown and not self.ctx.net.get_flag( + trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED + ): output.dtype = output_dtype.to(trt.DataType, use_default=True) output.name = name