Skip to content

Commit e477e73

Browse files
committed
Add similarity API endpoint
I've refactored `compute_similarity_metadata()` to get query tag by name and not by ID, so that `compute_similarity_metadata()` can create the metadata name from the query name if it's not provided. I've also introduce TagNotFoundError exception.
1 parent 25e8c85 commit e477e73

File tree

7 files changed

+236
-28
lines changed

7 files changed

+236
-28
lines changed

lightly_studio/src/lightly_studio/api/routes/api/metadata.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from typing import List
66
from uuid import UUID
77

8-
from fastapi import APIRouter, Depends, Path
8+
from fastapi import APIRouter, Depends, HTTPException, Path
99
from pydantic import BaseModel, Field
1010
from typing_extensions import Annotated
1111

1212
from lightly_studio.api.routes.api.dataset import get_and_validate_dataset_id
13+
from lightly_studio.api.routes.api.status import HTTP_STATUS_NOT_FOUND
1314
from lightly_studio.db_manager import SessionDep
14-
from lightly_studio.metadata import compute_typicality
15+
from lightly_studio.errors import TagNotFoundError
16+
from lightly_studio.metadata import compute_similarity, compute_typicality
1517
from lightly_studio.models.dataset import DatasetTable
1618
from lightly_studio.models.metadata import MetadataInfoView
1719
from lightly_studio.resolvers import embedding_model_resolver
@@ -89,3 +91,73 @@ def compute_typicality_metadata(
8991
embedding_model_id=embedding_model.embedding_model_id,
9092
metadata_name=request.metadata_name,
9193
)
94+
95+
96+
class ComputeSimilarityRequest(BaseModel):
97+
"""Request model for computing typicality metadata."""
98+
99+
embedding_model_name: str | None = Field(
100+
default=None,
101+
description="Embedding model name (uses default if not specified)",
102+
)
103+
query_tag_name: str = Field(
104+
description="The name of the tag to use for the query",
105+
)
106+
metadata_name: str | None = Field(
107+
default=None,
108+
description="Metadata field name (defaults to None)",
109+
)
110+
111+
112+
@metadata_router.post(
113+
"/metadata/similarity",
114+
status_code=200,
115+
response_model=str,
116+
)
117+
def compute_similarity_metadata(
118+
session: SessionDep,
119+
dataset: Annotated[
120+
DatasetTable,
121+
Depends(get_and_validate_dataset_id),
122+
],
123+
request: ComputeSimilarityRequest,
124+
) -> str:
125+
"""Compute similarity metadata for a dataset.
126+
127+
Args:
128+
session: The database session.
129+
dataset: The dataset to compute similarity for.
130+
request: Request parameters including optional embedding model name
131+
and metadata field name.
132+
133+
Returns:
134+
Metadata name used for the similarity.
135+
136+
Raises:
137+
HTTPException: 404 if invalid embedding model or query tag is given.
138+
"""
139+
try:
140+
embedding_model = embedding_model_resolver.get_by_name(
141+
session=session,
142+
dataset_id=dataset.dataset_id,
143+
embedding_model_name=request.embedding_model_name,
144+
)
145+
except ValueError as e:
146+
raise HTTPException(
147+
status_code=HTTP_STATUS_NOT_FOUND,
148+
detail=f"embedding model {request.embedding_model_name} not found",
149+
) from e
150+
151+
try:
152+
return compute_similarity.compute_similarity_metadata(
153+
session=session,
154+
key_dataset_id=dataset.dataset_id,
155+
query_tag_name=request.query_tag_name,
156+
embedding_model_id=embedding_model.embedding_model_id,
157+
metadata_name=request.metadata_name,
158+
)
159+
except TagNotFoundError as e:
160+
raise HTTPException(
161+
status_code=HTTP_STATUS_NOT_FOUND,
162+
detail=f"Query tag {request.query_tag_name} not found",
163+
) from e

