Skip to content

Experimental codegen for types of different stack sizes #14620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libsolidity/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ set(sources
experimental/ast/TypeSystemHelper.h
experimental/codegen/Common.h
experimental/codegen/Common.cpp
experimental/codegen/IRVariable.cpp
experimental/codegen/IRVariable.h
experimental/codegen/IRGenerationContext.h
experimental/codegen/IRGenerator.cpp
experimental/codegen/IRGenerator.h
Expand Down
3 changes: 3 additions & 0 deletions libsolidity/ast/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,9 @@ class BinaryOperation: public Expression
Expression const& rightExpression() const { return *m_right; }
Token getOperator() const { return m_operator; }

/// @returns the given arguments in the order they were written.
std::vector<ASTPointer<Expression const>> arguments() const { return {m_left, m_right}; }

FunctionType const* userDefinedFunctionType() const;

BinaryOperationAnnotation& annotation() const override;
Expand Down
1 change: 1 addition & 0 deletions libsolidity/experimental/analysis/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition)

members->second.emplace("abs", TypeMember{helper.functionType(*underlyingType, definedType)});
members->second.emplace("rep", TypeMember{helper.functionType(definedType, *underlyingType)});
annotation().underlyingTypes[constructor] = *underlyingType;
}

if (helper.isPrimitiveType(definedType, PrimitiveType::Pair))
Expand Down
1 change: 1 addition & 0 deletions libsolidity/experimental/analysis/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TypeInference: public ASTConstVisitor
std::map<TypeClass, std::map<std::string, Type>> typeClassFunctions;
std::map<Token, std::tuple<TypeClass, std::string>> operators;
std::map<TypeConstructor, std::map<std::string, TypeMember>> members;
std::map<TypeConstructor, Type> underlyingTypes;
};
bool visit(Block const&) override { return true; }
bool visit(VariableDeclarationStatement const&) override { return true; }
Expand Down
155 changes: 119 additions & 36 deletions libsolidity/experimental/codegen/IRGeneratorForStatements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <libsolidity/experimental/codegen/Common.h>

#include <range/v3/view/drop_last.hpp>
#include <range/v3/view/zip.hpp>

using namespace solidity;
using namespace solidity::util;
Expand Down Expand Up @@ -101,7 +102,7 @@ struct CopyTranslate: public yul::ASTCopier
auto type = m_context.analysis.annotation<TypeInference>(*varDecl).type;
solAssert(type);
solAssert(m_context.env->typeEquals(*type, m_context.analysis.typeSystem().type(PrimitiveType::Word, {})));
std::string value = IRNames::localVariable(*varDecl);
std::string value = IRVariable{*varDecl, *type, IRGeneratorForStatements::stackSize(m_context, *type)}.name();
return yul::Identifier{_identifier.debugData, yul::YulString{value}};
}

Expand All @@ -112,14 +113,75 @@ struct CopyTranslate: public yul::ASTCopier

}

std::size_t IRGeneratorForStatements::stackSize(IRGenerationContext const& _context, Type _type)
{
TypeSystemHelpers helper{_context.analysis.typeSystem()};
_type = _context.env->resolve(_type);
solAssert(std::holds_alternative<TypeConstant>(_type), "No monomorphized type.");

// type -> # stack slots
// unit, itself -> 0
// void, literals(integer), typeFunction -> error (maybe generate a revert)
// word, bool, function -> 1
// pair -> sum(stackSize(args))
// user-defined -> stackSize(underlying type)
TypeConstant typeConstant = std::get<TypeConstant>(_type);
if (
helper.isPrimitiveType(_type, PrimitiveType::Unit) ||
helper.isPrimitiveType(_type, PrimitiveType::Itself)
)
return 0;
else if (
helper.isPrimitiveType(_type, PrimitiveType::Bool) ||
helper.isPrimitiveType(_type, PrimitiveType::Word)
)
{
solAssert(typeConstant.arguments.empty(), "Primitive type Bool or Word should have no arguments.");
return 1;
}
else if (helper.isFunctionType(_type))
return 1;
else if (
helper.isPrimitiveType(_type, PrimitiveType::Integer) ||
helper.isPrimitiveType(_type, PrimitiveType::Void) ||
helper.isPrimitiveType(_type, PrimitiveType::TypeFunction)
)
solAssert(false, "Attempted to query the stack size of a type without stack representation.");
else if (helper.isPrimitiveType(_type, PrimitiveType::Pair))
{
solAssert(typeConstant.arguments.size() == 2);
return stackSize(_context, typeConstant.arguments.front()) + stackSize(_context, typeConstant.arguments.back());
}
else
{
Type underlyingType = _context.env->resolve(
_context.analysis.annotation<TypeInference>().underlyingTypes.at(typeConstant.constructor));
if (helper.isTypeConstant(underlyingType))
return stackSize(_context, underlyingType);

TypeEnvironment env = _context.env->clone();
Type genericFunctionType = helper.typeFunctionType(
helper.tupleType(typeConstant.arguments),
env.typeSystem().freshTypeVariable({}));
solAssert(env.unify(genericFunctionType, underlyingType).empty());

Type resolvedType = env.resolveRecursive(genericFunctionType);
auto [argumentType, resultType] = helper.destTypeFunctionType(resolvedType);
return stackSize(_context, resultType);
}

//TODO: sum types
return 0;
}

