|
38 | 38 | from distributed.utils_test import ( |
39 | 39 | NO_AMM, |
40 | 40 | BlockedGetData, |
| 41 | + async_poll_for, |
41 | 42 | captured_logger, |
42 | 43 | freeze_batched_send, |
43 | 44 | gen_cluster, |
@@ -1877,3 +1878,54 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): |
1877 | 1878 | await c.gather(futs) |
1878 | 1879 | events = s.get_events("stealing") |
1879 | 1880 | assert len(events) == 0 |
| 1881 | + |
| 1882 | + |
| 1883 | +@gen_cluster( |
| 1884 | + nthreads=[("", 1)], |
| 1885 | + client=True, |
| 1886 | + config={"distributed.scheduler.worker-saturation": "inf"}, |
| 1887 | +) |
| 1888 | +async def test_stealing_ogjective_accounts_for_in_flight(c, s, a): |
| 1889 | + """Regression test that work-stealing's objective correctly accounts for in-flight data requests |
| 1890 | + """ |
| 1891 | + in_event = Event() |
| 1892 | + block_event = Event() |
| 1893 | + |
| 1894 | + def block(i: int, in_event: Event, block_event: Event) -> int: |
| 1895 | + in_event.set() |
| 1896 | + block_event.wait() |
| 1897 | + return i |
| 1898 | + |
| 1899 | + # Stop stealing for deterministic testing |
| 1900 | + extension = s.extensions["stealing"] |
| 1901 | + await extension.stop() |
| 1902 | + |
| 1903 | + try: |
| 1904 | + futs = c.map(block, range(20), in_event=in_event, block_event=block_event) |
| 1905 | + await in_event.wait() |
| 1906 | + |
| 1907 | + async with Worker(s.address, nthreads=1) as b: |
| 1908 | + try: |
| 1909 | + await async_poll_for(lambda: s.idle, timeout=5) |
| 1910 | + wsA = s.workers[a.address] |
| 1911 | + wsB = s.workers[b.address] |
| 1912 | + ts = next(iter(wsA.processing)) |
| 1913 | + |
| 1914 | + # No in-flight requests, so both match |
| 1915 | + assert extension.stealing_objective(ts, wsA) == s.worker_objective(ts, wsA) |
| 1916 | + assert extension.stealing_objective(ts, wsB) == s.worker_objective(ts, wsB) |
| 1917 | + |
| 1918 | + extension.balance() |
| 1919 | + assert extension.in_flight |
| 1920 | + # We move tasks from a to b |
| 1921 | + assert extension.stealing_objective(ts, wsA) < s.worker_objective(ts, wsA) |
| 1922 | + assert extension.stealing_objective(ts, wsB) > s.worker_objective(ts, wsB) |
| 1923 | + |
| 1924 | + await async_poll_for(lambda: not extension.in_flight, timeout=5) |
| 1925 | + # No in-flight requests, so both match |
| 1926 | + assert extension.stealing_objective(ts, wsA) == s.worker_objective(ts, wsA) |
| 1927 | + assert extension.stealing_objective(ts, wsB) == s.worker_objective(ts, wsB) |
| 1928 | + finally: |
| 1929 | + await block_event.set() |
| 1930 | + finally: |
| 1931 | + await block_event.set() |
0 commit comments