forked from llvm/llvm-project
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathParseHLSLRootSignature.cpp
309 lines (257 loc) · 9.5 KB
/
ParseHLSLRootSignature.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
//=== ParseHLSLRootSignature.cpp - Parse Root Signature -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "clang/Parse/ParseHLSLRootSignature.h"
#include "clang/Lex/LiteralSupport.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm::hlsl::rootsig;
namespace clang {
namespace hlsl {
using TokenKind = RootSignatureToken::Kind;
RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
RootSignatureLexer &Lexer,
Preprocessor &PP)
: Elements(Elements), Lexer(Lexer), PP(PP), CurToken(SourceLocation()) {}
bool RootSignatureParser::parse() {
// Iterate as many RootElements as possible
while (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
// Dispatch onto parser method.
// We guard against the unreachable here as we just ensured that CurToken
// will be one of the kinds in the while condition
switch (CurToken.TokKind) {
case TokenKind::kw_DescriptorTable:
if (parseDescriptorTable())
return true;
break;
default:
llvm_unreachable("Switch for consumed token was not provided");
}
if (!tryConsumeExpectedToken(TokenKind::pu_comma))
break;
}
if (consumeExpectedToken(TokenKind::end_of_stream,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/TokenKind::kw_RootSignature))
return true;
return false;
}
bool RootSignatureParser::parseDescriptorTable() {
assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
"Expects to only be invoked starting at given keyword");
DescriptorTable Table;
if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
CurToken.TokKind))
return true;
// Iterate as many Clauses as possible
while (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
if (parseDescriptorTableClause())
return true;
Table.NumClauses++;
if (!tryConsumeExpectedToken(TokenKind::pu_comma))
break;
}
if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/TokenKind::kw_DescriptorTable))
return true;
Elements.push_back(Table);
return false;
}
bool RootSignatureParser::parseDescriptorTableClause() {
assert((CurToken.TokKind == TokenKind::kw_CBV ||
CurToken.TokKind == TokenKind::kw_SRV ||
CurToken.TokKind == TokenKind::kw_UAV ||
CurToken.TokKind == TokenKind::kw_Sampler) &&
"Expects to only be invoked starting at given keyword");
TokenKind ParamKind = CurToken.TokKind; // retain for diagnostics
DescriptorTableClause Clause;
TokenKind ExpectedRegister;
switch (ParamKind) {
default:
llvm_unreachable("Switch for consumed token was not provided");
case TokenKind::kw_CBV:
Clause.Type = ClauseType::CBuffer;
ExpectedRegister = TokenKind::bReg;
break;
case TokenKind::kw_SRV:
Clause.Type = ClauseType::SRV;
ExpectedRegister = TokenKind::tReg;
break;
case TokenKind::kw_UAV:
Clause.Type = ClauseType::UAV;
ExpectedRegister = TokenKind::uReg;
break;
case TokenKind::kw_Sampler:
Clause.Type = ClauseType::Sampler;
ExpectedRegister = TokenKind::sReg;
break;
}
if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
ParamKind))
return true;
llvm::SmallDenseMap<TokenKind, ParamType> Params = {
{ExpectedRegister, &Clause.Register},
{TokenKind::kw_space, &Clause.Space},
};
llvm::SmallDenseSet<TokenKind> Mandatory = {
ExpectedRegister,
};
if (parseParams(Params, Mandatory))
return true;
if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/ParamKind))
return true;
Elements.push_back(Clause);
return false;
}
// Helper struct defined to use the overloaded notation of std::visit.
template <class... Ts> struct ParseParamTypeMethods : Ts... {
using Ts::operator()...;
};
template <class... Ts>
ParseParamTypeMethods(Ts...) -> ParseParamTypeMethods<Ts...>;
bool RootSignatureParser::parseParam(ParamType Ref) {
return std::visit(
ParseParamTypeMethods{
[this](Register *X) -> bool { return parseRegister(X); },
[this](uint32_t *X) -> bool {
return consumeExpectedToken(TokenKind::pu_equal,
diag::err_expected_after,
CurToken.TokKind) ||
parseUIntParam(X);
},
},
Ref);
}
bool RootSignatureParser::parseParams(
llvm::SmallDenseMap<TokenKind, ParamType> &Params,
llvm::SmallDenseSet<TokenKind> &Mandatory) {
// Initialize a vector of possible keywords
SmallVector<TokenKind> Keywords;
for (auto Pair : Params)
Keywords.push_back(Pair.first);
// Keep track of which keywords have been seen to report duplicates
llvm::SmallDenseSet<TokenKind> Seen;
while (tryConsumeExpectedToken(Keywords)) {
if (Seen.contains(CurToken.TokKind)) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
<< CurToken.TokKind;
return true;
}
Seen.insert(CurToken.TokKind);
if (parseParam(Params[CurToken.TokKind]))
return true;
if (!tryConsumeExpectedToken(TokenKind::pu_comma))
break;
}
bool AllMandatoryDefined = true;
for (auto Kind : Mandatory) {
bool SeenParam = Seen.contains(Kind);
if (!SeenParam) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
<< Kind;
}
AllMandatoryDefined &= SeenParam;
}
return !AllMandatoryDefined;
}
bool RootSignatureParser::parseUIntParam(uint32_t *X) {
assert(CurToken.TokKind == TokenKind::pu_equal &&
"Expects to only be invoked starting at given keyword");
tryConsumeExpectedToken(TokenKind::pu_plus);
return consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
CurToken.TokKind) ||
handleUIntLiteral(X);
}
bool RootSignatureParser::parseRegister(Register *Register) {
assert((CurToken.TokKind == TokenKind::bReg ||
CurToken.TokKind == TokenKind::tReg ||
CurToken.TokKind == TokenKind::uReg ||
CurToken.TokKind == TokenKind::sReg) &&
"Expects to only be invoked starting at given keyword");
switch (CurToken.TokKind) {
default:
llvm_unreachable("Switch for consumed token was not provided");
case TokenKind::bReg:
Register->ViewType = RegisterType::BReg;
break;
case TokenKind::tReg:
Register->ViewType = RegisterType::TReg;
break;
case TokenKind::uReg:
Register->ViewType = RegisterType::UReg;
break;
case TokenKind::sReg:
Register->ViewType = RegisterType::SReg;
break;
}
if (handleUIntLiteral(&Register->Number))
return true; // propogate NumericLiteralParser error
return false;
}
bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
// Parse the numeric value and do semantic checks on its specification
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
PP.getSourceManager(), PP.getLangOpts(),
PP.getTargetInfo(), PP.getDiagnostics());
if (Literal.hadError)
return true; // Error has already been reported so just return
assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
llvm::APSInt Val = llvm::APSInt(32, false);
if (Literal.GetIntegerValue(Val)) {
// Report that the value has overflowed
PP.getDiagnostics().Report(CurToken.TokLoc,
diag::err_hlsl_number_literal_overflow)
<< 0 << CurToken.NumSpelling;
return true;
}
*X = Val.getExtValue();
return false;
}
bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
return peekExpectedToken(ArrayRef{Expected});
}
bool RootSignatureParser::peekExpectedToken(ArrayRef<TokenKind> AnyExpected) {
RootSignatureToken Result = Lexer.PeekNextToken();
return llvm::is_contained(AnyExpected, Result.TokKind);
}
bool RootSignatureParser::consumeExpectedToken(TokenKind Expected,
unsigned DiagID,
TokenKind Context) {
if (tryConsumeExpectedToken(Expected))
return false;
// Report unexpected token kind error
DiagnosticBuilder DB = getDiags().Report(CurToken.TokLoc, DiagID);
switch (DiagID) {
case diag::err_expected:
DB << Expected;
break;
case diag::err_hlsl_unexpected_end_of_params:
case diag::err_expected_either:
case diag::err_expected_after:
DB << Expected << Context;
break;
default:
break;
}
return true;
}
bool RootSignatureParser::tryConsumeExpectedToken(TokenKind Expected) {
return tryConsumeExpectedToken(ArrayRef{Expected});
}
bool RootSignatureParser::tryConsumeExpectedToken(
ArrayRef<TokenKind> AnyExpected) {
// If not the expected token just return
if (!peekExpectedToken(AnyExpected))
return false;
consumeNextToken();
return true;
}
} // namespace hlsl
} // namespace clang