bool IRGeneratorForStatements::visit(TupleExpression const& _tupleExpression)
{
std::vector<std::string> components;
for (auto const& component: _tupleExpression.components())
{
solUnimplementedAssert(component);
component->accept(*this);
components.emplace_back(IRNames::localVariable(*component));
components.emplace_back(var(*component).commaSeparatedList());
}

solUnimplementedAssert(false, "No support for tuples.");
Expand All @@ -144,10 +206,11 @@ bool IRGeneratorForStatements::visit(VariableDeclarationStatement const& _variab
VariableDeclaration const* variableDeclaration = _variableDeclarationStatement.declarations().front().get();
solAssert(variableDeclaration);
// TODO: check the type of the variable; register local variable; initialize
m_code << "let " << IRNames::localVariable(*variableDeclaration);
if (_variableDeclarationStatement.initialValue())
m_code << " := " << IRNames::localVariable(*_variableDeclarationStatement.initialValue());
m_code << "\n";
define(var(*variableDeclaration), var(*_variableDeclarationStatement.initialValue()));
else
declare(var(*variableDeclaration));

return false;
}

Expand All @@ -158,10 +221,8 @@ bool IRGeneratorForStatements::visit(ExpressionStatement const&)

bool IRGeneratorForStatements::visit(Identifier const& _identifier)
{
if (auto const* var = dynamic_cast<VariableDeclaration const*>(_identifier.annotation().referencedDeclaration))
{
m_code << "let " << IRNames::localVariable(_identifier) << " := " << IRNames::localVariable(*var) << "\n";
}
if (auto const* variable = dynamic_cast<VariableDeclaration const*>(_identifier.annotation().referencedDeclaration))
define(var(_identifier), var(*variable));
else if (auto const* function = dynamic_cast<FunctionDefinition const*>(_identifier.annotation().referencedDeclaration))
solAssert(m_expressionDeclaration.emplace(&_identifier, function).second);
else if (auto const* typeClass = dynamic_cast<TypeClassDefinition const*>(_identifier.annotation().referencedDeclaration))
Expand All @@ -179,7 +240,8 @@ void IRGeneratorForStatements::endVisit(Return const& _return)
{
solAssert(_return.annotation().function, "Invalid return.");
solAssert(_return.annotation().function->experimentalReturnExpression(), "Invalid return.");
m_code << IRNames::localVariable(*_return.annotation().function->experimentalReturnExpression()) << " := " << IRNames::localVariable(*value) << "\n";
auto returnExpression = _return.annotation().function->experimentalReturnExpression();
assign(var(*returnExpression), var(*value));
}

m_code << "leave\n";
Expand All @@ -201,13 +263,44 @@ void IRGeneratorForStatements::endVisit(BinaryOperation const& _binaryOperation)
Type functionType = helper.functionType(helper.tupleType({leftType, rightType}), resultType);
auto [typeClass, memberName] = m_context.analysis.annotation<TypeInference>().operators.at(_binaryOperation.getOperator());
auto const& functionDefinition = resolveTypeClassFunction(typeClass, memberName, functionType);
// TODO: deduplicate with FunctionCall
std::string result = var(_binaryOperation).commaSeparatedList();
if (!result.empty())
m_code << "let " << result << " := ";
m_code << buildFunctionCall(functionDefinition, functionType, _binaryOperation.arguments());
}

std::string IRGeneratorForStatements::buildFunctionCall(FunctionDefinition const& _functionDefinition, Type _functionType, std::vector<ASTPointer<Expression const>> const& _arguments)
{
// Ensure type is resolved
// TODO: get around resolveRecursive by passing the environment further down?
functionType = m_context.env->resolveRecursive(functionType);
m_context.enqueueFunctionDefinition(&functionDefinition, functionType);
// TODO: account for return stack size
m_code << "let " << IRNames::localVariable(_binaryOperation) << " := " << IRNames::function(*m_context.env, functionDefinition, functionType) << "("
<< IRNames::localVariable(_binaryOperation.leftExpression()) << ", " << IRNames::localVariable(_binaryOperation.rightExpression()) << ")\n";
Type resolvedFunctionType = m_context.env->resolveRecursive(_functionType);
m_context.enqueueFunctionDefinition(&_functionDefinition, resolvedFunctionType);

std::ostringstream output;
output << IRNames::function(*m_context.env, _functionDefinition, resolvedFunctionType) << "(";
if (_arguments.size() == 1)
output << var(*_arguments.back()).commaSeparatedList();
else if (_arguments.size() > 1)
{
for (auto arg: _arguments | ranges::views::drop_last(1))
output << var(*arg).commaSeparatedList();
output << var(*_arguments.back()).commaSeparatedListPrefixed();
}
output << ")\n";
return output.str();
}

void IRGeneratorForStatements::assign(IRVariable const& _lhs, IRVariable const& _rhs, bool _declare)
{
solAssert(stackSize(m_context, _lhs.type()) == stackSize(m_context, _rhs.type()));
for (auto&& [lhsSlot, rhsSlot]: ranges::zip_view(_lhs.stackSlots(), _rhs.stackSlots()))
m_code << (_declare ? "let " : "") << lhsSlot << " := " << rhsSlot << "\n";
}

void IRGeneratorForStatements::declare(IRVariable const& _var)
{
if (_var.stackSize() > 0)
m_code << "let " << _var.commaSeparatedList() << "\n";
}

namespace
Expand Down Expand Up @@ -308,32 +401,23 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
case Builtins::FromBool:
case Builtins::Identity:
solAssert(_functionCall.arguments().size() == 1);
m_code << "let " << IRNames::localVariable(_functionCall) << " := " << IRNames::localVariable(*_functionCall.arguments().front()) << "\n";
define(var(_functionCall), var(*_functionCall.arguments().front()));
return;
case Builtins::ToBool:
solAssert(_functionCall.arguments().size() == 1);
m_code << "let " << IRNames::localVariable(_functionCall) << " := iszero(iszero(" << IRNames::localVariable(*_functionCall.arguments().front()) << "))\n";
m_code << "let " << var(_functionCall).name() << " := iszero(iszero(" << var(*_functionCall.arguments().front()).name() << "))\n";
return;
}
solAssert(false);
}
FunctionDefinition const* functionDefinition = dynamic_cast<FunctionDefinition const*>(std::get<Declaration const*>(declaration));
solAssert(functionDefinition);
// TODO: get around resolveRecursive by passing the environment further down?
functionType = m_context.env->resolveRecursive(functionType);
m_context.enqueueFunctionDefinition(functionDefinition, functionType);
// TODO: account for return stack size
solAssert(!functionDefinition->returnParameterList());
if (functionDefinition->experimentalReturnExpression())
m_code << "let " << IRNames::localVariable(_functionCall) << " := ";
m_code << IRNames::function(*m_context.env, *functionDefinition, functionType) << "(";
auto const& arguments = _functionCall.arguments();
if (arguments.size() > 1)
for (auto arg: arguments | ranges::views::drop_last(1))
m_code << IRNames::localVariable(*arg) << ", ";
if (!arguments.empty())
m_code << IRNames::localVariable(*arguments.back());
m_code << ")\n";
std::string result = var(_functionCall).commaSeparatedList();
if (!result.empty())
m_code << "let " << result << " := ";
m_code << buildFunctionCall(*functionDefinition, functionType, _functionCall.arguments());
}

