Skip to content

Commit 19d5dac

Browse files
authored
make all manager methods async (#273)
* make all ClusterManager methods async allows easier, more consistent override - defer initialization to `await manager` instead of immediately (avoids invoking asyncio at import time) - instantiate default manager as part of extension loading, not at import time * fix dask_cluster_manager key
1 parent edac755 commit 19d5dac

File tree

5 files changed

+77
-46
lines changed

5 files changed

+77
-46
lines changed

dask_labextension/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from jupyter_server.utils import url_path_join
44

5-
from . import config
5+
from . import config # noqa
66
from .clusterhandler import DaskClusterHandler
77
from .dashboardhandler import DaskDashboardCheckHandler, DaskDashboardHandler
8+
from .manager import DaskClusterManager
89

910

10-
from ._version import __version__
11+
from ._version import __version__ # noqa
1112

1213

1314
def _jupyter_labextension_paths():
@@ -33,6 +34,7 @@ def load_jupyter_server_extension(nb_server_app):
3334
cluster_id_regex = r"(?P<cluster_id>[^/]+)"
3435
web_app = nb_server_app.web_app
3536
base_url = web_app.settings["base_url"]
37+
web_app.settings["dask_cluster_manager"] = DaskClusterManager()
3638
get_cluster_path = url_path_join(base_url, "dask/clusters/" + cluster_id_regex)
3739
list_clusters_path = url_path_join(base_url, "dask/clusters/" + "?")
3840
get_dashboard_path = url_path_join(

dask_labextension/clusterhandler.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,34 @@
44
# Distributed under the terms of the Modified BSD License.
55

66
import json
7+
from inspect import isawaitable
78

89
from tornado import web
910
from jupyter_server.base.handlers import APIHandler
1011

11-
from .manager import manager
12+
from .manager import DaskClusterManager
1213

1314

1415
class DaskClusterHandler(APIHandler):
1516
"""
1617
A tornado HTTP handler for managing dask clusters.
1718
"""
1819

20+
manager: DaskClusterManager
21+
22+
async def prepare(self):
23+
r = super().prepare()
24+
if isawaitable(r):
25+
await r
26+
self.manager = await self.settings["dask_cluster_manager"]
27+
1928
@web.authenticated
2029
async def delete(self, cluster_id: str) -> None:
2130
"""
2231
Delete a cluster by id.
2332
"""
2433
try: # to delete the cluster.
25-
val = await manager.close_cluster(cluster_id)
34+
val = await self.manager.close_cluster(cluster_id)
2635
if val is None:
2736
raise web.HTTPError(404, f"Dask cluster {cluster_id} not found")
2837

@@ -37,12 +46,13 @@ async def get(self, cluster_id: str = "") -> None:
3746
"""
3847
Get a cluster by id. If no id is given, lists known clusters.
3948
"""
49+
manager = self.manager
4050
if cluster_id == "":
41-
cluster_list = manager.list_clusters()
51+
cluster_list = await manager.list_clusters()
4252
self.set_status(200)
4353
self.finish(json.dumps(cluster_list))
4454
else:
45-
cluster_model = manager.get_cluster(cluster_id)
55+
cluster_model = await manager.get_cluster(cluster_id)
4656
if cluster_model is None:
4757
raise web.HTTPError(404, f"Dask cluster {cluster_id} not found")
4858

@@ -55,13 +65,13 @@ async def put(self, cluster_id: str = "") -> None:
5565
Create a new cluster with a given id. If no id is given, a random
5666
one is selected.
5767
"""
58-
if manager.get_cluster(cluster_id):
68+
if await self.manager.get_cluster(cluster_id):
5969
raise web.HTTPError(
6070
403, f"A Dask cluster with ID {cluster_id} already exists!"
6171
)
6272

6373
try:
64-
cluster_model = await manager.start_cluster(cluster_id)
74+
cluster_model = await self.manager.start_cluster(cluster_id)
6575
self.set_status(200)
6676
self.finish(json.dumps(cluster_model))
6777
except Exception as e:
@@ -76,13 +86,13 @@ async def patch(self, cluster_id):
7686
new_model = json.loads(self.request.body)
7787
try:
7888
if new_model.get("adapt") is not None:
79-
cluster_model = manager.adapt_cluster(
89+
cluster_model = await self.manager.adapt_cluster(
8090
cluster_id,
8191
new_model["adapt"]["minimum"],
8292
new_model["adapt"]["maximum"],
8393
)
8494
else:
85-
cluster_model = await manager.scale_cluster(
95+
cluster_model = await self.manager.scale_cluster(
8696
cluster_id, new_model["workers"]
8797
)
8898
self.set_status(200)

dask_labextension/dashboardhandler.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,32 @@
44
server, preventing CORS issues.
55
"""
66
import json
7+
from inspect import isawaitable
78
from urllib import parse
89

910
from tornado import httpclient, web
1011

12+
1113
from jupyter_server.base.handlers import APIHandler
1214
from jupyter_server.utils import url_path_join
1315
from jupyter_server_proxy.handlers import ProxyHandler
1416

15-
from .manager import manager
17+
from .manager import DaskClusterManager
1618

1719

1820
class DaskDashboardCheckHandler(APIHandler):
1921
"""
2022
A handler for checking validity of a dask dashboard.
2123
"""
2224

25+
manager: DaskClusterManager
26+
27+
async def prepare(self):
28+
r = super().prepare()
29+
if isawaitable(r):
30+
await r
31+
self.manager = await self.settings["dask_cluster_manager"]
32+
2333
@web.authenticated
2434
async def get(self, url) -> None:
2535
"""
@@ -133,7 +143,7 @@ async def http_get(self, cluster_id, proxied_path):
133143
return await self.proxy(cluster_id, proxied_path)
134144

135145
async def open(self, cluster_id, proxied_path):
136-
host, port = self._get_parsed(cluster_id)
146+
host, port = await self._get_parsed(cluster_id)
137147
return await super().proxy_open(host, port, proxied_path)
138148

139149
# We have to duplicate all these for now, I've no idea why!
@@ -157,17 +167,17 @@ def patch(self, cluster_id, proxied_path):
157167
def options(self, cluster_id, proxied_path):
158168
return self.proxy(cluster_id, proxied_path)
159169

160-
def proxy(self, cluster_id, proxied_path):
161-
host, port = self._get_parsed(cluster_id)
170+
async def proxy(self, cluster_id, proxied_path):
171+
host, port = await self._get_parsed(cluster_id)
162172
return super().proxy(host, port, proxied_path)
163173

164-
def _get_parsed(self, cluster_id):
174+
async def _get_parsed(self, cluster_id):
165175
"""
166176
Given a cluster ID, get the hostname and port of its bokeh server.
167177
"""
168178
# Get the cluster by ID. If it is not found,
169179
# raise an error.
170-
cluster_model = manager.get_cluster(cluster_id)
180+
cluster_model = await self.manager.get_cluster(cluster_id)
171181
if not cluster_model:
172182
raise web.HTTPError(404, f"Dask cluster {cluster_id} not found")
173183

dask_labextension/manager.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Copyright (c) Jupyter Development Team.
44
# Distributed under the terms of the Modified BSD License.
55

6+
import asyncio
67
import importlib
78
from inspect import isawaitable
89
from typing import Any, Dict, List, Union
@@ -11,8 +12,6 @@
1112
import dask
1213
from dask.utils import format_bytes
1314
from dask.distributed import Adaptive
14-
from tornado.ioloop import IOLoop
15-
from tornado.concurrent import Future
1615

1716
# A type for a dask cluster model: a serializable
1817
# representation of information about the cluster.
@@ -60,15 +59,28 @@ def __init__(self) -> None:
6059
self._adaptives: Dict[str, Adaptive] = dict()
6160
self._cluster_names: Dict[str, str] = dict()
6261
self._n_clusters = 0
62+
self._initialized = None
6363

64-
self.initialized = Future()
64+
async def _async_init(self):
65+
"""The async part of init
6566
66-
async def start_clusters():
67-
for model in dask.config.get("labextension.initial"):
68-
await self.start_cluster(configuration=model)
69-
self.initialized.set_result(self)
67+
Invoked by `await manager`
68+
"""
69+
for model in dask.config.get("labextension.initial"):
70+
await self.start_cluster(configuration=model)
71+
return self
7072

71-
IOLoop.current().add_callback(start_clusters)
73+
@property
74+
def initialized(self):
75+
"""Don't create initialization task until it's been requested
76+
77+
typically via `await manager`
78+
79+
Makes it easier to ensure we don't do anything before we are in the event loop.
80+
"""
81+
if self._initialized is None:
82+
self._initialized = asyncio.create_task(self._async_init())
83+
return self._initialized
7284

7385
async def start_cluster(
7486
self, cluster_id: str = "", configuration: dict = {}
@@ -121,7 +133,9 @@ async def close_cluster(self, cluster_id: str) -> Union[ClusterModel, None]:
121133
"""
122134
cluster = self._clusters.get(cluster_id)
123135
if cluster:
124-
await cluster.close()
136+
r = cluster.close()
137+
if isawaitable(r):
138+
await r
125139
self._clusters.pop(cluster_id)
126140
name = self._cluster_names.pop(cluster_id)
127141
adaptive = self._adaptives.pop(cluster_id, None)
@@ -130,7 +144,7 @@ async def close_cluster(self, cluster_id: str) -> Union[ClusterModel, None]:
130144
else:
131145
return None
132146

133-
def get_cluster(self, cluster_id) -> Union[ClusterModel, None]:
147+
async def get_cluster(self, cluster_id) -> Union[ClusterModel, None]:
134148
"""
135149
Get a Dask cluster model.
136150
@@ -151,7 +165,7 @@ def get_cluster(self, cluster_id) -> Union[ClusterModel, None]:
151165

152166
return make_cluster_model(cluster_id, name, cluster, adaptive)
153167

154-
def list_clusters(self) -> List[ClusterModel]:
168+
async def list_clusters(self) -> List[ClusterModel]:
155169
"""
156170
List the Dask cluster models known to the manager.
157171
@@ -188,7 +202,7 @@ async def scale_cluster(self, cluster_id: str, n: int) -> Union[ClusterModel, No
188202
await t
189203
return make_cluster_model(cluster_id, name, cluster, adaptive=None)
190204

191-
def adapt_cluster(
205+
async def adapt_cluster(
192206
self, cluster_id: str, minimum: int, maximum: int
193207
) -> Union[ClusterModel, None]:
194208
cluster = self._clusters.get(cluster_id)
@@ -290,8 +304,3 @@ def make_cluster_model(
290304
model["adapt"] = {"minimum": adaptive.minimum, "maximum": adaptive.maximum}
291305

292306
return model
293-
294-
295-
# Create a default cluster manager
296-
# to keep track of clusters.
297-
manager = DaskClusterManager()

dask_labextension/tests/test_manager.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def test_start():
3131
assert not model.get("adapt")
3232

3333
# close cluster
34-
assert len(manager.list_clusters()) == 1
34+
assert len(await manager.list_clusters()) == 1
3535
await manager.close_cluster(model["id"])
3636

3737
# add cluster with adaptive configuration
@@ -55,7 +55,7 @@ async def test_close():
5555

5656
# close the cluster
5757
await manager.close_cluster(model["id"])
58-
assert not manager.list_clusters()
58+
assert not await manager.list_clusters()
5959

6060

6161
@gen_test()
@@ -66,10 +66,10 @@ async def test_get():
6666
model = await manager.start_cluster()
6767

6868
# return None if a nonexistent cluster is requested
69-
assert not manager.get_cluster("fake")
69+
assert not await manager.get_cluster("fake")
7070

7171
# get the cluster by id
72-
assert model == manager.get_cluster(model["id"])
72+
assert model == await manager.get_cluster(model["id"])
7373

7474

7575
@pytest.mark.filterwarnings("ignore")
@@ -78,12 +78,12 @@ async def test_list():
7878
with dask.config.set(config):
7979
async with DaskClusterManager() as manager:
8080
# start with an empty list
81-
assert not manager.list_clusters()
81+
assert not await manager.list_clusters()
8282
# start clusters
8383
model1 = await manager.start_cluster()
8484
model2 = await manager.start_cluster()
8585

86-
models = manager.list_clusters()
86+
models = await manager.list_clusters()
8787
assert len(models) == 2
8888
assert model1 in models
8989
assert model2 in models
@@ -98,7 +98,7 @@ async def test_scale():
9898
start = time()
9999
while model["workers"] != 3:
100100
await sleep(0.01)
101-
model = manager.get_cluster(model["id"])
101+
model = await manager.get_cluster(model["id"])
102102
assert time() < start + 10, model["workers"]
103103

104104
await sleep(0.2) # let workers settle # TODO: remove need for this
@@ -108,7 +108,7 @@ async def test_scale():
108108
start = time()
109109
while model["workers"] != 6:
110110
await sleep(0.01)
111-
model = manager.get_cluster(model["id"])
111+
model = await manager.get_cluster(model["id"])
112112
assert time() < start + 10, model["workers"]
113113

114114

@@ -119,7 +119,7 @@ async def test_adapt():
119119
# add a new cluster
120120
model = await manager.start_cluster()
121121
assert not model.get("adapt")
122-
model = manager.adapt_cluster(model["id"], 0, 4)
122+
model = await manager.adapt_cluster(model["id"], 0, 4)
123123
adapt = model.get("adapt")
124124
assert adapt
125125
assert adapt["minimum"] == 0
@@ -144,21 +144,21 @@ async def test_initial():
144144
):
145145
# Test asynchronous starting of clusters via a context
146146
async with DaskClusterManager() as manager:
147-
clusters = manager.list_clusters()
147+
clusters = await manager.list_clusters()
148148
assert len(clusters) == 1
149149
assert clusters[0]["name"] == "foo"
150150

151151
# Test asynchronous starting of clusters outside of a context
152152
manager = DaskClusterManager()
153-
assert len(manager.list_clusters()) == 0
153+
assert len(await manager.list_clusters()) == 0
154154
await manager
155-
clusters = manager.list_clusters()
155+
clusters = await manager.list_clusters()
156156
assert len(clusters) == 1
157157
assert clusters[0]["name"] == "foo"
158158
await manager.close()
159159

160160
manager = await DaskClusterManager()
161-
clusters = manager.list_clusters()
161+
clusters = await manager.list_clusters()
162162
assert len(clusters) == 1
163163
assert clusters[0]["name"] == "foo"
164164
await manager.close()

0 commit comments

Comments
 (0)