Skip to content

Commit 9b48bb6

Browse files
committed
[experimental-solidity] Generate code for types of different stack sizes
1 parent c7db606 commit 9b48bb6

File tree

8 files changed

+313
-39
lines changed

8 files changed

+313
-39
lines changed

Diff for: libsolidity/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ set(sources
207207
experimental/ast/TypeSystemHelper.h
208208
experimental/codegen/Common.h
209209
experimental/codegen/Common.cpp
210+
experimental/codegen/IRVariable.cpp
211+
experimental/codegen/IRVariable.h
210212
experimental/codegen/IRGenerationContext.h
211213
experimental/codegen/IRGenerator.cpp
212214
experimental/codegen/IRGenerator.h

Diff for: libsolidity/ast/AST.h

+3
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,9 @@ class BinaryOperation: public Expression
21622162
Expression const& rightExpression() const { return *m_right; }
21632163
Token getOperator() const { return m_operator; }
21642164

2165+
/// @returns the given arguments in the order they were written.
2166+
std::vector<ASTPointer<Expression const>> arguments() const { return {m_left, m_right}; }
2167+
21652168
FunctionType const* userDefinedFunctionType() const;
21662169

21672170
BinaryOperationAnnotation& annotation() const override;

Diff for: libsolidity/experimental/analysis/TypeInference.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,7 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition)
822822

823823
members->second.emplace("abs", TypeMember{helper.functionType(*underlyingType, definedType)});
824824
members->second.emplace("rep", TypeMember{helper.functionType(definedType, *underlyingType)});
825+
annotation().underlyingTypes[constructor] = *underlyingType;
825826
}
826827

827828
if (helper.isPrimitiveType(definedType, PrimitiveType::Pair))

Diff for: libsolidity/experimental/analysis/TypeInference.h

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class TypeInference: public ASTConstVisitor
5050
std::map<TypeClass, std::map<std::string, Type>> typeClassFunctions;
5151
std::map<Token, std::tuple<TypeClass, std::string>> operators;
5252
std::map<TypeConstructor, std::map<std::string, TypeMember>> members;
53+
std::map<TypeConstructor, Type> underlyingTypes;
5354
};
5455
bool visit(Block const&) override { return true; }
5556
bool visit(VariableDeclarationStatement const&) override { return true; }

Diff for: libsolidity/experimental/codegen/IRGeneratorForStatements.cpp

+122-39
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <libsolidity/experimental/codegen/Common.h>
3434

3535
#include <range/v3/view/drop_last.hpp>
36+
#include <range/v3/view/zip.hpp>
3637

3738
using namespace solidity;
3839
using namespace solidity::util;
@@ -47,9 +48,70 @@ std::string IRGeneratorForStatements::generate(ASTNode const& _node)
4748
}
4849

4950

