Skip to content

Commit 6a04e6e

Browse files
committed
The SyntaxError should be raised when there are multiple assignments to the same name in a pattern
1 parent 7b969d2 commit 6a04e6e

File tree

1 file changed

+138
-13
lines changed

1 file changed

+138
-13
lines changed

monic/expressions/interpreter.py

Lines changed: 138 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2435,12 +2435,14 @@ def _match_sequence_pattern(
24352435
self,
24362436
pattern: ast.MatchSequence,
24372437
value: t.Any,
2438+
pattern_vars: set[str],
24382439
) -> bool:
24392440
"""Match a sequence pattern.
24402441
24412442
Args:
24422443
pattern: MatchSequence AST node
24432444
value: Value to match against
2445+
pattern_vars: Set of pattern variables already bound in this pattern
24442446
24452447
Returns:
24462448
Whether the pattern matches the value
@@ -2456,40 +2458,51 @@ def _match_sequence_pattern(
24562458
break
24572459

24582460
if star_idx == -1:
2459-
return self._match_fixed_sequence(pattern.patterns, value)
2461+
return self._match_fixed_sequence(
2462+
pattern.patterns, value, pattern_vars
2463+
)
24602464
else:
2461-
return self._match_star_sequence(pattern.patterns, value, star_idx)
2465+
return self._match_star_sequence(
2466+
pattern.patterns, value, star_idx, pattern_vars
2467+
)
24622468

24632469
def _match_fixed_sequence(
24642470
self,
24652471
patterns: list[ast.pattern],
24662472
value: t.Any,
2473+
pattern_vars: set[str],
24672474
) -> bool:
24682475
"""Match a sequence pattern without star expressions.
24692476
24702477
Args:
24712478
patterns: List of pattern AST nodes
24722479
value: Value to match against
2480+
pattern_vars: Set of pattern variables already bound in this pattern
24732481
24742482
Returns:
24752483
Whether the pattern matches the value
24762484
"""
24772485
if len(patterns) != len(value):
24782486
return False
2479-
return all(self._match_pattern(p, v) for p, v in zip(patterns, value))
2487+
return all(
2488+
self._match_pattern(p, v, pattern_vars)
2489+
for p, v in zip(patterns, value)
2490+
)
24802491

24812492
def _match_star_sequence(
24822493
self,
24832494
patterns: list[ast.pattern],
24842495
value: t.Any,
24852496
star_idx: int,
2497+
pattern_vars: set[str],
24862498
) -> bool:
24872499
"""Match a sequence pattern with a star expression.
24882500
24892501
Args:
24902502
patterns: List of pattern AST nodes
24912503
value: Value to match against
24922504
star_idx: Index of the star pattern
2505+
pattern_vars: Set of pattern variables already bound in this pattern
24932506
24942507
Returns:
24952508
Whether the pattern matches the value
@@ -2499,7 +2512,7 @@ def _match_star_sequence(
24992512

25002513
# Match patterns before star
25012514
for p, v in zip(patterns[:star_idx], value[:star_idx]):
2502-
if not self._match_pattern(p, v):
2515+
if not self._match_pattern(p, v, pattern_vars):
25032516
return False
25042517

25052518
# Calculate remaining elements after star
@@ -2510,7 +2523,7 @@ def _match_star_sequence(
25102523
patterns[star_idx + 1 :],
25112524
value[-remaining_count:] if remaining_count > 0 else [],
25122525
):
2513-
if not self._match_pattern(p, v):
2526+
if not self._match_pattern(p, v, pattern_vars):
25142527
return False
25152528

25162529
# Bind star pattern if it has a name
@@ -2675,32 +2688,143 @@ def _match_pattern(
26752688
self,
26762689
pattern: ast.pattern,
26772690
value: t.Any,
2691+
pattern_vars: set[str] | None = None,
26782692
) -> bool:
26792693
"""Match a pattern against a value.
26802694
26812695
Args:
26822696
pattern: Pattern AST node
26832697
value: Value to match against
2698+
pattern_vars: Set of pattern variables already bound in this pattern
26842699
26852700
Returns:
26862701
Whether the pattern matches the value
2702+
2703+
Raises:
2704+
SyntaxError: If a pattern variable is bound multiple times
26872705
"""
2706+
# Initialize pattern_vars on first call
2707+
if pattern_vars is None:
2708+
pattern_vars = set()
2709+
2710+
# Helper to check and add pattern variables
2711+
def check_pattern_var(name: str) -> None:
2712+
if name in pattern_vars:
2713+
raise SyntaxError(
2714+
f"multiple assignments to name '{name}' in pattern"
2715+
)
2716+
pattern_vars.add(name)
2717+
26882718
if isinstance(pattern, ast.MatchValue):
26892719
return self._match_value_pattern(pattern, value)
26902720
elif isinstance(pattern, ast.MatchSingleton):
26912721
return value is pattern.value
26922722
elif isinstance(pattern, ast.MatchSequence):
2693-
return self._match_sequence_pattern(pattern, value)
2723+
if not isinstance(value, (list, tuple)):
2724+
return False
2725+
2726+
# Find star pattern index if exists
2727+
star_idx = -1
2728+
for i, p in enumerate(pattern.patterns):
2729+
if isinstance(p, ast.MatchStar):
2730+
star_idx = i
2731+
if p.name:
2732+
check_pattern_var(p.name)
2733+
break
2734+
2735+
if star_idx == -1:
2736+
return self._match_fixed_sequence(
2737+
pattern.patterns, value, pattern_vars
2738+
)
2739+
else:
2740+
return self._match_star_sequence(
2741+
pattern.patterns, value, star_idx, pattern_vars
2742+
)
26942743
elif isinstance(pattern, ast.MatchMapping):
2695-
return self._match_mapping_pattern(pattern, value)
2744+
if not isinstance(value, dict):
2745+
return False
2746+
2747+
# Check if all required keys are present
2748+
for key in pattern.keys:
2749+
key_value = self.visit(key)
2750+
if key_value not in value:
2751+
return False
2752+
2753+
# Match each key-pattern pair
2754+
for key, pat in zip(pattern.keys, pattern.patterns):
2755+
key_value = self.visit(key)
2756+
if not self._match_pattern(pat, value[key_value], pattern_vars):
2757+
return False
2758+
2759+
# Handle rest pattern if present
2760+
if pattern.rest is not None:
2761+
check_pattern_var(pattern.rest)
2762+
rest_dict = {
2763+
k: v
2764+
for k, v in value.items()
2765+
if not any(self.visit(key) == k for key in pattern.keys)
2766+
}
2767+
self._set_name_value(pattern.rest, rest_dict)
2768+
self.current_scope.locals.add(pattern.rest)
2769+
2770+
return True
26962771
elif isinstance(pattern, ast.MatchStar):
2697-
return self._match_star_pattern(pattern, value)
2772+
if pattern.name:
2773+
check_pattern_var(pattern.name)
2774+
return True
26982775
elif isinstance(pattern, ast.MatchAs):
2699-
return self._match_as_pattern(pattern, value)
2776+
if pattern.pattern is not None:
2777+
if not self._match_pattern(
2778+
pattern.pattern, value, pattern_vars
2779+
):
2780+
return False
2781+
if pattern.name is not None:
2782+
check_pattern_var(pattern.name)
2783+
self._set_name_value(pattern.name, value)
2784+
self.current_scope.locals.add(pattern.name)
2785+
return True
27002786
elif isinstance(pattern, ast.MatchOr):
2701-
return self._match_or_pattern(pattern, value)
2787+
# Try each alternative pattern
2788+
matched_vars = None
2789+
for p in pattern.patterns:
2790+
alt_vars: set[str] = set()
2791+
if self._match_pattern(p, value, alt_vars):
2792+
if matched_vars is None:
2793+
matched_vars = alt_vars
2794+
elif matched_vars != alt_vars:
2795+
raise SyntaxError(
2796+
"alternative patterns bind different names"
2797+
)
2798+
pattern_vars.update(alt_vars)
2799+
return True
2800+
return False
27022801
elif isinstance(pattern, ast.MatchClass):
2703-
return self._match_class_pattern(pattern, value)
2802+
cls = self.visit(pattern.cls)
2803+
if not isinstance(value, cls):
2804+
return False
2805+
2806+
# Get positional attributes from __match_args__
2807+
match_args = getattr(cls, "__match_args__", ())
2808+
if len(pattern.patterns) > len(match_args):
2809+
return False
2810+
2811+
# Match positional patterns
2812+
for pat, attr_name in zip(pattern.patterns, match_args):
2813+
if not self._match_pattern(
2814+
pat, getattr(value, attr_name), pattern_vars
2815+
):
2816+
return False
2817+
2818+
# Match keyword patterns
2819+
for name, pat in zip(pattern.kwd_attrs, pattern.kwd_patterns):
2820+
if not hasattr(value, name):
2821+
return False
2822+
if not self._match_pattern(
2823+
pat, getattr(value, name), pattern_vars
2824+
):
2825+
return False
2826+
2827+
return True
27042828

27052829
return False
27062830

@@ -2732,8 +2856,9 @@ def visit_Match(self, node: ast.Match) -> None:
27322856

27332857
# Create a temporary scope for pattern matching
27342858
with ScopeContext(self):
2735-
# Try to match the pattern
2736-
if not self._match_pattern(pattern, subject):
2859+
# Try to match the pattern with a new pattern_vars set
2860+
pattern_vars: set[str] = set()
2861+
if not self._match_pattern(pattern, subject, pattern_vars):
27372862
# If no match, continue to the next case
27382863
continue
27392864

0 commit comments

Comments
 (0)