@@ -2705,6 +2705,8 @@ llvm::StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,
2705
2705
}
2706
2706
2707
2707
const char *OP::GetOpCodeName (OpCode opCode) {
2708
+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2709
+ " otherwise caller passed OOB index" );
2708
2710
return m_OpCodeProps[(unsigned )opCode].pOpCodeName ;
2709
2711
}
2710
2712
@@ -2717,22 +2719,26 @@ const char *OP::GetAtomicOpName(DXIL::AtomicBinOpCode OpCode) {
2717
2719
}
2718
2720
2719
2721
OP::OpCodeClass OP::GetOpCodeClass (OpCode opCode) {
2722
+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2723
+ " otherwise caller passed OOB index" );
2720
2724
return m_OpCodeProps[(unsigned )opCode].opCodeClass ;
2721
2725
}
2722
2726
2723
2727
const char *OP::GetOpCodeClassName (OpCode opCode) {
2728
+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2729
+ " otherwise caller passed OOB index" );
2724
2730
return m_OpCodeProps[(unsigned )opCode].pOpCodeClassName ;
2725
2731
}
2726
2732
2727
2733
llvm::Attribute::AttrKind OP::GetMemAccessAttr (OpCode opCode) {
2734
+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2735
+ " otherwise caller passed OOB index" );
2728
2736
return m_OpCodeProps[(unsigned )opCode].FuncAttr ;
2729
2737
}
2730
2738
2731
2739
bool OP::IsOverloadLegal (OpCode opCode, Type *pType) {
2732
- if (!pType)
2733
- return false ;
2734
- if (opCode == OpCode::NumOpCodes)
2735
- return false ;
2740
+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2741
+ " otherwise caller passed OOB index" );
2736
2742
unsigned TypeSlot = GetTypeSlot (pType);
2737
2743
return TypeSlot != UINT_MAX &&
2738
2744
m_OpCodeProps[(unsigned )opCode].bAllowOverload [TypeSlot];
@@ -2808,13 +2814,8 @@ bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode) {
2808
2814
}
2809
2815
2810
2816
OP::OpCode OP::getOpCode (const llvm::Instruction *I) {
2811
- auto *OpConst = llvm::dyn_cast<llvm::ConstantInt>(I->getOperand (0 ));
2812
- if (!OpConst)
2813
- return OpCode::NumOpCodes;
2814
- uint64_t OpCodeVal = OpConst->getZExtValue ();
2815
- if (OpCodeVal >= static_cast <uint64_t >(OP::OpCode::NumOpCodes))
2816
- return OP::OpCode::NumOpCodes;
2817
- return static_cast <OP::OpCode>(OpCodeVal);
2817
+ return (OP::OpCode)llvm::cast<llvm::ConstantInt>(I->getOperand (0 ))
2818
+ ->getZExtValue ();
2818
2819
}
2819
2820
2820
2821
OP::OpCode OP::GetDxilOpFuncCallInst (const llvm::Instruction *I) {
@@ -3524,7 +3525,9 @@ void OP::RefreshCache() {
3524
3525
CallInst *CI = cast<CallInst>(*F.user_begin ());
3525
3526
OpCode OpCode = OP::GetDxilOpFuncCallInst (CI);
3526
3527
Type *pOverloadType = OP::GetOverloadType (OpCode, &F);
3527
- GetOpFunc (OpCode, pOverloadType);
3528
+ Function *OpFunc = GetOpFunc (OpCode, pOverloadType);
3529
+ (void )(OpFunc);
3530
+ DXASSERT_NOMSG (OpFunc == &F);
3528
3531
}
3529
3532
}
3530
3533
}
@@ -3543,15 +3546,13 @@ void OP::FixOverloadNames() {
3543
3546
CallInst *CI = cast<CallInst>(*F.user_begin ());
3544
3547
DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst (CI);
3545
3548
llvm::Type *Ty = OP::GetOverloadType (opCode, &F);
3546
- if (!OP::IsOverloadLegal (opCode, Ty))
3547
- continue ;
3548
- if (!isa<StructType>(Ty) && !isa<PointerType>(Ty))
3549
- continue ;
3550
-
3551
- std::string funcName;
3552
- if (OP::ConstructOverloadName (Ty, opCode, funcName)
3553
- .compare (F.getName ()) != 0 )
3554
- F.setName (funcName);
3549
+ if (isa<StructType>(Ty) || isa<PointerType>(Ty)) {
3550
+ std::string funcName;
3551
+ if (OP::ConstructOverloadName (Ty, opCode, funcName)
3552
+ .compare (F.getName ()) != 0 ) {
3553
+ F.setName (funcName);
3554
+ }
3555
+ }
3555
3556
}
3556
3557
}
3557
3558
}
@@ -3562,11 +3563,12 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) {
3562
3563
}
3563
3564
3564
3565
Function *OP::GetOpFunc (OpCode opCode, Type *pOverloadType) {
3565
- if (opCode == OpCode::NumOpCodes)
3566
- return nullptr ;
3567
- if (!IsOverloadLegal (opCode, pOverloadType))
3568
- return nullptr ;
3569
-
3566
+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
3567
+ " otherwise caller passed OOB OpCode" );
3568
+ assert (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes);
3569
+ DXASSERT (IsOverloadLegal (opCode, pOverloadType),
3570
+ " otherwise the caller requested illegal operation overload (eg HLSL "
3571
+ " function with unsupported types for mapped intrinsic function)" );
3570
3572
OpCodeClass opClass = m_OpCodeProps[(unsigned )opCode].opCodeClass ;
3571
3573
Function *&F =
3572
3574
m_OpCodeClassCache[(unsigned )opClass].pOverloads [pOverloadType];
@@ -5509,8 +5511,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
5509
5511
// and return values to ensure that ResRetType is constructed in the
5510
5512
// RefreshCache case.
5511
5513
if (Function *existF = m_pModule->getFunction (funcName)) {
5512
- if (existF->getFunctionType () != pFT)
5513
- return nullptr ;
5514
+ DXASSERT (existF->getFunctionType () == pFT,
5515
+ " existing function must have the expected function type " ) ;
5514
5516
F = existF;
5515
5517
UpdateCache (opClass, pOverloadType, F);
5516
5518
return F;
@@ -5529,6 +5531,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
5529
5531
5530
5532
const SmallMapVector<llvm::Type *, llvm::Function *, 8 > &
5531
5533
OP::GetOpFuncList (OpCode opCode) const {
5534
+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
5535
+ " otherwise caller passed OOB OpCode" );
5536
+ assert (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes);
5532
5537
return m_OpCodeClassCache[(unsigned )m_OpCodeProps[(unsigned )opCode]
5533
5538
.opCodeClass ]
5534
5539
.pOverloads ;
@@ -5626,8 +5631,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
5626
5631
case OpCode::CallShader:
5627
5632
case OpCode::Pack4x8:
5628
5633
case OpCode::WaveMatrix_Fill:
5629
- if (FT->getNumParams () <= 2 )
5630
- return nullptr ;
5634
+ DXASSERT_NOMSG (FT->getNumParams () > 2 );
5631
5635
return FT->getParamType (2 );
5632
5636
case OpCode::MinPrecXRegStore:
5633
5637
case OpCode::StoreOutput:
@@ -5637,8 +5641,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
5637
5641
case OpCode::StoreVertexOutput:
5638
5642
case OpCode::StorePrimitiveOutput:
5639
5643
case OpCode::DispatchMesh:
5640
- if (FT->getNumParams () <= 4 )
5641
- return nullptr ;
5644
+ DXASSERT_NOMSG (FT->getNumParams () > 4 );
5642
5645
return FT->getParamType (4 );
5643
5646
case OpCode::IsNaN:
5644
5647
case OpCode::IsInf:
@@ -5656,27 +5659,22 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
5656
5659
case OpCode::WaveActiveAllEqual:
5657
5660
case OpCode::CreateHandleForLib:
5658
5661
case OpCode::WaveMatch:
5659
- if (FT->getNumParams () <= 1 )
5660
- return nullptr ;
5662
+ DXASSERT_NOMSG (FT->getNumParams () > 1 );
5661
5663
return FT->getParamType (1 );
5662
5664
case OpCode::TextureStore:
5663
5665
case OpCode::TextureStoreSample:
5664
- if (FT->getNumParams () <= 5 )
5665
- return nullptr ;
5666
+ DXASSERT_NOMSG (FT->getNumParams () > 5 );
5666
5667
return FT->getParamType (5 );
5667
5668
case OpCode::TraceRay:
5668
- if (FT->getNumParams () <= 15 )
5669
- return nullptr ;
5669
+ DXASSERT_NOMSG (FT->getNumParams () > 15 );
5670
5670
return FT->getParamType (15 );
5671
5671
case OpCode::ReportHit:
5672
5672
case OpCode::WaveMatrix_ScalarOp:
5673
- if (FT->getNumParams () <= 3 )
5674
- return nullptr ;
5673
+ DXASSERT_NOMSG (FT->getNumParams () > 3 );
5675
5674
return FT->getParamType (3 );
5676
5675
case OpCode::WaveMatrix_LoadGroupShared:
5677
5676
case OpCode::WaveMatrix_StoreGroupShared:
5678
- if (FT->getNumParams () <= 2 )
5679
- return nullptr ;
5677
+ DXASSERT_NOMSG (FT->getNumParams () > 2 );
5680
5678
return FT->getParamType (2 )->getPointerElementType ();
5681
5679
case OpCode::CreateHandle:
5682
5680
case OpCode::BufferUpdateCounter:
0 commit comments