|
34 | 34 |
|
35 | 35 | #ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
|
36 | 36 | #include "clang/Basic/Version.h"
|
| 37 | +#include "clang/Sema/Lookup.h" |
37 | 38 | #else
|
38 | 39 | namespace clang {
|
39 | 40 | uint32_t getGitCommitCount() { return 0; }
|
@@ -13234,14 +13235,26 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
|
13234 | 13235 |
|
13235 | 13236 | void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
|
13236 | 13237 | auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
|
13237 |
| - assert(numThreadsAttr && "thread group size missing from entry-point"); |
| 13238 | + auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>(); |
| 13239 | + assert((numThreadsAttr || localSizeIdAttr) && |
| 13240 | + "thread group size missing from entry-point"); |
13238 | 13241 |
|
13239 |
| - uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX()); |
13240 |
| - uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY()); |
13241 |
| - uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ()); |
| 13242 | + if (numThreadsAttr) { |
| 13243 | + uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX()); |
| 13244 | + uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY()); |
| 13245 | + uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ()); |
13242 | 13246 |
|
13243 |
| - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, |
13244 |
| - {x, y, z}, decl->getLocation()); |
| 13247 | + spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, |
| 13248 | + {x, y, z}, decl->getLocation()); |
| 13249 | + } else { |
| 13250 | + auto *exprX = localSizeIdAttr->getX(); |
| 13251 | + auto *x = doExpr(exprX); |
| 13252 | + auto *y = doExpr(localSizeIdAttr->getY()); |
| 13253 | + auto *z = doExpr(localSizeIdAttr->getZ()); |
| 13254 | + spvBuilder.addExecutionModeId(entryFunction, |
| 13255 | + spv::ExecutionMode::LocalSizeId, {x, y, z}, |
| 13256 | + decl->getLocation()); |
| 13257 | + } |
13245 | 13258 |
|
13246 | 13259 | auto *waveSizeAttr = decl->getAttr<HLSLWaveSizeAttr>();
|
13247 | 13260 | if (waveSizeAttr) {
|
@@ -13469,6 +13482,13 @@ bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
|
13469 | 13482 | z = static_cast<uint32_t>(numThreadsAttr->getZ());
|
13470 | 13483 | spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
|
13471 | 13484 | {x, y, z}, decl->getLocation());
|
| 13485 | + } else if (auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>()) { |
| 13486 | + auto *x = doExpr(localSizeIdAttr->getX()); |
| 13487 | + auto *y = doExpr(localSizeIdAttr->getY()); |
| 13488 | + auto *z = doExpr(localSizeIdAttr->getZ()); |
| 13489 | + spvBuilder.addExecutionModeId(entryFunction, |
| 13490 | + spv::ExecutionMode::LocalSizeId, {x, y, z}, |
| 13491 | + decl->getLocation()); |
13472 | 13492 | }
|
13473 | 13493 |
|
13474 | 13494 | // Early return for amplification shaders as they only take the 'numthreads'
|
@@ -15030,9 +15050,14 @@ bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
|
15030 | 15050 | void SpirvEmitter::addDerivativeGroupExecutionMode() {
|
15031 | 15051 | assert(spvContext.isCS());
|
15032 | 15052 |
|
15033 |
| - SpirvExecutionMode *numThreadsEm = spvBuilder.getModule()->findExecutionMode( |
15034 |
| - entryFunction, spv::ExecutionMode::LocalSize); |
15035 |
| - auto numThreads = numThreadsEm->getParams(); |
| 15053 | + SpirvExecutionModeBase *numThreadsEm = |
| 15054 | + spvBuilder.getModule()->findExecutionMode(entryFunction, |
| 15055 | + spv::ExecutionMode::LocalSize); |
| 15056 | + |
| 15057 | + // TODO: Need to handle LocalSizeID as well. |
| 15058 | + assert(numThreadsEm->getKind() == SpirvInstruction::IK_ExecutionMode); |
| 15059 | + auto numThreads = |
| 15060 | + static_cast<SpirvExecutionMode *>(numThreadsEm)->getParams(); |
15036 | 15061 |
|
15037 | 15062 | // The layout of the quad is determined by the numer of threads in each
|
15038 | 15063 | // dimention. From the HLSL spec
|
|
0 commit comments