Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3b41517

Browse files
committedNov 10, 2024·
Move coroutine code out of LLVMCodeBuilder
1 parent feb8a0b commit 3b41517

File tree

5 files changed

+168
-108
lines changed

5 files changed

+168
-108
lines changed
 

‎src/dev/engine/internal/llvm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ target_sources(scratchcpp
66
llvminstruction.h
77
llvmifstatement.h
88
llvmloop.h
9+
llvmcoroutine.cpp
910
llvmcoroutine.h
1011
llvmvariableptr.h
1112
llvmprocedure.h

‎src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 10 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
6060
m_builder.SetInsertPoint(entry);
6161

6262
// Init coroutine
63-
LLVMCoroutine coro;
63+
std::unique_ptr<LLVMCoroutine> coro;
6464

6565
if (!m_warp)
66-
coro = initCoroutine(func);
66+
coro = std::make_unique<LLVMCoroutine>(m_module.get(), &m_builder, func);
6767

6868
std::vector<LLVMIfStatement> ifStatements;
6969
std::vector<LLVMLoop> loops;
@@ -439,14 +439,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
439439
if (!m_warp) {
440440
freeHeap();
441441
syncVariables(targetVariables);
442-
m_builder.CreateStore(m_builder.getInt1(true), coro.didSuspend);
443-
llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create(m_ctx, "", func);
444-
llvm::Value *noneToken = llvm::ConstantTokenNone::get(m_ctx);
445-
llvm::Value *suspendResult = m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_suspend), { noneToken, m_builder.getInt1(false) });
446-
llvm::SwitchInst *sw = m_builder.CreateSwitch(suspendResult, coro.suspend, 2);
447-
sw->addCase(m_builder.getInt8(0), resumeBranch);
448-
sw->addCase(m_builder.getInt8(1), coro.cleanup);
449-
m_builder.SetInsertPoint(resumeBranch);
442+
coro->createSuspend();
450443
}
451444

452445
break;
@@ -651,34 +644,18 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
651644
freeHeap();
652645
syncVariables(targetVariables);
653646

654-
// Add final suspend point
655-
if (!m_warp) {
656-
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_ctx, "end", func);
657-
llvm::BasicBlock *finalSuspendBranch = llvm::BasicBlock::Create(m_ctx, "finalSuspend", func);
658-
m_builder.CreateCondBr(m_builder.CreateLoad(m_builder.getInt1Ty(), coro.didSuspend), finalSuspendBranch, endBranch);
659-
660-
m_builder.SetInsertPoint(finalSuspendBranch);
661-
llvm::Value *suspendResult =
662-
m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get(m_ctx), m_builder.getInt1(true) });
663-
llvm::SwitchInst *sw = m_builder.CreateSwitch(suspendResult, coro.suspend, 2);
664-
sw->addCase(m_builder.getInt8(0), endBranch); // unreachable
665-
sw->addCase(m_builder.getInt8(1), coro.cleanup);
666-
667-
m_builder.SetInsertPoint(endBranch);
668-
}
669-
670647
// End and verify the function
648+
if (m_warp)
649+
m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType));
650+
else
651+
coro->end();
652+
671653
if (!m_tmpRegs.empty()) {
672654
std::cout
673655
<< "warning: " << m_tmpRegs.size() << " registers were leaked by script '" << m_module->getName().str() << "', function '" << func->getName().str()
674656
<< "' (if you see this as a regular user, this is a bug and should be reported)" << std::endl;
675657
}
676658

677-
if (m_warp)
678-
m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType));
679-
else
680-
m_builder.CreateBr(coro.freeMemRet);
681-
682659
verifyFunction(func);
683660

684661
// Create resume function
@@ -691,12 +668,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
691668

692669
if (m_warp)
693670
m_builder.CreateRet(m_builder.getInt1(true));
694-
else {
695-
llvm::Value *coroHandle = func->getArg(0);
696-
m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_resume), { coroHandle });
697-
llvm::Value *done = m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_done), { coroHandle });
698-
m_builder.CreateRet(done);
699-
}
671+
else
672+
m_builder.CreateRet(coro->createResume(func->getArg(0)));
700673

701674
verifyFunction(func);
702675

@@ -1010,63 +983,6 @@ void LLVMCodeBuilder::createVariableMap()
1010983
}
1011984
}
1012985

