Skip to content

Commit 513c93d

Browse files
authored
Merge pull request #8 from wasi-master/copilot/fix-3
Fix NameError when timing code with global variable conflicts
2 parents b1fd2db + 2d54fbe commit 513c93d

1 file changed

Lines changed: 115 additions & 1 deletion

File tree

fastero/utils.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,4 +366,118 @@ def make_bar_plot(labels, amounts, ascii_only = False) -> Panel:
366366
class _Timer(timeit.Timer):
367367
def __init__(self, *args, **kwargs):
368368
self.stmt = kwargs.get('stmt')
369-
super().__init__(*args, **kwargs)
369+
# Store setup for our custom handling
370+
self.setup_code = kwargs.get('setup', 'pass')
371+
super().__init__(*args, **kwargs)
372+
373+
def _extract_globals_and_assignments(self, code):
374+
"""Extract global declarations and top-level assignments that might conflict."""
375+
import ast
376+
377+
try:
378+
tree = ast.parse(code)
379+
except SyntaxError:
380+
return set(), {}
381+
382+
globals_vars = set()
383+
assignments = {}
384+
385+
# Look at all nodes, including nested ones for global declarations
386+
for node in ast.walk(tree):
387+
if isinstance(node, ast.Global):
388+
globals_vars.update(node.names)
389+
390+
# Only look at top-level assignments
391+
for node in tree.body:
392+
if isinstance(node, ast.Assign):
393+
for target in node.targets:
394+
if isinstance(target, ast.Name):
395+
var_name = target.id
396+
# Try to get the value if it's a simple constant
397+
if isinstance(node.value, (ast.Constant, ast.Num, ast.Str)):
398+
try:
399+
if hasattr(node.value, 'value'): # ast.Constant
400+
assignments[var_name] = node.value.value
401+
elif hasattr(node.value, 'n'): # ast.Num (older Python)
402+
assignments[var_name] = node.value.n
403+
elif hasattr(node.value, 's'): # ast.Str (older Python)
404+
assignments[var_name] = node.value.s
405+
except:
406+
pass
407+
408+
return globals_vars, assignments
409+
410+
def timeit(self, number=timeit.default_number):
411+
"""Enhanced timeit that handles global variables properly."""
412+
# Check if we have a global/assignment conflict
413+
if self.stmt:
414+
globals_vars, assignments = self._extract_globals_and_assignments(self.stmt)
415+
conflicting_vars = globals_vars & assignments.keys()
416+
417+
if conflicting_vars:
418+
# We have a conflict - need to modify execution
419+
return self._timeit_with_globals(number, conflicting_vars, assignments)
420+
421+
# No conflict, use standard timeit
422+
return super().timeit(number)
423+
424+
def _timeit_with_globals(self, number, conflicting_vars, assignments):
425+
"""Execute timing with proper global variable handling."""
426+
import ast
427+
import types
428+
429+
# Create a modified version of the statement
430+
# Remove top-level assignments for conflicting variables
431+
tree = ast.parse(self.stmt)
432+
433+
# Filter out conflicting assignments from the statement
434+
new_body = []
435+
for node in tree.body:
436+
if isinstance(node, ast.Assign):
437+
# Check if this assigns to any conflicting variable
438+
assigns_conflicting = False
439+
for target in node.targets:
440+
if isinstance(target, ast.Name) and target.id in conflicting_vars:
441+
assigns_conflicting = True
442+
break
443+
if not assigns_conflicting:
444+
new_body.append(node)
445+
else:
446+
new_body.append(node)
447+
448+
tree.body = new_body
449+
modified_stmt = ast.unparse(tree) if new_body else "pass"
450+
451+
# Create a global namespace with the conflicting variables
452+
execution_globals = {}
453+
execution_globals.update(self.inner.__globals__)
454+
455+
# Add the conflicting variables to globals
456+
for var in conflicting_vars:
457+
if var in assignments:
458+
execution_globals[var] = assignments[var]
459+
460+
# Create a new timer with modified statement and proper globals
461+
# Need to properly indent the modified statement for the loop
462+
modified_stmt_lines = modified_stmt.split('\n')
463+
indented_stmt = '\n '.join(modified_stmt_lines)
464+
465+
timer_code = f"""
466+
def inner(_it, _timer):
467+
{self.setup_code}
468+
_t0 = _timer()
469+
for _i in _it:
470+
{indented_stmt}
471+
_t1 = _timer()
472+
return _t1 - _t0
473+
"""
474+
475+
# Execute the timer code in our custom globals
476+
local_vars = {}
477+
exec(timer_code, execution_globals, local_vars)
478+
inner_func = local_vars['inner']
479+
480+
# Time the execution
481+
it = iter(range(number))
482+
timing = inner_func(it, self.timer)
483+
return timing

0 commit comments

Comments
 (0)