Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete interpretation #10

Merged
merged 7 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 107 additions & 15 deletions interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,40 @@

from generated.MyParser import MyParser
from generated.MyParserVisitor import MyParserVisitor
from utils.values import Int, Float, String, Vector
from utils.memory import MemoryStack
from utils.values import Value, Int, Float, String, Vector


class Break(Exception):
pass


class Continue(Exception):
pass


def not_same_type(a: Value, b: Value):
return type(a) is not type(b) or (
isinstance(a, Vector)
and (a.dims != b.dims or a.primitive_type != b.primitive_type)
)


class Interpreter(MyParserVisitor):
def __init__(self):
self.memory_stack = MemoryStack()
self.memory_stack.push_memory()

def visitScopeStatement(self, ctx: MyParser.ScopeStatementContext):
return self.visitChildren(ctx) # todo
self.memory_stack.push_memory()
self.visitChildren(ctx)
self.memory_stack.pop_memory()

def visitIfThenElse(self, ctx: MyParser.IfThenElseContext):
condition = self.visit(ctx.if_())
if condition:
return self.visit(ctx.then())
elif ctx.else_() is not None:
elif ctx.else_():
return self.visit(ctx.else_())

def visitIf(self, ctx: MyParser.IfContext):
Expand All @@ -24,13 +46,32 @@ def visitElse(self, ctx: MyParser.ElseContext):
return self.visit(ctx.statement())

def visitForLoop(self, ctx: MyParser.ForLoopContext):
return self.visitChildren(ctx) # todo
a, b = self.visit(ctx.range_())
variable = ctx.id_().getText()
for i in range(a, b + 1):
self.memory_stack.put(variable, Int(i))
try:
self.visit(ctx.statement())
except Continue:
continue
except Break:
break

def visitRange(self, ctx: MyParser.RangeContext):
return self.visitChildren(ctx) # todo
a = self.visit(ctx.expression(0))
b = self.visit(ctx.expression(1))
if {type(a), type(b)} != {Int}:
raise TypeError
return (a.value, b.value)

def visitWhileLoop(self, ctx: MyParser.WhileLoopContext):
return self.visitChildren(ctx) # todo
while self.visit(ctx.comparison()):
try:
self.visit(ctx.statement())
except Continue:
continue
except Break:
break

def visitComparison(self, ctx: MyParser.ComparisonContext):
a = self.visit(ctx.expression(0))
Expand All @@ -50,17 +91,51 @@ def visitComparison(self, ctx: MyParser.ComparisonContext):
return a >= b

def visitSimpleAssignment(self, ctx: MyParser.SimpleAssignmentContext):
return self.visitChildren(ctx) # todo
if ctx.id_(): # a = 1
self.memory_stack.put(ctx.id_().getText(), self.visit(ctx.expression()))
else: # a[0] = 1
ref_value = self.visit(ctx.elementReference())
new_value = self.visit(ctx.expression())
if not_same_type(ref_value, new_value):
raise TypeError
ref_value.value = new_value.value

def visitCompoundAssignment(self, ctx: MyParser.CompoundAssignmentContext):
return self.visitChildren(ctx) # todo
if ctx.id_(): # a += 1
value = self.memory_stack.get(ctx.id_().getText())
new_value = self.visit(ctx.expression())
match ctx.getChild(1).symbol.type:
case MyParser.ASSIGN_PLUS:
new_value = value + new_value
case MyParser.ASSIGN_MINUS:
new_value = value - new_value
case MyParser.ASSIGN_MULTIPLY:
new_value = value * new_value
case MyParser.ASSIGN_DIVIDE:
new_value = value / new_value
self.memory_stack.put(ctx.id_().getText(), new_value)
else: # a[0] += 1
ref_value = self.visit(ctx.elementReference())
new_value = self.visit(ctx.expression())
if not_same_type(ref_value, new_value):
raise TypeError
match ctx.getChild(1).symbol.type:
case MyParser.ASSIGN_PLUS:
new_value = ref_value + new_value
case MyParser.ASSIGN_MINUS:
new_value = ref_value - new_value
case MyParser.ASSIGN_MULTIPLY:
new_value = ref_value * new_value
case MyParser.ASSIGN_DIVIDE:
new_value = ref_value / new_value
ref_value.value = new_value.value

