Skip to content

Commit 0eef1fc

Browse files
committed
Interpret matrix operators
1 parent dfd512a commit 0eef1fc

File tree

4 files changed

+69
-1
lines changed

4 files changed

+69
-1
lines changed

interpreter.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,14 @@ def visitBinaryExpression(self, ctx: MyParser.BinaryExpressionContext):
154154
return a * b
155155
case MyParser.DIVIDE:
156156
return a / b
157-
# todo: MAT_* operations
157+
case MyParser.MAT_PLUS:
158+
return a.mat_add(b)
159+
case MyParser.MAT_MINUS:
160+
return a.mat_sub(b)
161+
case MyParser.MAT_MULTIPLY:
162+
return a.mat_mul(b)
163+
case MyParser.MAT_DIVIDE:
164+
return a.mat_truediv(b)
158165

159166
def visitParenthesesExpression(self, ctx: MyParser.ParenthesesExpressionContext):
160167
return self.visit(ctx.expression())

test_main.py

+8
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str):
9595
[[0, 0], [0, 1]],
9696
],
9797
),
98+
(
99+
"mat_operators",
100+
[
101+
[[2, 2], [2, 2]],
102+
[[4, 4], [4, 4]],
103+
[[3, 3], [3, 3]],
104+
],
105+
),
98106
],
99107
)
100108
def test_interpreter(name: str, output: str):

tests/interpreter/mat_operators.txt

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
A = ones(2, 2);
2+
B = ones(2, 2);
3+
A = A .+ B;
4+
print A;
5+
A = A .* A;
6+
print A;
7+
A = A .- B;
8+
print A;
9+
A = A ./ B;

utils/values.py

+44
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,50 @@ def __init__(self, value: list):
130130
def __str__(self):
131131
return "[" + ", ".join(str(elem) for elem in self.value) + "]"
132132

133+
def mat_add(self, other):
134+
if isinstance(other, Vector):
135+
rows = []
136+
for elem, other_elem in zip(self.value, other.value):
137+
if isinstance(elem, Vector):
138+
rows.append(elem.mat_add(other_elem))
139+
else:
140+
rows.append(elem + other_elem)
141+
return Vector(rows)
142+
raise TypeError()
143+
144+
def mat_sub(self, other):
145+
if isinstance(other, Vector):
146+
rows = []
147+
for elem, other_elem in zip(self.value, other.value):
148+
if isinstance(elem, Vector):
149+
rows.append(elem.mat_sub(other_elem))
150+
else:
151+
rows.append(elem - other_elem)
152+
return Vector(rows)
153+
raise TypeError()
154+
155+
def mat_mul(self, other):
156+
if isinstance(other, Vector):
157+
rows = []
158+
for elem, other_elem in zip(self.value, other.value):
159+
if isinstance(elem, Vector):
160+
rows.append(elem.mat_mul(other_elem))
161+
else:
162+
rows.append(elem * other_elem)
163+
return Vector(rows)
164+
raise TypeError()
165+
166+
def mat_truediv(self, other):
167+
if isinstance(other, Vector):
168+
rows = []
169+
for elem, other_elem in zip(self.value, other.value):
170+
if isinstance(elem, Vector):
171+
rows.append(elem.mat_truediv(other_elem))
172+
else:
173+
rows.append(elem / other_elem)
174+
return Vector(rows)
175+
raise TypeError()
176+
133177
def transpose(self):
134178
if len(self.dims) != 2:
135179
raise TypeError

0 commit comments

Comments
 (0)