Skip to content

Commit 74a8ddb

Browse files
committed
Add vk::LocalSizeId Attribute
In SPIR-V, the number of thread in the group can be specificed using the `LocalSize` execution mode. This corresponds nicely with the `numthreads` attribute in HLSL. However there is another way. You can use `LocalSizeId`, which uses ids of other instructions. This allows spec constants to be provided as a dimention on the local size id. This PR adds a new attribute that can be used instead of the `numthreads` attribute. It allows constant expression or spec constants as parameters.
1 parent a7b7a0c commit 74a8ddb

14 files changed

+182
-39
lines changed

tools/clang/include/clang/Basic/Attr.td

+5
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,11 @@ def HLSLNumThreads: InheritableAttr {
671671
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
672672
let Documentation = [Undocumented];
673673
}
674+
def HLSLSpirvNumThreads: InheritableAttr {
675+
let Spellings = [CXX11<"vk", "LocalSizeId", 2015>];
676+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
677+
let Documentation = [Undocumented];
678+
}
674679
def HLSLRootSignature: InheritableAttr {
675680
let Spellings = [CXX11<"", "RootSignature", 2015>];
676681
let Args = [StringArgument<"SignatureName">];

tools/clang/include/clang/SPIRV/SpirvBuilder.h

+39-2
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,13 @@ class SpirvBuilder {
607607
SourceLocation,
608608
bool useIdParams = false);
609609

610+
/// \brief Adds an execution mode to the module under construction if it does
611+
/// not already exist. Return the newly added instruction or the existing
612+
/// instruction, if one already exists.
613+
inline SpirvInstruction *addExecutionModeId(SpirvFunction *entryPoint, spv::ExecutionMode em,
614+
llvm::ArrayRef<SpirvInstruction*> params,
615+
SourceLocation loc);
616+
610617
/// \brief Adds an OpModuleProcessed instruction to the module under
611618
/// construction.
612619
void addModuleProcessed(llvm::StringRef process);
@@ -954,15 +961,45 @@ SpirvBuilder::addExecutionMode(SpirvFunction *entryPoint, spv::ExecutionMode em,
954961
llvm::ArrayRef<uint32_t> params,
955962
SourceLocation loc, bool useIdParams) {
956963
SpirvExecutionMode *mode = nullptr;
957-
SpirvExecutionMode *existingInstruction =
964+
SpirvExecutionModeBase *existingInstruction =
958965
mod->findExecutionMode(entryPoint, em);
959966

960967
if (!existingInstruction) {
961968
mode = new (context)
962969
SpirvExecutionMode(loc, entryPoint, em, params, useIdParams);
963970
mod->addExecutionMode(mode);
964971
} else {
965-
mode = existingInstruction;
972+
// No execution mode can be used with both OpExecutionMode and
973+
// OpExecutionModeId. If this assert is triggered, then either this
974+
// `addExecutionModeId` should have been called with `em` or the existing
975+
// instruction is wrong.
976+
assert(existingInstruction->getKind() == SpirvInstruction::IK_ExecutionMode);
977+
mode = cast<SpirvExecutionMode>(existingInstruction);
978+
}
979+
980+
return mode;
981+
}
982+
983+
SpirvInstruction *
984+
SpirvBuilder::addExecutionModeId(SpirvFunction *entryPoint, spv::ExecutionMode em,
985+
llvm::ArrayRef<SpirvInstruction*> params,
986+
SourceLocation loc) {
987+
SpirvExecutionModeId *mode = nullptr;
988+
SpirvExecutionModeBase *existingInstruction =
989+
mod->findExecutionMode(entryPoint, em);
990+
assert(!existingInstruction || existingInstruction->getKind() == SpirvInstruction::IK_ExecutionModeId);
991+
992+
if (!existingInstruction) {
993+
mode = new (context)
994+
SpirvExecutionModeId(loc, entryPoint, em, params);
995+
mod->addExecutionMode(mode);
996+
} else {
997+
// No execution mode can be used with both OpExecutionMode and
998+
// OpExecutionModeId. If this assert is triggered, then either this
999+
// `addExecutionMode` should have been called with `em` or the existing
1000+
// instruction is wrong.
1001+
assert(existingInstruction->getKind() == SpirvInstruction::IK_ExecutionMode);
1002+
mode = cast<SpirvExecutionModeId>(existingInstruction);
9661003
}
9671004

9681005
return mode;

tools/clang/include/clang/SPIRV/SpirvInstruction.h

+47-5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class SpirvInstruction {
5353
IK_MemoryModel, // OpMemoryModel
5454
IK_EntryPoint, // OpEntryPoint
5555
IK_ExecutionMode, // OpExecutionMode
56+
IK_ExecutionModeId, // OpExecutionModeId
5657
IK_String, // OpString (debug)
5758
IK_Source, // OpSource (debug)
5859
IK_ModuleProcessed, // OpModuleProcessed (debug)
@@ -396,8 +397,32 @@ class SpirvEntryPoint : public SpirvInstruction {
396397
llvm::SmallVector<SpirvVariable *, 8> interfaceVec;
397398
};
398399

400+
class SpirvExecutionModeBase : public SpirvInstruction {
401+
public:
402+
SpirvExecutionModeBase(Kind kind, spv::Op opcode, SourceLocation loc, SpirvFunction *entryPointFunction,
403+
spv::ExecutionMode executionMode)
404+
: SpirvInstruction(kind, opcode, QualType(), loc),
405+
entryPoint(entryPointFunction), execMode(executionMode) {}
406+
407+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExecutionModeBase)
408+
409+
// For LLVM-style RTTI
410+
static bool classof(const SpirvInstruction *inst) {
411+
return false;
412+
}
413+
414+
bool invokeVisitor(Visitor *v) override;
415+
416+
SpirvFunction *getEntryPoint() const { return entryPoint; }
417+
spv::ExecutionMode getExecutionMode() const { return execMode; }
418+
419+
private:
420+
SpirvFunction *entryPoint;
421+
spv::ExecutionMode execMode;
422+
};
423+
399424
/// \brief OpExecutionMode and OpExecutionModeId instructions
400-
class SpirvExecutionMode : public SpirvInstruction {
425+
class SpirvExecutionMode : public SpirvExecutionModeBase {
401426
public:
402427
SpirvExecutionMode(SourceLocation loc, SpirvFunction *entryPointFunction,
403428
spv::ExecutionMode, llvm::ArrayRef<uint32_t> params,
@@ -412,16 +437,33 @@ class SpirvExecutionMode : public SpirvInstruction {
412437

413438
bool invokeVisitor(Visitor *v) override;
414439

415-
SpirvFunction *getEntryPoint() const { return entryPoint; }
416-
spv::ExecutionMode getExecutionMode() const { return execMode; }
417440
llvm::ArrayRef<uint32_t> getParams() const { return params; }
418441

419442
private:
420-
SpirvFunction *entryPoint;
421-
spv::ExecutionMode execMode;
422443
llvm::SmallVector<uint32_t, 4> params;
423444
};
424445

446+
/// \brief OpExecutionModeId
447+
class SpirvExecutionModeId : public SpirvExecutionModeBase {
448+
public:
449+
SpirvExecutionModeId(SourceLocation loc, SpirvFunction *entryPointFunction,
450+
spv::ExecutionMode em, llvm::ArrayRef<SpirvInstruction*> params);
451+
452+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExecutionModeId)
453+
454+
// For LLVM-style RTTI
455+
static bool classof(const SpirvInstruction *inst) {
456+
return inst->getKind() == IK_ExecutionModeId;
457+
}
458+
459+
bool invokeVisitor(Visitor *v) override;
460+
461+
llvm::ArrayRef<SpirvInstruction*> getParams() const { return params; }
462+
463+
private:
464+
llvm::SmallVector<SpirvInstruction*, 4> params;
465+
};
466+
425467
/// \brief OpString instruction
426468
class SpirvString : public SpirvInstruction {
427469
public:

tools/clang/include/clang/SPIRV/SpirvModule.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ class SpirvModule {
119119

120120
// Returns an existing execution mode instruction that is the same as em if it
121121
// exists. Return nullptr otherwise.
122-
SpirvExecutionMode *findExecutionMode(SpirvFunction *entryPoint,
122+
SpirvExecutionModeBase *findExecutionMode(SpirvFunction *entryPoint,
123123
spv::ExecutionMode em);
124124

125125
// Adds an execution mode to the module.
126-
void addExecutionMode(SpirvExecutionMode *);
126+
void addExecutionMode(SpirvExecutionModeBase *em);
127127

128128
// Adds an extension to the module. Returns true if the extension was added.
129129
// Returns false otherwise (e.g. if the extension already existed).
@@ -194,7 +194,7 @@ class SpirvModule {
194194
llvm::SmallVector<SpirvExtInstImport *, 1> extInstSets;
195195
SpirvMemoryModel *memoryModel;
196196
llvm::SmallVector<SpirvEntryPoint *, 1> entryPoints;
197-
llvm::SmallVector<SpirvExecutionMode *, 4> executionModes;
197+
llvm::SmallVector<SpirvExecutionModeBase *, 4> executionModes;
198198
llvm::SmallVector<SpirvString *, 4> constStrings;
199199
std::vector<SpirvSource *> sources;
200200
std::vector<SpirvModuleProcessed *> moduleProcesses;

tools/clang/include/clang/SPIRV/SpirvVisitor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Visitor {
6060
DEFINE_VISIT_METHOD(SpirvExtInstImport)
6161
DEFINE_VISIT_METHOD(SpirvMemoryModel)
6262
DEFINE_VISIT_METHOD(SpirvEntryPoint)
63-
DEFINE_VISIT_METHOD(SpirvExecutionMode)
63+
DEFINE_VISIT_METHOD(SpirvExecutionModeBase)
6464
DEFINE_VISIT_METHOD(SpirvString)
6565
DEFINE_VISIT_METHOD(SpirvSource)
6666
DEFINE_VISIT_METHOD(SpirvModuleProcessed)

tools/clang/lib/SPIRV/CapabilityVisitor.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
634634
return true;
635635
}
636636

637-
bool CapabilityVisitor::visit(SpirvExecutionMode *execMode) {
637+
bool CapabilityVisitor::visit(SpirvExecutionModeBase *execMode) {
638638
spv::ExecutionMode executionMode = execMode->getExecutionMode();
639639
SourceLocation execModeSourceLocation = execMode->getSourceLocation();
640640
SourceLocation entryPointSourceLocation =

tools/clang/lib/SPIRV/CapabilityVisitor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class CapabilityVisitor : public Visitor {
3131

3232
bool visit(SpirvDecoration *decor) override;
3333
bool visit(SpirvEntryPoint *) override;
34-
bool visit(SpirvExecutionMode *) override;
34+
bool visit(SpirvExecutionModeBase *execMode) override;
3535
bool visit(SpirvImageQuery *) override;
3636
bool visit(SpirvImageOp *) override;
3737
bool visit(SpirvImageSparseTexelsResident *) override;

tools/clang/lib/SPIRV/EmitVisitor.cpp

+16-7
Original file line numberDiff line numberDiff line change
@@ -613,18 +613,27 @@ bool EmitVisitor::visit(SpirvEntryPoint *inst) {
613613
return true;
614614
}
615615

616-
bool EmitVisitor::visit(SpirvExecutionMode *inst) {
616+
bool EmitVisitor::visit(SpirvExecutionModeBase *inst) {
617617
initInstruction(inst);
618618
curInst.push_back(getOrAssignResultId<SpirvFunction>(inst->getEntryPoint()));
619619
curInst.push_back(static_cast<uint32_t>(inst->getExecutionMode()));
620620
if (inst->getopcode() == spv::Op::OpExecutionMode) {
621-
curInst.insert(curInst.end(), inst->getParams().begin(),
622-
inst->getParams().end());
621+
ArrayRef<uint32_t> params = static_cast<SpirvExecutionMode*>(inst)->getParams();
622+
curInst.insert(curInst.end(), params.begin(), params.end());
623623
} else {
624-
for (uint32_t param : inst->getParams()) {
625-
curInst.push_back(typeHandler.getOrCreateConstantInt(
626-
llvm::APInt(32, param), context.getUIntType(32),
627-
/*isSpecConst */ false));
624+
if (inst->getKind() == SpirvInstruction::IK_ExecutionModeId) {
625+
auto* exeModeId = static_cast<SpirvExecutionModeId*>(inst);
626+
for (SpirvInstruction* param : exeModeId->getParams()) {
627+
uint32_t id = getOrAssignResultId<SpirvInstruction>(param);
628+
curInst.push_back(id);
629+
}
630+
} else {
631+
ArrayRef<uint32_t> params = static_cast<SpirvExecutionMode*>(inst)->getParams();
632+
for (uint32_t param : params) {
633+
curInst.push_back(typeHandler.getOrCreateConstantInt(
634+
llvm::APInt(32, param), context.getUIntType(32),
635+
/*isSpecConst */ false));
636+
}
628637
}
629638
}
630639
finalizeInstruction(&preambleBinary);

tools/clang/lib/SPIRV/EmitVisitor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class EmitVisitor : public Visitor {
228228
bool visit(SpirvEmitVertex *) override;
229229
bool visit(SpirvEndPrimitive *) override;
230230
bool visit(SpirvEntryPoint *) override;
231-
bool visit(SpirvExecutionMode *) override;
231+
bool visit(SpirvExecutionModeBase *) override;
232232
bool visit(SpirvString *) override;
233233
bool visit(SpirvSource *) override;
234234
bool visit(SpirvModuleProcessed *) override;

tools/clang/lib/SPIRV/SpirvEmitter.cpp

+28-8
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
#ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
3636
#include "clang/Basic/Version.h"
37+
#include "clang/Sema/Lookup.h"
3738
#else
3839
namespace clang {
3940
uint32_t getGitCommitCount() { return 0; }
@@ -13226,14 +13227,24 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
1322613227

1322713228
void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
1322813229
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
13229-
assert(numThreadsAttr && "thread group size missing from entry-point");
13230+
auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>();
13231+
assert((numThreadsAttr || localSizeIdAttr) && "thread group size missing from entry-point");
1323013232

13231-
uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
13232-
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
13233-
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());
13233+
if (numThreadsAttr) {
13234+
uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
13235+
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
13236+
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());
1323413237

13235-
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13236-
{x, y, z}, decl->getLocation());
13238+
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13239+
{x, y, z}, decl->getLocation());
13240+
} else {
13241+
auto * exprX = localSizeIdAttr->getX();
13242+
auto *x = doExpr(exprX);
13243+
auto *y = doExpr(localSizeIdAttr->getY());
13244+
auto *z = doExpr(localSizeIdAttr->getZ());
13245+
spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::LocalSizeId,
13246+
{x, y, z}, decl->getLocation());
13247+
}
1323713248

1323813249
auto *waveSizeAttr = decl->getAttr<HLSLWaveSizeAttr>();
1323913250
if (waveSizeAttr) {
@@ -13461,6 +13472,12 @@ bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
1346113472
z = static_cast<uint32_t>(numThreadsAttr->getZ());
1346213473
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
1346313474
{x, y, z}, decl->getLocation());
13475+
} else if (auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>()) {
13476+
auto *x = doExpr(localSizeIdAttr->getX());
13477+
auto *y = doExpr(localSizeIdAttr->getY());
13478+
auto *z = doExpr(localSizeIdAttr->getZ());
13479+
spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::LocalSizeId,
13480+
{x, y, z}, decl->getLocation());
1346413481
}
1346513482

1346613483
// Early return for amplification shaders as they only take the 'numthreads'
@@ -15022,9 +15039,12 @@ bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
1502215039
void SpirvEmitter::addDerivativeGroupExecutionMode() {
1502315040
assert(spvContext.isCS());
1502415041

15025-
SpirvExecutionMode *numThreadsEm = spvBuilder.getModule()->findExecutionMode(
15042+
SpirvExecutionModeBase *numThreadsEm = spvBuilder.getModule()->findExecutionMode(
1502615043
entryFunction, spv::ExecutionMode::LocalSize);
15027-
auto numThreads = numThreadsEm->getParams();
15044+
15045+
// TODO: Need to handle LocalSizeID as well.
15046+
assert(numThreadsEm->getKind() == SpirvInstruction::IK_ExecutionMode);
15047+
auto numThreads = static_cast<SpirvExecutionMode*>(numThreadsEm)->getParams();
1502815048

1502915049
// The layout of the quad is determined by the numer of threads in each
1503015050
// dimention. From the HLSL spec

tools/clang/lib/SPIRV/SpirvInstruction.cpp

+12-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtension)
2929
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtInstImport)
3030
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvMemoryModel)
3131
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvEntryPoint)
32+
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionModeBase)
3233
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionMode)
34+
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionModeId)
3335
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvString)
3436
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSource)
3537
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvModuleProcessed)
@@ -203,13 +205,19 @@ SpirvExecutionMode::SpirvExecutionMode(SourceLocation loc, SpirvFunction *entry,
203205
spv::ExecutionMode em,
204206
llvm::ArrayRef<uint32_t> paramsVec,
205207
bool usesIdParams)
206-
: SpirvInstruction(IK_ExecutionMode,
208+
: SpirvExecutionModeBase(IK_ExecutionMode,
207209
usesIdParams ? spv::Op::OpExecutionModeId
208-
: spv::Op::OpExecutionMode,
209-
QualType(), loc),
210-
entryPoint(entry), execMode(em),
210+
: spv::Op::OpExecutionMode, loc, entry, em),
211211
params(paramsVec.begin(), paramsVec.end()) {}
212212

213+
SpirvExecutionModeId::SpirvExecutionModeId(SourceLocation loc, SpirvFunction *entry,
214+
spv::ExecutionMode em,
215+
llvm::ArrayRef<SpirvInstruction*> paramsVec)
216+
: SpirvExecutionModeBase(IK_ExecutionModeId,
217+
spv::Op::OpExecutionModeId, loc, entry, em),
218+
params(paramsVec.begin(), paramsVec.end()) {
219+
}
220+
213221
SpirvString::SpirvString(SourceLocation loc, llvm::StringRef stringLiteral)
214222
: SpirvInstruction(IK_String, spv::Op::OpString, QualType(), loc),
215223
str(stringLiteral) {}

tools/clang/lib/SPIRV/SpirvModule.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,10 @@ void SpirvModule::addEntryPoint(SpirvEntryPoint *ep) {
294294
entryPoints.push_back(ep);
295295
}
296296

297-
SpirvExecutionMode *SpirvModule::findExecutionMode(SpirvFunction *entryPoint,
297+
SpirvExecutionModeBase *
298+
SpirvModule::findExecutionMode(SpirvFunction *entryPoint,
298299
spv::ExecutionMode em) {
299-
for (SpirvExecutionMode *cem : executionModes) {
300+
for (SpirvExecutionModeBase *cem : executionModes) {
300301
if (cem->getEntryPoint() != entryPoint)
301302
continue;
302303
if (cem->getExecutionMode() != em)
@@ -306,7 +307,7 @@ SpirvExecutionMode *SpirvModule::findExecutionMode(SpirvFunction *entryPoint,
306307
return nullptr;
307308
}
308309

309-
void SpirvModule::addExecutionMode(SpirvExecutionMode *em) {
310+
void SpirvModule::addExecutionMode(SpirvExecutionModeBase *em) {
310311
assert(em && "cannot add null execution mode");
311312
executionModes.push_back(em);
312313
}

0 commit comments

Comments
 (0)