Skip to content

Commit 11ee8de

Browse files
authored
Revert "Fix crash in DXIL.dll caused by illegal DXIL intrinsic. (#6302) (#6342)" (#6418)
This file deleted with conflicts from subsequent changes: tools/clang/test/LitDXILValidation/illegalDXILOp.ll This reverts commit 487080f. Fixes #6419.
1 parent c9660a8 commit 11ee8de

File tree

8 files changed

+52
-211
lines changed

8 files changed

+52
-211
lines changed

docs/DXIL.rst

-2
Original file line numberDiff line numberDiff line change
@@ -3080,8 +3080,6 @@ INSTR.EVALINTERPOLATIONMODE Interpolation mode on %0 used with eva
30803080
INSTR.EXTRACTVALUE ExtractValue should only be used on dxil struct types and cmpxchg.
30813081
INSTR.FAILTORESLOVETGSMPOINTER TGSM pointers must originate from an unambiguous TGSM global variable.
30823082
INSTR.HANDLENOTFROMCREATEHANDLE Resource handle should returned by createHandle.
3083-
INSTR.ILLEGALDXILOPCODE DXILOpCode must be [0..%0]. %1 specified.
3084-
INSTR.ILLEGALDXILOPFUNCTION '%0' is not a DXILOpFuncition for DXILOpcode '%1'.
30853083
INSTR.IMMBIASFORSAMPLEB bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.
30863084
INSTR.INBOUNDSACCESS Access to out-of-bounds memory is disallowed.
30873085
INSTR.MINPRECISIONNOTPRECISE Instructions marked precise may not refer to minprecision values.

lib/DXIL/DxilCounters.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ void CountInstructions(llvm::Module &M, DxilCounters &counters) {
336336
}
337337
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
338338
if (hlsl::OP::IsDxilOpFuncCallInst(CI)) {
339-
unsigned opcode = static_cast<unsigned>(hlsl::OP::getOpCode(CI));
339+
unsigned opcode =
340+
(unsigned)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
341+
->getZExtValue();
340342
CountDxilOp(opcode, counters);
341343
}
342344
} else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {

lib/DXIL/DxilOperations.cpp

+40-42
Original file line numberDiff line numberDiff line change
@@ -2705,6 +2705,8 @@ llvm::StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,
27052705
}
27062706

27072707
const char *OP::GetOpCodeName(OpCode opCode) {
2708+
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
2709+
"otherwise caller passed OOB index");
27082710
return m_OpCodeProps[(unsigned)opCode].pOpCodeName;
27092711
}
27102712

@@ -2717,22 +2719,26 @@ const char *OP::GetAtomicOpName(DXIL::AtomicBinOpCode OpCode) {
27172719
}
27182720

27192721
OP::OpCodeClass OP::GetOpCodeClass(OpCode opCode) {
2722+
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
2723+
"otherwise caller passed OOB index");
27202724
return m_OpCodeProps[(unsigned)opCode].opCodeClass;
27212725
}
27222726

27232727
const char *OP::GetOpCodeClassName(OpCode opCode) {
2728+
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
2729+
"otherwise caller passed OOB index");
27242730
return m_OpCodeProps[(unsigned)opCode].pOpCodeClassName;
27252731
}
27262732

27272733
llvm::Attribute::AttrKind OP::GetMemAccessAttr(OpCode opCode) {
2734+
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
2735+
"otherwise caller passed OOB index");
27282736
return m_OpCodeProps[(unsigned)opCode].FuncAttr;
27292737
}
27302738

27312739
bool OP::IsOverloadLegal(OpCode opCode, Type *pType) {
2732-
if (!pType)
2733-
return false;
2734-
if (opCode == OpCode::NumOpCodes)
2735-
return false;
2740+
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
2741+
"otherwise caller passed OOB index");
27362742
unsigned TypeSlot = GetTypeSlot(pType);
27372743
return TypeSlot != UINT_MAX &&
27382744
m_OpCodeProps[(unsigned)opCode].bAllowOverload[TypeSlot];
@@ -2808,13 +2814,8 @@ bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode) {
28082814
}
28092815

28102816
OP::OpCode OP::getOpCode(const llvm::Instruction *I) {
2811-
auto *OpConst = llvm::dyn_cast<llvm::ConstantInt>(I->getOperand(0));
2812-
if (!OpConst)
2813-
return OpCode::NumOpCodes;
2814-
uint64_t OpCodeVal = OpConst->getZExtValue();
2815-
if (OpCodeVal >= static_cast<uint64_t>(OP::OpCode::NumOpCodes))
2816-
return OP::OpCode::NumOpCodes;
2817-
return static_cast<OP::OpCode>(OpCodeVal);
2817+
return (OP::OpCode)llvm::cast<llvm::ConstantInt>(I->getOperand(0))
2818+
->getZExtValue();
28182819
}
28192820

