Skip to content

[HLSL] Implement SpirvType and SpirvOpaqueType #134034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clang/include/clang-c/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -3034,7 +3034,8 @@ enum CXTypeKind {

/* HLSL Types */
CXType_HLSLResource = 179,
CXType_HLSLAttributedResource = 180
CXType_HLSLAttributedResource = 180,
CXType_HLSLInlineSpirv = 181
};

/**
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
DependentBitIntTypes;
mutable llvm::FoldingSet<BTFTagAttributedType> BTFTagAttributedTypes;
llvm::FoldingSet<HLSLAttributedResourceType> HLSLAttributedResourceTypes;
llvm::FoldingSet<HLSLInlineSpirvType> HLSLInlineSpirvTypes;

mutable llvm::FoldingSet<CountAttributedType> CountAttributedTypes;

Expand Down Expand Up @@ -1795,6 +1796,10 @@ class ASTContext : public RefCountedBase<ASTContext> {
QualType Wrapped, QualType Contained,
const HLSLAttributedResourceType::Attributes &Attrs);

QualType getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size,
uint32_t Alignment,
ArrayRef<SpirvOperand> Operands);

QualType
getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
unsigned Index,
Expand Down
18 changes: 18 additions & 0 deletions clang/include/clang/AST/ASTNodeTraverser.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,24 @@ class ASTNodeTraverser
if (!Contained.isNull())
Visit(Contained);
}
void VisitHLSLInlineSpirvType(const HLSLInlineSpirvType *T) {
for (auto &Operand : T->getOperands()) {
using SpirvOperandKind = SpirvOperand::SpirvOperandKind;

switch (Operand.getKind()) {
case SpirvOperandKind::kConstantId:
case SpirvOperandKind::kLiteral:
break;

case SpirvOperandKind::kTypeId:
Visit(Operand.getResultType());
break;

default:
llvm_unreachable("Invalid SpirvOperand kind!");
}
}
}
void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *) {}
void
VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) {
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/AST/PropertiesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def UInt64 : CountPropertyType<"uint64_t">;
def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">;
def VectorKind : EnumPropertyType<"VectorKind">;
def TypeCoupledDeclRefInfo : PropertyType;
def HLSLSpirvOperand : PropertyType<"SpirvOperand"> { let PassByReference = 1; }

def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> {
let BufferElementTypes = [ QualType ];
Expand Down
11 changes: 11 additions & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,14 @@ DEF_TRAVERSE_TYPE(BTFTagAttributedType,
DEF_TRAVERSE_TYPE(HLSLAttributedResourceType,
{ TRY_TO(TraverseType(T->getWrappedType())); })

DEF_TRAVERSE_TYPE(HLSLInlineSpirvType, {
for (auto &Operand : T->getOperands()) {
if (Operand.isConstant() || Operand.isType()) {
TRY_TO(TraverseType(Operand.getResultType()));
}
}
})

DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); })

DEF_TRAVERSE_TYPE(MacroQualifiedType,
Expand Down Expand Up @@ -1457,6 +1465,9 @@ DEF_TRAVERSE_TYPELOC(BTFTagAttributedType,
DEF_TRAVERSE_TYPELOC(HLSLAttributedResourceType,
{ TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); })

DEF_TRAVERSE_TYPELOC(HLSLInlineSpirvType,
{ TRY_TO(TraverseType(TL.getType())); })

DEF_TRAVERSE_TYPELOC(ElaboratedType, {
if (TL.getQualifierLoc()) {
TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc()));
Expand Down
142 changes: 141 additions & 1 deletion clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -2652,6 +2652,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
bool isHLSLSpecificType() const; // Any HLSL specific type
bool isHLSLBuiltinIntangibleType() const; // Any HLSL builtin intangible type
bool isHLSLAttributedResourceType() const;
bool isHLSLInlineSpirvType() const;
bool isHLSLResourceRecord() const;
bool isHLSLIntangibleType()
const; // Any HLSL intangible type (builtin, array, class)
Expand Down Expand Up @@ -6330,6 +6331,140 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
findHandleTypeOnResource(const Type *RT);
};

/// Instances of this class represent operands to a SPIR-V type instruction.
class SpirvOperand {
public:
enum SpirvOperandKind : unsigned char {
kInvalid, ///< Uninitialized.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kConstantId, ///< Integral value to represent as a SPIR-V OpConstant
///< instruction ID.
kLiteral, ///< Integral value to represent as an immediate literal.
kTypeId, ///< Type to represent as a SPIR-V type ID.

kMax,
};

private:
SpirvOperandKind Kind = kInvalid;

QualType ResultType;
llvm::APInt Value; // Signedness of constants is represented by ResultType.

public:
SpirvOperand() : Kind(kInvalid), ResultType() {}

SpirvOperand(SpirvOperandKind Kind, QualType ResultType, llvm::APInt Value)
: Kind(Kind), ResultType(ResultType), Value(Value) {}

SpirvOperand(const SpirvOperand &Other) { *this = Other; }
~SpirvOperand() {}

SpirvOperand &operator=(const SpirvOperand &Other) {
this->Kind = Other.Kind;
this->ResultType = Other.ResultType;
this->Value = Other.Value;
return *this;
}

bool operator==(const SpirvOperand &Other) const {
return Kind == Other.Kind && ResultType == Other.ResultType &&
Value == Other.Value;
}

bool operator!=(const SpirvOperand &Other) const { return !(*this == Other); }

SpirvOperandKind getKind() const { return Kind; }

bool isValid() const { return Kind != kInvalid && Kind < kMax; }
bool isConstant() const { return Kind == kConstantId; }
bool isLiteral() const { return Kind == kLiteral; }
bool isType() const { return Kind == kTypeId; }

llvm::APInt getValue() const {
assert((isConstant() || isLiteral()) &&
"This is not an operand with a value!");
return Value;
}

QualType getResultType() const {
assert((isConstant() || isType()) &&
"This is not an operand with a result type!");
return ResultType;
}

static SpirvOperand createConstant(QualType ResultType, llvm::APInt Val) {
return SpirvOperand(kConstantId, ResultType, Val);
}

static SpirvOperand createLiteral(llvm::APInt Val) {
return SpirvOperand(kLiteral, QualType(), Val);
}

static SpirvOperand createType(QualType T) {
return SpirvOperand(kTypeId, T, llvm::APSInt());
}

void Profile(llvm::FoldingSetNodeID &ID) const {
ID.AddInteger(Kind);
ID.AddPointer(ResultType.getAsOpaquePtr());
Value.Profile(ID);
}
};

/// Represents an arbitrary, user-specified SPIR-V type instruction.
class HLSLInlineSpirvType final
: public Type,
public llvm::FoldingSetNode,
private llvm::TrailingObjects<HLSLInlineSpirvType, SpirvOperand> {
friend class ASTContext; // ASTContext creates these
friend TrailingObjects;

private:
uint32_t Opcode;
uint32_t Size;
uint32_t Alignment;
size_t NumOperands;

HLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, uint32_t Alignment,
ArrayRef<SpirvOperand> Operands)
: Type(HLSLInlineSpirv, QualType(), TypeDependence::None), Opcode(Opcode),
Size(Size), Alignment(Alignment), NumOperands(Operands.size()) {
for (size_t I = 0; I < NumOperands; I++) {
getTrailingObjects<SpirvOperand>()[I] = Operands[I];
}
Comment on lines +6432 to +6434
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements

Suggested change
for (size_t I = 0; I < NumOperands; I++) {
getTrailingObjects<SpirvOperand>()[I] = Operands[I];
}
for (size_t I = 0; I < NumOperands; I++)
getTrailingObjects<SpirvOperand>()[I] = Operands[I];

}

public:
uint32_t getOpcode() const { return Opcode; }
uint32_t getSize() const { return Size; }
uint32_t getAlignment() const { return Alignment; }
ArrayRef<SpirvOperand> getOperands() const {
return {getTrailingObjects<SpirvOperand>(), NumOperands};
}

bool isSugared() const { return false; }
QualType desugar() const { return QualType(this, 0); }

void Profile(llvm::FoldingSetNodeID &ID) {
Profile(ID, Opcode, Size, Alignment, getOperands());
}

static void Profile(llvm::FoldingSetNodeID &ID, uint32_t Opcode,
uint32_t Size, uint32_t Alignment,
ArrayRef<SpirvOperand> Operands) {
ID.AddInteger(Opcode);
ID.AddInteger(Size);
ID.AddInteger(Alignment);
for (auto &Operand : Operands) {
Operand.Profile(ID);
}
Comment on lines +6458 to +6460
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements

Suggested change
for (auto &Operand : Operands) {
Operand.Profile(ID);
}
for (auto &Operand : Operands)
Operand.Profile(ID);

}

static bool classof(const Type *T) {
return T->getTypeClass() == HLSLInlineSpirv;
}
};

class TemplateTypeParmType : public Type, public llvm::FoldingSetNode {
friend class ASTContext; // ASTContext creates these

Expand Down Expand Up @@ -8458,13 +8593,18 @@ inline bool Type::isHLSLBuiltinIntangibleType() const {
}

inline bool Type::isHLSLSpecificType() const {
return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType();
return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType() ||
isHLSLInlineSpirvType();
}

inline bool Type::isHLSLAttributedResourceType() const {
return isa<HLSLAttributedResourceType>(this);
}

inline bool Type::isHLSLInlineSpirvType() const {
return isa<HLSLInlineSpirvType>(this);
}

inline bool Type::isTemplateTypeParmType() const {
return isa<TemplateTypeParmType>(CanonicalType);
}
Expand Down
19 changes: 19 additions & 0 deletions clang/include/clang/AST/TypeLoc.h
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,25 @@ class HLSLAttributedResourceTypeLoc
}
};

struct HLSLInlineSpirvTypeLocInfo {
SourceLocation Loc;
}; // Nothing.

class HLSLInlineSpirvTypeLoc
: public ConcreteTypeLoc<UnqualTypeLoc, HLSLInlineSpirvTypeLoc,
HLSLInlineSpirvType, HLSLInlineSpirvTypeLocInfo> {
public:
SourceLocation getSpirvTypeLoc() const { return getLocalData()->Loc; }
void setSpirvTypeLoc(SourceLocation loc) const { getLocalData()->Loc = loc; }

SourceRange getLocalSourceRange() const {
return SourceRange(getSpirvTypeLoc(), getSpirvTypeLoc());
}
void initializeLocal(ASTContext &Context, SourceLocation loc) {
setSpirvTypeLoc(loc);
}
};

struct ObjCObjectTypeLocInfo {
SourceLocation TypeArgsLAngleLoc;
SourceLocation TypeArgsRAngleLoc;
Expand Down
18 changes: 18 additions & 0 deletions clang/include/clang/AST/TypeProperties.td
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,24 @@ let Class = HLSLAttributedResourceType in {
}]>;
}

let Class = HLSLInlineSpirvType in {
def : Property<"opcode", UInt32> {
let Read = [{ node->getOpcode() }];
}
def : Property<"size", UInt32> {
let Read = [{ node->getSize() }];
}
def : Property<"alignment", UInt32> {
let Read = [{ node->getAlignment() }];
}
def : Property<"operands", Array<HLSLSpirvOperand>> {
let Read = [{ node->getOperands() }];
}
def : Creator<[{
return ctx.getHLSLInlineSpirvType(opcode, size, alignment, operands);
}]>;
}

let Class = DependentAddressSpaceType in {
def : Property<"pointeeType", QualType> {
let Read = [{ node->getPointeeType() }];
Expand Down
18 changes: 15 additions & 3 deletions clang/include/clang/Basic/BuiltinTemplates.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,37 @@ class BuiltinNTTP<string type_name> : TemplateArg<""> {
}

def SizeT : BuiltinNTTP<"size_t"> {}
def Uint32T: BuiltinNTTP<"uint32_t"> {}

class BuiltinTemplate<list<TemplateArg> template_head> {
list<TemplateArg> TemplateHead = template_head;
}

class CPlusPlusBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>;

class HLSLBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>;

// template <template <class T, T... Ints> IntSeq, class T, T N>
def __make_integer_seq : BuiltinTemplate<
def __make_integer_seq : CPlusPlusBuiltinTemplate<
[Template<[Class<"T">, NTTP<"T", "Ints", /*is_variadic=*/1>], "IntSeq">, Class<"T">, NTTP<"T", "N">]>;

