Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
}

builder = make_trt(nvinfer1::createInferBuilder(logger));
// kEXPLICIT_BATCH was removed in TRT 11 (explicit batch is always on); pass 0 on TRT 11+.
#if NV_TENSORRT_MAJOR >= 11
net = make_trt(builder->createNetworkV2(0));
#else
net = make_trt(
builder->createNetworkV2(1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
#endif

LOG_INFO(settings);
cfg = make_trt(builder->createBuilderConfig());
Expand All @@ -59,7 +64,8 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
switch (*p) {
case nvinfer1::DataType::kHALF:
// tensorrt_rtx is strong typed, cannot set fp16 by builder config, only do this for tensorrt build
#ifndef TRT_MAJOR_RTX
// TRT 11.0 removed platformHasFastFp16 and kFP16 (always strongly typed).
#if !defined(TRT_MAJOR_RTX) && NV_TENSORRT_MAJOR < 11
TORCHTRT_CHECK(
builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
Expand Down
6 changes: 6 additions & 0 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,13 @@ bool InstanceNormConverter(ConversionCtx* ctx, const torch::jit::Node* n, args&
fc.nbFields = f.size();
fc.fields = f.data();

// TRT 11.0 renamed getPluginCreator → getCreator with IPluginCreatorInterface return type
#if NV_TENSORRT_MAJOR >= 11
auto* creator_iface = getPluginRegistry()->getCreator("InstanceNormalization_TRT", "1", "");
auto* creator = dynamic_cast<nvinfer1::IPluginCreator*>(creator_iface);
#else
auto creator = getPluginRegistry()->getPluginCreator("InstanceNormalization_TRT", "1", "");
#endif
auto instance_norm_plugin = creator->createPlugin("instance_norm", &fc);

TORCHTRT_CHECK(instance_norm_plugin, "Unable to create instance_norm plugin from TensorRT plugin registry" << *n);
Expand Down
6 changes: 6 additions & 0 deletions core/conversion/converters/impl/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ void create_plugin(

fc.nbFields = f.size();
fc.fields = f.data();
// TRT 11.0 renamed getPluginCreator → getCreator with IPluginCreatorInterface return type
#if NV_TENSORRT_MAJOR >= 11
auto* creator_iface = getPluginRegistry()->getCreator("Interpolate", "1", "torch_tensorrt");
auto* creator = dynamic_cast<nvinfer1::IPluginCreator*>(creator_iface);
#else
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "torch_tensorrt");
#endif
auto interpolate_plugin = creator->createPlugin(name, &fc);

auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/converters/impl/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
TORCHTRT_CHECK(normalize_layer, "Unable to create layer_norm from node: " << *n);
normalize_layer->setName(util::node_info(n).c_str());
normalize_layer->setEpsilon(eps);
#if NV_TENSORRT_MAJOR < 11
normalize_layer->setComputePrecision(input->getType());
#endif
auto normalized = normalize_layer->getOutput(0);

ctx->AssociateValueAndTensor(n->outputs()[0], normalized);
Expand Down
6 changes: 6 additions & 0 deletions core/conversion/converters/impl/normalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ void create_plugin(
}
}

// TRT 11.0 renamed getPluginCreator → getCreator with IPluginCreatorInterface return type
#if NV_TENSORRT_MAJOR >= 11
auto* creator_iface = getPluginRegistry()->getCreator("NormalizePlugin", "1", "torch_tensorrt");
auto* creator = dynamic_cast<nvinfer1::IPluginCreator*>(creator_iface);
#else
auto creator = getPluginRegistry()->getPluginCreator("NormalizePlugin", "1", "torch_tensorrt");
#endif
auto plugin = creator->createPlugin(name, &fc);
auto normalize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
TORCHTRT_CHECK(normalize_layer, "Unable to create normalization plugin from node" << *n);
Expand Down
6 changes: 6 additions & 0 deletions core/conversion/converters/impl/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ bool AdaptivePoolingConverter(
LOG_WARNING(
"Adaptive pooling layer will be using Aten library kernels in pytorch for execution. TensorRT does not support adaptive pooling natively. Consider switching to non-adaptive pooling if this is an issue");

// TRT 11.0 renamed getPluginCreator → getCreator with IPluginCreatorInterface return type
#if NV_TENSORRT_MAJOR >= 11
auto* creator_iface = getPluginRegistry()->getCreator("Interpolate", "1", "torch_tensorrt");
auto* creator = dynamic_cast<nvinfer1::IPluginCreator*>(creator_iface);
#else
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "torch_tensorrt");
#endif
auto interpolate_plugin = creator->createPlugin(mode.c_str(), &fc);

new_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/converters/impl/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ auto sqrt_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patter
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kSQRT);
TORCHTRT_CHECK(unary_layer, "Unable to create sqrt layer from node: " << *n);
unary_layer->setName(util::node_info(n).c_str());
#if NV_TENSORRT_MAJOR < 11
unary_layer->setOutputType(0, in->getType());
#endif
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
Expand Down
19 changes: 19 additions & 0 deletions core/plugins/register_plugins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ class TorchTRTPluginRegistry {
plugin_logger.set_reportable_log_level(util::logging::get_logger().get_reportable_log_level());

int numCreators = 0;
// TRT 11.0 renamed getPluginCreatorList → getAllCreators with IPluginCreatorInterface return type
#if NV_TENSORRT_MAJOR >= 11
auto pluginsList = getPluginRegistry()->getAllCreators(&numCreators);
for (int k = 0; k < numCreators; ++k) {
if (!pluginsList[k]) {
plugin_logger.log(util::logging::LogLevel::kDEBUG, "Plugin creator for plugin " + str(k) + " is a nullptr");
continue;
}
auto* creator_v1 = dynamic_cast<nvinfer1::IPluginCreator*>(pluginsList[k]);
if (creator_v1) {
std::string pluginNamespace = creator_v1->getPluginNamespace();
plugin_logger.log(
util::logging::LogLevel::kDEBUG,
"Registered plugin creator - " + std::string(creator_v1->getPluginName()) +
", Namespace: " + pluginNamespace);
}
}
#else
auto pluginsList = getPluginRegistry()->getPluginCreatorList(&numCreators);
for (int k = 0; k < numCreators; ++k) {
if (!pluginsList[k]) {
Expand All @@ -42,6 +60,7 @@ class TorchTRTPluginRegistry {
"Registered plugin creator - " + std::string(pluginsList[k]->getPluginName()) +
", Namespace: " + pluginNamespace);
}
#endif
plugin_logger.log(util::logging::LogLevel::kDEBUG, "Total number of plugins registered: " + str(numCreators));
}

Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
output_dtype = self.output_dtypes[i]

self.ctx.net.mark_output(output)
if output_dtype is not dtype.unknown:
if output_dtype is not dtype.unknown and not self.ctx.net.get_flag(
trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED
):
output.dtype = output_dtype.to(trt.DataType, use_default=True)
output.name = name

Expand Down
Loading