28202821
OP::OpCode OP::GetDxilOpFuncCallInst(const llvm::Instruction *I) {
@@ -3524,7 +3525,9 @@ void OP::RefreshCache() {
35243525
CallInst *CI = cast<CallInst>(*F.user_begin());
35253526
OpCode OpCode = OP::GetDxilOpFuncCallInst(CI);
35263527
Type *pOverloadType = OP::GetOverloadType(OpCode, &F);
3527-
GetOpFunc(OpCode, pOverloadType);
3528+
Function *OpFunc = GetOpFunc(OpCode, pOverloadType);
3529+
(void)(OpFunc);
3530+
DXASSERT_NOMSG(OpFunc == &F);
35283531
}
35293532
}
35303533
}
@@ -3543,15 +3546,13 @@ void OP::FixOverloadNames() {
35433546
CallInst *CI = cast<CallInst>(*F.user_begin());
35443547
DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst(CI);
35453548
llvm::Type *Ty = OP::GetOverloadType(opCode, &F);
3546-
if (!OP::IsOverloadLegal(opCode, Ty))
3547-
continue;
3548-
if (!isa<StructType>(Ty) && !isa<PointerType>(Ty))
3549-
continue;
3550-
3551-
std::string funcName;
3552-
if (OP::ConstructOverloadName(Ty, opCode, funcName)
3553-
.compare(F.getName()) != 0)
3554-
F.setName(funcName);
3549+
if (isa<StructType>(Ty) || isa<PointerType>(Ty)) {
3550+
std::string funcName;
3551+
if (OP::ConstructOverloadName(Ty, opCode, funcName)
3552+
.compare(F.getName()) != 0) {
3553+
F.setName(funcName);
3554+
}
3555+
}
35553556
}
35563557
}
35573558
}
@@ -3562,11 +3563,12 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) {
35623563
}
35633564

