Skip to content

Commit d532354

Browse files
Migrate internal changes (#12)
- Fix most non-functional style casts in order to conform with upstream deprecations - Fix minor build/documentation issues - add psutil and build packages Co-authored-by: Sagar Shelke Signed-off-by: Christopher Bate <[email protected]>
1 parent 31115b4 commit d532354

File tree

48 files changed

+356
-337
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+356
-337
lines changed

mlir-tensorrt/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ git apply ../build_tools/llvm-project.patch
5050

5151
# Do the build
5252
cd ..
53-
./build_tools/scripts/build_mlir.sh llvm-project build/llvm
53+
./build_tools/scripts/build_mlir.sh llvm-project build/llvm-project
5454
```
5555

5656
2. Build the project and run all tests
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env bash
2+
set -ex
3+
set -o pipefail
4+
5+
REPO_ROOT=$(pwd)
6+
BUILD_DIR="${BUILD_DIR:=${REPO_ROOT}/build/mlir-tensorrt}"
7+
8+
ENABLE_NCCL=${ENABLE_NCCL:OFF}
9+
RUN_LONG_TESTS=${RUN_LONG_TESTS:-False}
10+
LLVM_LIT_ARGS=${LLVM_LIT_ARGS:-"-v --xunit-xml-output ${BUILD_DIR}/test-results.xml --timeout=1200 --time-tests -Drun_long_tests=${RUN_LONG_TESTS}"}
11+
DOWNLOAD_TENSORRT_VERSION=${DOWNLOAD_TENSORRT_VERSION:-10.0.0.6}
12+
ENABLE_ASAN=${ENABLE_ASAN:-OFF}
13+
14+
echo "Using DOWNLOAD_TENSORRT_VERSION=${DOWNLOAD_TENSORRT_VERSION}"
15+
echo "Using LLVM_LIT_ARGS=${LLVM_LIT_ARGS}"
16+
17+
cmake -GNinja -B "${BUILD_DIR}" -S "${REPO_ROOT}" \
18+
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
19+
-DMLIR_TRT_USE_LINKER=lld -DCMAKE_BUILD_TYPE=RelWithDebInfo \
20+
-DMLIR_TRT_PACKAGE_CACHE_DIR=$PWD/.cache.cpm \
21+
-DMLIR_TRT_ENABLE_PYTHON=ON \
22+
-DMLIR_TRT_ENABLE_NCCL=${ENABLE_NCCL} \
23+
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION="$DOWNLOAD_TENSORRT_VERSION" \
24+
-DLLVM_LIT_ARGS="${LLVM_LIT_ARGS}" \
25+
-DENABLE_ASAN="${ENABLE_ASAN}" \
26+
-DMLIR_DIR=${REPO_ROOT}/build/llvm-project/lib/cmake/mlir \
27+
-DCMAKE_PLATFORM_NO_VERSIONED_SONAME=ON
28+
29+
echo "==== Running Build ==="
30+
ninja -C ${BUILD_DIR} -k 0 check-mlir-executor
31+
ninja -C ${BUILD_DIR} -k 0 check-mlir-tensorrt
32+
ninja -C ${BUILD_DIR} -k 0 check-mlir-tensorrt-dialect

mlir-tensorrt/compiler/lib/Conversion/CUDAToExecutor/CUDAToExecutor.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ struct CudaBlasRunGemmOpConverter
124124
SmallVector<Value> newOperands = {adaptor.getHandle(), adaptor.getStream()};
125125
newOperands.push_back(adaptor.getAlgo());
126126
auto createMemRefAndExractPtr = [&](Value oldVal, Value newVal) {
127-
auto memrefType = oldVal.getType().cast<MemRefType>();
127+
auto memrefType = cast<MemRefType>(oldVal.getType());
128128
if (!memrefType)
129129
return failure();
130-
assert(newVal.getType().isa<TableType>());
130+
assert(isa<TableType>(newVal.getType()));
131131
executor::MemRefDescriptor memref(newVal, memrefType);
132132
newOperands.push_back(memref.alignedPtr(b));
133133
return success();

mlir-tensorrt/compiler/lib/Conversion/PlanToExecutor/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanToExecutor
55
MLIRTensorRTExecutorDialect
66
MLIRTensorRTPlanDialect
77
MLIRTransforms
8+
MLIRSCFTransforms
89
)

mlir-tensorrt/compiler/lib/Conversion/PlanToExecutor/PlanToExecutor.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,8 @@ struct ConstantOpConverter : public OpConversionPattern<arith::ConstantOp> {
9898
LogicalResult
9999
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
100100
ConversionPatternRewriter &rewriter) const override {
101-
auto resultType = getTypeConverter()
102-
->convertType(op.getType())
103-
.dyn_cast_or_null<RankedTensorType>();
101+
auto resultType = dyn_cast_or_null<RankedTensorType>(
102+
getTypeConverter()->convertType(op.getType()));
104103
if (!resultType)
105104
return failure();
106105

mlir-tensorrt/compiler/lib/Conversion/StablehloScalarToArith/StablehloScalarToArith.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ struct StablehloRewriteConcat
119119
matchAndRewrite(stablehlo::ConcatenateOp op, OpAdaptor adaptor,
120120
OneToNPatternRewriter &rewriter) const override {
121121
if (!llvm::all_of(op->getOperandTypes(), [](Type t) {
122-
return t.cast<RankedTensorType>().getRank() == 1;
122+
return cast<RankedTensorType>(t).getRank() == 1;
123123
}))
124124
return failure();
125125
rewriter.replaceOp(op, adaptor.getFlatOperands(),
@@ -133,13 +133,13 @@ struct StablehloRewriteConcat
133133
/// scalar `type`.
134134
static Attribute getScalarValue(RewriterBase &rewriter, Type type,
135135
int64_t idx) {
136-
if (type.isa<FloatType>())
136+
if (isa<FloatType>(type))
137137
return rewriter.getFloatAttr(type, static_cast<double>(idx));
138-
if (type.isa<IndexType>())
138+
if (isa<IndexType>(type))
139139
return rewriter.getIndexAttr(idx);
140-
if (auto integerType = type.dyn_cast<IntegerType>())
140+
if (auto integerType = dyn_cast<IntegerType>(type))
141141
return rewriter.getIntegerAttr(
142-
type, APInt(type.cast<IntegerType>().getWidth(), idx));
142+
type, APInt(cast<IntegerType>(type).getWidth(), idx));
143143
return {};
144144
}
145145

mlir-tensorrt/compiler/lib/Conversion/StablehloToScf/StablehloToScf.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ static void inlineStablehloRegionIntoSCFRegion(PatternRewriter &rewriter,
5151
static Value extractScalarFromTensorValue(OpBuilder &b, Value tensor) {
5252
Location loc = tensor.getLoc();
5353
// If ranked tensor, first collapse shape.
54-
if (tensor.getType().cast<RankedTensorType>().getRank() != 0)
54+
if (cast<RankedTensorType>(tensor.getType()).getRank() != 0)
5555
tensor = b.create<tensor::CollapseShapeOp>(
5656
loc, tensor, SmallVector<ReassociationIndices>());
5757

@@ -129,10 +129,10 @@ static scf::IfOp createNestedCases(int currentIdx, stablehlo::CaseOp op,
129129

130130
// Determine if the current index matches the case index.
131131
auto scalarType = idxValue.getType();
132-
auto shapedType = scalarType.cast<ShapedType>();
132+
auto shapedType = cast<ShapedType>(scalarType);
133133
auto constAttr = DenseElementsAttr::get(
134134
shapedType,
135-
{outerBuilder.getI32IntegerAttr(currentIdx).cast<mlir::Attribute>()});
135+
{cast<mlir::Attribute>(outerBuilder.getI32IntegerAttr(currentIdx))});
136136
Value currentIdxVal = outerBuilder.create<stablehlo::ConstantOp>(
137137
loc, idxValue.getType(), constAttr);
138138

mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/ChloToTensorRT.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct ConvertChloErfToTensorRT
4141
ConversionPatternRewriter &rewriter) const override {
4242
Location loc = op->getLoc();
4343
auto operand = adaptor.getOperand();
44-
auto operandType = operand.getType().cast<RankedTensorType>();
44+
auto operandType = cast<RankedTensorType>(operand.getType());
4545
Type resultType = typeConverter->convertType(op.getType());
4646
if (!resultType)
4747
return failure();
@@ -74,7 +74,7 @@ struct ConvertChloTopKOpToTensorRT
7474
matchAndRewrite(chlo::TopKOp op, OpAdaptor adaptor,
7575
ConversionPatternRewriter &rewriter) const override {
7676
auto operand = adaptor.getOperand();
77-
RankedTensorType operandType = operand.getType().cast<RankedTensorType>();
77+
RankedTensorType operandType = cast<RankedTensorType>(operand.getType());
7878

7979
int64_t rank = operandType.getRank();
8080
uint64_t axis = static_cast<uint64_t>(rank) - 1;

mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/ControlFlowOps.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ struct ConvertCaseOp : public ConvertHloOpToTensorRTPattern<stablehlo::CaseOp> {
118118
if (!isa_and_nonnull<tensorrt::IdentityOp, stablehlo::ConvertOp>(op))
119119
return false;
120120
RankedTensorType producerType =
121-
op->getOperand(0).getType().cast<RankedTensorType>();
121+
cast<RankedTensorType>(op->getOperand(0).getType());
122122
return isa_and_nonnull<tensorrt::IdentityOp, stablehlo::ConvertOp>(op) &&
123123
producerType.getElementType().isInteger(1) &&
124124
producerType.getNumElements() == 1;

0 commit comments

Comments
 (0)