Skip to content

Commit c4ae2ae

Browse files
committed
Parse binary operations with correct operator precedence
1 parent 902b364 commit c4ae2ae

File tree

5 files changed

+372
-112
lines changed

5 files changed

+372
-112
lines changed

parser/src/parser/parser.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ impl<'a> Parser<'a> {
929929
// https://docs.python.org/3/reference/compound_stmts.html#literal-patterns
930930
fn parse_literal_pattern(&mut self) -> Result<MatchPattern, ParsingError> {
931931
let node = self.start_node();
932-
let value = self.parse_binary_arithmetic_operation()?;
932+
let value = self.parse_binary_arithmetic_operation(0)?;
933933
Ok(MatchPattern::MatchValue(MatchValue {
934934
node: self.finish_node(node),
935935
value,
@@ -2189,15 +2189,15 @@ impl<'a> Parser<'a> {
21892189
// https://docs.python.org/3/reference/expressions.html#shifting-operations
21902190
fn parse_shift_expr(&mut self) -> Result<Expression, ParsingError> {
21912191
let node = self.start_node();
2192-
let mut arith_expr = self.parse_binary_arithmetic_operation()?;
2192+
let mut arith_expr = self.parse_binary_arithmetic_operation(0)?;
21932193
if self.at(Kind::LeftShift) || self.at(Kind::RightShift) {
21942194
let op = if self.eat(Kind::LeftShift) {
21952195
BinaryOperator::LShift
21962196
} else {
21972197
self.bump(Kind::RightShift);
21982198
BinaryOperator::RShift
21992199
};
2200-
let lhs = self.parse_binary_arithmetic_operation()?;
2200+
let lhs = self.parse_binary_arithmetic_operation(0)?;
22012201
arith_expr = Expression::BinOp(Box::new(BinOp {
22022202
node: self.finish_node(node),
22032203
op,
@@ -2209,12 +2209,24 @@ impl<'a> Parser<'a> {
22092209
}
22102210

22112211
// https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations
2212-
fn parse_binary_arithmetic_operation(&mut self) -> Result<Expression, ParsingError> {
2212+
//
2213+
fn parse_binary_arithmetic_operation(
2214+
&mut self,
2215+
min_precedence: u8,
2216+
) -> Result<Expression, ParsingError> {
22132217
let node = self.start_node();
22142218
let mut lhs = self.parse_unary_arithmetic_operation()?;
2215-
while self.cur_kind().is_bin_arithmetic_op() {
2216-
let op = self.parse_bin_arithmetic_op()?;
2217-
let rhs = self.parse_unary_arithmetic_operation()?;
2219+
while let Some((op, precedence, associativity)) = self.cur_kind().bin_op_precedence() {
2220+
if precedence < min_precedence {
2221+
break;
2222+
}
2223+
self.bump_any();
2224+
let next_precedence = match associativity {
2225+
0 => precedence + 1,
2226+
1 => precedence,
2227+
_ => unreachable!(),
2228+
};
2229+
let rhs = self.parse_binary_arithmetic_operation(next_precedence)?;
22182230
lhs = Expression::BinOp(Box::new(BinOp {
22192231
node: self.finish_node(node),
22202232
op,

parser/src/token.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use core::panic;
22
use std::fmt::Display;
33

4+
use crate::ast::BinaryOperator;
5+
46
#[derive(Debug, Clone, PartialEq)]
57
pub struct Token {
68
pub kind: Kind,
@@ -177,18 +179,18 @@ impl Kind {
177179
matches!(self, Kind::Not | Kind::BitNot | Kind::Minus | Kind::Plus)
178180
}
179181

180-
pub fn is_bin_arithmetic_op(&self) -> bool {
181-
matches!(
182-
self,
183-
Kind::Plus
184-
| Kind::Minus
185-
| Kind::Mul
186-
| Kind::MatrixMul
187-
| Kind::Div
188-
| Kind::Mod
189-
| Kind::Pow
190-
| Kind::IntDiv
191-
)
182+
pub fn bin_op_precedence(&self) -> Option<(BinaryOperator, u8, u8)> {
183+
match self {
184+
Kind::Plus => Some((BinaryOperator::Add, 9, 0)),
185+
Kind::Minus => Some((BinaryOperator::Sub, 9, 0)),
186+
Kind::Mul => Some((BinaryOperator::Mult, 10, 0)),
187+
Kind::MatrixMul => Some((BinaryOperator::MatMult, 10, 0)),
188+
Kind::Div => Some((BinaryOperator::Div, 10, 0)),
189+
Kind::Mod => Some((BinaryOperator::Mod, 10, 0)),
190+
Kind::Pow => Some((BinaryOperator::Pow, 10, 0)),
191+
Kind::IntDiv => Some((BinaryOperator::FloorDiv, 10, 0)),
192+
_ => None,
193+
}
192194
}
193195

194196
pub fn is_comparison_operator(&self) -> bool {

parser/test_data/inputs/binary_op.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
1 % 2
1212

13-
1 ** 2
13+
1**2
1414

1515
1 << 2
1616

@@ -25,3 +25,11 @@
2525
1 | 2 | 3
2626

2727
1 @ 2
28+
29+
1 + 2 * 3
30+
31+
1 * 2 + 3
32+
33+
1 ^ 2 + 3
34+
35+
3 + (1 + 2) * 3
Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
source: parser/src/lexer/mod.rs
3-
description: "1 + 2\n\n1 - 2\n\n1 * 2\n\n1 / 2\n\n1 // 2\n\n1 % 2\n\n1 ** 2\n\n1 << 2\n\n1 >> 2\n\n1 & 2\n\n1 ^ 2\n\n1 | 2\n\n1 | 2 | 3\n\n1 @ 2\n"
3+
description: "1 + 2\n\n1 - 2\n\n1 * 2\n\n1 / 2\n\n1 // 2\n\n1 % 2\n\n1**2\n\n1 << 2\n\n1 >> 2\n\n1 & 2\n\n1 ^ 2\n\n1 | 2\n\n1 | 2 | 3\n\n1 @ 2\n\n1 + 2 * 3\n\n1 * 2 + 3\n\n1 ^ 2 + 3\n\n3 + (1 + 2) * 3\n"
44
input_file: parser/test_data/inputs/binary_op.py
55
---
66
0,1: Integer (Number("1"))
@@ -34,43 +34,75 @@ input_file: parser/test_data/inputs/binary_op.py
3434
41,42: NewLine (None)
3535
42,43: NL (None)
3636
43,44: Integer (Number("1"))
37-
45,47: ** (None)
38-
48,49: Integer (Number("2"))
39-
49,50: NewLine (None)
40-
50,51: NL (None)
41-
51,52: Integer (Number("1"))
42-
53,55: << (None)
43-
56,57: Integer (Number("2"))
44-
57,58: NewLine (None)
45-
58,59: NL (None)
46-
59,60: Integer (Number("1"))
47-
61,63: >> (None)
48-
64,65: Integer (Number("2"))
49-
65,66: NewLine (None)
50-
66,67: NL (None)
51-
67,68: Integer (Number("1"))
52-
69,70: & (None)
53-
71,72: Integer (Number("2"))
54-
72,73: NewLine (None)
55-
73,74: NL (None)
56-
74,75: Integer (Number("1"))
57-
76,77: ^ (None)
58-
78,79: Integer (Number("2"))
59-
79,80: NewLine (None)
60-
80,81: NL (None)
61-
81,82: Integer (Number("1"))
62-
83,84: | (None)
63-
85,86: Integer (Number("2"))
64-
86,87: NewLine (None)
65-
87,88: NL (None)
66-
88,89: Integer (Number("1"))
67-
90,91: | (None)
68-
92,93: Integer (Number("2"))
69-
94,95: | (None)
70-
96,97: Integer (Number("3"))
71-
97,98: NewLine (None)
72-
98,99: NL (None)
73-
99,100: Integer (Number("1"))
74-
101,102: @ (None)
75-
103,104: Integer (Number("2"))
76-
104,105: NewLine (None)
37+
44,46: ** (None)
38+
46,47: Integer (Number("2"))
39+
47,48: NewLine (None)
40+
48,49: NL (None)
41+
49,50: Integer (Number("1"))
42+
51,53: << (None)
43+
54,55: Integer (Number("2"))
44+
55,56: NewLine (None)
45+
56,57: NL (None)
46+
57,58: Integer (Number("1"))
47+
59,61: >> (None)
48+
62,63: Integer (Number("2"))
49+
63,64: NewLine (None)
50+
64,65: NL (None)
51+
65,66: Integer (Number("1"))
52+
67,68: & (None)
53+
69,70: Integer (Number("2"))
54+
70,71: NewLine (None)
55+
71,72: NL (None)
56+
72,73: Integer (Number("1"))
57+
74,75: ^ (None)
58+
76,77: Integer (Number("2"))
59+
77,78: NewLine (None)
60+
78,79: NL (None)
61+
79,80: Integer (Number("1"))
62+
81,82: | (None)
63+
83,84: Integer (Number("2"))
64+
84,85: NewLine (None)
65+
85,86: NL (None)
66+
86,87: Integer (Number("1"))
67+
88,89: | (None)
68+
90,91: Integer (Number("2"))
69+
92,93: | (None)
70+
94,95: Integer (Number("3"))
71+
95,96: NewLine (None)
72+
96,97: NL (None)
73+
97,98: Integer (Number("1"))
74+
99,100: @ (None)
75+
101,102: Integer (Number("2"))
76+
102,103: NewLine (None)
77+
103,104: NL (None)
78+
104,105: Integer (Number("1"))
79+
106,107: + (None)
80+
108,109: Integer (Number("2"))
81+
110,111: * (None)
82+
112,113: Integer (Number("3"))
83+
113,114: NewLine (None)
84+
114,115: NL (None)
85+
115,116: Integer (Number("1"))
86+
117,118: * (None)
87+
119,120: Integer (Number("2"))
88+
121,122: + (None)
89+
123,124: Integer (Number("3"))
90+
124,125: NewLine (None)
91+
125,126: NL (None)
92+
126,127: Integer (Number("1"))
93+
128,129: ^ (None)
94+
130,131: Integer (Number("2"))
95+
132,133: + (None)
96+
134,135: Integer (Number("3"))
97+
135,136: NewLine (None)
98+
136,137: NL (None)
99+
137,138: Integer (Number("3"))
100+
139,140: + (None)
101+
141,142: ( (None)
102+
142,143: Integer (Number("1"))
103+
144,145: + (None)
104+
146,147: Integer (Number("2"))
105+
147,148: ) (None)
106+
149,150: * (None)
107+
151,152: Integer (Number("3"))
108+
152,153: NewLine (None)

0 commit comments

Comments
 (0)