Skip to content

Commit af72bef

Browse files
committed
LLVM-backend: use sret to match C calling convention
1 parent bc9b8cf commit af72bef

File tree

8 files changed

+154
-19
lines changed

8 files changed

+154
-19
lines changed

src/backend/c-backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ void CGenerator::codegenStore(const StoreInst* inst) {
142142
void CGenerator::codegenInsert(const InsertInst* inst) {
143143
stream.indent(4);
144144
auto type = inst->aggregate->getType();
145-
assert(type->isStruct() || type->isArrayType());
145+
ASSERT(type->isStruct() || type->isArrayType());
146146
codegenType(stream, type, true);
147147
auto name = "_insert" + std::to_string(valueSuffixCounter++);
148148
stream << " " << name;

src/backend/llvm.cpp

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ llvm::Type* LLVMGenerator::getStructType(IRStructType* type) {
5353
return llvmStruct;
5454
}
5555

56-
llvm::Type* LLVMGenerator::getLLVMType(IRType* type) {
56+
llvm::Type* LLVMGenerator::getLLVMType(IRType* type, bool* isSret) {
5757
switch (type->kind) {
5858
case IRTypeKind::IRBasicType: {
5959
return NOTNULL(getBuiltinType(type->getName()));
@@ -65,7 +65,15 @@ llvm::Type* LLVMGenerator::getLLVMType(IRType* type) {
6565
case IRTypeKind::IRFunctionType: {
6666
auto functionType = llvm::cast<IRFunctionType>(type);
6767
auto returnType = getLLVMType(functionType->returnType);
68-
auto paramTypes = map(functionType->paramTypes, [&](IRType* type) { return getLLVMType(type); });
68+
auto paramTypes = map(functionType->paramTypes, [&](IRType* p) { return getLLVMType(p); });
69+
// Use hidden sret pointer parameter to return larger structs to be compatible with the C calling convention.
70+
if (shouldUseSret(returnType)) {
71+
if (isSret) *isSret = true;
72+
paramTypes.insert(paramTypes.begin(), llvm::PointerType::get(returnType, 0));
73+
returnType = llvm::Type::getVoidTy(ctx);
74+
} else {
75+
if (isSret) *isSret = false;
76+
}
6977
return llvm::FunctionType::get(returnType, paramTypes, functionType->isVariadic);
7078
}
7179
case IRTypeKind::IRPointerType: {
@@ -101,31 +109,39 @@ llvm::Type* LLVMGenerator::getLLVMType(IRType* type) {
101109
llvm_unreachable("all cases handled");
102110
}
103111

112+
bool LLVMGenerator::shouldUseSret(llvm::Type* returnType) {
113+
return !returnType->isVoidTy() && module->getDataLayout().getTypeAllocSize(returnType) > 16;
114+
}
115+
104116
llvm::Function* LLVMGenerator::getFunction(const Function* function) {
105117
if (auto* llvmFunction = module->getFunction(function->mangledName)) return llvmFunction;
106118

107-
llvm::SmallVector<llvm::Type*, 16> paramTypes;
108-
for (auto& param : function->params) {
109-
paramTypes.emplace_back(getLLVMType(param.type));
110-
}
111-
112-
auto* returnType = getLLVMType(function->returnType);
113-
auto* functionType = llvm::FunctionType::get(returnType, paramTypes, function->isVariadic);
114-
auto* llvmFunction = llvm::Function::Create(functionType, llvm::Function::ExternalLinkage, function->mangledName, &*module);
119+
bool isSret;
120+
auto llvmFunctionType = llvm::cast<llvm::FunctionType>(getLLVMType(function->getType()->getPointee(), &isSret));
121+
auto* llvmFunction = llvm::Function::Create(llvmFunctionType, llvm::Function::ExternalLinkage, function->mangledName, module);
115122

116123
auto arg = llvmFunction->arg_begin(), argsEnd = llvmFunction->arg_end();
117-
ASSERT(function->params.size() == size_t(std::distance(arg, argsEnd)));
124+
if (isSret) {
125+
arg->setName("sret.arg");
126+
++arg;
127+
}
118128
for (auto param = function->params.begin(); arg != argsEnd; ++param, ++arg) {
119129
arg->setName(param->name);
120130
}
121131

132+
if (isSret) {
133+
auto structType = getLLVMType(function->returnType);
134+
llvmFunction->getArg(0)->addAttr(llvm::Attribute::get(ctx, llvm::Attribute::StructRet, structType));
135+
}
122136
return llvmFunction;
123137
}
124138

125139
void LLVMGenerator::codegenFunctionBody(const Function* function, llvm::Function* llvmFunction) {
140+
isCurrentFunctionSret = shouldUseSret(getLLVMType(function->returnType));
126141
llvm::IRBuilder<>::InsertPointGuard insertPointGuard(builder);
127142

128143
auto arg = llvmFunction->arg_begin();
144+
if (isCurrentFunctionSret) ++arg;
129145
for (auto& param : function->params) {
130146
generatedValues.emplace(&param, &*arg++);
131147
}
@@ -183,6 +199,11 @@ llvm::Value* LLVMGenerator::codegenAlloca(const AllocaInst* inst) {
183199
}
184200

185201
llvm::Value* LLVMGenerator::codegenReturn(const ReturnInst* inst) {
202+
if (isCurrentFunctionSret) {
203+
auto currentFunction = builder.GetInsertBlock()->getParent();
204+
builder.CreateStore(getValue(inst->value), currentFunction->getArg(0));
205+
return builder.CreateRetVoid();
206+
}
186207
return inst->value ? builder.CreateRet(getValue(inst->value)) : builder.CreateRetVoid();
187208
}
188209

@@ -237,12 +258,20 @@ llvm::Value* LLVMGenerator::codegenCall(const CallInst* inst) {
237258
auto function = getValue(inst->function);
238259
auto args = map(inst->args, [&](auto* arg) { return getValue(arg); });
239260
auto cxFunctionType = inst->function->getType();
240-
if (cxFunctionType->isPointerType()) {
241-
cxFunctionType = cxFunctionType->getPointee();
261+
if (cxFunctionType->isPointerType()) cxFunctionType = cxFunctionType->getPointee();
262+
ASSERT(cxFunctionType->isFunctionType());
263+
264+
bool isSret;
265+
auto* llvmFunctionType = llvm::cast<llvm::FunctionType>(getLLVMType(cxFunctionType, &isSret));
266+
if (isSret) {
267+
auto sretType = getLLVMType(cxFunctionType->getReturnType());
268+
auto sretAlloca = builder.CreateAlloca(sretType, nullptr, "sret.alloca");
269+
args.insert(args.begin(), sretAlloca);
270+
builder.CreateCall(llvmFunctionType, function, args);
271+
return builder.CreateLoad(sretType, sretAlloca, "sret.load");
272+
} else {
273+
return builder.CreateCall(llvmFunctionType, function, args);
242274
}
243-
assert(cxFunctionType->isFunctionType());
244-
auto* llvmFunctionType = llvm::cast<llvm::FunctionType>(getLLVMType(cxFunctionType));
245-
return builder.CreateCall(llvmFunctionType, function, args);
246275
}
247276

248277
llvm::Value* LLVMGenerator::codegenBinary(const BinaryInst* inst) {

src/backend/llvm.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ struct LLVMGenerator {
5050
llvm::Function* getFunction(const Function* function);
5151
void codegenFunction(const Function* function);
5252
void codegenFunctionBody(const Function* function, llvm::Function* llvmFunction);
53-
llvm::Type* getLLVMType(IRType* type);
53+
llvm::Type* getLLVMType(IRType* type, bool* isSret = nullptr);
54+
bool shouldUseSret(llvm::Type* returnType);
5455
llvm::Type* getBuiltinType(llvm::StringRef name);
5556
llvm::Type* getStructType(IRStructType* type);
5657

@@ -60,6 +61,7 @@ struct LLVMGenerator {
6061
std::vector<llvm::Module*> generatedModules;
6162
std::unordered_map<const Value*, llvm::Value*> generatedValues;
6263
std::unordered_map<IRType*, llvm::StructType*> structs;
64+
bool isCurrentFunctionSret;
6365
};
6466

6567
} // namespace cx

src/driver/driver.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ cl::opt<bool> compileOnly("c", cl::desc("Compile only, generating an object file
7575

7676
cl::OptionCategory outputCategory("Output Options");
7777
// TODO: Add -print-llvm-all option.
78+
// TODO: support simultaneous -print-c and -print-llvm? (requires both C backend and LLVM backend)
7879
enum class PrintOpt { AST, IR, IRAll, C, LLVM };
7980
cl::bits<PrintOpt> printOpts(cl::desc("Print output from intermediate steps:"), cl::sub(build), cl::sub(cl::SubCommand::getTopLevel()), cl::cat(outputCategory),
8081
cl::values(clEnumValN(PrintOpt::AST, "print-ast", "Print the abstract syntax tree of main module"),
@@ -217,7 +218,7 @@ static void emitLLVMBitcode(const llvm::Module& module, llvm::StringRef fileName
217218
llvm::MemoryBufferRef cx::addFileBufferToModule(llvm::StringRef filePath, Module& targetModule) {
218219
auto buffer = llvm::MemoryBuffer::getFile(filePath);
219220
if (!buffer) ABORT("couldn't open file '" << filePath << "'");
220-
assert((*buffer)->getBufferIdentifier() == filePath);
221+
ASSERT((*buffer)->getBufferIdentifier() == filePath);
221222
targetModule.fileBuffers.push_back(std::move(*buffer));
222223
return targetModule.fileBuffers.back()->getMemBufferRef();
223224
}

test/snapshot/c-import-sret.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
struct S {
4+
double d[4];
5+
};
6+
7+
struct S returnsLargeStruct() {
8+
struct S s;
9+
s.d[0] = 0;
10+
s.d[1] = 1;
11+
s.d[2] = 2;
12+
s.d[3] = 3;
13+
return s;
14+
}

test/snapshot/struct-return.cx

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: check-snapshots
2+
3+
import "c-import-sret.h"
4+
5+
struct S2 {
6+
float64 x;
7+
float64 y;
8+
float64 z;
9+
float64 w;
10+
}
11+
12+
S2 returnsLargeStruct2() {
13+
return S2(0, 1, 2, 3);
14+
}
15+
16+
void main() {
17+
var s = returnsLargeStruct();
18+
var s2 = returnsLargeStruct2();
19+
}

test/snapshot/struct-return.cx.ir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
S2 _EN4main19returnsLargeStruct2E() {
3+
S2* _0 = alloca S2
4+
void _1 = call _EN4main2S24initE7float647float647float647float64(S2* _0, float64 0, float64 1, float64 2, float64 3)
5+
S2 .load = load _0
6+
return .load
7+
}
8+
9+
void _EN4main2S24initE7float647float647float647float64(S2* this, float64 x, float64 y, float64 z, float64 w) {
10+
float64* x_0 = getelementptr this, 0
11+
store x to x_0
12+
float64* y_0 = getelementptr this, 1
13+
store y to y_0
14+
float64* z_0 = getelementptr this, 2
15+
store z to z_0
16+
float64* w_0 = getelementptr this, 3
17+
store w to w_0
18+
return void
19+
}
20+
21+
int main() {
22+
S* s = alloca S
23+
S2* s2 = alloca S2
24+
S _0 = call returnsLargeStruct()
25+
store _0 to s
26+
S2 _1 = call _EN4main19returnsLargeStruct2E()
27+
store _1 to s2
28+
return int 0
29+
}
30+
31+
extern S returnsLargeStruct()

test/snapshot/struct-return.cx.ll

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
%S2 = type { double, double, double, double }
3+
%S = type { [4 x double] }
4+
5+
define void @_EN4main19returnsLargeStruct2E(ptr sret(%S2) %sret.arg) {
6+
%1 = alloca %S2, align 8
7+
call void @_EN4main2S24initE7float647float647float647float64(ptr %1, double 0.000000e+00, double 1.000000e+00, double 2.000000e+00, double 3.000000e+00)
8+
%.load = load %S2, ptr %1, align 8
9+
store %S2 %.load, ptr %sret.arg, align 8
10+
ret void
11+
}
12+
13+
define void @_EN4main2S24initE7float647float647float647float64(ptr %this, double %x, double %y, double %z, double %w) {
14+
%x1 = getelementptr inbounds %S2, ptr %this, i32 0, i32 0
15+
store double %x, ptr %x1, align 8
16+
%y2 = getelementptr inbounds %S2, ptr %this, i32 0, i32 1
17+
store double %y, ptr %y2, align 8
18+
%z3 = getelementptr inbounds %S2, ptr %this, i32 0, i32 2
19+
store double %z, ptr %z3, align 8
20+
%w4 = getelementptr inbounds %S2, ptr %this, i32 0, i32 3
21+
store double %w, ptr %w4, align 8
22+
ret void
23+
}
24+
25+
define i32 @main() {
26+
%s = alloca %S, align 8
27+
%s2 = alloca %S2, align 8
28+
%sret.alloca = alloca %S, align 8
29+
call void @returnsLargeStruct(ptr %sret.alloca)
30+
%sret.load = load %S, ptr %sret.alloca, align 8
31+
store %S %sret.load, ptr %s, align 8
32+
%sret.alloca1 = alloca %S2, align 8
33+
call void @_EN4main19returnsLargeStruct2E(ptr %sret.alloca1)
34+
%sret.load2 = load %S2, ptr %sret.alloca1, align 8
35+
store %S2 %sret.load2, ptr %s2, align 8
36+
ret i32 0
37+
}
38+
39+
declare void @returnsLargeStruct(ptr sret(%S))

0 commit comments

Comments
 (0)