1013-
LLVMCoroutine LLVMCodeBuilder::initCoroutine(llvm::Function *func)
1014-
{
1015-
// Set presplitcoroutine attribute
1016-
func->setPresplitCoroutine();
1017-
1018-
// Coroutine intrinsics
1019-
llvm::Function *coroId = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_id);
1020-
llvm::Function *coroSize = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_size, m_builder.getInt64Ty());
1021-
llvm::Function *coroBegin = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_begin);
1022-
llvm::Function *coroEnd = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_end);
1023-
llvm::Function *coroFree = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_free);
1024-
1025-
// Init coroutine
1026-
LLVMCoroutine coro;
1027-
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0);
1028-
llvm::Constant *nullPointer = llvm::ConstantPointerNull::get(pointerType);
1029-
llvm::Value *coroIdRet = m_builder.CreateCall(coroId, { m_builder.getInt32(8), nullPointer, nullPointer, nullPointer });
1030-
1031-
// Allocate memory
1032-
llvm::Value *coroSizeRet = m_builder.CreateCall(coroSize, std::nullopt, "size");
1033-
llvm::Function *mallocFunc = llvm::Function::Create(llvm::FunctionType::get(pointerType, { m_builder.getInt64Ty() }, false), llvm::Function::ExternalLinkage, "malloc", m_module.get());
1034-
llvm::Value *alloc = m_builder.CreateCall(mallocFunc, coroSizeRet, "mem");
1035-
1036-
// Begin
1037-
coro.handle = m_builder.CreateCall(coroBegin, { coroIdRet, alloc });
1038-
coro.didSuspend = m_builder.CreateAlloca(m_builder.getInt1Ty(), nullptr, "didSuspend");
1039-
m_builder.CreateStore(m_builder.getInt1(false), coro.didSuspend);
1040-
llvm::BasicBlock *entry = m_builder.GetInsertBlock();
1041-
1042-
// Create suspend branch
1043-
coro.suspend = llvm::BasicBlock::Create(m_ctx, "suspend", func);
1044-
m_builder.SetInsertPoint(coro.suspend);
1045-
m_builder.CreateCall(coroEnd, { coro.handle, m_builder.getInt1(false), llvm::ConstantTokenNone::get(m_ctx) });
1046-
m_builder.CreateRet(coro.handle);
1047-
1048-
// Create free branches
1049-
coro.freeMemRet = llvm::BasicBlock::Create(m_ctx, "freeMemRet", func);
1050-
m_builder.SetInsertPoint(coro.freeMemRet);
1051-
m_builder.CreateFree(alloc);
1052-
m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType));
1053-
1054-
llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create(m_ctx, "free", func);
1055-
m_builder.SetInsertPoint(freeBranch);
1056-
m_builder.CreateFree(alloc);
1057-
m_builder.CreateBr(coro.suspend);
1058-
1059-
// Create cleanup branch
1060-
coro.cleanup = llvm::BasicBlock::Create(m_ctx, "cleanup", func);
1061-
m_builder.SetInsertPoint(coro.cleanup);
1062-
llvm::Value *mem = m_builder.CreateCall(coroFree, { coroIdRet, coro.handle });
1063-
llvm::Value *needFree = m_builder.CreateIsNotNull(mem);
1064-
m_builder.CreateCondBr(needFree, freeBranch, coro.suspend);
1065-
1066-
m_builder.SetInsertPoint(entry);
1067-
return coro;
1068-
}
1069-
1070986
void LLVMCodeBuilder::verifyFunction(llvm::Function *func)
1071987
{
1072988
if (llvm::verifyFunction(*func, &llvm::errs())) {

‎src/dev/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ class LLVMCodeBuilder : public ICodeBuilder
8484
void initTypes();
8585
void createVariableMap();
8686

87-
LLVMCoroutine initCoroutine(llvm::Function *func);
8887
void verifyFunction(llvm::Function *func);
8988
void optimize();
9089

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
#include "llvmcoroutine.h"
4+
5+
using namespace libscratchcpp;
6+
7+
LLVMCoroutine::LLVMCoroutine(llvm::Module *module, llvm::IRBuilder<> *builder, llvm::Function *func) :
8+
m_module(module),
9+
m_builder(builder),
10+
m_function(func)
11+
{
12+
llvm::LLVMContext &ctx = builder->getContext();
13+
14+
// Set presplitcoroutine attribute
15+
func->setPresplitCoroutine();
16+
17+
// Coroutine intrinsics
18+
llvm::Function *coroId = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::coro_id);
19+
llvm::Function *coroSize = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::coro_size, builder->getInt64Ty());
20+
llvm::Function *coroBegin = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::coro_begin);
21+
llvm::Function *coroEnd = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::coro_end);
22+
llvm::Function *coroFree = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::coro_free);
23+
24+
// Init coroutine
25+
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(ctx), 0);
26+
llvm::Constant *nullPointer = llvm::ConstantPointerNull::get(pointerType);
27+
llvm::Value *coroIdRet = builder->CreateCall(coroId, { builder->getInt32(8), nullPointer, nullPointer, nullPointer });
28+
29+
// Allocate memory
30+
llvm::Value *coroSizeRet = builder->CreateCall(coroSize, std::nullopt, "size");
31+
llvm::Function *mallocFunc = llvm::Function::Create(llvm::FunctionType::get(pointerType, { builder->getInt64Ty() }, false), llvm::Function::ExternalLinkage, "malloc", module);
32+
llvm::Value *alloc = builder->CreateCall(mallocFunc, coroSizeRet, "mem");
33+
34+
// Begin
35+
m_handle = builder->CreateCall(coroBegin, { coroIdRet, alloc });
36+
m_didSuspendVar = builder->CreateAlloca(builder->getInt1Ty(), nullptr, "didSuspend");
37+
builder->CreateStore(builder->getInt1(false), m_didSuspendVar);
38+
llvm::BasicBlock *entry = builder->GetInsertBlock();
39+
40+
// Create suspend branch
41+
m_suspendBlock = llvm::BasicBlock::Create(ctx, "suspend", func);
42+
builder->SetInsertPoint(m_suspendBlock);
43+
builder->CreateCall(coroEnd, { m_handle, builder->getInt1(false), llvm::ConstantTokenNone::get(ctx) });
44+
builder->CreateRet(m_handle);
45+
46+
// Create free branches
47+
m_freeMemRetBlock = llvm::BasicBlock::Create(ctx, "freeMemRet", func);
48+
builder->SetInsertPoint(m_freeMemRetBlock);
49+
builder->CreateFree(alloc);
50+
builder->CreateRet(llvm::ConstantPointerNull::get(pointerType));
51+
52+
llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create(ctx, "free", func);
53+
builder->SetInsertPoint(freeBranch);
54+
builder->CreateFree(alloc);
55+
builder->CreateBr(m_suspendBlock);
56+
57+
// Create cleanup branch
58+
m_cleanupBlock = llvm::BasicBlock::Create(ctx, "cleanup", func);
59+
builder->SetInsertPoint(m_cleanupBlock);
60+
llvm::Value *mem = builder->CreateCall(coroFree, { coroIdRet, m_handle });
61+
llvm::Value *needFree = builder->CreateIsNotNull(mem);
62+
builder->CreateCondBr(needFree, freeBranch, m_suspendBlock);
63+
64+
builder->SetInsertPoint(entry);
65+
}
66+
67+
llvm::Value *LLVMCoroutine::handle() const
68+
{
69+
return m_handle;
70+
}
71+
72+
llvm::BasicBlock *LLVMCoroutine::suspendBlock() const
73+
{
74+
return m_suspendBlock;
75+
}
76+
77+
llvm::BasicBlock *LLVMCoroutine::cleanupBlock() const
78+
{
79+
return m_cleanupBlock;
80+
}
81+
82+
llvm::BasicBlock *LLVMCoroutine::freeMemRetBlock() const
83+
{
84+
return m_freeMemRetBlock;
85+
}
86+
87+
llvm::Value *LLVMCoroutine::didSuspendVar() const
88+
{
89+
return m_didSuspendVar;
90+
}
91+
92+
void LLVMCoroutine::createSuspend()
93+
{
94+
llvm::LLVMContext &ctx = m_builder->getContext();
95+
96+
m_builder->CreateStore(m_builder->getInt1(true), m_didSuspendVar);
97+
llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create(ctx, "", m_function);
98+
llvm::Value *noneToken = llvm::ConstantTokenNone::get(ctx);
99+
llvm::Value *suspendResult = m_builder->CreateCall(llvm::Intrinsic::getDeclaration(m_module, llvm::Intrinsic::coro_suspend), { noneToken, m_builder->getInt1(false) });
100+
llvm::SwitchInst *sw = m_builder->CreateSwitch(suspendResult, m_suspendBlock, 2);
101+
sw->addCase(m_builder->getInt8(0), resumeBranch);
102+
sw->addCase(m_builder->getInt8(1), m_cleanupBlock);
103+
m_builder->SetInsertPoint(resumeBranch);
104+
}
105+
106+
llvm::Value *LLVMCoroutine::createResume(llvm::Value *coroHandle)
107+
{
108+
m_builder->CreateCall(llvm::Intrinsic::getDeclaration(m_module, llvm::Intrinsic::coro_resume), { coroHandle });
109+
return m_builder->CreateCall(llvm::Intrinsic::getDeclaration(m_module, llvm::Intrinsic::coro_done), { coroHandle });
110+
}
111+
112+
void LLVMCoroutine::end()
113+
{
114+
llvm::LLVMContext &ctx = m_builder->getContext();
115+
116+
// Add final suspend point
117+
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(ctx, "end", m_function);
118+
llvm::BasicBlock *finalSuspendBranch = llvm::BasicBlock::Create(ctx, "finalSuspend", m_function);
119+
m_builder->CreateCondBr(m_builder->CreateLoad(m_builder->getInt1Ty(), m_didSuspendVar), finalSuspendBranch, endBranch);
120+
121+
m_builder->SetInsertPoint(finalSuspendBranch);
122+
llvm::Value *suspendResult = m_builder->CreateCall(llvm::Intrinsic::getDeclaration(m_module, llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get(ctx), m_builder->getInt1(true) });
123+
llvm::SwitchInst *sw = m_builder->CreateSwitch(suspendResult, m_suspendBlock, 2);
124+
sw->addCase(m_builder->getInt8(0), endBranch); // unreachable
125+
sw->addCase(m_builder->getInt8(1), m_cleanupBlock);
126+
127+
// Jump to "free and return" branch
128+
m_builder->SetInsertPoint(endBranch);
129+
m_builder->CreateBr(m_freeMemRetBlock);
130+
}

