Skip to content

Commit 4625ed8

Browse files
authored
Merge pull request #1 from mdbrnowski/semantic-analysis
Semantic analysis
2 parents 93fabf3 + 4817b29 commit 4625ed8

13 files changed

+432
-0
lines changed

main.py

+17
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from generated.MyLexer import MyLexer
55
from generated.MyParser import MyParser
66
from ast_listener import ASTListener
7+
from semantic_listener import SemanticListener
78

89
app = typer.Typer(no_args_is_help=True)
910

@@ -50,5 +51,21 @@ def ast(filename: str):
5051
ParseTreeWalker().walk(listener, tree)
5152

5253

54+
@app.command()
55+
def sem(filename: str):
56+
"""Semantic analysis"""
57+
with open(filename, encoding="utf-8") as f:
58+
string = f.read()
59+
60+
lexer = MyLexer(InputStream(string))
61+
stream = CommonTokenStream(lexer)
62+
parser = MyParser(stream)
63+
64+
tree = parser.program()
65+
if parser.getNumberOfSyntaxErrors() == 0:
66+
listener = SemanticListener()
67+
ParseTreeWalker().walk(listener, tree)
68+
69+
5370
if __name__ == "__main__":
5471
app()

semantic_listener.py

+245
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
from enum import Enum, auto
2+
from antlr4 import ParserRuleContext
3+
from generated.MyParser import MyParser
4+
from generated.MyParserListener import MyParserListener
5+
6+
7+
class Type(Enum):
8+
INT = auto()
9+
FLOAT = auto()
10+
STRING = auto()
11+
TBD = auto() # to be determined (only during assignment)
12+
13+
14+
def is_plain_integer(ctx: ParserRuleContext) -> bool:
15+
return isinstance(ctx, MyParser.SingleExpressionContext) and isinstance(
16+
ctx.getChild(0), MyParser.IntContext
17+
)
18+
19+
20+
class SemanticListener(MyParserListener):
21+
"""Checks break and continue statements, variable declarations,types and assignments."""
22+
23+
def __init__(self):
24+
self.nested_loop_counter = 0
25+
self.variables: dict[str, Type | None] = {}
26+
self.expr_type: dict[
27+
ParserRuleContext, Type | tuple
28+
] = {} # values are either Type or (Type, int | None, int | None, ...)
29+
30+
# LOOP CHECKING
31+
32+
def enterForLoop(self, ctx: MyParser.ForLoopContext):
33+
self.nested_loop_counter += 1
34+
35+
def exitForLoop(self, ctx: MyParser.ForLoopContext):
36+
self.nested_loop_counter -= 1
37+
38+
def enterWhileLoop(self, ctx: MyParser.WhileLoopContext):
39+
self.nested_loop_counter += 1
40+
41+
def exitWhileLoop(self, ctx: MyParser.WhileLoopContext):
42+
self.nested_loop_counter -= 1
43+
44+
def enterBreak(self, ctx: MyParser.BreakContext):
45+
if self.nested_loop_counter == 0:
46+
ctx.parser.notifyErrorListeners(
47+
"Break statement outside of loop", ctx.BREAK().getSymbol()
48+
)
49+
50+
def enterContinue(self, ctx: MyParser.ContinueContext):
51+
if self.nested_loop_counter == 0:
52+
ctx.parser.notifyErrorListeners(
53+
"Continue statement outside of loop", ctx.CONTINUE().getSymbol()
54+
)
55+
56+
# VARIABLES & TYPES CHECKING
57+
58+
def enterRange(self, ctx: MyParser.RangeContext):
59+
pass
60+
61+
def exitRange(self, ctx: MyParser.RangeContext):
62+
pass
63+
64+
def enterComparison(self, ctx: MyParser.ComparisonContext):
65+
pass
66+
67+
def exitComparison(self, ctx: MyParser.ComparisonContext):
68+
children_types = {self.expr_type[ctx.getChild(i)] for i in [0, 2]}
69+
if not (
70+
children_types <= {Type.INT, Type.FLOAT}
71+
or (
72+
ctx.getChild(1).symbol.type in {MyParser.EQ, MyParser.NE}
73+
and children_types <= {Type.STRING}
74+
)
75+
):
76+
ctx.parser.notifyErrorListeners(
77+
"Incompatible types in a comparison", ctx.getChild(1).getSymbol()
78+
)
79+
self.expr_type[ctx] = None
80+
81+
def enterAssignment(self, ctx: MyParser.AssignmentContext):
82+
if (
83+
ctx.getChild(1).symbol.type == MyParser.ASSIGN
84+
and isinstance(ctx.getChild(0), MyParser.IdContext)
85+
and ctx.getChild(0).getText() not in self.variables
86+
):
87+
# type is unknown at this point
88+
self.variables[ctx.getChild(0).getText()] = Type.TBD
89+
90+
def exitAssignment(self, ctx: MyParser.AssignmentContext):
91+
if (
92+
ctx.getChild(1).symbol.type == MyParser.ASSIGN
93+
and isinstance(ctx.getChild(0), MyParser.IdContext)
94+
and self.variables[ctx.getChild(0).getText()] is Type.TBD
95+
):
96+
# we finally know the type
97+
if self.expr_type[ctx.getChild(2)] is Type.TBD:
98+
ctx.parser.notifyErrorListeners(
99+
"Using a variable while declaring it is not allowed",
100+
ctx.getChild(1).getSymbol(),
101+
)
102+
self.variables[ctx.getChild(0).getText()] = self.expr_type[ctx.getChild(2)]
103+
104+
def exitBinaryExpression(self, ctx: MyParser.BinaryExpressionContext):
105+
first = ctx.getChild(0)
106+
second = ctx.getChild(2)
107+
type_1 = self.expr_type[first]
108+
type_2 = self.expr_type[second]
109+
if ctx.op.type in [
110+
MyParser.PLUS,
111+
MyParser.MINUS,
112+
MyParser.MULTIPLY,
113+
MyParser.DIVIDE,
114+
]:
115+
if {type_1, type_2} == {Type.INT}:
116+
self.expr_type[ctx] = Type.INT
117+
elif {type_1, type_2} <= {Type.FLOAT, Type.INT}:
118+
self.expr_type[ctx] = Type.FLOAT
119+
else:
120+
ctx.parser.notifyErrorListeners(
121+
"Incompatible types in a binary operation",
122+
ctx.getChild(1).getSymbol(),
123+
)
124+
self.expr_type[ctx] = None
125+
else:
126+
if type_1 == type_2:
127+
self.expr_type[ctx] = type_1
128+
else:
129+
ctx.parser.notifyErrorListeners(
130+
"Incompatible types in a matrix binary operation",
131+
ctx.getChild(1).getSymbol(),
132+
)
133+
self.expr_type[ctx] = None
134+
135+
def exitParenthesesExpression(self, ctx: MyParser.ParenthesesExpressionContext):
136+
self.expr_type[ctx] = self.expr_type[ctx.getChild(1)]
137+
138+
def exitTransposeExpression(self, ctx: MyParser.TransposeExpressionContext):
139+
matrix = ctx.getChild(0)
140+
if ( # is a matrix
141+
isinstance(self.expr_type[matrix], tuple)
142+
and len(self.expr_type[matrix]) == 3
143+
):
144+
self.expr_type[ctx] = tuple(self.expr_type[matrix][i] for i in (0, 2, 1))
145+
else:
146+
ctx.parser.notifyErrorListeners(
147+
"Transpose operator can only be applied to matrices",
148+
ctx.getChild(1).getSymbol(),
149+
)
150+
self.expr_type[ctx] = self.expr_type[matrix]
151+
152+
def exitMinusExpression(self, ctx: MyParser.MinusExpressionContext):
153+
self.expr_type[ctx] = self.expr_type[ctx.getChild(1)]
154+
155+
def exitSingleExpression(self, ctx: MyParser.SingleExpressionContext):
156+
self.expr_type[ctx] = self.expr_type[ctx.getChild(0)]
157+
158+
def exitSpecialMatrixFunction(self, ctx: MyParser.SpecialMatrixFunctionContext):
159+
dimentions = ctx.children[2::2]
160+
for dim in dimentions:
161+
if self.expr_type[dim] != Type.INT:
162+
ctx.parser.notifyErrorListeners(
163+
"Matrix dimentions must be integers", ctx.getChild(0).getSymbol()
164+
)
165+
self.expr_type[ctx] = None
166+
return
167+
type_dimentions = []
168+
for dim in dimentions:
169+
if is_plain_integer(dim):
170+
type_dimentions.append(int(dim.getText()))
171+
else:
172+
type_dimentions.append(None)
173+
self.expr_type[ctx] = (Type.INT, *type_dimentions)
174+
175+
def exitVector(self, ctx: MyParser.VectorContext):
176+
elements = ctx.children[1::2]
177+
for i in range(1, len(elements)):
178+
if self.expr_type[elements[i]] != self.expr_type[elements[i - 1]]:
179+
wrong_token = ctx.COMMA(i) or ctx.CLOSE_BRACKET_SQUARE()
180+
ctx.parser.notifyErrorListeners(
181+
"Inconsistent types in a vector", wrong_token.getSymbol()
182+
)
183+
self.expr_type[ctx] = None
184+
return
185+
elem_type = self.expr_type[elements[1]]
186+
if isinstance(elem_type, Type):
187+
self.expr_type[ctx] = (elem_type, len(elements))
188+
else:
189+
self.expr_type[ctx] = (
190+
elem_type[0],
191+
len(elements),
192+
*elem_type[1:],
193+
)
194+
195+
def exitElementReference(self, ctx: MyParser.ElementReferenceContext):
196+
references = ctx.children[2::2]
197+
for ref in references:
198+
if self.expr_type[ref] != Type.INT:
199+
ctx.parser.notifyErrorListeners(
200+
"Indices must be integers", ctx.getChild(1).getSymbol()
201+
)
202+
self.expr_type[ctx] = None
203+
return
204+
id_type = self.expr_type[ctx.getChild(0)]
205+
if not isinstance(id_type, tuple):
206+
ctx.parser.notifyErrorListeners(
207+
"Indexing can only be applied to tensors", ctx.getChild(1).getSymbol()
208+
)
209+
self.expr_type[ctx] = None
210+
return
211+
elif len(references) > len(id_type) - 1:
212+
ctx.parser.notifyErrorListeners(
213+
"Too many indices", ctx.getChild(1).getSymbol()
214+
)
215+
self.expr_type[ctx] = None
216+
return
217+
elif len(references) < len(id_type) - 1:
218+
self.expr_type[ctx] = (id_type[0], *id_type[1 + len(references) :])
219+
else:
220+
self.expr_type[ctx] = id_type[0]
221+
222+
for i, ref in enumerate(references):
223+
if is_plain_integer(ref) and id_type[i + 1] is not None:
224+
if int(ref.getText()) >= id_type[i + 1]:
225+
ctx.parser.notifyErrorListeners(
226+
"Index out of bounds", ctx.getChild(1).getSymbol()
227+
)
228+
229+
def exitId(self, ctx: MyParser.IdContext):
230+
if ctx.getText() not in self.variables:
231+
ctx.parser.notifyErrorListeners(
232+
f"Variable {ctx.getText()} not declared", ctx.ID().getSymbol()
233+
)
234+
self.expr_type[ctx] = None
235+
else:
236+
self.expr_type[ctx] = self.variables[ctx.getText()]
237+
238+
def exitInt(self, ctx: MyParser.IntContext):
239+
self.expr_type[ctx] = Type.INT
240+
241+
def exitFloat(self, ctx: MyParser.FloatContext):
242+
self.expr_type[ctx] = Type.FLOAT
243+
244+
def exitString(self, ctx: MyParser.StringContext):
245+
self.expr_type[ctx] = Type.STRING

