Skip to content

Commit de846c5

Browse files
Move internal changes (#455)
## NFC: Simplify some aspects of options management (OptionsContext) - Adds a convenience 'OptionsContext::Option' class that simplifies how options are declared. - Closes a loophole where tuples of structs containing options can cause crashes if they populate options in their constructor. Due to how the external storage mechanism works, we can no longer use direct `std::tuple` of aggregate objects which invoke `addOption`. Instead, one must use `unique_ptr` to wrap those types when used as elements of a `std::tuple`. - To help enforce this, we explicitly delete the move constructor of `OptionsProvider`. ## [compiler|python] Update how cached pipelines/"Compiler Tasks" are registered This change updates how registration functions for "compilation tasks" invoked. We now expose a C API method that can be invoked within the Pybind11 module initializer. This decouples compiler task registration from pass or dialect registration. This change also cleans up the C API function naming for pass/dialect registration functions. ## [python] Add more robust CMake logic for fixing missing CAPI dependency in core MLIR PyBind module Adds CMake logic to ensure that the Core '_mlir' pybind extension has the correct CAPI dependencies declared until the upstream CMake declarations can be fixed. ## NFC: Remove unnecessary PyCapsule <-> CAPI casters in compiler and runtime bindings Removes unnecessary custom PyBind11 capsule -> C API object casters. These cast functions are only required when it is desired to allow PyBind11 to extract the C API object from the C++ python wrapper type automatically. ## [tensorrt|compiler] Drop "layer metadata callback" utility from TensorRT translation This change removes the "layer metadata callback" feature from the MLIR-to-TensorRT translation. It also removes the relevant APIs from the MLIR-TensorRT compiler's C++ and Python APIs. This capability was originally offered as a bridge for populating the generated TensorRT ILayers with custom metadata. However, the mechanism prevents caching of pass pipelines and therefore is too expensive to use. In the future, any metadata passed to TensorRT should be derived from the MLIR operations' location information. ## NFC: update various uses of "Stablehlo" in class and function names to have consistent capitalization ## NFC: Reorganize some directories This change: - Moves the top-level 'tools' to 'compiler/tools' - Moves the top-level 'test' to 'compiler/test' - Moves the 'mlir-tensorrt-tblgen' tool under 'tensorrt/tools' since the 'tensorrt' project is supposed to be independent. - Similarly move TensorRT-specific python definitions under `tensorrt/python`. ## [executor]: Add a missing guard for builds without CUDA enabled. Wrapping the makeCudaStringError function with MLIR_EXECUTOR_ENABLE_CUDA fixes builds without CUDA enabled. ## [executor] Use Lua locals for block arguments Previously, the Executor MLIR-to-Lua translator used Lua globals for block arguments outside of the entry block since the values that represent block arguments need to be passed between blocks. On the other hand, the scope of Lua local variables are restricted to their block. It is almost never a good idea to use Lua global variables in our translation strategy, however -- for coroutine functions, a translation that uses globals is obviously incorrect since all Lua coroutines in a single Lua environment share the same set of globals. This change declares all block arguments up front as locals in the "entry block" and just sets them to `nil` initially. Since we don't declare a block scope for the entry block, all the following Lua block scopes will have these locals in scope. This allows us to retain the use of locals for all block arguments. GitOrigin-RevId: e9dd03c47eab6145e889ea8ff56fd1c71181f72a
1 parent 1668010 commit de846c5

File tree

180 files changed

+436
-506
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

180 files changed

+436
-506
lines changed

.github/workflows/mlir-tensorrt-ci.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ jobs:
101101
cat > run_format_check.sh <<EOF
102102
#!/bin/bash
103103
set -e
104-
python3 -m black --check --exclude='.*\.pyi' mlir-tensorrt/test/
105-
python3 -m black --check --exclude='.*\.pyi' mlir-tensorrt/python/
104+
python3 -m black --check --extend-exclude='.*\.pyi' mlir-tensorrt/compiler/
105+
python3 -m black --check --extend-exclude='.*\.pyi' mlir-tensorrt/python/
106106
git clang-format HEAD~1 --diff
107107
EOF
108108

mlir-tensorrt/CMakeLists.txt

-6
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,3 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/tensorrt/include)
262262

263263
add_subdirectory(compiler)
264264
add_subdirectory(python)
265-
266-
if(MLIR_TRT_ENABLE_TESTING)
267-
add_subdirectory(test)
268-
endif()
269-
270-
add_subdirectory(tools)