def visitPrint(self, ctx: MyParser.PrintContext):
for i in range(ctx.getChildCount() // 2):
print(str(self.visit(ctx.expression(i))))

def visitReturn(self, ctx: MyParser.ReturnContext):
if ctx.expression() is not None:
if ctx.expression():
return_value = self.visit(ctx.expression())
if not isinstance(return_value, Int):
raise TypeError
Expand All @@ -79,7 +154,14 @@ def visitBinaryExpression(self, ctx: MyParser.BinaryExpressionContext):
return a * b
case MyParser.DIVIDE:
return a / b
# todo: MAT_* operations
case MyParser.MAT_PLUS:
return a.mat_add(b)
case MyParser.MAT_MINUS:
return a.mat_sub(b)
case MyParser.MAT_MULTIPLY:
return a.mat_mul(b)
case MyParser.MAT_DIVIDE:
return a.mat_truediv(b)

def visitParenthesesExpression(self, ctx: MyParser.ParenthesesExpressionContext):
return self.visit(ctx.expression())
Expand Down Expand Up @@ -117,10 +199,10 @@ def visitSpecialMatrixFunction(self, ctx: MyParser.SpecialMatrixFunctionContext)
return vector

def visitBreak(self, ctx: MyParser.BreakContext):
return self.visitChildren(ctx) # todo
raise Break()

def visitContinue(self, ctx: MyParser.ContinueContext):
return self.visitChildren(ctx) # todo
raise Continue()

def visitVector(self, ctx: MyParser.VectorContext):
elements = [
Expand All @@ -129,10 +211,20 @@ def visitVector(self, ctx: MyParser.VectorContext):
return Vector(elements)

def visitElementReference(self, ctx: MyParser.ElementReferenceContext):
return self.visitChildren(ctx) # todo
indices = [
self.visit(ctx.expression(i)) for i in range(ctx.getChildCount() // 2 - 1)
]
if {type(idx) for idx in indices} != {Int}:
raise TypeError
result = self.visit(ctx.id_())
for idx in indices:
if not isinstance(result, Vector):
raise TypeError
result = result.value[idx.value]
return result

def visitId(self, ctx: MyParser.IdContext):
return self.visitChildren(ctx) # todo
return self.memory_stack.get(ctx.getText())

def visitInt(self, ctx: MyParser.IntContext):
return Int(ctx.getText())
Expand All @@ -141,4 +233,4 @@ def visitFloat(self, ctx: MyParser.FloatContext):
return Float(ctx.getText())

def visitString(self, ctx: MyParser.StringContext):
return String(ctx.getText())
return String(ctx.getText()[1:-1]) # without quotes
5 changes: 3 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def run(filename: str):

tree = parser.program()
if parser.getNumberOfSyntaxErrors() == 0:
listener = SemanticListener()
ParseTreeWalker().walk(listener, tree)
# todo: Fix SemanticListener
# listener = SemanticListener()
# ParseTreeWalker().walk(listener, tree)
if parser.getNumberOfSyntaxErrors() == 0:
visitor = Interpreter()
visitor.visit(tree)
Expand Down
22 changes: 22 additions & 0 deletions test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,28 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str):
[[1, 1, 1]],
],
),
("variables", [2, 1, 3, "OK", 6]),
("while", [4, 3, 2, 1, 0]),
("for", [1, 10, 2, 10, 3, 10, 4, 10]),
("break_continue", [1, 2, 1, 2, 4] * 2),
(
"element_reference",
[
[1, 0],
0,
[[1, 2], [0, 1]],
[[0, 2], [0, 1]],
[[0, 0], [0, 1]],
],
),
(
"mat_operators",
[
[[2, 2], [2, 2]],
[[4, 4], [4, 4]],
[[3, 3], [3, 3]],
],
),
],
)
def test_interpreter(name: str, output: str):
Expand Down
27 changes: 27 additions & 0 deletions tests/interpreter/break_continue.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
for i = 1:4 {
if (i == 3)
break;
print i;
}

for i = 1:4 {
if (i == 3)
continue;
print i;
}

i = 0;
while (i < 4) {
i += 1;
if (i == 3)
break;
print i;
}

i = 0;
while (i < 4) {
i += 1;
if (i == 3)
continue;
print i;
}
9 changes: 9 additions & 0 deletions tests/interpreter/element_reference.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
A = eye(2);
print A[0];
print A[0, 1];
A[0, 1] = 2;
print A;
A[0] = [0, 2];
print A;
A[0, 1] -= 2;
print A;
6 changes: 6 additions & 0 deletions tests/interpreter/for.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
n = 4;
for i = 1:n {
print i;
i = 10;
print i;
}
9 changes: 9 additions & 0 deletions tests/interpreter/mat_operators.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
A = ones(2, 2);
B = ones(2, 2);
A = A .+ B;
print A;
A = A .* A;
print A;
A = A .- B;
print A;
A = A ./ B;
16 changes: 16 additions & 0 deletions tests/interpreter/variables.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
a = 2;
print a;

a -= 1;
print a;

a *= 3;
print a;

if (a == 3) {
b = "OK";
print b;
}

b = 2 * a;
print b;
7 changes: 7 additions & 0 deletions tests/interpreter/while.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
b = -4;
a = 4;
while (a >= b) {
print a;
a -= 1;
b += 1;
}
38 changes: 38 additions & 0 deletions utils/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from .values import Value


class Memory:
def __init__(self):
self.variables: dict[str, Value] = {}

def has_variable(self, name: str) -> bool:
return name in self.variables

def get(self, name: str) -> Value:
return self.variables[name]

def put(self, name: str, value: Value):
self.variables[name] = value


class MemoryStack:
def __init__(self):
self.stack: list[Memory] = []

def get(self, name: str) -> Value:
for memory in self.stack:
if memory.has_variable(name):
return memory.get(name)

def put(self, name: str, value: Value):
for memory in self.stack:
if memory.has_variable(name):
memory.put(name, value)
return
self.stack[-1].put(name, value)

def push_memory(self):
self.stack.append(Memory())

def pop_memory(self):
self.stack.pop()
Loading
Loading