Skip to content

Commit 953b0d9

Browse files
committed
LLVMCodeBuilder: Add list range checks
1 parent a3de164 commit 953b0d9

File tree

2 files changed

+160
-20
lines changed

2 files changed

+160
-20
lines changed

src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -660,9 +660,25 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
660660
assert(step.args.size() == 1);
661661
const auto &arg = step.args[0];
662662
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
663-
llvm::Value *index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty());
663+
664+
// Range check
665+
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
666+
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
667+
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
668+
llvm::Value *index = castValue(arg.second, arg.first);
669+
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
670+
llvm::BasicBlock *removeBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
671+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
672+
m_builder.CreateCondBr(inRange, removeBlock, nextBlock);
673+
674+
// Remove
675+
m_builder.SetInsertPoint(removeBlock);
676+
index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty());
664677
m_builder.CreateCall(resolve_list_remove(), { listPtr.ptr, index });
665678
// NOTE: Removing doesn't deallocate (see List::removeAt()), so there's no need to update the data pointer
679+
m_builder.CreateBr(nextBlock);
680+
681+
m_builder.SetInsertPoint(nextBlock);
666682
break;
667683
}
668684

@@ -733,11 +749,23 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
733749
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
734750
m_builder.CreateStore(m_builder.CreateOr(dataPtrDirty, m_builder.CreateICmpEQ(allocatedSize, size)), listPtr.dataPtrDirty);
735751

752+
// Range check
753+
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
754+
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
755+
llvm::Value *index = castValue(indexArg.second, indexArg.first);
756+
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLE(index, size));
757+
llvm::BasicBlock *insertBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
758+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
759+
m_builder.CreateCondBr(inRange, insertBlock, nextBlock);
760+
736761
// Insert
737-
llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty());
762+
m_builder.SetInsertPoint(insertBlock);
763+
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
738764
llvm::Value *itemPtr = m_builder.CreateCall(resolve_list_insert_empty(), { listPtr.ptr, index });
739765
createReusedValueStore(valueArg.second, itemPtr, type, listPtr.type);
766+
m_builder.CreateBr(nextBlock);
740767

768+
m_builder.SetInsertPoint(nextBlock);
741769
break;
742770
}
743771

@@ -747,9 +775,23 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
747775
const auto &valueArg = step.args[1];
748776
Compiler::StaticType type = optimizeRegisterType(valueArg.second);
749777
LLVMListPtr &listPtr = m_listPtrs[step.workList];
750-
llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty());
778+
779+
// Range check
780+
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
781+
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
782+
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
783+
llvm::Value *index = castValue(indexArg.second, indexArg.first);
784+
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
785+
llvm::BasicBlock *replaceBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
786+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
787+
m_builder.CreateCondBr(inRange, replaceBlock, nextBlock);
788+
789+
// Replace
790+
m_builder.SetInsertPoint(replaceBlock);
791+
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
751792
llvm::Value *itemPtr = getListItem(listPtr, index);
752793
createValueStore(valueArg.second, itemPtr, type, listPtr.type);
794+
m_builder.CreateBr(nextBlock);
753795

754796
auto &typeMap = m_scopeLists.back();
755797

@@ -761,6 +803,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
761803
typeMap[&listPtr] = listPtr.type;
762804
}
763805

806+
m_builder.SetInsertPoint(nextBlock);
764807
break;
765808
}
766809

@@ -777,8 +820,18 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
777820
assert(step.args.size() == 1);
778821
const auto &arg = step.args[0];
779822
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
780-
llvm::Value *index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty());
781-
step.functionReturnReg->value = getListItem(listPtr, index);
823+
824+
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
825+
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
826+
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
827+
llvm::Value *index = castValue(arg.second, arg.first);
828+
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
829+
830+
LLVMConstantRegister nullReg(listPtr.type == Compiler::StaticType::Unknown ? Compiler::StaticType::Number : listPtr.type, Value());
831+
llvm::Value *null = createValue(static_cast<LLVMRegister *>(static_cast<CompilerValue *>(&nullReg)));
832+
833+
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
834+
step.functionReturnReg->value = m_builder.CreateSelect(inRange, getListItem(listPtr, index), null);
782835
step.functionReturnReg->setType(listPtr.type);
783836
break;
784837
}