50-
namespace
51+
namespace solidity::frontend::experimental
5152
{
5253

54+
static std::size_t stackSize(IRGenerationContext const& _context, Type _type)
55+
{
56+
TypeSystemHelpers helper{_context.analysis.typeSystem()};
57+
_type = _context.env->resolve(_type);
58+
solAssert(std::holds_alternative<TypeConstant>(_type), "No monomorphized type.");
59+
60+
// type -> # stack slots
61+
// unit, itself -> 0
62+
// void, literals(integer), typeFunction -> error (maybe generate a revert)
63+
// word, bool, function -> 1
64+
// pair -> sum(stackSize(args))
65+
// user-defined -> stackSize(underlying type)
66+
TypeConstant typeConstant = std::get<TypeConstant>(_type);
67+
if (
68+
helper.isPrimitiveType(_type, PrimitiveType::Unit) ||
69+
helper.isPrimitiveType(_type, PrimitiveType::Itself)
70+
)
71+
return 0;
72+
else if (
73+
helper.isPrimitiveType(_type, PrimitiveType::Bool) ||
74+
helper.isPrimitiveType(_type, PrimitiveType::Word)
75+
)
76+
{
77+
solAssert(typeConstant.arguments.empty(), "Primitive type Bool or Word should have no arguments.");
78+
return 1;
79+
}
80+
else if (helper.isFunctionType(_type))
81+
return 1;
82+
else if (
83+
helper.isPrimitiveType(_type, PrimitiveType::Integer) ||
84+
helper.isPrimitiveType(_type, PrimitiveType::Void) ||
85+
helper.isPrimitiveType(_type, PrimitiveType::TypeFunction)
86+
)
87+
solAssert(false, "Attempted to query the stack size of a type without stack representation.");
88+
else if (helper.isPrimitiveType(_type, PrimitiveType::Pair))
89+
{
90+
solAssert(typeConstant.arguments.size() == 2);
91+
return stackSize(_context, typeConstant.arguments.front()) + stackSize(_context, typeConstant.arguments.back());
92+
}
93+
else
94+
{
95+
Type underlyingType = _context.env->resolve(
96+
_context.analysis.annotation<TypeInference>().underlyingTypes.at(typeConstant.constructor));
97+
if (helper.isTypeConstant(underlyingType))
98+
return stackSize(_context, underlyingType);
99+
100+
TypeEnvironment env = _context.env->clone();
101+
Type genericFunctionType = helper.typeFunctionType(
102+
helper.tupleType(typeConstant.arguments),
103+
env.typeSystem().freshTypeVariable({}));
104+
solAssert(env.unify(genericFunctionType, underlyingType).empty());
105+
106+
Type resolvedType = env.resolveRecursive(genericFunctionType);
107+
auto [argumentType, resultType] = helper.destTypeFunctionType(resolvedType);
108+
return stackSize(_context, resultType);
109+
}
110+
111+
//TODO: sum types
112+
return 0;
113+
}
114+
53115
struct CopyTranslate: public yul::ASTCopier
54116
{
55117
CopyTranslate(
@@ -101,7 +163,7 @@ struct CopyTranslate: public yul::ASTCopier
101163
auto type = m_context.analysis.annotation<TypeInference>(*varDecl).type;
102164
solAssert(type);
103165
solAssert(m_context.env->typeEquals(*type, m_context.analysis.typeSystem().type(PrimitiveType::Word, {})));
104-
std::string value = IRNames::localVariable(*varDecl);
166+
std::string value = IRVariable{*varDecl, *type, stackSize(m_context, *type)}.name();
105167
return yul::Identifier{_identifier.debugData, yul::YulString{value}};
106168
}
107169

@@ -110,16 +172,14 @@ struct CopyTranslate: public yul::ASTCopier
110172
std::map<yul::Identifier const*, InlineAssemblyAnnotation::ExternalIdentifierInfo> m_references;
111173
};
112174

113-
}
114-
115175
bool IRGeneratorForStatements::visit(TupleExpression const& _tupleExpression)
116176
{
117177
std::vector<std::string> components;
118178
for (auto const& component: _tupleExpression.components())
119179
{
120180
solUnimplementedAssert(component);
121181
component->accept(*this);
122-
components.emplace_back(IRNames::localVariable(*component));
182+
components.emplace_back(var(*component).commaSeparatedList());
123183
}
124184

125185
solUnimplementedAssert(false, "No support for tuples.");
@@ -144,10 +204,11 @@ bool IRGeneratorForStatements::visit(VariableDeclarationStatement const& _variab
144204
VariableDeclaration const* variableDeclaration = _variableDeclarationStatement.declarations().front().get();
145205
solAssert(variableDeclaration);
146206
// TODO: check the type of the variable; register local variable; initialize
147-
m_code << "let " << IRNames::localVariable(*variableDeclaration);
148207
if (_variableDeclarationStatement.initialValue())
149-
m_code << " := " << IRNames::localVariable(*_variableDeclarationStatement.initialValue());
150-
m_code << "\n";
208+
define(var(*variableDeclaration), var(*_variableDeclarationStatement.initialValue()));
209+
else
210+
declare(var(*variableDeclaration));
211+
151212
return false;
152213
}
153214

@@ -158,10 +219,8 @@ bool IRGeneratorForStatements::visit(ExpressionStatement const&)
158219

159220
bool IRGeneratorForStatements::visit(Identifier const& _identifier)
160221
{
161-
if (auto const* var = dynamic_cast<VariableDeclaration const*>(_identifier.annotation().referencedDeclaration))
162-
{
163-
m_code << "let " << IRNames::localVariable(_identifier) << " := " << IRNames::localVariable(*var) << "\n";
164-
}
222+
if (auto const* variable = dynamic_cast<VariableDeclaration const*>(_identifier.annotation().referencedDeclaration))
223+
define(var(_identifier), var(*variable));
165224
else if (auto const* function = dynamic_cast<FunctionDefinition const*>(_identifier.annotation().referencedDeclaration))
166225
solAssert(m_expressionDeclaration.emplace(&_identifier, function).second);
167226
else if (auto const* typeClass = dynamic_cast<TypeClassDefinition const*>(_identifier.annotation().referencedDeclaration))
@@ -179,7 +238,8 @@ void IRGeneratorForStatements::endVisit(Return const& _return)
179238
{
180239
solAssert(_return.annotation().function, "Invalid return.");
181240
solAssert(_return.annotation().function->experimentalReturnExpression(), "Invalid return.");
182-
m_code << IRNames::localVariable(*_return.annotation().function->experimentalReturnExpression()) << " := " << IRNames::localVariable(*value) << "\n";
241+
auto returnExpression = _return.annotation().function->experimentalReturnExpression();
242+
assign(var(*returnExpression), var(*value));
183243
}
184244

185245
m_code << "leave\n";
@@ -201,13 +261,44 @@ void IRGeneratorForStatements::endVisit(BinaryOperation const& _binaryOperation)
201261
Type functionType = helper.functionType(helper.tupleType({leftType, rightType}), resultType);
202262
auto [typeClass, memberName] = m_context.analysis.annotation<TypeInference>().operators.at(_binaryOperation.getOperator());
203263
auto const& functionDefinition = resolveTypeClassFunction(typeClass, memberName, functionType);
204-
// TODO: deduplicate with FunctionCall
264+
std::string result = var(_binaryOperation).commaSeparatedList();
265+
if (!result.empty())
266+
m_code << "let " << result << " := ";
267+
m_code << buildFunctionCall(functionDefinition, functionType, _binaryOperation.arguments());
268+
}
269+
270+
std::string IRGeneratorForStatements::buildFunctionCall(FunctionDefinition const& _functionDefinition, Type _functionType, std::vector<ASTPointer<Expression const>> const& _arguments)
271+
{
272+
// Ensure type is resolved
205273
// TODO: get around resolveRecursive by passing the environment further down?
206-
functionType = m_context.env->resolveRecursive(functionType);
207-
m_context.enqueueFunctionDefinition(&functionDefinition, functionType);
208-
// TODO: account for return stack size
209-
m_code << "let " << IRNames::localVariable(_binaryOperation) << " := " << IRNames::function(*m_context.env, functionDefinition, functionType) << "("
210-
<< IRNames::localVariable(_binaryOperation.leftExpression()) << ", " << IRNames::localVariable(_binaryOperation.rightExpression()) << ")\n";
274+
Type resolvedFunctionType = m_context.env->resolveRecursive(_functionType);
275+
m_context.enqueueFunctionDefinition(&_functionDefinition, resolvedFunctionType);
276+
277+
std::ostringstream output;
278+
output << IRNames::function(*m_context.env, _functionDefinition, resolvedFunctionType) << "(";
279+
if (_arguments.size() == 1)
280+
output << var(*_arguments.back()).commaSeparatedList();
281+
else if (_arguments.size() > 1)
282+
{
283+
for (auto arg: _arguments | ranges::views::drop_last(1))
284+
output << var(*arg).commaSeparatedList();
285+
output << var(*_arguments.back()).commaSeparatedListPrefixed();
286+
}
287+
output << ")\n";
288+
return output.str();
289+
}
290+
291+
void IRGeneratorForStatements::assign(IRVariable const& _lhs, IRVariable const& _rhs, bool _declare)
292+
{
293+
solAssert(stackSize(m_context, _lhs.type()) == stackSize(m_context, _rhs.type()));
294+
for (auto&& [lhsSlot, rhsSlot]: ranges::zip_view(_lhs.stackSlots(), _rhs.stackSlots()))
295+
m_code << (_declare ? "let " : "") << lhsSlot << " := " << rhsSlot << "\n";
296+
}
297+
298+
void IRGeneratorForStatements::declare(IRVariable const& _var)
299+
{
300+
if (_var.stackSize() > 0)
301+
m_code << "let " << _var.commaSeparatedList() << "\n";
211302
}
212303

213304
namespace
@@ -308,32 +399,23 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
308399
case Builtins::FromBool:
309400
case Builtins::Identity:
310401
solAssert(_functionCall.arguments().size() == 1);
311-
m_code << "let " << IRNames::localVariable(_functionCall) << " := " << IRNames::localVariable(*_functionCall.arguments().front()) << "\n";
402+
define(var(_functionCall), var(*_functionCall.arguments().front()));
312403
return;
313404
case Builtins::ToBool:
314405
solAssert(_functionCall.arguments().size() == 1);
315-
m_code << "let " << IRNames::localVariable(_functionCall) << " := iszero(iszero(" << IRNames::localVariable(*_functionCall.arguments().front()) << "))\n";
406+
m_code << "let " << var(_functionCall).name() << " := iszero(iszero(" << var(*_functionCall.arguments().front()).name() << "))\n";
316407
return;
317408
}
318409
solAssert(false);
319410
}
320411
FunctionDefinition const* functionDefinition = dynamic_cast<FunctionDefinition const*>(std::get<Declaration const*>(declaration));
321412
solAssert(functionDefinition);
322-
// TODO: get around resolveRecursive by passing the environment further down?
323-
functionType = m_context.env->resolveRecursive(functionType);
324-
m_context.enqueueFunctionDefinition(functionDefinition, functionType);
325413
// TODO: account for return stack size
326414
solAssert(!functionDefinition->returnParameterList());
327-
if (functionDefinition->experimentalReturnExpression())
328-
m_code << "let " << IRNames::localVariable(_functionCall) << " := ";
329-
m_code << IRNames::function(*m_context.env, *functionDefinition, functionType) << "(";
330-
auto const& arguments = _functionCall.arguments();
331-
if (arguments.size() > 1)
332-
for (auto arg: arguments | ranges::views::drop_last(1))
333-
m_code << IRNames::localVariable(*arg) << ", ";
334-
if (!arguments.empty())
335-
m_code << IRNames::localVariable(*arguments.back());
336-
m_code << ")\n";
415+
std::string result = var(_functionCall).commaSeparatedList();
416+
if (!result.empty())
417+
m_code << "let " << result << " := ";
418+
m_code << buildFunctionCall(*functionDefinition, functionType, _functionCall.arguments());
337419
}
338420

