Skip to content

Commit ff83d67

Browse files
authored
Merge pull request #10 from mdbrnowski/interpretation
Complete interpretation
2 parents ef43a2f + 0eef1fc commit ff83d67

11 files changed

+295
-18
lines changed

interpreter.py

+107-15
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,40 @@
33

44
from generated.MyParser import MyParser
55
from generated.MyParserVisitor import MyParserVisitor
6-
from utils.values import Int, Float, String, Vector
6+
from utils.memory import MemoryStack
7+
from utils.values import Value, Int, Float, String, Vector
8+
9+
10+
class Break(Exception):
11+
pass
12+
13+
14+
class Continue(Exception):
15+
pass
16+
17+
18+
def not_same_type(a: Value, b: Value):
19+
return type(a) is not type(b) or (
20+
isinstance(a, Vector)
21+
and (a.dims != b.dims or a.primitive_type != b.primitive_type)
22+
)
723

824

925
class Interpreter(MyParserVisitor):
26+
def __init__(self):
27+
self.memory_stack = MemoryStack()
28+
self.memory_stack.push_memory()
29+
1030
def visitScopeStatement(self, ctx: MyParser.ScopeStatementContext):
11-
return self.visitChildren(ctx) # todo
31+
self.memory_stack.push_memory()
32+
self.visitChildren(ctx)
33+
self.memory_stack.pop_memory()
1234

1335
def visitIfThenElse(self, ctx: MyParser.IfThenElseContext):
1436
condition = self.visit(ctx.if_())
1537
if condition:
1638
return self.visit(ctx.then())
17-
elif ctx.else_() is not None:
39+
elif ctx.else_():
1840
return self.visit(ctx.else_())
1941

2042
def visitIf(self, ctx: MyParser.IfContext):
@@ -24,13 +46,32 @@ def visitElse(self, ctx: MyParser.ElseContext):
2446
return self.visit(ctx.statement())
2547

2648
def visitForLoop(self, ctx: MyParser.ForLoopContext):
27-
return self.visitChildren(ctx) # todo
49+
a, b = self.visit(ctx.range_())
50+
variable = ctx.id_().getText()
51+
for i in range(a, b + 1):
52+
self.memory_stack.put(variable, Int(i))
53+
try:
54+
self.visit(ctx.statement())
55+
except Continue:
56+
continue
57+
except Break:
58+
break
2859

2960
def visitRange(self, ctx: MyParser.RangeContext):
30-
return self.visitChildren(ctx) # todo
61+
a = self.visit(ctx.expression(0))
62+
b = self.visit(ctx.expression(1))
63+
if {type(a), type(b)} != {Int}:
64+
raise TypeError
65+
return (a.value, b.value)
3166

3267
def visitWhileLoop(self, ctx: MyParser.WhileLoopContext):
33-
return self.visitChildren(ctx) # todo
68+
while self.visit(ctx.comparison()):
69+
try:
70+
self.visit(ctx.statement())
71+
except Continue:
72+
continue
73+
except Break:
74+
break
3475

3576
def visitComparison(self, ctx: MyParser.ComparisonContext):
3677
a = self.visit(ctx.expression(0))
@@ -50,17 +91,51 @@ def visitComparison(self, ctx: MyParser.ComparisonContext):
5091
return a >= b
5192

5293
def visitSimpleAssignment(self, ctx: MyParser.SimpleAssignmentContext):
53-
return self.visitChildren(ctx) # todo
94+
if ctx.id_(): # a = 1
95+
self.memory_stack.put(ctx.id_().getText(), self.visit(ctx.expression()))
96+
else: # a[0] = 1
97+
ref_value = self.visit(ctx.elementReference())
98+
new_value = self.visit(ctx.expression())
99+
if not_same_type(ref_value, new_value):
100+
raise TypeError
101+
ref_value.value = new_value.value
54102

