@@ -53,7 +53,7 @@ llvm::Type* LLVMGenerator::getStructType(IRStructType* type) {
5353 return llvmStruct;
5454}
5555
56- llvm::Type* LLVMGenerator::getLLVMType (IRType* type) {
56+ llvm::Type* LLVMGenerator::getLLVMType (IRType* type, bool * isSret ) {
5757 switch (type->kind ) {
5858 case IRTypeKind::IRBasicType: {
5959 return NOTNULL (getBuiltinType (type->getName ()));
@@ -65,7 +65,15 @@ llvm::Type* LLVMGenerator::getLLVMType(IRType* type) {
6565 case IRTypeKind::IRFunctionType: {
6666 auto functionType = llvm::cast<IRFunctionType>(type);
6767 auto returnType = getLLVMType (functionType->returnType );
68- auto paramTypes = map (functionType->paramTypes , [&](IRType* type) { return getLLVMType (type); });
68+ auto paramTypes = map (functionType->paramTypes , [&](IRType* p) { return getLLVMType (p); });
69+ // Use hidden sret pointer parameter to return larger structs to be compatible with the C calling convention.
70+ if (shouldUseSret (returnType)) {
71+ if (isSret) *isSret = true ;
72+ paramTypes.insert (paramTypes.begin (), llvm::PointerType::get (returnType, 0 ));
73+ returnType = llvm::Type::getVoidTy (ctx);
74+ } else {
75+ if (isSret) *isSret = false ;
76+ }
6977 return llvm::FunctionType::get (returnType, paramTypes, functionType->isVariadic );
7078 }
7179 case IRTypeKind::IRPointerType: {
@@ -101,31 +109,39 @@ llvm::Type* LLVMGenerator::getLLVMType(IRType* type) {
101109 llvm_unreachable (" all cases handled" );
102110}
103111
112+ bool LLVMGenerator::shouldUseSret (llvm::Type* returnType) {
113+ return !returnType->isVoidTy () && module ->getDataLayout ().getTypeAllocSize (returnType) > 16 ;
114+ }
115+
104116llvm::Function* LLVMGenerator::getFunction (const Function* function) {
105117 if (auto * llvmFunction = module ->getFunction (function->mangledName )) return llvmFunction;
106118
107- llvm::SmallVector<llvm::Type*, 16 > paramTypes;
108- for (auto & param : function->params ) {
109- paramTypes.emplace_back (getLLVMType (param.type ));
110- }
111-
112- auto * returnType = getLLVMType (function->returnType );
113- auto * functionType = llvm::FunctionType::get (returnType, paramTypes, function->isVariadic );
114- auto * llvmFunction = llvm::Function::Create (functionType, llvm::Function::ExternalLinkage, function->mangledName , &*module );
119+ bool isSret;
120+ auto llvmFunctionType = llvm::cast<llvm::FunctionType>(getLLVMType (function->getType ()->getPointee (), &isSret));
121+ auto * llvmFunction = llvm::Function::Create (llvmFunctionType, llvm::Function::ExternalLinkage, function->mangledName , module );
115122
116123 auto arg = llvmFunction->arg_begin (), argsEnd = llvmFunction->arg_end ();
117- ASSERT (function->params .size () == size_t (std::distance (arg, argsEnd)));
124+ if (isSret) {
125+ arg->setName (" sret.arg" );
126+ ++arg;
127+ }
118128 for (auto param = function->params .begin (); arg != argsEnd; ++param, ++arg) {
119129 arg->setName (param->name );
120130 }
121131
132+ if (isSret) {
133+ auto structType = getLLVMType (function->returnType );
134+ llvmFunction->getArg (0 )->addAttr (llvm::Attribute::get (ctx, llvm::Attribute::StructRet, structType));
135+ }
122136 return llvmFunction;
123137}
124138
125139void LLVMGenerator::codegenFunctionBody (const Function* function, llvm::Function* llvmFunction) {
140+ isCurrentFunctionSret = shouldUseSret (getLLVMType (function->returnType ));
126141 llvm::IRBuilder<>::InsertPointGuard insertPointGuard (builder);
127142
128143 auto arg = llvmFunction->arg_begin ();
144+ if (isCurrentFunctionSret) ++arg;
129145 for (auto & param : function->params ) {
130146 generatedValues.emplace (¶m, &*arg++);
131147 }
@@ -183,6 +199,11 @@ llvm::Value* LLVMGenerator::codegenAlloca(const AllocaInst* inst) {
183199}
184200
185201llvm::Value* LLVMGenerator::codegenReturn (const ReturnInst* inst) {
202+ if (isCurrentFunctionSret) {
203+ auto currentFunction = builder.GetInsertBlock ()->getParent ();
204+ builder.CreateStore (getValue (inst->value ), currentFunction->getArg (0 ));
205+ return builder.CreateRetVoid ();
206+ }
186207 return inst->value ? builder.CreateRet (getValue (inst->value )) : builder.CreateRetVoid ();
187208}
188209
@@ -237,12 +258,20 @@ llvm::Value* LLVMGenerator::codegenCall(const CallInst* inst) {
237258 auto function = getValue (inst->function );
238259 auto args = map (inst->args , [&](auto * arg) { return getValue (arg); });
239260 auto cxFunctionType = inst->function ->getType ();
240- if (cxFunctionType->isPointerType ()) {
241- cxFunctionType = cxFunctionType->getPointee ();
261+ if (cxFunctionType->isPointerType ()) cxFunctionType = cxFunctionType->getPointee ();
262+ ASSERT (cxFunctionType->isFunctionType ());
263+
264+ bool isSret;
265+ auto * llvmFunctionType = llvm::cast<llvm::FunctionType>(getLLVMType (cxFunctionType, &isSret));
266+ if (isSret) {
267+ auto sretType = getLLVMType (cxFunctionType->getReturnType ());
268+ auto sretAlloca = builder.CreateAlloca (sretType, nullptr , " sret.alloca" );
269+ args.insert (args.begin (), sretAlloca);
270+ builder.CreateCall (llvmFunctionType, function, args);
271+ return builder.CreateLoad (sretType, sretAlloca, " sret.load" );
272+ } else {
273+ return builder.CreateCall (llvmFunctionType, function, args);
242274 }
243- assert (cxFunctionType->isFunctionType ());
244- auto * llvmFunctionType = llvm::cast<llvm::FunctionType>(getLLVMType (cxFunctionType));
245- return builder.CreateCall (llvmFunctionType, function, args);
246275}
247276
248277llvm::Value* LLVMGenerator::codegenBinary (const BinaryInst* inst) {
0 commit comments