Skip to content

Commit f09c7e8

Browse files
authored
Move @SpecConst rate propagation from IR cloning to IR instruction creation (#10694)
Fixes #10373 Spec-const rate propagation is performed in `cloneInstAndOperands`, which only applies the `@SpecConst` rate when instructions are cloned (e.g. during generic specialization), not when they were first created. This meant that `static const` expressions over specialization constants (like `FOO2 == 0 ? 0 : FOO1 / FOO2`) created directly do not receive the `@SpecConst` rate and would fail to be emitted as `OpSpecConstantOp` in SPIR-V. PR #10391 attempted to fix this but caused regressions and was partially reverted in #10663. This change moves `shouldHaveSpecConstRate` and `ensureSpecConstRate` (formerly `maybeAddSpecConstRate`) out of `slang-ir-clone.cpp` into `slang-ir-util.cpp` as general utilities, and calls them from `IRBuilder::_createInst` so that spec-const rate is applied at instruction creation time regardless of how the instruction is produced. Instructions that receive spec-const rate are routed through `_findOrEmitHoistableInst` for deduplication, matching the prior behavior for hoistable ops. `ensureSpecConstRate` now correctly replaces an existing `ConstExpr` rate with `SpecConst` when `shouldHaveSpecConstRate` determines an instruction should be a specialization constant. This handles `static const` expressions whose operands are specialization constants, since such values are only known at pipeline creation time and `ConstExpr` (compile-time) is incorrect. `isInstHoistable` is simplified to only check the `kIROpFlag_Hoistable` flag, and the now-unused `isSpecConstOpHoistable` is removed. `isLegalGlobalInstForTarget` in the SPIR-V legalization pass treats any instruction with a spec-const rate type as a legal global instruction, preventing the global inst inlining pass from moving them into function bodies. This is required because `OpSpecConstantOp` results must appear at module scope per the SPIR-V spec. A targeted fix in `_cloneInstDecorationsAndChildren` prevents `NameHintDecoration` accumulation on deduplicated hoistable spec const instructions across repeated generic-specialization passes, without blocking children cloning (which is needed for collection-like hoistable insts such as `FuncSet`). Re-enables the test disabled in #10663 and adds a regression test that more closely resembles the reproducer code from #10605 (the regression issue fixed by #10663) to accompany the minimal regression test. Also, closes #10665 Changes from that PR, which are required for the changes described above, have been folded into this change, and the following is the descriptions of the changes from that PR: When a bool specialization constant is cast to an integer type (or vice-versa) inside a global constant initializer, the SPIR-V emitter needs to lower the cast as an OpSpecConstantOp. The existing integer cast handling in emitSpecializationConstantOp only covered integer-to-integer conversions via UConvert/SConvert, but those opcodes cannot accept OpTypeBool as a source or destination type, so casts involving bools were unhandled. This adds two cases before the integer-to-integer path. For bool-to-integer, the emitter now produces OpSpecConstantOp ... Select %bool_operand %one %zero, picking between literal 1 and 0 since UConvert/SConvert cannot accept a bool operand. For integer-to-bool, it produces OpSpecConstantOp ... INotEqual %int_operand %zero, comparing the operand against zero since UConvert/SConvert cannot produce a bool result. Also adds a regression test with both a SPIR-V assembly filecheck verifying the expected OpSpecConstantOp opcodes and a Vulkan compute test verifying runtime correctness.
1 parent 41c9c1c commit f09c7e8

9 files changed

Lines changed: 348 additions & 56 deletions

source/slang/slang-emit-spirv.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3549,15 +3549,55 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
35493549
}
35503550