‎src/dev/engine/internal/llvm/llvmcoroutine.h

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,38 @@
22

33
#pragma once
44

5-
namespace llvm
6-
{
7-
8-
class Value;
9-
class BasicBlock;
10-
11-
} // namespace llvm
5+
#include <llvm/IR/IRBuilder.h>
126

137
namespace libscratchcpp
148
{
159

16-
struct LLVMCoroutine
10+
class LLVMCoroutine
1711
{
18-
llvm::Value *handle = nullptr;
19-
llvm::BasicBlock *suspend = nullptr;
20-
llvm::BasicBlock *cleanup = nullptr;
21-
llvm::BasicBlock *freeMemRet = nullptr;
22-
llvm::Value *didSuspend = nullptr;
12+
public:
13+
LLVMCoroutine(llvm::Module *module, llvm::IRBuilder<> *builder, llvm::Function *func);
14+
LLVMCoroutine(const LLVMCoroutine &) = delete;
15+
16+
llvm::Value *handle() const;
17+
18+
llvm::BasicBlock *suspendBlock() const;
19+
llvm::BasicBlock *cleanupBlock() const;
20+
llvm::BasicBlock *freeMemRetBlock() const;
21+
22+
llvm::Value *didSuspendVar() const;
23+
24+
void createSuspend();
25+
llvm::Value *createResume(llvm::Value *coroHandle);
26+
void end();
27+
28+
private:
29+
llvm::Module *m_module = nullptr;
30+
llvm::IRBuilder<> *m_builder = nullptr;
31+
llvm::Function *m_function = nullptr;
32+
llvm::Value *m_handle = nullptr;
33+
llvm::BasicBlock *m_suspendBlock = nullptr;
34+
llvm::BasicBlock *m_cleanupBlock = nullptr;
35+
llvm::BasicBlock *m_freeMemRetBlock = nullptr;
36+
llvm::Value *m_didSuspendVar = nullptr;
2337
};
2438

2539
} // namespace libscratchcpp

0 commit comments

Comments
 (0)
Please sign in to comment.