Skip to content

Commit badd083

Browse files
committed
Add 'singlecollect' distribution mode
This adds a new 'singlecollect' distribution mode that only collects tests on the first worker node and skips redundant collection on other nodes. This can significantly improve startup time for large test suites with expensive collection. Key features: - Only the first worker performs test collection - Other workers skip collection verification entirely - Tests are distributed using the same algorithm as 'load' mode - Handles worker failures gracefully, including the collecting worker - Solves issues with floating parameters in pytest collection
1 parent 3508f8c commit badd083

File tree

7 files changed

+415
-0
lines changed

7 files changed

+415
-0
lines changed

changelog/1180.feature.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added new 'singlecollect' distribution mode that only collects tests once on a single worker and skips collection verification on other workers. This can significantly improve startup time for test suites with expensive collection.

src/xdist/dsession.py

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from xdist.scheduler import LoadScheduling
2121
from xdist.scheduler import LoadScopeScheduling
2222
from xdist.scheduler import Scheduling
23+
from xdist.scheduler import SingleCollectScheduling
2324
from xdist.scheduler import WorkStealingScheduling
2425
from xdist.workermanage import NodeManager
2526
from xdist.workermanage import WorkerController
@@ -123,6 +124,8 @@ def pytest_xdist_make_scheduler(
123124
return LoadGroupScheduling(config, log)
124125
if dist == "worksteal":
125126
return WorkStealingScheduling(config, log)
127+
if dist == "singlecollect":
128+
return SingleCollectScheduling(config, log)
126129
return None
127130

128131
@pytest.hookimpl

src/xdist/plugin.py

+2
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def pytest_addoption(parser: pytest.Parser) -> None:
108108
"loadfile",
109109
"loadgroup",
110110
"worksteal",
111+
"singlecollect",
111112
"no",
112113
],
113114
dest="dist",
@@ -124,6 +125,7 @@ def pytest_addoption(parser: pytest.Parser) -> None:
124125
"loadgroup: Like 'load', but sends tests marked with 'xdist_group' to the same worker.\n\n"
125126
"worksteal: Split the test suite between available environments,"
126127
" then re-balance when any worker runs out of tests.\n\n"
128+
"singlecollect: Only collect tests once on a single worker and skip collection verification.\n\n"
127129
"(default) no: Run tests inprocess, don't distribute."
128130
),
129131
)

src/xdist/scheduler/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
from xdist.scheduler.loadgroup import LoadGroupScheduling as LoadGroupScheduling
55
from xdist.scheduler.loadscope import LoadScopeScheduling as LoadScopeScheduling
66
from xdist.scheduler.protocol import Scheduling as Scheduling
7+
from xdist.scheduler.singlecollect import SingleCollectScheduling as SingleCollectScheduling
78
from xdist.scheduler.worksteal import WorkStealingScheduling as WorkStealingScheduling