lightly_studio/src/lightly_studio/core/dataset.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import datetime
65
from pathlib import Path
76
from typing import Iterable, Iterator
87
from uuid import UUID
@@ -609,27 +608,18 @@ def compute_similarity_metadata(
609608
Returns:
610609
The name of the metadata storing the similarity values.
611610
"""
612-
query_tag = tag_resolver.get_by_name(
613-
session=self.session, tag_name=query_tag_name, dataset_id=self.dataset_id
614-
)
615-
if query_tag is None:
616-
raise ValueError("Query tag not found")
617611
embedding_model_id = embedding_model_resolver.get_by_name(
618612
session=self.session,
619613
dataset_id=self.dataset_id,
620614
embedding_model_name=embedding_model_name,
621615
).embedding_model_id
622-
date = datetime.datetime.now(datetime.timezone.utc)
623-
if metadata_name is None:
624-
metadata_name = f"similarity_{query_tag_name}_{date.isoformat()}"
625-
compute_similarity.compute_similarity_metadata(
616+
return compute_similarity.compute_similarity_metadata(
626617
session=self.session,
627618
key_dataset_id=self.dataset_id,
628619
embedding_model_id=embedding_model_id,
629-
query_tag_id=query_tag.tag_id,
620+
query_tag_name=query_tag_name,
630621
metadata_name=metadata_name,
631622
)
632-
return metadata_name
633623

634624

635625
def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UUID]) -> None:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Lightly Studio Exceptions types."""
2+
3+
4+
class TagNotFoundError(Exception):
5+
"""Exception signaling that a tag has not been found."""
Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
"""Computes similarity from embeddings."""
22

3+
from datetime import datetime, timezone
4+
from typing import Optional
35
from uuid import UUID
46

57
from lightly_mundig import Similarity # type: ignore[import-untyped]
68
from sqlmodel import Session
79

810
from lightly_studio.dataset.env import LIGHTLY_STUDIO_LICENSE_KEY
9-
from lightly_studio.resolvers import metadata_resolver, sample_embedding_resolver
11+
from lightly_studio.errors import TagNotFoundError
12+
from lightly_studio.resolvers import metadata_resolver, sample_embedding_resolver, tag_resolver
1013
from lightly_studio.resolvers.sample_resolver.sample_filter import SampleFilter
1114

1215

1316
def compute_similarity_metadata(
1417
session: Session,
1518
key_dataset_id: UUID,
1619
embedding_model_id: UUID,
17-
query_tag_id: UUID,
18-
metadata_name: str,
19-
) -> None:
20+
query_tag_name: str,
21+
metadata_name: Optional[str] = None,
22+
) -> str:
2023
"""Computes similarity for each sample in the dataset from embeddings.
2124
2225
Similarity is a measure of how similar a sample is to its nearest neighbor
@@ -31,26 +34,38 @@ def compute_similarity_metadata(
3134
The ID of the dataset the similarity is computed on.
3235
embedding_model_id:
3336
The ID of the embedding model to use for the computation.
34-
query_tag_id:
35-
The ID of the tag describing the query.
37+
query_tag_name:
38+
The name of the tag describing the query.
3639
metadata_name:
3740
The name of the metadata field to store the similarity values in.
38-
Defaults to "similarity".
41+
42+
Raises:
43+
TagNotFoundError if `query_tag_name` does not exist
44+
45+
Returns:
46+
The name of the metadata storing the similarity values.
3947
"""
4048
license_key = LIGHTLY_STUDIO_LICENSE_KEY
4149
if license_key is None:
4250
raise ValueError(
4351
"LIGHTLY_STUDIO_LICENSE_KEY environment variable is not set. "
4452
"Please set it to your LightlyStudio license key."
4553
)
54+
query_tag = tag_resolver.get_by_name(
55+
session=session,
56+
tag_name=query_tag_name,
57+
dataset_id=key_dataset_id,
58+
)
59+
if query_tag is None:
60+
raise TagNotFoundError(f"Query tag {query_tag_name} not found")
4661

4762
key_samples = sample_embedding_resolver.get_all_by_dataset_id(
4863
session=session, dataset_id=key_dataset_id, embedding_model_id=embedding_model_id
4964
)
5065
key_embeddings = [sample.embedding for sample in key_samples]
5166
similarity = Similarity(key_embeddings=key_embeddings, token=license_key)
5267

