Skip to content

Commit b53db89

Browse files
authored
[CIR] Upstream SelectOp and ShiftOp (#133405)
Since SelectOp will only generated by a future pass that transforms a TernaryOp this only includes the lowering bits. This patch also improves the testing of the existing binary operators. --------- Co-authored-by: Morris Hafner <[email protected]>
1 parent c607180 commit b53db89

File tree

8 files changed

+813
-18
lines changed

8 files changed

+813
-18
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+38
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,44 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
357357
return create<cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
358358
}
359359

360+
mlir::Value createShift(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
361+
bool isShiftLeft) {
362+
return create<cir::ShiftOp>(loc, lhs.getType(), lhs, rhs, isShiftLeft);
363+
}
364+
365+
mlir::Value createShift(mlir::Location loc, mlir::Value lhs,
366+
const llvm::APInt &rhs, bool isShiftLeft) {
367+
return createShift(loc, lhs, getConstAPInt(loc, lhs.getType(), rhs),
368+
isShiftLeft);
369+
}
370+
371+
mlir::Value createShift(mlir::Location loc, mlir::Value lhs, unsigned bits,
372+
bool isShiftLeft) {
373+
auto width = mlir::dyn_cast<cir::IntType>(lhs.getType()).getWidth();
374+
auto shift = llvm::APInt(width, bits);
375+
return createShift(loc, lhs, shift, isShiftLeft);
376+
}
377+
378+
mlir::Value createShiftLeft(mlir::Location loc, mlir::Value lhs,
379+
unsigned bits) {
380+
return createShift(loc, lhs, bits, true);
381+
}
382+
383+
mlir::Value createShiftRight(mlir::Location loc, mlir::Value lhs,
384+
unsigned bits) {
385+
return createShift(loc, lhs, bits, false);
386+
}
387+
388+
mlir::Value createShiftLeft(mlir::Location loc, mlir::Value lhs,
389+
mlir::Value rhs) {
390+
return createShift(loc, lhs, rhs, true);
391+
}
392+
393+
mlir::Value createShiftRight(mlir::Location loc, mlir::Value lhs,
394+
mlir::Value rhs) {
395+
return createShift(loc, lhs, rhs, false);
396+
}
397+
360398
//
361399
// Block handling helpers
362400
// ----------------------

clang/include/clang/CIR/Dialect/IR/CIROps.td

+73
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,79 @@ def BinOp : CIR_Op<"binop", [Pure,
11731173
let hasVerifier = 1;
11741174
}
11751175

1176+
//===----------------------------------------------------------------------===//
1177+
// ShiftOp
1178+
//===----------------------------------------------------------------------===//
1179+
1180+
def ShiftOp : CIR_Op<"shift", [Pure]> {
1181+
let summary = "Shift";
1182+
let description = [{
1183+
The `cir.shift` operation performs a bitwise shift, either to the left or to
1184+
the right, based on the first operand. The second operand specifies the
1185+
value to be shifted, and the third operand determines the number of
1186+
positions by which the shift is applied. Both the second and third operands
1187+
are required to be integers.
1188+
1189+
```mlir
1190+
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
1191+
```
1192+
}];
1193+
1194+
// TODO(cir): Support vectors. CIR_IntType -> CIR_AnyIntOrVecOfInt. Also
1195+
// update the description above.
1196+
let results = (outs CIR_IntType:$result);
1197+
let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
1198+
UnitAttr:$isShiftleft);
1199+
1200+
let assemblyFormat = [{
1201+
`(`
1202+
(`left` $isShiftleft^) : (```right`)?
1203+
`,` $value `:` type($value)
1204+
`,` $amount `:` type($amount)
1205+
`)` `->` type($result) attr-dict
1206+
}];
1207+
1208+
let hasVerifier = 1;
1209+
}
1210+
1211+
//===----------------------------------------------------------------------===//
1212+
// SelectOp
1213+
//===----------------------------------------------------------------------===//
1214+
1215+
def SelectOp : CIR_Op<"select", [Pure,
1216+
AllTypesMatch<["true_value", "false_value", "result"]>]> {
1217+
let summary = "Yield one of two values based on a boolean value";
1218+
let description = [{
1219+
The `cir.select` operation takes three operands. The first operand
1220+
`condition` is a boolean value of type `!cir.bool`. The second and the third
1221+
operand can be of any CIR types, but their types must be the same. If the
1222+
first operand is `true`, the operation yields its second operand. Otherwise,
1223+
the operation yields its third operand.
1224+
1225+
Example:
1226+
1227+
```mlir
1228+
%0 = cir.const #cir.bool<true> : !cir.bool
1229+
%1 = cir.const #cir.int<42> : !s32i
1230+
%2 = cir.const #cir.int<72> : !s32i
1231+
%3 = cir.select if %0 then %1 else %2 : (!cir.bool, !s32i, !s32i) -> !s32i
1232+
```
1233+
}];
1234+
1235+
let arguments = (ins CIR_BoolType:$condition, CIR_AnyType:$true_value,
1236+
CIR_AnyType:$false_value);
1237+
let results = (outs CIR_AnyType:$result);
1238+
1239+
let assemblyFormat = [{
1240+
`if` $condition `then` $true_value `else` $false_value
1241+
`:` `(`
1242+
qualified(type($condition)) `,`
1243+
qualified(type($true_value)) `,`
1244+
qualified(type($false_value))
1245+
`)` `->` qualified(type($result)) attr-dict
1246+
}];
1247+
}
1248+
11761249
//===----------------------------------------------------------------------===//
11771250
// GlobalOp
11781251
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -1303,8 +1303,7 @@ mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &ops) {
13031303
mlir::isa<cir::IntType>(ops.lhs.getType()))
13041304
cgf.cgm.errorNYI("sanitizers");
13051305

