From 78ac1bc4225b41bc4b9fbd9fd9ab9dc82a2953ca Mon Sep 17 00:00:00 2001 From: Cassandra Beckley Date: Tue, 1 Apr 2025 23:12:02 -0700 Subject: [PATCH 1/3] [HLSL] Implement `SpirvType` and `SpirvOpaqueType` This implements the design proposed by [Representing SpirvType in Clang's Type System](https://github.com/llvm/wg-hlsl/pull/181). It creates `HLSLInlineSpirvType` as a new `Type` subclass, and `__hlsl_spirv_type` as a new builtin type template to create such a type. This new type is lowered to the `spirv.Type` target extension type, as described in [Target Extension Types for Inline SPIR-V and Decorated Types](https://github.com/llvm/wg-hlsl/blob/main/proposals/0017-inline-spirv-and-decorated-types.md). --- clang/include/clang-c/Index.h | 3 +- clang/include/clang/AST/ASTContext.h | 5 + clang/include/clang/AST/ASTNodeTraverser.h | 18 +++ clang/include/clang/AST/PropertiesBase.td | 1 + clang/include/clang/AST/RecursiveASTVisitor.h | 11 ++ clang/include/clang/AST/Type.h | 142 +++++++++++++++++- clang/include/clang/AST/TypeLoc.h | 19 +++ clang/include/clang/AST/TypeProperties.td | 18 +++ clang/include/clang/Basic/BuiltinTemplates.td | 18 ++- .../clang/Basic/DiagnosticSemaKinds.td | 3 + clang/include/clang/Basic/TypeNodes.td | 1 + .../clang/Serialization/ASTRecordReader.h | 2 + .../clang/Serialization/ASTRecordWriter.h | 14 ++ .../clang/Serialization/TypeBitCodes.def | 1 + clang/lib/AST/ASTContext.cpp | 59 ++++++++ clang/lib/AST/ASTImporter.cpp | 42 ++++++ clang/lib/AST/ASTStructuralEquivalence.cpp | 17 +++ clang/lib/AST/ExprConstant.cpp | 1 + clang/lib/AST/ItaniumMangle.cpp | 40 ++++- clang/lib/AST/MicrosoftMangle.cpp | 5 + clang/lib/AST/Type.cpp | 14 ++ clang/lib/AST/TypePrinter.cpp | 48 ++++++ clang/lib/CodeGen/CGDebugInfo.cpp | 8 + clang/lib/CodeGen/CGDebugInfo.h | 1 + clang/lib/CodeGen/CodeGenFunction.cpp | 2 + clang/lib/CodeGen/CodeGenTypes.cpp | 6 + clang/lib/CodeGen/ItaniumCXXABI.cpp | 2 + clang/lib/CodeGen/Targets/SPIR.cpp | 90 ++++++++++- clang/lib/Headers/CMakeLists.txt | 1 + clang/lib/Headers/hlsl.h | 4 + clang/lib/Headers/hlsl/hlsl_spirv.h | 30 ++++ clang/lib/Sema/SemaExpr.cpp | 1 + clang/lib/Sema/SemaLookup.cpp | 21 ++- clang/lib/Sema/SemaTemplate.cpp | 103 ++++++++++++- clang/lib/Sema/SemaTemplateDeduction.cpp | 2 + clang/lib/Sema/SemaType.cpp | 1 + clang/lib/Sema/TreeTransform.h | 7 + clang/lib/Serialization/ASTReader.cpp | 9 ++ clang/lib/Serialization/ASTWriter.cpp | 4 + .../test/AST/HLSL/Inputs/pch_spirv_type.hlsl | 6 + clang/test/AST/HLSL/ast-dump-SpirvType.hlsl | 27 ++++ clang/test/AST/HLSL/pch_spirv_type.hlsl | 17 +++ clang/test/AST/HLSL/vector-alias.hlsl | 105 +++++++------ .../inline/SpirvType.alignment.hlsl | 16 ++ .../inline/SpirvType.dx.error.hlsl | 12 ++ clang/test/CodeGenHLSL/inline/SpirvType.hlsl | 68 +++++++++ .../inline/SpirvType.incomplete.hlsl | 14 ++ .../inline/SpirvType.literal.error.hlsl | 11 ++ clang/tools/libclang/CIndex.cpp | 5 + clang/tools/libclang/CXType.cpp | 1 + .../TableGen/ClangBuiltinTemplatesEmitter.cpp | 72 +++++++-- 51 files changed, 1052 insertions(+), 76 deletions(-) create mode 100644 clang/lib/Headers/hlsl/hlsl_spirv.h create mode 100644 clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl create mode 100644 clang/test/AST/HLSL/ast-dump-SpirvType.hlsl create mode 100644 clang/test/AST/HLSL/pch_spirv_type.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.alignment.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.dx.error.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.incomplete.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.literal.error.hlsl diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h index 38e2417dcd181..757f8a3afc758 100644 --- a/clang/include/clang-c/Index.h +++ b/clang/include/clang-c/Index.h @@ -3034,7 +3034,8 @@ enum CXTypeKind { /* HLSL Types */ CXType_HLSLResource = 179, - CXType_HLSLAttributedResource = 180 + CXType_HLSLAttributedResource = 180, + CXType_HLSLInlineSpirv = 181 }; /** diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h index a24f30815e6b9..c62f9f7672010 100644 --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -260,6 +260,7 @@ class ASTContext : public RefCountedBase { DependentBitIntTypes; mutable llvm::FoldingSet BTFTagAttributedTypes; llvm::FoldingSet HLSLAttributedResourceTypes; + llvm::FoldingSet HLSLInlineSpirvTypes; mutable llvm::FoldingSet CountAttributedTypes; @@ -1795,6 +1796,10 @@ class ASTContext : public RefCountedBase { QualType Wrapped, QualType Contained, const HLSLAttributedResourceType::Attributes &Attrs); + QualType getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, + uint32_t Alignment, + ArrayRef Operands); + QualType getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl, unsigned Index, diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h index f086d8134a64b..fd9108221590e 100644 --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -450,6 +450,24 @@ class ASTNodeTraverser if (!Contained.isNull()) Visit(Contained); } + void VisitHLSLInlineSpirvType(const HLSLInlineSpirvType *T) { + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::kConstantId: + case SpirvOperandKind::kLiteral: + break; + + case SpirvOperandKind::kTypeId: + Visit(Operand.getResultType()); + break; + + default: + llvm_unreachable("Invalid SpirvOperand kind!"); + } + } + } void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *) {} void VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) { diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td index 5171555008ac9..7d5e6671fec7d 100644 --- a/clang/include/clang/AST/PropertiesBase.td +++ b/clang/include/clang/AST/PropertiesBase.td @@ -147,6 +147,7 @@ def UInt64 : CountPropertyType<"uint64_t">; def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">; def VectorKind : EnumPropertyType<"VectorKind">; def TypeCoupledDeclRefInfo : PropertyType; +def HLSLSpirvOperand : PropertyType<"SpirvOperand"> { let PassByReference = 1; } def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> { let BufferElementTypes = [ QualType ]; diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index 0530996ed20d3..255e39a46db09 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -1154,6 +1154,14 @@ DEF_TRAVERSE_TYPE(BTFTagAttributedType, DEF_TRAVERSE_TYPE(HLSLAttributedResourceType, { TRY_TO(TraverseType(T->getWrappedType())); }) +DEF_TRAVERSE_TYPE(HLSLInlineSpirvType, { + for (auto &Operand : T->getOperands()) { + if (Operand.isConstant() || Operand.isType()) { + TRY_TO(TraverseType(Operand.getResultType())); + } + } +}) + DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); }) DEF_TRAVERSE_TYPE(MacroQualifiedType, @@ -1457,6 +1465,9 @@ DEF_TRAVERSE_TYPELOC(BTFTagAttributedType, DEF_TRAVERSE_TYPELOC(HLSLAttributedResourceType, { TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); }) +DEF_TRAVERSE_TYPELOC(HLSLInlineSpirvType, + { TRY_TO(TraverseType(TL.getType())); }) + DEF_TRAVERSE_TYPELOC(ElaboratedType, { if (TL.getQualifierLoc()) { TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc())); diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h index cfd417068abb7..f351e68d5297d 100644 --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -2652,6 +2652,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase { bool isHLSLSpecificType() const; // Any HLSL specific type bool isHLSLBuiltinIntangibleType() const; // Any HLSL builtin intangible type bool isHLSLAttributedResourceType() const; + bool isHLSLInlineSpirvType() const; bool isHLSLResourceRecord() const; bool isHLSLIntangibleType() const; // Any HLSL intangible type (builtin, array, class) @@ -6330,6 +6331,140 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode { findHandleTypeOnResource(const Type *RT); }; +/// Instances of this class represent operands to a SPIR-V type instruction. +class SpirvOperand { +public: + enum SpirvOperandKind : unsigned char { + kInvalid, ///< Uninitialized. + kConstantId, ///< Integral value to represent as a SPIR-V OpConstant + ///< instruction ID. + kLiteral, ///< Integral value to represent as an immediate literal. + kTypeId, ///< Type to represent as a SPIR-V type ID. + + kMax, + }; + +private: + SpirvOperandKind Kind = kInvalid; + + QualType ResultType; + llvm::APInt Value; // Signedness of constants is represented by ResultType. + +public: + SpirvOperand() : Kind(kInvalid), ResultType() {} + + SpirvOperand(SpirvOperandKind Kind, QualType ResultType, llvm::APInt Value) + : Kind(Kind), ResultType(ResultType), Value(Value) {} + + SpirvOperand(const SpirvOperand &Other) { *this = Other; } + ~SpirvOperand() {} + + SpirvOperand &operator=(const SpirvOperand &Other) { + this->Kind = Other.Kind; + this->ResultType = Other.ResultType; + this->Value = Other.Value; + return *this; + } + + bool operator==(const SpirvOperand &Other) const { + return Kind == Other.Kind && ResultType == Other.ResultType && + Value == Other.Value; + } + + bool operator!=(const SpirvOperand &Other) const { return !(*this == Other); } + + SpirvOperandKind getKind() const { return Kind; } + + bool isValid() const { return Kind != kInvalid && Kind < kMax; } + bool isConstant() const { return Kind == kConstantId; } + bool isLiteral() const { return Kind == kLiteral; } + bool isType() const { return Kind == kTypeId; } + + llvm::APInt getValue() const { + assert((isConstant() || isLiteral()) && + "This is not an operand with a value!"); + return Value; + } + + QualType getResultType() const { + assert((isConstant() || isType()) && + "This is not an operand with a result type!"); + return ResultType; + } + + static SpirvOperand createConstant(QualType ResultType, llvm::APInt Val) { + return SpirvOperand(kConstantId, ResultType, Val); + } + + static SpirvOperand createLiteral(llvm::APInt Val) { + return SpirvOperand(kLiteral, QualType(), Val); + } + + static SpirvOperand createType(QualType T) { + return SpirvOperand(kTypeId, T, llvm::APSInt()); + } + + void Profile(llvm::FoldingSetNodeID &ID) const { + ID.AddInteger(Kind); + ID.AddPointer(ResultType.getAsOpaquePtr()); + Value.Profile(ID); + } +}; + +/// Represents an arbitrary, user-specified SPIR-V type instruction. +class HLSLInlineSpirvType final + : public Type, + public llvm::FoldingSetNode, + private llvm::TrailingObjects { + friend class ASTContext; // ASTContext creates these + friend TrailingObjects; + +private: + uint32_t Opcode; + uint32_t Size; + uint32_t Alignment; + size_t NumOperands; + + HLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, uint32_t Alignment, + ArrayRef Operands) + : Type(HLSLInlineSpirv, QualType(), TypeDependence::None), Opcode(Opcode), + Size(Size), Alignment(Alignment), NumOperands(Operands.size()) { + for (size_t I = 0; I < NumOperands; I++) { + getTrailingObjects()[I] = Operands[I]; + } + } + +public: + uint32_t getOpcode() const { return Opcode; } + uint32_t getSize() const { return Size; } + uint32_t getAlignment() const { return Alignment; } + ArrayRef getOperands() const { + return {getTrailingObjects(), NumOperands}; + } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, Opcode, Size, Alignment, getOperands()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, uint32_t Opcode, + uint32_t Size, uint32_t Alignment, + ArrayRef Operands) { + ID.AddInteger(Opcode); + ID.AddInteger(Size); + ID.AddInteger(Alignment); + for (auto &Operand : Operands) { + Operand.Profile(ID); + } + } + + static bool classof(const Type *T) { + return T->getTypeClass() == HLSLInlineSpirv; + } +}; + class TemplateTypeParmType : public Type, public llvm::FoldingSetNode { friend class ASTContext; // ASTContext creates these @@ -8458,13 +8593,18 @@ inline bool Type::isHLSLBuiltinIntangibleType() const { } inline bool Type::isHLSLSpecificType() const { - return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType(); + return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType() || + isHLSLInlineSpirvType(); } inline bool Type::isHLSLAttributedResourceType() const { return isa(this); } +inline bool Type::isHLSLInlineSpirvType() const { + return isa(this); +} + inline bool Type::isTemplateTypeParmType() const { return isa(CanonicalType); } diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h index 92661b8b13fe0..53c7ea8c65df2 100644 --- a/clang/include/clang/AST/TypeLoc.h +++ b/clang/include/clang/AST/TypeLoc.h @@ -973,6 +973,25 @@ class HLSLAttributedResourceTypeLoc } }; +struct HLSLInlineSpirvTypeLocInfo { + SourceLocation Loc; +}; // Nothing. + +class HLSLInlineSpirvTypeLoc + : public ConcreteTypeLoc { +public: + SourceLocation getSpirvTypeLoc() const { return getLocalData()->Loc; } + void setSpirvTypeLoc(SourceLocation loc) const { getLocalData()->Loc = loc; } + + SourceRange getLocalSourceRange() const { + return SourceRange(getSpirvTypeLoc(), getSpirvTypeLoc()); + } + void initializeLocal(ASTContext &Context, SourceLocation loc) { + setSpirvTypeLoc(loc); + } +}; + struct ObjCObjectTypeLocInfo { SourceLocation TypeArgsLAngleLoc; SourceLocation TypeArgsRAngleLoc; diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td index 391fd26a086f7..784c2104f1bb2 100644 --- a/clang/include/clang/AST/TypeProperties.td +++ b/clang/include/clang/AST/TypeProperties.td @@ -719,6 +719,24 @@ let Class = HLSLAttributedResourceType in { }]>; } +let Class = HLSLInlineSpirvType in { + def : Property<"opcode", UInt32> { + let Read = [{ node->getOpcode() }]; + } + def : Property<"size", UInt32> { + let Read = [{ node->getSize() }]; + } + def : Property<"alignment", UInt32> { + let Read = [{ node->getAlignment() }]; + } + def : Property<"operands", Array> { + let Read = [{ node->getOperands() }]; + } + def : Creator<[{ + return ctx.getHLSLInlineSpirvType(opcode, size, alignment, operands); + }]>; +} + let Class = DependentAddressSpaceType in { def : Property<"pointeeType", QualType> { let Read = [{ node->getPointeeType() }]; diff --git a/clang/include/clang/Basic/BuiltinTemplates.td b/clang/include/clang/Basic/BuiltinTemplates.td index d46ce063d2f7e..5b9672b395955 100644 --- a/clang/include/clang/Basic/BuiltinTemplates.td +++ b/clang/include/clang/Basic/BuiltinTemplates.td @@ -28,25 +28,37 @@ class BuiltinNTTP : TemplateArg<""> { } def SizeT : BuiltinNTTP<"size_t"> {} +def Uint32T: BuiltinNTTP<"uint32_t"> {} class BuiltinTemplate template_head> { list TemplateHead = template_head; } +class CPlusPlusBuiltinTemplate template_head> : BuiltinTemplate; + +class HLSLBuiltinTemplate template_head> : BuiltinTemplate; + // template