Skip to content

Commit e66d350

Browse files
authored
New semantic analysis (#12)
1 parent b096969 commit e66d350

8 files changed

+509
-287
lines changed

main.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from generated.MyLexer import MyLexer
88
from generated.MyParser import MyParser
99
from interpreter import Interpreter
10-
from semantic_listener import SemanticListener
10+
from semantic_analyser import SemanticAnalyser
1111

1212
app = typer.Typer(no_args_is_help=True)
1313
err_console = Console(stderr=True)
@@ -72,8 +72,8 @@ def sem(filename: str):
7272

7373
tree = parser.program()
7474
if parser.getNumberOfSyntaxErrors() == 0:
75-
listener = SemanticListener()
76-
ParseTreeWalker().walk(listener, tree)
75+
visitor = SemanticAnalyser()
76+
visitor.visit(tree)
7777

7878

7979
@app.command()
@@ -87,9 +87,8 @@ def run(filename: str):
8787

8888
tree = parser.program()
8989
if parser.getNumberOfSyntaxErrors() == 0:
90-
# todo: Fix SemanticListener
91-
# listener = SemanticListener()
92-
# ParseTreeWalker().walk(listener, tree)
90+
visitor = SemanticAnalyser()
91+
visitor.visit(tree)
9392
if parser.getNumberOfSyntaxErrors() == 0:
9493
visitor = Interpreter()
9594
visitor.visit(tree)

semantic_analyser.py

+282
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
from copy import deepcopy
2+
3+
from generated.MyParser import MyParser
4+
from generated.MyParserVisitor import MyParserVisitor
5+
from utils.memory import MemoryStack
6+
from utils.types import Int, Float, String, Vector, same_type
7+
8+
9+
class SemanticAnalyser(MyParserVisitor):
10+
"""Checks break and continue statements, variable declarations, types, assignments, etc."""
11+
12+
def __init__(self):
13+
self.nested_loop_counter = 0
14+
self.memory_stack = MemoryStack()
15+
self.memory_stack.push_memory()
16+
17+
def visitScopeStatement(self, ctx: MyParser.ScopeStatementContext):
18+
self.memory_stack.push_memory()
19+
self.visitChildren(ctx)
20+
self.memory_stack.pop_memory()
21+
22+
def visitForLoop(self, ctx: MyParser.ForLoopContext):
23+
self.visit(ctx.range_())
24+
variable = ctx.id_().getText()
25+
if self.memory_stack.get(variable) is None or (
26+
isinstance(self.memory_stack.get(variable), Int)
27+
):
28+
self.memory_stack.put(variable, Int())
29+
else:
30+
ctx.parser.notifyErrorListeners(
31+
"Incompatible types in an assignment", ctx.getChild(1).getSymbol()
32+
)
33+
self.nested_loop_counter += 1
34+
self.visit(ctx.statement())
35+
self.nested_loop_counter -= 1
36+
37+
def visitRange(self, ctx: MyParser.RangeContext):
38+
a = self.visit(ctx.expression(0))
39+
b = self.visit(ctx.expression(1))
40+
if not isinstance(a, Int) or not isinstance(b, Int):
41+
ctx.parser.notifyErrorListeners(
42+
"Range bounds must be integers", ctx.getChild(1).getSymbol()
43+
)
44+
45+
def visitWhileLoop(self, ctx: MyParser.WhileLoopContext):
46+
self.nested_loop_counter += 1
47+
self.visitChildren(ctx)
48+
self.nested_loop_counter -= 1
49+
50+
def visitComparison(self, ctx: MyParser.ComparisonContext):
51+
a = self.visit(ctx.expression(0))
52+
b = self.visit(ctx.expression(1))
53+
try:
54+
match ctx.getChild(1).symbol.type:
55+
case MyParser.EQ:
56+
return a == b
57+
case MyParser.NEQ:
58+
return a != b
59+
case MyParser.LT:
60+
return a < b
61+
case MyParser.LEQ:
62+
return a <= b
63+
case MyParser.GT:
64+
return a > b
65+
case MyParser.GEQ:
66+
return a >= b
67+
except TypeError:
68+
ctx.parser.notifyErrorListeners(
69+
"Incompatible types in a comparison", ctx.getChild(1).getSymbol()
70+
)
71+
72+
def visitSimpleAssignment(self, ctx: MyParser.SimpleAssignmentContext):
73+
if ctx.id_(): # a = 1
74+
variable = ctx.id_().getText()
75+
new_type = self.visit(ctx.expression())
76+
if self.memory_stack.get(variable) is None or (
77+
same_type(self.memory_stack.get(variable), new_type)
78+
):
79+
self.memory_stack.put(variable, new_type)
80+
else:
81+
ctx.parser.notifyErrorListeners(
82+
"Incompatible types in an assignment", ctx.getChild(1).getSymbol()
83+
)
84+
else: ## a[0] = 1
85+
reference = self.visit(ctx.elementReference())
86+
new_type = self.visit(ctx.expression())
87+
if not same_type(reference, new_type):
88+
ctx.parser.notifyErrorListeners(
89+
"Incompatible types in an assignment", ctx.getChild(1).getSymbol()
90+
)
91+
92+
def visitCompoundAssignment(self, ctx: MyParser.CompoundAssignmentContext):
93+
if ctx.id_(): # a = 1
94+
old_type = self.visit(ctx.id_())
95+
new_type = self.visit(ctx.expression())
96+
else: ## a[0] = 1
97+
old_type = self.visit(ctx.elementReference())
98+
new_type = self.visit(ctx.expression())
99+
try:
100+
match ctx.getChild(1).symbol.type:
101+
case MyParser.ASSIGN_PLUS:
102+
if not same_type(old_type, old_type + new_type):
103+
raise TypeError
104+
case MyParser.ASSIGN_MINUS:
105+
if not same_type(old_type, old_type - new_type):
106+
raise TypeError
107+
case MyParser.ASSIGN_MULTIPLY:
108+
if not same_type(old_type, old_type * new_type):
109+
raise TypeError
110+
case MyParser.ASSIGN_DIVIDE:
111+
if not same_type(old_type, old_type / new_type):
112+
raise TypeError
113+
except TypeError:
114+
ctx.parser.notifyErrorListeners(
115+
"Incompatible types in a compound assignment",
116+
ctx.getChild(1).getSymbol(),
117+
)
118+
119+
def visitReturn(self, ctx: MyParser.ReturnContext):
120+
if ctx.expression():
121+
return_type = self.visit(ctx.expression())
122+
if not isinstance(return_type, Int):
123+
ctx.parser.notifyErrorListeners(
124+
"Return type must be an integer", ctx.RETURN().getSymbol()
125+
)
126+
127+
def visitBinaryExpression(self, ctx: MyParser.BinaryExpressionContext):
128+
a = self.visit(ctx.expression(0))
129+
b = self.visit(ctx.expression(1))
130+
try:
131+
match ctx.op.type:
132+
case MyParser.PLUS:
133+
return a + b
134+
case MyParser.MINUS:
135+
return a - b
136+
case MyParser.MULTIPLY:
137+
return a * b
138+
case MyParser.DIVIDE:
139+
return a / b
140+
case MyParser.MAT_PLUS:
141+
return a.mat_add(b)
142+
case MyParser.MAT_MINUS:
143+
return a.mat_sub(b)
144+
case MyParser.MAT_MULTIPLY:
145+
return a.mat_mul(b)
146+
case MyParser.MAT_DIVIDE:
147+
return a.mat_truediv(b)
148+
except TypeError:
149+
ctx.parser.notifyErrorListeners(
150+
"Incompatible types in a binary operation",
151+
ctx.getChild(1).getSymbol(),
152+
)
153+
154+
def visitParenthesesExpression(self, ctx: MyParser.ParenthesesExpressionContext):
155+
return self.visit(ctx.expression())
156+
157+
def visitTransposeExpression(self, ctx: MyParser.TransposeExpressionContext):
158+
try:
159+
return self.visit(ctx.expression()).transpose()
160+
except TypeError:
161+
ctx.parser.notifyErrorListeners(
162+
"Transpose operator can only be applied to matrices",
163+
ctx.getChild(1).getSymbol(),
164+
)
165+
166+
def visitMinusExpression(self, ctx: MyParser.MinusExpressionContext):
167+
try:
168+
return -self.visit(ctx.expression())
169+
except TypeError:
170+
ctx.parser.notifyErrorListeners(
171+
"Unary minus can be applied only to integers or floats",
172+
ctx.MINUS().getSymbol(),
173+
)
174+
175+
def visitSpecialMatrixFunction(self, ctx: MyParser.SpecialMatrixFunctionContext):
176+
fname = ctx.getChild(0).symbol.type
177+
if fname == MyParser.EYE:
178+
dim = self.visit(ctx.expression(0))
179+
if not isinstance(dim, Int):
180+
ctx.parser.notifyErrorListeners(
181+
"Matrix dimentions must be integers", ctx.getChild(0).getSymbol()
182+
)
183+
return
184+
return Vector((dim.value, dim.value), Int())
185+
else:
186+
dims = [
187+
self.visit(ctx.expression(i))
188+
for i in range(ctx.getChildCount() // 2 - 1)
189+
]
190+
if not all(isinstance(dim, Int) for dim in dims):
191+
ctx.parser.notifyErrorListeners(
192+
"Matrix dimentions must be integers", ctx.getChild(0).getSymbol()
193+
) # todo: add more specific symbol
194+
return
195+
return Vector(
196+
tuple(dim.value for dim in dims), Int()
197+
) # todo: return Int(0) or Int(1)
198+
199+
def visitBreak(self, ctx: MyParser.BreakContext):
200+
if self.nested_loop_counter == 0:
201+
ctx.parser.notifyErrorListeners(
202+
"Break statement outside of loop", ctx.BREAK().getSymbol()
203+
)
204+
205+
def visitContinue(self, ctx: MyParser.ContinueContext):
206+
if self.nested_loop_counter == 0:
207+
ctx.parser.notifyErrorListeners(
208+
"Continue statement outside of loop", ctx.CONTINUE().getSymbol()
209+
)
210+
211+
def visitVector(self, ctx: MyParser.VectorContext):
212+
elements = [
213+
self.visit(ctx.expression(i)) for i in range(ctx.getChildCount() // 2)
214+
]
215+
for i in range(1, len(elements)):
216+
if not same_type(elements[i], elements[i - 1]):
217+
wrong_token = ctx.COMMA(i) or ctx.CLOSE_BRACKET_SQUARE()
218+
ctx.parser.notifyErrorListeners(
219+
"Inconsistent types in a vector", wrong_token.getSymbol()
220+
)
221+
return None
222+
elem = elements[0]
223+
if isinstance(elem, Int):
224+
elem.value = None
225+
if isinstance(elem, Vector):
226+
return Vector((len(elements), *elem.dims), elem.primitive_type)
227+
else:
228+
return Vector((len(elements),), elem)
229+
230+
def visitElementReference(self, ctx: MyParser.ElementReferenceContext):
231+
indices = [
232+
self.visit(ctx.expression(i)) for i in range(ctx.getChildCount() // 2 - 1)
233+
]
234+
if not all(isinstance(index, Int) for index in indices):
235+
ctx.parser.notifyErrorListeners(
236+
"Indices must be integers", ctx.OPEN_BRACKET_SQUARE().getSymbol()
237+
)
238+
return
239+
result = deepcopy(self.visit(ctx.id_()))
240+
if not isinstance(result, Vector):
241+
ctx.parser.notifyErrorListeners(
242+
"Indexing can only be applied to vectors",
243+
ctx.OPEN_BRACKET_SQUARE().getSymbol(),
244+
)
245+
return
246+
for idx in indices:
247+
if not isinstance(result, Vector):
248+
ctx.parser.notifyErrorListeners(
249+
"Too many indices", ctx.OPEN_BRACKET_SQUARE().getSymbol()
250+
)
251+
return
252+
if (
253+
idx.value is not None
254+
and result.dims[0] is not None
255+
and idx.value >= result.dims[0]
256+
):
257+
ctx.parser.notifyErrorListeners(
258+
"Index out of bounds", ctx.OPEN_BRACKET_SQUARE().getSymbol()
259+
)
260+
return
261+
result.dims = result.dims[1:]
262+
if len(result.dims) == 0:
263+
result = result.primitive_type
264+
return result
265+
266+
def visitId(self, ctx: MyParser.IdContext):
267+
result = self.memory_stack.get(ctx.getText())
268+
if result is not None:
269+
return result
270+
else:
271+
ctx.parser.notifyErrorListeners(
272+
f"Variable {ctx.getText()} not declared", ctx.ID().getSymbol()
273+
)
274+
275+
def visitInt(self, ctx: MyParser.IntContext):
276+
return Int(ctx.getText())
277+
278+
def visitFloat(self, ctx: MyParser.FloatContext):
279+
return Float()
280+
281+
def visitString(self, ctx: MyParser.StringContext):
282+
return String()

0 commit comments

Comments
 (0)