test/dev/llvm/llvmcodebuilder_test.cpp

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,9 +1913,21 @@ TEST_F(LLVMCodeBuilderTest, RemoveFromList)
19131913
CompilerValue *v = m_builder->addConstValue(1);
19141914
m_builder->createListRemove(globalList.get(), v);
19151915

1916+
v = m_builder->addConstValue(-1);
1917+
m_builder->createListRemove(globalList.get(), v);
1918+
1919+
v = m_builder->addConstValue(3);
1920+
m_builder->createListRemove(globalList.get(), v);
1921+
19161922
v = m_builder->addConstValue(3);
19171923
m_builder->createListRemove(localList.get(), v);
19181924

1925+
v = m_builder->addConstValue(-1);
1926+
m_builder->createListRemove(localList.get(), v);
1927+
1928+
v = m_builder->addConstValue(4);
1929+
m_builder->createListRemove(localList.get(), v);
1930+
19191931
auto code = m_builder->finalize();
19201932
Script script(&sprite, nullptr, nullptr);
19211933
script.setCode(code);
@@ -2040,6 +2052,18 @@ TEST_F(LLVMCodeBuilderTest, InsertToList)
20402052
v2 = m_builder->addConstValue("hello world");
20412053
m_builder->createListInsert(localList.get(), v1, v2);
20422054

2055+
v1 = m_builder->addConstValue(3);
2056+
v2 = m_builder->addConstValue("test");
2057+
m_builder->createListInsert(localList.get(), v1, v2);
2058+
2059+
v1 = m_builder->addConstValue(-1);
2060+
v2 = m_builder->addConstValue(123);
2061+
m_builder->createListInsert(localList.get(), v1, v2);
2062+
2063+
v1 = m_builder->addConstValue(6);
2064+
v2 = m_builder->addConstValue(123);
2065+
m_builder->createListInsert(localList.get(), v1, v2);
2066+
20432067
auto code = m_builder->finalize();
20442068
Script script(&sprite, nullptr, nullptr);
20452069
script.setCode(code);
@@ -2049,7 +2073,7 @@ TEST_F(LLVMCodeBuilderTest, InsertToList)
20492073
code->run(ctx.get());
20502074

20512075
ASSERT_EQ(globalList->toString(), "1 2 1 test 3");
2052-
ASSERT_EQ(localList->toString(), "false hello world true");
2076+
ASSERT_EQ(localList->toString(), "false hello world true test");
20532077
}
20542078

20552079
TEST_F(LLVMCodeBuilderTest, ListReplace)
@@ -2099,6 +2123,14 @@ TEST_F(LLVMCodeBuilderTest, ListReplace)
20992123
v2 = m_builder->addConstValue("hello world");
21002124
m_builder->createListReplace(localList.get(), v1, v2);
21012125

2126+
v1 = m_builder->addConstValue(-1);
2127+
v2 = m_builder->addConstValue(123);
2128+
m_builder->createListReplace(localList.get(), v1, v2);
2129+
2130+
v1 = m_builder->addConstValue(5);
2131+
v2 = m_builder->addConstValue(123);
2132+
m_builder->createListReplace(localList.get(), v1, v2);
2133+
21022134
auto code = m_builder->finalize();
21032135
Script script(&sprite, nullptr, nullptr);
21042136
script.setCode(code);
@@ -2169,26 +2201,36 @@ TEST_F(LLVMCodeBuilderTest, GetListItem)
21692201
sprite.setEngine(&m_engine);
21702202
EXPECT_CALL(m_engine, stage()).WillRepeatedly(Return(&stage));
21712203

2172-
std::unordered_map<List *, std::string> strings;
2173-
21742204
auto globalList = std::make_shared<List>("", "");
21752205
stage.addList(globalList);
21762206

2177-
auto localList = std::make_shared<List>("", "");
2178-
sprite.addList(localList);
2207+
auto localList1 = std::make_shared<List>("", "");
2208+
sprite.addList(localList1);
2209+
2210+
auto localList2 = std::make_shared<List>("", "");
2211+
sprite.addList(localList2);
2212+
2213+
auto localList3 = std::make_shared<List>("", "");
2214+
sprite.addList(localList3);
21792215

