Skip to content

Commit e05c0b2

Browse files
committed
remove autoscale and use multiplexed_model_id
1 parent 14aa7f2 commit e05c0b2

File tree

4 files changed

+88
-63
lines changed

4 files changed

+88
-63
lines changed

bioimageio_colab/__main__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
default=None,
3636
help="Address of the Ray cluster for running SAM",
3737
)
38+
parser.add_argument(
39+
"--num_replicas",
40+
default=2,
41+
type=int,
42+
help="Number of replicas for the SAM deployment",
43+
)
3844
parser.add_argument(
3945
"--restart_deployment",
4046
default=False,
@@ -49,7 +55,7 @@
4955
)
5056
parser.add_argument(
5157
"--max_concurrent_requests",
52-
default=10,
58+
default=4,
5359
type=int,
5460
help="Maximum number of concurrent requests to the service",
5561
)

bioimageio_colab/models/sam_deployment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ async def get_model(self, model_id: str):
5858
model_architecture=self.models[model_id]["architecture"],
5959
)
6060

61-
async def __call__(self, model_id: str, array: np.ndarray):
61+
async def __call__(self, array: np.ndarray) -> np.ndarray:
62+
# Get the model from the request
63+
model_id = serve.get_multiplexed_model_id()
6264
model = await self.get_model(model_id)
65+
6366
return model.encode(array)

bioimageio_colab/register_sam_service.py

