@@ -2543,12 +2543,14 @@ def _match_mapping_pattern(
25432543 self ,
25442544 pattern : ast .MatchMapping ,
25452545 value : t .Any ,
2546+ pattern_vars : set [str ],
25462547 ) -> bool :
25472548 """Match a mapping pattern.
25482549
25492550 Args:
25502551 pattern: MatchMapping AST node
25512552 value: Value to match against
2553+ pattern_vars: Set of pattern variables already bound in this pattern
25522554
25532555 Returns:
25542556 Whether the pattern matches the value
@@ -2565,11 +2567,16 @@ def _match_mapping_pattern(
25652567 # Match each key-pattern pair
25662568 for key , pat in zip (pattern .keys , pattern .patterns ):
25672569 key_value = self .visit (key )
2568- if not self ._match_pattern (pat , value [key_value ]):
2570+ if not self ._match_pattern (pat , value [key_value ], pattern_vars ):
25692571 return False
25702572
25712573 # Handle rest pattern if present
25722574 if pattern .rest is not None :
2575+ if pattern .rest in pattern_vars :
2576+ raise SyntaxError (
2577+ f"multiple assignments to name '{ pattern .rest } ' in pattern"
2578+ )
2579+ pattern_vars .add (pattern .rest )
25732580 rest_dict = {
25742581 k : v
25752582 for k , v in value .items ()
@@ -2580,87 +2587,31 @@ def _match_mapping_pattern(
25802587
25812588 return True
25822589
2583- def _match_or_pattern (
2584- self ,
2585- pattern : ast .MatchOr ,
2586- value : t .Any ,
2587- ) -> bool :
2588- """Match an OR pattern.
2589-
2590- Args:
2591- pattern: MatchOr AST node
2592- value: Value to match against
2593-
2594- Returns:
2595- Whether the pattern matches the value
2596- """
2597- for p in pattern .patterns :
2598- # Create a temporary scope for each OR pattern
2599- # to avoid variable binding conflicts
2600- temp_scope = Scope ()
2601- self .scope_stack .append (temp_scope )
2602- try :
2603- if self ._match_pattern (p , value ):
2604- return True
2605- finally :
2606- self .scope_stack .pop ()
2607- return False
2608-
2609- def _match_class_pattern (
2610- self ,
2611- pattern : ast .MatchClass ,
2612- value : t .Any ,
2613- ) -> bool :
2614- """Match a class pattern.
2615-
2616- Args:
2617- pattern: MatchClass AST node
2618- value: Value to match against
2619-
2620- Returns:
2621- Whether the pattern matches the value
2622- """
2623- cls = self .visit (pattern .cls )
2624- if not isinstance (value , cls ):
2625- return False
2626-
2627- # Get positional attributes from __match_args__
2628- match_args = getattr (cls , "__match_args__" , ())
2629- if len (pattern .patterns ) > len (match_args ):
2630- return False
2631-
2632- # Match positional patterns
2633- for pat , attr_name in zip (pattern .patterns , match_args ):
2634- if not self ._match_pattern (pat , getattr (value , attr_name )):
2635- return False
2636-
2637- # Match keyword patterns
2638- for name , pat in zip (pattern .kwd_attrs , pattern .kwd_patterns ):
2639- if not hasattr (value , name ):
2640- return False
2641- if not self ._match_pattern (pat , getattr (value , name )):
2642- return False
2643-
2644- return True
2645-
26462590 def _match_as_pattern (
26472591 self ,
26482592 pattern : ast .MatchAs ,
26492593 value : t .Any ,
2594+ pattern_vars : set [str ],
26502595 ) -> bool :
26512596 """Match an AS pattern.
26522597
26532598 Args:
26542599 pattern: MatchAs AST node
26552600 value: Value to match against
2601+ pattern_vars: Set of pattern variables already bound in this pattern
26562602
26572603 Returns:
26582604 Whether the pattern matches the value
26592605 """
26602606 if pattern .pattern is not None :
2661- if not self ._match_pattern (pattern .pattern , value ):
2607+ if not self ._match_pattern (pattern .pattern , value , pattern_vars ):
26622608 return False
26632609 if pattern .name is not None :
2610+ if pattern .name in pattern_vars :
2611+ raise SyntaxError (
2612+ f"multiple assignments to name '{ pattern .name } ' in pattern"
2613+ )
2614+ pattern_vars .add (pattern .name )
26642615 self ._set_name_value (pattern .name , value )
26652616 self .current_scope .locals .add (pattern .name )
26662617 return True
@@ -2707,82 +2658,18 @@ def _match_pattern(
27072658 if pattern_vars is None :
27082659 pattern_vars = set ()
27092660
2710- # Helper to check and add pattern variables
2711- def check_pattern_var (name : str ) -> None :
2712- if name in pattern_vars :
2713- raise SyntaxError (
2714- f"multiple assignments to name '{ name } ' in pattern"
2715- )
2716- pattern_vars .add (name )
2717-
27182661 if isinstance (pattern , ast .MatchValue ):
27192662 return self ._match_value_pattern (pattern , value )
27202663 elif isinstance (pattern , ast .MatchSingleton ):
27212664 return value is pattern .value
27222665 elif isinstance (pattern , ast .MatchSequence ):
2723- if not isinstance (value , (list , tuple )):
2724- return False
2725-
2726- # Find star pattern index if exists
2727- star_idx = - 1
2728- for i , p in enumerate (pattern .patterns ):
2729- if isinstance (p , ast .MatchStar ):
2730- star_idx = i
2731- if p .name :
2732- check_pattern_var (p .name )
2733- break
2734-
2735- if star_idx == - 1 :
2736- return self ._match_fixed_sequence (
2737- pattern .patterns , value , pattern_vars
2738- )
2739- else :
2740- return self ._match_star_sequence (
2741- pattern .patterns , value , star_idx , pattern_vars
2742- )
2666+ return self ._match_sequence_pattern (pattern , value , pattern_vars )
27432667 elif isinstance (pattern , ast .MatchMapping ):
2744- if not isinstance (value , dict ):
2745- return False
2746-
2747- # Check if all required keys are present
2748- for key in pattern .keys :
2749- key_value = self .visit (key )
2750- if key_value not in value :
2751- return False
2752-
2753- # Match each key-pattern pair
2754- for key , pat in zip (pattern .keys , pattern .patterns ):
2755- key_value = self .visit (key )
2756- if not self ._match_pattern (pat , value [key_value ], pattern_vars ):
2757- return False
2758-
2759- # Handle rest pattern if present
2760- if pattern .rest is not None :
2761- check_pattern_var (pattern .rest )
2762- rest_dict = {
2763- k : v
2764- for k , v in value .items ()
2765- if not any (self .visit (key ) == k for key in pattern .keys )
2766- }
2767- self ._set_name_value (pattern .rest , rest_dict )
2768- self .current_scope .locals .add (pattern .rest )
2769-
2770- return True
2668+ return self ._match_mapping_pattern (pattern , value , pattern_vars )
27712669 elif isinstance (pattern , ast .MatchStar ):
2772- if pattern .name :
2773- check_pattern_var (pattern .name )
2774- return True
2670+ return self ._match_star_pattern (pattern , value )
27752671 elif isinstance (pattern , ast .MatchAs ):
2776- if pattern .pattern is not None :
2777- if not self ._match_pattern (
2778- pattern .pattern , value , pattern_vars
2779- ):
2780- return False
2781- if pattern .name is not None :
2782- check_pattern_var (pattern .name )
2783- self ._set_name_value (pattern .name , value )
2784- self .current_scope .locals .add (pattern .name )
2785- return True
2672+ return self ._match_as_pattern (pattern , value , pattern_vars )
27862673 elif isinstance (pattern , ast .MatchOr ):
27872674 # Try each alternative pattern
27882675 matched_vars = None
@@ -2799,34 +2686,50 @@ def check_pattern_var(name: str) -> None:
27992686 return True
28002687 return False
28012688 elif isinstance (pattern , ast .MatchClass ):
2802- cls = self .visit (pattern .cls )
2803- if not isinstance (value , cls ):
2804- return False
2689+ return self ._match_class_pattern (pattern , value , pattern_vars )
28052690
2806- # Get positional attributes from __match_args__
2807- match_args = getattr (cls , "__match_args__" , ())
2808- if len (pattern .patterns ) > len (match_args ):
2809- return False
2691+ return False
28102692
2811- # Match positional patterns
2812- for pat , attr_name in zip (pattern .patterns , match_args ):
2813- if not self ._match_pattern (
2814- pat , getattr (value , attr_name ), pattern_vars
2815- ):
2816- return False
2693+ def _match_class_pattern (
2694+ self ,
2695+ pattern : ast .MatchClass ,
2696+ value : t .Any ,
2697+ pattern_vars : set [str ],
2698+ ) -> bool :
2699+ """Match a class pattern.
28172700
2818- # Match keyword patterns
2819- for name , pat in zip (pattern .kwd_attrs , pattern .kwd_patterns ):
2820- if not hasattr (value , name ):
2821- return False
2822- if not self ._match_pattern (
2823- pat , getattr (value , name ), pattern_vars
2824- ):
2825- return False
2701+ Args:
2702+ pattern: MatchClass AST node
2703+ value: Value to match against
2704+ pattern_vars: Set of pattern variables already bound in this pattern
28262705
2827- return True
2706+ Returns:
2707+ Whether the pattern matches the value
2708+ """
2709+ cls = self .visit (pattern .cls )
2710+ if not isinstance (value , cls ):
2711+ return False
28282712
2829- return False
2713+ # Get positional attributes from __match_args__
2714+ match_args = getattr (cls , "__match_args__" , ())
2715+ if len (pattern .patterns ) > len (match_args ):
2716+ return False
2717+
2718+ # Match positional patterns
2719+ for pat , attr_name in zip (pattern .patterns , match_args ):
2720+ if not self ._match_pattern (
2721+ pat , getattr (value , attr_name ), pattern_vars
2722+ ):
2723+ return False
2724+
2725+ # Match keyword patterns
2726+ for name , pat in zip (pattern .kwd_attrs , pattern .kwd_patterns ):
2727+ if not hasattr (value , name ):
2728+ return False
2729+ if not self ._match_pattern (pat , getattr (value , name ), pattern_vars ):
2730+ return False
2731+
2732+ return True
28302733
28312734 def visit_Match (self , node : ast .Match ) -> None :
28322735 """Handle match-case statements.
0 commit comments