Skip to content

Commit 48cf8d3

Browse files
committed
LLVMCodeBuilder: Implement variable script type prediction
1 parent c8d6923 commit 48cf8d3

File tree

4 files changed

+422
-28
lines changed

4 files changed

+422
-28
lines changed

src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 143 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,18 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
8080
for (auto &[var, varPtr] : m_variablePtrs) {
8181
llvm::Value *ptr = getVariablePtr(targetVariables, var);
8282

83-
// Access variable directly (slow?)
84-
// varPtr.ptr = ptr;
83+
// Direct access
84+
varPtr.heapPtr = ptr;
8585

86-
// All variables are currently copied to the stack and synced later (seems to be faster)
86+
// All variables are currently created on the stack and synced later (seems to be faster)
8787
// NOTE: Strings are NOT copied, only the pointer and string size are copied
88-
varPtr.ptr = m_builder.CreateAlloca(m_valueDataType);
89-
varPtr.onStack = true;
90-
createValueCopy(ptr, varPtr.ptr);
88+
varPtr.stackPtr = m_builder.CreateAlloca(m_valueDataType);
89+
varPtr.onStack = false; // use heap before the first assignment
9190
}
9291

92+
m_scopeVariables.clear();
93+
m_scopeVariables.push_back({});
94+
9395
// Execute recorded steps
9496
for (const LLVMInstruction &step : m_instructions) {
9597
switch (step.type) {
@@ -430,16 +432,41 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
430432
assert(step.args.size() == 1);
431433
assert(m_variablePtrs.find(step.workVariable) != m_variablePtrs.cend());
432434
const auto &arg = step.args[0];
435+
Compiler::StaticType type = optimizeRegisterType(arg.second);
433436
LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable];
434437
varPtr.changed = true;
435-
createValueStore(arg.second, varPtr.ptr);
438+
439+
// Initialize stack variable on first assignment
440+
if (!varPtr.onStack) {
441+
varPtr.onStack = true;
442+
varPtr.type = type; // don't care about unknown type on first assignment
443+
444+
ValueType mappedType;
445+
446+
if (type == Compiler::StaticType::String || type == Compiler::StaticType::Unknown) {
447+
// Value functions are used for these types, so don't break them
448+
mappedType = ValueType::Number;
449+
} else {
450+
auto it = std::find_if(TYPE_MAP.begin(), TYPE_MAP.end(), [type](const std::pair<ValueType, Compiler::StaticType> &pair) { return pair.second == type; });
451+
assert(it != TYPE_MAP.cend());
452+
mappedType = it->first;
453+
}
454+
455+
llvm::Value *typeField = m_builder.CreateStructGEP(m_valueDataType, varPtr.stackPtr, 1);
456+
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(mappedType)), typeField);
457+
}
458+
459+
createValueStore(arg.second, varPtr.stackPtr, type, varPtr.type);
460+
varPtr.type = type;
461+
m_scopeVariables.back()[&varPtr] = varPtr.type;
436462
break;
437463
}
438464

439465
case LLVMInstruction::Type::ReadVariable: {
440466
assert(step.args.size() == 0);
441467
const LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable];
442-
step.functionReturnReg->value = varPtr.ptr;
468+
step.functionReturnReg->value = varPtr.onStack ? varPtr.stackPtr : varPtr.heapPtr;
469+
step.functionReturnReg->type = varPtr.type;
443470
break;
444471
}
445472

@@ -448,6 +475,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
448475
freeHeap();
449476
syncVariables(targetVariables);
450477
coro->createSuspend();
478+
reloadVariables(targetVariables);
451479
}
452480

453481
break;
@@ -468,13 +496,23 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
468496
m_builder.SetInsertPoint(statement.body);
469497

470498
ifStatements.push_back(statement);
499+
pushScopeLevel();
471500
break;
472501
}
473502

474503
case LLVMInstruction::Type::BeginElse: {
475504
assert(!ifStatements.empty());
476505
LLVMIfStatement &statement = ifStatements.back();
477506

507+
// Restore types from parent scope
508+
std::unordered_map<LLVMVariablePtr *, Compiler::StaticType> parentScopeVariables = m_scopeVariables[m_scopeVariables.size() - 2]; // no reference!
509+
popScopeLevel();
510+
511+
for (auto &[ptr, type] : parentScopeVariables)
512+
ptr->type = type;
513+
514+
pushScopeLevel();
515+
478516
// Jump to the branch after the if statement
479517
assert(!statement.afterIf);
480518
statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", func);
@@ -516,6 +554,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
516554
m_builder.SetInsertPoint(statement.afterIf);
517555

518556
ifStatements.pop_back();
557+
popScopeLevel();
519558
break;
520559
}
521560

@@ -569,6 +608,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
569608
m_builder.SetInsertPoint(body);
570609

571610
loops.push_back(loop);
611+
pushScopeLevel();
572612
break;
573613
}
574614

@@ -590,6 +630,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
590630

591631
// Switch to body branch
592632
m_builder.SetInsertPoint(body);
633+
pushScopeLevel();
593634
break;
594635
}
595636

@@ -611,6 +652,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
611652

