diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h index 38e2417dcd181..757f8a3afc758 100644 --- a/clang/include/clang-c/Index.h +++ b/clang/include/clang-c/Index.h @@ -3034,7 +3034,8 @@ enum CXTypeKind { /* HLSL Types */ CXType_HLSLResource = 179, - CXType_HLSLAttributedResource = 180 + CXType_HLSLAttributedResource = 180, + CXType_HLSLInlineSpirv = 181 }; /** diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h index a24f30815e6b9..c62f9f7672010 100644 --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -260,6 +260,7 @@ class ASTContext : public RefCountedBase { DependentBitIntTypes; mutable llvm::FoldingSet BTFTagAttributedTypes; llvm::FoldingSet HLSLAttributedResourceTypes; + llvm::FoldingSet HLSLInlineSpirvTypes; mutable llvm::FoldingSet CountAttributedTypes; @@ -1795,6 +1796,10 @@ class ASTContext : public RefCountedBase { QualType Wrapped, QualType Contained, const HLSLAttributedResourceType::Attributes &Attrs); + QualType getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, + uint32_t Alignment, + ArrayRef Operands); + QualType getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl, unsigned Index, diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h index f086d8134a64b..fd9108221590e 100644 --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -450,6 +450,24 @@ class ASTNodeTraverser if (!Contained.isNull()) Visit(Contained); } + void VisitHLSLInlineSpirvType(const HLSLInlineSpirvType *T) { + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::kConstantId: + case SpirvOperandKind::kLiteral: + break; + + case SpirvOperandKind::kTypeId: + Visit(Operand.getResultType()); + break; + + default: + llvm_unreachable("Invalid SpirvOperand kind!"); + } + } + } void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *) {} void VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) { diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td index 5171555008ac9..7d5e6671fec7d 100644 --- a/clang/include/clang/AST/PropertiesBase.td +++ b/clang/include/clang/AST/PropertiesBase.td @@ -147,6 +147,7 @@ def UInt64 : CountPropertyType<"uint64_t">; def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">; def VectorKind : EnumPropertyType<"VectorKind">; def TypeCoupledDeclRefInfo : PropertyType; +def HLSLSpirvOperand : PropertyType<"SpirvOperand"> { let PassByReference = 1; } def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> { let BufferElementTypes = [ QualType ]; diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index 0530996ed20d3..255e39a46db09 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -1154,6 +1154,14 @@ DEF_TRAVERSE_TYPE(BTFTagAttributedType, DEF_TRAVERSE_TYPE(HLSLAttributedResourceType, { TRY_TO(TraverseType(T->getWrappedType())); }) +DEF_TRAVERSE_TYPE(HLSLInlineSpirvType, { + for (auto &Operand : T->getOperands()) { + if (Operand.isConstant() || Operand.isType()) { + TRY_TO(TraverseType(Operand.getResultType())); + } + } +}) + DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); }) DEF_TRAVERSE_TYPE(MacroQualifiedType, @@ -1457,6 +1465,9 @@ DEF_TRAVERSE_TYPELOC(BTFTagAttributedType, DEF_TRAVERSE_TYPELOC(HLSLAttributedResourceType, { TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); }) +DEF_TRAVERSE_TYPELOC(HLSLInlineSpirvType, + { TRY_TO(TraverseType(TL.getType())); }) + DEF_TRAVERSE_TYPELOC(ElaboratedType, { if (TL.getQualifierLoc()) { TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc())); diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h index cfd417068abb7..f351e68d5297d 100644 --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -2652,6 +2652,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase { bool isHLSLSpecificType() const; // Any HLSL specific type bool isHLSLBuiltinIntangibleType() const; // Any HLSL builtin intangible type bool isHLSLAttributedResourceType() const; + bool isHLSLInlineSpirvType() const; bool isHLSLResourceRecord() const; bool isHLSLIntangibleType() const; // Any HLSL intangible type (builtin, array, class) @@ -6330,6 +6331,140 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode { findHandleTypeOnResource(const Type *RT); }; +/// Instances of this class represent operands to a SPIR-V type instruction. +class SpirvOperand { +public: + enum SpirvOperandKind : unsigned char { + kInvalid, ///< Uninitialized. + kConstantId, ///< Integral value to represent as a SPIR-V OpConstant + ///< instruction ID. + kLiteral, ///< Integral value to represent as an immediate literal. + kTypeId, ///< Type to represent as a SPIR-V type ID. + + kMax, + }; + +private: + SpirvOperandKind Kind = kInvalid; + + QualType ResultType; + llvm::APInt Value; // Signedness of constants is represented by ResultType. + +public: + SpirvOperand() : Kind(kInvalid), ResultType() {} + + SpirvOperand(SpirvOperandKind Kind, QualType ResultType, llvm::APInt Value) + : Kind(Kind), ResultType(ResultType), Value(Value) {} + + SpirvOperand(const SpirvOperand &Other) { *this = Other; } + ~SpirvOperand() {} + + SpirvOperand &operator=(const SpirvOperand &Other) { + this->Kind = Other.Kind; + this->ResultType = Other.ResultType; + this->Value = Other.Value; + return *this; + } + + bool operator==(const SpirvOperand &Other) const { + return Kind == Other.Kind && ResultType == Other.ResultType && + Value == Other.Value; + } + + bool operator!=(const SpirvOperand &Other) const { return !(*this == Other); } + + SpirvOperandKind getKind() const { return Kind; } + + bool isValid() const { return Kind != kInvalid && Kind < kMax; } + bool isConstant() const { return Kind == kConstantId; } + bool isLiteral() const { return Kind == kLiteral; } + bool isType() const { return Kind == kTypeId; } + + llvm::APInt getValue() const { + assert((isConstant() || isLiteral()) && + "This is not an operand with a value!"); + return Value; + } + + QualType getResultType() const { + assert((isConstant() || isType()) && + "This is not an operand with a result type!"); + return ResultType; + } + + static SpirvOperand createConstant(QualType ResultType, llvm::APInt Val) { + return SpirvOperand(kConstantId, ResultType, Val); + } + + static SpirvOperand createLiteral(llvm::APInt Val) { + return SpirvOperand(kLiteral, QualType(), Val); + } + + static SpirvOperand createType(QualType T) { + return SpirvOperand(kTypeId, T, llvm::APSInt()); + } + + void Profile(llvm::FoldingSetNodeID &ID) const { + ID.AddInteger(Kind); + ID.AddPointer(ResultType.getAsOpaquePtr()); + Value.Profile(ID); + } +}; + +/// Represents an arbitrary, user-specified SPIR-V type instruction. +class HLSLInlineSpirvType final + : public Type, + public llvm::FoldingSetNode, + private llvm::TrailingObjects { + friend class ASTContext; // ASTContext creates these + friend TrailingObjects; + +private: + uint32_t Opcode; + uint32_t Size; + uint32_t Alignment; + size_t NumOperands; + + HLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, uint32_t Alignment, + ArrayRef Operands) + : Type(HLSLInlineSpirv, QualType(), TypeDependence::None), Opcode(Opcode), + Size(Size), Alignment(Alignment), NumOperands(Operands.size()) { + for (size_t I = 0; I < NumOperands; I++) { + getTrailingObjects()[I] = Operands[I]; + } + } + +public: + uint32_t getOpcode() const { return Opcode; } + uint32_t getSize() const { return Size; } + uint32_t getAlignment() const { return Alignment; } + ArrayRef getOperands() const { + return {getTrailingObjects(), NumOperands}; + } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, Opcode, Size, Alignment, getOperands()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, uint32_t Opcode, + uint32_t Size, uint32_t Alignment, + ArrayRef Operands) { + ID.AddInteger(Opcode); + ID.AddInteger(Size); + ID.AddInteger(Alignment); + for (auto &Operand : Operands) { + Operand.Profile(ID); + } + } + + static bool classof(const Type *T) { + return T->getTypeClass() == HLSLInlineSpirv; + } +}; + class TemplateTypeParmType : public Type, public llvm::FoldingSetNode { friend class ASTContext; // ASTContext creates these @@ -8458,13 +8593,18 @@ inline bool Type::isHLSLBuiltinIntangibleType() const { } inline bool Type::isHLSLSpecificType() const { - return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType(); + return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType() || + isHLSLInlineSpirvType(); } inline bool Type::isHLSLAttributedResourceType() const { return isa(this); } +inline bool Type::isHLSLInlineSpirvType() const { + return isa(this); +} + inline bool Type::isTemplateTypeParmType() const { return isa(CanonicalType); } diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h index 92661b8b13fe0..53c7ea8c65df2 100644 --- a/clang/include/clang/AST/TypeLoc.h +++ b/clang/include/clang/AST/TypeLoc.h @@ -973,6 +973,25 @@ class HLSLAttributedResourceTypeLoc } }; +struct HLSLInlineSpirvTypeLocInfo { + SourceLocation Loc; +}; // Nothing. + +class HLSLInlineSpirvTypeLoc + : public ConcreteTypeLoc { +public: + SourceLocation getSpirvTypeLoc() const { return getLocalData()->Loc; } + void setSpirvTypeLoc(SourceLocation loc) const { getLocalData()->Loc = loc; } + + SourceRange getLocalSourceRange() const { + return SourceRange(getSpirvTypeLoc(), getSpirvTypeLoc()); + } + void initializeLocal(ASTContext &Context, SourceLocation loc) { + setSpirvTypeLoc(loc); + } +}; + struct ObjCObjectTypeLocInfo { SourceLocation TypeArgsLAngleLoc; SourceLocation TypeArgsRAngleLoc; diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td index 391fd26a086f7..784c2104f1bb2 100644 --- a/clang/include/clang/AST/TypeProperties.td +++ b/clang/include/clang/AST/TypeProperties.td @@ -719,6 +719,24 @@ let Class = HLSLAttributedResourceType in { }]>; } +let Class = HLSLInlineSpirvType in { + def : Property<"opcode", UInt32> { + let Read = [{ node->getOpcode() }]; + } + def : Property<"size", UInt32> { + let Read = [{ node->getSize() }]; + } + def : Property<"alignment", UInt32> { + let Read = [{ node->getAlignment() }]; + } + def : Property<"operands", Array> { + let Read = [{ node->getOperands() }]; + } + def : Creator<[{ + return ctx.getHLSLInlineSpirvType(opcode, size, alignment, operands); + }]>; +} + let Class = DependentAddressSpaceType in { def : Property<"pointeeType", QualType> { let Read = [{ node->getPointeeType() }]; diff --git a/clang/include/clang/Basic/BuiltinTemplates.td b/clang/include/clang/Basic/BuiltinTemplates.td index d46ce063d2f7e..5b9672b395955 100644 --- a/clang/include/clang/Basic/BuiltinTemplates.td +++ b/clang/include/clang/Basic/BuiltinTemplates.td @@ -28,25 +28,37 @@ class BuiltinNTTP : TemplateArg<""> { } def SizeT : BuiltinNTTP<"size_t"> {} +def Uint32T: BuiltinNTTP<"uint32_t"> {} class BuiltinTemplate template_head> { list TemplateHead = template_head; } +class CPlusPlusBuiltinTemplate template_head> : BuiltinTemplate; + +class HLSLBuiltinTemplate template_head> : BuiltinTemplate; + // template