Lines changed: 76 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dotenv import find_dotenv, load_dotenv
1010
from hypha_rpc import connect_to_server
1111
from hypha_rpc.rpc import RemoteService
12+
1213
# from kaibu_utils import mask_to_features
1314
from ray.serve.config import AutoscalingConfig
1415
from tifffile import imread
@@ -73,8 +74,8 @@ def connect_to_ray(address: str = None) -> None:
7374
async def deploy_to_ray(
7475
cache_dir: str,
7576
app_name: str = "SAM Image Encoder",
76-
min_replicas: int = 1,
77-
max_replicas: int = 2,
77+
num_replicas: int = 2,
78+
max_queued_requests: int = 10,
7879
restart_deployment: bool = False,
7980
skip_test_runs: bool = False,
8081
) -> None:
@@ -86,16 +87,15 @@ async def deploy_to_ray(
8687
Returns:
8788
dict: Handles to the deployed image encoders.
8889
"""
89-
logger.info(f"Deploying the app '{app_name}' with {min_replicas} to {max_replicas} replicas on Ray Serve...")
90-
# Set autoscaling configuration
91-
autoscaling_config = AutoscalingConfig(
92-
min_replicas=min_replicas,
93-
max_replicas=max_replicas,
90+
logger.info(
91+
f"Deploying the app '{app_name}' with {num_replicas} replicas on Ray Serve..."
9492
)
9593

9694
# Create a new Ray Serve deployment
9795
deployment = SamDeployment.options(
98-
autoscaling_config=autoscaling_config,
96+
num_replicas=num_replicas,
97+
max_replicas_per_node=1,
98+
max_queued_requests=max_queued_requests,
9999
)
100100

101101
# Bind the arguments to the deployment and return an Application.
@@ -113,22 +113,28 @@ async def deploy_to_ray(
113113
# Deploy the application
114114
ray.serve.run(app, name=app_name, route_prefix=None)
115115

116+
# Get the application handle
117+
app_handle = ray.serve.get_app_handle(name=app_name)
118+
if not app_handle:
119+
raise ConnectionError("Failed to get the application handle.")
120+
116121
if app_name in applications:
117122
logger.info(f"Updated application deployment '{app_name}'.")
118123
else:
119124
logger.info(f"Deployed application '{app_name}'.")
120-
125+
121126
if skip_test_runs:
122127
logger.info("Skipping test runs for each model.")
123128
else:
124129
# Test run each model
125-
handle = ray.serve.get_app_handle(name=app_name)
126130
img_file = os.path.join(BASE_DIR, "data/example_image.tif")
127131
image = imread(img_file)
128132
for model_id in SAM_MODELS.keys():
129-
await handle.remote(model_id=model_id, array=image)
133+
await app_handle.options(multiplexed_model_id=model_id).remote(image)
130134
logger.info(f"Test run for model '{model_id}' completed successfully.")
131135

136+
return app_handle
137+
132138

133139
def hello(context: dict = None) -> str:
134140
return "Welcome to the Interactive Segmentation service!"
@@ -139,8 +145,8 @@ def ping(context: dict = None) -> str:
139145

140146

141147
async def compute_image_embedding(
148+
app_handle: ray.serve.handle.DeploymentHandle,
142149
semaphore: asyncio.Semaphore,
143-
app_name: str,
144150
image: np.ndarray,
145151
model_id: str,
146152
require_login: bool = False,
@@ -162,8 +168,9 @@ async def compute_image_embedding(
162168

163169
# Compute the embedding
164170
# Returns: {"features": embedding, "input_size": input_size}
165-
handle = ray.serve.get_app_handle(name=app_name)
166-
result = await handle.remote(model_id=model_id, array=image)
171+
result = await app_handle.options(multiplexed_model_id=model_id).remote(
172+
image
173+
)
167174

168175
logger.info(f"User '{user_id}' - Embedding computed successfully.")
169176

@@ -232,19 +239,21 @@ async def check_readiness(client: RemoteService, service_id: str) -> dict:
232239
"""
233240
Readiness probe for the SAM service.
234241
"""
235-
logger.info("Checking the readiness of the SAM service...")
236-
237-
services = await client.list_services()
238-
service_found = False
239-
for service in services:
240-
if service["id"] == service_id:
241-
service_found = True
242-
break
243-
assert service_found, f"Service with ID '{service_id}' not found."
242+
try:
243+
services = await client.list_services()
244+
service_found = False
245+
for service in services:
246+
if service["id"] == service_id:
247+
service_found = True
248+
break
249+
assert service_found, f"Service with ID '{service_id}' not found."
244250

245-
logger.info(f"Service with ID '{service_id}' is ready.")
251+
logger.info(f"Service with ID '{service_id}' is ready.")
246252

247-
return {"status": "ready"}
253+
return {"status": "ready"}
254+
except Exception as e:
255+
logger.error(f"Error during readiness probe: {e}")
256+
raise e
248257

249258

250259
def format_time(last_deployed_time_s, tz: timezone = timezone.utc) -> str:
@@ -285,40 +294,41 @@ async def check_liveness(app_name: str, client: RemoteService, service_id: str)
285294
"""
286295
Liveness probe for the SAM service.
287296
"""
288-
logger.info(f"Checking the liveness of the SAM service...")
289-
output = {}
290-
291-
# Check the Ray Serve application status
292-
serve_status = ray.serve.status()
293-
application = serve_status.applications[app_name]
294-
assert application.status == "RUNNING"
295-
formatted_time = format_time(application.last_deployed_time_s)
296-
output[f"application: {app_name}"] = {
297-
"status": application.status.value,
298-
"last_deployed_at": formatted_time["last_deployed_at"],
299-
"duration_since": formatted_time["duration_since"],
300-
}
301-
302-
for name, deployment in serve_status.applications[app_name].deployments.items():
303-
assert deployment.status == "HEALTHY"
304-
# assert deployment.replica_states == {"RUNNING": 1}
305-
output[f"application: {app_name}"][f"deployment: {name}"] = {
306-
"status": deployment.status.value,
307-
# "replica_states": deployment.replica_states,
297+
try:
298+
output = {}
299+
300+
# Check the Ray Serve application status
301+
serve_status = ray.serve.status()
302+
application = serve_status.applications[app_name]
303+
assert application.status == "RUNNING"
304+
formatted_time = format_time(application.last_deployed_time_s)
305+
output[f"application: {app_name}"] = {
306+
"status": application.status.value,
307+
"last_deployed_at": formatted_time["last_deployed_at"],
308+
"duration_since": formatted_time["duration_since"],
308309
}
309310

310-
# Check if the service can be accessed
311-
service = await client.get_service(service_id)
312-
assert await service.ping() == "pong"
313-
314-
output["service"] = {
315-
"status": "RUNNING",
316-
"service_id": service_id,
317-
}
318-
319-
logger.info(f"Service with ID '{service_id}' is live.")
311+
for name, deployment in serve_status.applications[app_name].deployments.items():
312+
assert deployment.status == "HEALTHY"
313+
assert deployment.replica_states["RUNNING"] > 0
314+
output[f"application: {app_name}"][f"deployment: {name}"] = {
315+
"status": deployment.status.value,
316+
"replica_states": deployment.replica_states,
317+
}
318+
319+
# Check if the service can be accessed
320+
service = await client.get_service(service_id)
321+
assert await service.ping() == "pong"
322+
323+
output["service"] = {
324+
"status": "RUNNING",
325+
"service_id": service_id,
326+
}
320327

321-
return output
328+
return output
329+
except Exception as e:
330+
logger.error(f"Error during liveness probe: {e}")
331+
raise e
322332

323333

324334
async def register_service(args: dict) -> None:
@@ -353,18 +363,24 @@ async def register_service(args: dict) -> None:
353363
# Deploy SAM image encoders
354364
cache_dir = os.path.abspath(args.cache_dir)
355365
app_name = "SAM Image Encoder"
356-
await deploy_to_ray(
366+
app_handle = await deploy_to_ray(
357367
cache_dir=cache_dir,
358368
app_name=app_name,
369+
num_replicas=args.num_replicas,
370+
max_queued_requests=args.max_concurrent_requests,
359371
restart_deployment=args.restart_deployment,
360372
skip_test_runs=args.skip_test_runs,
361373
)
362374

363375
# Register a new service
364376
semaphore = asyncio.Semaphore(args.max_concurrent_requests)
365-
logger.info(f"Created semaphore for {args.max_concurrent_requests} concurrent requests.")
377+
logger.info(
378+
f"Created semaphore for {args.max_concurrent_requests} concurrent requests."
379+
)
366380

367-
logger.ingo(f"Registering the SAM service: ID='{args.service_id}', require_login={args.require_login}")
381+
logger.info(
382+
f"Registering the SAM service: ID='{args.service_id}', require_login={args.require_login}"
383+
)
368384
service_info = await client.register_service(
369385
{
370386
"name": "Interactive Segmentation",
@@ -379,8 +395,8 @@ async def register_service(args: dict) -> None:
379395
"ping": ping,
380396
"compute_embedding": partial(
381397
compute_image_embedding,
398+
app_handle=app_handle,
382399
semaphore=semaphore,
383-
app_name=app_name,
384400
require_login=args.require_login,
385401
),
386402
# "compute_mask": partial(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
33

44
[project]
55
name = "bioimageio-colab"
6-
version = "0.2.2"
6+
version = "0.2.3"
77
readme = "README.md"
88
description = "Collaborative image annotation and model training with human in the loop."
99
dependencies = [

0 commit comments

Comments
 (0)