612653
// Switch to body branch
613654
m_builder.SetInsertPoint(body);
655+
pushScopeLevel();
614656
break;
615657
}
616658

@@ -644,6 +686,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
644686
m_builder.SetInsertPoint(loop.afterLoop);
645687

646688
loops.pop_back();
689+
popScopeLevel();
647690
break;
648691
}
649692
}
@@ -728,7 +771,6 @@ void LLVMCodeBuilder::addConstValue(const Value &value)
728771

729772
void LLVMCodeBuilder::addVariableValue(Variable *variable)
730773
{
731-
// TODO: Implement type prediction
732774
LLVMInstruction ins(LLVMInstruction::Type::ReadVariable);
733775
ins.workVariable = variable;
734776
m_variablePtrs[variable] = LLVMVariablePtr();
@@ -989,6 +1031,23 @@ void LLVMCodeBuilder::createVariableMap()
9891031
}
9901032
}
9911033

1034+
void LLVMCodeBuilder::pushScopeLevel()
1035+
{
1036+
m_scopeVariables.push_back({});
1037+
}
1038+
1039+
void LLVMCodeBuilder::popScopeLevel()
1040+
{
1041+
for (size_t i = 0; i < m_scopeVariables.size() - 1; i++) {
1042+
for (auto &[ptr, type] : m_scopeVariables[i]) {
1043+
if (ptr->type != type)
1044+
ptr->type = Compiler::StaticType::Unknown;
1045+
}
1046+
}
1047+
1048+
m_scopeVariables.pop_back();
1049+
}
1050+
9921051
void LLVMCodeBuilder::verifyFunction(llvm::Function *func)
9931052
{
9941053
if (llvm::verifyFunction(*func, &llvm::errs())) {
@@ -1198,6 +1257,17 @@ llvm::Constant *LLVMCodeBuilder::castConstValue(const Value &value, Compiler::St
11981257
}
11991258
}
12001259

1260+
Compiler::StaticType LLVMCodeBuilder::optimizeRegisterType(LLVMRegisterPtr reg)
1261+
{
1262+
Compiler::StaticType ret = reg->type;
1263+
1264+
// Optimize string constants that represent numbers
1265+
if (reg->isConstValue && reg->type == Compiler::StaticType::String && reg->constValue.isValidNumber())
1266+
ret = Compiler::StaticType::Number;
1267+
1268+
return ret;
1269+
}
1270+
12011271
llvm::Type *LLVMCodeBuilder::getType(Compiler::StaticType type)
12021272
{
12031273
switch (type) {
@@ -1247,14 +1317,25 @@ llvm::Value *LLVMCodeBuilder::getVariablePtr(llvm::Value *targetVariables, Varia
12471317

12481318
void LLVMCodeBuilder::syncVariables(llvm::Value *targetVariables)
12491319
{
1320+
// Copy stack variables to the actual variables
12501321
for (auto &[var, varPtr] : m_variablePtrs) {
12511322
if (varPtr.onStack && varPtr.changed)
1252-
createValueCopy(varPtr.ptr, getVariablePtr(targetVariables, var));
1323+
createValueCopy(varPtr.stackPtr, getVariablePtr(targetVariables, var));
12531324

12541325
varPtr.changed = false;
12551326
}
12561327
}
12571328

1329+
void LLVMCodeBuilder::reloadVariables(llvm::Value *targetVariables)
1330+
{
1331+
// Reset variables to use heap
1332+
for (auto &[var, varPtr] : m_variablePtrs) {
1333+
varPtr.onStack = false;
1334+
varPtr.changed = false;
1335+
varPtr.type = Compiler::StaticType::Unknown;
1336+
}
1337+
}
1338+
12581339
LLVMInstruction &LLVMCodeBuilder::createOp(LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, size_t argCount)
12591340
{
12601341
LLVMInstruction ins(type);
@@ -1279,33 +1360,69 @@ LLVMInstruction &LLVMCodeBuilder::createOp(LLVMInstruction::Type type, Compiler:
12791360
return m_instructions.back();
12801361
}
12811362

1282-
void LLVMCodeBuilder::createValueStore(LLVMRegisterPtr reg, llvm::Value *targetPtr)
1363+
void LLVMCodeBuilder::createValueStore(LLVMRegisterPtr reg, llvm::Value *targetPtr, Compiler::StaticType sourceType, Compiler::StaticType targetType)
12831364
{
1284-
// TODO: Implement type prediction
1285-
Compiler::StaticType type = reg->type;
12861365
llvm::Value *converted = nullptr;
12871366

1288-
// Optimize string constants that represent numbers
1289-
if (reg->isConstValue && reg->type == Compiler::StaticType::String && reg->constValue.isValidNumber())
1290-
type = Compiler::StaticType::Number;
1367+
if (sourceType != Compiler::StaticType::Unknown)
1368+
converted = castValue(reg, sourceType);
12911369

1292-
switch (type) {
1370+
auto it = std::find_if(TYPE_MAP.begin(), TYPE_MAP.end(), [sourceType](const std::pair<ValueType, Compiler::StaticType> &pair) { return pair.second == sourceType; });
1371+
const ValueType mappedType = it == TYPE_MAP.cend() ? ValueType::Number : it->first; // unknown type can be ignored
1372+
1373+
switch (sourceType) {
12931374
case Compiler::StaticType::Number:
1294-
converted = castValue(reg, type);
1295-
m_builder.CreateCall(resolve_value_assign_double(), { targetPtr, converted });
1296-
/*{
1297-
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1298-
m_builder.CreateStore(converted, ptr);
1299-
}*/
1375+
switch (targetType) {
1376+
case Compiler::StaticType::Number: {
1377+
// Write number to number directly
1378+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1379+
m_builder.CreateStore(converted, ptr);
1380+
break;
1381+
}
1382+
1383+
case Compiler::StaticType::Bool: {
1384+
// Write number to bool value directly and change type
1385+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1386+
llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1387+
m_builder.CreateStore(converted, ptr);
1388+
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(mappedType)), typePtr);
1389+
break;
1390+
}
1391+
1392+
default:
1393+
m_builder.CreateCall(resolve_value_assign_double(), { targetPtr, converted });
1394+
break;
1395+
}
1396+
13001397
break;
13011398

13021399
case Compiler::StaticType::Bool:
1303-
converted = castValue(reg, type);
1304-
m_builder.CreateCall(resolve_value_assign_bool(), { targetPtr, converted });
1400+
switch (targetType) {
1401+
case Compiler::StaticType::Number: {
1402+
// Write bool to number value directly and change type
1403+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1404+
m_builder.CreateStore(converted, ptr);
1405+
llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1406+
m_builder.CreateStore(converted, ptr);
1407+
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(mappedType)), typePtr);
1408+
break;
1409+
}
1410+
1411+
case Compiler::StaticType::Bool: {
1412+
// Write bool to bool directly
1413+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1414+
m_builder.CreateStore(converted, ptr);
1415+
break;
1416+
}
1417+
1418+
default:
1419+
m_builder.CreateCall(resolve_value_assign_bool(), { targetPtr, converted });
1420+
break;
1421+
}
1422+
13051423
break;
13061424

13071425
case Compiler::StaticType::String:
1308-
converted = castValue(reg, type);
13091426
m_builder.CreateCall(resolve_value_assign_cstring(), { targetPtr, converted });
13101427
break;
13111428

src/dev/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class LLVMCodeBuilder : public ICodeBuilder
8383

8484
void initTypes();
8585
void createVariableMap();
86+
void pushScopeLevel();
87+
void popScopeLevel();
8688

8789
void verifyFunction(llvm::Function *func);
8890
void optimize();
@@ -91,16 +93,18 @@ class LLVMCodeBuilder : public ICodeBuilder
9193
llvm::Value *castValue(LLVMRegisterPtr reg, Compiler::StaticType targetType);
9294
llvm::Value *castRawValue(LLVMRegisterPtr reg, Compiler::StaticType targetType);
9395
llvm::Constant *castConstValue(const Value &value, Compiler::StaticType targetType);
96+
Compiler::StaticType optimizeRegisterType(LLVMRegisterPtr reg);
9497
llvm::Type *getType(Compiler::StaticType type);
9598
llvm::Value *isNaN(llvm::Value *num);
9699
llvm::Value *removeNaN(llvm::Value *num);
97100

98101
llvm::Value *getVariablePtr(llvm::Value *targetVariables, Variable *variable);
99102
void syncVariables(llvm::Value *targetVariables);
103+
void reloadVariables(llvm::Value *targetVariables);
100104

101105
LLVMInstruction &createOp(LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, size_t argCount);
102106

103-
void createValueStore(LLVMRegisterPtr reg, llvm::Value *targetPtr);
107+
void createValueStore(LLVMRegisterPtr reg, llvm::Value *targetPtr, Compiler::StaticType sourceType, Compiler::StaticType targetType);
104108
void createValueCopy(llvm::Value *source, llvm::Value *target);
105109
void copyStructField(llvm::Value *source, llvm::Value *target, int index, llvm::StructType *structType, llvm::Type *fieldType);
106110
llvm::Value *createValue(LLVMRegisterPtr reg);
@@ -130,6 +134,7 @@ class LLVMCodeBuilder : public ICodeBuilder
130134
Target *m_target = nullptr;
131135
std::unordered_map<Variable *, size_t> m_targetVariableMap;
132136
std::unordered_map<Variable *, LLVMVariablePtr> m_variablePtrs;
137+
std::vector<std::unordered_map<LLVMVariablePtr *, Compiler::StaticType>> m_scopeVariables;
133138

134139
std::string m_id;
135140
llvm::LLVMContext m_ctx;

src/dev/engine/internal/llvm/llvmvariableptr.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ namespace libscratchcpp
1616

1717
struct LLVMVariablePtr
1818
{
19-
llvm::Value *ptr = nullptr;
19+
llvm::Value *stackPtr = nullptr;
20+
llvm::Value *heapPtr = nullptr;
2021
Compiler::StaticType type = Compiler::StaticType::Unknown;
2122
bool onStack = false;
2223
bool changed = false;

0 commit comments

Comments
 (0)