Skip to content

Commit c349c34

Browse files
authored
[eudsl-tblgen] add lots of stuff (support retrieving ArgumentInits from Record) (#318)
This is kind of a niche PR so I don't expect anyone to really closely review but I'm describing it anyway just to capture some findings. The basic feature added in this PR is the ability to recover from an instantiated `def` which template args were used in the `class` call/invocation/thing. E.g., ``` class Test_Op<string mnemonic, list<Trait> traits = []> : Op<Test_Dialect, mnemonic, traits>; def Test_AndOp : Test_Op<"and"> { ... } ``` With this PR you can recover that the `mnemonic` passed as the first argument to `Test_Op` when instantiating `Test_AndOp` is `"and"`. This sounds like something that should've been always possible but turns out tablegen literally doesn't keep that information in any way/shape/form. In order to make this work I had to hack the parser to keep it around (note, this isn't completely straightforward because arguments can have unevaluated [bang operators](https://llvm.org/docs/TableGen/ProgRef.html#bang-operators)). For context, I intended on using this functionality to translate LLVM `Intrinsic`s into MLIR OpDefs but turns out it didn't really work out (not because the feature doesn't work but because I have ~30k `Intrinsic`s lol).
1 parent c0c6123 commit c349c34

File tree

14 files changed

+426
-58
lines changed

14 files changed

+426
-58
lines changed

.github/workflows/build_llvm.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ jobs:
181181
# Tar up MLIR/LLVM distro
182182
###############################
183183
184+
# if timezone causes datetime to be ahead of local datetime then ninja will go into a loop...
185+
find $LLVM_INSTALL_DIR -exec touch -a -m -t 197001010000 {} \;
186+
184187
tar -czf "mlir_${{ matrix.name }}_$WHEEL_VERSION.tar.gz" -C "$LLVM_INSTALL_DIR/.." llvm-install
185188
rm -rf "$LLVM_BUILD_DIR" "$LLVM_SOURCE_DIR"
186189

projects/common/eudsl/bind_vec_like.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ extern nanobind::class_<_SmallVector> smallVector;
2525
extern nanobind::class_<_ArrayRef> arrayRef;
2626
extern nanobind::class_<_MutableArrayRef> mutableArrayRef;
2727

28-
template <typename Element, typename... Args>
29-
std::tuple<nanobind::class_<llvm::SmallVector<Element>>,
28+
template <typename Element, int Size = 4, typename... Args>
29+
std::tuple<nanobind::class_<llvm::SmallVector<Element, Size>>,
3030
nanobind::class_<llvm::ArrayRef<Element>>,
3131
nanobind::class_<llvm::MutableArrayRef<Element>>>
3232
bind_array_ref(nanobind::handle scope, Args &&...args) {
3333
using ArrayRef = llvm::ArrayRef<Element>;
34-
using SmallVec = llvm::SmallVector<Element>;
34+
using SmallVec = llvm::SmallVector<Element, Size>;
3535
using MutableArrayRef = llvm::MutableArrayRef<Element>;
3636
using ValueRef = Element &;
3737

projects/eudsl-llvmpy/eudsl-llvmpy-generate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,15 @@ class LLVMMatchType(Generic[_T]):
320320
),
321321
file=amdgcn_f,
322322
)
323-
intrins = RecordKeeper().parse_td(
323+
rk = RecordKeeper()
324+
rk.parse_td(
324325
str(llvm_include_root / "llvm" / "IR" / "Intrinsics.td"),
325326
include_dirs=[str(llvm_include_root)],
326327
)
327328
int_regex = re.compile(r"_i(\d+)")
328329
fp_regex = re.compile(r"_f(\d+)")
329330

330-
defs = intrins.get_defs()
331+
defs = rk.get_defs()
331332
for d in defs:
332333
intr = defs[d]
333334
if (
File renamed without changes.

projects/eudsl-tblgen/pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Homepage = "https://github.com/llvm/eudsl"
1313

1414
[build-system]
1515
requires = [
16-
"nanobind>=2.2.0",
16+
"nanobind>=2.9.2",
1717
"scikit-build-core>=0.10.7",
1818
"typing_extensions>=4.12.2"
1919
]
@@ -99,3 +99,6 @@ before-build = [
9999
before-build = [
100100
"ccache -z"
101101
]
102+
103+
[tool.pytest.ini_options]
104+
addopts = "-s"

projects/eudsl-tblgen/src/TGLexer.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ class TGLexer {
209209
public:
210210
typedef std::set<std::string> DependenciesSetTy;
211211

212-
private:
213212
/// Dependencies - This is the list of all included files.
214213
DependenciesSetTy Dependencies;
215214

@@ -240,7 +239,6 @@ class TGLexer {
240239
SMLoc getLoc() const;
241240
SMRange getLocRange() const;
242241

243-
private:
244242
/// LexToken - Read the next token and return its code.
245243
tgtok::TokKind LexToken(bool FileOrLineStart = false);
246244

projects/eudsl-tblgen/src/TGParser.cpp

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct SubClassReference {
4848
SubClassReference() = default;
4949

5050
bool isInvalid() const { return Rec == nullptr; }
51+
void dump() const;
5152
};
5253

5354
struct SubMultiClassReference {
@@ -72,6 +73,19 @@ LLVM_DUMP_METHOD void SubMultiClassReference::dump() const {
7273
for (const Init *TA : TemplateArgs)
7374
TA->dump();
7475
}
76+
77+
LLVM_DUMP_METHOD void SubClassReference::dump() const {
78+
errs() << "Subclass:\n";
79+
80+
Rec->dump();
81+
82+
errs() << "Template args:\n";
83+
for (const Init *TA : TemplateArgs) {
84+
TA->dump();
85+
errs() << "\n";
86+
}
87+
errs() << "\n";
88+
}
7589
#endif
7690

7791
static bool checkBitsConcrete(Record &R, const RecordVal &RV) {
@@ -307,6 +321,66 @@ bool TGParser::SetValue(Record *CurRec, SMLoc Loc, const Init *ValName,
307321
return false;
308322
}
309323

324+
static std::string getCanonicalRecName(Record *CurRec) {
325+
const Init *Name;
326+
if (CurRec->isClass())
327+
Name = VarInit::get(QualifiedNameOfImplicitName(*CurRec),
328+
StringRecTy::get(CurRec->getRecords()));
329+
else
330+
Name = CurRec->getNameInit();
331+
return Name->getAsUnquotedString();
332+
}
333+
334+
// https://nimrod.blog/posts/cpp-how-to-access-private-members-validly/
335+
template <typename Tag>
336+
struct Storage {
337+
inline static typename Tag::type ptr;
338+
};
339+
340+
template <typename Tag, typename Tag::type V>
341+
struct PtrTaker {
342+
struct Transferer {
343+
Transferer() { Storage<Tag>::ptr = V; }
344+
};
345+
inline static Transferer tr;
346+
};
347+
348+
struct TemplateArgsTag {
349+
using type = SmallVector<const Init *, 0> Record::*;
350+
};
351+
352+
template struct PtrTaker<TemplateArgsTag, &Record::TemplateArgs>;
353+
354+
static void
355+
addOrReplaceTemplateArgs(Record *rec,
356+
SmallVector<const ArgumentInit *, 4> newTemplateArgs) {
357+
SmallVector<const Init *, 0> *mutableTemplateArgs =
358+
&(*rec.*Storage<TemplateArgsTag>::ptr);
359+
std::string recName = getCanonicalRecName(rec);
360+
for (const ArgumentInit *newTemplateArg : newTemplateArgs) {
361+
const Init **i = llvm::find_if(*mutableTemplateArgs, [&](const Init *tArg) {
362+
const auto *oTArg = llvm::dyn_cast<ArgumentInit>(tArg);
363+
if (oTArg->isNamed()) {
364+
if (newTemplateArg->isNamed())
365+
return oTArg->getName() == newTemplateArg->getName();
366+
return false;
367+
}
368+
if (newTemplateArg->isPositional())
369+
return oTArg->getIndex() == newTemplateArg->getIndex();
370+
return false;
371+
});
372+
auto newVal = RecordVal(newTemplateArg, RecordRecTy::get(rec),
373+
RecordVal::FK_TemplateArg);
374+
if (i == mutableTemplateArgs->end()) {
375+
mutableTemplateArgs->push_back(newTemplateArg);
376+
} else {
377+
rec->removeValue(*i);
378+
*i = newTemplateArg;
379+
}
380+
rec->addValue(newVal);
381+
}
382+
}
383+
310384
/// AddSubClass - Add SubClass as a subclass to CurRec, resolving its template
311385
/// args as SubClass's template arguments.
312386
bool TGParser::AddSubClass(Record *CurRec, SubClassReference &SubClass) {
@@ -336,9 +410,15 @@ bool TGParser::AddSubClass(Record *CurRec, SubClassReference &SubClass) {
336410
StringRecTy::get(Records));
337411
else
338412
Name = CurRec->getNameInit();
413+
auto canName = getCanonicalRecName(CurRec);
414+
recordTemplateArgs[canName] = SubClass.TemplateArgs;
339415
R.set(QualifiedNameOfImplicitName(*SC), Name);
340416

341417
CurRec->resolveReferences(R);
418+
// we only want to keep args for defs not classes which pass
419+
// args to other classes (i think so?)
420+
if (!CurRec->isClass() && !SubClass.TemplateArgs.empty())
421+
addOrReplaceTemplateArgs(CurRec, SubClass.TemplateArgs);
342422

343423
// Since everything went well, we can now set the "superclass" list for the
344424
// current record.
@@ -531,7 +611,20 @@ bool TGParser::resolve(const std::vector<RecordsEntry> &Source,
531611
MapResolver R(Rec.get());
532612
for (const auto &S : Substs)
533613
R.set(S.first, S.second);
614+
std::string eName = getCanonicalRecName(E.Rec.get());
615+
assert(recordTemplateArgs.count(eName));
616+
auto templateArgs = recordTemplateArgs[eName];
534617
Rec->resolveReferences(R);
618+
SmallVector<const ArgumentInit *, 4> resolvedTemplateArgs;
619+
for (const ArgumentInit *templateArg : templateArgs)
620+
resolvedTemplateArgs.emplace_back(
621+
llvm::cast<ArgumentInit>(templateArg->resolveReferences(R)));
622+
std::string recName = getCanonicalRecName(Rec.get());
623+
recordTemplateArgs[recName] = resolvedTemplateArgs;
624+
// we only want to keep args for defs not classes which pass
625+
// args to other classes (i think so?)
626+
if (!Rec->isClass() && !resolvedTemplateArgs.empty())
627+
addOrReplaceTemplateArgs(Rec.get(), resolvedTemplateArgs);
535628

536629
if (Dest)
537630
Dest->push_back(std::move(Rec));
@@ -574,7 +667,7 @@ bool TGParser::addDefOne(std::unique_ptr<Record> Rec) {
574667
Rec->emitRecordDumps();
575668

576669
// If ObjectBody has template arguments, it's an error.
577-
assert(Rec->getTemplateArgs().empty() && "How'd this get template args?");
670+
// assert(Rec->getTemplateArgs().empty() && "How'd this get template args?");
578671

579672
for (DefsetRecord *Defset : Defsets) {
580673
DefInit *I = Rec->getDefInit();
@@ -2872,7 +2965,9 @@ const Init *TGParser::ParseSimpleValue(Record *CurRec, const RecTy *ItemType,
28722965

28732966
if (TrackReferenceLocs)
28742967
Class->appendReferenceLoc(NameLoc);
2875-
return VarDefInit::get(NameLoc.Start, Class, Args)->Fold();
2968+
const VarDefInit *v = VarDefInit::get(NameLoc.Start, Class, Args);
2969+
TheVarDefInitPool.insert(v);
2970+
return v->Fold();
28762971
}
28772972
case tgtok::l_brace: { // Value ::= '{' ValueList '}'
28782973
SMLoc BraceLoc = Lex.getLoc();

projects/eudsl-tblgen/src/TGParser.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ class TGVarScope {
9090
public:
9191
enum ScopeKind { SK_Local, SK_Record, SK_ForeachLoop, SK_MultiClass };
9292

93-
private:
9493
ScopeKind Kind;
9594
std::unique_ptr<TGVarScope> Parent;
9695
// A scope to hold variable definitions from defvar.
@@ -142,6 +141,9 @@ class TGParser {
142141
std::vector<SmallVector<LetRecord, 4>> LetStack;
143142
std::map<std::string, std::unique_ptr<MultiClass>> MultiClasses;
144143
std::map<std::string, const RecTy *> TypeAliases;
144+
DenseSet<const VarDefInit *> TheVarDefInitPool;
145+
std::unordered_map<std::string, SmallVector<const ArgumentInit *, 4>>
146+
recordTemplateArgs;
145147

146148
/// Loops - Keep track of any foreach loops we are within.
147149
///
@@ -181,6 +183,12 @@ class TGParser {
181183
NoWarnOnUnusedTemplateArgs(NoWarnOnUnusedTemplateArgs),
182184
TrackReferenceLocs(TrackReferenceLocs) {}
183185

186+
DenseSet<const VarDefInit *> getVarDefInits() { return TheVarDefInitPool; }
187+
std::unordered_map<std::string, SmallVector<const ArgumentInit *, 4>>
188+
getRecordTemplateArgs() {
189+
return recordTemplateArgs;
190+
}
191+
184192
/// ParseFile - Main entrypoint for parsing a tblgen file. These parser
185193
/// routines return true on error, or false on success.
186194
bool ParseFile();
@@ -219,7 +227,6 @@ class TGParser {
219227
CurScope = CurScope->extractParent();
220228
}
221229

222-
private: // Semantic analysis methods.
223230
bool AddValue(Record *TheRec, SMLoc Loc, const RecordVal &RV);
224231
/// Set the value of a RecordVal within the given record. If `OverrideDefLoc`
225232
/// is set, the provided location overrides any existing location of the
@@ -253,7 +260,6 @@ class TGParser {
253260
ArrayRef<const ArgumentInit *> ArgValues,
254261
const Init *DefmName, SMLoc Loc);
255262

256-
private: // Parser methods.
257263
bool consume(tgtok::TokKind K);
258264
bool ParseObjectList(MultiClass *MC = nullptr);
259265
bool ParseObject(MultiClass *MC);

0 commit comments

Comments
 (0)