Skip to content

Commit c2797ab

Browse files
committed
Add vk::LocalSizeId Attribute
In SPIR-V, the number of thread in the group can be specificed using the `LocalSize` execution mode. This corresponds nicely with the `numthreads` attribute in HLSL. However there is another way. You can use `LocalSizeId`, which uses ids of other instructions. This allows spec constants to be provided as a dimention on the local size id. This PR adds a new attribute that can be used instead of the `numthreads` attribute. It allows constant expression or spec constants as parameters.
1 parent 6475f98 commit c2797ab

14 files changed

+198
-43
lines changed

tools/clang/include/clang/Basic/Attr.td

+5
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,11 @@ def HLSLNumThreads: InheritableAttr {
671671
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
672672
let Documentation = [Undocumented];
673673
}
674+
def HLSLSpirvNumThreads : InheritableAttr {
675+
let Spellings = [CXX11<"vk", "LocalSizeId", 2015>];
676+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
677+
let Documentation = [Undocumented];
678+
}
674679
def HLSLRootSignature: InheritableAttr {
675680
let Spellings = [CXX11<"", "RootSignature", 2015>];
676681
let Args = [StringArgument<"SignatureName">];

tools/clang/include/clang/SPIRV/SpirvBuilder.h

+41-2
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,14 @@ class SpirvBuilder {
607607
SourceLocation,
608608
bool useIdParams = false);
609609

610+
/// \brief Adds an execution mode to the module under construction if it does
611+
/// not already exist. Return the newly added instruction or the existing
612+
/// instruction, if one already exists.
613+
inline SpirvInstruction *
614+
addExecutionModeId(SpirvFunction *entryPoint, spv::ExecutionMode em,
615+
llvm::ArrayRef<SpirvInstruction *> params,
616+
SourceLocation loc);
617+
610618
/// \brief Adds an OpModuleProcessed instruction to the module under
611619
/// construction.
612620
void addModuleProcessed(llvm::StringRef process);
@@ -954,15 +962,46 @@ SpirvBuilder::addExecutionMode(SpirvFunction *entryPoint, spv::ExecutionMode em,
954962
llvm::ArrayRef<uint32_t> params,
955963
SourceLocation loc, bool useIdParams) {
956964
SpirvExecutionMode *mode = nullptr;
957-
SpirvExecutionMode *existingInstruction =
965+
SpirvExecutionModeBase *existingInstruction =
958966
mod->findExecutionMode(entryPoint, em);
959967

960968
if (!existingInstruction) {
961969
mode = new (context)
962970
SpirvExecutionMode(loc, entryPoint, em, params, useIdParams);
963971
mod->addExecutionMode(mode);
964972
} else {
965-
mode = existingInstruction;
973+
// No execution mode can be used with both OpExecutionMode and
974+
// OpExecutionModeId. If this assert is triggered, then either this
975+
// `addExecutionModeId` should have been called with `em` or the existing
976+
// instruction is wrong.
977+
assert(existingInstruction->getKind() ==
978+
SpirvInstruction::IK_ExecutionMode);
979+
mode = cast<SpirvExecutionMode>(existingInstruction);
980+
}
981+
982+
return mode;
983+
}
984+
985+
SpirvInstruction *SpirvBuilder::addExecutionModeId(
986+
SpirvFunction *entryPoint, spv::ExecutionMode em,
987+
llvm::ArrayRef<SpirvInstruction *> params, SourceLocation loc) {
988+
SpirvExecutionModeId *mode = nullptr;
989+
SpirvExecutionModeBase *existingInstruction =
990+
mod->findExecutionMode(entryPoint, em);
991+
assert(!existingInstruction || existingInstruction->getKind() ==
992+
SpirvInstruction::IK_ExecutionModeId);
993+
994+
if (!existingInstruction) {
995+
mode = new (context) SpirvExecutionModeId(loc, entryPoint, em, params);
996+
mod->addExecutionMode(mode);
997+
} else {
998+
// No execution mode can be used with both OpExecutionMode and
999+
// OpExecutionModeId. If this assert is triggered, then either this
1000+
// `addExecutionMode` should have been called with `em` or the existing
1001+
// instruction is wrong.
1002+
assert(existingInstruction->getKind() ==
1003+
SpirvInstruction::IK_ExecutionMode);
1004+
mode = cast<SpirvExecutionModeId>(existingInstruction);
9661005
}
9671006

9681007
return mode;

tools/clang/include/clang/SPIRV/SpirvInstruction.h

+47-5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class SpirvInstruction {
5353
IK_MemoryModel, // OpMemoryModel
5454
IK_EntryPoint, // OpEntryPoint
5555
IK_ExecutionMode, // OpExecutionMode
56+
IK_ExecutionModeId, // OpExecutionModeId
5657
IK_String, // OpString (debug)
5758
IK_Source, // OpSource (debug)
5859
IK_ModuleProcessed, // OpModuleProcessed (debug)
@@ -396,8 +397,31 @@ class SpirvEntryPoint : public SpirvInstruction {
396397
llvm::SmallVector<SpirvVariable *, 8> interfaceVec;
397398
};
398399

400+
class SpirvExecutionModeBase : public SpirvInstruction {
401+
public:
402+
SpirvExecutionModeBase(Kind kind, spv::Op opcode, SourceLocation loc,
403+
SpirvFunction *entryPointFunction,
404+
spv::ExecutionMode executionMode)
405+
: SpirvInstruction(kind, opcode, QualType(), loc),
406+
entryPoint(entryPointFunction), execMode(executionMode) {}
407+
408+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExecutionModeBase)
409+
410+
// For LLVM-style RTTI
411+
static bool classof(const SpirvInstruction *inst) { return false; }
412+
413+
bool invokeVisitor(Visitor *v) override;
414+
415+
SpirvFunction *getEntryPoint() const { return entryPoint; }
416+
spv::ExecutionMode getExecutionMode() const { return execMode; }
417+
418+
private:
419+
SpirvFunction *entryPoint;
420+
spv::ExecutionMode execMode;
421+
};
422+
399423
/// \brief OpExecutionMode and OpExecutionModeId instructions
400-
class SpirvExecutionMode : public SpirvInstruction {
424+
class SpirvExecutionMode : public SpirvExecutionModeBase {
401425
public:
402426
SpirvExecutionMode(SourceLocation loc, SpirvFunction *entryPointFunction,
403427
spv::ExecutionMode, llvm::ArrayRef<uint32_t> params,
@@ -412,16 +436,34 @@ class SpirvExecutionMode : public SpirvInstruction {
412436

413437
bool invokeVisitor(Visitor *v) override;
414438

415-
SpirvFunction *getEntryPoint() const { return entryPoint; }
416-
spv::ExecutionMode getExecutionMode() const { return execMode; }
417439
llvm::ArrayRef<uint32_t> getParams() const { return params; }
418440

419441
private:
420-
SpirvFunction *entryPoint;
421-
spv::ExecutionMode execMode;
422442
llvm::SmallVector<uint32_t, 4> params;
423443
};
424444

445+
/// \brief OpExecutionModeId
446+
class SpirvExecutionModeId : public SpirvExecutionModeBase {
447+
public:
448+
SpirvExecutionModeId(SourceLocation loc, SpirvFunction *entryPointFunction,
449+
spv::ExecutionMode em,
450+
llvm::ArrayRef<SpirvInstruction *> params);
451+
452+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExecutionModeId)
453+
454+
// For LLVM-style RTTI
455+
static bool classof(const SpirvInstruction *inst) {
456+
return inst->getKind() == IK_ExecutionModeId;
457+
}
458+
459+
bool invokeVisitor(Visitor *v) override;
460+
461+
llvm::ArrayRef<SpirvInstruction *> getParams() const { return params; }
462+
463+
private:
464+
llvm::SmallVector<SpirvInstruction *, 4> params;
465+
};
466+
425467
/// \brief OpString instruction
426468
class SpirvString : public SpirvInstruction {
427469
public:

tools/clang/include/clang/SPIRV/SpirvModule.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ class SpirvModule {
119119

120120
// Returns an existing execution mode instruction that is the same as em if it
121121
// exists. Return nullptr otherwise.
122-
SpirvExecutionMode *findExecutionMode(SpirvFunction *entryPoint,
123-
spv::ExecutionMode em);
122+
SpirvExecutionModeBase *findExecutionMode(SpirvFunction *entryPoint,
123+
spv::ExecutionMode em);
124124

125125
// Adds an execution mode to the module.
126-
void addExecutionMode(SpirvExecutionMode *);
126+
void addExecutionMode(SpirvExecutionModeBase *em);
127127

128128
// Adds an extension to the module. Returns true if the extension was added.
129129
// Returns false otherwise (e.g. if the extension already existed).
@@ -194,7 +194,7 @@ class SpirvModule {
194194
llvm::SmallVector<SpirvExtInstImport *, 1> extInstSets;
195195
SpirvMemoryModel *memoryModel;
196196
llvm::SmallVector<SpirvEntryPoint *, 1> entryPoints;
197-
llvm::SmallVector<SpirvExecutionMode *, 4> executionModes;
197+
llvm::SmallVector<SpirvExecutionModeBase *, 4> executionModes;
198198
llvm::SmallVector<SpirvString *, 4> constStrings;
199199
std::vector<SpirvSource *> sources;
200200
std::vector<SpirvModuleProcessed *> moduleProcesses;

tools/clang/include/clang/SPIRV/SpirvVisitor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Visitor {
6060
DEFINE_VISIT_METHOD(SpirvExtInstImport)
6161
DEFINE_VISIT_METHOD(SpirvMemoryModel)
6262
DEFINE_VISIT_METHOD(SpirvEntryPoint)
63-
DEFINE_VISIT_METHOD(SpirvExecutionMode)
63+
DEFINE_VISIT_METHOD(SpirvExecutionModeBase)
6464
DEFINE_VISIT_METHOD(SpirvString)
6565
DEFINE_VISIT_METHOD(SpirvSource)
6666
DEFINE_VISIT_METHOD(SpirvModuleProcessed)

tools/clang/lib/SPIRV/CapabilityVisitor.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
634634
return true;
635635
}
636636

637-
bool CapabilityVisitor::visit(SpirvExecutionMode *execMode) {
637+
bool CapabilityVisitor::visit(SpirvExecutionModeBase *execMode) {
638638
spv::ExecutionMode executionMode = execMode->getExecutionMode();
639639
SourceLocation execModeSourceLocation = execMode->getSourceLocation();
640640
SourceLocation entryPointSourceLocation =

tools/clang/lib/SPIRV/CapabilityVisitor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class CapabilityVisitor : public Visitor {
3131

3232
bool visit(SpirvDecoration *decor) override;
3333
bool visit(SpirvEntryPoint *) override;
34-
bool visit(SpirvExecutionMode *) override;
34+
bool visit(SpirvExecutionModeBase *execMode) override;
3535
bool visit(SpirvImageQuery *) override;
3636
bool visit(SpirvImageOp *) override;
3737
bool visit(SpirvImageSparseTexelsResident *) override;

tools/clang/lib/SPIRV/EmitVisitor.cpp

+18-7
Original file line numberDiff line numberDiff line change
@@ -613,18 +613,29 @@ bool EmitVisitor::visit(SpirvEntryPoint *inst) {
613613
return true;
614614
}
615615

616-
bool EmitVisitor::visit(SpirvExecutionMode *inst) {
616+
bool EmitVisitor::visit(SpirvExecutionModeBase *inst) {
617617
initInstruction(inst);
618618
curInst.push_back(getOrAssignResultId<SpirvFunction>(inst->getEntryPoint()));
619619
curInst.push_back(static_cast<uint32_t>(inst->getExecutionMode()));
620620
if (inst->getopcode() == spv::Op::OpExecutionMode) {
621-
curInst.insert(curInst.end(), inst->getParams().begin(),
622-
inst->getParams().end());
621+
ArrayRef<uint32_t> params =
622+
static_cast<SpirvExecutionMode *>(inst)->getParams();
623+
curInst.insert(curInst.end(), params.begin(), params.end());
623624
} else {
624-
for (uint32_t param : inst->getParams()) {
625-
curInst.push_back(typeHandler.getOrCreateConstantInt(
626-
llvm::APInt(32, param), context.getUIntType(32),
627-
/*isSpecConst */ false));
625+
if (inst->getKind() == SpirvInstruction::IK_ExecutionModeId) {
626+
auto *exeModeId = static_cast<SpirvExecutionModeId *>(inst);
627+
for (SpirvInstruction *param : exeModeId->getParams()) {
628+
uint32_t id = getOrAssignResultId<SpirvInstruction>(param);
629+
curInst.push_back(id);
630+
}
631+
} else {
632+
ArrayRef<uint32_t> params =
633+
static_cast<SpirvExecutionMode *>(inst)->getParams();
634+
for (uint32_t param : params) {
635+
curInst.push_back(typeHandler.getOrCreateConstantInt(
636+
llvm::APInt(32, param), context.getUIntType(32),
637+
/*isSpecConst */ false));
638+
}
628639
}
629640
}
630641
finalizeInstruction(&preambleBinary);

tools/clang/lib/SPIRV/EmitVisitor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class EmitVisitor : public Visitor {
228228
bool visit(SpirvEmitVertex *) override;
229229
bool visit(SpirvEndPrimitive *) override;
230230
bool visit(SpirvEntryPoint *) override;
231-
bool visit(SpirvExecutionMode *) override;
231+
bool visit(SpirvExecutionModeBase *) override;
232232
bool visit(SpirvString *) override;
233233
bool visit(SpirvSource *) override;
234234
bool visit(SpirvModuleProcessed *) override;

tools/clang/lib/SPIRV/SpirvEmitter.cpp

+34-9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
#ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
3636
#include "clang/Basic/Version.h"
37+
#include "clang/Sema/Lookup.h"
3738
#else
3839
namespace clang {
3940
uint32_t getGitCommitCount() { return 0; }
@@ -13234,14 +13235,26 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
1323413235

1323513236
void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
1323613237
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
13237-
assert(numThreadsAttr && "thread group size missing from entry-point");
13238+
auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>();
13239+
assert((numThreadsAttr || localSizeIdAttr) &&
13240+
"thread group size missing from entry-point");
1323813241

13239-
uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
13240-
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
13241-
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());
13242+
if (numThreadsAttr) {
13243+
uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
13244+
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
13245+
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());
1324213246

13243-
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13244-
{x, y, z}, decl->getLocation());
13247+
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13248+
{x, y, z}, decl->getLocation());
13249+
} else {
13250+
auto *exprX = localSizeIdAttr->getX();
13251+
auto *x = doExpr(exprX);
13252+
auto *y = doExpr(localSizeIdAttr->getY());
13253+
auto *z = doExpr(localSizeIdAttr->getZ());
13254+
spvBuilder.addExecutionModeId(entryFunction,
13255+
spv::ExecutionMode::LocalSizeId, {x, y, z},
13256+
decl->getLocation());
13257+
}
1324513258

1324613259
auto *waveSizeAttr = decl->getAttr<HLSLWaveSizeAttr>();
1324713260
if (waveSizeAttr) {
@@ -13469,6 +13482,13 @@ bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
1346913482
z = static_cast<uint32_t>(numThreadsAttr->getZ());
1347013483
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
1347113484
{x, y, z}, decl->getLocation());
13485+
} else if (auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>()) {
13486+
auto *x = doExpr(localSizeIdAttr->getX());
13487+
auto *y = doExpr(localSizeIdAttr->getY());
13488+
auto *z = doExpr(localSizeIdAttr->getZ());
13489+
spvBuilder.addExecutionModeId(entryFunction,
13490+
spv::ExecutionMode::LocalSizeId, {x, y, z},
13491+
decl->getLocation());
1347213492
}
1347313493

1347413494
// Early return for amplification shaders as they only take the 'numthreads'
@@ -15030,9 +15050,14 @@ bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
1503015050
void SpirvEmitter::addDerivativeGroupExecutionMode() {
1503115051
assert(spvContext.isCS());
1503215052

15033-
SpirvExecutionMode *numThreadsEm = spvBuilder.getModule()->findExecutionMode(
15034-
entryFunction, spv::ExecutionMode::LocalSize);
15035-
auto numThreads = numThreadsEm->getParams();
15053+
SpirvExecutionModeBase *numThreadsEm =
15054+
spvBuilder.getModule()->findExecutionMode(entryFunction,
15055+
spv::ExecutionMode::LocalSize);
15056+
15057+
// TODO: Need to handle LocalSizeID as well.
15058+
assert(numThreadsEm->getKind() == SpirvInstruction::IK_ExecutionMode);
15059+
auto numThreads =
15060+
static_cast<SpirvExecutionMode *>(numThreadsEm)->getParams();
1503615061

1503715062
// The layout of the quad is determined by the numer of threads in each
1503815063
// dimention. From the HLSL spec

tools/clang/lib/SPIRV/SpirvInstruction.cpp

+13-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtension)
2929
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtInstImport)
3030
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvMemoryModel)
3131
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvEntryPoint)
32+
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionModeBase)
3233
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionMode)
34+
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionModeId)
3335
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvString)
3436
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSource)
3537
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvModuleProcessed)
@@ -203,11 +205,17 @@ SpirvExecutionMode::SpirvExecutionMode(SourceLocation loc, SpirvFunction *entry,
203205
spv::ExecutionMode em,
204206
llvm::ArrayRef<uint32_t> paramsVec,
205207
bool usesIdParams)
206-
: SpirvInstruction(IK_ExecutionMode,
207-
usesIdParams ? spv::Op::OpExecutionModeId
208-
: spv::Op::OpExecutionMode,
209-
QualType(), loc),
210-
entryPoint(entry), execMode(em),
208+
: SpirvExecutionModeBase(IK_ExecutionMode,
209+
usesIdParams ? spv::Op::OpExecutionModeId
210+
: spv::Op::OpExecutionMode,
211+
loc, entry, em),
212+
params(paramsVec.begin(), paramsVec.end()) {}
213+
214+
SpirvExecutionModeId::SpirvExecutionModeId(
215+
SourceLocation loc, SpirvFunction *entry, spv::ExecutionMode em,
216+
llvm::ArrayRef<SpirvInstruction *> paramsVec)
217+
: SpirvExecutionModeBase(IK_ExecutionModeId, spv::Op::OpExecutionModeId,
218+
loc, entry, em),
211219
params(paramsVec.begin(), paramsVec.end()) {}
212220

213221
SpirvString::SpirvString(SourceLocation loc, llvm::StringRef stringLiteral)

tools/clang/lib/SPIRV/SpirvModule.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,10 @@ void SpirvModule::addEntryPoint(SpirvEntryPoint *ep) {
294294
entryPoints.push_back(ep);
295295
}
296296

297-
SpirvExecutionMode *SpirvModule::findExecutionMode(SpirvFunction *entryPoint,
298-
spv::ExecutionMode em) {
299-
for (SpirvExecutionMode *cem : executionModes) {
297+
SpirvExecutionModeBase *
298+
SpirvModule::findExecutionMode(SpirvFunction *entryPoint,
299+
spv::ExecutionMode em) {
300+
for (SpirvExecutionModeBase *cem : executionModes) {
300301
if (cem->getEntryPoint() != entryPoint)
301302
continue;
302303
if (cem->getExecutionMode() != em)
@@ -306,7 +307,7 @@ SpirvExecutionMode *SpirvModule::findExecutionMode(SpirvFunction *entryPoint,
306307
return nullptr;
307308
}
308309

309-
void SpirvModule::addExecutionMode(SpirvExecutionMode *em) {
310+
void SpirvModule::addExecutionMode(SpirvExecutionModeBase *em) {
310311
assert(em && "cannot add null execution mode");
311312
executionModes.push_back(em);
312313
}

0 commit comments

Comments
 (0)