Skip to content

Commit 7a80643

Browse files
steffenlarsenvmaksimo
authored andcommitted
[Backport to 17] Implement SPV_INTEL_ternary_bitwise_function (#3104)
This commit implements support for bi-directional translation of the `BitwiseFunctionINTEL` operation added in https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_ternary_bitwise_function.html together with the corresponding capability. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent c392352 commit 7a80643

File tree

5 files changed

+136
-0
lines changed

5 files changed

+136
-0
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,4 @@ EXT(SPV_INTEL_bindless_images)
7676
EXT(SPV_INTEL_2d_block_io)
7777
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
7878
EXT(SPV_KHR_bfloat16)
79+
EXT(SPV_INTEL_ternary_bitwise_function)

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4034,5 +4034,60 @@ class SPIRVSubgroupMatrixMultiplyAccumulateINTELInst
40344034
_SPIRV_OP(SubgroupMatrixMultiplyAccumulate, true, 7, true, 4)
40354035
#undef _SPIRV_OP
40364036

4037+
class SPIRVTernaryBitwiseFunctionINTELInst : public SPIRVInstTemplateBase {
4038+
public:
4039+
void validate() const override {
4040+
SPIRVInstruction::validate();
4041+
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
4042+
std::string InstName = "BitwiseFunctionINTEL";
4043+
4044+
const SPIRVType *ResTy = this->getType();
4045+
SPVErrLog.checkError(
4046+
ResTy->isTypeInt() || (ResTy->isTypeVector() &&
4047+
ResTy->getVectorComponentType()->isTypeInt()),
4048+
SPIRVEC_InvalidInstruction,
4049+
InstName + "\nResult type must be an integer scalar or vector.\n");
4050+
4051+
auto CommonArgCheck = [this, ResTy, &InstName,
4052+
&SPVErrLog](size_t ArgI, const char *ArgPlacement) {
4053+
SPIRVValue *Arg =
4054+
const_cast<SPIRVTernaryBitwiseFunctionINTELInst *>(this)->getOperand(
4055+
ArgI);
4056+
SPVErrLog.checkError(
4057+
Arg->getType() == ResTy, SPIRVEC_InvalidInstruction,
4058+
InstName + "\n" + ArgPlacement +
4059+
" argument must be the same as the result type.\n");
4060+
};
4061+
4062+
CommonArgCheck(0, "First");
4063+
CommonArgCheck(1, "Second");
4064+
CommonArgCheck(2, "Third");
4065+
4066+
SPIRVValue *LUTIndexArg =
4067+
const_cast<SPIRVTernaryBitwiseFunctionINTELInst *>(this)->getOperand(3);
4068+
const SPIRVType *LUTIndexArgTy = LUTIndexArg->getType();
4069+
SPVErrLog.checkError(
4070+
LUTIndexArgTy->isTypeInt(32), SPIRVEC_InvalidInstruction,
4071+
InstName + "\nFourth argument must be a 32-bit integer scalar.\n");
4072+
SPVErrLog.checkError(
4073+
isConstantOpCode(LUTIndexArg->getOpCode()), SPIRVEC_InvalidInstruction,
4074+
InstName + "\nFourth argument must be constant instruction.\n");
4075+
}
4076+
4077+
std::optional<ExtensionID> getRequiredExtension() const override {
4078+
return ExtensionID::SPV_INTEL_ternary_bitwise_function;
4079+
}
4080+
SPIRVCapVec getRequiredCapability() const override {
4081+
return getVec(CapabilityTernaryBitwiseFunctionINTEL);
4082+
}
4083+
};
4084+
4085+
#define _SPIRV_OP(x, ...) \
4086+
typedef SPIRVInstTemplate<SPIRVTernaryBitwiseFunctionINTELInst, \
4087+
Op##x##INTEL, __VA_ARGS__> \
4088+
SPIRV##x##INTEL;
4089+
_SPIRV_OP(BitwiseFunction, true, 7)
4090+
#undef _SPIRV_OP
4091+
40374092
} // namespace SPIRV
40384093
#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
643643
add(CapabilitySubgroup2DBlockTransposeINTEL, "Subgroup2DBlockTransposeINTEL");
644644
add(CapabilitySubgroupMatrixMultiplyAccumulateINTEL,
645645
"SubgroupMatrixMultiplyAccumulateINTEL");
646+
add(CapabilityTernaryBitwiseFunctionINTEL, "TernaryBitwiseFunctionINTEL");
646647
// From spirv_internal.hpp
647648
add(internal::CapabilityFastCompositeINTEL, "FastCompositeINTEL");
648649
add(internal::CapabilityTokenTypeINTEL, "TokenTypeINTEL");

lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ _SPIRV_OP(Subgroup2DBlockLoadTransposeINTEL, 6233)
570570
_SPIRV_OP(Subgroup2DBlockPrefetchINTEL, 6234)
571571
_SPIRV_OP(Subgroup2DBlockStoreINTEL, 6235)
572572
_SPIRV_OP(SubgroupMatrixMultiplyAccumulateINTEL, 6237)
573+
_SPIRV_OP(BitwiseFunctionINTEL, 6242)
573574
_SPIRV_OP(GroupIMulKHR, 6401)
574575
_SPIRV_OP(GroupFMulKHR, 6402)
575576
_SPIRV_OP(GroupBitwiseAndKHR, 6403)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_ternary_bitwise_function -o %t.spv
3+
; RUN: llvm-spirv %t.spv --to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
8+
9+
; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
10+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
11+
; CHECK-ERROR-NEXT: SPV_INTEL_ternary_bitwise_function
12+
13+
; CHECK-SPIRV-NOT: Name [[#]] "_Z28__spirv_BitwiseFunctionINTELiiij"
14+
; CHECK-SPIRV-NOT: Name [[#]] "_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_j"
15+
16+
; CHECK-SPIRV-DAG: Capability TernaryBitwiseFunctionINTEL
17+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_ternary_bitwise_function"
18+
19+
; CHECK-SPIRV-DAG: TypeInt [[#TYPEINT:]] 32 0
20+
; CHECK-SPIRV-DAG: TypeVector [[#TYPEINTVEC4:]] [[#TYPEINT]] 4
21+
; CHECK-SPIRV-DAG: Constant [[#TYPEINT]] [[#ScalarLUT:]] 24
22+
; CHECK-SPIRV-DAG: Constant [[#TYPEINT]] [[#VecLUT:]] 42
23+
24+
; CHECK-SPIRV: Load [[#TYPEINT]] [[#ScalarA:]]
25+
; CHECK-SPIRV: Load [[#TYPEINT]] [[#ScalarB:]]
26+
; CHECK-SPIRV: Load [[#TYPEINT]] [[#ScalarC:]]
27+
; CHECK-SPIRV: BitwiseFunctionINTEL [[#TYPEINT]] {{.*}} [[#ScalarA]] [[#ScalarB]] [[#ScalarC]] [[#ScalarLUT]]
28+
; CHECK-SPIRV: Load [[#TYPEINTVEC4]] [[#VecA:]]
29+
; CHECK-SPIRV: Load [[#TYPEINTVEC4]] [[#VecB:]]
30+
; CHECK-SPIRV: Load [[#TYPEINTVEC4]] [[#VecC:]]
31+
; CHECK-SPIRV: BitwiseFunctionINTEL [[#TYPEINTVEC4]] {{.*}} [[#VecA]] [[#VecB]] [[#VecC]] [[#VecLUT]]
32+
33+
; CHECK-LLVM: %[[ScalarA:.*]] = load i32, ptr
34+
; CHECK-LLVM: %[[ScalarB:.*]] = load i32, ptr
35+
; CHECK-LLVM: %[[ScalarC:.*]] = load i32, ptr
36+
; CHECK-LLVM: call spir_func i32 @_Z28__spirv_BitwiseFunctionINTELiiii(i32 %[[ScalarA]], i32 %[[ScalarB]], i32 %[[ScalarC]], i32 24)
37+
; CHECK-LLVM: %[[VecA:.*]] = load <4 x i32>, ptr
38+
; CHECK-LLVM: %[[VecB:.*]] = load <4 x i32>, ptr
39+
; CHECK-LLVM: %[[VecC:.*]] = load <4 x i32>, ptr
40+
; CHECK-LLVM: call spir_func <4 x i32> @_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_i(<4 x i32> %[[VecA]], <4 x i32> %[[VecB]], <4 x i32> %[[VecC]], i32 42)
41+
42+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
43+
target triple = "spir"
44+
45+
; Function Attrs: nounwind readnone
46+
define spir_kernel void @fooScalar() {
47+
entry:
48+
%argA = alloca i32
49+
%argB = alloca i32
50+
%argC = alloca i32
51+
%A = load i32, ptr %argA
52+
%B = load i32, ptr %argB
53+
%C = load i32, ptr %argC
54+
%res = call spir_func i32 @_Z28__spirv_BitwiseFunctionINTELiiii(i32 %A, i32 %B, i32 %C, i32 24)
55+
ret void
56+
}
57+
58+
; Function Attrs: nounwind readnone
59+
define spir_kernel void @fooVec() {
60+
entry:
61+
%argA = alloca <4 x i32>
62+
%argB = alloca <4 x i32>
63+
%argC = alloca <4 x i32>
64+
%A = load <4 x i32>, ptr %argA
65+
%B = load <4 x i32>, ptr %argB
66+
%C = load <4 x i32>, ptr %argC
67+
%res = call spir_func <4 x i32> @_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_i(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C, i32 42)
68+
ret void
69+
}
70+
71+
declare dso_local spir_func i32 @_Z28__spirv_BitwiseFunctionINTELiiii(i32, i32, i32, i32)
72+
declare dso_local spir_func <4 x i32> @_Z28__spirv_BitwiseFunctionINTELDv4_iS_S_i(<4 x i32>, <4 x i32>, <4 x i32>, i32)
73+
74+
!llvm.module.flags = !{!0}
75+
!opencl.spir.version = !{!1}
76+
77+
!0 = !{i32 1, !"wchar_size", i32 4}
78+
!1 = !{i32 1, i32 2}

0 commit comments

Comments
 (0)