mlir-tensorrt/build_tools/cmake/ManagedLLVM.cmake

+15
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,19 @@ macro(mtrt_llvm_project)
1313

1414
set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin)
1515
set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib)
16+
17+
# The 'MLIRPythonExtensions.Core' target upstream is missing an
18+
# EMBED_CAPI_LINK_LIBS argument on 'MLIRCAPITransforms'. Instead, it's
19+
# declared on the '_mlirRegisterEverything' extension, which appears to be wrong.
20+
# TODO: fix this upstream.
21+
if(MLIR_TRT_ENABLE_PYTHON)
22+
get_property(mlir_core_pybind_capi_embed
23+
TARGET MLIRPythonExtension.Core
24+
PROPERTY mlir_python_EMBED_CAPI_LINK_LIBS)
25+
list(FIND mlir_core_pybind_capi_embed MLIRCAPITransforms item_index)
26+
if(item_index EQUAL -1)
27+
set_property(TARGET MLIRPythonExtension.Core
28+
APPEND PROPERTY mlir_python_EMBED_CAPI_LINK_LIBS MLIRCAPITransforms)
29+
endif()
30+
endif()
1631
endmacro()

mlir-tensorrt/compiler/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
set(MLIR_TENSORRT_COMPILER_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
2+
13
include_directories(${CMAKE_CURRENT_LIST_DIR}/include)
24
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
35
include_directories(${MLIR_TENSORRT_ROOT_DIR}/executor/include)
46
include_directories(${MLIR_TENSORRT_ROOT_BINARY_DIR}/executor/include)
57

68
add_subdirectory(include)
79
add_subdirectory(lib)
10+
add_subdirectory(test)
11+
add_subdirectory(tools)
812

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

-14
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,6 @@ typedef struct MTRT_StableHLOToExecutableOptions {
8484
void *ptr;
8585
} MTRT_StableHLOToExecutableOptions;
8686

87-
/// A callback that allows the user to customize the metadata set for layers
88-
/// corresponding to each MLIR operation. The callback should invoke the
89-
/// provided append function in order to manipulate the result string.
90-
typedef void (*MTRT_MetadataCallback)(MlirOperation op,
91-
MlirStringCallback append,
92-
void *appendCtx, void *userData);
93-
9487
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsCreate(
9588
MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions *options,
9689
int32_t tensorRTBuilderOptLevel, bool tensorRTStronglyTyped);
@@ -108,13 +101,6 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
108101
const char **debugTypes, size_t debugTypeSizes,
109102
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);
110103

111-
/// Sets the layer metadata callback. The `userData` argument is passed along
112-
/// to the callback when it is invoked.
113-
MLIR_CAPI_EXPORTED MTRT_Status
114-
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
115-
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
116-
void *userData);
117-
118104
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
119105
MTRT_StableHLOToExecutableOptions options);
120106

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Registration/RegisterAllDialects.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ extern "C" {
3232

3333
/// Add all the dialects used by MLIR-TensorRT to the registry.
3434
MLIR_CAPI_EXPORTED void
35-
mlirTensorRTRegisterAllDialects(MlirDialectRegistry registry);
35+
mtrtCompilerRegisterDialects(MlirDialectRegistry registry);
3636

3737
/// Register all the compiler passes used by MLIR-TensorRT.
38-
MLIR_CAPI_EXPORTED void mlirTensorRTRegisterAllPasses();
38+
MLIR_CAPI_EXPORTED void mtrtCompilerRegisterPasses();
39+
40+
/// Register all the compiler task types (pass manager types).
41+
MLIR_CAPI_EXPORTED void mtrtCompilerRegisterTasks();
3942

4043
#ifdef __cplusplus
4144
}

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

