Skip to content

Commit 2803105

Browse files
authored
Merge pull request #77 from ix-dcourtois/feature/+_operator_for_strings
+ operator for strings
2 parents 8056dea + ffd65f9 commit 2803105

File tree

7 files changed

+246
-70
lines changed

7 files changed

+246
-70
lines changed

src/SeExpr/Evaluator.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ class LLVMEvaluator {
123123
Function *SeExpr2LLVMEvalCustomFunctionFunc = nullptr;
124124
Function *SeExpr2LLVMEvalFPVarRefFunc = nullptr;
125125
Function *SeExpr2LLVMEvalStrVarRefFunc = nullptr;
126+
Function *SeExpr2LLVMEvalstrlenFunc = nullptr;
127+
Function *SeExpr2LLVMEvalmallocFunc = nullptr;
128+
Function *SeExpr2LLVMEvalfreeFunc = nullptr;
129+
Function *SeExpr2LLVMEvalmemsetFunc = nullptr;
130+
Function *SeExpr2LLVMEvalstrcatFunc = nullptr;
126131
{
127132
{
128133
FunctionType *FT = FunctionType::get(voidTy, {i32PtrTy, doublePtrTy, i8PtrPtrTy, i8PtrPtrTy, i64Ty}, false);
@@ -136,6 +141,26 @@ class LLVMEvaluator {
136141
FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, i8PtrPtrTy}, false);
137142
SeExpr2LLVMEvalStrVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalStrVarRef", TheModule.get());
138143
}
144+
{
145+
FunctionType *FT = FunctionType::get(i32Ty, { i8PtrTy }, false);
146+
SeExpr2LLVMEvalstrlenFunc = Function::Create(FT, Function::ExternalLinkage, "strlen", TheModule.get());
147+
}
148+
{
149+
FunctionType *FT = FunctionType::get(i8PtrTy, { i32Ty }, false);
150+
SeExpr2LLVMEvalmallocFunc = Function::Create(FT, Function::ExternalLinkage, "malloc", TheModule.get());
151+
}
152+
{
153+
FunctionType *FT = FunctionType::get(voidTy, { i8PtrTy }, false);
154+
SeExpr2LLVMEvalfreeFunc = Function::Create(FT, Function::ExternalLinkage, "free", TheModule.get());
155+
}
156+
{
157+
FunctionType *FT = FunctionType::get(voidTy, { i8PtrTy, i32Ty, i32Ty }, false);
158+
SeExpr2LLVMEvalmemsetFunc = Function::Create(FT, Function::ExternalLinkage, "memset", TheModule.get());
159+
}
160+
{
161+
FunctionType *FT = FunctionType::get(i8PtrTy, { i8PtrTy, i8PtrTy }, false);
162+
SeExpr2LLVMEvalstrcatFunc = Function::Create(FT, Function::ExternalLinkage, "strcat", TheModule.get());
163+
}
139164
}
140165

141166
// create function and entry BB
@@ -286,6 +311,11 @@ class LLVMEvaluator {
286311
TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalFPVarRefFunc, (void *)SeExpr2LLVMEvalFPVarRef);
287312
TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalStrVarRefFunc, (void *)SeExpr2LLVMEvalStrVarRef);
288313
TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalCustomFunctionFunc, (void *)SeExpr2LLVMEvalCustomFunction);
314+
TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalstrlenFunc, (void *)strlen);
315+
TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalstrcatFunc, (void *)strcat);
316+
TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalmemsetFunc, (void *)memset);
317+
TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalmallocFunc, (void *)malloc);
318+
TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalfreeFunc, (void *)free);
289319

290320
// [verify]
291321
std::string errorStr;