test_main.py

+88
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,91 @@ def test_ast(n):
3535
with open(f"tests/ast/output_{n}.txt", encoding="utf-8") as f:
3636
output = f.read()
3737
assert result.stdout == output
38+
39+
40+
def test_sem_error_break():
41+
result = runner.invoke(app, ["sem", "tests/semantic/input_break.txt"])
42+
assert result.exit_code == 0
43+
assert "line 1" in result.stdout
44+
assert "break" in result.stdout.lower()
45+
assert result.stdout.count("line") == 1
46+
47+
48+
def test_sem_error_continue():
49+
result = runner.invoke(app, ["sem", "tests/semantic/input_continue.txt"])
50+
assert result.exit_code == 0
51+
assert "line 1" in result.stdout
52+
assert "continue" in result.stdout.lower()
53+
assert result.stdout.count("line") == 1
54+
55+
56+
def test_sem_error_vector():
57+
result = runner.invoke(app, ["sem", "tests/semantic/input_vector.txt"])
58+
assert result.exit_code == 0
59+
assert "line 1" in result.stdout
60+
assert "line 3" in result.stdout
61+
assert "line 7" in result.stdout
62+
assert result.stdout.count("line") == 3
63+
assert result.stdout.lower().count("vector") == 3
64+
65+
66+
def test_sem_error_variables():
67+
result = runner.invoke(app, ["sem", "tests/semantic/input_variables.txt"])
68+
assert result.exit_code == 0
69+
assert "line 5" in result.stdout
70+
assert "line 7" in result.stdout
71+
assert result.stdout.count("line") == 2
72+
assert result.stdout.lower().count("variable") == 2
73+
74+
75+
def test_sem_error_transpose():
76+
result = runner.invoke(app, ["sem", "tests/semantic/input_transpose.txt"])
77+
assert result.exit_code == 0
78+
assert "line 7" in result.stdout
79+
assert "transpose" in result.stdout.lower()
80+
assert result.stdout.count("line") == 1
81+
82+
83+
def test_sem_error_special_matrix():
84+
result = runner.invoke(app, ["sem", "tests/semantic/input_special_matrix.txt"])
85+
assert result.exit_code == 0
86+
assert "line 1" in result.stdout
87+
assert "line 11" in result.stdout
88+
assert result.stdout.count("line") == 2
89+
90+
91+
def test_sem_error_indexing():
92+
result = runner.invoke(app, ["sem", "tests/semantic/input_indexing.txt"])
93+
assert result.exit_code == 0
94+
assert "line 5" in result.stdout
95+
assert "line 6" in result.stdout
96+
assert "line 7" in result.stdout
97+
assert result.stdout.count("line") == 3
98+
99+
100+
def test_sem_error_indexing_bounds():
101+
result = runner.invoke(app, ["sem", "tests/semantic/input_indexing_bounds.txt"])
102+
assert result.exit_code == 0
103+
assert "line 4" in result.stdout
104+
assert "line 11" in result.stdout
105+
assert "line 12" in result.stdout
106+
assert result.stdout.count("line") == 3
107+
108+
109+
def test_sem_error_binary_operations():
110+
result = runner.invoke(app, ["sem", "tests/semantic/input_binary_operations.txt"])
111+
assert result.exit_code == 0
112+
assert "line 7" in result.stdout
113+
assert "line 8" in result.stdout
114+
assert "line 14" in result.stdout
115+
assert "line 16" in result.stdout
116+
assert "line 17" in result.stdout
117+
assert result.stdout.count("line") == 5
118+
119+
120+
def test_sem_error_comparisons():
121+
result = runner.invoke(app, ["sem", "tests/semantic/input_comparisons.txt"])
122+
assert result.exit_code == 0
123+
assert "line 7" in result.stdout
124+
assert "line 9" in result.stdout
125+
assert result.stdout.count("line") == 2
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
a1 = 1;
2+
a2 = 2;
3+
4+
print a1 + a2;
5+
6+
B = zeros(2);
7+
print a1 + B;
8+
print a1 .+ B;
9+
10+
C = ones(2);
11+
print B .+ C;
12+
13+
D = eye(3);
14+
print B .+ D;
15+
16+
print 2 + "not ok";
17+
print 2 .+ "not ok";

tests/semantic/input_break.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
break;

0 commit comments

Comments
 (0)