Skip to content

Commit 177f30c

Browse files
committed
Code cleanup
1 parent 6a04e6e commit 177f30c

File tree

1 file changed

+59
-156
lines changed

1 file changed

+59
-156
lines changed

monic/expressions/interpreter.py

Lines changed: 59 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -2543,12 +2543,14 @@ def _match_mapping_pattern(
25432543
self,
25442544
pattern: ast.MatchMapping,
25452545
value: t.Any,
2546+
pattern_vars: set[str],
25462547
) -> bool:
25472548
"""Match a mapping pattern.
25482549
25492550
Args:
25502551
pattern: MatchMapping AST node
25512552
value: Value to match against
2553+
pattern_vars: Set of pattern variables already bound in this pattern
25522554
25532555
Returns:
25542556
Whether the pattern matches the value
@@ -2565,11 +2567,16 @@ def _match_mapping_pattern(
25652567
# Match each key-pattern pair
25662568
for key, pat in zip(pattern.keys, pattern.patterns):
25672569
key_value = self.visit(key)
2568-
if not self._match_pattern(pat, value[key_value]):
2570+
if not self._match_pattern(pat, value[key_value], pattern_vars):
25692571
return False
25702572

25712573
# Handle rest pattern if present
25722574
if pattern.rest is not None:
2575+
if pattern.rest in pattern_vars:
2576+
raise SyntaxError(
2577+
f"multiple assignments to name '{pattern.rest}' in pattern"
2578+
)
2579+
pattern_vars.add(pattern.rest)
25732580
rest_dict = {
25742581
k: v
25752582
for k, v in value.items()
@@ -2580,87 +2587,31 @@ def _match_mapping_pattern(
25802587

25812588
return True
25822589

2583-
def _match_or_pattern(
2584-
self,
2585-
pattern: ast.MatchOr,
2586-
value: t.Any,
2587-
) -> bool:
2588-
"""Match an OR pattern.
2589-
2590-
Args:
2591-
pattern: MatchOr AST node
2592-
value: Value to match against
2593-
2594-
Returns:
2595-
Whether the pattern matches the value
2596-
"""
2597-
for p in pattern.patterns:
2598-
# Create a temporary scope for each OR pattern
2599-
# to avoid variable binding conflicts
2600-
temp_scope = Scope()
2601-
self.scope_stack.append(temp_scope)
2602-
try:
2603-
if self._match_pattern(p, value):
2604-
return True
2605-
finally:
2606-
self.scope_stack.pop()
2607-
return False
2608-
2609-
def _match_class_pattern(
2610-
self,
2611-
pattern: ast.MatchClass,
2612-
value: t.Any,
2613-
) -> bool:
2614-
"""Match a class pattern.
2615-
2616-
Args:
2617-
pattern: MatchClass AST node
2618-
value: Value to match against
2619-
2620-
Returns:
2621-
Whether the pattern matches the value
2622-
"""
2623-
cls = self.visit(pattern.cls)
2624-
if not isinstance(value, cls):
2625-
return False
2626-
2627-
# Get positional attributes from __match_args__
2628-
match_args = getattr(cls, "__match_args__", ())
2629-
if len(pattern.patterns) > len(match_args):
2630-
return False
2631-
2632-
# Match positional patterns
2633-
for pat, attr_name in zip(pattern.patterns, match_args):
2634-
if not self._match_pattern(pat, getattr(value, attr_name)):
2635-
return False
2636-
2637-
# Match keyword patterns
2638-
for name, pat in zip(pattern.kwd_attrs, pattern.kwd_patterns):
2639-
if not hasattr(value, name):
2640-
return False
2641-
if not self._match_pattern(pat, getattr(value, name)):
2642-
return False
2643-
2644-
return True
2645-
26462590
def _match_as_pattern(
26472591
self,
26482592
pattern: ast.MatchAs,
26492593
value: t.Any,
2594+
pattern_vars: set[str],
26502595
) -> bool:
26512596
"""Match an AS pattern.
26522597
26532598
Args:
26542599
pattern: MatchAs AST node
26552600
value: Value to match against
2601+
pattern_vars: Set of pattern variables already bound in this pattern
26562602
26572603
Returns:
26582604
Whether the pattern matches the value
26592605
"""
26602606
if pattern.pattern is not None:
2661-
if not self._match_pattern(pattern.pattern, value):
2607+
if not self._match_pattern(pattern.pattern, value, pattern_vars):
26622608
return False
26632609
if pattern.name is not None:
2610+
if pattern.name in pattern_vars:
2611+
raise SyntaxError(
2612+
f"multiple assignments to name '{pattern.name}' in pattern"
2613+
)
2614+
pattern_vars.add(pattern.name)
26642615
self._set_name_value(pattern.name, value)
26652616
self.current_scope.locals.add(pattern.name)
26662617
return True
@@ -2707,82 +2658,18 @@ def _match_pattern(
27072658
if pattern_vars is None:
27082659
pattern_vars = set()
27092660

2710-
# Helper to check and add pattern variables
2711-
def check_pattern_var(name: str) -> None:
2712-
if name in pattern_vars:
2713-
raise SyntaxError(
2714-
f"multiple assignments to name '{name}' in pattern"
2715-
)
2716-
pattern_vars.add(name)
2717-
27182661
if isinstance(pattern, ast.MatchValue):
27192662
return self._match_value_pattern(pattern, value)
27202663
elif isinstance(pattern, ast.MatchSingleton):
27212664
return value is pattern.value
27222665
elif isinstance(pattern, ast.MatchSequence):
2723-
if not isinstance(value, (list, tuple)):
2724-
return False
2725-
2726-
# Find star pattern index if exists
2727-
star_idx = -1
2728-
for i, p in enumerate(pattern.patterns):
2729-
if isinstance(p, ast.MatchStar):
2730-
star_idx = i
2731-
if p.name:
2732-
check_pattern_var(p.name)
2733-
break
2734-
2735-
if star_idx == -1:
2736-
return self._match_fixed_sequence(
2737-
pattern.patterns, value, pattern_vars
2738-
)
2739-
else:
2740-
return self._match_star_sequence(
2741-
pattern.patterns, value, star_idx, pattern_vars
2742-
)
2666+
return self._match_sequence_pattern(pattern, value, pattern_vars)
27432667
elif isinstance(pattern, ast.MatchMapping):
2744-
if not isinstance(value, dict):
2745-
return False
2746-
2747-
# Check if all required keys are present
2748-
for key in pattern.keys:
2749-
key_value = self.visit(key)
2750-
if key_value not in value:
2751-
return False
2752-
2753-
# Match each key-pattern pair
2754-
for key, pat in zip(pattern.keys, pattern.patterns):
2755-
key_value = self.visit(key)
2756-
if not self._match_pattern(pat, value[key_value], pattern_vars):
2757-
return False
2758-
2759-
# Handle rest pattern if present
2760-
if pattern.rest is not None:
2761-
check_pattern_var(pattern.rest)
2762-
rest_dict = {
2763-
k: v
2764-
for k, v in value.items()
2765-
if not any(self.visit(key) == k for key in pattern.keys)
2766-
}
2767-
self._set_name_value(pattern.rest, rest_dict)
2768-
self.current_scope.locals.add(pattern.rest)
2769-
2770-
return True
2668+
return self._match_mapping_pattern(pattern, value, pattern_vars)
27712669
elif isinstance(pattern, ast.MatchStar):
2772-
if pattern.name:
2773-
check_pattern_var(pattern.name)
2774-
return True
2670+
return self._match_star_pattern(pattern, value)
27752671
elif isinstance(pattern, ast.MatchAs):
2776-
if pattern.pattern is not None:
2777-
if not self._match_pattern(
2778-
pattern.pattern, value, pattern_vars
2779-
):
2780-
return False
2781-
if pattern.name is not None:
2782-
check_pattern_var(pattern.name)
2783-
self._set_name_value(pattern.name, value)
2784-
self.current_scope.locals.add(pattern.name)
2785-
return True
2672+
return self._match_as_pattern(pattern, value, pattern_vars)
27862673
elif isinstance(pattern, ast.MatchOr):
27872674
# Try each alternative pattern
27882675
matched_vars = None
@@ -2799,34 +2686,50 @@ def check_pattern_var(name: str) -> None:
27992686
return True
28002687
return False
28012688
elif isinstance(pattern, ast.MatchClass):
2802-
cls = self.visit(pattern.cls)
2803-
if not isinstance(value, cls):
2804-
return False
2689+
return self._match_class_pattern(pattern, value, pattern_vars)
28052690

2806-
# Get positional attributes from __match_args__
2807-
match_args = getattr(cls, "__match_args__", ())
2808-
if len(pattern.patterns) > len(match_args):
2809-
return False
2691+
return False
28102692

2811-
# Match positional patterns
2812-
for pat, attr_name in zip(pattern.patterns, match_args):
2813-
if not self._match_pattern(
2814-
pat, getattr(value, attr_name), pattern_vars
2815-
):
2816-
return False
2693+
def _match_class_pattern(
2694+
self,
2695+
pattern: ast.MatchClass,
2696+
value: t.Any,
2697+
pattern_vars: set[str],
2698+
) -> bool:
2699+
"""Match a class pattern.
28172700
2818-
# Match keyword patterns
2819-
for name, pat in zip(pattern.kwd_attrs, pattern.kwd_patterns):
2820-
if not hasattr(value, name):
2821-
return False
2822-
if not self._match_pattern(
2823-
pat, getattr(value, name), pattern_vars
2824-
):
2825-
return False
2701+
Args:
2702+
pattern: MatchClass AST node
2703+
value: Value to match against
2704+
pattern_vars: Set of pattern variables already bound in this pattern
28262705
2827-
return True
2706+
Returns:
2707+
Whether the pattern matches the value
2708+
"""
2709+
cls = self.visit(pattern.cls)
2710+
if not isinstance(value, cls):
2711+
return False
28282712

2829-
return False
2713+
# Get positional attributes from __match_args__
2714+
match_args = getattr(cls, "__match_args__", ())
2715+
if len(pattern.patterns) > len(match_args):
2716+
return False
2717+
2718+
# Match positional patterns
2719+
for pat, attr_name in zip(pattern.patterns, match_args):
2720+
if not self._match_pattern(
2721+
pat, getattr(value, attr_name), pattern_vars
2722+
):
2723+
return False
2724+
2725+
# Match keyword patterns
2726+
for name, pat in zip(pattern.kwd_attrs, pattern.kwd_patterns):
2727+
if not hasattr(value, name):
2728+
return False
2729+
if not self._match_pattern(pat, getattr(value, name), pattern_vars):
2730+
return False
2731+
2732+
return True
28302733

28312734
def visit_Match(self, node: ast.Match) -> None:
28322735
"""Handle match-case statements.

0 commit comments

Comments
 (0)