Skip to content

Commit 37fa79e

Browse files
Search Logged Models Support Datasets Filter for File Store and Sqlalchemy Store (mlflow#16262)
Signed-off-by: Raymond Zhou <[email protected]>
1 parent f8a885b commit 37fa79e

File tree

8 files changed

+252
-44
lines changed

8 files changed

+252
-44
lines changed

mlflow/server/handlers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,6 +2827,7 @@ def _search_logged_models():
28272827
_assert_required,
28282828
],
28292829
"filter": [_assert_string],
2830+
"datasets": [_assert_array],
28302831
"max_results": [_assert_intlike],
28312832
"order_by": [_assert_array],
28322833
"page_token": [_assert_string],
@@ -2837,6 +2838,17 @@ def _search_logged_models():
28372838
# to avoid serialization issues
28382839
experiment_ids=list(request_message.experiment_ids),
28392840
filter_string=request_message.filter or None,
2841+
datasets=(
2842+
[
2843+
{
2844+
"dataset_name": d.dataset_name,
2845+
"dataset_digest": d.dataset_digest or None,
2846+
}
2847+
for d in request_message.datasets
2848+
]
2849+
if request_message.datasets
2850+
else None
2851+
),
28402852
max_results=request_message.max_results or None,
28412853
order_by=(
28422854
[

mlflow/store/tracking/file_store.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
from collections import defaultdict
1010
from dataclasses import dataclass
11-
from typing import Any, NamedTuple, Optional
11+
from typing import Any, NamedTuple, Optional, TypedDict
1212

1313
from mlflow.entities import (
1414
Dataset,
@@ -165,6 +165,15 @@ def _read_persisted_run_info_dict(run_info_dict):
165165
return RunInfo.from_dictionary(dict_copy)
166166

167167

168+
class DatasetFilter(TypedDict, total=False):
169+
"""
170+
Dataset filter used for search_logged_models.
171+
"""
172+
173+
dataset_name: str
174+
dataset_digest: str
175+
176+
168177
class FileStore(AbstractStore):
169178
TRASH_FOLDER_NAME = ".trash"
170179
ARTIFACTS_FOLDER_NAME = "artifacts"
@@ -2285,7 +2294,7 @@ def search_logged_models(
22852294
self,
22862295
experiment_ids: list[str],
22872296
filter_string: Optional[str] = None,
2288-
datasets: Optional[list[str]] = None,
2297+
datasets: Optional[list[DatasetFilter]] = None,
22892298
max_results: Optional[int] = None,
22902299
order_by: Optional[list[dict[str, Any]]] = None,
22912300
page_token: Optional[str] = None,
@@ -2299,8 +2308,8 @@ def search_logged_models(
22992308
datasets: List of dictionaries to specify datasets on which to apply metrics filters.
23002309
The following fields are supported:
23012310
2302-
name (str): Required. Name of the dataset.
2303-
digest (str): Optional. Digest of the dataset.
2311+
dataset_name (str): Required. Name of the dataset.
2312+
dataset_digest (str): Optional. Digest of the dataset.
23042313
max_results: Maximum number of logged models desired. Default is 100.
23052314
order_by: List of dictionaries to specify the ordering of the search results.
23062315
The following fields are supported:
@@ -2321,17 +2330,17 @@ def search_logged_models(
23212330
A :py:class:`PagedList <mlflow.store.entities.PagedList>` of
23222331
:py:class:`LoggedModel <mlflow.entities.LoggedModel>` objects.
23232332
"""
2324-
if datasets:
2333+
if datasets and not all(d.get("dataset_name") for d in datasets):
23252334
raise MlflowException(
2326-
"Filtering by datasets is not currently supported by FileStore",
2335+
"`dataset_name` in the `datasets` clause must be specified.",
23272336
INVALID_PARAMETER_VALUE,
23282337
)
23292338
max_results = max_results or SEARCH_LOGGED_MODEL_MAX_RESULTS_DEFAULT
23302339
all_models = []
23312340
for experiment_id in experiment_ids:
23322341
models = self._list_models(experiment_id)
23332342
all_models.extend(models)
2334-
filtered = SearchLoggedModelsUtils.filter_logged_models(all_models, filter_string)
2343+
filtered = SearchLoggedModelsUtils.filter_logged_models(all_models, filter_string, datasets)
23352344
sorted_logged_models = SearchLoggedModelsUtils.sort(filtered, order_by)
23362345
logged_models, next_page_token = SearchLoggedModelsUtils.paginate(
23372346
sorted_logged_models, page_token, max_results

mlflow/store/tracking/rest_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,8 +1008,8 @@ def search_logged_models(
10081008
datasets: List of dictionaries to specify datasets on which to apply metrics filters.
10091009
The following fields are supported:
10101010
1011-
name (str): Required. Name of the dataset.
1012-
digest (str): Optional. Digest of the dataset.
1011+
dataset_name (str): Required. Name of the dataset.
1012+
dataset_digest (str): Optional. Digest of the dataset.
10131013
max_results: Maximum number of logged models desired.
10141014
order_by: List of dictionaries to specify the ordering of the search results.
10151015
The following fields are supported:

mlflow/store/tracking/sqlalchemy_store.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import uuid
88
from collections import defaultdict
99
from functools import reduce
10-
from typing import Any, Optional
10+
from typing import Any, Optional, TypedDict
1111

1212
import sqlalchemy
1313
import sqlalchemy.orm
@@ -126,6 +126,15 @@
126126
sqlalchemy.orm.configure_mappers()
127127

128128

129+
class DatasetFilter(TypedDict, total=False):
130+
"""
131+
Dataset filter used for search_logged_models.
132+
"""
133+
134+
dataset_name: str
135+
dataset_digest: str
136+
137+
129138
class SqlAlchemyStore(AbstractStore):
130139
"""
131140
SQLAlchemy compliant backend store for tracking meta data for MLflow entities. MLflow
@@ -1968,31 +1977,46 @@ def _apply_order_by_search_logged_models(
19681977

19691978
return models.order_by(*order_by_clauses)
19701979

1971-
def _apply_filter_string_search_logged_models(
1980+
def _apply_filter_string_datasets_search_logged_models(
19721981
self,
19731982
models: sqlalchemy.orm.Query,
19741983
session: sqlalchemy.orm.Session,
19751984
experiment_ids: list[str],
19761985
filter_string: Optional[str],
1986+
datasets: Optional[list[dict[str, Any]]],
19771987
):
19781988
from mlflow.utils.search_logged_model_utils import EntityType, parse_filter_string
19791989

19801990
comparisons = parse_filter_string(filter_string)
19811991
dialect = self._get_dialect()
19821992
attr_filters: list[sqlalchemy.BinaryExpression] = []
19831993
non_attr_filters: list[sqlalchemy.BinaryExpression] = []
1994+
1995+
dataset_filters = []
1996+
if datasets:
1997+
for dataset in datasets:
1998+
dataset_filter = SqlLoggedModelMetric.dataset_name == dataset["dataset_name"]
1999+
if "dataset_digest" in dataset:
2000+
dataset_filter = dataset_filter & (
2001+
SqlLoggedModelMetric.dataset_digest == dataset["dataset_digest"]
2002+
)
2003+
dataset_filters.append(dataset_filter)
2004+
2005+
has_metric_filters = False
19842006
for comp in comparisons:
19852007
comp_func = SearchUtils.get_sql_comparison_func(comp.op, dialect)
19862008
if comp.entity.type == EntityType.ATTRIBUTE:
19872009
attr_filters.append(comp_func(getattr(SqlLoggedModel, comp.entity.key), comp.value))
19882010
elif comp.entity.type == EntityType.METRIC:
2011+
has_metric_filters = True
2012+
metric_filters = [
2013+
SqlLoggedModelMetric.metric_name == comp.entity.key,
2014+
comp_func(SqlLoggedModelMetric.metric_value, comp.value),
2015+
]
2016+
if dataset_filters:
2017+
metric_filters.append(sqlalchemy.or_(*dataset_filters))
19892018
non_attr_filters.append(
1990-
session.query(SqlLoggedModelMetric)
1991-
.filter(
1992-
SqlLoggedModelMetric.metric_name == comp.entity.key,
1993-
comp_func(SqlLoggedModelMetric.metric_value, comp.value),
1994-
)
1995-
.subquery()
2019+
session.query(SqlLoggedModelMetric).filter(*metric_filters).subquery()
19962020
)
19972021
elif comp.entity.type == EntityType.PARAM:
19982022
non_attr_filters.append(
@@ -2016,6 +2040,17 @@ def _apply_filter_string_search_logged_models(
20162040
for f in non_attr_filters:
20172041
models = models.join(f)
20182042

2043+
# If there are dataset filters but no metric filters,
2044+
# filter for models that have any metrics on the datasets
2045+
if dataset_filters and not has_metric_filters:
2046+
subquery = (
2047+
session.query(SqlLoggedModelMetric.model_id)
2048+
.filter(sqlalchemy.or_(*dataset_filters))
2049+
.distinct()
2050+
.subquery()
2051+
)
2052+
models = models.join(subquery)
2053+
20192054
return models.filter(
20202055
SqlLoggedModel.lifecycle_stage != LifecycleStage.DELETED,
20212056
SqlLoggedModel.experiment_id.in_(experiment_ids),
@@ -2026,14 +2061,14 @@ def search_logged_models(
20262061
self,
20272062
experiment_ids: list[str],
20282063
filter_string: Optional[str] = None,
2029-
datasets: Optional[list[DatasetInput]] = None,
2064+
datasets: Optional[list[DatasetFilter]] = None,
20302065
max_results: Optional[int] = None,
20312066
order_by: Optional[list[dict[str, Any]]] = None,
20322067
page_token: Optional[str] = None,
20332068
) -> PagedList[LoggedModel]:
2034-
if datasets:
2069+
if datasets and not all(d.get("dataset_name") for d in datasets):
20352070
raise MlflowException(
2036-
"Filtering by datasets is not currently supported by SqlAlchemyStore",
2071+
"`dataset_name` in the `datasets` clause must be specified.",
20372072
INVALID_PARAMETER_VALUE,
20382073
)
20392074
if page_token:
@@ -2046,8 +2081,8 @@ def search_logged_models(
20462081
max_results = max_results or SEARCH_LOGGED_MODEL_MAX_RESULTS_DEFAULT
20472082
with self.ManagedSessionMaker() as session:
20482083
models = session.query(SqlLoggedModel)
2049-
models = self._apply_filter_string_search_logged_models(
2050-
models, session, experiment_ids, filter_string
2084+
models = self._apply_filter_string_datasets_search_logged_models(
2085+
models, session, experiment_ids, filter_string, datasets
20512086
)
20522087
models = self._apply_order_by_search_logged_models(models, session, order_by)
20532088
models = models.offset(offset).limit(max_results + 1).all()

mlflow/utils/search_utils.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from sqlparse.tokens import Token as TokenType
2222

23-
from mlflow.entities import LoggedModel, RunInfo
23+
from mlflow.entities import LoggedModel, Metric, RunInfo
2424
from mlflow.entities.model_registry.model_version_stages import STAGE_DELETED_INTERNAL
2525
from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
2626
from mlflow.exceptions import MlflowException
@@ -593,6 +593,13 @@ def is_dataset(cls, key_type, comparator):
593593
return True
594594
return False
595595

596+
@classmethod
597+
def _is_metric_on_dataset(cls, metric: Metric, dataset: dict[str, Any]) -> bool:
598+
return metric.dataset_name == dataset.get("dataset_name") and (
599+
dataset.get("dataset_digest") is None
600+
or dataset.get("dataset_digest") == metric.dataset_digest
601+
)
602+
596603
@classmethod
597604
def _does_run_match_clause(cls, run, sed):
598605
key_type = sed.get("type")
@@ -1859,7 +1866,12 @@ class SearchLoggedModelsUtils(SearchUtils):
18591866
VALID_ORDER_BY_ATTRIBUTE_KEYS = VALID_SEARCH_ATTRIBUTE_KEYS
18601867

18611868
@classmethod
1862-
def _does_logged_model_match_clause(cls, model: LoggedModel, condition: dict[str, Any]):
1869+
def _does_logged_model_match_clause(
1870+
cls,
1871+
model: LoggedModel,
1872+
condition: dict[str, Any],
1873+
datasets: Optional[list[dict[str, Any]]] = None,
1874+
):
18631875
key_type = condition.get("type")
18641876
key = condition.get("key")
18651877
value = condition.get("value")
@@ -1869,6 +1881,12 @@ def _does_logged_model_match_clause(cls, model: LoggedModel, condition: dict[str
18691881

18701882
if cls.is_metric(key_type, comparator):
18711883
matching_metrics = [metric for metric in model.metrics if metric.key == key]
1884+
if datasets:
1885+
matching_metrics = [
1886+
metric
1887+
for metric in matching_metrics
1888+
if any(cls._is_metric_on_dataset(metric, dataset) for dataset in datasets)
1889+
]
18721890
lhs = matching_metrics[0].value if matching_metrics else None
18731891
value = float(value)
18741892
elif cls.is_param(key_type, comparator):
@@ -1896,15 +1914,34 @@ def validate_list_supported(cls, key: str) -> None:
18961914
"""
18971915

18981916
@classmethod
1899-
def filter_logged_models(cls, models: list[LoggedModel], filter_string: Optional[str] = None):
1900-
"""Filters a set of runs based on a search filter string."""
1901-
if not filter_string:
1917+
def filter_logged_models(
1918+
cls,
1919+
models: list[LoggedModel],
1920+
filter_string: Optional[str] = None,
1921+
datasets: Optional[list[dict[str, Any]]] = None,
1922+
):
1923+
"""Filters a set of runs based on a search filter string and list of dataset filters."""
1924+
if not filter_string and not datasets:
19021925
return models
19031926

19041927
parsed = cls.parse_search_filter(filter_string)
19051928

1929+
# If there are dataset filters but no metric filters in the filter string,
1930+
# filter for models that have any metrics on the datasets
1931+
if datasets and not any(
1932+
cls.is_metric(s.get("type"), s.get("comparator").upper()) for s in parsed
1933+
):
1934+
1935+
def model_has_metrics_on_datasets(model):
1936+
return any(
1937+
any(cls._is_metric_on_dataset(metric, dataset) for dataset in datasets)
1938+
for metric in model.metrics
1939+
)
1940+
1941+
models = [model for model in models if model_has_metrics_on_datasets(model)]
1942+
19061943
def model_matches(model):
1907-
return all(cls._does_logged_model_match_clause(model, s) for s in parsed)
1944+
return all(cls._does_logged_model_match_clause(model, s, datasets) for s in parsed)
19081945

19091946
return [model for model in models if model_matches(model)]
19101947

0 commit comments

Comments
 (0)