Skip to content

Commit b7e4260

Browse files
frasercrmckfhahn
authored andcommitted
[clang] Restrict the use of scalar types in vector builtins (llvm#119423)
This commit restricts the use of scalar types in vector math builtins, particularly the `__builtin_elementwise_*` builtins. Previously, small scalar integer types would be promoted to `int`, as per the usual conversions. This would silently do the wrong thing for certain operations, such as `add_sat`, `popcount`, `bitreverse`, and others. Similarly, since unsigned integer types were promoted to `int`, something like `add_sat(unsigned char, unsigned char)` would perform a *signed* operation. With this patch, promotable scalar integer types are not promoted to int, and are kept intact. If any of the types differ in the binary and ternary builtins, an error is issued. Similarly an error is issued if builtins are supplied integer types of different signs. Mixing enums of different types in binary/ternary builtins now consistently raises an error in all language modes. This brings the behaviour surrounding scalar types more in line with that of vector types. No change is made to vector types, which are both not promoted and whose element types must match. Fixes llvm#84047. RFC: https://discourse.llvm.org/t/rfc-change-behaviour-of-elementwise-builtins-on-scalar-integer-types/83725
1 parent b0d6e64 commit b7e4260

File tree

10 files changed

+322
-141
lines changed

10 files changed

+322
-141
lines changed

clang/docs/LanguageExtensions.rst

+29-36
Original file line numberDiff line numberDiff line change
@@ -647,42 +647,35 @@ elementwise to the input.
647647

648648
Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±infinity
649649

650-
=========================================== ================================================================ =========================================
651-
Name Operation Supported element types
652-
=========================================== ================================================================ =========================================
653-
T __builtin_elementwise_abs(T x) return the absolute value of a number x; the absolute value of signed integer and floating point types
654-
the most negative integer remains the most negative integer
655-
T __builtin_elementwise_fma(T x, T y, T z) fused multiply add, (x * y) + z. floating point types
656-
T __builtin_elementwise_ceil(T x) return the smallest integral value greater than or equal to x floating point types
657-
T __builtin_elementwise_sin(T x) return the sine of x interpreted as an angle in radians floating point types
658-
T __builtin_elementwise_cos(T x) return the cosine of x interpreted as an angle in radians floating point types
659-
T __builtin_elementwise_tan(T x) return the tangent of x interpreted as an angle in radians floating point types
660-
T __builtin_elementwise_asin(T x) return the arcsine of x interpreted as an angle in radians floating point types
661-
T __builtin_elementwise_acos(T x) return the arccosine of x interpreted as an angle in radians floating point types
662-
T __builtin_elementwise_atan(T x) return the arctangent of x interpreted as an angle in radians floating point types
663-
T __builtin_elementwise_sinh(T x) return the hyperbolic sine of angle x in radians floating point types
664-
T __builtin_elementwise_cosh(T x) return the hyperbolic cosine of angle x in radians floating point types
665-
T __builtin_elementwise_tanh(T x) return the hyperbolic tangent of angle x in radians floating point types
666-
T __builtin_elementwise_floor(T x) return the largest integral value less than or equal to x floating point types
667-
T __builtin_elementwise_log(T x) return the natural logarithm of x floating point types
668-
T __builtin_elementwise_log2(T x) return the base 2 logarithm of x floating point types
669-
T __builtin_elementwise_log10(T x) return the base 10 logarithm of x floating point types
670-
T __builtin_elementwise_pow(T x, T y) return x raised to the power of y floating point types
671-
T __builtin_elementwise_bitreverse(T x) return the integer represented after reversing the bits of x integer types
672-
T __builtin_elementwise_exp(T x) returns the base-e exponential, e^x, of the specified value floating point types
673-
T __builtin_elementwise_exp2(T x) returns the base-2 exponential, 2^x, of the specified value floating point types
674-
675-
T __builtin_elementwise_sqrt(T x) return the square root of a floating-point number floating point types
676-
T __builtin_elementwise_roundeven(T x) round x to the nearest integer value in floating point format, floating point types
677-
rounding halfway cases to even (that is, to the nearest value
678-
that is an even integer), regardless of the current rounding
679-
direction.
680-
T __builtin_elementwise_round(T x) round x to the nearest integer value in floating point format, floating point types
681-
rounding halfway cases away from zero, regardless of the
682-
current rounding direction. May raise floating-point
683-
exceptions.
684-
T __builtin_elementwise_trunc(T x) return the integral value nearest to but no larger in floating point types
685-
magnitude than x
650+
No implicit promotion of integer types takes place. The mixing of integer types
651+
of different sizes and signs is forbidden in binary and ternary builtins.
652+
653+
============================================== ====================================================================== =========================================
654+
Name Operation Supported element types
655+
============================================== ====================================================================== =========================================
656+
T __builtin_elementwise_abs(T x) return the absolute value of a number x; the absolute value of signed integer and floating point types
657+
the most negative integer remains the most negative integer
658+
T __builtin_elementwise_fma(T x, T y, T z) fused multiply add, (x * y) + z. floating point types
659+
T __builtin_elementwise_ceil(T x) return the smallest integral value greater than or equal to x floating point types
660+
T __builtin_elementwise_sin(T x) return the sine of x interpreted as an angle in radians floating point types
661+
T __builtin_elementwise_cos(T x) return the cosine of x interpreted as an angle in radians floating point types
662+
T __builtin_elementwise_tan(T x) return the tangent of x interpreted as an angle in radians floating point types
663+
T __builtin_elementwise_asin(T x) return the arcsine of x interpreted as an angle in radians floating point types
664+
T __builtin_elementwise_acos(T x) return the arccosine of x interpreted as an angle in radians floating point types
665+
T __builtin_elementwise_atan(T x) return the arctangent of x interpreted as an angle in radians floating point types
666+
T __builtin_elementwise_atan2(T y, T x) return the arctangent of y/x floating point types
667+
T __builtin_elementwise_sinh(T x) return the hyperbolic sine of angle x in radians floating point types
668+
T __builtin_elementwise_cosh(T x) return the hyperbolic cosine of angle x in radians floating point types
669+
T __builtin_elementwise_tanh(T x) return the hyperbolic tangent of angle x in radians floating point types
670+
T __builtin_elementwise_floor(T x) return the largest integral value less than or equal to x floating point types
671+
T __builtin_elementwise_log(T x) return the natural logarithm of x floating point types
672+
T __builtin_elementwise_log2(T x) return the base 2 logarithm of x floating point types
673+
T __builtin_elementwise_log10(T x) return the base 10 logarithm of x floating point types
674+
T __builtin_elementwise_popcount(T x) return the number of 1 bits in x integer types
675+
T __builtin_elementwise_pow(T x, T y) return x raised to the power of y floating point types
676+
T __builtin_elementwise_bitreverse(T x) return the integer represented after reversing the bits of x integer types
677+
T __builtin_elementwise_exp(T x) returns the base-e exponential, e^x, of the specified value floating point types
678+
T __builtin_elementwise_exp2(T x) returns the base-2 exponential, 2^x, of the specified value floating point types
686679

687680
T __builtin_elementwise_nearbyint(T x) round x to the nearest integer value in floating point format, floating point types
688681
rounding according to the current rounding direction.

clang/include/clang/Sema/Sema.h

+18-5
Original file line numberDiff line numberDiff line change
@@ -2551,7 +2551,9 @@ class Sema final : public SemaBase {
25512551
bool CheckFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall,
25522552
const FunctionProtoType *Proto);
25532553

2554-
bool BuiltinVectorMath(CallExpr *TheCall, QualType &Res);
2554+
/// \param FPOnly restricts the arguments to floating-point types.
2555+
std::optional<QualType> BuiltinVectorMath(CallExpr *TheCall,
2556+
bool FPOnly = false);
25552557
bool BuiltinVectorToScalarMath(CallExpr *TheCall);
25562558

25572559
/// Handles the checks for format strings, non-POD arguments to vararg
@@ -2762,8 +2764,9 @@ class Sema final : public SemaBase {
27622764
ExprResult AtomicOpsOverloaded(ExprResult TheCallResult,
27632765
AtomicExpr::AtomicOp Op);
27642766

2765-
bool BuiltinElementwiseMath(CallExpr *TheCall);
2766-
bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
2767+
bool BuiltinElementwiseMath(CallExpr *TheCall, bool FPOnly);
2768+
bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall,
2769+
bool FPOnly = false);
27672770

27682771
bool BuiltinNonDeterministicValue(CallExpr *TheCall);
27692772

@@ -7751,10 +7754,15 @@ class Sema final : public SemaBase {
77517754
return K == ConditionKind::Switch ? Context.IntTy : Context.BoolTy;
77527755
}
77537756

7754-
// UsualUnaryConversions - promotes integers (C99 6.3.1.1p2) and converts
7755-
// functions and arrays to their respective pointers (C99 6.3.2.1).
7757+
// UsualUnaryConversions - promotes integers (C99 6.3.1.1p2), converts
7758+
// functions and arrays to their respective pointers (C99 6.3.2.1), and
7759+
// promotes floating-piont types according to the language semantics.
77567760
ExprResult UsualUnaryConversions(Expr *E);
77577761

7762+
// UsualUnaryFPConversions - promotes floating-point types according to the
7763+
// current language semantics.
7764+
ExprResult UsualUnaryFPConversions(Expr *E);
7765+
77587766
/// CallExprUnaryConversions - a special case of an unary conversion
77597767
/// performed on a function designator of a call expression.
77607768
ExprResult CallExprUnaryConversions(Expr *E);
@@ -7829,6 +7837,11 @@ class Sema final : public SemaBase {
78297837
ExprResult DefaultVariadicArgumentPromotion(Expr *E, VariadicCallType CT,
78307838
FunctionDecl *FDecl);
78317839

7840+
// Check that the usual arithmetic conversions can be performed on this pair
7841+
// of expressions that might be of enumeration type.
7842+
void checkEnumArithmeticConversions(Expr *LHS, Expr *RHS, SourceLocation Loc,
7843+
Sema::ArithConvKind ACK);
7844+
78327845
// UsualArithmeticConversions - performs the UsualUnaryConversions on it's
78337846
// operands and then handles various conversions that are common to binary
78347847
// operators (C99 6.3.1.8). If both operands aren't arithmetic, this

clang/lib/Sema/SemaChecking.cpp

+81-36
Original file line numberDiff line numberDiff line change
@@ -2711,7 +2711,7 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
27112711
// These builtins restrict the element type to floating point
27122712
// types only, and take in two arguments.
27132713
case Builtin::BI__builtin_elementwise_pow: {
2714-
if (BuiltinElementwiseMath(TheCall))
2714+
if (BuiltinElementwiseMath(TheCall, true))
27152715
return ExprError();
27162716

27172717
QualType ArgTy = TheCall->getArg(0)->getType();
@@ -2727,7 +2727,7 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
27272727
// types only.
27282728
case Builtin::BI__builtin_elementwise_add_sat:
27292729
case Builtin::BI__builtin_elementwise_sub_sat: {
2730-
if (BuiltinElementwiseMath(TheCall))
2730+
if (BuiltinElementwiseMath(TheCall, false))
27312731
return ExprError();
27322732

27332733
const Expr *Arg = TheCall->getArg(0);
@@ -2747,7 +2747,7 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
27472747

27482748
case Builtin::BI__builtin_elementwise_min:
27492749
case Builtin::BI__builtin_elementwise_max:
2750-
if (BuiltinElementwiseMath(TheCall))
2750+
if (BuiltinElementwiseMath(TheCall, false))
27512751
return ExprError();
27522752
break;
27532753

@@ -14899,11 +14899,23 @@ void Sema::CheckAddressOfPackedMember(Expr *rhs) {
1489914899
_2, _3, _4));
1490014900
}
1490114901

14902+
// Performs a similar job to Sema::UsualUnaryConversions, but without any
14903+
// implicit promotion of integral/enumeration types.
14904+
static ExprResult BuiltinVectorMathConversions(Sema &S, Expr *E) {
14905+
// First, convert to an r-value.
14906+
ExprResult Res = S.DefaultFunctionArrayLvalueConversion(E);
14907+
if (Res.isInvalid())
14908+
return ExprError();
14909+
14910+
// Promote floating-point types.
14911+
return S.UsualUnaryFPConversions(Res.get());
14912+
}
14913+
1490214914
bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1490314915
if (checkArgCount(TheCall, 1))
1490414916
return true;
1490514917

14906-
ExprResult A = UsualUnaryConversions(TheCall->getArg(0));
14918+
ExprResult A = BuiltinVectorMathConversions(*this, TheCall->getArg(0));
1490714919
if (A.isInvalid())
1490814920
return true;
1490914921

@@ -14917,63 +14929,96 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1491714929
return false;
1491814930
}
1491914931

14920-
bool Sema::BuiltinElementwiseMath(CallExpr *TheCall) {
14921-
QualType Res;
14922-
if (BuiltinVectorMath(TheCall, Res))
14923-
return true;
14924-
TheCall->setType(Res);
14925-
return false;
14932+
bool Sema::BuiltinElementwiseMath(CallExpr *TheCall, bool FPOnly) {
14933+
if (auto Res = BuiltinVectorMath(TheCall, FPOnly); Res.has_value()) {
14934+
TheCall->setType(*Res);
14935+
return false;
14936+
}
14937+
return true;
1492614938
}
1492714939

1492814940
bool Sema::BuiltinVectorToScalarMath(CallExpr *TheCall) {
14929-
QualType Res;
14930-
if (BuiltinVectorMath(TheCall, Res))
14941+
std::optional<QualType> Res = BuiltinVectorMath(TheCall);
14942+
if (!Res)
1493114943
return true;
1493214944

14933-
if (auto *VecTy0 = Res->getAs<VectorType>())
14945+
if (auto *VecTy0 = (*Res)->getAs<VectorType>())
1493414946
TheCall->setType(VecTy0->getElementType());
1493514947
else
14936-
TheCall->setType(Res);
14948+
TheCall->setType(*Res);
1493714949

1493814950
return false;
1493914951
}
1494014952

14941-
bool Sema::BuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
14953+
static bool checkBuiltinVectorMathMixedEnums(Sema &S, Expr *LHS, Expr *RHS,
14954+
SourceLocation Loc) {
14955+
QualType L = LHS->getEnumCoercedType(S.Context),
14956+
R = RHS->getEnumCoercedType(S.Context);
14957+
if (L->isUnscopedEnumerationType() && R->isUnscopedEnumerationType() &&
14958+
!S.Context.hasSameUnqualifiedType(L, R)) {
14959+
return S.Diag(Loc, diag::err_conv_mixed_enum_types_cxx26)
14960+
<< LHS->getSourceRange() << RHS->getSourceRange()
14961+
<< /*Arithmetic Between*/ 0 << L << R;
14962+
}
14963+
return false;
14964+
}
14965+
14966+
std::optional<QualType> Sema::BuiltinVectorMath(CallExpr *TheCall,
14967+
bool FPOnly) {
1494214968
if (checkArgCount(TheCall, 2))
14943-
return true;
14969+
return std::nullopt;
1494414970

14945-
ExprResult A = TheCall->getArg(0);
14946-
ExprResult B = TheCall->getArg(1);
14947-
// Do standard promotions between the two arguments, returning their common
14948-
// type.
14949-
Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
14950-
if (A.isInvalid() || B.isInvalid())
14951-
return true;
14971+
if (checkBuiltinVectorMathMixedEnums(
14972+
*this, TheCall->getArg(0), TheCall->getArg(1), TheCall->getExprLoc()))
14973+
return std::nullopt;
1495214974

14953-
QualType TyA = A.get()->getType();
14954-
QualType TyB = B.get()->getType();
14975+
Expr *Args[2];
14976+
for (int I = 0; I < 2; ++I) {
14977+
ExprResult Converted =
14978+
BuiltinVectorMathConversions(*this, TheCall->getArg(I));
14979+
if (Converted.isInvalid())
14980+
return std::nullopt;
14981+
Args[I] = Converted.get();
14982+
}
1495514983

14956-
if (Res.isNull() || TyA.getCanonicalType() != TyB.getCanonicalType())
14957-
return Diag(A.get()->getBeginLoc(),
14958-
diag::err_typecheck_call_different_arg_types)
14959-
<< TyA << TyB;
14984+
SourceLocation LocA = Args[0]->getBeginLoc();
14985+
QualType TyA = Args[0]->getType();
14986+
QualType TyB = Args[1]->getType();
1496014987

14961-
if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
14962-
return true;
14988+
if (TyA.getCanonicalType() != TyB.getCanonicalType()) {
14989+
Diag(LocA, diag::err_typecheck_call_different_arg_types) << TyA << TyB;
14990+
return std::nullopt;
14991+
}
1496314992

14964-
TheCall->setArg(0, A.get());
14965-
TheCall->setArg(1, B.get());
14966-
return false;
14993+
if (FPOnly) {
14994+
if (checkFPMathBuiltinElementType(*this, LocA, TyA, 1))
14995+
return std::nullopt;
14996+
} else {
14997+
if (checkMathBuiltinElementType(*this, LocA, TyA, 1))
14998+
return std::nullopt;
14999+
}
15000+
15001+
TheCall->setArg(0, Args[0]);
15002+
TheCall->setArg(1, Args[1]);
15003+
return TyA;
1496715004
}
1496815005

1496915006
bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
1497015007
bool CheckForFloatArgs) {
1497115008
if (checkArgCount(TheCall, 3))
1497215009
return true;
1497315010

15011+
SourceLocation Loc = TheCall->getExprLoc();
15012+
if (checkBuiltinVectorMathMixedEnums(*this, TheCall->getArg(0),
15013+
TheCall->getArg(1), Loc) ||
15014+
checkBuiltinVectorMathMixedEnums(*this, TheCall->getArg(1),
15015+
TheCall->getArg(2), Loc))
15016+
return true;
15017+
1497415018
Expr *Args[3];
1497515019
for (int I = 0; I < 3; ++I) {
14976-
ExprResult Converted = UsualUnaryConversions(TheCall->getArg(I));
15020+
ExprResult Converted =
15021+
BuiltinVectorMathConversions(*this, TheCall->getArg(I));
1497715022
if (Converted.isInvalid())
1497815023
return true;
1497915024
Args[I] = Converted.get();
@@ -15010,7 +15055,7 @@ bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
1501015055
return false;
1501115056
}
1501215057

15013-
bool Sema::PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall) {
15058+
bool Sema::PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall, bool FPOnly) {
1501415059
if (checkArgCount(TheCall, 1))
1501515060
return true;
1501615061

0 commit comments

Comments
 (0)