Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {

def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;

def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType],
"any cir integer, floating point or pointer type"
def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType, CIR_AnyBoolType],
"any cir integer, floating point, pointer or boolean type"
> {
let cppFunctionName = "isValidVectorTypeElementType";
}
Expand Down
87 changes: 78 additions & 9 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
#undef VISITCOMP

mlir::Value VisitBinAssign(const BinaryOperator *E);
mlir::Value emitVectorLogicalOp(const BinaryOperator *E,
cir::BinOpKind opKind);
mlir::Value VisitBinLAnd(const BinaryOperator *B);
mlir::Value VisitBinLOr(const BinaryOperator *B);
mlir::Value VisitBinComma(const BinaryOperator *E) {
Expand Down Expand Up @@ -1127,6 +1129,27 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
srcVal);
}

// Convert a vector value to a vector<bool> by testing each element for
// non-zero.
mlir::Value emitVectorToBoolConversion(mlir::Value src, mlir::Location loc) {
auto vecType = mlir::cast<cir::VectorType>(src.getType());
auto elemType = vecType.getElementType();
uint64_t numElts = vecType.getSize();

// Build a zero vector of the same type
auto zeroElemAttr = cir::IntAttr::get(elemType, 0);
llvm::SmallVector<mlir::Attribute> zeroElems(numElts, zeroElemAttr);
auto zeroVecAttr =
cir::ConstVectorAttr::get(vecType, Builder.getArrayAttr(zeroElems));
auto zeroVec = cir::ConstantOp::create(Builder, loc, vecType, zeroVecAttr);

// Perform elementwise comparison: src != 0
auto boolElemType = cir::BoolType::get(Builder.getContext());
auto boolVecType = cir::VectorType::get(boolElemType, numElts);
return cir::VecCmpOp::create(Builder, loc, boolVecType, cir::CmpOpKind::ne,
src, zeroVec);
}

/// Convert the specified expression value to a boolean (!cir.bool) truth
/// value. This is equivalent to "Val != 0".
mlir::Value emitConversionToBool(mlir::Value Src, QualType SrcType,
Expand All @@ -1137,12 +1160,17 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
return emitFloatToBoolConversion(Src, loc);

if (llvm::isa<MemberPointerType>(SrcType))
assert(0 && "not implemented");
llvm_unreachable("member pointer to bool not implemented");

if (SrcType->isIntegerType())
return emitIntToBoolConversion(Src, loc);

assert(::mlir::isa<cir::PointerType>(Src.getType()));
// Convert vector values to vector<bool>
if (SrcType->isVectorType())
return emitVectorToBoolConversion(Src, loc);

assert(::mlir::isa<cir::PointerType>(Src.getType()) &&
"expected pointer type for pointer-to-bool conversion");
return emitPointerToBoolConversion(Src, SrcType);
}

