Skip to content

Commit 677a743

Browse files
Add Math dialect to EmitC conversion support (#729)
- Add MathToEmitC pass and ArithExpandOps pass to StablehloToExecutable - Add Math Log conversion pattern for f32 and f64 types - Add cmath include for log/logf functions - Add test for Math Log to EmitC conversion GitOrigin-RevId: a6a7f80125e3b92981ca40194181f7ff923d09d2
1 parent e3a2350 commit 677a743

6 files changed

Lines changed: 70 additions & 2 deletions

File tree

mlir-tensorrt/compiler/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ mtrt_add_project_targets(MLIRTensorRT
3939
MLIRAsyncDialect
4040
MLIRBufferizationTransforms
4141
MLIRControlFlowTransforms
42+
MLIRDebug
4243
MLIRFuncInlinerExtension
4344
MLIRNVVMTarget
4445
MLIRPtrDialect
46+
MLIRTargetCpp
4547
MLIRTargetLLVM
4648
MLIRTensorInferTypeOpInterfaceImpl
4749
MLIRTensorTransformOps
48-
MLIRTargetCpp
49-
MLIRDebug
5050
MLIRTransformDialectTransforms
5151
)
5252

mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ add_mlir_tensorrt_library(MLIRTensorRTCompilerStableHloToExecutable
4141
MLIRTensorRTTransformsUnrollForLoops
4242
StablehloLinalgTransforms
4343
MLIR_LIBS PUBLIC
44+
MLIRArithTransforms
4445
MLIREmitCTransforms
4546
MLIRIR
4647
MLIRLLVMIRTransforms
48+
MLIRMathToEmitC
4749
)
4850

mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StablehloToExecutable.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@
3737
#include "mlir-tensorrt/Transforms/Passes.h"
3838
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
3939
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
40+
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
41+
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
4042
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
4143
#include "mlir/Dialect/Affine/Passes.h"
44+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
4245
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
4346
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
4447
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -241,6 +244,10 @@ void StablehloToExecutableTask::populatePassManager() {
241244
// For EmitC, just run Host-to-EmitC followed
242245
// by cleanup and expression forming.
243246
if (hostTarget == HostTarget::EmitC) {
247+
pm.addNestedPass<func::FuncOp>(mlir::createConvertMathToEmitC());
248+
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
249+
pm.addNestedPass<func::FuncOp>(createCSEPass());
250+
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
244251
pm.addPass(createConvertHostToEmitCPass({options.artifactsDirectory}));
245252
addCleanupPasses(pm);
246253
// The EmitC "form-expressions" pass combines operations into

mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ add_mlir_tensorrt_library(MLIRTensorRTHostToEmitC
1616
MLIRSCFToEmitC
1717
MLIRFuncToEmitC
1818
MLIRFuncTransforms
19+
MLIRMathDialect
1920
)

mlir-tensorrt/compiler/lib/Conversion/HostToEmitC/HostToEmitC.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/Dialect/EmitC/IR/EmitC.h"
3535
#include "mlir/Dialect/Func/IR/FuncOps.h"
3636
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
37+
#include "mlir/Dialect/Math/IR/Math.h"
3738
#include "mlir/Dialect/MemRef/IR/MemRef.h"
3839
#include "mlir/Dialect/Utils/IndexingUtils.h"
3940
#include "mlir/Interfaces/DataLayoutInterfaces.h"
@@ -1686,7 +1687,33 @@ class MemRefExtractAlignedPointerAsIndexConverter
16861687
return success();
16871688
}
16881689
};
1690+
} // namespace
1691+
1692+
//===----------------------------------------------------------------------===//
1693+
// Math conversions missing from 'math-to-emitc' pass
1694+
//===----------------------------------------------------------------------===//
1695+
1696+
namespace {
1697+
struct MathLogToEmitCPattern : public EmitCConversionPattern<math::LogOp> {
1698+
using EmitCConversionPattern::EmitCConversionPattern;
16891699

1700+
LogicalResult
1701+
matchAndRewrite(math::LogOp op, OpAdaptor adaptor,
1702+
ConversionPatternRewriter &rewriter) const override {
1703+
1704+
Value input = adaptor.getOperand();
1705+
Type inputType = input.getType();
1706+
if (!inputType.isF32() && !inputType.isF64())
1707+
return failure();
1708+
llvm::StringRef funcName = "log";
1709+
if (inputType.isF32())
1710+
funcName = "logf";
1711+
auto callOp =
1712+
createCallOpaque(rewriter, op.getLoc(), inputType, funcName, {input});
1713+
rewriter.replaceOp(op, callOp.getResult(0));
1714+
return success();
1715+
}
1716+
};
16901717
} // namespace
16911718

16921719
//===----------------------------------------------------------------------===//
@@ -1762,6 +1789,7 @@ static void populateEmitCConversionPatternsAndLegality(
17621789
CUDAStreamSyncConverter,
17631790
ExecutorPrintConverter,
17641791
ExtractStridedMetadataOpLowering,
1792+
MathLogToEmitCPattern,
17651793
MemRefAllocOpLowering,
17661794
MemrefCastOpLowering,
17671795
MemRefDimOpLowering,
@@ -1858,6 +1886,7 @@ class HostToEmitCPass
18581886
rewriter.create<emitc::IncludeOp>(moduleOp->getLoc(), "cstdio", true);
18591887
rewriter.create<emitc::IncludeOp>(moduleOp->getLoc(), "cstdint", true);
18601888
rewriter.create<emitc::IncludeOp>(moduleOp->getLoc(), "cstdlib", true);
1889+
rewriter.create<emitc::IncludeOp>(moduleOp->getLoc(), "cmath", true);
18611890
rewriter.create<emitc::IncludeOp>(moduleOp->getLoc(), "MTRTRuntime.h",
18621891
false);
18631892

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: rm -rf %t || true
2+
// RUN: mkdir -p %t
3+
// RUN: mlir-tensorrt-opt -split-input-file -convert-host-to-emitc="artifacts-dir=%t" %s | \
4+
// RUN: mlir-tensorrt-translate -split-input-file -mlir-to-cpp | FileCheck %s --check-prefix=CPP
5+
6+
// Test Math Log conversion to EmitC
7+
8+
func.func @math_log_f32(%arg0: f32) -> f32 {
9+
%0 = math.log %arg0 : f32
10+
return %0 : f32
11+
}
12+
13+
// CPP-LABEL: float math_log_f32(float v1) {
14+
// CPP-NEXT: float v2 = logf(v1);
15+
// CPP-NEXT: return v2;
16+
// CPP-NEXT: }
17+
18+
// -----
19+
20+
func.func @math_log_f64(%arg0: f64) -> f64 {
21+
%0 = math.log %arg0 : f64
22+
return %0 : f64
23+
}
24+
25+
// CPP-LABEL: double math_log_f64(double v1) {
26+
// CPP-NEXT: double v2 = log(v1);
27+
// CPP-NEXT: return v2;
28+
// CPP-NEXT: }
29+

0 commit comments

Comments
 (0)