@@ -366,4 +366,118 @@ def make_bar_plot(labels, amounts, ascii_only = False) -> Panel:
366366class _Timer (timeit .Timer ):
367367 def __init__ (self , * args , ** kwargs ):
368368 self .stmt = kwargs .get ('stmt' )
369- super ().__init__ (* args , ** kwargs )
369+ # Store setup for our custom handling
370+ self .setup_code = kwargs .get ('setup' , 'pass' )
371+ super ().__init__ (* args , ** kwargs )
372+
373+ def _extract_globals_and_assignments (self , code ):
374+ """Extract global declarations and top-level assignments that might conflict."""
375+ import ast
376+
377+ try :
378+ tree = ast .parse (code )
379+ except SyntaxError :
380+ return set (), {}
381+
382+ globals_vars = set ()
383+ assignments = {}
384+
385+ # Look at all nodes, including nested ones for global declarations
386+ for node in ast .walk (tree ):
387+ if isinstance (node , ast .Global ):
388+ globals_vars .update (node .names )
389+
390+ # Only look at top-level assignments
391+ for node in tree .body :
392+ if isinstance (node , ast .Assign ):
393+ for target in node .targets :
394+ if isinstance (target , ast .Name ):
395+ var_name = target .id
396+ # Try to get the value if it's a simple constant
397+ if isinstance (node .value , (ast .Constant , ast .Num , ast .Str )):
398+ try :
399+ if hasattr (node .value , 'value' ): # ast.Constant
400+ assignments [var_name ] = node .value .value
401+ elif hasattr (node .value , 'n' ): # ast.Num (older Python)
402+ assignments [var_name ] = node .value .n
403+ elif hasattr (node .value , 's' ): # ast.Str (older Python)
404+ assignments [var_name ] = node .value .s
405+ except :
406+ pass
407+
408+ return globals_vars , assignments
409+
410+ def timeit (self , number = timeit .default_number ):
411+ """Enhanced timeit that handles global variables properly."""
412+ # Check if we have a global/assignment conflict
413+ if self .stmt :
414+ globals_vars , assignments = self ._extract_globals_and_assignments (self .stmt )
415+ conflicting_vars = globals_vars & assignments .keys ()
416+
417+ if conflicting_vars :
418+ # We have a conflict - need to modify execution
419+ return self ._timeit_with_globals (number , conflicting_vars , assignments )
420+
421+ # No conflict, use standard timeit
422+ return super ().timeit (number )
423+
424+ def _timeit_with_globals (self , number , conflicting_vars , assignments ):
425+ """Execute timing with proper global variable handling."""
426+ import ast
427+ import types
428+
429+ # Create a modified version of the statement
430+ # Remove top-level assignments for conflicting variables
431+ tree = ast .parse (self .stmt )
432+
433+ # Filter out conflicting assignments from the statement
434+ new_body = []
435+ for node in tree .body :
436+ if isinstance (node , ast .Assign ):
437+ # Check if this assigns to any conflicting variable
438+ assigns_conflicting = False
439+ for target in node .targets :
440+ if isinstance (target , ast .Name ) and target .id in conflicting_vars :
441+ assigns_conflicting = True
442+ break
443+ if not assigns_conflicting :
444+ new_body .append (node )
445+ else :
446+ new_body .append (node )
447+
448+ tree .body = new_body
449+ modified_stmt = ast .unparse (tree ) if new_body else "pass"
450+
451+ # Create a global namespace with the conflicting variables
452+ execution_globals = {}
453+ execution_globals .update (self .inner .__globals__ )
454+
455+ # Add the conflicting variables to globals
456+ for var in conflicting_vars :
457+ if var in assignments :
458+ execution_globals [var ] = assignments [var ]
459+
460+ # Create a new timer with modified statement and proper globals
461+ # Need to properly indent the modified statement for the loop
462+ modified_stmt_lines = modified_stmt .split ('\n ' )
463+ indented_stmt = '\n ' .join (modified_stmt_lines )
464+
465+ timer_code = f"""
466+ def inner(_it, _timer):
467+ { self .setup_code }
468+ _t0 = _timer()
469+ for _i in _it:
470+ { indented_stmt }
471+ _t1 = _timer()
472+ return _t1 - _t0
473+ """
474+
475+ # Execute the timer code in our custom globals
476+ local_vars = {}
477+ exec (timer_code , execution_globals , local_vars )
478+ inner_func = local_vars ['inner' ]
479+
480+ # Time the execution
481+ it = iter (range (number ))
482+ timing = inner_func (it , self .timer )
483+ return timing
0 commit comments