99from dotenv import find_dotenv , load_dotenv
1010from hypha_rpc import connect_to_server
1111from hypha_rpc .rpc import RemoteService
12+
1213# from kaibu_utils import mask_to_features
1314from ray .serve .config import AutoscalingConfig
1415from tifffile import imread
@@ -73,8 +74,8 @@ def connect_to_ray(address: str = None) -> None:
7374async def deploy_to_ray (
7475 cache_dir : str ,
7576 app_name : str = "SAM Image Encoder" ,
76- min_replicas : int = 1 ,
77- max_replicas : int = 2 ,
77+ num_replicas : int = 2 ,
78+ max_queued_requests : int = 10 ,
7879 restart_deployment : bool = False ,
7980 skip_test_runs : bool = False ,
8081) -> None :
@@ -86,16 +87,15 @@ async def deploy_to_ray(
8687 Returns:
8788 dict: Handles to the deployed image encoders.
8889 """
89- logger .info (f"Deploying the app '{ app_name } ' with { min_replicas } to { max_replicas } replicas on Ray Serve..." )
90- # Set autoscaling configuration
91- autoscaling_config = AutoscalingConfig (
92- min_replicas = min_replicas ,
93- max_replicas = max_replicas ,
90+ logger .info (
91+ f"Deploying the app '{ app_name } ' with { num_replicas } replicas on Ray Serve..."
9492 )
9593
9694 # Create a new Ray Serve deployment
9795 deployment = SamDeployment .options (
98- autoscaling_config = autoscaling_config ,
96+ num_replicas = num_replicas ,
97+ max_replicas_per_node = 1 ,
98+ max_queued_requests = max_queued_requests ,
9999 )
100100
101101 # Bind the arguments to the deployment and return an Application.
@@ -113,22 +113,28 @@ async def deploy_to_ray(
113113 # Deploy the application
114114 ray .serve .run (app , name = app_name , route_prefix = None )
115115
116+ # Get the application handle
117+ app_handle = ray .serve .get_app_handle (name = app_name )
118+ if not app_handle :
119+ raise ConnectionError ("Failed to get the application handle." )
120+
116121 if app_name in applications :
117122 logger .info (f"Updated application deployment '{ app_name } '." )
118123 else :
119124 logger .info (f"Deployed application '{ app_name } '." )
120-
125+
121126 if skip_test_runs :
122127 logger .info ("Skipping test runs for each model." )
123128 else :
124129 # Test run each model
125- handle = ray .serve .get_app_handle (name = app_name )
126130 img_file = os .path .join (BASE_DIR , "data/example_image.tif" )
127131 image = imread (img_file )
128132 for model_id in SAM_MODELS .keys ():
129- await handle . remote ( model_id = model_id , array = image )
133+ await app_handle . options ( multiplexed_model_id = model_id ). remote ( image )
130134 logger .info (f"Test run for model '{ model_id } ' completed successfully." )
131135
136+ return app_handle
137+
132138
133139def hello (context : dict = None ) -> str :
134140 return "Welcome to the Interactive Segmentation service!"
@@ -139,8 +145,8 @@ def ping(context: dict = None) -> str:
139145
140146
141147async def compute_image_embedding (
148+ app_handle : ray .serve .handle .DeploymentHandle ,
142149 semaphore : asyncio .Semaphore ,
143- app_name : str ,
144150 image : np .ndarray ,
145151 model_id : str ,
146152 require_login : bool = False ,
@@ -162,8 +168,9 @@ async def compute_image_embedding(
162168
163169 # Compute the embedding
164170 # Returns: {"features": embedding, "input_size": input_size}
165- handle = ray .serve .get_app_handle (name = app_name )
166- result = await handle .remote (model_id = model_id , array = image )
171+ result = await app_handle .options (multiplexed_model_id = model_id ).remote (
172+ image
173+ )
167174
168175 logger .info (f"User '{ user_id } ' - Embedding computed successfully." )
169176
@@ -232,19 +239,21 @@ async def check_readiness(client: RemoteService, service_id: str) -> dict:
232239 """
233240 Readiness probe for the SAM service.
234241 """
235- logger .info ("Checking the readiness of the SAM service..." )
236-
237- services = await client .list_services ()
238- service_found = False
239- for service in services :
240- if service ["id" ] == service_id :
241- service_found = True
242- break
243- assert service_found , f"Service with ID '{ service_id } ' not found."
242+ try :
243+ services = await client .list_services ()
244+ service_found = False
245+ for service in services :
246+ if service ["id" ] == service_id :
247+ service_found = True
248+ break
249+ assert service_found , f"Service with ID '{ service_id } ' not found."
244250
245- logger .info (f"Service with ID '{ service_id } ' is ready." )
251+ logger .info (f"Service with ID '{ service_id } ' is ready." )
246252
247- return {"status" : "ready" }
253+ return {"status" : "ready" }
254+ except Exception as e :
255+ logger .error (f"Error during readiness probe: { e } " )
256+ raise e
248257
249258
250259def format_time (last_deployed_time_s , tz : timezone = timezone .utc ) -> str :
@@ -285,40 +294,41 @@ async def check_liveness(app_name: str, client: RemoteService, service_id: str)
285294 """
286295 Liveness probe for the SAM service.
287296 """
288- logger .info (f"Checking the liveness of the SAM service..." )
289- output = {}
290-
291- # Check the Ray Serve application status
292- serve_status = ray .serve .status ()
293- application = serve_status .applications [app_name ]
294- assert application .status == "RUNNING"
295- formatted_time = format_time (application .last_deployed_time_s )
296- output [f"application: { app_name } " ] = {
297- "status" : application .status .value ,
298- "last_deployed_at" : formatted_time ["last_deployed_at" ],
299- "duration_since" : formatted_time ["duration_since" ],
300- }
301-
302- for name , deployment in serve_status .applications [app_name ].deployments .items ():
303- assert deployment .status == "HEALTHY"
304- # assert deployment.replica_states == {"RUNNING": 1}
305- output [f"application: { app_name } " ][f"deployment: { name } " ] = {
306- "status" : deployment .status .value ,
307- # "replica_states": deployment.replica_states,
297+ try :
298+ output = {}
299+
300+ # Check the Ray Serve application status
301+ serve_status = ray .serve .status ()
302+ application = serve_status .applications [app_name ]
303+ assert application .status == "RUNNING"
304+ formatted_time = format_time (application .last_deployed_time_s )
305+ output [f"application: { app_name } " ] = {
306+ "status" : application .status .value ,
307+ "last_deployed_at" : formatted_time ["last_deployed_at" ],
308+ "duration_since" : formatted_time ["duration_since" ],
308309 }
309310
310- # Check if the service can be accessed
311- service = await client .get_service (service_id )
312- assert await service .ping () == "pong"
313-
314- output ["service" ] = {
315- "status" : "RUNNING" ,
316- "service_id" : service_id ,
317- }
318-
319- logger .info (f"Service with ID '{ service_id } ' is live." )
311+ for name , deployment in serve_status .applications [app_name ].deployments .items ():
312+ assert deployment .status == "HEALTHY"
313+ assert deployment .replica_states ["RUNNING" ] > 0
314+ output [f"application: { app_name } " ][f"deployment: { name } " ] = {
315+ "status" : deployment .status .value ,
316+ "replica_states" : deployment .replica_states ,
317+ }
318+
319+ # Check if the service can be accessed
320+ service = await client .get_service (service_id )
321+ assert await service .ping () == "pong"
322+
323+ output ["service" ] = {
324+ "status" : "RUNNING" ,
325+ "service_id" : service_id ,
326+ }
320327
321- return output
328+ return output
329+ except Exception as e :
330+ logger .error (f"Error during liveness probe: { e } " )
331+ raise e
322332
323333
324334async def register_service (args : dict ) -> None :
@@ -353,18 +363,24 @@ async def register_service(args: dict) -> None:
353363 # Deploy SAM image encoders
354364 cache_dir = os .path .abspath (args .cache_dir )
355365 app_name = "SAM Image Encoder"
356- await deploy_to_ray (
366+ app_handle = await deploy_to_ray (
357367 cache_dir = cache_dir ,
358368 app_name = app_name ,
369+ num_replicas = args .num_replicas ,
370+ max_queued_requests = args .max_concurrent_requests ,
359371 restart_deployment = args .restart_deployment ,
360372 skip_test_runs = args .skip_test_runs ,
361373 )
362374
363375 # Register a new service
364376 semaphore = asyncio .Semaphore (args .max_concurrent_requests )
365- logger .info (f"Created semaphore for { args .max_concurrent_requests } concurrent requests." )
377+ logger .info (
378+ f"Created semaphore for { args .max_concurrent_requests } concurrent requests."
379+ )
366380
367- logger .ingo (f"Registering the SAM service: ID='{ args .service_id } ', require_login={ args .require_login } " )
381+ logger .info (
382+ f"Registering the SAM service: ID='{ args .service_id } ', require_login={ args .require_login } "
383+ )
368384 service_info = await client .register_service (
369385 {
370386 "name" : "Interactive Segmentation" ,
@@ -379,8 +395,8 @@ async def register_service(args: dict) -> None:
379395 "ping" : ping ,
380396 "compute_embedding" : partial (
381397 compute_image_embedding ,
398+ app_handle = app_handle ,
382399 semaphore = semaphore ,
383- app_name = app_name ,
384400 require_login = args .require_login ,
385401 ),
386402 # "compute_mask": partial(
0 commit comments