Skip to content

Commit 6afdcb1

Browse files
committed
Partial review of task streams
1 parent f431280 commit 6afdcb1

3 files changed

Lines changed: 65 additions & 70 deletions

File tree

distributed/client.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6300,10 +6300,16 @@ def __init__(self, client=None, plot=False, filename="task-stream.html"):
63006300
self._filename = filename
63016301
self.figure = None
63026302
self.client = client or default_client()
6303-
self.client.get_task_stream(start=0, stop=0) # ensure plugin
6303+
self._init = False
63046304

63056305
def __enter__(self):
6306-
self.start = time()
6306+
if not self._init:
6307+
self.client.get_task_stream(start=0, stop=0) # ensure plugin
6308+
self._init = True
6309+
6310+
# Smooth over time differences of client vs. workers
6311+
# FIXME this is very crude. We should query TaskStreamPlugin.index instead.
6312+
self.start = time() - 0.1
63076313
return self
63086314

63096315
def __exit__(self, exc_type, exc_value, traceback):
@@ -6315,6 +6321,13 @@ def __exit__(self, exc_type, exc_value, traceback):
63156321
self.data.extend(L)
63166322

63176323
async def __aenter__(self):
6324+
if not self._init:
6325+
await self.client.get_task_stream(start=0, stop=0) # ensure plugin
6326+
self._init = True
6327+
6328+
# Smooth over time differences of client vs. workers
6329+
# FIXME this is very crude. We should query TaskStreamPlugin.index instead.
6330+
self.start = time() - 0.1
63186331
return self
63196332

63206333
async def __aexit__(self, exc_type, exc_value, traceback):

distributed/diagnostics/task_stream.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,24 @@ def __init__(self, scheduler, maxlen=None):
3131
self.index = 0
3232

3333
def transition(self, key, start, finish, *args, **kwargs):
34-
if start == "processing":
35-
if key not in self.scheduler.tasks:
36-
return
37-
if not kwargs.get("startstops"):
38-
# Other methods require `kwargs` to have a non-empty list of `startstops`
39-
return
34+
if start == "processing" and finish in ("memory", "erred"):
35+
assert kwargs["startstops"]
4036
kwargs["key"] = key
41-
if finish == "memory" or finish == "erred":
42-
self.buffer.append(kwargs)
43-
self.index += 1
37+
self.buffer.append(kwargs)
38+
self.index += 1
4439

4540
def collect(self, start=None, stop=None, count=None):
4641
def bisect(target, left, right):
47-
if left == right:
48-
return left
49-
50-
mid = (left + right) // 2
51-
value = max(
52-
startstop["stop"] for startstop in self.buffer[mid]["startstops"]
53-
)
54-
55-
if value < target:
56-
return bisect(target, mid + 1, right)
57-
else:
58-
return bisect(target, left, mid)
42+
while left != right:
43+
mid = (left + right) // 2
44+
stop = max(
45+
startstop["stop"] for startstop in self.buffer[mid]["startstops"]
46+
)
47+
if stop < target:
48+
left = mid + 1
49+
else:
50+
right = mid
51+
return left
5952

6053
if isinstance(start, str):
6154
start = time() - parse_timedelta(start)

distributed/diagnostics/tests/test_task_stream.py

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from __future__ import annotations
22

3-
import os
4-
from time import sleep
5-
63
import pytest
74
from tlz import frequencies
85

@@ -85,50 +82,36 @@ async def test_collect(c, s, a, b):
8582
assert tasks.collect(start=start, count=3) == list(tasks.buffer)[:3]
8683

8784

88-
@gen_cluster(client=True)
89-
async def test_no_startstops(c, s, a, b):
90-
tasks = TaskStreamPlugin(s)
91-
s.add_plugin(tasks)
92-
# just to create the key on the scheduler
93-
future = c.submit(inc, 1)
94-
await wait(future)
95-
assert len(tasks.buffer) == 1
96-
97-
tasks.transition(future.key, "processing", "erred", stimulus_id="s1")
98-
# Transition was not recorded because it didn't contain `startstops`
99-
assert len(tasks.buffer) == 1
100-
101-
tasks.transition(future.key, "processing", "erred", stimulus_id="s2", startstops=[])
102-
# Transition was not recorded because `startstops` was empty
103-
assert len(tasks.buffer) == 1
104-
105-
tasks.transition(
106-
future.key,
107-
"processing",
108-
"erred",
109-
stimulus_id="s3",
110-
startstops=[dict(start=time(), stop=time())],
111-
)
112-
assert len(tasks.buffer) == 2
113-
114-
11585
@gen_cluster(client=True)
11686
async def test_client(c, s, a, b):
117-
L = await c.get_task_stream()
118-
assert L == ()
87+
await c.get_task_stream()
11988

120-
futures = c.map(slowinc, range(10), delay=0.1)
89+
futures = c.map(inc, range(10))
12190
await wait(futures)
122-
123-
tasks = s.plugins[TaskStreamPlugin.name]
124-
L = await c.get_task_stream()
125-
assert L == tuple(tasks.buffer)
91+
data = await c.get_task_stream()
92+
assert len(data) == 10
12693

12794

12895
def test_client_sync(client):
129-
with get_task_stream(client=client) as ts:
130-
sleep(0.1) # to smooth over time differences on the scheduler
131-
# to smooth over time differences on the scheduler
96+
client.get_task_stream()
97+
98+
futures = client.map(inc, range(10))
99+
wait(futures)
100+
data = client.get_task_stream()
101+
assert len(data) == 10
102+
103+
104+
@gen_cluster(client=True)
105+
async def test_client_ctx(c, s, a, b):
106+
async with get_task_stream() as ts:
107+
futures = c.map(inc, range(10))
108+
await wait(futures)
109+
110+
assert len(ts.data) == 10
111+
112+
113+
def test_client_ctx_sync(client):
114+
with get_task_stream() as ts:
132115
futures = client.map(inc, range(10))
133116
wait(futures)
134117

@@ -140,23 +123,29 @@ async def test_get_task_stream_plot(c, s, a, b):
140123
bkm = pytest.importorskip("bokeh.models")
141124
await c.get_task_stream()
142125

143-
futures = c.map(slowinc, range(10), delay=0.1)
126+
futures = c.map(inc, range(10))
144127
await wait(futures)
145128

146129
data, figure = await c.get_task_stream(plot=True)
130+
assert len(data) == 10
147131
assert isinstance(figure, bkm.Plot)
148132

149133

150-
def test_get_task_stream_save(client, tmp_path):
134+
@gen_cluster(client=True)
135+
async def test_get_task_stream_save(c, s, a, b, tmp_path):
151136
bkm = pytest.importorskip("bokeh.models")
152-
tmpdir = str(tmp_path)
153-
fn = os.path.join(tmpdir, "foo.html")
137+
await c.get_task_stream()
138+
139+
futures = c.map(inc, range(10))
140+
await wait(futures)
141+
142+
fn = str(tmp_path / "foo.html")
143+
data, figure = await c.get_task_stream(plot="save", filename=fn)
144+
assert len(data) == 10
154145

155-
with get_task_stream(plot="save", filename=fn) as ts:
156-
wait(client.map(inc, range(10)))
157146
with open(fn) as f:
158147
data = f.read()
159148
assert "inc" in data
160149
assert "bokeh" in data
161150

162-
assert isinstance(ts.figure, bkm.Plot)
151+
assert isinstance(figure, bkm.Plot)

0 commit comments

Comments
 (0)