21802216
globalList->append(1);
21812217
globalList->append(2);
21822218
globalList->append(3);
21832219

2184-
localList->append("Lorem");
2185-
localList->append("ipsum");
2186-
localList->append("dolor");
2187-
localList->append("sit");
2188-
strings[localList.get()] = localList->toString();
2220+
localList1->append("Lorem");
2221+
localList1->append("ipsum");
2222+
localList1->append("dolor");
2223+
localList1->append("sit");
2224+
2225+
localList2->append(-564.121);
2226+
localList2->append(4257.4);
2227+
2228+
localList3->append(true);
2229+
localList3->append(false);
21892230

21902231
createBuilder(&sprite, true);
21912232

2233+
// Global
21922234
CompilerValue *v = m_builder->addConstValue(2);
21932235
v = m_builder->addListItem(globalList.get(), v);
21942236
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
@@ -2201,24 +2243,67 @@ TEST_F(LLVMCodeBuilderTest, GetListItem)
22012243
v = m_builder->addListItem(globalList.get(), v);
22022244
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
22032245

2246+
v = m_builder->addConstValue(-1);
2247+
v = m_builder->addListItem(globalList.get(), v);
2248+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
2249+
2250+
v = m_builder->addConstValue(3);
2251+
v = m_builder->addListItem(globalList.get(), v);
2252+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
2253+
2254+
// Local 1
22042255
v = m_builder->addConstValue(0);
2205-
v = m_builder->addListItem(localList.get(), v);
2256+
v = m_builder->addListItem(localList1.get(), v);
22062257
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
22072258

22082259
v = m_builder->addConstValue(2);
2209-
v = m_builder->addListItem(localList.get(), v);
2260+
v = m_builder->addListItem(localList1.get(), v);
22102261
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
22112262

22122263
v = m_builder->addConstValue(3);
2213-
v = m_builder->addListItem(localList.get(), v);
2264+
v = m_builder->addListItem(localList1.get(), v);
2265+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
2266+
2267+
v = m_builder->addConstValue(-1);
2268+
v = m_builder->addListItem(localList1.get(), v);
22142269
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
22152270

2271+
v = m_builder->addConstValue(4);
2272+
v = m_builder->addListItem(localList1.get(), v);
2273+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
2274+
2275+
// Local 2
2276+
v = m_builder->addConstValue(-1);
2277+
v = m_builder->addListItem(localList2.get(), v);
2278+
m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v });
2279+
2280+
v = m_builder->addConstValue(2);
2281+
v = m_builder->addListItem(localList2.get(), v);
2282+
m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v });
2283+
2284+
// Local 3
2285+
v = m_builder->addConstValue(-1);
2286+
v = m_builder->addListItem(localList3.get(), v);
2287+
m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v });
2288+
2289+
v = m_builder->addConstValue(2);
2290+
v = m_builder->addListItem(localList3.get(), v);
2291+
m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v });
2292+
22162293
static const std::string expected =
22172294
"3\n"
22182295
"1\n"
2296+
"0\n"
2297+
"0\n"
22192298
"Lorem\n"
22202299
"dolor\n"
2221-
"sit\n";
2300+
"sit\n"
2301+
"0\n"
2302+
"0\n"
2303+
"0\n"
2304+
"0\n"
2305+
"0\n"
2306+
"0\n";
22222307

22232308
auto code = m_builder->finalize();
22242309
Script script(&sprite, nullptr, nullptr);
@@ -2231,7 +2316,9 @@ TEST_F(LLVMCodeBuilderTest, GetListItem)
22312316
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected);
22322317

22332318
ASSERT_EQ(globalList->toString(), "1 test 3");
2234-
ASSERT_EQ(localList->toString(), "Lorem ipsum dolor sit");
2319+
ASSERT_EQ(localList1->toString(), "Lorem ipsum dolor sit");
2320+
ASSERT_EQ(localList2->toString(), "-564.121 4257.4");
2321+
ASSERT_EQ(localList3->toString(), "true false");
22352322
}
22362323

22372324
TEST_F(LLVMCodeBuilderTest, GetListSize)

0 commit comments

Comments
 (0)