Expand Down Expand Up @@ -2757,7 +2785,12 @@ mlir::Value ScalarExprEmitter::VisitAbstractConditionalOperator(
// the select function.
if ((CGF.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) ||
condExpr->getType()->isExtVectorType()) {
llvm_unreachable("NYI");
// Treat the conditional operator as an element-wise select on vectors
mlir::Value condValue = Visit(condExpr);
mlir::Value lhsValue = Visit(lhsExpr);
mlir::Value rhsValue = Visit(rhsExpr);
return cir::VecTernaryOp::create(builder, loc, condValue, lhsValue,
rhsValue);
}

if (condExpr->getType()->isVectorType() ||
Expand Down Expand Up @@ -2866,10 +2899,47 @@ mlir::Value CIRGenFunction::emitScalarPrePostIncDec(const UnaryOperator *E,
.emitScalarPrePostIncDec(E, LV, isInc, isPre);
}

// Emit elementwise vector logical operations
mlir::Value ScalarExprEmitter::emitVectorLogicalOp(const BinaryOperator *E,
cir::BinOpKind opKind) {
assert(!cir::MissingFeatures::incrementProfileCounter());
mlir::Location loc = CGF.getLoc(E->getExprLoc());
mlir::Type resTy = convertType(E->getType());

mlir::Value lhs = Visit(E->getLHS());
mlir::Value rhs = Visit(E->getRHS());

// Build zero vector of the same type
cir::ConstantOp zeroVec = Builder.getNullValue(lhs.getType(), loc);

auto vecTy = mlir::cast<cir::VectorType>(lhs.getType());
auto boolElemTy = Builder.getBoolTy();
auto boolVecTy = cir::VectorType::get(boolElemTy, vecTy.getSize());

// Compare operands to zero to produce vector<bool>
auto lhsBool = cir::VecCmpOp::create(Builder, loc, boolVecTy,
cir::CmpOpKind::ne, lhs, zeroVec);
auto rhsBool = cir::VecCmpOp::create(Builder, loc, boolVecTy,
cir::CmpOpKind::ne, rhs, zeroVec);

// Elementwise logical operation on vector<bool>
auto logicVal =
cir::BinOp::create(Builder, loc, boolVecTy, opKind, lhsBool, rhsBool);

if (resTy == boolVecTy)
return logicVal;

// Convert back to result vector type
auto resVecTy = mlir::cast<cir::VectorType>(resTy);
if (mlir::isa<cir::IntType>(resVecTy.getElementType()))
return Builder.createBoolToInt(logicVal, resVecTy);

llvm_unreachable("NYI");
}

mlir::Value ScalarExprEmitter::VisitBinLAnd(const clang::BinaryOperator *E) {
if (E->getType()->isVectorType()) {
llvm_unreachable("NYI");
}
if (E->getType()->isVectorType())
return emitVectorLogicalOp(E, cir::BinOpKind::And);

bool InstrumentRegions = CGF.CGM.getCodeGenOpts().hasProfileClangInstr();
mlir::Type ResTy = convertType(E->getType());
Expand Down Expand Up @@ -2935,9 +3005,8 @@ mlir::Value ScalarExprEmitter::VisitBinLAnd(const clang::BinaryOperator *E) {
}

mlir::Value ScalarExprEmitter::VisitBinLOr(const clang::BinaryOperator *E) {
if (E->getType()->isVectorType()) {
llvm_unreachable("NYI");
}
if (E->getType()->isVectorType())
return emitVectorLogicalOp(E, cir::BinOpKind::Or);

bool InstrumentRegions = CGF.CGM.getCodeGenOpts().hasProfileClangInstr();
mlir::Type ResTy = convertType(E->getType());
Expand Down
89 changes: 77 additions & 12 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,80 @@ mlir::Type CIRToLLVMCastOpLowering::convertTy(mlir::Type ty) const {
return getTypeConverter()->convertType(ty);
}

// Lower a bool-to-integer cast for either scalar or vector types.
// Mirrors LLVM IR semantics:
// - Same width: bitcast
// - Different width: zero-extend
mlir::LogicalResult CIRToLLVMCastOpLowering::lowerBoolToIntCast(
cir::CastOp castOp, mlir::Value srcValue, mlir::Type dstType,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type srcType = srcValue.getType();
mlir::Type dstElemTy = dstType;

// If it's a vector, get the element type to check signedness
if (auto vt = mlir::dyn_cast<mlir::VectorType>(dstType))
dstElemTy = vt.getElementType();

bool isSigned =
mlir::isa<cir::VectorType>(castOp.getType())
? mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getType()))
.isSigned()
: false;

// Scalar case: i1 -> iN
if (auto srcIntTy = mlir::dyn_cast<mlir::IntegerType>(srcType)) {
auto dstIntTy = mlir::cast<mlir::IntegerType>(dstType);
if (srcIntTy.getWidth() == dstIntTy.getWidth()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, dstType,
srcValue);
} else {
if (isSigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(castOp, dstType,
srcValue);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, dstType,
srcValue);
}
return mlir::success();
}

// Vector case: vector<i1> -> vector<iN>
if (auto srcVecTy = mlir::dyn_cast<mlir::VectorType>(srcType)) {
auto dstVecTy = mlir::dyn_cast<mlir::VectorType>(dstType);
if (!dstVecTy)
return rewriter.notifyMatchFailure(castOp, "Target must be vector");

// Shape check
if (srcVecTy.getShape() != dstVecTy.getShape())
return rewriter.notifyMatchFailure(castOp, "Vector shape mismatch");

auto srcElemTy =
mlir::dyn_cast<mlir::IntegerType>(srcVecTy.getElementType());
auto dstElemTy =
mlir::dyn_cast<mlir::IntegerType>(dstVecTy.getElementType());

if (!srcElemTy || !dstElemTy)
return rewriter.notifyMatchFailure(castOp, "Elements must be integers");

if (srcElemTy.getWidth() == dstElemTy.getWidth()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, dstType,
srcValue);
} else {
// If destination is signed, use SExt to get -1 for true (matches OG)
// If destination is unsigned/bool, use ZExt to get 1 for true
if (isSigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(castOp, dstType,
srcValue);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, dstType,
srcValue);
}
return mlir::success();
}

return rewriter.notifyMatchFailure(castOp, "Unsupported type combination");
}

mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
cir::CastOp castOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -1423,18 +1497,9 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
return mlir::success();
}
case cir::CastKind::bool_to_int: {
auto dstTy = mlir::cast<cir::IntType>(castOp.getType());
auto llvmSrcVal = adaptor.getSrc();
auto llvmSrcTy = mlir::cast<mlir::IntegerType>(llvmSrcVal.getType());
auto llvmDstTy =
mlir::cast<mlir::IntegerType>(getTypeConverter()->convertType(dstTy));
if (llvmSrcTy.getWidth() == llvmDstTy.getWidth())
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
llvmSrcVal);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, llvmDstTy,
llvmSrcVal);
return mlir::success();
mlir::Value srcValue = adaptor.getSrc();
mlir::Type dstType = getTypeConverter()->convertType(castOp.getType());
return lowerBoolToIntCast(castOp, srcValue, dstType, rewriter);
}
case cir::CastKind::bool_to_float: {
auto dstTy = castOp.getType();
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {

mlir::Type convertTy(mlir::Type ty) const;

mlir::LogicalResult
lowerBoolToIntCast(cir::CastOp castOp, mlir::Value srcValue,
mlir::Type dstType,
mlir::ConversionPatternRewriter &rewriter) const;

public:
CIRToLLVMCastOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
Expand Down
39 changes: 39 additions & 0 deletions clang/test/CIR/CodeGen/OpenCL/vec_logic.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: %clang -cc1 -triple spirv64-unknown-unknown -cl-std=CL2.0 -finclude-default-header -O0 -emit-cir -fclangir -o - %s | FileCheck %s --check-prefix=CIR
// RUN: %clang -cc1 -triple spirv64-unknown-unknown -cl-std=CL2.0 -finclude-default-header -O0 -emit-llvm -fclangir -o - %s | FileCheck %s --check-prefix=LLVM
// RUN: %clang -cc1 -triple spirv64-unknown-unknown -cl-std=CL2.0 -finclude-default-header -O0 -emit-llvm -o - %s | FileCheck %s --check-prefix=OG-LLVM

kernel void test(char4 in1, char4 in2, local char4 *out)
{
*out = (in1 == (char4)3 && (in1 == (char4)5 || in2 == (char4)7))
? in1 : in2;
}


// CIR: [[ZERO:%.*]] = cir.const #cir.zero : !cir.vector<!s8i x 4>
// CIR: [[CMP1:%.*]] = cir.vec.cmp(ne, %{{.*}}, [[ZERO]]) : !cir.vector<!s8i x 4>, !cir.vector<!cir.bool x 4>
// CIR: [[CMP2:%.*]] = cir.vec.cmp(ne, %{{.*}}, [[ZERO]]) : !cir.vector<!s8i x 4>, !cir.vector<!cir.bool x 4>
// CIR: [[OR:%.*]] = cir.binop(or, [[CMP1]], [[CMP2]]) : !cir.vector<!cir.bool x 4>
// CIR: [[CAST1:%.*]] = cir.cast bool_to_int [[OR]] : !cir.vector<!cir.bool x 4> -> !cir.vector<!s8i x 4>
// CIR: [[ZERO2:%.*]] = cir.const #cir.zero : !cir.vector<!s8i x 4>
// CIR: [[CMP3:%.*]] = cir.vec.cmp(ne, %{{.*}}, [[ZERO2]]) : !cir.vector<!s8i x 4>, !cir.vector<!cir.bool x 4>
// CIR: [[CMP4:%.*]] = cir.vec.cmp(ne, [[CAST1]], [[ZERO2]]) : !cir.vector<!s8i x 4>, !cir.vector<!cir.bool x 4>
// CIR: [[AND:%.*]] = cir.binop(and, [[CMP3]], [[CMP4]]) : !cir.vector<!cir.bool x 4>
// CIR: cir.cast bool_to_int [[AND]] : !cir.vector<!cir.bool x 4> -> !cir.vector<!s8i x 4>

// LLVM: [[CMP1:%.*]] = icmp ne <4 x i8> %{{.*}}, zeroinitializer
// LLVM: [[CMP2:%.*]] = icmp ne <4 x i8> %{{.*}}, zeroinitializer
// LLVM: [[OR:%.*]] = or <4 x i1> [[CMP1]], [[CMP2]]
// LLVM: [[SEXT:%.*]] = sext <4 x i1> [[OR]] to <4 x i8>
// LLVM: [[CMP3:%.*]] = icmp ne <4 x i8> %{{.*}}, zeroinitializer
// LLVM: [[CMP4:%.*]] = icmp ne <4 x i8> [[SEXT]], zeroinitializer
// LLVM: [[AND:%.*]] = and <4 x i1> [[CMP3]], [[CMP4]]
// LLVM: sext <4 x i1> [[AND]] to <4 x i8>

// OG-LLVM: [[CMP1:%.*]] = icmp ne <4 x i8> %{{.*}}, zeroinitializer
// OG-LLVM: [[CMP2:%.*]] = icmp ne <4 x i8> %{{.*}}, zeroinitializer
// OG-LLVM: [[OR:%.*]] = or <4 x i1> [[CMP1]], [[CMP2]]
// OG-LLVM: [[SEXT:%.*]] = sext <4 x i1> [[OR]] to <4 x i8>
// OG-LLVM: [[CMP3:%.*]] = icmp ne <4 x i8> %{{.*}}, zeroinitializer
// OG-LLVM: [[CMP4:%.*]] = icmp ne <4 x i8> [[SEXT]], zeroinitializer
// OG-LLVM: [[AND:%.*]] = and <4 x i1> [[CMP3]], [[CMP4]]
// OG-LLVM: sext <4 x i1> [[AND]] to <4 x i8>