55103
def visitCompoundAssignment(self, ctx: MyParser.CompoundAssignmentContext):
56-
return self.visitChildren(ctx) # todo
104+
if ctx.id_(): # a += 1
105+
value = self.memory_stack.get(ctx.id_().getText())
106+
new_value = self.visit(ctx.expression())
107+
match ctx.getChild(1).symbol.type:
108+
case MyParser.ASSIGN_PLUS:
109+
new_value = value + new_value
110+
case MyParser.ASSIGN_MINUS:
111+
new_value = value - new_value
112+
case MyParser.ASSIGN_MULTIPLY:
113+
new_value = value * new_value
114+
case MyParser.ASSIGN_DIVIDE:
115+
new_value = value / new_value
116+
self.memory_stack.put(ctx.id_().getText(), new_value)
117+
else: # a[0] += 1
118+
ref_value = self.visit(ctx.elementReference())
119+
new_value = self.visit(ctx.expression())
120+
if not_same_type(ref_value, new_value):
121+
raise TypeError
122+
match ctx.getChild(1).symbol.type:
123+
case MyParser.ASSIGN_PLUS:
124+
new_value = ref_value + new_value
125+
case MyParser.ASSIGN_MINUS:
126+
new_value = ref_value - new_value
127+
case MyParser.ASSIGN_MULTIPLY:
128+
new_value = ref_value * new_value
129+
case MyParser.ASSIGN_DIVIDE:
130+
new_value = ref_value / new_value
131+
ref_value.value = new_value.value
57132

58133
def visitPrint(self, ctx: MyParser.PrintContext):
59134
for i in range(ctx.getChildCount() // 2):
60135
print(str(self.visit(ctx.expression(i))))
61136

62137
def visitReturn(self, ctx: MyParser.ReturnContext):
63-
if ctx.expression() is not None:
138+
if ctx.expression():
64139
return_value = self.visit(ctx.expression())
65140
if not isinstance(return_value, Int):
66141
raise TypeError
@@ -79,7 +154,14 @@ def visitBinaryExpression(self, ctx: MyParser.BinaryExpressionContext):
79154
return a * b
80155
case MyParser.DIVIDE:
81156
return a / b
82-
# todo: MAT_* operations
157+
case MyParser.MAT_PLUS:
158+
return a.mat_add(b)
159+
case MyParser.MAT_MINUS:
160+
return a.mat_sub(b)
161+
case MyParser.MAT_MULTIPLY:
162+
return a.mat_mul(b)
163+
case MyParser.MAT_DIVIDE:
164+
return a.mat_truediv(b)
83165

84166
def visitParenthesesExpression(self, ctx: MyParser.ParenthesesExpressionContext):
85167
return self.visit(ctx.expression())
@@ -117,10 +199,10 @@ def visitSpecialMatrixFunction(self, ctx: MyParser.SpecialMatrixFunctionContext)
117199
return vector
118200

119201
def visitBreak(self, ctx: MyParser.BreakContext):
120-
return self.visitChildren(ctx) # todo
202+
raise Break()
121203

122204
def visitContinue(self, ctx: MyParser.ContinueContext):
123-
return self.visitChildren(ctx) # todo
205+
raise Continue()
124206

125207
def visitVector(self, ctx: MyParser.VectorContext):
126208
elements = [
@@ -129,10 +211,20 @@ def visitVector(self, ctx: MyParser.VectorContext):
129211
return Vector(elements)
130212

131213
def visitElementReference(self, ctx: MyParser.ElementReferenceContext):
132-
return self.visitChildren(ctx) # todo
214+
indices = [
215+
self.visit(ctx.expression(i)) for i in range(ctx.getChildCount() // 2 - 1)
216+
]
217+
if {type(idx) for idx in indices} != {Int}:
218+
raise TypeError
219+
result = self.visit(ctx.id_())
220+
for idx in indices:
221+
if not isinstance(result, Vector):
222+
raise TypeError
223+
result = result.value[idx.value]
224+
return result
133225

134226
def visitId(self, ctx: MyParser.IdContext):
135-
return self.visitChildren(ctx) # todo
227+
return self.memory_stack.get(ctx.getText())
136228

137229
def visitInt(self, ctx: MyParser.IntContext):
138230
return Int(ctx.getText())
@@ -141,4 +233,4 @@ def visitFloat(self, ctx: MyParser.FloatContext):
141233
return Float(ctx.getText())
142234

143235
def visitString(self, ctx: MyParser.StringContext):
144-
return String(ctx.getText())
236+
return String(ctx.getText()[1:-1]) # without quotes

main.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ def run(filename: str):
8787

8888
tree = parser.program()
8989
if parser.getNumberOfSyntaxErrors() == 0:
90-
listener = SemanticListener()
91-
ParseTreeWalker().walk(listener, tree)
90+
# todo: Fix SemanticListener
91+
# listener = SemanticListener()
92+
# ParseTreeWalker().walk(listener, tree)
9293
if parser.getNumberOfSyntaxErrors() == 0:
9394
visitor = Interpreter()
9495
visitor.visit(tree)

test_main.py

+22
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,28 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str):
8181
[[1, 1, 1]],
8282
],
8383
),
84+
("variables", [2, 1, 3, "OK", 6]),
85+
("while", [4, 3, 2, 1, 0]),
86+
("for", [1, 10, 2, 10, 3, 10, 4, 10]),
87+
("break_continue", [1, 2, 1, 2, 4] * 2),
88+
(
89+
"element_reference",
90+
[
91+
[1, 0],
92+
0,
93+
[[1, 2], [0, 1]],
94+
[[0, 2], [0, 1]],
95+
[[0, 0], [0, 1]],
96+
],
97+
),
98+
(
99+
"mat_operators",
100+
[
101+
[[2, 2], [2, 2]],
102+
[[4, 4], [4, 4]],
103+
[[3, 3], [3, 3]],
104+
],
105+
),
84106
],
85107
)
86108
def test_interpreter(name: str, output: str):