339421
bool IRGeneratorForStatements::visit(FunctionCall const&)
@@ -356,7 +438,7 @@ bool IRGeneratorForStatements::visit(IfStatement const& _ifStatement)
356438
_ifStatement.condition().accept(*this);
357439
if (_ifStatement.falseStatement())
358440
{
359-
m_code << "switch " << IRNames::localVariable(_ifStatement.condition()) << " {\n";
441+
m_code << "switch " << var(_ifStatement.condition()).name() << " {\n";
360442
m_code << "case 0 {\n";
361443
_ifStatement.falseStatement()->accept(*this);
362444
m_code << "}\n";
@@ -366,7 +448,7 @@ bool IRGeneratorForStatements::visit(IfStatement const& _ifStatement)
366448
}
367449
else
368450
{
369-
m_code << "if " << IRNames::localVariable(_ifStatement.condition()) << " {\n";
451+
m_code << "if " << var(_ifStatement.condition()).name() << " {\n";
370452
_ifStatement.trueStatement().accept(*this);
371453
m_code << "}\n";
372454
}
@@ -380,9 +462,8 @@ bool IRGeneratorForStatements::visit(Assignment const& _assignment)
380462
solAssert(lhs, "Can only assign to identifiers.");
381463
auto const* lhsVar = dynamic_cast<VariableDeclaration const*>(lhs->annotation().referencedDeclaration);
382464
solAssert(lhsVar, "Can only assign to identifiers referring to variables.");
383-
m_code << IRNames::localVariable(*lhsVar) << " := " << IRNames::localVariable(_assignment.rightHandSide()) << "\n";
384-
385-
m_code << "let " << IRNames::localVariable(_assignment) << " := " << IRNames::localVariable(*lhsVar) << "\n";
465+
assign(var(*lhsVar), var(_assignment.rightHandSide()));
466+
define(var(_assignment), var(*lhsVar));
386467
return false;
387468
}
388469

