|
1 | 1 | import argparse |
2 | 2 | import asyncio |
3 | | -from time import sleep |
| 3 | +import os |
| 4 | + |
4 | 5 | import numpy as np |
5 | 6 | from hypha_rpc import connect_to_server |
6 | 7 | from tifffile import imread |
7 | 8 |
|
| 9 | + |
8 | 10 | SERVER_URL = "https://hypha.aicell.io" |
9 | 11 | WORKSPACE_NAME = "bioimageio-colab" |
| 12 | +CLIENT_ID = os.getenv("CLIENT_ID") |
10 | 13 | SERVICE_ID = "microsam" |
11 | 14 | MODEL_IDS = ["sam_vit_b", "sam_vit_b_lm", "sam_vit_b_em_organelles"] |
12 | 15 | IMG_PATH = "./data/example_image.tif" |
13 | 16 |
|
14 | 17 |
|
15 | | -async def run_client( |
16 | | - client_id: int, image: np.ndarray, model_id: str, method_timeout: int = 300 |
17 | | -): |
18 | | - print(f"Client {client_id} started", flush=True) |
| 18 | +async def compute_embedding(req_id, service, image): |
| 19 | + # Prepare image and model ID |
| 20 | + image_prep = image + np.random.normal(0, 0.1, image.shape) |
| 21 | + model_id = MODEL_IDS[np.random.randint(0, len(MODEL_IDS))] |
| 22 | + |
| 23 | + print(f"Sending request {req_id + 1}") |
| 24 | + await service.compute_embedding( |
| 25 | + image=image_prep, |
| 26 | + model_id=model_id, |
| 27 | + ) |
| 28 | + print(f"Request {req_id} finished") |
| 29 | + |
| 30 | + |
| 31 | +async def stress_test(num_requests: int, method_timeout: int = 30): |
| 32 | + # Connect to the server and get the compute service |
| 33 | + service_client_str = f"{CLIENT_ID}:" if CLIENT_ID else "" |
| 34 | + compute_service_id = f"{WORKSPACE_NAME}/{service_client_str}{SERVICE_ID}" |
| 35 | + print(f"Compute service ID: {compute_service_id}") |
19 | 36 | client = await connect_to_server( |
20 | 37 | {"server_url": SERVER_URL, "method_timeout": method_timeout} |
21 | 38 | ) |
22 | | - service = await client.get_service( |
23 | | - f"{WORKSPACE_NAME}/{SERVICE_ID}", {"mode": "random"} |
24 | | - ) |
25 | | - await service.compute_embedding(model_id=model_id, image=image) |
26 | | - print(f"Client {client_id} finished", flush=True) |
27 | | - |
| 39 | + service = await client.get_service(compute_service_id, {"mode": "first"}) |
28 | 40 |
|
29 | | -async def stress_test(num_clients: int): |
| 41 | + # Load the image |
30 | 42 | image = imread(IMG_PATH) |
| 43 | + |
| 44 | + # Send requests |
31 | 45 | tasks = [] |
32 | | - for client_id in range(num_clients): |
33 | | - sleep(0.1) |
34 | | - model_id = MODEL_IDS[np.random.randint(0, len(MODEL_IDS))] |
35 | | - tasks.append(run_client(client_id=client_id, image=image, model_id=model_id)) |
| 46 | + for req_id in range(num_requests): |
| 47 | + tasks.append(compute_embedding(req_id, service, image)) |
36 | 48 | await asyncio.gather(*tasks) |
37 | | - print("All clients finished") |
| 49 | + |
| 50 | + print("All requests completed successfully.") |
38 | 51 |
|
39 | 52 |
|
40 | 53 | if __name__ == "__main__": |
41 | 54 | parser = argparse.ArgumentParser() |
42 | | - parser.add_argument("--num_clients", type=int, default=50) |
| 55 | + parser.add_argument( |
| 56 | + "--num_requests", type=int, default=30, help="Number of requests" |
| 57 | + ) |
43 | 58 | args = parser.parse_args() |
44 | 59 |
|
45 | | - asyncio.run(stress_test(args.num_clients)) |
| 60 | + asyncio.run(stress_test(args.num_requests)) |
0 commit comments