-
Notifications
You must be signed in to change notification settings - Fork 15.9k
[HLSL][DXIL][SPRIV] Added WaveActiveProduct intrinsic #164385 #165109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5ffce0b
f8db86c
e6a17ac
548299a
737aa73
690730e
3b1ca01
8941735
825c894
520c8aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -257,6 +257,24 @@ static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch, | |
| } | ||
| } | ||
|
|
||
| // Return wave active product that corresponds to the QT scalar type | ||
| static Intrinsic::ID getWaveActiveProductIntrinsic(llvm::Triple::ArchType Arch, | ||
| CGHLSLRuntime &RT, | ||
| QualType QT) { | ||
| switch (Arch) { | ||
| case llvm::Triple::spirv: | ||
| return Intrinsic::spv_wave_product; | ||
| case llvm::Triple::dxil: { | ||
| if (QT->isUnsignedIntegerType()) | ||
| return Intrinsic::dx_wave_uproduct; | ||
| return Intrinsic::dx_wave_product; | ||
| } | ||
| default: | ||
| llvm_unreachable("Intrinsic WaveActiveProduct" | ||
| " not supported by target architecture"); | ||
| } | ||
| } | ||
|
|
||
| // Return wave active max that corresponds to the QT scalar type | ||
| static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch, | ||
| CGHLSLRuntime &RT, QualType QT) { | ||
|
|
@@ -847,6 +865,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, | |
| &CGM.getModule(), IID, {OpExpr->getType()}), | ||
| ArrayRef{OpExpr}, "hlsl.wave.active.sum"); | ||
| } | ||
| case Builtin::BI__builtin_hlsl_wave_active_product: { | ||
| // Due to the use of variadic arguments, explicitly retreive argument | ||
| Value *OpExpr = EmitScalarExpr(E->getArg(0)); | ||
| Intrinsic::ID IID = getWaveActiveProductIntrinsic( | ||
| getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
| E->getArg(0)->getType()); | ||
|
|
||
| return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( | ||
| &CGM.getModule(), IID, {OpExpr->getType()}), | ||
| ArrayRef{OpExpr}, "hlsl.wave.active.product"); | ||
| } | ||
| case Builtin::BI__builtin_hlsl_wave_active_max: { | ||
| // Due to the use of variadic arguments, explicitly retreive argument | ||
| Value *OpExpr = EmitScalarExpr(E->getArg(0)); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ | ||
| // RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ | ||
| // RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL | ||
| // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ | ||
| // RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ | ||
| // RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV | ||
|
|
||
| // Test basic lowering to runtime function call. | ||
|
|
||
| // CHECK-LABEL: test_int | ||
| int test_int(int expr) { | ||
| // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.product.i32([[TY]] %[[#]]) | ||
| // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.product.i32([[TY]] %[[#]]) | ||
| // CHECK: ret [[TY]] %[[RET]] | ||
| return WaveActiveProduct(expr); | ||
| } | ||
|
|
||
| // CHECK-DXIL: declare [[TY]] @llvm.dx.wave.product.i32([[TY]]) #[[#attr:]] | ||
| // CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.product.i32([[TY]]) #[[#attr:]] | ||
|
|
||
| // CHECK-LABEL: test_uint64_t | ||
| uint64_t test_uint64_t(uint64_t expr) { | ||
| // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.product.i64([[TY]] %[[#]]) | ||
| // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.uproduct.i64([[TY]] %[[#]]) | ||
| // CHECK: ret [[TY]] %[[RET]] | ||
| return WaveActiveProduct(expr); | ||
| } | ||
|
|
||
| // CHECK-DXIL: declare [[TY]] @llvm.dx.wave.uproduct.i64([[TY]]) #[[#attr:]] | ||
| // CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.product.i64([[TY]]) #[[#attr:]] | ||
|
|
||
| // Test basic lowering to runtime function call with array and float value. | ||
|
|
||
| // CHECK-LABEL: test_floatv4 | ||
| float4 test_floatv4(float4 expr) { | ||
| // CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func [[TY1:.*]] @llvm.spv.wave.product.v4f32([[TY1]] %[[#]] | ||
| // CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn [[TY1:.*]] @llvm.dx.wave.product.v4f32([[TY1]] %[[#]]) | ||
| // CHECK: ret [[TY1]] %[[RET1]] | ||
| return WaveActiveProduct(expr); | ||
| } | ||
|
|
||
| // CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.product.v4f32([[TY1]]) #[[#attr]] | ||
| // CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.product.v4f32([[TY1]]) #[[#attr]] | ||
|
|
||
| // CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| // RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify | ||
|
|
||
| int test_too_few_arg() { | ||
| return __builtin_hlsl_wave_active_product(); | ||
| // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} | ||
| } | ||
|
|
||
| float2 test_too_many_arg(float2 p0) { | ||
| return __builtin_hlsl_wave_active_product(p0, p0); | ||
| // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} | ||
| } | ||
|
|
||
| bool test_expr_bool_type_check(bool p0) { | ||
| return __builtin_hlsl_wave_active_product(p0); | ||
| // expected-error@-1 {{invalid operand of type 'bool'}} | ||
| } | ||
|
|
||
| bool2 test_expr_bool_vec_type_check(bool2 p0) { | ||
| return __builtin_hlsl_wave_active_product(p0); | ||
| // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}} | ||
| } | ||
|
|
||
| struct S { float f; }; | ||
|
|
||
| S test_expr_struct_type_check(S p0) { | ||
| return __builtin_hlsl_wave_active_product(p0); | ||
| // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}} | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -228,6 +228,9 @@ class SPIRVInstructionSelector : public InstructionSelector { | |
| bool selectWaveReduceSum(Register ResVReg, const SPIRVType *ResType, | ||
| MachineInstr &I) const; | ||
|
|
||
| bool selectWaveReduceProduct(Register ResVReg, const SPIRVType *ResType, | ||
| MachineInstr &I) const; | ||
|
|
||
| bool selectConst(Register ResVReg, const SPIRVType *ResType, | ||
| MachineInstr &I) const; | ||
|
|
||
|
|
@@ -2525,6 +2528,32 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, | |
| .addUse(I.getOperand(2).getReg()); | ||
| } | ||
|
|
||
| bool SPIRVInstructionSelector::selectWaveReduceProduct(Register ResVReg, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mentioned this on #165156 but it would be nice to do something like |
||
| const SPIRVType *ResType, | ||
| MachineInstr &I) const { | ||
| assert(I.getNumOperands() == 3); | ||
| assert(I.getOperand(2).isReg()); | ||
| MachineBasicBlock &BB = *I.getParent(); | ||
| Register InputRegister = I.getOperand(2).getReg(); | ||
| SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister); | ||
|
|
||
| if (!InputType) | ||
| report_fatal_error("Input Type could not be determined."); | ||
|
|
||
| SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); | ||
| // Retreive the operation to use based on input type | ||
| bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat); | ||
| auto Opcode = | ||
| IsFloatTy ? SPIRV::OpGroupNonUniformFMul : SPIRV::OpGroupNonUniformIMul; | ||
| return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) | ||
| .addDef(ResVReg) | ||
| .addUse(GR.getSPIRVTypeID(ResType)) | ||
| .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII, | ||
| !STI.isShader())) | ||
| .addImm(SPIRV::GroupOperation::Reduce) | ||
| .addUse(I.getOperand(2).getReg()); | ||
| } | ||
|
|
||
| bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg, | ||
| const SPIRVType *ResType, | ||
| MachineInstr &I) const { | ||
|
|
@@ -3540,6 +3569,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, | |
| return selectWaveReduceMin(ResVReg, ResType, I, /*IsUnsigned*/ false); | ||
| case Intrinsic::spv_wave_reduce_sum: | ||
| return selectWaveReduceSum(ResVReg, ResType, I); | ||
| case Intrinsic::spv_wave_product: | ||
| return selectWaveReduceProduct(ResVReg, ResType, I); | ||
| case Intrinsic::spv_wave_readlane: | ||
| return selectWaveOpInst(ResVReg, ResType, I, | ||
| SPIRV::OpGroupNonUniformShuffle); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RT is not used in this function