bool IRGeneratorForStatements::visit(FunctionCall const&)
Expand All @@ -356,7 +440,7 @@ bool IRGeneratorForStatements::visit(IfStatement const& _ifStatement)
_ifStatement.condition().accept(*this);
if (_ifStatement.falseStatement())
{
m_code << "switch " << IRNames::localVariable(_ifStatement.condition()) << " {\n";
m_code << "switch " << var(_ifStatement.condition()).name() << " {\n";
m_code << "case 0 {\n";
_ifStatement.falseStatement()->accept(*this);
m_code << "}\n";
Expand All @@ -366,7 +450,7 @@ bool IRGeneratorForStatements::visit(IfStatement const& _ifStatement)
}
else
{
m_code << "if " << IRNames::localVariable(_ifStatement.condition()) << " {\n";
m_code << "if " << var(_ifStatement.condition()).name() << " {\n";
_ifStatement.trueStatement().accept(*this);
m_code << "}\n";
}
Expand All @@ -380,9 +464,8 @@ bool IRGeneratorForStatements::visit(Assignment const& _assignment)
solAssert(lhs, "Can only assign to identifiers.");
auto const* lhsVar = dynamic_cast<VariableDeclaration const*>(lhs->annotation().referencedDeclaration);
solAssert(lhsVar, "Can only assign to identifiers referring to variables.");
m_code << IRNames::localVariable(*lhsVar) << " := " << IRNames::localVariable(_assignment.rightHandSide()) << "\n";