src/xdist/scheduler/singlecollect.py

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Sequence
4+
from itertools import cycle
5+
6+
import pytest
7+
8+
from xdist.remote import Producer
9+
from xdist.workermanage import parse_tx_spec_config
10+
from xdist.workermanage import WorkerController
11+
12+
13+
class SingleCollectScheduling:
14+
"""Implement scheduling with a single test collection phase.
15+
16+
This differs from LoadScheduling by:
17+
1. Only collecting tests on the first node
18+
2. Skipping collection on other nodes
19+
3. Not checking for collection equality
20+
21+
This can significantly improve startup time by avoiding redundant collection
22+
and collection verification across multiple worker processes.
23+
"""
24+
25+
def __init__(self, config: pytest.Config, log: Producer | None = None) -> None:
26+
self.numnodes = len(parse_tx_spec_config(config))
27+
self.node2pending: dict[WorkerController, list[int]] = {}
28+
self.pending: list[int] = []
29+
self.collection: list[str] | None = None
30+
self.first_node: WorkerController | None = None
31+
if log is None:
32+
self.log = Producer("singlecollect")
33+
else:
34+
self.log = log.singlecollect
35+
self.config = config
36+
self.maxschedchunk = self.config.getoption("maxschedchunk")
37+
self.collection_done = False
38+
39+
@property
40+
def nodes(self) -> list[WorkerController]:
41+
"""A list of all nodes in the scheduler."""
42+
return list(self.node2pending.keys())
43+
44+
@property
45+
def collection_is_completed(self) -> bool:
46+
"""Return True once we have collected tests from the first node."""
47+
return self.collection_done
48+
49+
@property
50+
def tests_finished(self) -> bool:
51+
"""Return True if all tests have been executed by the nodes."""
52+
if not self.collection_is_completed:
53+
return False
54+
if self.pending:
55+
return False
56+
for pending in self.node2pending.values():
57+
if len(pending) >= 2:
58+
return False
59+
return True
60+
61+
@property
62+
def has_pending(self) -> bool:
63+
"""Return True if there are pending test items."""
64+
if self.pending:
65+
return True
66+
for pending in self.node2pending.values():
67+
if pending:
68+
return True
69+
return False
70+
71+
def add_node(self, node: WorkerController) -> None:
72+
"""Add a new node to the scheduler."""
73+
assert node not in self.node2pending
74+
self.node2pending[node] = []
75+
76+
# Remember the first node as our collector
77+
if self.first_node is None:
78+
self.first_node = node
79+
self.log(f"Using {node.gateway.id} as collection node")
80+
81+
def add_node_collection(
82+
self, node: WorkerController, collection: Sequence[str]
83+
) -> None:
84+
"""Only use collection from the first node."""
85+
# We only care about collection from the first node
86+
if node == self.first_node:
87+
self.log(f"Received collection from first node {node.gateway.id}")
88+
self.collection = list(collection)
89+
self.collection_done = True
90+
else:
91+
# Skip collection verification for other nodes
92+
self.log(f"Ignoring collection from node {node.gateway.id}")
93+
94+
def mark_test_complete(
95+
self, node: WorkerController, item_index: int, duration: float = 0
96+
) -> None:
97+
"""Mark test item as completed by node."""
98+
self.node2pending[node].remove(item_index)
99+
self.check_schedule(node, duration=duration)
100+
101+
def mark_test_pending(self, item: str) -> None:
102+
assert self.collection is not None
103+
self.pending.insert(
104+
0,
105+
self.collection.index(item),
106+
)
107+
for node in self.node2pending:
108+
self.check_schedule(node)
109+
110+
def remove_pending_tests_from_node(
111+
self,
112+
node: WorkerController,
113+
indices: Sequence[int],
114+
) -> None:
115+
raise NotImplementedError()
116+
117+
def check_schedule(self, node: WorkerController, duration: float = 0) -> None:
118+
"""Maybe schedule new items on the node."""
119+
if node.shutting_down:
120+
return
121+
122+
if self.pending:
123+
# how many nodes do we have?
124+
num_nodes = len(self.node2pending)
125+
# if our node goes below a heuristic minimum, fill it out to
126+
# heuristic maximum
127+
items_per_node_min = max(2, len(self.pending) // num_nodes // 4)
128+
items_per_node_max = max(2, len(self.pending) // num_nodes // 2)
129+
node_pending = self.node2pending[node]
130+
if len(node_pending) < items_per_node_min:
131+
if duration >= 0.1 and len(node_pending) >= 2:
132+
# seems the node is doing long-running tests
133+
# and has enough items to continue
134+
# so let's rather wait with sending new items
135+
return
136+
num_send = items_per_node_max - len(node_pending)
137+
# keep at least 2 tests pending even if --maxschedchunk=1
138+
maxschedchunk = max(2 - len(node_pending), self.maxschedchunk)
139+
self._send_tests(node, min(num_send, maxschedchunk))
140+
else:
141+
node.shutdown()
142+
143+
self.log("num items waiting for node:", len(self.pending))
144+
145+
def remove_node(self, node: WorkerController) -> str | None:
146+
"""Remove a node from the scheduler."""
147+
pending = self.node2pending.pop(node)
148+
149+
# If this is the first node (collector), reset it
150+
if node == self.first_node:
151+
self.first_node = None
152+
153+
if not pending:
154+
return None
155+
156+
# Reassign pending items if the node had any
157+
assert self.collection is not None
158+
crashitem = self.collection[pending.pop(0)]
159+
self.pending.extend(pending)
160+
for node in self.node2pending:
161+
self.check_schedule(node)
162+
return crashitem
163+
164+
def schedule(self) -> None:
165+
"""Initiate distribution of the test collection."""
166+
assert self.collection_is_completed
167+
168+
# Initial distribution already happened, reschedule on all nodes
169+
if self.pending:
170+
for node in self.nodes:
171+
self.check_schedule(node)
172+
return
173+
174+
# Initialize the index of pending items
175+
assert self.collection is not None
176+
self.pending[:] = range(len(self.collection))
177+
if not self.collection:
178+
return
179+
180+
if self.maxschedchunk is None:
181+
self.maxschedchunk = len(self.collection)
182+
183+
# Send a batch of tests to run. If we don't have at least two
184+
# tests per node, we have to send them all so that we can send
185+
# shutdown signals and get all nodes working.
186+
if len(self.pending) < 2 * len(self.nodes):
187+
# Distribute tests round-robin
188+
nodes = cycle(self.nodes)
189+
for _ in range(len(self.pending)):
190+
self._send_tests(next(nodes), 1)
191+
else:
192+
# how many items per node do we have about?
193+
items_per_node = len(self.collection) // len(self.node2pending)
194+
# take a fraction of tests for initial distribution
195+
node_chunksize = min(items_per_node // 4, self.maxschedchunk)
196+
node_chunksize = max(node_chunksize, 2)
197+
# and initialize each node with a chunk of tests
198+
for node in self.nodes:
199+
self._send_tests(node, node_chunksize)
200+
201+
if not self.pending:
202+
# initial distribution sent all tests, start node shutdown
203+
for node in self.nodes:
204+
node.shutdown()
205+
206+
def _send_tests(self, node: WorkerController, num: int) -> None:
207+
tests_per_node = self.pending[:num]
208+
if tests_per_node:
209+
del self.pending[:num]
210+
self.node2pending[node].extend(tests_per_node)
211+
node.send_runtest_some(tests_per_node)

testing/acceptance_test.py

+64
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,70 @@
1010
import xdist
1111

1212

13+
class TestSingleCollectScheduling:
14+
def test_singlecollect_mode(self, pytester: pytest.Pytester) -> None:
15+
"""Test that the singlecollect distribution mode works."""
16+
# Create a simple test file
17+
p1 = pytester.makepyfile(
18+
"""
19+
def test_ok():
20+
pass
21+
"""
22+
)
23+
result = pytester.runpytest(p1, "-n2", "--dist=singlecollect", "-v")
24+
assert result.ret == 0
25+
result.stdout.fnmatch_lines(["*1 passed*"])
26+
# Make sure the tests are correctly distributed
27+
result.stdout.fnmatch_lines(["*scheduling tests via SingleCollectScheduling*"])
28+
29+
def test_singlecollect_many_tests(self, pytester: pytest.Pytester) -> None:
30+
"""Test that the singlecollect mode correctly distributes many tests."""
31+
# Create test file with multiple tests
32+
p1 = pytester.makepyfile(
33+
"""
34+
import pytest
35+
@pytest.mark.parametrize("x", range(10))
36+
def test_ok(x):
37+
assert True
38+
"""
39+
)
40+
result = pytester.runpytest(p1, "-n2", "--dist=singlecollect", "-v")
41+
assert result.ret == 0
42+
result.stdout.fnmatch_lines(["*passed*"])
43+
# Make sure the tests are correctly distributed
44+
result.stdout.fnmatch_lines(["*scheduling tests via SingleCollectScheduling*"])
45+
46+
def test_singlecollect_failure(self, pytester: pytest.Pytester) -> None:
47+
"""Test that failures are correctly reported with singlecollect mode."""
48+
p1 = pytester.makepyfile(
49+
"""
50+
def test_fail():
51+
assert 0
52+
"""
53+
)
54+
result = pytester.runpytest(p1, "-n2", "--dist=singlecollect", "-v")
55+
assert result.ret == 1
56+
result.stdout.fnmatch_lines(["*1 failed*"])
57+
58+
def test_singlecollect_handles_fixtures(self, pytester: pytest.Pytester) -> None:
59+
"""Test that fixtures work correctly with singlecollect mode."""
60+
pytester.makepyfile(
61+
"""
62+
import pytest
63+
64+
@pytest.fixture
65+
def my_fixture():
66+
return 42
67+
68+
def test_with_fixture(my_fixture):
69+
assert my_fixture == 42
70+
"""
71+
)
72+
result = pytester.runpytest("-n2", "--dist=singlecollect", "-v")
73+
assert result.ret == 0
74+
result.stdout.fnmatch_lines(["*1 passed*"])
75+
76+
1377
class TestDistribution:
1478
def test_n1_pass(self, pytester: pytest.Pytester) -> None:
1579
p1 = pytester.makepyfile(

0 commit comments

Comments
 (0)