Skip to content

Commit 979d22d

Browse files
authored
Let make_serving_endpoint reference a valid model version (#106)
## Changes The `make_serving_endpoint` fixture [stopped working as of yesterday](https://github.com/databrickslabs/pytester/actions/runs/13405748534). Apparently, the created models from the `make_model` fixture do not come with a model version anymore (which used to be `'1'`). The `make_serving_endpoint` is updated to fallback on a UC model version. Also, it allows users to provide input parameters in case they want to use another model. ### Linked issues Resolves #databrickslabs/ucx#3714 Resolves #databrickslabs/ucx#3715 ### Tests - [x] added unit tests - [ ] fixed integration test
1 parent 0202115 commit 979d22d

File tree

3 files changed

+133
-23
lines changed

3 files changed

+133
-23
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,7 @@ def test_models(make_group, make_model, make_registered_model_permissions):
10651065
)
10661066
```
10671067

1068-
See also [`make_serving_endpoint`](#make_serving_endpoint-fixture), [`ws`](#ws-fixture), [`make_random`](#make_random-fixture), [`watchdog_remove_after`](#watchdog_remove_after-fixture).
1068+
See also [`ws`](#ws-fixture), [`make_random`](#make_random-fixture), [`watchdog_remove_after`](#watchdog_remove_after-fixture).
10691069

10701070

10711071
[[back to top](#python-testing-for-databricks)]
@@ -1204,6 +1204,13 @@ The function returns a [`ServingEndpointDetailed`](https://databricks-sdk-py.rea
12041204

12051205
Under the covers, this fixture also creates a model to serve on a small workload size.
12061206

1207+
Keyword arguments:
1208+
* `endpoint_name` (str, optional): The name of the endpoint. Defaults to `dummy-*`.
1209+
* `model_name` (str, optional): The name of the model to serve on the endpoint.
1210+
Defaults to system model `system.ai.llama_v3_2_1b_instruct`.
1211+
* `model_version` (str, optional): The model version to serve. If None, tries to get the latest version for
1212+
workspace local models. Otherwise, defaults to version `1`.
1213+
12071214
Usage:
12081215
```python
12091216
def test_endpoints(make_group, make_serving_endpoint, make_serving_endpoint_permissions):
@@ -1216,7 +1223,7 @@ def test_endpoints(make_group, make_serving_endpoint, make_serving_endpoint_perm
12161223
)
12171224
```
12181225

1219-
See also [`ws`](#ws-fixture), [`make_random`](#make_random-fixture), [`make_model`](#make_model-fixture), [`watchdog_remove_after`](#watchdog_remove_after-fixture).
1226+
See also [`ws`](#ws-fixture), [`make_random`](#make_random-fixture), [`watchdog_remove_after`](#watchdog_remove_after-fixture).
12201227

12211228

12221229
[[back to top](#python-testing-for-databricks)]

src/databricks/labs/pytester/fixtures/ml.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
1+
import logging
12
from collections.abc import Callable, Generator
3+
from unittest.mock import Mock
24

35
from pytest import fixture
6+
from databricks.sdk.errors import BadRequest
47
from databricks.sdk.service._internal import Wait
58
from databricks.sdk.service.serving import (
6-
ServingEndpointDetailed,
79
EndpointCoreConfigInput,
10+
EndpointPendingConfig,
11+
EndpointTag,
812
ServedModelInput,
913
ServedModelInputWorkloadSize,
10-
EndpointTag,
14+
ServedModelOutput,
15+
ServingEndpointDetailed,
1116
)
1217
from databricks.sdk.service.ml import CreateExperimentResponse, ModelDatabricks, ModelTag
1318

1419
from databricks.labs.pytester.fixtures.baseline import factory
1520

1621

22+
logger = logging.getLogger(__name__)
23+
24+
1725
@fixture
1826
def make_experiment(
1927
ws,
@@ -103,13 +111,20 @@ def create(*, model_name: str | None = None, **kwargs) -> ModelDatabricks:
103111

104112

105113
@fixture
106-
def make_serving_endpoint(ws, make_random, make_model, watchdog_remove_after):
114+
def make_serving_endpoint(ws, make_random, watchdog_remove_after):
107115
"""
108116
Returns a function to create Databricks Serving Endpoints and clean them up after the test.
109117
The function returns a `databricks.sdk.service.serving.ServingEndpointDetailed` object.
110118
111119
Under the covers, this fixture also creates a model to serve on a small workload size.
112120
121+
Keyword arguments:
122+
* `endpoint_name` (str, optional): The name of the endpoint. Defaults to `dummy-*`.
123+
* `model_name` (str, optional): The name of the model to serve on the endpoint.
124+
Defaults to system model `system.ai.llama_v3_2_1b_instruct`.
125+
* `model_version` (str, optional): The model version to serve. If None, tries to get the latest version for
126+
workspace local models. Otherwise, defaults to version `1`.
127+
113128
Usage:
114129
```python
115130
def test_endpoints(make_group, make_serving_endpoint, make_serving_endpoint_permissions):
@@ -123,27 +138,51 @@ def test_endpoints(make_group, make_serving_endpoint, make_serving_endpoint_perm
123138
```
124139
"""
125140

126-
def create() -> Wait[ServingEndpointDetailed]:
127-
endpoint_name = make_random(4)
128-
model = make_model()
141+
def create(
142+
*,
143+
endpoint_name: str | None = None,
144+
model_name: str | None = None,
145+
model_version: str | None = None,
146+
) -> Wait[ServingEndpointDetailed]:
147+
endpoint_name = endpoint_name or make_random(4)
148+
model_name = model_name or "system.ai.llama_v3_2_1b_instruct"
149+
if not model_version and "." not in model_name: # The period in the name signals it is NOT workspace local
150+
try:
151+
model_version = ws.model_registry.get_latest_versions(model_name).version
152+
except BadRequest as e:
153+
logger.warning(
154+
f"Cannot get latest version for model: {model_name}. Fallback to version '1'.", exc_info=e
155+
)
156+
model_version = model_version or "1"
157+
tags = [EndpointTag(key="RemoveAfter", value=watchdog_remove_after)]
158+
served_model_input = ServedModelInput(
159+
model_name=model_name,
160+
model_version=model_version,
161+
scale_to_zero_enabled=True,
162+
workload_size=ServedModelInputWorkloadSize.SMALL,
163+
)
129164
endpoint = ws.serving_endpoints.create(
130165
endpoint_name,
131-
config=EndpointCoreConfigInput(
132-
served_models=[
133-
ServedModelInput(
134-
model_name=model.name,
135-
model_version="1",
136-
scale_to_zero_enabled=True,
137-
workload_size=ServedModelInputWorkloadSize.SMALL,
138-
)
139-
]
140-
),
141-
tags=[EndpointTag(key="RemoveAfter", value=watchdog_remove_after)],
166+
EndpointCoreConfigInput(served_models=[served_model_input]),
167+
tags=tags,
142168
)
169+
if isinstance(endpoint, Mock): # For testing
170+
served_model_output = ServedModelOutput(
171+
model_name=model_name,
172+
model_version=model_version,
173+
scale_to_zero_enabled=True,
174+
workload_size=ServedModelInputWorkloadSize.SMALL.value,
175+
)
176+
endpoint = ServingEndpointDetailed(
177+
name=endpoint_name,
178+
pending_config=EndpointPendingConfig(served_models=[served_model_output]),
179+
tags=tags,
180+
)
143181
return endpoint
144182

145-
def remove(endpoint_name: str):
146-
ws.serving_endpoints.delete(endpoint_name)
183+
def remove(endpoint: ServingEndpointDetailed) -> None:
184+
if endpoint.name:
185+
ws.serving_endpoints.delete(endpoint.name)
147186

148187
yield from factory("Serving endpoint", create, remove)
149188

tests/unit/fixtures/test_ml.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import pytest
2+
3+
from databricks.sdk.errors import InvalidParameterValue
4+
from databricks.sdk.service.ml import ModelVersion
5+
16
from databricks.labs.pytester.fixtures.ml import make_experiment, make_model, make_serving_endpoint
2-
from databricks.labs.pytester.fixtures.unwrap import call_stateful
7+
from databricks.labs.pytester.fixtures.unwrap import CallContext, call_stateful
38

49

510
def test_make_experiment_no_args():
@@ -14,7 +19,66 @@ def test_make_model_no_args():
1419
assert model is not None
1520

1621

17-
def test_make_serving_endpoint_no_args():
22+
def test_make_serving_endpoint_no_args() -> None:
1823
ctx, serving_endpoint = call_stateful(make_serving_endpoint)
1924
assert ctx is not None
2025
assert serving_endpoint is not None
26+
27+
28+
def test_make_serving_endpoint_sets_default_endpoint_name() -> None:
29+
"""Default endpoint name should be random."""
30+
_, serving_endpoint = call_stateful(make_serving_endpoint)
31+
assert serving_endpoint.name == "RANDOM" # Mocked value in unit tests
32+
33+
34+
def test_make_serving_endpoint_sets_endpoint_name() -> None:
35+
_, serving_endpoint = call_stateful(make_serving_endpoint, endpoint_name="test")
36+
assert serving_endpoint.name == "test"
37+
38+
39+
def test_make_serving_endpoint_sets_default_model_name() -> None:
40+
"""The default model name should be 'system.ai.llama_v3_2_1b_instruct'."""
41+
_, serving_endpoint = call_stateful(make_serving_endpoint)
42+
assert serving_endpoint.pending_config.served_models[0].model_name == "system.ai.llama_v3_2_1b_instruct"
43+
44+
45+
def test_make_serving_endpoint_sets_model_name() -> None:
46+
_, serving_endpoint = call_stateful(make_serving_endpoint, model_name="test")
47+
assert serving_endpoint.pending_config.served_models[0].model_name == "test"
48+
49+
50+
@pytest.mark.parametrize("model_name", [None, "test"])
51+
def test_make_serving_endpoint_sets_default_model_version_to_one(model_name: str | None) -> None:
52+
"""The default model version should be '1' independent
53+
54+
Independent of the model name, if the latest version cannot be retrieved.
55+
"""
56+
57+
def _setup_model_registry_api(call_context: CallContext) -> CallContext:
58+
"""Set up the model registry api for unit testing."""
59+
call_context["ws"].model_registry.get_latest_versions.side_effect = InvalidParameterValue("test")
60+
return call_context
61+
62+
_, serving_endpoint = call_stateful(
63+
make_serving_endpoint, model_name=model_name, call_context_setup=_setup_model_registry_api
64+
)
65+
assert serving_endpoint.pending_config.served_models[0].model_version == "1"
66+
67+
68+
@pytest.mark.parametrize("model_name", [None, "test"])
69+
def test_make_serving_endpoint_sets_model_version(model_name: str | None) -> None:
70+
"""The model version should be the value passed into the fixture.
71+
72+
Independent of the model name, also if the latest version can be retrieved.
73+
"""
74+
75+
def _setup_model_registry_api(call_context: CallContext) -> CallContext:
76+
"""Set up the model registry api for unit testing."""
77+
model_version = ModelVersion(version="3") # Latest version is higher than the expected version
78+
call_context["ws"].model_registry.get_latest_versions.return_value = model_version
79+
return call_context
80+
81+
_, serving_endpoint = call_stateful(
82+
make_serving_endpoint, model_name=model_name, model_version="2", call_context_setup=_setup_model_registry_api
83+
)
84+
assert serving_endpoint.pending_config.served_models[0].model_version == "2"

0 commit comments

Comments
 (0)