@@ -2435,12 +2435,14 @@ def _match_sequence_pattern(
24352435 self ,
24362436 pattern : ast .MatchSequence ,
24372437 value : t .Any ,
2438+ pattern_vars : set [str ],
24382439 ) -> bool :
24392440 """Match a sequence pattern.
24402441
24412442 Args:
24422443 pattern: MatchSequence AST node
24432444 value: Value to match against
2445+ pattern_vars: Set of pattern variables already bound in this pattern
24442446
24452447 Returns:
24462448 Whether the pattern matches the value
@@ -2456,40 +2458,51 @@ def _match_sequence_pattern(
24562458 break
24572459
24582460 if star_idx == - 1 :
2459- return self ._match_fixed_sequence (pattern .patterns , value )
2461+ return self ._match_fixed_sequence (
2462+ pattern .patterns , value , pattern_vars
2463+ )
24602464 else :
2461- return self ._match_star_sequence (pattern .patterns , value , star_idx )
2465+ return self ._match_star_sequence (
2466+ pattern .patterns , value , star_idx , pattern_vars
2467+ )
24622468
24632469 def _match_fixed_sequence (
24642470 self ,
24652471 patterns : list [ast .pattern ],
24662472 value : t .Any ,
2473+ pattern_vars : set [str ],
24672474 ) -> bool :
24682475 """Match a sequence pattern without star expressions.
24692476
24702477 Args:
24712478 patterns: List of pattern AST nodes
24722479 value: Value to match against
2480+ pattern_vars: Set of pattern variables already bound in this pattern
24732481
24742482 Returns:
24752483 Whether the pattern matches the value
24762484 """
24772485 if len (patterns ) != len (value ):
24782486 return False
2479- return all (self ._match_pattern (p , v ) for p , v in zip (patterns , value ))
2487+ return all (
2488+ self ._match_pattern (p , v , pattern_vars )
2489+ for p , v in zip (patterns , value )
2490+ )
24802491
24812492 def _match_star_sequence (
24822493 self ,
24832494 patterns : list [ast .pattern ],
24842495 value : t .Any ,
24852496 star_idx : int ,
2497+ pattern_vars : set [str ],
24862498 ) -> bool :
24872499 """Match a sequence pattern with a star expression.
24882500
24892501 Args:
24902502 patterns: List of pattern AST nodes
24912503 value: Value to match against
24922504 star_idx: Index of the star pattern
2505+ pattern_vars: Set of pattern variables already bound in this pattern
24932506
24942507 Returns:
24952508 Whether the pattern matches the value
@@ -2499,7 +2512,7 @@ def _match_star_sequence(
24992512
25002513 # Match patterns before star
25012514 for p , v in zip (patterns [:star_idx ], value [:star_idx ]):
2502- if not self ._match_pattern (p , v ):
2515+ if not self ._match_pattern (p , v , pattern_vars ):
25032516 return False
25042517
25052518 # Calculate remaining elements after star
@@ -2510,7 +2523,7 @@ def _match_star_sequence(
25102523 patterns [star_idx + 1 :],
25112524 value [- remaining_count :] if remaining_count > 0 else [],
25122525 ):
2513- if not self ._match_pattern (p , v ):
2526+ if not self ._match_pattern (p , v , pattern_vars ):
25142527 return False
25152528
25162529 # Bind star pattern if it has a name
@@ -2675,32 +2688,143 @@ def _match_pattern(
26752688 self ,
26762689 pattern : ast .pattern ,
26772690 value : t .Any ,
2691+ pattern_vars : set [str ] | None = None ,
26782692 ) -> bool :
26792693 """Match a pattern against a value.
26802694
26812695 Args:
26822696 pattern: Pattern AST node
26832697 value: Value to match against
2698+ pattern_vars: Set of pattern variables already bound in this pattern
26842699
26852700 Returns:
26862701 Whether the pattern matches the value
2702+
2703+ Raises:
2704+ SyntaxError: If a pattern variable is bound multiple times
26872705 """
2706+ # Initialize pattern_vars on first call
2707+ if pattern_vars is None :
2708+ pattern_vars = set ()
2709+
2710+ # Helper to check and add pattern variables
2711+ def check_pattern_var (name : str ) -> None :
2712+ if name in pattern_vars :
2713+ raise SyntaxError (
2714+ f"multiple assignments to name '{ name } ' in pattern"
2715+ )
2716+ pattern_vars .add (name )
2717+
26882718 if isinstance (pattern , ast .MatchValue ):
26892719 return self ._match_value_pattern (pattern , value )
26902720 elif isinstance (pattern , ast .MatchSingleton ):
26912721 return value is pattern .value
26922722 elif isinstance (pattern , ast .MatchSequence ):
2693- return self ._match_sequence_pattern (pattern , value )
2723+ if not isinstance (value , (list , tuple )):
2724+ return False
2725+
2726+ # Find star pattern index if exists
2727+ star_idx = - 1
2728+ for i , p in enumerate (pattern .patterns ):
2729+ if isinstance (p , ast .MatchStar ):
2730+ star_idx = i
2731+ if p .name :
2732+ check_pattern_var (p .name )
2733+ break
2734+
2735+ if star_idx == - 1 :
2736+ return self ._match_fixed_sequence (
2737+ pattern .patterns , value , pattern_vars
2738+ )
2739+ else :
2740+ return self ._match_star_sequence (
2741+ pattern .patterns , value , star_idx , pattern_vars
2742+ )
26942743 elif isinstance (pattern , ast .MatchMapping ):
2695- return self ._match_mapping_pattern (pattern , value )
2744+ if not isinstance (value , dict ):
2745+ return False
2746+
2747+ # Check if all required keys are present
2748+ for key in pattern .keys :
2749+ key_value = self .visit (key )
2750+ if key_value not in value :
2751+ return False
2752+
2753+ # Match each key-pattern pair
2754+ for key , pat in zip (pattern .keys , pattern .patterns ):
2755+ key_value = self .visit (key )
2756+ if not self ._match_pattern (pat , value [key_value ], pattern_vars ):
2757+ return False
2758+
2759+ # Handle rest pattern if present
2760+ if pattern .rest is not None :
2761+ check_pattern_var (pattern .rest )
2762+ rest_dict = {
2763+ k : v
2764+ for k , v in value .items ()
2765+ if not any (self .visit (key ) == k for key in pattern .keys )
2766+ }
2767+ self ._set_name_value (pattern .rest , rest_dict )
2768+ self .current_scope .locals .add (pattern .rest )
2769+
2770+ return True
26962771 elif isinstance (pattern , ast .MatchStar ):
2697- return self ._match_star_pattern (pattern , value )
2772+ if pattern .name :
2773+ check_pattern_var (pattern .name )
2774+ return True
26982775 elif isinstance (pattern , ast .MatchAs ):
2699- return self ._match_as_pattern (pattern , value )
2776+ if pattern .pattern is not None :
2777+ if not self ._match_pattern (
2778+ pattern .pattern , value , pattern_vars
2779+ ):
2780+ return False
2781+ if pattern .name is not None :
2782+ check_pattern_var (pattern .name )
2783+ self ._set_name_value (pattern .name , value )
2784+ self .current_scope .locals .add (pattern .name )
2785+ return True
27002786 elif isinstance (pattern , ast .MatchOr ):
2701- return self ._match_or_pattern (pattern , value )
2787+ # Try each alternative pattern
2788+ matched_vars = None
2789+ for p in pattern .patterns :
2790+ alt_vars : set [str ] = set ()
2791+ if self ._match_pattern (p , value , alt_vars ):
2792+ if matched_vars is None :
2793+ matched_vars = alt_vars
2794+ elif matched_vars != alt_vars :
2795+ raise SyntaxError (
2796+ "alternative patterns bind different names"
2797+ )
2798+ pattern_vars .update (alt_vars )
2799+ return True
2800+ return False
27022801 elif isinstance (pattern , ast .MatchClass ):
2703- return self ._match_class_pattern (pattern , value )
2802+ cls = self .visit (pattern .cls )
2803+ if not isinstance (value , cls ):
2804+ return False
2805+
2806+ # Get positional attributes from __match_args__
2807+ match_args = getattr (cls , "__match_args__" , ())
2808+ if len (pattern .patterns ) > len (match_args ):
2809+ return False
2810+
2811+ # Match positional patterns
2812+ for pat , attr_name in zip (pattern .patterns , match_args ):
2813+ if not self ._match_pattern (
2814+ pat , getattr (value , attr_name ), pattern_vars
2815+ ):
2816+ return False
2817+
2818+ # Match keyword patterns
2819+ for name , pat in zip (pattern .kwd_attrs , pattern .kwd_patterns ):
2820+ if not hasattr (value , name ):
2821+ return False
2822+ if not self ._match_pattern (
2823+ pat , getattr (value , name ), pattern_vars
2824+ ):
2825+ return False
2826+
2827+ return True
27042828
27052829 return False
27062830
@@ -2732,8 +2856,9 @@ def visit_Match(self, node: ast.Match) -> None:
27322856
27332857 # Create a temporary scope for pattern matching
27342858 with ScopeContext (self ):
2735- # Try to match the pattern
2736- if not self ._match_pattern (pattern , subject ):
2859+ # Try to match the pattern with a new pattern_vars set
2860+ pattern_vars : set [str ] = set ()
2861+ if not self ._match_pattern (pattern , subject , pattern_vars ):
27372862 # If no match, continue to the next case
27382863 continue
27392864
0 commit comments