+15-21
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,13 @@ namespace mlirtrt::compiler {
4949
// StableHLOToExecutableOptions
5050
//===----------------------------------------------------------------------===//
5151

52-
class StableHloToExecutableTask;
52+
class StablehloToExecutableTask;
5353

54-
struct StableHLOToExecutableOptions
54+
struct StablehloToExecutableOptions
5555
: public mlir::OptionsBundle<DebugOptions, ExecutorOptions, DeviceOptions> {
5656
/// Initializes the options. The extensions in the provided registry
5757
/// must be extensions for the StableHloToExecutable task.
58-
StableHLOToExecutableOptions(TaskExtensionRegistry extensions);
59-
60-
/// Return the hash of the options. Returns `nullopt` when the TensorRT
61-
/// layer metadata callback is set since that can't be reliably hashed.
62-
std::optional<llvm::hash_code> getHash() const override;
58+
StablehloToExecutableOptions(TaskExtensionRegistry extensions);
6359

6460
/// Whether to disallow host tensors in TensorRT clusters.
6561
bool disallowHostTensorsInTensorRTClusters = false;
@@ -71,18 +67,16 @@ struct StableHLOToExecutableOptions
7167
/// Entrypoint function name.
7268
std::string entrypoint = "main";
7369

74-
std::function<std::string(mlir::Operation *)> layerMetadataCallback{nullptr};
75-
7670
/// Base class for extensions associated with StableHloToExecutableTask.
7771
class ExtensionBase : public TaskExtensionBase {
7872
public:
7973
ExtensionBase(mlir::TypeID typeID)
8074
: TaskExtensionBase(typeID,
81-
mlir::TypeID::get<StableHloToExecutableTask>()) {}
75+
mlir::TypeID::get<StablehloToExecutableTask>()) {}
8276

8377
static bool classof(const TaskExtensionBase *extension) {
8478
return extension->getTaskID() ==
85-
mlir::TypeID::get<StableHloToExecutableTask>();
79+
mlir::TypeID::get<StablehloToExecutableTask>();
8680
}
8781

8882
enum class Phase {
@@ -98,7 +92,7 @@ struct StableHLOToExecutableOptions
9892
/// relative to each other (yet).
9993
virtual void
10094
populatePasses(mlir::OpPassManager &pm, Phase phase,
101-
const StableHLOToExecutableOptions &options) const = 0;
95+
const StablehloToExecutableOptions &options) const = 0;
10296
};
10397

10498
/// A StableHLOToExecutableOptions::Extension is an extension that must
@@ -120,39 +114,39 @@ struct StableHLOToExecutableOptions
120114
/// A StableHloToExecutableTask is a concrete CompilationTask (PassManager) that
121115
/// accepts StableHLO input IR and lowers it down to Executor IR which can be
122116
/// translated into a MLIR-TensorRT executable.
123-
class StableHloToExecutableTask
124-
: public CompilationTask<StableHloToExecutableTask,
125-
StableHLOToExecutableOptions> {
117+
class StablehloToExecutableTask
118+
: public CompilationTask<StablehloToExecutableTask,
119+
StablehloToExecutableOptions> {
126120
public:
127121
using Base::Base;
128122

129123
/// Build the clustering pipeline that occurs on Stablehlo Ops.
130124
static void
131125
buildStablehloClusteringPipeline(mlir::OpPassManager &pm,
132-
const StableHLOToExecutableOptions &options);
126+
const StablehloToExecutableOptions &options);
133127

134128
/// Build the pipeline (bufferization and lowering) that runs after
135129
/// clustering.
136130
static void
137131
buildPostClusteringPipeline(mlir::OpPassManager &pm,
138-
const StableHLOToExecutableOptions &options);
132+
const StablehloToExecutableOptions &options);
139133

140134
static void populatePassManager(mlir::PassManager &pm,
141-
const StableHLOToExecutableOptions &options);
135+
const StablehloToExecutableOptions &options);
142136

143137
/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
144138
/// This is the "functional" entrypoint that will allocate a new PassManager
145139
/// for a single run.
146140
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
147141
compileStableHLOToExecutable(mlir::ModuleOp module,
148-
const StableHLOToExecutableOptions &options);
142+
const StablehloToExecutableOptions &options);
149143

150144
/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
151145
/// This is the "functional" entrypoint that will allocate a new PassManager
152146
/// for a single run.
153147
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
154148
compileStableHLOToExecutable(CompilerClient &client, mlir::ModuleOp module,
155-
const StableHLOToExecutableOptions &options);
149+
const StablehloToExecutableOptions &options);
156150
};
157151

158152
/// Register the task/options with the client's registry.
@@ -175,7 +169,7 @@ void registerStablehloClusteringPipelines();
175169

176170
} // namespace mlirtrt::compiler
177171

178-
MLIR_DECLARE_EXPLICIT_TYPE_ID(mlirtrt::compiler::StableHloToExecutableTask)
172+
MLIR_DECLARE_EXPLICIT_TYPE_ID(mlirtrt::compiler::StablehloToExecutableTask)
179173