src/SeExpr/ExprLLVMCodeGeneration.cpp

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -566,34 +566,75 @@ LLVM_VALUE ExprBinaryOpNode::codegen(LLVM_BUILDER Builder) LLVM_BODY {
566566
LLVM_VALUE op1 = pv.first;
567567
LLVM_VALUE op2 = pv.second;
568568

569-
switch (_op) {
570-
case '+':
571-
return Builder.CreateFAdd(op1, op2);
572-
case '-':
573-
return Builder.CreateFSub(op1, op2);
574-
case '*':
575-
return Builder.CreateFMul(op1, op2);
576-
case '/':
577-
return Builder.CreateFDiv(op1, op2);
578-
case '%': {
579-
// niceMod() from v1: b==0 ? 0 : a-floor(a/b)*b
580-
LLVM_VALUE a = op1, b = op2;
581-
LLVM_VALUE aOverB = Builder.CreateFDiv(a, b);
582-
Function *floorFun = Intrinsic::getDeclaration(llvm_getModule(Builder), Intrinsic::floor, op1->getType());
583-
LLVM_VALUE normal = Builder.CreateFSub(a, Builder.CreateFMul(Builder.CreateCall(floorFun, {aOverB}), b));
584-
Constant *zero = ConstantFP::get(op1->getType(), 0.0);
585-
return Builder.CreateSelect(Builder.CreateFCmpOEQ(zero, op1), zero, normal);
586-
}
587-
case '^': {
588-
// TODO: make external function reference work with interpreter, libffi
589-
// TODO: needed for MCJIT??
590-
// TODO: is the above not already done?!
591-
std::vector<Type *> arg_type;
592-
arg_type.push_back(op1->getType());
593-
Function *fun = Intrinsic::getDeclaration(llvm_getModule(Builder), Intrinsic::pow, arg_type);
594-
std::vector<LLVM_VALUE> ops = {op1, op2};
595-
return Builder.CreateCall(fun, ops);
569+
const bool isString = child(0)->type().isString();
570+
571+
if (isString == false) {
572+
switch (_op) {
573+
case '+':
574+
return Builder.CreateFAdd(op1, op2);
575+
case '-':
576+
return Builder.CreateFSub(op1, op2);
577+
case '*':
578+
return Builder.CreateFMul(op1, op2);
579+
case '/':
580+
return Builder.CreateFDiv(op1, op2);
581+
case '%': {
582+
// niceMod() from v1: b==0 ? 0 : a-floor(a/b)*b
583+
LLVM_VALUE a = op1, b = op2;
584+
LLVM_VALUE aOverB = Builder.CreateFDiv(a, b);
585+
Function *floorFun = Intrinsic::getDeclaration(llvm_getModule(Builder), Intrinsic::floor, op1->getType());
586+
LLVM_VALUE normal = Builder.CreateFSub(a, Builder.CreateFMul(Builder.CreateCall(floorFun, {aOverB}), b));
587+
Constant *zero = ConstantFP::get(op1->getType(), 0.0);
588+
return Builder.CreateSelect(Builder.CreateFCmpOEQ(zero, op1), zero, normal);
589+
}
590+
case '^': {
591+
// TODO: make external function reference work with interpreter, libffi
592+
// TODO: needed for MCJIT??
593+
// TODO: is the above not already done?!
594+
std::vector<Type *> arg_type;
595+
arg_type.push_back(op1->getType());
596+
Function *fun = Intrinsic::getDeclaration(llvm_getModule(Builder), Intrinsic::pow, arg_type);
597+
std::vector<LLVM_VALUE> ops = {op1, op2};
598+
return Builder.CreateCall(fun, ops);
599+
}
596600
}
601+
} else {
602+
// precompute a few things
603+
LLVMContext &context = Builder.getContext();
604+
Module *module = llvm_getModule(Builder);
605+
PointerType *i8PtrPtrTy = PointerType::getUnqual(Type::getInt8PtrTy(context));
606+
Type *i32Ty = Type::getInt32Ty(context);
607+
Function *strlen = module->getFunction("strlen");
608+
Function *malloc = module->getFunction("malloc");
609+
Function *free = module->getFunction("free");
610+
Function *memset = module->getFunction("memset");
611+
Function *strcat = module->getFunction("strcat");
612+
613+
// do magic (see the pseudo C code on the comments at the end
614+
// of each LLVM instruction)
615+
616+
// compute the length of the operand strings
617+
LLVM_VALUE len1 = Builder.CreateCall(strlen, { op1 }); // len1 = strlen(op1);
618+
LLVM_VALUE len2 = Builder.CreateCall(strlen, { op2 }); // len2 = strlen(op2);
619+
LLVM_VALUE len = Builder.CreateAdd(len1, len2); // len = len1 + len2;
620+
621+
// allocate and clear memory
622+
LLVM_VALUE alloc = Builder.CreateCall(malloc, { len }); // alloc = malloc(len1 + len2);
623+
LLVM_VALUE zero = ConstantInt::get(i32Ty, 0); // zero = 0;
624+
Builder.CreateCall(memset, { alloc, zero, len }); // memset(alloc, zero, len);
625+
626+
// concatenate operand strings into output string
627+
Builder.CreateCall(strcat, { alloc, op1 }); // strcat(alloc, op1);
628+
LLVM_VALUE newAlloc = Builder.CreateGEP(nullptr, alloc, len1); // newAlloc = alloc + len1
629+
Builder.CreateCall(strcat, { newAlloc, op2 }); // strcat(alloc, op2);
630+
631+
// store the address in the node's _out member so that it will be
632+
// cleaned up when the expression is destroyed.
633+
APInt outAddr = APInt(64, (uint64_t)&_out);
634+
LLVM_VALUE out = Constant::getIntegerValue(i8PtrPtrTy, outAddr); // out = &_out;
635+
Builder.CreateCall(free, { Builder.CreateLoad(out) }); // free(*out);
636+
Builder.CreateStore(alloc, out); // *out = alloc
637+
return alloc;
597638
}
598639

599640
assert(false && "unexpected op");

src/SeExpr/ExprNode.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,12 +464,13 @@ ExprType ExprBinaryOpNode::prep(bool wantScalar, ExprVarEnvBuilder& envBuilder)
464464

465465
bool error = false;
466466

467+
// prep children and get their types
467468
firstType = child(0)->prep(false, envBuilder);
468-
checkIsFP(firstType, error);
469469
secondType = child(1)->prep(false, envBuilder);
470-
checkIsFP(secondType, error);
471-
checkTypesCompatible(firstType, secondType, error);
472470

471+
// check compatibility and get return type
472+
// TODO: handle string + fp or fp + string, the same as in Python or equivalent
473+
checkTypesCompatible(firstType, secondType, error);
473474
if (error)
474475
setType(ExprType().Error());
475476
else

src/SeExpr/ExprNode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,13 +450,15 @@ class ExprCompareNode : public ExprNode {
450450
/// Node that implements an binary operator
451451
class ExprBinaryOpNode : public ExprNode {
452452
public:
453-
ExprBinaryOpNode(const Expression* expr, ExprNode* a, ExprNode* b, char op) : ExprNode(expr, a, b), _op(op) {}
453+
ExprBinaryOpNode(const Expression* expr, ExprNode* a, ExprNode* b, char op) : ExprNode(expr, a, b), _op(op), _out(0) {}
454+
virtual ~ExprBinaryOpNode() { free(_out); }
454455

455456
virtual ExprType prep(bool wantScalar, ExprVarEnvBuilder& envBuilder);
456457
virtual int buildInterpreter(Interpreter* interpreter) const;
457458
virtual LLVM_VALUE codegen(LLVM_BUILDER) LLVM_BODY;
458459

459460
char _op;
461+
char* _out;
460462
};
461463

462464
/// Node that references a variable

src/SeExpr/Interpreter.cpp

Lines changed: 91 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,40 @@ static Interpreter::OpF getTemplatizedOp2(int i) {
134134

135135
namespace {
136136

137+
//! Binary operator for strings. Currently only handle '+'
138+
struct BinaryStringOp {
139+
static int f(int* opData, double* fp, char** c, std::vector<int>& callStack) {
140+
// get the operand data
141+
char*& out = *(char**)c[opData[0]];
142+
char* in1 = c[opData[1]];
143+
char* in2 = c[opData[2]];
144+
145+
// delete previous data and allocate a new buffer, only if needed
146+
// NOTE: this is more efficient, but might consume more memory...
147+
// Maybe make this behaviour configurable ?
148+
int len1 = strlen(in1);
149+
int len2 = strlen(in2);
150+
if (out == 0 || len1 + len2 + 1 > strlen(out))
151+
{
152+
delete [] out;
153+
out = new char [len1 + len2 + 1];
154+
}
155+
156+
// clear previous evaluation content
157+
memset(out, 0, len1 + len2 + 1);
158+
159+
// concatenate strings
160+
strcat(out, in1);
161+
strcat(out + len1, in2);
162+
out[len1 + len2] = '\0';
163+
164+
// copy to the output
165+
c[opData[3]] = out;
166+
167+
return 1;
168+
}
169+
};
170+
137171
//! Computes a binary op of vector dimension d
138172
template <char op, int d>
139173
struct BinaryOp {
@@ -539,33 +573,67 @@ int ExprBinaryOpNode::buildInterpreter(Interpreter* interpreter) const {
539573
}
540574
}
541575

542-
switch (_op) {
543-
case '+':
544-
interpreter->addOp(getTemplatizedOp2<'+', BinaryOp>(dimout));
545-
break;
546-
case '-':
547-
interpreter->addOp(getTemplatizedOp2<'-', BinaryOp>(dimout));
548-
break;
549-
case '*':
550-
interpreter->addOp(getTemplatizedOp2<'*', BinaryOp>(dimout));
551-
break;
552-
case '/':
553-
interpreter->addOp(getTemplatizedOp2<'/', BinaryOp>(dimout));
554-
break;
555-
case '^':
556-
interpreter->addOp(getTemplatizedOp2<'^', BinaryOp>(dimout));
557-
break;
558-
case '%':
559-
interpreter->addOp(getTemplatizedOp2<'%', BinaryOp>(dimout));
560-
break;
561-
default:
562-
assert(false);
576+
// check if the node will output a string of numerical value
577+
bool isString = child0->type().isString() || child1->type().isString();
578+
579+
// add the operator
580+
if (isString == false) {
581+
switch (_op) {
582+
case '+':
583+
interpreter->addOp(getTemplatizedOp2<'+', BinaryOp>(dimout));
584+
break;
585+
case '-':
586+
interpreter->addOp(getTemplatizedOp2<'-', BinaryOp>(dimout));
587+
break;
588+
case '*':
589+
interpreter->addOp(getTemplatizedOp2<'*', BinaryOp>(dimout));
590+
break;
591+
case '/':
592+
interpreter->addOp(getTemplatizedOp2<'/', BinaryOp>(dimout));
593+
break;
594+
case '^':
595+
interpreter->addOp(getTemplatizedOp2<'^', BinaryOp>(dimout));
596+
break;
597+
case '%':
598+
interpreter->addOp(getTemplatizedOp2<'%', BinaryOp>(dimout));
599+
break;
600+
default:
601+
assert(false);
602+
}
603+
} else {
604+
switch (_op) {
605+
case '+': {
606+
interpreter->addOp(BinaryStringOp::f);
607+
int intermediateOp = interpreter->allocPtr();
608+
interpreter->s[intermediateOp] = (char*)(&_out);
609+
interpreter->addOperand(intermediateOp);
610+
break;
611+
}
612+
default:
613+
assert(false);
614+
}
563615
}
564-
int op2 = interpreter->allocFP(dimout);
616+
617+
// allocate the output
618+
int op2 = -1;
619+
if (isString == false) {
620+
op2 = interpreter->allocFP(dimout);
621+
} else {
622+
op2 = interpreter->allocPtr();
623+
}
624+
565625
interpreter->addOperand(op0);
566626
interpreter->addOperand(op1);
567627
interpreter->addOperand(op2);
568-
interpreter->endOp();
628+
629+
// NOTE: one of the operand can be a function. If it's the case for
630+
// strings, since functions are not immediately executed (they have
631+
// endOp(false)) using endOp() here would result in a nullptr
632+
// input operand during eval, thus the following arg to endOp.
633+
//
634+
// TODO: only stop execution if one of the operand is either a
635+
// function of a var ref.
636+
interpreter->endOp(isString == false);
569637

570638
return op2;
571639
}

src/tests/string.cpp

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,11 @@ using namespace SeExpr2;
2424

2525

2626
struct StringFunc : public ExprFuncSimple {
27-
StringFunc() : ExprFuncSimple(true) {}
27+
StringFunc() : ExprFuncSimple(true) {
28+
}
2829

29-
struct StringData : public SeExpr2::ExprFuncNode::Data
30+
struct StringData : public SeExpr2::ExprFuncNode::Data, public std::string
3031
{
31-
std::string str;
32-
int numArgs;
3332
};
3433

3534
virtual ExprType prep(ExprFuncNode* node, bool scalarWanted, ExprVarEnvBuilder& envBuilder) const {
@@ -45,21 +44,21 @@ struct StringFunc : public ExprFuncSimple {
4544
}
4645
return constant == true ? SeExpr2::ExprType().String().Constant() : SeExpr2::ExprType().String().Varying();
4746
}
47+
4848
virtual ExprFuncNode::Data* evalConstant(const ExprFuncNode* node, ArgHandle args) const {
49-
StringData* data = new StringData();
50-
data->numArgs = node->numChildren();
51-
return data;
49+
return new StringData();
5250
}
51+
5352
virtual void eval(ArgHandle args) {
54-
StringData* data = reinterpret_cast<StringData*>(args.data);
55-
data->str.clear();
56-
for (int i = 0; i < data->numArgs; ++i) {
57-
data->str += args.inStr(i);
58-
if (i != data->numArgs - 1) {
59-
data->str += "/";
53+
StringData& data = *reinterpret_cast<StringData*>(args.data);
54+
data.clear();
55+
for (int i = 0, iend = args.nargs(); i < iend; ++i) {
56+
data += args.inStr(i);
57+
if (i != iend - 1) {
58+
data += "/";
6059
}
6160
}
62-
args.outStr = const_cast<char*>(data->str.c_str());
61+
args.outStr = const_cast<char*>(data.c_str());
6362
}
6463
} joinPath;
6564
ExprFunc joinPathFunc(joinPath, 2, 100);
@@ -126,3 +125,31 @@ TEST(StringTests, FunctionVarying) {
126125
EXPECT_TRUE(expr.isConstant() == false);
127126
EXPECT_STREQ(expr.evalStr(), "/home/foo/some/relative/path");
128127
}
128+
129+
TEST(StringTests, BinaryOp) {
130+
StringExpression expr1("\"hello \" + \"world!\"");
131+
EXPECT_TRUE(expr1.isValid() == true);
132+
EXPECT_TRUE(expr1.returnType().isString() == true);
133+
EXPECT_TRUE(expr1.isConstant() == true);
134+
EXPECT_STREQ(expr1.evalStr(), "hello world!");
135+
136+
StringExpression expr2("\"hello \" + \"world\" + \"!\"");
137+
EXPECT_TRUE(expr2.isValid() == true);
138+
EXPECT_TRUE(expr2.returnType().isString() == true);
139+
EXPECT_TRUE(expr2.isConstant() == true);
140+
EXPECT_STREQ(expr2.evalStr(), "hello world!");
141+
142+
StringExpression expr3("stringVar + \"world!\"");
143+
expr3.stringVar = "hello ";
144+
EXPECT_TRUE(expr3.isValid() == true);
145+
EXPECT_TRUE(expr3.returnType().isString() == true);
146+
EXPECT_TRUE(expr3.isConstant() == false);
147+
EXPECT_STREQ(expr3.evalStr(), "hello world!");
148+
149+
StringExpression expr4("join_path(\"a\", \"b\") + \"/c/\" + stringVar");
150+
expr4.stringVar = "d";
151+
EXPECT_TRUE(expr4.isValid() == true);
152+
EXPECT_TRUE(expr4.returnType().isString() == true);
153+
EXPECT_TRUE(expr4.isConstant() == false);
154+
EXPECT_STREQ(expr4.evalStr(), "a/b/c/d");
155+
}

0 commit comments

Comments
 (0)