1306-
cgf.cgm.errorNYI("shift ops");
1307-
return {};
1306+
return builder.createShiftLeft(cgf.getLoc(ops.loc), ops.lhs, ops.rhs);
13081307
}
13091308

13101309
mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {
@@ -1328,8 +1327,7 @@ mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) {
13281327

13291328
// Note that we don't need to distinguish unsigned treatment at this
13301329
// point since it will be handled later by LLVM lowering.
1331-
cgf.cgm.errorNYI("shift ops");
1332-
return {};
1330+
return builder.createShiftRight(cgf.getLoc(ops.loc), ops.lhs, ops.rhs);
13331331
}
13341332

13351333
mlir::Value ScalarExprEmitter::emitAnd(const BinOpInfo &ops) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,9 @@ mlir::OpTrait::impl::verifySameFirstOperandAndResultType(Operation *op) {
10241024
// been implemented yet.
10251025
mlir::LogicalResult cir::FuncOp::verify() { return success(); }
10261026

1027+
//===----------------------------------------------------------------------===//
1028+
// BinOp
1029+
//===----------------------------------------------------------------------===//
10271030
LogicalResult cir::BinOp::verify() {
10281031
bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
10291032
bool saturated = getSaturated();
@@ -1055,6 +1058,24 @@ LogicalResult cir::BinOp::verify() {
10551058
return mlir::success();
10561059
}
10571060

1061+
//===----------------------------------------------------------------------===//
1062+
// ShiftOp
1063+
//===----------------------------------------------------------------------===//
1064+
LogicalResult cir::ShiftOp::verify() {
1065+
mlir::Operation *op = getOperation();
1066+
mlir::Type resType = getResult().getType();
1067+
assert(!cir::MissingFeatures::vectorType());
1068+
bool isOp0Vec = false;
1069+
bool isOp1Vec = false;
1070+
if (isOp0Vec != isOp1Vec)
1071+
return emitOpError() << "input types cannot be one vector and one scalar";
1072+
if (isOp1Vec && op->getOperand(1).getType() != resType) {
1073+
return emitOpError() << "shift amount must have the type of the result "
1074+
<< "if it is vector shift";
1075+
}
1076+
return mlir::success();
1077+
}
1078+
10581079
//===----------------------------------------------------------------------===//
10591080
// UnaryOp
10601081
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+88
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/DLTI/DLTI.h"
2020
#include "mlir/Dialect/Func/IR/FuncOps.h"
2121
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22+
#include "mlir/IR/BuiltinAttributes.h"
2223
#include "mlir/IR/BuiltinDialect.h"
2324
#include "mlir/IR/BuiltinOps.h"
2425
#include "mlir/IR/Types.h"
@@ -28,6 +29,7 @@
2829
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2930
#include "mlir/Target/LLVMIR/Export.h"
3031
#include "mlir/Transforms/DialectConversion.h"
32+
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
3133
#include "clang/CIR/Dialect/IR/CIRDialect.h"
3234
#include "clang/CIR/Dialect/Passes.h"
3335
#include "clang/CIR/LoweringHelpers.h"
@@ -1292,6 +1294,90 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
12921294
return mlir::success();
12931295
}
12941296

1297+
mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
1298+
cir::ShiftOp op, OpAdaptor adaptor,
1299+
mlir::ConversionPatternRewriter &rewriter) const {
1300+
auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
1301+
auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());
1302+
1303+
// Operands could also be vector type
1304+
assert(!cir::MissingFeatures::vectorType());
1305+
mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1306+
mlir::Value amt = adaptor.getAmount();
1307+
mlir::Value val = adaptor.getValue();
1308+
1309+
// TODO(cir): Assert for vector types
1310+
assert((cirValTy && cirAmtTy) &&
1311+
"shift input type must be integer or vector type, otherwise NYI");
1312+
1313+
assert((cirValTy == op.getType()) && "inconsistent operands' types NYI");
1314+
1315+
// Ensure shift amount is the same type as the value. Some undefined
1316+
// behavior might occur in the casts below as per [C99 6.5.7.3].
1317+
// Vector type shift amount needs no cast as type consistency is expected to
1318+
// be already be enforced at CIRGen.
1319+
if (cirAmtTy)
1320+
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
1321+
true, cirAmtTy.getWidth(), cirValTy.getWidth());
1322+
1323+
// Lower to the proper LLVM shift operation.
1324+
if (op.getIsShiftleft()) {
1325+
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
1326+
} else {
1327+
assert(!cir::MissingFeatures::vectorType());
1328+
bool isUnsigned = !cirValTy.isSigned();
1329+
if (isUnsigned)
1330+
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
1331+
else
1332+
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
1333+
}
1334+
1335+
return mlir::success();
1336+
}
1337+
1338+
mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
1339+
cir::SelectOp op, OpAdaptor adaptor,
1340+
mlir::ConversionPatternRewriter &rewriter) const {
1341+
auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr {
1342+
auto definingOp =
1343+
mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp());
1344+
if (!definingOp)
1345+
return {};
1346+
1347+
auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue());
1348+
if (!constValue)
1349+
return {};
1350+
1351+
return constValue;
1352+
};
1353+
1354+
// Two special cases in the LLVMIR codegen of select op:
1355+
// - select %0, %1, false => and %0, %1
1356+
// - select %0, true, %1 => or %0, %1
1357+
if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
1358+
cir::BoolAttr trueValue = getConstantBool(op.getTrueValue());
1359+
cir::BoolAttr falseValue = getConstantBool(op.getFalseValue());
1360+
if (falseValue && !falseValue.getValue()) {
1361+
// select %0, %1, false => and %0, %1
1362+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(),
1363+
adaptor.getTrueValue());
1364+
return mlir::success();
1365+
}
1366+
if (trueValue && trueValue.getValue()) {
1367+
// select %0, true, %1 => or %0, %1
1368+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
1369+
adaptor.getFalseValue());
1370+
return mlir::success();
1371+
}
1372+
}
1373+
1374+
mlir::Value llvmCondition = adaptor.getCondition();
1375+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
1376+
op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
1377+
1378+
return mlir::success();
1379+
}
1380+
12951381
static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
12961382
mlir::DataLayout &dataLayout) {
12971383
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1465,6 +1551,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
14651551
CIRToLLVMConstantOpLowering,
14661552
CIRToLLVMFuncOpLowering,
14671553
CIRToLLVMGetGlobalOpLowering,
1554+
CIRToLLVMSelectOpLowering,
1555+
CIRToLLVMShiftOpLowering,
14681556
CIRToLLVMTrapOpLowering,
14691557
CIRToLLVMUnaryOpLowering
14701558
// clang-format on

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

+20
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,26 @@ class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
209209
mlir::ConversionPatternRewriter &) const override;
210210
};
211211

212+
class CIRToLLVMShiftOpLowering
213+
: public mlir::OpConversionPattern<cir::ShiftOp> {
214+
public:
215+
using mlir::OpConversionPattern<cir::ShiftOp>::OpConversionPattern;
216+
217+
mlir::LogicalResult
218+
matchAndRewrite(cir::ShiftOp op, OpAdaptor,
219+
mlir::ConversionPatternRewriter &) const override;
220+
};
221+
222+
class CIRToLLVMSelectOpLowering
223+
: public mlir::OpConversionPattern<cir::SelectOp> {
224+
public:
225+
using mlir::OpConversionPattern<cir::SelectOp>::OpConversionPattern;
226+
227+
mlir::LogicalResult
228+
matchAndRewrite(cir::SelectOp op, OpAdaptor,
229+
mlir::ConversionPatternRewriter &) const override;
230+
};
231+
212232
class CIRToLLVMBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> {
213233
public:
214234
using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern;

0 commit comments

Comments
 (0)