180174
#endif // MLIR_TRT_ENABLE_HLO
181175
#endif // MLIR_TENSORRT_COMPILER_STABLEHLOTOEXECUTABLE

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTExtension/TensorRTExtension.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace mlirtrt::compiler {
3434
//===----------------------------------------------------------------------===//
3535

3636
class StableHLOToExecutableTensorRTExtension
37-
: public StableHLOToExecutableOptions::Extension<
37+
: public StablehloToExecutableOptions::Extension<
3838
StableHLOToExecutableTensorRTExtension> {
3939
public:
4040
StableHLOToExecutableTensorRTExtension();
@@ -45,7 +45,7 @@ class StableHLOToExecutableTensorRTExtension
4545
/// It is not guarunteed the order in which different extensions are run
4646
/// relative to each other (yet).
4747
void populatePasses(mlir::OpPassManager &pm, Phase phase,
48-
const StableHLOToExecutableOptions &options) const final;
48+
const StablehloToExecutableOptions &options) const final;
4949

5050
/// Allows the extension to hook into the option parsing infrastructure.
5151
void addToOptions(mlir::OptionsContext &context) final {

mlir-tensorrt/compiler/include/mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ inline void registerAllMlirTensorRtPasses() {
5353
mlir::registerConvertPDLToPDLInterp();
5454

5555
#ifdef MLIR_TRT_ENABLE_HLO
56-
mlirtrt::compiler::registerStableHloToExecutableTask();
5756
mlirtrt::compiler::registerStablehloClusteringPipelines();
5857
registerStableHloInputPipelines();
5958
stablehlo_ext::registerStableHloExtPasses();

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

+10-36
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ using namespace mlir;
4747
#endif
4848
DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient)
4949
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions,
50-
StableHLOToExecutableOptions)
50+
StablehloToExecutableOptions)
5151
DEFINE_C_API_PTR_METHODS(MTRT_OptionsContext, OptionsContext)
5252
#if defined(__GNUC__) || defined(__clang__)
5353
#pragma GCC diagnostic pop
@@ -84,7 +84,7 @@ MTRT_Status mtrtCompilerClientCreate(MlirContext context,
8484
ctx->getOrLoadDialect<mlir::plan::PlanDialect>();
8585
assert(planDialect && "expected loaded PlanDialect");
8686
if (failed(planDialect->extensionConstructors.addCheckedExtensionConstructor<
87-
compiler::StableHloToExecutableTask,
87+
compiler::StablehloToExecutableTask,
8888
compiler::StableHLOToExecutableTensorRTExtension>()))
8989
emitWarning(mlir::UnknownLoc::get(ctx))
9090
<< "ignoring duplicate extension load request; TensorRTExtension is "
@@ -156,7 +156,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreate(
156156
context->getLoadedDialect<mlir::plan::PlanDialect>();
157157
compiler::TaskExtensionRegistry extensions =
158158
planDialect->extensionConstructors
159-
.getExtensionRegistryForTask<compiler::StableHloToExecutableTask>();
159+
.getExtensionRegistryForTask<compiler::StablehloToExecutableTask>();
160160

161161
// Check that default extension set is loaded and set options on the TRT
162162
// extension.
@@ -168,7 +168,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreate(
168168
trtExtension->setOptions(translationOpts);
169169

170170
auto result =
171-
std::make_unique<StableHLOToExecutableOptions>(std::move(extensions));
171+
std::make_unique<StablehloToExecutableOptions>(std::move(extensions));
172172

173173
llvm::Error finalizeStatus = result->finalize();
174174

@@ -194,7 +194,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreateFromArgs(
194194
context->getLoadedDialect<mlir::plan::PlanDialect>();
195195
compiler::TaskExtensionRegistry extensions =
196196
planDialect->extensionConstructors
197-
.getExtensionRegistryForTask<compiler::StableHloToExecutableTask>();
197+
.getExtensionRegistryForTask<compiler::StablehloToExecutableTask>();
198198

199199
// Check that default extension set is loaded.
200200
assert(
@@ -203,7 +203,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreateFromArgs(
203203
"expected valid StableHLOToExecutableTensorRTExtension");
204204

205205
auto result =
206-
std::make_unique<StableHLOToExecutableOptions>(std::move(extensions));
206+
std::make_unique<StablehloToExecutableOptions>(std::move(extensions));
207207
std::vector<llvm::StringRef> argvStrRef(argc);
208208
for (unsigned i = 0; i < argc; i++)
209209
argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length);
@@ -234,7 +234,7 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
234234
const char **debugTypes, size_t debugTypeSizes, const char *dumpIrTreeDir,
235235
const char *dumpTensorRTDir) {
236236

237-
StableHLOToExecutableOptions *cppOpts = unwrap(options);
237+
StablehloToExecutableOptions *cppOpts = unwrap(options);
238238
cppOpts->get<DebugOptions>().enableLLVMDebugFlag = enableDebugging;
239239
for (unsigned i = 0; i < debugTypeSizes; i++)
240240
cppOpts->get<DebugOptions>().llvmDebugTypes.emplace_back(debugTypes[i]);
@@ -245,35 +245,9 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
245245
return mtrtStatusGetOk();
246246
}
247247

248-
MTRT_Status
249-
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
250-
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
251-
void *userData) {
252-
StableHLOToExecutableOptions *cppOpts = unwrap(options);
253-
254-
// Construct the append callback which we will pass to the callback provided
255-
// by the user. We do it this way to avoid needing a string construct in the C
256-
// API.
257-
auto appendFunc = [](MlirStringRef str, void *appendCtx) {
258-
std::string &accum = *reinterpret_cast<std::string *>(appendCtx);
259-
accum += std::string(str.data, str.length);
260-
};
261-
262-
// Capturing by reference here will cause `callback` to point to the wrong
263-
// place at the time this callback is invoked.
264-
cppOpts->layerMetadataCallback = [=](Operation *op) {
265-
std::string accum;
266-
void *appendCtx = reinterpret_cast<void *>(&accum);
267-
callback(wrap(op), appendFunc, appendCtx, userData);
268-
return accum;
269-
};
270-
271-
return mtrtStatusGetOk();
272-
}
273-
274248
MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
275249
MTRT_StableHLOToExecutableOptions options) {
276-
delete reinterpret_cast<StableHLOToExecutableOptions *>(options.ptr);
250+
delete reinterpret_cast<StablehloToExecutableOptions *>(options.ptr);
277251
return mtrtStatusGetOk();
278252
}
279253

@@ -288,7 +262,7 @@ mtrtStableHloPipelineGetCached(MTRT_CompilerClient client,
288262

289263
mlir::PassManager *runner{};
290264
if (unwrap(options)->getHash()) {
291-
runner = &unwrap(client)->getOrCreatePassManager<StableHloToExecutableTask>(
265+
runner = &unwrap(client)->getOrCreatePassManager<StablehloToExecutableTask>(
292266
*unwrap(options));
293267
result->ptr = runner;
294268
return mtrtStatusGetOk();
@@ -340,7 +314,7 @@ MTRT_Status mtrtCompilerStableHLOToExecutable(
340314
"StableHLO-to-Executable compilation expects a ModuleOp");
341315

342316
StatusOr<std::unique_ptr<mlirtrt::runtime::Executable>> exe =
343-
compiler::StableHloToExecutableTask::compileStableHLOToExecutable(
317+
compiler::StablehloToExecutableTask::compileStableHLOToExecutable(
344318
*unwrap(client), moduleOp, *unwrap(stableHloToExecutableOptions));
345319
if (!exe.isOk())
346320
return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError,

mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,5 @@ add_mlir_tensorrt_public_c_api_library(MLIRTensorRTCAPIRegisterAllDialects
55
LINK_LIBS PUBLIC
66
MLIRTensorRTRegistration
77
MLIRCAPIIR
8-
MLIRCAPITransforms
98
MLIRTensorRTCompilerStableHloToExecutable
109
)

mlir-tensorrt/compiler/lib/CAPI/Compiler/Registration/RegisterAllDialects.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,19 @@
2323
//===----------------------------------------------------------------------===//
2424

2525
#include "mlir-tensorrt-c/Compiler/Registration/RegisterAllDialects.h"
26+
#include "mlir-tensorrt/Compiler/StableHloToExecutable.h"
2627
#include "mlir-tensorrt/Registration/RegisterMlirTensorRtDialects.h"
2728
#include "mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h"
2829
#include "mlir/CAPI/IR.h"
2930

30-
void mlirTensorRTRegisterAllDialects(MlirDialectRegistry registry) {
31+
void mtrtCompilerRegisterDialects(MlirDialectRegistry registry) {
3132
mlir::registerAllMlirTensorRtDialects(*unwrap(registry));
3233
}
3334

34-
void mlirTensorRTRegisterAllPasses() {
35+
void mtrtCompilerRegisterPasses() {
3536
mlir::tensorrt::registerAllMlirTensorRtPasses();
3637
}
38+
39+
void mtrtCompilerRegisterTasks() {
40+
mlirtrt::compiler::registerStableHloToExecutableTask();
41+
}

0 commit comments

Comments
 (0)