Skip to content

Commit 50ad443

Browse files
committed
[WebAssembly] Support promoting lower lanes of f16x8 to f32x4.
1 parent 9102afc commit 50ad443

File tree

6 files changed

+80
-15
lines changed

6 files changed

+80
-15
lines changed

Diff for: clang/lib/Headers/wasm_simd128.h

+9
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ typedef int __i32x2 __attribute__((__vector_size__(8), __aligned__(8)));
4545
typedef unsigned int __u32x2
4646
__attribute__((__vector_size__(8), __aligned__(8)));
4747
typedef float __f32x2 __attribute__((__vector_size__(8), __aligned__(8)));
48+
typedef __fp16 __f16x4 __attribute__((__vector_size__(8), __aligned__(8)));
4849

4950
#define __DEFAULT_FN_ATTRS \
5051
__attribute__((__always_inline__, __nodebug__, __target__("simd128"), \
@@ -2010,6 +2011,14 @@ static __inline__ v128_t __FP16_FN_ATTRS wasm_f16x8_convert_u16x8(v128_t __a) {
20102011
return (v128_t) __builtin_convertvector((__u16x8)__a, __f16x8);
20112012
}
20122013

2014+
static __inline__ v128_t __FP16_FN_ATTRS
2015+
wasm_f32x4_promote_low_f16x8(v128_t __a) {
2016+
return (v128_t) __builtin_convertvector(
2017+
(__f16x4){((__f16x8)__a)[0], ((__f16x8)__a)[1], ((__f16x8)__a)[2],
2018+
((__f16x8)__a)[3]},
2019+
__f32x4);
2020+
}
2021+
20132022
static __inline__ v128_t __FP16_FN_ATTRS wasm_f16x8_relaxed_madd(v128_t __a,
20142023
v128_t __b,
20152024
v128_t __c) {

Diff for: cross-project-tests/intrinsic-header-tests/wasm_simd128.c

+6
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,12 @@ v128_t test_f64x2_promote_low_f32x4(v128_t a) {
10331033
return wasm_f64x2_promote_low_f32x4(a);
10341034
}
10351035

1036+
// CHECK-LABEL: test_f32x4_promote_low_f16x8:
1037+
// CHECK: f32x4.promote_low_f16x8{{$}}
1038+
v128_t test_f32x4_promote_low_f16x8(v128_t a) {
1039+
return wasm_f32x4_promote_low_f16x8(a);
1040+
}
1041+
10361042
// CHECK-LABEL: test_i8x16_shuffle:
10371043
// CHECK: i8x16.shuffle 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
10381044
// 0{{$}}

Diff for: llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

+40-15
Original file line numberDiff line numberDiff line change
@@ -2341,7 +2341,7 @@ WebAssemblyTargetLowering::LowerEXTEND_VECTOR_INREG(SDValue Op,
23412341

23422342
static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
23432343
SDLoc DL(Op);
2344-
if (Op.getValueType() != MVT::v2f64)
2344+
if (Op.getValueType() != MVT::v2f64 && Op.getValueType() != MVT::v4f32)
23452345
return SDValue();
23462346

23472347
auto GetConvertedLane = [](SDValue Op, unsigned &Opcode, SDValue &SrcVec,
@@ -2354,6 +2354,7 @@ static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
23542354
Opcode = WebAssemblyISD::CONVERT_LOW_U;
23552355
break;
23562356
case ISD::FP_EXTEND:
2357+
case ISD::FP16_TO_FP:
23572358
Opcode = WebAssemblyISD::PROMOTE_LOW;
23582359
break;
23592360
default:
@@ -2372,36 +2373,60 @@ static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
23722373
return true;
23732374
};
23742375

2375-
unsigned LHSOpcode, RHSOpcode, LHSIndex, RHSIndex;
2376-
SDValue LHSSrcVec, RHSSrcVec;
2377-
if (!GetConvertedLane(Op.getOperand(0), LHSOpcode, LHSSrcVec, LHSIndex) ||
2378-
!GetConvertedLane(Op.getOperand(1), RHSOpcode, RHSSrcVec, RHSIndex))
2376+
unsigned NumLanes = Op.getValueType() == MVT::v2f64 ? 2 : 4;
2377+
unsigned FirstOpcode = 0, SecondOpcode = 0, ThirdOpcode = 0, FourthOpcode = 0;
2378+
unsigned FirstIndex = 0, SecondIndex = 0, ThirdIndex = 0, FourthIndex = 0;
2379+
SDValue FirstSrcVec, SecondSrcVec, ThirdSrcVec, FourthSrcVec;
2380+
2381+
if (!GetConvertedLane(Op.getOperand(0), FirstOpcode, FirstSrcVec,
2382+
FirstIndex) ||
2383+
!GetConvertedLane(Op.getOperand(1), SecondOpcode, SecondSrcVec,
2384+
SecondIndex))
2385+
return SDValue();
2386+
2387+
// If we're converting to v4f32, check the third and fourth lanes, too.
2388+
if (NumLanes == 4 && (!GetConvertedLane(Op.getOperand(2), ThirdOpcode,
2389+
ThirdSrcVec, ThirdIndex) ||
2390+
!GetConvertedLane(Op.getOperand(3), FourthOpcode,
2391+
FourthSrcVec, FourthIndex)))
2392+
return SDValue();
2393+
2394+
if (FirstOpcode != SecondOpcode)
23792395
return SDValue();
23802396

2381-
if (LHSOpcode != RHSOpcode)
2397+
// TODO Add an optimization similar to the v2f64 below for shuffling the
2398+
// vectors when the lanes are in the wrong order or come from different src
2399+
// vectors.
2400+
if (NumLanes == 4 &&
2401+
(FirstOpcode != ThirdOpcode || FirstOpcode != FourthOpcode ||
2402+
FirstSrcVec != SecondSrcVec || FirstSrcVec != ThirdSrcVec ||
2403+
FirstSrcVec != FourthSrcVec || FirstIndex != 0 || SecondIndex != 1 ||
2404+
ThirdIndex != 2 || FourthIndex != 3))
23822405
return SDValue();
23832406

23842407
MVT ExpectedSrcVT;
2385-
switch (LHSOpcode) {
2408+
switch (FirstOpcode) {
23862409
case WebAssemblyISD::CONVERT_LOW_S:
23872410
case WebAssemblyISD::CONVERT_LOW_U:
23882411
ExpectedSrcVT = MVT::v4i32;
23892412
break;
23902413
case WebAssemblyISD::PROMOTE_LOW:
2391-
ExpectedSrcVT = MVT::v4f32;
2414+
ExpectedSrcVT = NumLanes == 2 ? MVT::v4f32 : MVT::v8i16;
23922415
break;
23932416
}
2394-
if (LHSSrcVec.getValueType() != ExpectedSrcVT)
2417+
if (FirstSrcVec.getValueType() != ExpectedSrcVT)
23952418
return SDValue();
23962419

2397-
auto Src = LHSSrcVec;
2398-
if (LHSIndex != 0 || RHSIndex != 1 || LHSSrcVec != RHSSrcVec) {
2420+
auto Src = FirstSrcVec;
2421+
if (NumLanes == 2 &&
2422+
(FirstIndex != 0 || SecondIndex != 1 || FirstSrcVec != SecondSrcVec)) {
23992423
// Shuffle the source vector so that the converted lanes are the low lanes.
2400-
Src = DAG.getVectorShuffle(
2401-
ExpectedSrcVT, DL, LHSSrcVec, RHSSrcVec,
2402-
{static_cast<int>(LHSIndex), static_cast<int>(RHSIndex) + 4, -1, -1});
2424+
Src = DAG.getVectorShuffle(ExpectedSrcVT, DL, FirstSrcVec, SecondSrcVec,
2425+
{static_cast<int>(FirstIndex),
2426+
static_cast<int>(SecondIndex) + 4, -1, -1});
24032427
}
2404-
return DAG.getNode(LHSOpcode, DL, MVT::v2f64, Src);
2428+
return DAG.getNode(FirstOpcode, DL, NumLanes == 2 ? MVT::v2f64 : MVT::v4f32,
2429+
Src);
24052430
}
24062431

24072432
SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op,

Diff for: llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td

+2
Original file line numberDiff line numberDiff line change
@@ -1468,6 +1468,8 @@ defm "" : SIMDConvert<F32x4, F64x2, demote_zero,
14681468
def promote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
14691469
def promote_low : SDNode<"WebAssemblyISD::PROMOTE_LOW", promote_t>;
14701470
defm "" : SIMDConvert<F64x2, F32x4, promote_low, "promote_low_f32x4", 0x5f>;
1471+
defm "" : HalfPrecisionConvert<F32x4, I16x8, promote_low, "promote_low_f16x8",
1472+
0x14b>;
14711473

14721474
// Lower extending loads to load64_zero + promote_low
14731475
def extloadv2f32 : PatFrag<(ops node:$ptr), (extload node:$ptr)> {

Diff for: llvm/test/CodeGen/WebAssembly/half-precision.ll

+20
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,23 @@ define <8 x half> @shuffle_poison_v8f16(<8 x half> %x, <8 x half> %y) {
369369
i32 poison, i32 poison, i32 poison, i32 poison>
370370
ret <8 x half> %res
371371
}
372+
373+
define <4 x float> @promote_low_v4f32(<8 x half> %x) {
374+
; CHECK-LABEL: promote_low_v4f32:
375+
; CHECK: .functype promote_low_v4f32 (v128) -> (v128){{$}}
376+
; CHECK-NEXT: f32x4.promote_low_f16x8 $push[[R:[0-9]+]]=, $0
377+
; CHECK-NEXT: return $pop[[R]]
378+
%v = shufflevector <8 x half> %x, <8 x half> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
379+
%a = fpext <4 x half> %v to <4 x float>
380+
ret <4 x float> %a
381+
}
382+
383+
define <4 x float> @promote_low_v4f32_2(<8 x half> %x) {
384+
; CHECK-LABEL: promote_low_v4f32_2:
385+
; CHECK: .functype promote_low_v4f32_2 (v128) -> (v128)
386+
; CHECK-NEXT: f32x4.promote_low_f16x8 $push[[R:[0-9]+]]=, $0
387+
; CHECK-NEXT: return $pop[[R]]
388+
%v = fpext <8 x half> %x to <8 x float>
389+
%a = shufflevector <8 x float> %v, <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
390+
ret <4 x float> %a
391+
}

Diff for: llvm/test/MC/WebAssembly/simd-encodings.s

+3
Original file line numberDiff line numberDiff line change
@@ -935,4 +935,7 @@ main:
935935
# CHECK: f16x8.convert_i16x8_u # encoding: [0xfd,0xc8,0x02]
936936
f16x8.convert_i16x8_u
937937

938+
# CHECK: f32x4.promote_low_f16x8 # encoding: [0xfd,0xcb,0x02]
939+
f32x4.promote_low_f16x8
940+
938941
end_function

0 commit comments

Comments
 (0)