Skip to content

Commit fedc9e5

Browse files
authored
Merge pull request #599 from scratchcpp/llvm_variable_script_type_prediction
LLVM: Implement variable type prediction in scripts
2 parents d8dc2bb + 48cf8d3 commit fedc9e5

File tree

5 files changed

+425
-33
lines changed

5 files changed

+425
-33
lines changed

src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 146 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,18 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
8080
for (auto &[var, varPtr] : m_variablePtrs) {
8181
llvm::Value *ptr = getVariablePtr(targetVariables, var);
8282

83-
// Access variable directly (slow?)
84-
// varPtr.ptr = ptr;
83+
// Direct access
84+
varPtr.heapPtr = ptr;
8585

86-
// All variables are currently copied to the stack and synced later (seems to be faster)
86+
// All variables are currently created on the stack and synced later (seems to be faster)
8787
// NOTE: Strings are NOT copied, only the pointer and string size are copied
88-
varPtr.ptr = m_builder.CreateAlloca(m_valueDataType);
89-
varPtr.onStack = true;
90-
createValueCopy(ptr, varPtr.ptr);
88+
varPtr.stackPtr = m_builder.CreateAlloca(m_valueDataType);
89+
varPtr.onStack = false; // use heap before the first assignment
9190
}
9291

92+
m_scopeVariables.clear();
93+
m_scopeVariables.push_back({});
94+
9395
// Execute recorded steps
9496
for (const LLVMInstruction &step : m_instructions) {
9597
switch (step.type) {
@@ -107,12 +109,13 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
107109
args.push_back(castValue(arg.second, arg.first));
108110
}
109111

110-
llvm::Value *ret = m_builder.CreateCall(resolveFunction(step.functionName, llvm::FunctionType::get(getType(step.functionReturnType), types, false)), args);
112+
llvm::Type *retType = getType(step.functionReturnReg ? step.functionReturnReg->type : Compiler::StaticType::Void);
113+
llvm::Value *ret = m_builder.CreateCall(resolveFunction(step.functionName, llvm::FunctionType::get(retType, types, false)), args);
111114

112115
if (step.functionReturnReg) {
113116
step.functionReturnReg->value = ret;
114117

115-
if (step.functionReturnType == Compiler::StaticType::String)
118+
if (step.functionReturnReg->type == Compiler::StaticType::String)
116119
m_heap.push_back(step.functionReturnReg->value);
117120
}
118121

@@ -429,16 +432,41 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
429432
assert(step.args.size() == 1);
430433
assert(m_variablePtrs.find(step.workVariable) != m_variablePtrs.cend());
431434
const auto &arg = step.args[0];
435+
Compiler::StaticType type = optimizeRegisterType(arg.second);
432436
LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable];
433437
varPtr.changed = true;
434-
createValueStore(arg.second, varPtr.ptr);
438+
439+
// Initialize stack variable on first assignment
440+
if (!varPtr.onStack) {
441+
varPtr.onStack = true;
442+
varPtr.type = type; // don't care about unknown type on first assignment
443+
444+
ValueType mappedType;
445+
446+
if (type == Compiler::StaticType::String || type == Compiler::StaticType::Unknown) {
447+
// Value functions are used for these types, so don't break them
448+
mappedType = ValueType::Number;
449+
} else {
450+
auto it = std::find_if(TYPE_MAP.begin(), TYPE_MAP.end(), [type](const std::pair<ValueType, Compiler::StaticType> &pair) { return pair.second == type; });
451+
assert(it != TYPE_MAP.cend());
452+
mappedType = it->first;
453+
}
454+
455+
llvm::Value *typeField = m_builder.CreateStructGEP(m_valueDataType, varPtr.stackPtr, 1);
456+
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(mappedType)), typeField);
457+
}
458+
459+
createValueStore(arg.second, varPtr.stackPtr, type, varPtr.type);
460+
varPtr.type = type;
461+
m_scopeVariables.back()[&varPtr] = varPtr.type;
435462
break;
436463
}
437464

438465
case LLVMInstruction::Type::ReadVariable: {
439466
assert(step.args.size() == 0);
440467
const LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable];
441-
step.functionReturnReg->value = varPtr.ptr;
468+
step.functionReturnReg->value = varPtr.onStack ? varPtr.stackPtr : varPtr.heapPtr;
469+
step.functionReturnReg->type = varPtr.type;
442470
break;
443471
}
444472

