Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more type operations #11

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str):
@pytest.mark.parametrize(
"name,output",
[
("simple_math", [23, 35, 1, 1.0, 1, -2]),
("simple_math", [23, 35, 1.0, 1.0, 1, -2, 9.0, 6.0, 1.0]),
("strings", ["aaa", "abcd"]),
("conditions", [0, 1, 0, 1, 0, 1, 0, 1]),
(
"vectors",
Expand Down
3 changes: 2 additions & 1 deletion tests/interpreter/simple_math.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
print 3 + 4 * 5, (3 + 4) * 5;
print 2 / 2, 1.0 * 1.0, 2 - 1;
print -1 * 2;
print -1 * 2;
print 3 * 3.0, 3.0 + 3, 3.0 / 3;
2 changes: 2 additions & 0 deletions tests/interpreter/strings.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
print "a" * 3;
print "ab" + "cd";
67 changes: 31 additions & 36 deletions utils/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,28 @@ def __init__(self, value):
def __add__(self, other):
if isinstance(other, Int):
return Int(self.value + other.value)
elif isinstance(other, Float):
return Float(self.value + other.value)
raise TypeError()

def __sub__(self, other):
if isinstance(other, Int):
return Int(self.value - other.value)
elif isinstance(other, Float):
return Float(self.value - other.value)
raise TypeError()

def __mul__(self, other):
if isinstance(other, Int):
return Int(self.value * other.value)
elif isinstance(other, Float):
return Float(self.value * other.value)
raise TypeError()

def __truediv__(self, other):
if isinstance(other, Int):
if self.value % other.value == 0:
return Int(self.value // other.value)
return Float(self.value / other.value)
elif isinstance(other, Float):
return Float(self.value / other.value)
raise TypeError()

Expand All @@ -74,22 +80,22 @@ def __init__(self, value):
super().__init__(float(value))

def __add__(self, other):
if isinstance(other, Float):
if isinstance(other, Float) or isinstance(other, Int):
return Float(self.value + other.value)
raise TypeError()

def __sub__(self, other):
if isinstance(other, Float):
if isinstance(other, Float) or isinstance(other, Int):
return Float(self.value - other.value)
raise TypeError()

def __mul__(self, other):
if isinstance(other, Float):
if isinstance(other, Float) or isinstance(other, Int):
return Float(self.value * other.value)
raise TypeError()

def __truediv__(self, other):
if isinstance(other, Float):
if isinstance(other, Float) or isinstance(other, Int):
return Float(self.value / other.value)
raise TypeError()

Expand All @@ -101,6 +107,16 @@ class String(Value):
def __init__(self, value):
super().__init__(value)

def __add__(self, other):
if isinstance(other, String):
return String(self.value + other.value)
raise TypeError()

def __mul__(self, other):
if isinstance(other, Int):
return String(self.value * other.value)
raise TypeError()


class Vector(Value):
def __init__(self, value: list):
Expand Down Expand Up @@ -130,49 +146,28 @@ def __init__(self, value: list):
def __str__(self):
return "[" + ", ".join(str(elem) for elem in self.value) + "]"

def mat_add(self, other):
def _mat_op(self, other, op):
if isinstance(other, Vector):
rows = []
for elem, other_elem in zip(self.value, other.value):
if isinstance(elem, Vector):
rows.append(elem.mat_add(other_elem))
rows.append(elem._mat_op(other_elem, op))
else:
rows.append(elem + other_elem)
rows.append(op(elem, other_elem))
return Vector(rows)
raise TypeError()

def mat_add(self, other):
return self._mat_op(other, lambda x, y: x + y)

def mat_sub(self, other):
if isinstance(other, Vector):
rows = []
for elem, other_elem in zip(self.value, other.value):
if isinstance(elem, Vector):
rows.append(elem.mat_sub(other_elem))
else:
rows.append(elem - other_elem)
return Vector(rows)
raise TypeError()
return self._mat_op(other, lambda x, y: x - y)

def mat_mul(self, other):
if isinstance(other, Vector):
rows = []
for elem, other_elem in zip(self.value, other.value):
if isinstance(elem, Vector):
rows.append(elem.mat_mul(other_elem))
else:
rows.append(elem * other_elem)
return Vector(rows)
raise TypeError()
return self._mat_op(other, lambda x, y: x * y)

def mat_truediv(self, other):
if isinstance(other, Vector):
rows = []
for elem, other_elem in zip(self.value, other.value):
if isinstance(elem, Vector):
rows.append(elem.mat_truediv(other_elem))
else:
rows.append(elem / other_elem)
return Vector(rows)
raise TypeError()
return self._mat_op(other, lambda x, y: x / y)

def transpose(self):
if len(self.dims) != 2:
Expand Down
Loading