35643565
Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
3565-
if (opCode == OpCode::NumOpCodes)
3566-
return nullptr;
3567-
if (!IsOverloadLegal(opCode, pOverloadType))
3568-
return nullptr;
3569-
3566+
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
3567+
"otherwise caller passed OOB OpCode");
3568+
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
3569+
DXASSERT(IsOverloadLegal(opCode, pOverloadType),
3570+
"otherwise the caller requested illegal operation overload (eg HLSL "
3571+
"function with unsupported types for mapped intrinsic function)");
35703572
OpCodeClass opClass = m_OpCodeProps[(unsigned)opCode].opCodeClass;
35713573
Function *&F =
35723574
m_OpCodeClassCache[(unsigned)opClass].pOverloads[pOverloadType];
@@ -5509,8 +5511,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
55095511
// and return values to ensure that ResRetType is constructed in the
55105512
// RefreshCache case.
55115513
if (Function *existF = m_pModule->getFunction(funcName)) {
5512-
if (existF->getFunctionType() != pFT)
5513-
return nullptr;
5514+
DXASSERT(existF->getFunctionType() == pFT,
5515+
"existing function must have the expected function type");
55145516
F = existF;
55155517
UpdateCache(opClass, pOverloadType, F);
55165518
return F;
@@ -5529,6 +5531,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
55295531

55305532
const SmallMapVector<llvm::Type *, llvm::Function *, 8> &
55315533
OP::GetOpFuncList(OpCode opCode) const {
5534+
DXASSERT(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes,
5535+
"otherwise caller passed OOB OpCode");
5536+
assert(0 <= (unsigned)opCode && opCode < OpCode::NumOpCodes);
55325537
return m_OpCodeClassCache[(unsigned)m_OpCodeProps[(unsigned)opCode]
55335538
.opCodeClass]
55345539
.pOverloads;
@@ -5626,8 +5631,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
56265631
case OpCode::CallShader:
56275632
case OpCode::Pack4x8:
56285633
case OpCode::WaveMatrix_Fill:
5629-
if (FT->getNumParams() <= 2)
5630-
return nullptr;
5634+
DXASSERT_NOMSG(FT->getNumParams() > 2);
56315635
return FT->getParamType(2);
56325636
case OpCode::MinPrecXRegStore:
56335637
case OpCode::StoreOutput:
@@ -5637,8 +5641,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
56375641
case OpCode::StoreVertexOutput:
56385642
case OpCode::StorePrimitiveOutput:
56395643
case OpCode::DispatchMesh:
5640-
if (FT->getNumParams() <= 4)
5641-
return nullptr;
5644+
DXASSERT_NOMSG(FT->getNumParams() > 4);
56425645
return FT->getParamType(4);
56435646
case OpCode::IsNaN:
56445647
case OpCode::IsInf:
@@ -5656,27 +5659,22 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
56565659
case OpCode::WaveActiveAllEqual:
56575660
case OpCode::CreateHandleForLib:
56585661
case OpCode::WaveMatch:
5659-
if (FT->getNumParams() <= 1)
5660-
return nullptr;
5662+
DXASSERT_NOMSG(FT->getNumParams() > 1);
56615663
return FT->getParamType(1);
56625664
case OpCode::TextureStore:
56635665
case OpCode::TextureStoreSample:
5664-
if (FT->getNumParams() <= 5)
5665-
return nullptr;
5666+
DXASSERT_NOMSG(FT->getNumParams() > 5);
56665667
return FT->getParamType(5);
56675668
case OpCode::TraceRay:
5668-
if (FT->getNumParams() <= 15)
5669-
return nullptr;
5669+
DXASSERT_NOMSG(FT->getNumParams() > 15);
56705670
return FT->getParamType(15);
56715671
case OpCode::ReportHit:
56725672
case OpCode::WaveMatrix_ScalarOp:
5673-
if (FT->getNumParams() <= 3)
5674-
return nullptr;
5673+
DXASSERT_NOMSG(FT->getNumParams() > 3);
56755674
return FT->getParamType(3);
56765675
case OpCode::WaveMatrix_LoadGroupShared:
56775676
case OpCode::WaveMatrix_StoreGroupShared:
5678-
if (FT->getNumParams() <= 2)
5679-
return nullptr;
5677+
DXASSERT_NOMSG(FT->getNumParams() > 2);
56805678
return FT->getParamType(2)->getPointerElementType();
56815679
case OpCode::CreateHandle:
56825680
case OpCode::BufferUpdateCounter:

lib/DXIL/DxilShaderFlags.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -585,9 +585,13 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
585585
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
586586
if (!OP::IsDxilOpFunc(CI->getCalledFunction()))
587587
continue;
588-
DXIL::OpCode dxilOp = hlsl::OP::getOpCode(CI);
589-
if (dxilOp == DXIL::OpCode::NumOpCodes)
590-
continue;
588+
Value *opcodeArg = CI->getArgOperand(DXIL::OperandIndex::kOpcodeIdx);
589+
ConstantInt *opcodeConst = dyn_cast<ConstantInt>(opcodeArg);
590+
DXASSERT(opcodeConst, "DXIL opcode arg must be immediate");
591+
unsigned opcode = opcodeConst->getLimitedValue();
592+
DXASSERT(opcode < static_cast<unsigned>(DXIL::OpCode::NumOpCodes),
593+
"invalid DXIL opcode");
594+
DXIL::OpCode dxilOp = static_cast<DXIL::OpCode>(opcode);
591595
if (hlsl::OP::IsDxilOpWave(dxilOp))
592596
hasWaveOps = true;
593597
switch (dxilOp) {

lib/HLSL/DxilValidation.cpp

-24
Original file line numberDiff line numberDiff line change
@@ -3208,8 +3208,6 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
32083208
CallInst *setMeshOutputCounts = nullptr;
32093209
CallInst *getMeshPayload = nullptr;
32103210
CallInst *dispatchMesh = nullptr;
3211-
hlsl::OP *hlslOP = ValCtx.DxilMod.GetOP();
3212-
32133211
for (auto b = F->begin(), bend = F->end(); b != bend; ++b) {
32143212
for (auto i = b->begin(), iend = b->end(); i != iend; ++i) {
32153213
llvm::Instruction &I = *i;
@@ -3259,30 +3257,8 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
32593257
}
32603258

32613259
unsigned opcode = OpcodeConst->getLimitedValue();
3262-
if (opcode >= static_cast<unsigned>(DXIL::OpCode::NumOpCodes)) {
3263-
ValCtx.EmitInstrFormatError(
3264-
&I, ValidationRule::InstrIllegalDXILOpCode,
3265-
{std::to_string((unsigned)DXIL::OpCode::NumOpCodes),
3266-
std::to_string(opcode)});
3267-
continue;
3268-
}
32693260
DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;
32703261

3271-
bool IllegalOpFunc = true;
3272-
for (auto &it : hlslOP->GetOpFuncList(dxilOpcode)) {
3273-
if (it.second == FCalled) {
3274-
IllegalOpFunc = false;
3275-
break;
3276-
}
3277-
}
3278-
3279-
if (IllegalOpFunc) {
3280-
ValCtx.EmitInstrFormatError(
3281-
&I, ValidationRule::InstrIllegalDXILOpFunction,
3282-
{FCalled->getName(), OP::GetOpCodeName(dxilOpcode)});
3283-
continue;
3284-
}
3285-
32863262
if (OP::IsDxilOpGradient(dxilOpcode)) {
32873263
gradientOps.push_back(CI);
32883264
}

tools/clang/test/LitDXILValidation/illegalDXILOp.ll

-120
This file was deleted.

utils/hct/hctdb.py

-7
Original file line numberDiff line numberDiff line change
@@ -7324,13 +7324,6 @@ def build_valrules(self):
73247324
"Instr.ImmBiasForSampleB",
73257325
"bias amount for sample_b must be in the range [%0,%1], but %2 was specified as an immediate.",
73267326
)
7327-
self.add_valrule(
7328-
"Instr.IllegalDXILOpCode", "DXILOpCode must be [0..%0]. %1 specified."
7329-
)
7330-
self.add_valrule(
7331-
"Instr.IllegalDXILOpFunction",
7332-
"'%0' is not a DXILOpFuncition for DXILOpcode '%1'.",
7333-
)
73347327
# If streams have not been declared, you must use cut instead of cut_stream in GS - is there an equivalent rule here?
73357328

73367329
# Need to clean up all error messages and actually implement.

0 commit comments

Comments
 (0)