Skip to content

Commit 757a58b

Browse files
committed
Don't check types twice
new semantic analyser is enough
1 parent e66d350 commit 757a58b

File tree

1 file changed

+2
-28
lines changed

1 file changed

+2
-28
lines changed

interpreter.py

+2-28
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from generated.MyParser import MyParser
55
from generated.MyParserVisitor import MyParserVisitor
66
from utils.memory import MemoryStack
7-
from utils.values import Value, Int, Float, String, Vector
7+
from utils.values import Int, Float, String, Vector
88

99

1010
class Break(Exception):
@@ -15,13 +15,6 @@ class Continue(Exception):
1515
pass
1616

1717

18-
def not_same_type(a: Value, b: Value):
19-
return type(a) is not type(b) or (
20-
isinstance(a, Vector)
21-
and (a.dims != b.dims or a.primitive_type != b.primitive_type)
22-
)
23-
24-
2518
class Interpreter(MyParserVisitor):
2619
def __init__(self):
2720
self.memory_stack = MemoryStack()
@@ -60,8 +53,6 @@ def visitForLoop(self, ctx: MyParser.ForLoopContext):
6053
def visitRange(self, ctx: MyParser.RangeContext):
6154
a = self.visit(ctx.expression(0))
6255
b = self.visit(ctx.expression(1))
63-
if {type(a), type(b)} != {Int}:
64-
raise TypeError
6556
return (a.value, b.value)
6657

6758
def visitWhileLoop(self, ctx: MyParser.WhileLoopContext):
@@ -96,8 +87,6 @@ def visitSimpleAssignment(self, ctx: MyParser.SimpleAssignmentContext):
9687
else: # a[0] = 1
9788
ref_value = self.visit(ctx.elementReference())
9889
new_value = self.visit(ctx.expression())
99-
if not_same_type(ref_value, new_value):
100-
raise TypeError
10190
ref_value.value = new_value.value
10291

10392
def visitCompoundAssignment(self, ctx: MyParser.CompoundAssignmentContext):
@@ -117,8 +106,6 @@ def visitCompoundAssignment(self, ctx: MyParser.CompoundAssignmentContext):
117106
else: # a[0] += 1
118107
ref_value = self.visit(ctx.elementReference())
119108
new_value = self.visit(ctx.expression())
120-
if not_same_type(ref_value, new_value):
121-
raise TypeError
122109
match ctx.getChild(1).symbol.type:
123110
case MyParser.ASSIGN_PLUS:
124111
new_value = ref_value + new_value
@@ -137,8 +124,6 @@ def visitPrint(self, ctx: MyParser.PrintContext):
137124
def visitReturn(self, ctx: MyParser.ReturnContext):
138125
if ctx.expression():
139126
return_value = self.visit(ctx.expression())
140-
if not isinstance(return_value, Int):
141-
raise TypeError
142127
sys.exit(return_value.value)
143128
sys.exit()
144129

@@ -167,10 +152,7 @@ def visitParenthesesExpression(self, ctx: MyParser.ParenthesesExpressionContext)
167152
return self.visit(ctx.expression())
168153

169154
def visitTransposeExpression(self, ctx: MyParser.TransposeExpressionContext):
170-
vector = self.visit(ctx.expression())
171-
if not isinstance(vector, Vector):
172-
raise TypeError
173-
return vector.transpose()
155+
return self.visit(ctx.expression()).transpose()
174156

175157
def visitMinusExpression(self, ctx: MyParser.MinusExpressionContext):
176158
return -self.visit(ctx.expression())
@@ -179,8 +161,6 @@ def visitSpecialMatrixFunction(self, ctx: MyParser.SpecialMatrixFunctionContext)
179161
fname = ctx.getChild(0).symbol.type
180162
if fname == MyParser.EYE:
181163
dim = self.visit(ctx.expression(0))
182-
if not isinstance(dim, Int):
183-
raise TypeError
184164
rows = [
185165
Vector([Int(i == j) for j in range(dim.value)])
186166
for i in range(dim.value)
@@ -191,8 +171,6 @@ def visitSpecialMatrixFunction(self, ctx: MyParser.SpecialMatrixFunctionContext)
191171
self.visit(ctx.expression(i))
192172
for i in range(ctx.getChildCount() // 2 - 1)
193173
]
194-
if {type(dim) for dim in dims} != {Int}:
195-
raise TypeError
196174
vector = {MyParser.ZEROS: Int(0), MyParser.ONES: Int(1)}[fname]
197175
for dim in reversed(dims):
198176
vector = Vector([deepcopy(vector) for _ in range(dim.value)])
@@ -214,12 +192,8 @@ def visitElementReference(self, ctx: MyParser.ElementReferenceContext):
214192
indices = [
215193
self.visit(ctx.expression(i)) for i in range(ctx.getChildCount() // 2 - 1)
216194
]
217-
if {type(idx) for idx in indices} != {Int}:
218-
raise TypeError
219195
result = self.visit(ctx.id_())
220196
for idx in indices:
221-
if not isinstance(result, Vector):
222-
raise TypeError
223197
result = result.value[idx.value]
224198
return result
225199

0 commit comments

Comments
 (0)