@@ -447,6 +475,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
447475
freeHeap();
448476
syncVariables(targetVariables);
449477
coro->createSuspend();
478+
reloadVariables(targetVariables);
450479
}
451480

452481
break;
@@ -467,13 +496,23 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
467496
m_builder.SetInsertPoint(statement.body);
468497

469498
ifStatements.push_back(statement);
499+
pushScopeLevel();
470500
break;
471501
}
472502

473503
case LLVMInstruction::Type::BeginElse: {
474504
assert(!ifStatements.empty());
475505
LLVMIfStatement &statement = ifStatements.back();
476506

507+
// Restore types from parent scope
508+
std::unordered_map<LLVMVariablePtr *, Compiler::StaticType> parentScopeVariables = m_scopeVariables[m_scopeVariables.size() - 2]; // no reference!
509+
popScopeLevel();
510+
511+
for (auto &[ptr, type] : parentScopeVariables)
512+
ptr->type = type;
513+
514+
pushScopeLevel();
515+
477516
// Jump to the branch after the if statement
478517
assert(!statement.afterIf);
479518
statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", func);
@@ -515,6 +554,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
515554
m_builder.SetInsertPoint(statement.afterIf);
516555

517556
ifStatements.pop_back();
557+
popScopeLevel();
518558
break;
519559
}
520560

@@ -568,6 +608,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
568608
m_builder.SetInsertPoint(body);
569609

570610
loops.push_back(loop);
611+
pushScopeLevel();
571612
break;
572613
}
573614

@@ -589,6 +630,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
589630

590631
// Switch to body branch
591632
m_builder.SetInsertPoint(body);
633+
pushScopeLevel();
592634
break;
593635
}
594636

@@ -610,6 +652,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
610652

611653
// Switch to body branch
612654
m_builder.SetInsertPoint(body);
655+
pushScopeLevel();
613656
break;
614657
}
615658

@@ -643,6 +686,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
643686
m_builder.SetInsertPoint(loop.afterLoop);
644687

