Skip to content

Commit 8d649fb

Browse files
authored
[CIR][Dialect] Add FMaximumOp and FMinimumOp (#1237)
There are two sets of intrinsics regarding Min and Max operations for floating points [Maximum](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmintrmaximum-llvmmaximumop) vs [Maxnum](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmintrmaxnum-llvmmaxnumop) [Minimum](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmintrminimum-llvmminimumop) vs [Minnum](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmintrminnum-llvmminnumop) [The difference is whether NaN should be propagated when one of the inputs is NaN](https://llvm.org/docs/LangRef.html#llvm-maximumnum-intrinsic) Maxnum and Minnum would return number if one of inputs is NaN, and the other is a number, But Maximum and Minimum would return NaN (propagation of NaN) And they are resolved to different ASM such as [FMAX](https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMAX--vector---Floating-point-Maximum--vector--?lang=en) vs [FMAXNM](https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMAXNM--vector---Floating-point-Maximum-Number--vector--?lang=en) Both have user cases, we already implemented Maxnum and Minnum But Maximum and Minimum has user cases in [neon intrinsic ](https://developer.arm.com/architectures/instruction-sets/intrinsics/vmax_f32 ) and [__builtin_elementwise_maximum ](https://github.com/llvm/clangir/blob/a989ecb2c55da1fe28e4072c31af025cba6c4f0f/clang/test/CodeGen/strictfp-elementwise-bulitins.cpp#L53)
1 parent 49edd4b commit 8d649fb

File tree

4 files changed

+46
-26
lines changed

4 files changed

+46
-26
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+4-2
Original file line numberDiff line numberDiff line change
@@ -4520,8 +4520,10 @@ class BinaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
45204520
}
45214521

45224522
def CopysignOp : BinaryFPToFPBuiltinOp<"copysign", "CopySignOp">;
4523-
def FMaxOp : BinaryFPToFPBuiltinOp<"fmax", "MaxNumOp">;
4524-
def FMinOp : BinaryFPToFPBuiltinOp<"fmin", "MinNumOp">;
4523+
def FMaxNumOp : BinaryFPToFPBuiltinOp<"fmaxnum", "MaxNumOp">;
4524+
def FMinNumOp : BinaryFPToFPBuiltinOp<"fminnum", "MinNumOp">;
4525+
def FMaximumOp : BinaryFPToFPBuiltinOp<"fmaximum", "MaximumOp">;
4526+
def FMinimumOp : BinaryFPToFPBuiltinOp<"fminimum", "MinimumOp">;
45254527
def FModOp : BinaryFPToFPBuiltinOp<"fmod", "FRemOp">;
45264528
def PowOp : BinaryFPToFPBuiltinOp<"pow", "PowOp">;
45274529

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
670670
case Builtin::BI__builtin_fmaxf:
671671
case Builtin::BI__builtin_fmaxl:
672672
return RValue::get(
673-
emitBinaryMaybeConstrainedFPBuiltin<cir::FMaxOp>(*this, *E));
673+
emitBinaryMaybeConstrainedFPBuiltin<cir::FMaxNumOp>(*this, *E));
674674

675675
case Builtin::BI__builtin_fmaxf16:
676676
case Builtin::BI__builtin_fmaxf128:
@@ -683,7 +683,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
683683
case Builtin::BI__builtin_fminf:
684684
case Builtin::BI__builtin_fminl:
685685
return RValue::get(
686-
emitBinaryMaybeConstrainedFPBuiltin<cir::FMinOp>(*this, *E));
686+
emitBinaryMaybeConstrainedFPBuiltin<cir::FMinNumOp>(*this, *E));
687687

688688
case Builtin::BI__builtin_fminf16:
689689
case Builtin::BI__builtin_fminf128:

clang/test/CIR/CodeGen/builtin-floating-point.c