@@ -391,3 +472,5 @@ bool IRGeneratorForStatements::visitNode(ASTNode const&)
391472
{
392473
solAssert(false, "Unsupported AST node during statement code generation.");
393474
}
475+
476+
}

Diff for: libsolidity/experimental/codegen/IRGeneratorForStatements.h

+20
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#pragma once
2020

2121
#include <libsolidity/experimental/codegen/IRGenerationContext.h>
22+
#include <libsolidity/experimental/codegen/IRVariable.h>
23+
2224
#include <libsolidity/ast/ASTVisitor.h>
2325

2426
#include <functional>
@@ -34,6 +36,7 @@ class IRGeneratorForStatements: public ASTConstVisitor
3436
IRGeneratorForStatements(IRGenerationContext& _context): m_context(_context) {}
3537

3638
std::string generate(ASTNode const& _node);
39+
3740
private:
3841
bool visit(ExpressionStatement const& _expressionStatement) override;
3942
bool visit(Block const& _block) override;
@@ -54,6 +57,14 @@ class IRGeneratorForStatements: public ASTConstVisitor
5457
void endVisit(Return const& _return) override;
5558
/// Default visit will reject all AST nodes that are not explicitly supported.
5659
bool visitNode(ASTNode const& _node) override;
60+
61+
/// Defines @a _var using the value of @a _value. It declares and assign the variable.
62+
void define(IRVariable const& _var, IRVariable const& _value) { assign(_var, _value, true); }
63+
/// Assigns @a _var to the value of @a _value. It does not declare the variable.
64+
void assign(IRVariable const& _var, IRVariable const& _value, bool _declare = false);
65+
/// Declares variable @a _var.
66+
void declare(IRVariable const& _var);
67+
5768
IRGenerationContext& m_context;
5869
std::stringstream m_code;
5970
enum class Builtins
@@ -63,6 +74,15 @@ class IRGeneratorForStatements: public ASTConstVisitor
6374
ToBool
6475
};
6576
std::map<Expression const*, std::variant<Declaration const*, Builtins>> m_expressionDeclaration;
77+
78+
std::string buildFunctionCall(FunctionDefinition const& _functionDefinition, Type _functionType, std::vector<ASTPointer<Expression const>> const& _arguments);
79+
80+
template<typename IRVariableType>
81+
IRVariable var(IRVariableType const& _var) const
82+
{
83+
return IRVariable(_var, type(_var), stackSize(m_context, type(_var)));
84+
}
85+
6686
Type type(ASTNode const& _node) const;
6787

6888
FunctionDefinition const& resolveTypeClassFunction(TypeClass _class, std::string _name, Type _type);

0 commit comments

Comments
 (0)