53-
tag_filter = SampleFilter(tag_ids=[query_tag_id])
68+
tag_filter = SampleFilter(tag_ids=[query_tag.tag_id])
5469
query_samples = sample_embedding_resolver.get_all_by_dataset_id(
5570
session=session,
5671
dataset_id=key_dataset_id,
@@ -59,10 +74,14 @@ def compute_similarity_metadata(
5974
)
6075
query_embeddings = [sample.embedding for sample in query_samples]
6176
similarity_values = similarity.calculate_similarity(query_embeddings=query_embeddings)
77+
if metadata_name is None:
78+
date = datetime.now(timezone.utc)
79+
metadata_name = f"similarity_{query_tag_name}_{date.isoformat()}"
6280

6381
metadata = [
6482
(sample.sample_id, {metadata_name: similarity})
6583
for sample, similarity in zip(key_samples, similarity_values)
6684
]
6785

6886
metadata_resolver.bulk_update_metadata(session, metadata)
87+
return metadata_name

lightly_studio/tests/api/routes/api/test_metadata.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
HTTP_STATUS_OK,
1111
)
1212
from lightly_studio.models.metadata import MetadataInfoView
13-
from lightly_studio.resolvers import image_resolver, metadata_resolver
14-
from tests import helpers_resolvers
13+
from lightly_studio.resolvers import image_resolver, metadata_resolver, tag_resolver
14+
from tests.helpers_resolvers import create_tag, fill_db_with_samples_and_embeddings
1515

1616

1717
def test_get_metadata_info(test_client: TestClient, mocker: MockerFixture) -> None:
@@ -66,7 +66,7 @@ def test_get_metadata_info__empty_response(test_client: TestClient, mocker: Mock
6666
def test_compute_typicality_metadata(test_client: TestClient, db_session: Session) -> None:
6767
"""Test compute typicality metadata endpoint."""
6868
# Create dataset with samples and embeddings
69-
dataset_id = helpers_resolvers.fill_db_with_samples_and_embeddings(
69+
dataset_id = fill_db_with_samples_and_embeddings(
7070
test_db=db_session, n_samples=10, embedding_model_names=["test_embedding_model"]
7171
)
7272

@@ -90,3 +90,40 @@ def test_compute_typicality_metadata(test_client: TestClient, db_session: Sessio
9090
)
9191
assert typicality_value is not None
9292
assert isinstance(typicality_value, float)
93+
94+
95+
def test_compute_similarity_metadata(test_client: TestClient, db_session: Session) -> None:
96+
"""Test compute similarity metadata endpoint."""
97+
dataset_id = fill_db_with_samples_and_embeddings(
98+
test_db=db_session, n_samples=10, embedding_model_names=["test_embedding_model"]
99+
)
100+
query_tag = create_tag(session=db_session, dataset_id=dataset_id, tag_name="query_tag")
101+
samples = image_resolver.get_all_by_dataset_id(
102+
session=db_session, dataset_id=dataset_id
103+
).samples
104+
tag_resolver.add_sample_ids_to_tag_id(
105+
session=db_session,
106+
tag_id=query_tag.tag_id,
107+
sample_ids=[samples[0].sample_id, samples[2].sample_id],
108+
)
109+
110+
response = test_client.post(
111+
f"/api/datasets/{dataset_id}/metadata/similarity", json={"query_tag_name": "query_tag"}
112+
)
113+
114+
assert response.status_code == 200
115+
metadata_name = response.text[1:-1] # We strip the double-quotes
116+
assert metadata_name.startswith("similarity_query_tag_20")
117+
118+
samples = image_resolver.get_all_by_dataset_id(
119+
session=db_session, dataset_id=dataset_id
120+
).samples
121+
assert len(samples) == 10
122+
123+
# Verify all samples have similarity metadata.
124+
for sample in samples:
125+
similarity_value = metadata_resolver.get_value_for_sample(
126+
session=db_session, sample_id=sample.sample_id, key=metadata_name
127+
)
128+
assert similarity_value is not None
129+
assert isinstance(similarity_value, float)

lightly_studio/tests/metadata/test_compute_similarity.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,17 @@ def test_compute_similarity_metadata(test_db: Session) -> None:
3737
)
3838

3939
query_tag = create_tag(session=test_db, dataset_id=dataset_id, tag_name="query_tag")
40-
query_tag_id = query_tag.tag_id
4140
tag_resolver.add_sample_ids_to_tag_id(
4241
session=test_db,
43-
tag_id=query_tag_id,
42+
tag_id=query_tag.tag_id,
4443
sample_ids=[samples[0].sample_id, samples[2].sample_id],
4544
)
4645

4746
compute_similarity.compute_similarity_metadata(
4847
session=test_db,
4948
key_dataset_id=dataset_id,
5049
embedding_model_id=embedding_model_id,
51-
query_tag_id=query_tag_id,
50+
query_tag_name="query_tag",
5251
metadata_name="similarity",
5352
)
5453

0 commit comments

Comments
 (0)