+16-16
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,7 @@ long double call_copysignl(long double x, long double y) {
13001300
float my_fmaxf(float x, float y) {
13011301
return __builtin_fmaxf(x, y);
13021302
// CHECK: cir.func @my_fmaxf
1303-
// CHECK: %{{.+}} = cir.fmax %{{.+}}, %{{.+}} : !cir.float
1303+
// CHECK: %{{.+}} = cir.fmaxnum %{{.+}}, %{{.+}} : !cir.float
13041304

13051305
// LLVM: define dso_local float @my_fmaxf
13061306
// LLVM: %{{.+}} = call float @llvm.maxnum.f32(float %{{.+}}, float %{{.+}})
@@ -1310,7 +1310,7 @@ float my_fmaxf(float x, float y) {
13101310
double my_fmax(double x, double y) {
13111311
return __builtin_fmax(x, y);
13121312
// CHECK: cir.func @my_fmax
1313-
// CHECK: %{{.+}} = cir.fmax %{{.+}}, %{{.+}} : !cir.double
1313+
// CHECK: %{{.+}} = cir.fmaxnum %{{.+}}, %{{.+}} : !cir.double
13141314

13151315
// LLVM: define dso_local double @my_fmax
13161316
// LLVM: %{{.+}} = call double @llvm.maxnum.f64(double %{{.+}}, double %{{.+}})
@@ -1320,8 +1320,8 @@ double my_fmax(double x, double y) {
13201320
long double my_fmaxl(long double x, long double y) {
13211321
return __builtin_fmaxl(x, y);
13221322
// CHECK: cir.func @my_fmaxl
1323-
// CHECK: %{{.+}} = cir.fmax %{{.+}}, %{{.+}} : !cir.long_double<!cir.f80>
1324-
// AARCH64: %{{.+}} = cir.fmax %{{.+}}, %{{.+}} : !cir.long_double<!cir.double>
1323+
// CHECK: %{{.+}} = cir.fmaxnum %{{.+}}, %{{.+}} : !cir.long_double<!cir.f80>
1324+
// AARCH64: %{{.+}} = cir.fmaxnum %{{.+}}, %{{.+}} : !cir.long_double<!cir.double>
13251325

13261326
// LLVM: define dso_local x86_fp80 @my_fmaxl
13271327
// LLVM: %{{.+}} = call x86_fp80 @llvm.maxnum.f80(x86_fp80 %{{.+}}, x86_fp80 %{{.+}})
@@ -1335,7 +1335,7 @@ long double fmaxl(long double, long double);
13351335
float call_fmaxf(float x, float y) {
13361336
return fmaxf(x, y);
13371337
// CHECK: cir.func @call_fmaxf
1338-
// CHECK: %{{.+}} = cir.fmax %{{.+}}, %{{.+}} : !cir.float
1338+
// CHECK: %{{.+}} = cir.fmaxnum %{{.+}}, %{{.+}} : !cir.float
13391339

13401340
// LLVM: define dso_local float @call_fmaxf
13411341
// LLVM: %{{.+}} = call float @llvm.maxnum.f32(float %{{.+}}, float %{{.+}})
@@ -1345,7 +1345,7 @@ float call_fmaxf(float x, float y) {
13451345
double call_fmax(double x, double y) {
13461346
return fmax(x, y);
13471347
// CHECK: cir.func @call_fmax
1348-
// CHECK: %{{.+}} = cir.fmax %{{.+}}, %{{.+}} : !cir.double
1348+
// CHECK: %{{.+}} = cir.fmaxnum %{{.+}}, %{{.+}} : !cir.double
13491349

13501350
// LLVM: define dso_local double @call_fmax
13511351
// LLVM: %{{.+}} = call double @llvm.maxnum.f64(double %{{.+}}, double %{{.+}})
@@ -1355,8 +1355,8 @@ double call_fmax(double x, double y) {
13551355
long double call_fmaxl(long double x, long double y) {
13561356
return fmaxl(x, y);
13571357
// CHECK: cir.func @call_fmaxl
1358-
// CHECK: %{{.+}} = cir.fmax %{{.+}}, %{{.+}} : !cir.long_double<!cir.f80>
1359-
// AARCH64: %{{.+}} = cir.fmax %{{.+}}, %{{.+}} : !cir.long_double<!cir.double>
1358+
// CHECK: %{{.+}} = cir.fmaxnum %{{.+}}, %{{.+}} : !cir.long_double<!cir.f80>
1359+
// AARCH64: %{{.+}} = cir.fmaxnum %{{.+}}, %{{.+}} : !cir.long_double<!cir.double>
13601360

13611361
// LLVM: define dso_local x86_fp80 @call_fmaxl
13621362
// LLVM: %{{.+}} = call x86_fp80 @llvm.maxnum.f80(x86_fp80 %{{.+}}, x86_fp80 %{{.+}})
@@ -1368,7 +1368,7 @@ long double call_fmaxl(long double x, long double y) {
13681368
float my_fminf(float x, float y) {
13691369
return __builtin_fminf(x, y);
13701370
// CHECK: cir.func @my_fminf
1371-
// CHECK: %{{.+}} = cir.fmin %{{.+}}, %{{.+}} : !cir.float
1371+
// CHECK: %{{.+}} = cir.fminnum %{{.+}}, %{{.+}} : !cir.float
13721372

13731373
// LLVM: define dso_local float @my_fminf
13741374
// LLVM: %{{.+}} = call float @llvm.minnum.f32(float %{{.+}}, float %{{.+}})
@@ -1378,7 +1378,7 @@ float my_fminf(float x, float y) {
13781378
double my_fmin(double x, double y) {
13791379
return __builtin_fmin(x, y);
13801380
// CHECK: cir.func @my_fmin
1381-
// CHECK: %{{.+}} = cir.fmin %{{.+}}, %{{.+}} : !cir.double
1381+
// CHECK: %{{.+}} = cir.fminnum %{{.+}}, %{{.+}} : !cir.double
13821382

13831383
// LLVM: define dso_local double @my_fmin
13841384
// LLVM: %{{.+}} = call double @llvm.minnum.f64(double %{{.+}}, double %{{.+}})
@@ -1388,8 +1388,8 @@ double my_fmin(double x, double y) {
13881388
long double my_fminl(long double x, long double y) {
13891389
return __builtin_fminl(x, y);
13901390
// CHECK: cir.func @my_fminl
1391-
// CHECK: %{{.+}} = cir.fmin %{{.+}}, %{{.+}} : !cir.long_double<!cir.f80>
1392-
// AARCH64: %{{.+}} = cir.fmin %{{.+}}, %{{.+}} : !cir.long_double<!cir.double>
1391+
// CHECK: %{{.+}} = cir.fminnum %{{.+}}, %{{.+}} : !cir.long_double<!cir.f80>
1392+
// AARCH64: %{{.+}} = cir.fminnum %{{.+}}, %{{.+}} : !cir.long_double<!cir.double>
13931393

13941394
// LLVM: define dso_local x86_fp80 @my_fminl
13951395
// LLVM: %{{.+}} = call x86_fp80 @llvm.minnum.f80(x86_fp80 %{{.+}}, x86_fp80 %{{.+}})
@@ -1403,7 +1403,7 @@ long double fminl(long double, long double);
14031403
float call_fminf(float x, float y) {
14041404
return fminf(x, y);
14051405
// CHECK: cir.func @call_fminf
1406-
// CHECK: %{{.+}} = cir.fmin %{{.+}}, %{{.+}} : !cir.float
1406+
// CHECK: %{{.+}} = cir.fminnum %{{.+}}, %{{.+}} : !cir.float
14071407

14081408
// LLVM: define dso_local float @call_fminf
14091409
// LLVM: %{{.+}} = call float @llvm.minnum.f32(float %{{.+}}, float %{{.+}})
@@ -1413,7 +1413,7 @@ float call_fminf(float x, float y) {
14131413
double call_fmin(double x, double y) {
14141414
return fmin(x, y);
14151415
// CHECK: cir.func @call_fmin
1416-
// CHECK: %{{.+}} = cir.fmin %{{.+}}, %{{.+}} : !cir.double
1416+
// CHECK: %{{.+}} = cir.fminnum %{{.+}}, %{{.+}} : !cir.double
14171417

14181418
// LLVM: define dso_local double @call_fmin
14191419
// LLVM: %{{.+}} = call double @llvm.minnum.f64(double %{{.+}}, double %{{.+}})
@@ -1423,8 +1423,8 @@ double call_fmin(double x, double y) {
14231423
long double call_fminl(long double x, long double y) {
14241424
return fminl(x, y);
14251425
// CHECK: cir.func @call_fminl
1426-
// CHECK: %{{.+}} = cir.fmin %{{.+}}, %{{.+}} : !cir.long_double<!cir.f80>
1427-
// AARCH64: %{{.+}} = cir.fmin %{{.+}}, %{{.+}} : !cir.long_double<!cir.double>
1426+
// CHECK: %{{.+}} = cir.fminnum %{{.+}}, %{{.+}} : !cir.long_double<!cir.f80>
1427+
// AARCH64: %{{.+}} = cir.fminnum %{{.+}}, %{{.+}} : !cir.long_double<!cir.double>
14281428

14291429
// LLVM: define dso_local x86_fp80 @call_fminl
14301430
// LLVM: %{{.+}} = call x86_fp80 @llvm.minnum.f80(x86_fp80 %{{.+}}, x86_fp80 %{{.+}})

clang/test/CIR/Lowering/builtin-floating-point.cir

+24-6
Original file line numberDiff line numberDiff line change
@@ -138,22 +138,22 @@ module {
138138
%215 = cir.copysign %arg2, %arg2 : !cir.vector<!cir.float x 4>
139139
// CHECK: llvm.intr.copysign(%arg2, %arg2) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
140140

141-
%16 = cir.fmax %arg0, %arg0 : !cir.float
141+
%16 = cir.fmaxnum %arg0, %arg0 : !cir.float
142142
// CHECK: llvm.intr.maxnum(%arg0, %arg0) : (f32, f32) -> f32
143143

144-
%116 = cir.fmax %arg1, %arg1 : !cir.vector<!cir.double x 2>
144+
%116 = cir.fmaxnum %arg1, %arg1 : !cir.vector<!cir.double x 2>
145145
// CHECK: llvm.intr.maxnum(%arg1, %arg1) : (vector<2xf64>, vector<2xf64>) -> vector<2xf64>
146146

147-
%216 = cir.fmax %arg2, %arg2 : !cir.vector<!cir.float x 4>
147+
%216 = cir.fmaxnum %arg2, %arg2 : !cir.vector<!cir.float x 4>
148148
// CHECK: llvm.intr.maxnum(%arg2, %arg2) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
149149

150-
%17 = cir.fmin %arg0, %arg0 : !cir.float
150+
%17 = cir.fminnum %arg0, %arg0 : !cir.float
151151
// CHECK: llvm.intr.minnum(%arg0, %arg0) : (f32, f32) -> f32
152152

153-
%117 = cir.fmin %arg1, %arg1 : !cir.vector<!cir.double x 2>
153+
%117 = cir.fminnum %arg1, %arg1 : !cir.vector<!cir.double x 2>
154154
// CHECK: llvm.intr.minnum(%arg1, %arg1) : (vector<2xf64>, vector<2xf64>) -> vector<2xf64>
155155

156-
%217 = cir.fmin %arg2, %arg2 : !cir.vector<!cir.float x 4>
156+
%217 = cir.fminnum %arg2, %arg2 : !cir.vector<!cir.float x 4>
157157
// CHECK: llvm.intr.minnum(%arg2, %arg2) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
158158

159159
%18 = cir.fmod %arg0, %arg0 : !cir.float
@@ -174,6 +174,24 @@ module {
174174
%219 = cir.pow %arg2, %arg2 : !cir.vector<!cir.float x 4>
175175
// CHECK: llvm.intr.pow(%arg2, %arg2) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
176176

177+
%20 = cir.fmaximum %arg0, %arg0 : !cir.float
178+
// CHECK: llvm.intr.maximum(%arg0, %arg0) : (f32, f32) -> f32
179+
180+
%120 = cir.fmaximum %arg1, %arg1 : !cir.vector<!cir.double x 2>
181+
// CHECK: llvm.intr.maximum(%arg1, %arg1) : (vector<2xf64>, vector<2xf64>) -> vector<2xf64>
182+
183+
%220 = cir.fmaximum %arg2, %arg2 : !cir.vector<!cir.float x 4>
184+
// CHECK: llvm.intr.maximum(%arg2, %arg2) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
185+
186+
%21 = cir.fminimum %arg0, %arg0 : !cir.float
187+
// CHECK: llvm.intr.minimum(%arg0, %arg0) : (f32, f32) -> f32
188+
189+
%121 = cir.fminimum %arg1, %arg1 : !cir.vector<!cir.double x 2>
190+
// CHECK: llvm.intr.minimum(%arg1, %arg1) : (vector<2xf64>, vector<2xf64>) -> vector<2xf64>
191+
192+
%221 = cir.fminimum %arg2, %arg2 : !cir.vector<!cir.float x 4>
193+
// CHECK: llvm.intr.minimum(%arg2, %arg2) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
194+
177195
cir.return
178196
}
179197
}

0 commit comments

Comments
 (0)