44from generated .MyParser import MyParser
55from generated .MyParserVisitor import MyParserVisitor
66from utils .memory import MemoryStack
7- from utils .values import Int , Float , String , Vector
7+ from utils .values import Value , Int , Float , String , Vector
88
99
1010class Break (Exception ):
@@ -15,6 +15,13 @@ class Continue(Exception):
1515 pass
1616
1717
18+ def not_same_type (a : Value , b : Value ):
19+ return type (a ) is not type (b ) or (
20+ isinstance (a , Vector )
21+ and (a .dims != b .dims or a .primitive_type != b .primitive_type )
22+ )
23+
24+
1825class Interpreter (MyParserVisitor ):
1926 def __init__ (self ):
2027 self .memory_stack = MemoryStack ()
@@ -84,29 +91,44 @@ def visitComparison(self, ctx: MyParser.ComparisonContext):
8491 return a >= b
8592
8693 def visitSimpleAssignment (self , ctx : MyParser .SimpleAssignmentContext ):
87- if ctx .id_ ():
88- self .visitChildren (ctx )
94+ if ctx .id_ (): # a = 1
8995 self .memory_stack .put (ctx .id_ ().getText (), self .visit (ctx .expression ()))
90- else :
91- pass # todo
96+ else : # a[0] = 1
97+ ref_value = self .visit (ctx .elementReference ())
98+ new_value = self .visit (ctx .expression ())
99+ if not_same_type (ref_value , new_value ):
100+ raise TypeError
101+ ref_value .value = new_value .value
92102
93103 def visitCompoundAssignment (self , ctx : MyParser .CompoundAssignmentContext ):
94- if ctx .id_ ():
95- self .visitChildren (ctx )
96- old_value = self .memory_stack .get (ctx .id_ ().getText ())
104+ if ctx .id_ (): # a += 1
105+ value = self .memory_stack .get (ctx .id_ ().getText ())
97106 new_value = self .visit (ctx .expression ())
98107 match ctx .getChild (1 ).symbol .type :
99108 case MyParser .ASSIGN_PLUS :
100- new_value = old_value + new_value
109+ new_value = value + new_value
101110 case MyParser .ASSIGN_MINUS :
102- new_value = old_value - new_value
111+ new_value = value - new_value
103112 case MyParser .ASSIGN_MULTIPLY :
104- new_value = old_value * new_value
113+ new_value = value * new_value
105114 case MyParser .ASSIGN_DIVIDE :
106- new_value = old_value / new_value
115+ new_value = value / new_value
107116 self .memory_stack .put (ctx .id_ ().getText (), new_value )
108- else :
109- pass # todo
117+ else : # a[0] += 1
118+ ref_value = self .visit (ctx .elementReference ())
119+ new_value = self .visit (ctx .expression ())
120+ if not_same_type (ref_value , new_value ):
121+ raise TypeError
122+ match ctx .getChild (1 ).symbol .type :
123+ case MyParser .ASSIGN_PLUS :
124+ new_value = ref_value + new_value
125+ case MyParser .ASSIGN_MINUS :
126+ new_value = ref_value - new_value
127+ case MyParser .ASSIGN_MULTIPLY :
128+ new_value = ref_value * new_value
129+ case MyParser .ASSIGN_DIVIDE :
130+ new_value = ref_value / new_value
131+ ref_value .value = new_value .value
110132
111133 def visitPrint (self , ctx : MyParser .PrintContext ):
112134 for i in range (ctx .getChildCount () // 2 ):
@@ -182,7 +204,17 @@ def visitVector(self, ctx: MyParser.VectorContext):
182204 return Vector (elements )
183205
184206 def visitElementReference (self , ctx : MyParser .ElementReferenceContext ):
185- return self .visitChildren (ctx ) # todo
207+ indices = [
208+ self .visit (ctx .expression (i )) for i in range (ctx .getChildCount () // 2 - 1 )
209+ ]
210+ if {type (idx ) for idx in indices } != {Int }:
211+ raise TypeError
212+ result = self .visit (ctx .id_ ())
213+ for idx in indices :
214+ if not isinstance (result , Vector ):
215+ raise TypeError
216+ result = result .value [idx .value ]
217+ return result
186218
187219 def visitId (self , ctx : MyParser .IdContext ):
188220 return self .memory_stack .get (ctx .getText ())
0 commit comments