|
34 | 34 |
|
35 | 35 | #ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
|
36 | 36 | #include "clang/Basic/Version.h"
|
| 37 | +#include "clang/Sema/Lookup.h" |
37 | 38 | #else
|
38 | 39 | namespace clang {
|
39 | 40 | uint32_t getGitCommitCount() { return 0; }
|
@@ -13226,14 +13227,24 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
|
13226 | 13227 |
|
13227 | 13228 | void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
|
13228 | 13229 | auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
|
13229 |
| - assert(numThreadsAttr && "thread group size missing from entry-point"); |
| 13230 | + auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>(); |
| 13231 | + assert((numThreadsAttr || localSizeIdAttr) && "thread group size missing from entry-point"); |
13230 | 13232 |
|
13231 |
| - uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX()); |
13232 |
| - uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY()); |
13233 |
| - uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ()); |
| 13233 | + if (numThreadsAttr) { |
| 13234 | + uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX()); |
| 13235 | + uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY()); |
| 13236 | + uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ()); |
13234 | 13237 |
|
13235 |
| - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, |
13236 |
| - {x, y, z}, decl->getLocation()); |
| 13238 | + spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, |
| 13239 | + {x, y, z}, decl->getLocation()); |
| 13240 | + } else { |
| 13241 | + auto * exprX = localSizeIdAttr->getX(); |
| 13242 | + auto *x = doExpr(exprX); |
| 13243 | + auto *y = doExpr(localSizeIdAttr->getY()); |
| 13244 | + auto *z = doExpr(localSizeIdAttr->getZ()); |
| 13245 | + spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::LocalSizeId, |
| 13246 | + {x, y, z}, decl->getLocation()); |
| 13247 | + } |
13237 | 13248 |
|
13238 | 13249 | auto *waveSizeAttr = decl->getAttr<HLSLWaveSizeAttr>();
|
13239 | 13250 | if (waveSizeAttr) {
|
@@ -13461,6 +13472,12 @@ bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
|
13461 | 13472 | z = static_cast<uint32_t>(numThreadsAttr->getZ());
|
13462 | 13473 | spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
|
13463 | 13474 | {x, y, z}, decl->getLocation());
|
| 13475 | + } else if (auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>()) { |
| 13476 | + auto *x = doExpr(localSizeIdAttr->getX()); |
| 13477 | + auto *y = doExpr(localSizeIdAttr->getY()); |
| 13478 | + auto *z = doExpr(localSizeIdAttr->getZ()); |
| 13479 | + spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::LocalSizeId, |
| 13480 | + {x, y, z}, decl->getLocation()); |
13464 | 13481 | }
|
13465 | 13482 |
|
13466 | 13483 | // Early return for amplification shaders as they only take the 'numthreads'
|
@@ -15022,9 +15039,12 @@ bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
|
15022 | 15039 | void SpirvEmitter::addDerivativeGroupExecutionMode() {
|
15023 | 15040 | assert(spvContext.isCS());
|
15024 | 15041 |
|
15025 |
| - SpirvExecutionMode *numThreadsEm = spvBuilder.getModule()->findExecutionMode( |
| 15042 | + SpirvExecutionModeBase *numThreadsEm = spvBuilder.getModule()->findExecutionMode( |
15026 | 15043 | entryFunction, spv::ExecutionMode::LocalSize);
|
15027 |
| - auto numThreads = numThreadsEm->getParams(); |
| 15044 | + |
| 15045 | + // TODO: Need to handle LocalSizeID as well. |
| 15046 | + assert(numThreadsEm->getKind() == SpirvInstruction::IK_ExecutionMode); |
| 15047 | + auto numThreads = static_cast<SpirvExecutionMode*>(numThreadsEm)->getParams(); |
15028 | 15048 |
|
15029 | 15049 | // The layout of the quad is determined by the numer of threads in each
|
15030 | 15050 | // dimention. From the HLSL spec
|
|
0 commit comments