Skip to content

make all manager methods async #273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions dask_labextension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from jupyter_server.utils import url_path_join

from . import config
from . import config # noqa
from .clusterhandler import DaskClusterHandler
from .dashboardhandler import DaskDashboardCheckHandler, DaskDashboardHandler
from .manager import DaskClusterManager


from ._version import __version__
from ._version import __version__ # noqa


def _jupyter_labextension_paths():
Expand All @@ -33,6 +34,7 @@ def load_jupyter_server_extension(nb_server_app):
cluster_id_regex = r"(?P<cluster_id>[^/]+)"
web_app = nb_server_app.web_app
base_url = web_app.settings["base_url"]
web_app.settings["dask_cluster_manager"] = DaskClusterManager()
get_cluster_path = url_path_join(base_url, "dask/clusters/" + cluster_id_regex)
list_clusters_path = url_path_join(base_url, "dask/clusters/" + "?")
get_dashboard_path = url_path_join(
Expand Down
26 changes: 18 additions & 8 deletions dask_labextension/clusterhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,34 @@
# Distributed under the terms of the Modified BSD License.

import json
from inspect import isawaitable

from tornado import web
from jupyter_server.base.handlers import APIHandler

from .manager import manager
from .manager import DaskClusterManager


class DaskClusterHandler(APIHandler):
"""
A tornado HTTP handler for managing dask clusters.
"""

manager: DaskClusterManager

async def prepare(self):
r = super().prepare()
if isawaitable(r):
await r
self.manager = await self.settings["dask_cluster_manager"]

@web.authenticated
async def delete(self, cluster_id: str) -> None:
"""
Delete a cluster by id.
"""
try: # to delete the cluster.
val = await manager.close_cluster(cluster_id)
val = await self.manager.close_cluster(cluster_id)
if val is None:
raise web.HTTPError(404, f"Dask cluster {cluster_id} not found")

Expand All @@ -37,12 +46,13 @@ async def get(self, cluster_id: str = "") -> None:
"""
Get a cluster by id. If no id is given, lists known clusters.
"""
manager = self.manager
if cluster_id == "":
cluster_list = manager.list_clusters()
cluster_list = await manager.list_clusters()
self.set_status(200)
self.finish(json.dumps(cluster_list))
else:
cluster_model = manager.get_cluster(cluster_id)
cluster_model = await manager.get_cluster(cluster_id)
if cluster_model is None:
raise web.HTTPError(404, f"Dask cluster {cluster_id} not found")

Expand All @@ -55,13 +65,13 @@ async def put(self, cluster_id: str = "") -> None:
Create a new cluster with a given id. If no id is given, a random
one is selected.
"""
if manager.get_cluster(cluster_id):
if await self.manager.get_cluster(cluster_id):
raise web.HTTPError(
403, f"A Dask cluster with ID {cluster_id} already exists!"
)

try:
cluster_model = await manager.start_cluster(cluster_id)
cluster_model = await self.manager.start_cluster(cluster_id)
self.set_status(200)
self.finish(json.dumps(cluster_model))
except Exception as e:
Expand All @@ -76,13 +86,13 @@ async def patch(self, cluster_id):
new_model = json.loads(self.request.body)
try:
if new_model.get("adapt") is not None:
cluster_model = manager.adapt_cluster(
cluster_model = await self.manager.adapt_cluster(
cluster_id,
new_model["adapt"]["minimum"],
new_model["adapt"]["maximum"],
)
else:
cluster_model = await manager.scale_cluster(
cluster_model = await self.manager.scale_cluster(
cluster_id, new_model["workers"]
)
self.set_status(200)
Expand Down
22 changes: 16 additions & 6 deletions dask_labextension/dashboardhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@
server, preventing CORS issues.
"""
import json
from inspect import isawaitable
from urllib import parse

from tornado import httpclient, web


from jupyter_server.base.handlers import APIHandler
from jupyter_server.utils import url_path_join
from jupyter_server_proxy.handlers import ProxyHandler

from .manager import manager
from .manager import DaskClusterManager


class DaskDashboardCheckHandler(APIHandler):
"""
A handler for checking validity of a dask dashboard.
"""

manager: DaskClusterManager

async def prepare(self):
r = super().prepare()
if isawaitable(r):
await r
self.manager = await self.settings["dask_cluster_manager"]

@web.authenticated
async def get(self, url) -> None:
"""
Expand Down Expand Up @@ -133,7 +143,7 @@ async def http_get(self, cluster_id, proxied_path):
return await self.proxy(cluster_id, proxied_path)

async def open(self, cluster_id, proxied_path):
host, port = self._get_parsed(cluster_id)
host, port = await self._get_parsed(cluster_id)
return await super().proxy_open(host, port, proxied_path)

# We have to duplicate all these for now, I've no idea why!
Expand All @@ -157,17 +167,17 @@ def patch(self, cluster_id, proxied_path):
def options(self, cluster_id, proxied_path):
return self.proxy(cluster_id, proxied_path)

def proxy(self, cluster_id, proxied_path):
host, port = self._get_parsed(cluster_id)
async def proxy(self, cluster_id, proxied_path):
host, port = await self._get_parsed(cluster_id)
return super().proxy(host, port, proxied_path)

def _get_parsed(self, cluster_id):
async def _get_parsed(self, cluster_id):
"""
Given a cluster ID, get the hostname and port of its bokeh server.
"""
# Get the cluster by ID. If it is not found,
# raise an error.
cluster_model = manager.get_cluster(cluster_id)
cluster_model = await self.manager.get_cluster(cluster_id)
if not cluster_model:
raise web.HTTPError(404, f"Dask cluster {cluster_id} not found")

Expand Down
43 changes: 26 additions & 17 deletions dask_labextension/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import asyncio
import importlib
from inspect import isawaitable
from typing import Any, Dict, List, Union
Expand All @@ -11,8 +12,6 @@
import dask
from dask.utils import format_bytes
from dask.distributed import Adaptive
from tornado.ioloop import IOLoop
from tornado.concurrent import Future

# A type for a dask cluster model: a serializable
# representation of information about the cluster.
Expand Down Expand Up @@ -60,15 +59,28 @@ def __init__(self) -> None:
self._adaptives: Dict[str, Adaptive] = dict()
self._cluster_names: Dict[str, str] = dict()
self._n_clusters = 0
self._initialized = None

self.initialized = Future()
async def _async_init(self):
"""The async part of init

async def start_clusters():
for model in dask.config.get("labextension.initial"):
await self.start_cluster(configuration=model)
self.initialized.set_result(self)
Invoked by `await manager`
"""
for model in dask.config.get("labextension.initial"):
await self.start_cluster(configuration=model)
return self

IOLoop.current().add_callback(start_clusters)
@property
def initialized(self):
"""Don't create initialization task until it's been requested

typically via `await manager`

Makes it easier to ensure we don't do anything before we are in the event loop.
"""
if self._initialized is None:
self._initialized = asyncio.create_task(self._async_init())
return self._initialized

async def start_cluster(
self, cluster_id: str = "", configuration: dict = {}
Expand Down Expand Up @@ -121,7 +133,9 @@ async def close_cluster(self, cluster_id: str) -> Union[ClusterModel, None]:
"""
cluster = self._clusters.get(cluster_id)
if cluster:
await cluster.close()
r = cluster.close()
if isawaitable(r):
await r
self._clusters.pop(cluster_id)
name = self._cluster_names.pop(cluster_id)
adaptive = self._adaptives.pop(cluster_id, None)
Expand All @@ -130,7 +144,7 @@ async def close_cluster(self, cluster_id: str) -> Union[ClusterModel, None]:
else:
return None

def get_cluster(self, cluster_id) -> Union[ClusterModel, None]:
async def get_cluster(self, cluster_id) -> Union[ClusterModel, None]:
"""
Get a Dask cluster model.

Expand All @@ -151,7 +165,7 @@ def get_cluster(self, cluster_id) -> Union[ClusterModel, None]:

return make_cluster_model(cluster_id, name, cluster, adaptive)

def list_clusters(self) -> List[ClusterModel]:
async def list_clusters(self) -> List[ClusterModel]:
"""
List the Dask cluster models known to the manager.

Expand Down Expand Up @@ -188,7 +202,7 @@ async def scale_cluster(self, cluster_id: str, n: int) -> Union[ClusterModel, No
await t
return make_cluster_model(cluster_id, name, cluster, adaptive=None)

def adapt_cluster(
async def adapt_cluster(
self, cluster_id: str, minimum: int, maximum: int
) -> Union[ClusterModel, None]:
cluster = self._clusters.get(cluster_id)
Expand Down Expand Up @@ -290,8 +304,3 @@ def make_cluster_model(
model["adapt"] = {"minimum": adaptive.minimum, "maximum": adaptive.maximum}

return model


# Create a default cluster manager
# to keep track of clusters.
manager = DaskClusterManager()
26 changes: 13 additions & 13 deletions dask_labextension/tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_start():
assert not model.get("adapt")

# close cluster
assert len(manager.list_clusters()) == 1
assert len(await manager.list_clusters()) == 1
await manager.close_cluster(model["id"])

# add cluster with adaptive configuration
Expand All @@ -55,7 +55,7 @@ async def test_close():

# close the cluster
await manager.close_cluster(model["id"])
assert not manager.list_clusters()
assert not await manager.list_clusters()


@gen_test()
Expand All @@ -66,10 +66,10 @@ async def test_get():
model = await manager.start_cluster()

# return None if a nonexistent cluster is requested
assert not manager.get_cluster("fake")
assert not await manager.get_cluster("fake")

# get the cluster by id
assert model == manager.get_cluster(model["id"])
assert model == await manager.get_cluster(model["id"])


@pytest.mark.filterwarnings("ignore")
Expand All @@ -78,12 +78,12 @@ async def test_list():
with dask.config.set(config):
async with DaskClusterManager() as manager:
# start with an empty list
assert not manager.list_clusters()
assert not await manager.list_clusters()
# start clusters
model1 = await manager.start_cluster()
model2 = await manager.start_cluster()

models = manager.list_clusters()
models = await manager.list_clusters()
assert len(models) == 2
assert model1 in models
assert model2 in models
Expand All @@ -98,7 +98,7 @@ async def test_scale():
start = time()
while model["workers"] != 3:
await sleep(0.01)
model = manager.get_cluster(model["id"])
model = await manager.get_cluster(model["id"])
assert time() < start + 10, model["workers"]

await sleep(0.2) # let workers settle # TODO: remove need for this
Expand All @@ -108,7 +108,7 @@ async def test_scale():
start = time()
while model["workers"] != 6:
await sleep(0.01)
model = manager.get_cluster(model["id"])
model = await manager.get_cluster(model["id"])
assert time() < start + 10, model["workers"]


Expand All @@ -119,7 +119,7 @@ async def test_adapt():
# add a new cluster
model = await manager.start_cluster()
assert not model.get("adapt")
model = manager.adapt_cluster(model["id"], 0, 4)
model = await manager.adapt_cluster(model["id"], 0, 4)
adapt = model.get("adapt")
assert adapt
assert adapt["minimum"] == 0
Expand All @@ -144,21 +144,21 @@ async def test_initial():
):
# Test asynchronous starting of clusters via a context
async with DaskClusterManager() as manager:
clusters = manager.list_clusters()
clusters = await manager.list_clusters()
assert len(clusters) == 1
assert clusters[0]["name"] == "foo"

# Test asynchronous starting of clusters outside of a context
manager = DaskClusterManager()
assert len(manager.list_clusters()) == 0
assert len(await manager.list_clusters()) == 0
await manager
clusters = manager.list_clusters()
clusters = await manager.list_clusters()
assert len(clusters) == 1
assert clusters[0]["name"] == "foo"
await manager.close()

manager = await DaskClusterManager()
clusters = manager.list_clusters()
clusters = await manager.list_clusters()
assert len(clusters) == 1
assert clusters[0]["name"] == "foo"
await manager.close()
Loading