Skip to content

Commit dfd512a

Browse files
committed
Interpret vector element references
1 parent c852cf4 commit dfd512a

File tree

4 files changed

+73
-16
lines changed

4 files changed

+73
-16
lines changed

interpreter.py

+47-15
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from generated.MyParser import MyParser
55
from generated.MyParserVisitor import MyParserVisitor
66
from utils.memory import MemoryStack
7-
from utils.values import Int, Float, String, Vector
7+
from utils.values import Value, Int, Float, String, Vector
88

99

1010
class Break(Exception):
@@ -15,6 +15,13 @@ class Continue(Exception):
1515
pass
1616

1717

18+
def not_same_type(a: Value, b: Value):
19+
return type(a) is not type(b) or (
20+
isinstance(a, Vector)
21+
and (a.dims != b.dims or a.primitive_type != b.primitive_type)
22+
)
23+
24+
1825
class Interpreter(MyParserVisitor):
1926
def __init__(self):
2027
self.memory_stack = MemoryStack()
@@ -84,29 +91,44 @@ def visitComparison(self, ctx: MyParser.ComparisonContext):
8491
return a >= b
8592

8693
def visitSimpleAssignment(self, ctx: MyParser.SimpleAssignmentContext):
87-
if ctx.id_():
88-
self.visitChildren(ctx)
94+
if ctx.id_(): # a = 1
8995
self.memory_stack.put(ctx.id_().getText(), self.visit(ctx.expression()))
90-
else:
91-
pass # todo
96+
else: # a[0] = 1
97+
ref_value = self.visit(ctx.elementReference())
98+
new_value = self.visit(ctx.expression())
99+
if not_same_type(ref_value, new_value):
100+
raise TypeError
101+
ref_value.value = new_value.value
92102

93103
def visitCompoundAssignment(self, ctx: MyParser.CompoundAssignmentContext):
94-
if ctx.id_():
95-
self.visitChildren(ctx)
96-
old_value = self.memory_stack.get(ctx.id_().getText())
104+
if ctx.id_(): # a += 1
105+
value = self.memory_stack.get(ctx.id_().getText())
97106
new_value = self.visit(ctx.expression())
98107
match ctx.getChild(1).symbol.type:
99108
case MyParser.ASSIGN_PLUS:
100-
new_value = old_value + new_value
109+
new_value = value + new_value
101110
case MyParser.ASSIGN_MINUS:
102-
new_value = old_value - new_value
111+
new_value = value - new_value
103112
case MyParser.ASSIGN_MULTIPLY:
104-
new_value = old_value * new_value
113+
new_value = value * new_value
105114
case MyParser.ASSIGN_DIVIDE:
106-
new_value = old_value / new_value
115+
new_value = value / new_value
107116
self.memory_stack.put(ctx.id_().getText(), new_value)
108-
else:
109-
pass # todo
117+
else: # a[0] += 1
118+
ref_value = self.visit(ctx.elementReference())
119+
new_value = self.visit(ctx.expression())
120+
if not_same_type(ref_value, new_value):
121+
raise TypeError
122+
match ctx.getChild(1).symbol.type:
123+
case MyParser.ASSIGN_PLUS:
124+
new_value = ref_value + new_value
125+
case MyParser.ASSIGN_MINUS:
126+
new_value = ref_value - new_value
127+
case MyParser.ASSIGN_MULTIPLY:
128+
new_value = ref_value * new_value
129+
case MyParser.ASSIGN_DIVIDE:
130+
new_value = ref_value / new_value
131+
ref_value.value = new_value.value
110132

111133
def visitPrint(self, ctx: MyParser.PrintContext):
112134
for i in range(ctx.getChildCount() // 2):
@@ -182,7 +204,17 @@ def visitVector(self, ctx: MyParser.VectorContext):
182204
return Vector(elements)
183205

184206
def visitElementReference(self, ctx: MyParser.ElementReferenceContext):
185-
return self.visitChildren(ctx) # todo
207+
indices = [
208+
self.visit(ctx.expression(i)) for i in range(ctx.getChildCount() // 2 - 1)
209+
]
210+
if {type(idx) for idx in indices} != {Int}:
211+
raise TypeError
212+
result = self.visit(ctx.id_())
213+
for idx in indices:
214+
if not isinstance(result, Vector):
215+
raise TypeError
216+
result = result.value[idx.value]
217+
return result
186218

187219
def visitId(self, ctx: MyParser.IdContext):
188220
return self.memory_stack.get(ctx.getText())

test_main.py

+10
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str):
8585
("while", [4, 3, 2, 1, 0]),
8686
("for", [1, 10, 2, 10, 3, 10, 4, 10]),
8787
("break_continue", [1, 2, 1, 2, 4] * 2),
88+
(
89+
"element_reference",
90+
[
91+
[1, 0],
92+
0,
93+
[[1, 2], [0, 1]],
94+
[[0, 2], [0, 1]],
95+
[[0, 0], [0, 1]],
96+
],
97+
),
8898
],
8999
)
90100
def test_interpreter(name: str, output: str):
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
A = eye(2);
2+
print A[0];
3+
print A[0, 1];
4+
A[0, 1] = 2;
5+
print A;
6+
A[0] = [0, 2];
7+
print A;
8+
A[0, 1] -= 2;
9+
print A;

utils/values.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@ def __init__(self, value: list):
108108
if (
109109
len(
110110
{
111-
(type(elem), elem.dims if isinstance(elem, Vector) else None)
111+
(
112+
(elem.dims, elem.primitive_type)
113+
if isinstance(elem, Vector)
114+
else type(elem)
115+
)
112116
for elem in value
113117
}
114118
)
@@ -118,8 +122,10 @@ def __init__(self, value: list):
118122

119123
if isinstance(value[0], Vector):
120124
self.dims = (len(value), *value[0].dims)
125+
self.primitive_type = value[0].primitive_type
121126
else:
122127
self.dims = (len(value),)
128+
self.primitive_type = type(value[0])
123129

124130
def __str__(self):
125131
return "[" + ", ".join(str(elem) for elem in self.value) + "]"

0 commit comments

Comments
 (0)