Skip to content

Commit c3ecfd3

Browse files
steffenlarsenvmaksimo
authored andcommitted
[Backport to 16] Implement SPV_INTEL_ternary_bitwise_function (KhronosGroup#3104)
This commit implements support for bi-directional translation of the `BitwiseFunctionINTEL` operation added in https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_ternary_bitwise_function.html together with the corresponding capability. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent f235666 commit c3ecfd3

File tree

5 files changed

+136
-0
lines changed

5 files changed

+136
-0
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@ EXT(SPV_INTEL_bindless_images)
7575
EXT(SPV_INTEL_2d_block_io)
7676
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
7777
EXT(SPV_KHR_bfloat16)
78+
EXT(SPV_INTEL_ternary_bitwise_function)

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4040,5 +4040,60 @@ class SPIRVSubgroupMatrixMultiplyAccumulateINTELInst
40404040
_SPIRV_OP(SubgroupMatrixMultiplyAccumulate, true, 7, true, 4)
40414041
#undef _SPIRV_OP
40424042

4043+
class SPIRVTernaryBitwiseFunctionINTELInst : public SPIRVInstTemplateBase {
4044+
public:
4045+
void validate() const override {
4046+
SPIRVInstruction::validate();
4047+
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
4048+
std::string InstName = "BitwiseFunctionINTEL";
4049+
4050+
const SPIRVType *ResTy = this->getType();
4051+
SPVErrLog.checkError(
4052+
ResTy->isTypeInt() || (ResTy->isTypeVector() &&
4053+
ResTy->getVectorComponentType()->isTypeInt()),
4054+
SPIRVEC_InvalidInstruction,
4055+
InstName + "\nResult type must be an integer scalar or vector.\n");
4056+
4057+
auto CommonArgCheck = [this, ResTy, &InstName,
4058+
&SPVErrLog](size_t ArgI, const char *ArgPlacement) {
4059+
SPIRVValue *Arg =
4060+
const_cast<SPIRVTernaryBitwiseFunctionINTELInst *>(this)->getOperand(
4061+
ArgI);
4062+
SPVErrLog.checkError(
4063+
Arg->getType() == ResTy, SPIRVEC_InvalidInstruction,
4064+
InstName + "\n" + ArgPlacement +
4065+
" argument must be the same as the result type.\n");
4066+
};
4067+
4068+
CommonArgCheck(0, "First");
4069+
CommonArgCheck(1, "Second");
4070+
CommonArgCheck(2, "Third");
4071+
4072+
SPIRVValue *LUTIndexArg =
4073+
const_cast<SPIRVTernaryBitwiseFunctionINTELInst *>(this)->getOperand(3);
4074+
const SPIRVType *LUTIndexArgTy = LUTIndexArg->getType();
4075+
SPVErrLog.checkError(
4076+
LUTIndexArgTy->isTypeInt(32), SPIRVEC_InvalidInstruction,
4077+
InstName + "\nFourth argument must be a 32-bit integer scalar.\n");
4078+
SPVErrLog.checkError(
4079+
isConstantOpCode(LUTIndexArg->getOpCode()), SPIRVEC_InvalidInstruction,
4080+
InstName + "\nFourth argument must be constant instruction.\n");
4081+
}
4082+
4083+
std::optional<ExtensionID> getRequiredExtension() const override {
4084+
return ExtensionID::SPV_INTEL_ternary_bitwise_function;
4085+
}
4086+
SPIRVCapVec getRequiredCapability() const override {
4087+
return getVec(CapabilityTernaryBitwiseFunctionINTEL);
4088+
}
4089+
};
4090+
4091+
#define _SPIRV_OP(x, ...) \
4092+
typedef SPIRVInstTemplate<SPIRVTernaryBitwiseFunctionINTELInst, \
4093+
Op##x##INTEL, __VA_ARGS__> \
4094+
SPIRV##x##INTEL;
4095+
_SPIRV_OP(BitwiseFunction, true, 7)
4096+
#undef _SPIRV_OP
4097+
40434098
} // namespace SPIRV
40444099
#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
639639
add(CapabilitySubgroup2DBlockTransposeINTEL, "Subgroup2DBlockTransposeINTEL");
640640
add(CapabilitySubgroupMatrixMultiplyAccumulateINTEL,
641641
"SubgroupMatrixMultiplyAccumulateINTEL");
642+
add(CapabilityTernaryBitwiseFunctionINTEL, "TernaryBitwiseFunctionINTEL");
642643

643644
// From spirv_internal.hpp
644645
add(internal::CapabilityFastCompositeINTEL, "FastCompositeINTEL");

lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ _SPIRV_OP(Subgroup2DBlockLoadTransposeINTEL, 6233)
570570
_SPIRV_OP(Subgroup2DBlockPrefetchINTEL, 6234)
571571
_SPIRV_OP(Subgroup2DBlockStoreINTEL, 6235)
572572
_SPIRV_OP(SubgroupMatrixMultiplyAccumulateINTEL, 6237)
573+
_SPIRV_OP(BitwiseFunctionINTEL, 6242)
573574
_SPIRV_OP(GroupIMulKHR, 6401)
574575
_SPIRV_OP(GroupFMulKHR, 6402)
575576
_SPIRV_OP(GroupBitwiseAndKHR, 6403)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_ternary_bitwise_function -o %t.spv
3+
; RUN: llvm-spirv %t.spv --to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
8+
9+
; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
10+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
11+
; CHECK-ERROR-NEXT: SPV_INTEL_ternary_bitwise_function
12+
13+
; CHECK-SPIRV-NOT: Name [[#]] "_Z28__spirv_BitwiseFunctionINTELiiij"
14+
; CHECK-SPIRV-NOT: Name [[#]] "_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_j"
15+
16+
; CHECK-SPIRV-DAG: Capability TernaryBitwiseFunctionINTEL
17+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_ternary_bitwise_function"
18+
19+
; CHECK-SPIRV-DAG: TypeInt [[#TYPEINT:]] 32 0
20+
; CHECK-SPIRV-DAG: TypeVector [[#TYPEINTVEC4:]] [[#TYPEINT]] 4
21+
; CHECK-SPIRV-DAG: Constant [[#TYPEINT]] [[#ScalarLUT:]] 24
22+
; CHECK-SPIRV-DAG: Constant [[#TYPEINT]] [[#VecLUT:]] 42
23+
24+
; CHECK-SPIRV: Load [[#TYPEINT]] [[#ScalarA:]]
25+
; CHECK-SPIRV: Load [[#TYPEINT]] [[#ScalarB:]]
26+
; CHECK-SPIRV: Load [[#TYPEINT]] [[#ScalarC:]]
27+
; CHECK-SPIRV: BitwiseFunctionINTEL [[#TYPEINT]] {{.*}} [[#ScalarA]] [[#ScalarB]] [[#ScalarC]] [[#ScalarLUT]]
28+
; CHECK-SPIRV: Load [[#TYPEINTVEC4]] [[#VecA:]]
29+
; CHECK-SPIRV: Load [[#TYPEINTVEC4]] [[#VecB:]]
30+
; CHECK-SPIRV: Load [[#TYPEINTVEC4]] [[#VecC:]]
31+
; CHECK-SPIRV: BitwiseFunctionINTEL [[#TYPEINTVEC4]] {{.*}} [[#VecA]] [[#VecB]] [[#VecC]] [[#VecLUT]]
32+
33+
; CHECK-LLVM: %[[ScalarA:.*]] = load i32, ptr
34+
; CHECK-LLVM: %[[ScalarB:.*]] = load i32, ptr
35+
; CHECK-LLVM: %[[ScalarC:.*]] = load i32, ptr
36+
; CHECK-LLVM: call spir_func i32 @_Z28__spirv_BitwiseFunctionINTELiiii(i32 %[[ScalarA]], i32 %[[ScalarB]], i32 %[[ScalarC]], i32 24)
37+
; CHECK-LLVM: %[[VecA:.*]] = load <4 x i32>, ptr
38+
; CHECK-LLVM: %[[VecB:.*]] = load <4 x i32>, ptr
39+
; CHECK-LLVM: %[[VecC:.*]] = load <4 x i32>, ptr
40+
; CHECK-LLVM: call spir_func <4 x i32> @_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_i(<4 x i32> %[[VecA]], <4 x i32> %[[VecB]], <4 x i32> %[[VecC]], i32 42)
41+
42+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
43+
target triple = "spir"
44+
45+
; Function Attrs: nounwind readnone
46+
define spir_kernel void @fooScalar() {
47+
entry:
48+
%argA = alloca i32
49+
%argB = alloca i32
50+
%argC = alloca i32
51+
%A = load i32, ptr %argA
52+
%B = load i32, ptr %argB
53+
%C = load i32, ptr %argC
54+
%res = call spir_func i32 @_Z28__spirv_BitwiseFunctionINTELiiii(i32 %A, i32 %B, i32 %C, i32 24)
55+
ret void
56+
}
57+
58+
; Function Attrs: nounwind readnone
59+
define spir_kernel void @fooVec() {
60+
entry:
61+
%argA = alloca <4 x i32>
62+
%argB = alloca <4 x i32>
63+
%argC = alloca <4 x i32>
64+
%A = load <4 x i32>, ptr %argA
65+
%B = load <4 x i32>, ptr %argB
66+
%C = load <4 x i32>, ptr %argC
67+
%res = call spir_func <4 x i32> @_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_i(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C, i32 42)
68+
ret void
69+
}
70+
71+
declare dso_local spir_func i32 @_Z28__spirv_BitwiseFunctionINTELiiii(i32, i32, i32, i32)
72+
declare dso_local spir_func <4 x i32> @_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_i(<4 x i32>, <4 x i32>, <4 x i32>, i32)
73+
74+
!llvm.module.flags = !{!0}
75+
!opencl.spir.version = !{!1}
76+
77+
!0 = !{i32 1, !"wchar_size", i32 4}
78+
!1 = !{i32 1, i32 2}

0 commit comments

Comments
 (0)