Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 80 additions & 6 deletions scripts/generate_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,85 @@ def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802
and node.func.value.id == "asyncio"
and node.func.attr == "gather"
):
# Wrap each arg in a lambda for executor submission
executor_func = ast.Attribute(
value=ast.Name(id="self", ctx=ast.Load()),
attr="_run_in_executor",
ctx=ast.Load(),
)

# Handle starred arg: asyncio.gather(*[expr for x in iter])
# -> self._run_in_executor(*[lambda x=x: expr for x in iter])
if len(node.args) == 1 and isinstance(node.args[0], ast.Starred):
starred = node.args[0]
if isinstance(starred.value, ast.ListComp):
comp = starred.value
# Build lambda with default-arg capture from comprehension vars
# to avoid late-binding closure bug.
# [expr for x in iter] -> [lambda x=x: expr for x in iter]
gen = comp.generators[0]
if isinstance(gen.target, ast.Name):
capture_name = gen.target.id
lambda_node = ast.Lambda(
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg=capture_name)],
kwonlyargs=[],
kw_defaults=[],
defaults=[ast.Name(id=capture_name, ctx=ast.Load())],
),
body=comp.elt,
)
new_comp = ast.ListComp(
elt=lambda_node,
generators=comp.generators,
)
return ast.copy_location(
ast.Call(
func=executor_func,
args=[ast.Starred(value=new_comp, ctx=ast.Load())],
keywords=[],
),
node,
)
# Fall through to generic starred handling
# Generic: asyncio.gather(*tasks)
# -> self._run_in_executor(*[lambda fn=fn: fn() for fn in tasks])
iter_name = "fn"
new_comp = ast.ListComp(
elt=ast.Lambda(
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg=iter_name)],
kwonlyargs=[],
kw_defaults=[],
defaults=[ast.Name(id=iter_name, ctx=ast.Load())],
),
body=ast.Call(
func=ast.Name(id=iter_name, ctx=ast.Load()),
args=[],
keywords=[],
),
),
generators=[
ast.comprehension(
target=ast.Name(id=iter_name, ctx=ast.Store()),
iter=starred.value,
ifs=[],
is_async=0,
)
],
)
return ast.copy_location(
ast.Call(
func=executor_func,
args=[ast.Starred(value=new_comp, ctx=ast.Load())],
keywords=[],
),
node,
)

# Fixed positional args: asyncio.gather(a, b)
# -> self._run_in_executor(lambda: a, lambda: b)
lambda_args = [
ast.Lambda(
args=ast.arguments(
Expand All @@ -406,11 +484,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802
]
return ast.copy_location(
ast.Call(
func=ast.Attribute(
value=ast.Name(id="self", ctx=ast.Load()),
attr="_run_in_executor",
ctx=ast.Load(),
),
func=executor_func,
args=lambda_args,
keywords=[],
),
Expand Down
Loading