35513551
// Handle integer casts in spec constant context.
3552-
// Inside OpSpecConstantOp, UConvert/SConvert require different bit widths
3553-
// and matching signedness on the result type, and OpBitcast is not in the
3554-
// set of allowed opcodes. See emitSpecConstantSignReinterpret for the
3555-
// workaround used for signedness changes.
35563552
auto irOp = inst->getOp();
35573553
if (irOp == kIROp_IntCast || irOp == kIROp_ConstexprIntCast)
35583554
{
35593555
auto srcType = inst->getOperand(0)->getDataType();
35603556
auto dstType = inst->getDataType();
3557+
3558+
// bool to integer: UConvert/SConvert cannot accept a bool operand,
3559+
// so we use OpSelect to pick between literal 1 and 0 instead.
3560+
if (srcType->getOp() == kIROp_BoolType && isIntegralType(dstType))
3561+
{
3562+
auto operand = emitSpecializationConstantOp(inst->getOperand(0));
3563+
IRBuilder builder(m_irModule);
3564+
auto one = emitLit(builder.getIntValue(dstType, 1));
3565+
auto zero = emitLit(builder.getIntValue(dstType, 0));
3566+
return emitInst(
3567+
getSection(SpvLogicalSectionID::ConstantsAndTypes),
3568+
inst,
3569+
SpvOpSpecConstantOp,
3570+
inst->getFullType(),
3571+
kResultID,
3572+
SpvOpSelect,
3573+
operand,
3574+
one,
3575+
zero);
3576+
}
3577+
3578+
// integer to bool: UConvert/SConvert cannot produce a bool result,
3579+
// so we compare the operand against zero with OpINotEqual instead.
3580+
if (isIntegralType(srcType) && dstType->getOp() == kIROp_BoolType)
3581+
{
3582+
auto operand = emitSpecializationConstantOp(inst->getOperand(0));
3583+
IRBuilder builder(m_irModule);
3584+
auto zero = emitLit(builder.getIntValue(srcType, 0));
3585+
return emitInst(
3586+
getSection(SpvLogicalSectionID::ConstantsAndTypes),
3587+
inst,
3588+
SpvOpSpecConstantOp,
3589+
inst->getFullType(),
3590+
kResultID,
3591+
SpvOpINotEqual,
3592+
operand,
3593+
zero);
3594+
}
3595+
3596+
// Inside OpSpecConstantOp, UConvert/SConvert require different bit
3597+
// widths and matching signedness on the result type, and OpBitcast
3598+
// is not in the set of allowed opcodes. See
3599+
// emitSpecConstantSignReinterpret for the workaround used for
3600+
// signedness changes.
35613601
if (isIntegralType(srcType) && isIntegralType(dstType))
35623602
{
35633603
auto srcInfo = getIntTypeInfo(m_targetRequest, srcType);

source/slang/slang-ir-clone.cpp

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -48,43 +48,6 @@ IRInst* findCloneForOperand(IRCloneEnv* env, IRInst* oldOperand)
4848
return oldOperand;
4949
}
5050

51-
static bool shouldHaveSpecConstRate(IRInst* oldInst, IRInst* const* newOperands, UInt operandCount)
52-
{
53-
if (operandCount == 0)
54-
return false;
55-
56-
if (!canOperationBeSpecConst(
57-
oldInst->getOp(),
58-
oldInst->getDataType(),
59-
nullptr,
60-
oldInst->getOperands()))
61-
return false;
62-
63-
// An instruction whose result carries a spec-const rate will be hoisted
64-
// to global scope and, for SPIR-V, emitted as OpSpecConstantOp. That is
65-
// only valid when every operand is itself a specialization constant or a
66-
// plain constant. Mixing in a runtime value (e.g. a function parameter)
67-
// would produce invalid SPIR-V.
68-
//
69-
bool hasSpecConstOperand = false;
70-
for (UInt ii = 0; ii < operandCount; ++ii)
71-
{
72-
if (isSpecConstRateType(newOperands[ii]->getFullType()))
73-
hasSpecConstOperand = true;
74-
else if (!as<IRConstant>(newOperands[ii]))
75-
return false;
76-
}
77-
return hasSpecConstOperand;
78-
}
79-
80-
static IRType* maybeAddSpecConstRate(IRBuilder* builder, IRType* type)
81-
{
82-
// Do not add a spec-const rate if the type already carries a rate.
83-
if (as<IRRateQualifiedType>(type))
84-
return type;
85-
return builder->getRateQualifiedType(builder->getSpecConstRate(), type);
86-
}
87-
8851
IRInst* cloneInstAndOperands(IRCloneEnv* env, IRBuilder* builder, IRInst* oldInst)
8952
{
9053
SLANG_ASSERT(env);
@@ -136,8 +99,12 @@ IRInst* cloneInstAndOperands(IRCloneEnv* env, IRBuilder* builder, IRInst* oldIns
13699
newOperands[ii] = newOperand;
137100
}
138101

139-
if (shouldHaveSpecConstRate(oldInst, newOperands.getArrayView().getBuffer(), operandCount))
140-
newType = maybeAddSpecConstRate(builder, newType);
102+
if (shouldHaveSpecConstRate(
103+
oldInst->getOp(),
104+
newType,
105+
operandCount,
106+
newOperands.getArrayView().getBuffer()))
107+
newType = ensureSpecConstRate(builder, newType);
141108

142109
// Finally we create the inst with the updated operands.
143110
auto newInst = builder->emitIntrinsicInst(
@@ -283,6 +250,12 @@ static void _cloneInstDecorationsAndChildren(
283250
if (lookUp(env, oldChild))
284251
continue;
285252

253+
// When dedup returns a pre-existing instruction (e.g. a hoistable inst),
254+
// cloning NameHintDecorations onto it again would cause unbounded
255+
// accumulation across repeated generic-specialization passes.
256+
if (as<IRNameHintDecoration>(oldChild) && newInst->findDecoration<IRNameHintDecoration>())
257+
continue;
258+
286259
// Now we can perform the first phase of cloning
287260
// on the child, and register it in our map from
288261
// old to new values.

source/slang/slang-ir-spirv-legalize.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,13 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
17801780
{
17811781
bool isLegalGlobalInstForTarget(IRInst* inst) override
17821782
{
1783+
// Spec-const-rate instructions must stay at their current scope
1784+
// (module level or inside a generic) because SPIR-V requires
1785+
// OpSpecConstantOp results to appear outside function bodies.
1786+
// Only ops validated by canOperationBeSpecConst (via
1787+
// shouldHaveSpecConstRate in _createInst) acquire this rate.
1788+
if (isSpecConstRateType(inst->getFullType()))
1789+
return true;
17831790
switch (inst->getOp())
17841791
{
17851792
case kIROp_MakeStruct:

source/slang/slang-ir-util.cpp

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2855,6 +2855,27 @@ IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* o
28552855
return oldType;
28562856
}
28572857

2858+
// Ensures `type` carries a SpecConst rate.
2859+
// If the type already has a different rate (e.g. ConstExpr from a `static const`
2860+
// expression whose operands are specialization constants), the existing rate is
2861+
// replaced, as a value that depends on a spec-const is only known at pipeline
2862+
// creation time, so `ConstExpr` (compile-time) would be incorrect.
2863+
//
2864+
IRType* ensureSpecConstRate(IRBuilder* builder, IRType* type)
2865+
{
2866+
if (isSpecConstRateType(type))
2867+
return type;
2868+
2869+
// Strip any existing rate (e.g. ConstExpr) to avoid double-wrapping,
2870+
// since getRateQualifiedType does not unwrap for us.
2871+
if (auto rateQualified = as<IRRateQualifiedType>(type))
2872+
return builder->getRateQualifiedType(
2873+
builder->getSpecConstRate(),
2874+
rateQualified->getValueType());
2875+
2876+
return builder->getRateQualifiedType(builder->getSpecConstRate(), type);
2877+
}
2878+
28582879
bool canOperationBeSpecConst(IROp op, IRType* resultType, IRInst* const* fixedArgs, IRUse* operands)
28592880
{
28602881
// Returns true for ops that can be declared as an operation under `OpSpecConstantOp`.
@@ -2915,18 +2936,44 @@ bool canOperationBeSpecConst(IROp op, IRType* resultType, IRInst* const* fixedAr
29152936
}
29162937
}
29172938

2918-
bool isSpecConstOpHoistable(IROp op, IRType* type, IRInst* const* fixedArgs)
2939+
bool shouldHaveSpecConstRate(
2940+
IROp op,
2941+
IRType* resultType,
2942+
UInt operandCount,
2943+
IRInst* const* operands)
29192944
{
2920-
auto rateType = as<IRRateQualifiedType>(type);
2921-
return rateType && as<IRSpecConstRate>(rateType->getRate()) &&
2922-
canOperationBeSpecConst(op, rateType->getValueType(), fixedArgs, nullptr);
2923-
}
2945+
if (operandCount == 0)
2946+
return false;
2947+
2948+
// Unwrap any rate qualification so canOperationBeSpecConst sees the bare
2949+
// value type. isFloatingType checks as<IRBasicType> which doesn't match
2950+
// rate-qualified types like @ConstExpr float, so without unwrapping we
2951+
// would incorrectly allow float arithmetic as `OpSpecConstantOp`.
2952+
IRType* valueType = resultType;
2953+
if (auto rateQualifiedType = as<IRRateQualifiedType>(resultType))
2954+
valueType = rateQualifiedType->getValueType();
2955+
2956+
if (!canOperationBeSpecConst(op, valueType, operands, nullptr))
2957+
return false;
29242958

2959+
// An instruction whose result carries a spec-const rate is hoisted and
2960+
// emitted as OpSpecConstantOp for SPIR-V. That is only valid when
2961+
// every operand is itself a specialization constant or a plain
2962+
// constant. Mixing in a runtime value would produce invalid SPIR-V.
2963+
bool hasSpecConstOperand = false;
2964+
for (UInt ii = 0; ii < operandCount; ++ii)
2965+
{
2966+
if (isSpecConstRateType(operands[ii]->getFullType()))
2967+
hasSpecConstOperand = true;
2968+
else if (!as<IRConstant>(operands[ii]))
2969+
return false;
2970+
}
2971+
return hasSpecConstOperand;
2972+
}
29252973

2926-
bool isInstHoistable(IROp op, IRType* type, IRInst* const* fixedArgs)
2974+
bool isInstHoistable(IROp op)
29272975
{
2928-
return (getIROpInfo(op).flags & kIROpFlag_Hoistable) ||
2929-
isSpecConstOpHoistable(op, type, fixedArgs);
2976+
return (getIROpInfo(op).flags & kIROpFlag_Hoistable);
29302977
}
29312978

29322979
IRType* getUnsignedTypeFromSignedType(IRBuilder* builder, IRType* type)

source/slang/slang-ir-util.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,18 @@ bool isFirstBlock(IRInst* inst);
453453
bool isSpecConstRateType(IRType* type);
454454
void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst);
455455
IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* oldType);
456+
IRType* ensureSpecConstRate(IRBuilder* builder, IRType* type);
456457
bool canOperationBeSpecConst(
457458
IROp op,
458459
IRType* resultType,
459460
IRInst* const* fixedArgs,
460461
IRUse* operands);
461-
bool isInstHoistable(IROp op, IRType* type, IRInst* const* fixedArgs);
462+
bool shouldHaveSpecConstRate(
463+
IROp op,
464+
IRType* resultType,
465+
UInt operandCount,
466+
IRInst* const* operands);
467+
bool isInstHoistable(IROp op);
462468

463469
// most of <algorithm> doesn't work on out non-const iterators, so define this
464470
// version

source/slang/slang-ir.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1861,7 +1861,19 @@ IRInst* IRBuilder::_createInst(
18611861
m_dedupContext->getInstReplacementMap().tryGetValue(type, instReplacement);
18621862
type = (IRType*)instReplacement;
18631863

1864-
if (isInstHoistable(op, type, fixedArgs))
1864+
if (type && shouldHaveSpecConstRate(op, type, fixedArgCount, fixedArgs))
1865+
{
1866+
type = ensureSpecConstRate(this, type);
1867+
return _findOrEmitHoistableInst(
1868+
type,
1869+
op,
1870+
fixedArgCount,
1871+
fixedArgs,
1872+
varArgListCount,
1873+
listArgCounts,
1874+
listArgs);
1875+
}
1876+
else if (isInstHoistable(op))
18651877
{
18661878
return _findOrEmitHoistableInst(
18671879
type,
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target cuda -experimental-feature
2+
3+
// Full regression test for https://github.com/shader-slang/slang/issues/10605
4+
5+
// 4-layer MLP with autodiff backward pass targeting CUDA. Exercises the
6+
// interaction between `static const` globals computed from macro expansions,
7+
// generic struct members with constexpr arithmetic (division, ternary),
8+
// and the autodiff pass. The original crash was a null-operand segfault in
9+
// AutoDiffPass::processReferencedFunctions.
10+
11+
// CHECK-DAG: __global__ void __kernel__backward
12+
// CHECK-DAG: s_apply_runWaveMLP
13+
14+
import slang.neural;
15+
16+
#ifndef TCNN_MLP_IN_DIM
17+
#define TCNN_MLP_IN_DIM 2
18+
#endif
19+
20+
#ifndef WT_WAVE_WARPS
21+
#define WT_WAVE_WARPS 4
22+
#endif
23+
24+
static const int IN_DIM = TCNN_MLP_IN_DIM;
25+
static const int HIDDEN_DIM = 128;
26+
static const int OUT_DIM = 3;
27+
static const int SubgroupSize = 32;
28+
static const int WARPS_PER_BLOCK = WT_WAVE_WARPS;
29+
static const float LEAKY_ALPHA = 0.01f;
30+
31+
typealias Storage = TorchTensorViewAddress<half>;
32+
typealias ShMemSize = SharedMemorySize<half, TargetEnum.CUDA, ExecutionMode.Training, SubgroupSize, WARPS_PER_BLOCK>;
33+
typealias ShMemSizeLayer = ShMemSize.OfLayer4<IN_DIM, HIDDEN_DIM, HIDDEN_DIM, HIDDEN_DIM, OUT_DIM>;
34+
typealias InVec = WaveTangledVector<half, ShMemSizeLayer, IN_DIM, SubgroupSize>;
35+
typealias HidVec = WaveTangledVector<half, ShMemSizeLayer, HIDDEN_DIM, SubgroupSize>;
36+
typealias OutVec = WaveTangledVector<half, ShMemSizeLayer, OUT_DIM, SubgroupSize>;
37+
typealias Layer1 = FFLayer<half, InVec, HidVec, LinearLayout, LeakyReLU<half>, true>;
38+
typealias Layer2 = FFLayer<half, HidVec, HidVec, LinearLayout, LeakyReLU<half>, true>;
39+
typealias Layer3 = FFLayer<half, HidVec, HidVec, LinearLayout, LeakyReLU<half>, true>;
40+
typealias Layer4 = FFLayer<half, HidVec, OutVec, LinearLayout, Sigmoid<half>, true>;
41+
42+
static const int TotalParamCount =
43+
Layer1.ParameterCount + Layer2.ParameterCount + Layer3.ParameterCount + Layer4.ParameterCount;
44+
45+
[Differentiable]
46+
OutVec runWaveMLP(InVec input, Storage params)
47+
{
48+
LeakyReLU<half> leaky = LeakyReLU<half>(half(LEAKY_ALPHA));
49+
Layer1 layer1 = Layer1(leaky);
50+
Layer2 layer2 = Layer2(leaky);
51+
Layer3 layer3 = Layer3(leaky);
52+
Layer4 layer4 = Layer4();
53+
54+
int offset = 0;
55+
HidVec h1 = layer1.eval<Storage>(input, params.getOffset(offset)); offset += Layer1.ParameterCount;
56+
HidVec h2 = layer2.eval<Storage>(h1, params.getOffset(offset)); offset += Layer2.ParameterCount;
57+
HidVec h3 = layer3.eval<Storage>(h2, params.getOffset(offset)); offset += Layer3.ParameterCount;
58+
return layer4.eval<Storage>(h3, params.getOffset(offset));
59+
}
60+
61+
[AutoPyBindCUDA]
62+
[CUDAKernel]
63+
void backward(
64+
DiffTensorView input,
65+
TensorView<half> params,
66+
TensorView<half> paramsGrad,
67+
DiffTensorView output)
68+
{
69+
uint idx = cudaBlockIdx().x * cudaBlockDim().x + cudaThreadIdx().x;
70+
bool isActive = idx < input.size(0);
71+
72+
InVec x = InVec(half(0));
73+
if (isActive)
74+
{
75+
for (int i = 0; i < IN_DIM; ++i) x[i] = half(input.primal[idx, i]);
76+
}
77+
78+
OutVec dL_dOutput = OutVec(half(0));
79+
if (isActive)
80+
{
81+
for (int i = 0; i < OUT_DIM; ++i) dL_dOutput[i] = half(output.diff.diff[idx, i]);
82+
}
83+
84+
InVec dx = InVec(half(0));
85+
var dpInput = diffPair(x, dx);
86+
bwd_diff(runWaveMLP)(dpInput,
87+
DifferentialPtrPair<Storage>(Storage(params), Storage(paramsGrad)),
88+
dL_dOutput);
89+
90+
if (isActive)
91+
{
92+
for (int i = 0; i < IN_DIM; ++i) input.diff.diff[idx, i] = float(dpInput.d[i]);
93+
}
94+
}

0 commit comments

Comments
 (0)