645688
loops.pop_back();
689+
popScopeLevel();
646690
break;
647691
}
648692
}
@@ -705,8 +749,6 @@ void LLVMCodeBuilder::addFunctionCall(const std::string &functionName, Compiler:
705749

706750
m_tmpRegs.erase(m_tmpRegs.end() - argTypes.size(), m_tmpRegs.end());
707751

708-
ins.functionReturnType = returnType;
709-
710752
if (returnType != Compiler::StaticType::Void) {
711753
auto reg = std::make_shared<LLVMRegister>(returnType);
712754
reg->isRawValue = true;
@@ -729,7 +771,6 @@ void LLVMCodeBuilder::addConstValue(const Value &value)
729771

730772
void LLVMCodeBuilder::addVariableValue(Variable *variable)
731773
{
732-
// TODO: Implement type prediction
733774
LLVMInstruction ins(LLVMInstruction::Type::ReadVariable);
734775
ins.workVariable = variable;
735776
m_variablePtrs[variable] = LLVMVariablePtr();
@@ -990,6 +1031,23 @@ void LLVMCodeBuilder::createVariableMap()
9901031
}
9911032
}
9921033

1034+
void LLVMCodeBuilder::pushScopeLevel()
1035+
{
1036+
m_scopeVariables.push_back({});
1037+
}
1038+
1039+
void LLVMCodeBuilder::popScopeLevel()
1040+
{
1041+
for (size_t i = 0; i < m_scopeVariables.size() - 1; i++) {
1042+
for (auto &[ptr, type] : m_scopeVariables[i]) {
1043+
if (ptr->type != type)
1044+
ptr->type = Compiler::StaticType::Unknown;
1045+
}
1046+
}
1047+
1048+
m_scopeVariables.pop_back();
1049+
}
1050+
9931051
void LLVMCodeBuilder::verifyFunction(llvm::Function *func)
9941052
{
9951053
if (llvm::verifyFunction(*func, &llvm::errs())) {
@@ -1199,6 +1257,17 @@ llvm::Constant *LLVMCodeBuilder::castConstValue(const Value &value, Compiler::St
11991257
}
12001258
}
12011259

1260+
Compiler::StaticType LLVMCodeBuilder::optimizeRegisterType(LLVMRegisterPtr reg)
1261+
{
1262+
Compiler::StaticType ret = reg->type;
1263+
1264+
// Optimize string constants that represent numbers
1265+
if (reg->isConstValue && reg->type == Compiler::StaticType::String && reg->constValue.isValidNumber())
1266+
ret = Compiler::StaticType::Number;
1267+
1268+
return ret;
1269+
}
1270+
12021271
llvm::Type *LLVMCodeBuilder::getType(Compiler::StaticType type)
12031272
{
12041273
switch (type) {
@@ -1248,14 +1317,25 @@ llvm::Value *LLVMCodeBuilder::getVariablePtr(llvm::Value *targetVariables, Varia
12481317

12491318
void LLVMCodeBuilder::syncVariables(llvm::Value *targetVariables)
12501319
{
1320+
// Copy stack variables to the actual variables
12511321
for (auto &[var, varPtr] : m_variablePtrs) {
12521322
if (varPtr.onStack && varPtr.changed)
1253-
createValueCopy(varPtr.ptr, getVariablePtr(targetVariables, var));
1323+
createValueCopy(varPtr.stackPtr, getVariablePtr(targetVariables, var));
12541324

12551325
varPtr.changed = false;
12561326
}
12571327
}
12581328

1329+
void LLVMCodeBuilder::reloadVariables(llvm::Value *targetVariables)
1330+
{
1331+
// Reset variables to use heap
1332+
for (auto &[var, varPtr] : m_variablePtrs) {
1333+
varPtr.onStack = false;
1334+
varPtr.changed = false;
1335+
varPtr.type = Compiler::StaticType::Unknown;
1336+
}
1337+
}
1338+
12591339
LLVMInstruction &LLVMCodeBuilder::createOp(LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, size_t argCount)
12601340
{
12611341
LLVMInstruction ins(type);
@@ -1280,33 +1360,69 @@ LLVMInstruction &LLVMCodeBuilder::createOp(LLVMInstruction::Type type, Compiler:
12801360
return m_instructions.back();
12811361
}
12821362

1283-
void LLVMCodeBuilder::createValueStore(LLVMRegisterPtr reg, llvm::Value *targetPtr)
1363+
void LLVMCodeBuilder::createValueStore(LLVMRegisterPtr reg, llvm::Value *targetPtr, Compiler::StaticType sourceType, Compiler::StaticType targetType)
12841364
{
1285-
// TODO: Implement type prediction
1286-
Compiler::StaticType type = reg->type;
12871365
llvm::Value *converted = nullptr;
12881366

1289-
// Optimize string constants that represent numbers
1290-
if (reg->isConstValue && reg->type == Compiler::StaticType::String && reg->constValue.isValidNumber())
1291-
type = Compiler::StaticType::Number;
1367+
if (sourceType != Compiler::StaticType::Unknown)
1368+
converted = castValue(reg, sourceType);
12921369

1293-
switch (type) {
1370+
auto it = std::find_if(TYPE_MAP.begin(), TYPE_MAP.end(), [sourceType](const std::pair<ValueType, Compiler::StaticType> &pair) { return pair.second == sourceType; });
1371+
const ValueType mappedType = it == TYPE_MAP.cend() ? ValueType::Number : it->first; // unknown type can be ignored
1372+
1373+
switch (sourceType) {
12941374
case Compiler::StaticType::Number:
1295-
converted = castValue(reg, type);
1296-
m_builder.CreateCall(resolve_value_assign_double(), { targetPtr, converted });
1297-
/*{
1298-
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1299-
m_builder.CreateStore(converted, ptr);
1300-
}*/
1375+
switch (targetType) {
1376+
case Compiler::StaticType::Number: {
1377+
// Write number to number directly
1378+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1379+
m_builder.CreateStore(converted, ptr);
1380+
break;
1381+
}
1382+
1383+
case Compiler::StaticType::Bool: {
1384+
// Write number to bool value directly and change type
1385+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1386+
llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1387+
m_builder.CreateStore(converted, ptr);
1388+
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(mappedType)), typePtr);
1389+
break;
1390+
}
1391+
1392+
default:
1393+
m_builder.CreateCall(resolve_value_assign_double(), { targetPtr, converted });
1394+
break;
1395+
}
1396+
13011397
break;
13021398

13031399
case Compiler::StaticType::Bool:
1304-
converted = castValue(reg, type);
1305-
m_builder.CreateCall(resolve_value_assign_bool(), { targetPtr, converted });
1400+
switch (targetType) {
1401+
case Compiler::StaticType::Number: {
1402+
// Write bool to number value directly and change type
1403+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1404+
m_builder.CreateStore(converted, ptr);
1405+
llvm::Value *typePtr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1406+
m_builder.CreateStore(converted, ptr);
1407+
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(mappedType)), typePtr);
1408+
break;
1409+
}
1410+
1411+
case Compiler::StaticType::Bool: {
1412+
// Write bool to bool directly
1413+
llvm::Value *ptr = m_builder.CreateStructGEP(m_valueDataType, targetPtr, 0);
1414+
m_builder.CreateStore(converted, ptr);
1415+
break;
1416+
}
1417+
1418+
default:
1419+
m_builder.CreateCall(resolve_value_assign_bool(), { targetPtr, converted });
1420+
break;
1421+
}
1422+
13061423
break;
13071424

13081425
case Compiler::StaticType::String:
1309-
converted = castValue(reg, type);
13101426
m_builder.CreateCall(resolve_value_assign_cstring(), { targetPtr, converted });
13111427
break;
13121428

src/dev/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class LLVMCodeBuilder : public ICodeBuilder
8383

8484
void initTypes();
8585
void createVariableMap();
86+
void pushScopeLevel();
87+
void popScopeLevel();
8688

8789
void verifyFunction(llvm::Function *func);
8890
void optimize();
@@ -91,16 +93,18 @@ class LLVMCodeBuilder : public ICodeBuilder
9193
llvm::Value *castValue(LLVMRegisterPtr reg, Compiler::StaticType targetType);
9294
llvm::Value *castRawValue(LLVMRegisterPtr reg, Compiler::StaticType targetType);
9395
llvm::Constant *castConstValue(const Value &value, Compiler::StaticType targetType);
96+
Compiler::StaticType optimizeRegisterType(LLVMRegisterPtr reg);
9497
llvm::Type *getType(Compiler::StaticType type);
9598
llvm::Value *isNaN(llvm::Value *num);
9699
llvm::Value *removeNaN(llvm::Value *num);
97100

98101
llvm::Value *getVariablePtr(llvm::Value *targetVariables, Variable *variable);
99102
void syncVariables(llvm::Value *targetVariables);
103+
void reloadVariables(llvm::Value *targetVariables);
100104

101105
LLVMInstruction &createOp(LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, size_t argCount);
102106

103-
void createValueStore(LLVMRegisterPtr reg, llvm::Value *targetPtr);
107+
void createValueStore(LLVMRegisterPtr reg, llvm::Value *targetPtr, Compiler::StaticType sourceType, Compiler::StaticType targetType);
104108
void createValueCopy(llvm::Value *source, llvm::Value *target);
105109
void copyStructField(llvm::Value *source, llvm::Value *target, int index, llvm::StructType *structType, llvm::Type *fieldType);
106110
llvm::Value *createValue(LLVMRegisterPtr reg);
@@ -130,6 +134,7 @@ class LLVMCodeBuilder : public ICodeBuilder
130134
Target *m_target = nullptr;
131135
std::unordered_map<Variable *, size_t> m_targetVariableMap;
132136
std::unordered_map<Variable *, LLVMVariablePtr> m_variablePtrs;
137+
std::vector<std::unordered_map<LLVMVariablePtr *, Compiler::StaticType>> m_scopeVariables;
133138

134139
std::string m_id;
135140
llvm::LLVMContext m_ctx;

src/dev/engine/internal/llvm/llvminstruction.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ struct LLVMInstruction
6161
Type type;
6262
std::string functionName;
6363
std::vector<std::pair<Compiler::StaticType, LLVMRegisterPtr>> args; // target type, register
64-
Compiler::StaticType functionReturnType = Compiler::StaticType::Void;
6564
LLVMRegisterPtr functionReturnReg;
6665
Variable *workVariable = nullptr; // for variables
6766
};

src/dev/engine/internal/llvm/llvmvariableptr.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ namespace libscratchcpp
1616

1717
struct LLVMVariablePtr
1818
{
19-
llvm::Value *ptr = nullptr;
19+
llvm::Value *stackPtr = nullptr;
20+
llvm::Value *heapPtr = nullptr;
2021
Compiler::StaticType type = Compiler::StaticType::Unknown;
2122
bool onStack = false;
2223
bool changed = false;

0 commit comments

Comments
 (0)