Skip to content

Commit 2219f8b

Browse files
committed
LLVMCodeBuilder: Implement list type prediction in scripts
1 parent 7f62c8b commit 2219f8b

File tree

3 files changed

+526
-13
lines changed

3 files changed

+526
-13
lines changed

src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
9292
varPtr.onStack = false; // use heap before the first assignment
9393
}
9494

95-
m_scopeVariables.clear();
96-
m_scopeVariables.push_back({});
97-
9895
// Create list pointers
9996
for (auto &[list, listPtr] : m_listPtrs) {
10097
listPtr.ptr = getListPtr(targetLists, list);
@@ -109,6 +106,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
109106
m_builder.CreateStore(m_builder.getInt1(false), listPtr.dataPtrDirty);
110107
}
111108

109+
m_scopeVariables.clear();
110+
m_scopeLists.clear();
111+
pushScopeLevel();
112+
112113
// Execute recorded steps
113114
for (const LLVMInstruction &step : m_instructions) {
114115
switch (step.type) {
@@ -498,6 +499,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
498499
llvm::Value *dataPtrDirty = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.dataPtrDirty);
499500
llvm::Value *allocatedSize = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.allocatedSizePtr);
500501
m_builder.CreateStore(m_builder.CreateOr(dataPtrDirty, m_builder.CreateICmpNE(allocatedSize, oldAllocatedSize)), listPtr.dataPtrDirty);
502+
503+
m_scopeLists.back().erase(&listPtr);
501504
break;
502505
}
503506

@@ -515,7 +518,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
515518
assert(step.args.size() == 1);
516519
const auto &arg = step.args[0];
517520
Compiler::StaticType type = optimizeRegisterType(arg.second);
518-
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
521+
LLVMListPtr &listPtr = m_listPtrs[step.workList];
519522

520523
// Check if enough space is allocated
521524
llvm::Value *allocatedSize = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.allocatedSizePtr);
@@ -541,7 +544,16 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
541544
m_builder.CreateBr(nextBlock);
542545

543546
m_builder.SetInsertPoint(nextBlock);
544-
// TODO: Implement list type prediction
547+
auto &typeMap = m_scopeLists.back();
548+
549+
if (typeMap.find(&listPtr) == typeMap.cend()) {
550+
listPtr.type = type;
551+
typeMap[&listPtr] = listPtr.type;
552+
} else if (listPtr.type != type) {
553+
listPtr.type = Compiler::StaticType::Unknown;
554+
typeMap[&listPtr] = listPtr.type;
555+
}
556+
545557
break;
546558
}
547559

@@ -550,7 +562,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
550562
const auto &indexArg = step.args[0];
551563
const auto &valueArg = step.args[1];
552564
Compiler::StaticType type = optimizeRegisterType(valueArg.second);
553-
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
565+
LLVMListPtr &listPtr = m_listPtrs[step.workList];
554566

555567
// dataPtrDirty
556568
llvm::Value *dataPtrDirty = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.dataPtrDirty);
@@ -561,8 +573,18 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
561573
// Insert
562574
llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty());
563575
llvm::Value *itemPtr = m_builder.CreateCall(resolve_list_insert_empty(), { listPtr.ptr, index });
564-
// TODO: Implement list type prediction
565576
createReusedValueStore(valueArg.second, itemPtr, type);
577+
578+
auto &typeMap = m_scopeLists.back();
579+
580+
if (typeMap.find(&listPtr) == typeMap.cend()) {
581+
listPtr.type = type;
582+
typeMap[&listPtr] = listPtr.type;
583+
} else if (listPtr.type != type) {
584+
listPtr.type = Compiler::StaticType::Unknown;
585+
typeMap[&listPtr] = listPtr.type;
586+
}
587+
566588
break;
567589
}
568590

