|
20 | 20 | import re |
21 | 21 | from typing import Dict |
22 | 22 | from unittest import mock |
| 23 | +from unittest.mock import Mock |
23 | 24 |
|
24 | 25 | import avro.io |
25 | 26 | import avro.schema |
|
33 | 34 | from fastapi.testclient import TestClient |
34 | 35 | from ray import serve |
35 | 36 |
|
36 | | -from kserve import Model, ModelRepository, ModelServer |
| 37 | +from kserve import Model, ModelRepository, ModelServer, model_server |
37 | 38 | from kserve.constants.constants import ( |
38 | 39 | FASTAPI_APP_IMPORT_STRING, |
39 | 40 | INFERENCE_CONTENT_LENGTH_HEADER, |
@@ -1412,3 +1413,64 @@ def test_model_server_backwards_compatibility(self): |
1412 | 1413 |
|
1413 | 1414 | # Context should handle None predictor config gracefully |
1414 | 1415 | # The DataPlane should still be functional even without predictor config |
| 1416 | + |
| 1417 | + |
| 1418 | +@pytest.mark.asyncio |
| 1419 | +class TestModelServerEventLoopWithDummy: |
| 1420 | + async def test_start_passes_event_loop_to_rest_server(self, monkeypatch): |
| 1421 | + created = {} |
| 1422 | + |
| 1423 | + class DummyREST: |
| 1424 | + def __init__(self, *args, **kwargs): |
| 1425 | + created["instance"] = self |
| 1426 | + self.event_loop = kwargs.get("event_loop") |
| 1427 | + |
| 1428 | + def start(self): |
| 1429 | + return None |
| 1430 | + |
| 1431 | + # Patch to capture constructor and avoid side effects |
| 1432 | + monkeypatch.setattr(model_server, "RESTServer", DummyREST) |
| 1433 | + monkeypatch.setattr(model_server.asyncio, "run", Mock()) |
| 1434 | + monkeypatch.setattr( |
| 1435 | + model_server.ModelServer, "setup_event_loop", lambda _: None |
| 1436 | + ) |
| 1437 | + monkeypatch.setattr( |
| 1438 | + model_server.ModelServer, "register_signal_handler", lambda _: None |
| 1439 | + ) |
| 1440 | + |
| 1441 | + ms = model_server.ModelServer(workers=1, event_loop="uvloop") |
| 1442 | + m = DummyModel("TestModel") |
| 1443 | + m.load() |
| 1444 | + ms.start(models=[m]) |
| 1445 | + |
| 1446 | + assert isinstance(created.get("instance"), DummyREST) |
| 1447 | + assert created["instance"].event_loop == "uvloop" |
| 1448 | + |
| 1449 | + async def test_start_passes_event_loop_to_rest_multiprocess(self, monkeypatch): |
| 1450 | + created = {} |
| 1451 | + |
| 1452 | + class DummyMulti: |
| 1453 | + def __init__(self, *args, **kwargs): |
| 1454 | + created["instance"] = self |
| 1455 | + self.event_loop = kwargs.get("event_loop") |
| 1456 | + |
| 1457 | + def start(self): |
| 1458 | + return None |
| 1459 | + |
| 1460 | + # Patch multiprocess REST server and side effects |
| 1461 | + monkeypatch.setattr(model_server, "RESTServerMultiProcess", DummyMulti) |
| 1462 | + monkeypatch.setattr(model_server.asyncio, "run", Mock()) |
| 1463 | + monkeypatch.setattr( |
| 1464 | + model_server.ModelServer, "setup_event_loop", lambda _: None |
| 1465 | + ) |
| 1466 | + monkeypatch.setattr( |
| 1467 | + model_server.ModelServer, "register_signal_handler", lambda _: None |
| 1468 | + ) |
| 1469 | + |
| 1470 | + ms = model_server.ModelServer(workers=4, event_loop="asyncio") |
| 1471 | + m = DummyModel("TestModel") |
| 1472 | + m.load() |
| 1473 | + ms.start(models=[m]) |
| 1474 | + |
| 1475 | + assert isinstance(created.get("instance"), DummyMulti) |
| 1476 | + assert created["instance"].event_loop == "asyncio" |
0 commit comments