// template <size_t, class... T>
def __type_pack_element : BuiltinTemplate<
def __type_pack_element : CPlusPlusBuiltinTemplate<
[SizeT, Class<"T", /*is_variadic=*/1>]>;

// template <template <class... Args> BaseTemplate,
// template <class TypeMember> HasTypeMember,
// class HasNoTypeMember
// class... Ts>
def __builtin_common_type : BuiltinTemplate<
def __builtin_common_type : CPlusPlusBuiltinTemplate<
[Template<[Class<"Args", /*is_variadic=*/1>], "BaseTemplate">,
Template<[Class<"TypeMember">], "HasTypeMember">,
Class<"HasNoTypeMember">,
Class<"Ts", /*is_variadic=*/1>]>;

// template <uint32_t Opcode,
// uint32_t Size,
// uint32_t Alignment,
// typename ...Operands>
def __hlsl_spirv_type : HLSLBuiltinTemplate<
[Uint32T, Uint32T, Uint32T, Class<"Operands", /*is_variadic=*/1>]>;
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12709,6 +12709,9 @@ def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
def err_invalid_hlsl_resource_type: Error<
"invalid __hlsl_resource_t type attributes">;

def err_hlsl_spirv_only: Error<"%0 is only available for the SPIR-V target">;
def err_hlsl_vk_literal_must_contain_constant: Error<"the argument to vk::Literal must be a vk::integral_constant">;