@@ -571,11 +593,21 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
571593
const auto &indexArg = step.args[0];
572594
const auto &valueArg = step.args[1];
573595
Compiler::StaticType type = optimizeRegisterType(valueArg.second);
574-
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
596+
LLVMListPtr &listPtr = m_listPtrs[step.workList];
575597
llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty());
576598
llvm::Value *itemPtr = getListItem(listPtr, index, func);
577599
createValueStore(valueArg.second, itemPtr, type, listPtr.type);
578-
// TODO: Implement list type prediction
600+
601+
auto &typeMap = m_scopeLists.back();
602+
603+
if (typeMap.find(&listPtr) == typeMap.cend()) {
604+
listPtr.type = type;
605+
typeMap[&listPtr] = listPtr.type;
606+
} else if (listPtr.type != type) {
607+
listPtr.type = Compiler::StaticType::Unknown;
608+
typeMap[&listPtr] = listPtr.type;
609+
}
610+
579611
break;
580612
}
581613

@@ -618,9 +650,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
618650
if (!m_warp) {
619651
freeHeap();
620652
syncVariables(targetVariables);
621-
reloadLists();
622653
coro->createSuspend();
623654
reloadVariables(targetVariables);
655+
reloadLists();
624656
}
625657

626658
break;
@@ -1285,6 +1317,16 @@ void LLVMCodeBuilder::createListMap()
12851317
void LLVMCodeBuilder::pushScopeLevel()
12861318
{
12871319
m_scopeVariables.push_back({});
1320+
1321+
if (m_scopeLists.empty()) {
1322+
std::unordered_map<LLVMListPtr *, Compiler::StaticType> listTypes;
1323+
1324+
for (auto &[list, listPtr] : m_listPtrs)
1325+
listTypes[&listPtr] = Compiler::StaticType::Unknown;
1326+
1327+
m_scopeLists.push_back(listTypes);
1328+
} else
1329+
m_scopeLists.push_back(m_scopeLists.back());
12881330
}
12891331

12901332
void LLVMCodeBuilder::popScopeLevel()
@@ -1297,6 +1339,15 @@ void LLVMCodeBuilder::popScopeLevel()
12971339
}
12981340

12991341
m_scopeVariables.pop_back();
1342+
1343+
for (size_t i = 0; i < m_scopeLists.size() - 1; i++) {
1344+
for (auto &[ptr, type] : m_scopeLists[i]) {
1345+
if (ptr->type != type)
1346+
ptr->type = Compiler::StaticType::Unknown;
1347+
}
1348+
}
1349+
1350+
m_scopeLists.pop_back();
13001351
}
13011352

13021353
void LLVMCodeBuilder::verifyFunction(llvm::Function *func)
@@ -1605,9 +1656,14 @@ void LLVMCodeBuilder::reloadVariables(llvm::Value *targetVariables)
16051656

16061657
void LLVMCodeBuilder::reloadLists()
16071658
{
1608-
// Reload list data pointers
1609-
for (auto &[list, listPtr] : m_listPtrs)
1610-
m_builder.CreateStore(m_builder.CreateCall(resolve_list_data(), listPtr.ptr), listPtr.dataPtr);
1659+
// Reset list data dirty and list types
1660+
auto &typeMap = m_scopeLists.back();
1661+
1662+
for (auto &[list, listPtr] : m_listPtrs) {
1663+
m_builder.CreateStore(m_builder.getInt1(true), listPtr.dataPtrDirty);
1664+
listPtr.type = Compiler::StaticType::Unknown;
1665+
typeMap[&listPtr] = listPtr.type;
1666+
}
16111667
}
16121668

16131669
void LLVMCodeBuilder::updateListDataPtr(const LLVMListPtr &listPtr, llvm::Function *func)

src/dev/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class LLVMCodeBuilder : public ICodeBuilder
165165

166166
std::unordered_map<List *, size_t> m_targetListMap;
167167
std::unordered_map<List *, LLVMListPtr> m_listPtrs;
168+
std::vector<std::unordered_map<LLVMListPtr *, Compiler::StaticType>> m_scopeLists;
168169

169170
std::string m_id;
170171
llvm::LLVMContext m_ctx;

0 commit comments

Comments
 (0)