m_code << "let " << IRNames::localVariable(_assignment) << " := " << IRNames::localVariable(*lhsVar) << "\n";
assign(var(*lhsVar), var(_assignment.rightHandSide()));
define(var(_assignment), var(*lhsVar));
return false;
}

Expand Down
20 changes: 20 additions & 0 deletions libsolidity/experimental/codegen/IRGeneratorForStatements.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#pragma once

#include <libsolidity/experimental/codegen/IRGenerationContext.h>
#include <libsolidity/experimental/codegen/IRVariable.h>

#include <libsolidity/ast/ASTVisitor.h>

#include <functional>
Expand All @@ -34,6 +36,7 @@ class IRGeneratorForStatements: public ASTConstVisitor
IRGeneratorForStatements(IRGenerationContext& _context): m_context(_context) {}

std::string generate(ASTNode const& _node);
static std::size_t stackSize(IRGenerationContext const& _context, Type _type);
private:
bool visit(ExpressionStatement const& _expressionStatement) override;
bool visit(Block const& _block) override;
Expand All @@ -54,6 +57,14 @@ class IRGeneratorForStatements: public ASTConstVisitor
void endVisit(Return const& _return) override;
/// Default visit will reject all AST nodes that are not explicitly supported.
bool visitNode(ASTNode const& _node) override;

/// Defines @a _var using the value of @a _value. It declares and assign the variable.
void define(IRVariable const& _var, IRVariable const& _value) { assign(_var, _value, true); }
/// Assigns @a _var to the value of @a _value. It does not declare the variable.
void assign(IRVariable const& _var, IRVariable const& _value, bool _declare = false);
/// Declares variable @a _var.
void declare(IRVariable const& _var);

IRGenerationContext& m_context;
std::stringstream m_code;
enum class Builtins
Expand All @@ -63,6 +74,15 @@ class IRGeneratorForStatements: public ASTConstVisitor
ToBool
};
std::map<Expression const*, std::variant<Declaration const*, Builtins>> m_expressionDeclaration;

std::string buildFunctionCall(FunctionDefinition const& _functionDefinition, Type _functionType, std::vector<ASTPointer<Expression const>> const& _arguments);

template<typename IRVariableType>
IRVariable var(IRVariableType const& _var) const
{
return IRVariable(_var, type(_var), stackSize(m_context, type(_var)));
}

Type type(ASTNode const& _node) const;

FunctionDefinition const& resolveTypeClassFunction(TypeClass _class, std::string _name, Type _type);
Expand Down
Loading