tests/interpreter/break_continue.txt

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
for i = 1:4 {
2+
if (i == 3)
3+
break;
4+
print i;
5+
}
6+
7+
for i = 1:4 {
8+
if (i == 3)
9+
continue;
10+
print i;
11+
}
12+
13+
i = 0;
14+
while (i < 4) {
15+
i += 1;
16+
if (i == 3)
17+
break;
18+
print i;
19+
}
20+
21+
i = 0;
22+
while (i < 4) {
23+
i += 1;
24+
if (i == 3)
25+
continue;
26+
print i;
27+
}
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
A = eye(2);
2+
print A[0];
3+
print A[0, 1];
4+
A[0, 1] = 2;
5+
print A;
6+
A[0] = [0, 2];
7+
print A;
8+
A[0, 1] -= 2;
9+
print A;

tests/interpreter/for.txt

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
n = 4;
2+
for i = 1:n {
3+
print i;
4+
i = 10;
5+
print i;
6+
}

tests/interpreter/mat_operators.txt

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
A = ones(2, 2);
2+
B = ones(2, 2);
3+
A = A .+ B;
4+
print A;
5+
A = A .* A;
6+
print A;
7+
A = A .- B;
8+
print A;
9+
A = A ./ B;

tests/interpreter/variables.txt

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
a = 2;
2+
print a;
3+
4+
a -= 1;
5+
print a;
6+
7+
a *= 3;
8+
print a;
9+
10+
if (a == 3) {
11+
b = "OK";
12+
print b;
13+
}
14+
15+
b = 2 * a;
16+
print b;

tests/interpreter/while.txt

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
b = -4;
2+
a = 4;
3+
while (a >= b) {
4+
print a;
5+
a -= 1;
6+
b += 1;
7+
}

utils/memory.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from .values import Value
2+
3+
4+
class Memory:
5+
def __init__(self):
6+
self.variables: dict[str, Value] = {}
7+
8+
def has_variable(self, name: str) -> bool:
9+
return name in self.variables
10+
11+
def get(self, name: str) -> Value:
12+
return self.variables[name]
13+
14+
def put(self, name: str, value: Value):
15+
self.variables[name] = value
16+
17+
18+
class MemoryStack:
19+
def __init__(self):
20+
self.stack: list[Memory] = []
21+
22+
def get(self, name: str) -> Value:
23+
for memory in self.stack:
24+
if memory.has_variable(name):
25+
return memory.get(name)
26+
27+
def put(self, name: str, value: Value):
28+
for memory in self.stack:
29+
if memory.has_variable(name):
30+
memory.put(name, value)
31+
return
32+
self.stack[-1].put(name, value)
33+
34+
def push_memory(self):
35+
self.stack.append(Memory())
36+
37+
def pop_memory(self):
38+
self.stack.pop()

0 commit comments

Comments
 (0)