@@ -60,10 +60,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
60
60
m_builder.SetInsertPoint (entry);
61
61
62
62
// Init coroutine
63
- LLVMCoroutine coro;
63
+ std::unique_ptr< LLVMCoroutine> coro;
64
64
65
65
if (!m_warp)
66
- coro = initCoroutine ( func);
66
+ coro = std::make_unique<LLVMCoroutine>(m_module. get (), &m_builder, func);
67
67
68
68
std::vector<LLVMIfStatement> ifStatements;
69
69
std::vector<LLVMLoop> loops;
@@ -439,14 +439,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
439
439
if (!m_warp) {
440
440
freeHeap ();
441
441
syncVariables (targetVariables);
442
- m_builder.CreateStore (m_builder.getInt1 (true ), coro.didSuspend );
443
- llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create (m_ctx, " " , func);
444
- llvm::Value *noneToken = llvm::ConstantTokenNone::get (m_ctx);
445
- llvm::Value *suspendResult = m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_suspend), { noneToken, m_builder.getInt1 (false ) });
446
- llvm::SwitchInst *sw = m_builder.CreateSwitch (suspendResult, coro.suspend , 2 );
447
- sw->addCase (m_builder.getInt8 (0 ), resumeBranch);
448
- sw->addCase (m_builder.getInt8 (1 ), coro.cleanup );
449
- m_builder.SetInsertPoint (resumeBranch);
442
+ coro->createSuspend ();
450
443
}
451
444
452
445
break ;
@@ -651,34 +644,18 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
651
644
freeHeap ();
652
645
syncVariables (targetVariables);
653
646
654
- // Add final suspend point
655
- if (!m_warp) {
656
- llvm::BasicBlock *endBranch = llvm::BasicBlock::Create (m_ctx, " end" , func);
657
- llvm::BasicBlock *finalSuspendBranch = llvm::BasicBlock::Create (m_ctx, " finalSuspend" , func);
658
- m_builder.CreateCondBr (m_builder.CreateLoad (m_builder.getInt1Ty (), coro.didSuspend ), finalSuspendBranch, endBranch);
659
-
660
- m_builder.SetInsertPoint (finalSuspendBranch);
661
- llvm::Value *suspendResult =
662
- m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get (m_ctx), m_builder.getInt1 (true ) });
663
- llvm::SwitchInst *sw = m_builder.CreateSwitch (suspendResult, coro.suspend , 2 );
664
- sw->addCase (m_builder.getInt8 (0 ), endBranch); // unreachable
665
- sw->addCase (m_builder.getInt8 (1 ), coro.cleanup );
666
-
667
- m_builder.SetInsertPoint (endBranch);
668
- }
669
-
670
647
// End and verify the function
648
+ if (m_warp)
649
+ m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
650
+ else
651
+ coro->end ();
652
+
671
653
if (!m_tmpRegs.empty ()) {
672
654
std::cout
673
655
<< " warning: " << m_tmpRegs.size () << " registers were leaked by script '" << m_module->getName ().str () << " ', function '" << func->getName ().str ()
674
656
<< " ' (if you see this as a regular user, this is a bug and should be reported)" << std::endl;
675
657
}
676
658
677
- if (m_warp)
678
- m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
679
- else
680
- m_builder.CreateBr (coro.freeMemRet );
681
-
682
659
verifyFunction (func);
683
660
684
661
// Create resume function
@@ -691,12 +668,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
691
668
692
669
if (m_warp)
693
670
m_builder.CreateRet (m_builder.getInt1 (true ));
694
- else {
695
- llvm::Value *coroHandle = func->getArg (0 );
696
- m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_resume), { coroHandle });
697
- llvm::Value *done = m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_done), { coroHandle });
698
- m_builder.CreateRet (done);
699
- }
671
+ else
672
+ m_builder.CreateRet (coro->createResume (func->getArg (0 )));
700
673
701
674
verifyFunction (func);
702
675
@@ -1010,63 +983,6 @@ void LLVMCodeBuilder::createVariableMap()
1010
983
}
1011
984
}
1012
985
1013
- LLVMCoroutine LLVMCodeBuilder::initCoroutine (llvm::Function *func)
1014
- {
1015
- // Set presplitcoroutine attribute
1016
- func->setPresplitCoroutine ();
1017
-
1018
- // Coroutine intrinsics
1019
- llvm::Function *coroId = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_id);
1020
- llvm::Function *coroSize = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_size, m_builder.getInt64Ty ());
1021
- llvm::Function *coroBegin = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_begin);
1022
- llvm::Function *coroEnd = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_end);
1023
- llvm::Function *coroFree = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_free);
1024
-
1025
- // Init coroutine
1026
- LLVMCoroutine coro;
1027
- llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_ctx), 0 );
1028
- llvm::Constant *nullPointer = llvm::ConstantPointerNull::get (pointerType);
1029
- llvm::Value *coroIdRet = m_builder.CreateCall (coroId, { m_builder.getInt32 (8 ), nullPointer, nullPointer, nullPointer });
1030
-
1031
- // Allocate memory
1032
- llvm::Value *coroSizeRet = m_builder.CreateCall (coroSize, std::nullopt, " size" );
1033
- llvm::Function *mallocFunc = llvm::Function::Create (llvm::FunctionType::get (pointerType, { m_builder.getInt64Ty () }, false ), llvm::Function::ExternalLinkage, " malloc" , m_module.get ());
1034
- llvm::Value *alloc = m_builder.CreateCall (mallocFunc, coroSizeRet, " mem" );
1035
-
1036
- // Begin
1037
- coro.handle = m_builder.CreateCall (coroBegin, { coroIdRet, alloc });
1038
- coro.didSuspend = m_builder.CreateAlloca (m_builder.getInt1Ty (), nullptr , " didSuspend" );
1039
- m_builder.CreateStore (m_builder.getInt1 (false ), coro.didSuspend );
1040
- llvm::BasicBlock *entry = m_builder.GetInsertBlock ();
1041
-
1042
- // Create suspend branch
1043
- coro.suspend = llvm::BasicBlock::Create (m_ctx, " suspend" , func);
1044
- m_builder.SetInsertPoint (coro.suspend );
1045
- m_builder.CreateCall (coroEnd, { coro.handle , m_builder.getInt1 (false ), llvm::ConstantTokenNone::get (m_ctx) });
1046
- m_builder.CreateRet (coro.handle );
1047
-
1048
- // Create free branches
1049
- coro.freeMemRet = llvm::BasicBlock::Create (m_ctx, " freeMemRet" , func);
1050
- m_builder.SetInsertPoint (coro.freeMemRet );
1051
- m_builder.CreateFree (alloc);
1052
- m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
1053
-
1054
- llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create (m_ctx, " free" , func);
1055
- m_builder.SetInsertPoint (freeBranch);
1056
- m_builder.CreateFree (alloc);
1057
- m_builder.CreateBr (coro.suspend );
1058
-
1059
- // Create cleanup branch
1060
- coro.cleanup = llvm::BasicBlock::Create (m_ctx, " cleanup" , func);
1061
- m_builder.SetInsertPoint (coro.cleanup );
1062
- llvm::Value *mem = m_builder.CreateCall (coroFree, { coroIdRet, coro.handle });
1063
- llvm::Value *needFree = m_builder.CreateIsNotNull (mem);
1064
- m_builder.CreateCondBr (needFree, freeBranch, coro.suspend );
1065
-
1066
- m_builder.SetInsertPoint (entry);
1067
- return coro;
1068
- }
1069
-
1070
986
void LLVMCodeBuilder::verifyFunction (llvm::Function *func)
1071
987
{
1072
988
if (llvm::verifyFunction (*func, &llvm::errs ())) {
0 commit comments