// Layout randomization diagnostics.
def err_non_designated_init_used : Error<
"a randomized struct can only be initialized with a designated initializer">;
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/TypeNodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def ElaboratedType : TypeNode<Type>, NeverCanonical;
def AttributedType : TypeNode<Type>, NeverCanonical;
def BTFTagAttributedType : TypeNode<Type>, NeverCanonical;
def HLSLAttributedResourceType : TypeNode<Type>;
def HLSLInlineSpirvType : TypeNode<Type>;
def TemplateTypeParmType : TypeNode<Type>, AlwaysDependent, LeafType;
def SubstTemplateTypeParmType : TypeNode<Type>, NeverCanonical;
def SubstTemplateTypeParmPackType : TypeNode<Type>, AlwaysDependent;
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Serialization/ASTRecordReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ class ASTRecordReader

TypeCoupledDeclRefInfo readTypeCoupledDeclRefInfo();

SpirvOperand readHLSLSpirvOperand();

/// Read a declaration name, advancing Idx.
// DeclarationName readDeclarationName(); (inherited)
DeclarationNameLoc readDeclarationNameLoc(DeclarationName Name);
Expand Down
14 changes: 14 additions & 0 deletions clang/include/clang/Serialization/ASTRecordWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ class ASTRecordWriter
writeBool(Info.isDeref());
}

void writeHLSLSpirvOperand(SpirvOperand Op) {
QualType ResultType;
llvm::APInt Value;

if (Op.isConstant() || Op.isType())
ResultType = Op.getResultType();
if (Op.isConstant() || Op.isLiteral())
Value = Op.getValue();

Record->push_back(Op.getKind());
writeQualType(ResultType);
writeAPInt(Value);
}

/// Emit a source range.
void AddSourceRange(SourceRange Range, LocSeq *Seq = nullptr) {
return Writer->AddSourceRange(Range, *Record, Seq);
Expand Down
Loading
Loading