Skip to content

Commit 6562d9f

Browse files
committed
precommit
1 parent 5c3e8b3 commit 6562d9f

File tree

5 files changed

+19
-63
lines changed

5 files changed

+19
-63
lines changed

src/fundus_murag/api/routers/lookup.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,8 @@ def list_all_collections():
4242
summary="Get a `FundusCollection` by its name or MURAG ID.",
4343
)
4444
def get_fundus_collection_by_id(
45-
collection_name: str | None = Query(
46-
None, description="Unique internal name for the collection."
47-
),
48-
murag_id: str | None = Query(
49-
None, description="Unique identifier for the collection in the VectorDB."
50-
),
45+
collection_name: str | None = Query(None, description="Unique internal name for the collection."),
46+
murag_id: str | None = Query(None, description="Unique identifier for the collection in the VectorDB."),
5147
):
5248
try:
5349
if collection_name:
@@ -125,17 +121,13 @@ def get_fundus_records_by_id(
125121
None,
126122
description="An identifier for the `FundusRecord`. If a `FundusRecord` has multiple images, the records share the `fundus_id`.",
127123
),
128-
murag_id: str | None = Query(
129-
None, description="Unique identifier for the record in the VectorDB."
130-
),
124+
murag_id: str | None = Query(None, description="Unique identifier for the record in the VectorDB."),
131125
):
132126
try:
133127
if fundus_id is None and murag_id is None:
134128
raise ValueError("Either `fundus_id` or `murag_id` must be provided.")
135129
elif murag_id is not None and fundus_id is not None:
136-
raise ValueError(
137-
"Either `fundus_id` or `murag_id` must be provided, not both."
138-
)
130+
raise ValueError("Either `fundus_id` or `murag_id` must be provided, not both.")
139131
elif murag_id:
140132
record = vdb.get_fundus_record_by_murag_id(murag_id=murag_id)
141133
elif fundus_id:
@@ -159,9 +151,7 @@ def get_fundus_records_by_id(
159151
summary="Returns the `FundusRecordImage`s from the `FundusRecord` with the specified `murag_id`.",
160152
)
161153
def get_fundus_record_image_by_murag_id(
162-
murag_id: str = Query(
163-
..., description="Unique identifier for the `FundusRecord` in the VectorDB."
164-
),
154+
murag_id: str = Query(..., description="Unique identifier for the `FundusRecord` in the VectorDB."),
165155
):
166156
try:
167157
record_img = vdb.get_fundus_record_image_by_murag_id(murag_id=murag_id)

src/fundus_murag/api/routers/random.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def get_random_fundus_collection(
3535
)
3636
def get_random_fundus_record(
3737
n: int = Query(1, description="The number of random records to return."),
38-
collection_name: str | None = Query(
39-
None, description="Unique internal name for the collection."
40-
),
38+
collection_name: str | None = Query(None, description="Unique internal name for the collection."),
4139
):
4240
try:
4341
return vdb.get_random_fundus_records(n=n, collection_name=collection_name)

src/fundus_murag/api/routers/search.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
)
3030
def fundus_record_i2i_similarity_search(query: SimilaritySearchQuery):
3131
try:
32-
query_embedding = mlc.compute_image_embedding(
33-
base64_image=query.query, return_tensor="np"
34-
).tolist() # type: ignore
32+
query_embedding = mlc.compute_image_embedding(base64_image=query.query, return_tensor="np").tolist() # type: ignore
3533
return vdb._fundus_record_image_similarity_search(
3634
query_embedding=query_embedding,
3735
search_in_collections=query.collection_names,
@@ -48,9 +46,7 @@ def fundus_record_i2i_similarity_search(query: SimilaritySearchQuery):
4846
)
4947
def fundus_record_t2i_similarity_search(query: SimilaritySearchQuery):
5048
try:
51-
query_embedding = mlc.compute_text_embedding(
52-
text=query.query, return_tensor="np"
53-
).tolist() # type: ignore
49+
query_embedding = mlc.compute_text_embedding(text=query.query, return_tensor="np").tolist() # type: ignore
5450
return vdb._fundus_record_image_similarity_search(
5551
query_embedding=query_embedding,
5652
search_in_collections=query.collection_names,
@@ -67,9 +63,7 @@ def fundus_record_t2i_similarity_search(query: SimilaritySearchQuery):
6763
)
6864
def fundus_record_i2t_similarity_search(query: SimilaritySearchQuery):
6965
try:
70-
query_embedding = mlc.compute_image_embedding(
71-
base64_image=query.query, return_tensor="np"
72-
).tolist() # type: ignore
66+
query_embedding = mlc.compute_image_embedding(base64_image=query.query, return_tensor="np").tolist() # type: ignore
7367
return vdb._fundus_record_title_similarity_search(
7468
query_embedding=query_embedding,
7569
search_in_collections=query.collection_names,
@@ -86,9 +80,7 @@ def fundus_record_i2t_similarity_search(query: SimilaritySearchQuery):
8680
)
8781
def fundus_record_t2t_similarity_search(query: SimilaritySearchQuery):
8882
try:
89-
query_embedding = mlc.compute_text_embedding(
90-
text=query.query, return_tensor="np"
91-
).tolist() # type: ignore
83+
query_embedding = mlc.compute_text_embedding(text=query.query, return_tensor="np").tolist() # type: ignore
9284
return vdb._fundus_record_title_similarity_search(
9385
query_embedding=query_embedding,
9486
search_in_collections=query.collection_names,
@@ -140,9 +132,7 @@ def fundus_collection_lexical_search(query: CollectionLexicalSearchQuery):
140132
summary="Perform a semantic similarity search on `FundusCollection`s based on their description.",
141133
)
142134
def fundus_collection_description_similarity_search(query: SimilaritySearchQuery):
143-
query_embedding = mlc.compute_text_embedding(
144-
text=query.query, return_tensor="np"
145-
).tolist() # type: ignore
135+
query_embedding = mlc.compute_text_embedding(text=query.query, return_tensor="np").tolist() # type: ignore
146136
try:
147137
return vdb.fundus_collection_description_similarity_search(
148138
query_embedding=query_embedding,
@@ -159,9 +149,7 @@ def fundus_collection_description_similarity_search(query: SimilaritySearchQuery
159149
)
160150
def fundus_collection_title_similarity_search(query: SimilaritySearchQuery):
161151
try:
162-
query_embedding = mlc.compute_text_embedding(
163-
text=query.query, return_tensor="np"
164-
).tolist() # type: ignore
152+
query_embedding = mlc.compute_text_embedding(text=query.query, return_tensor="np").tolist() # type: ignore
165153
return vdb.fundus_collection_title_similarity_search(
166154
query_embedding=query_embedding,
167155
top_k=query.top_k,

src/fundus_murag/ml/client.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def _wait_for_ready(self, s: int = 60, sleep_t: int = 3) -> None:
2929
if self._is_ready():
3030
logger.info(f"Fundus ML is ready at {self._fundus_ml_url}!")
3131
return
32-
logger.info(
33-
f"Waiting {sleep_t}s for Fundus ML to be ready at {self._fundus_ml_url}..."
34-
)
32+
logger.info(f"Waiting {sleep_t}s for Fundus ML to be ready at {self._fundus_ml_url}...")
3533
time.sleep(sleep_t)
3634
s -= sleep_t
3735

@@ -85,9 +83,7 @@ def _get_embeddings(
8583
return_tensor: Literal["pt", "np"] | None = "np",
8684
squeeze: bool = True,
8785
) -> "EmbeddingsOutput | np.ndarray | torch.Tensor":
88-
response = requests.post(
89-
f"{self._fundus_ml_url}/embed", json=input.model_dump()
90-
)
86+
response = requests.post(f"{self._fundus_ml_url}/embed", json=input.model_dump())
9187
response.raise_for_status()
9288
response_json = response.json()
9389
emb = EmbeddingsOutput.model_validate(response_json)
@@ -103,11 +99,7 @@ def _get_embeddings(
10399
if squeeze:
104100
emb = emb.squeeze()
105101
else:
106-
if (
107-
squeeze
108-
and len(emb.embeddings) == 1
109-
and isinstance(emb.embeddings[0], list)
110-
):
102+
if squeeze and len(emb.embeddings) == 1 and isinstance(emb.embeddings[0], list):
111103
emb.embeddings = emb.embeddings[0]
112104

113105
return emb

src/fundus_murag/ml/server.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,11 @@ def setup(self, device: str):
3737

3838
def decode_request(self, request: EmbeddingsInput) -> dict[str, list | None]:
3939
if request.input_type == "text":
40-
text = (
41-
request.input_data
42-
if isinstance(request.input_data, list)
43-
else [request.input_data]
44-
)
40+
text = request.input_data if isinstance(request.input_data, list) else [request.input_data]
4541
image = None
4642
elif request.input_type == "image":
47-
image_data = (
48-
request.input_data
49-
if isinstance(request.input_data, list)
50-
else [request.input_data]
51-
)
52-
image = [
53-
Image.open(io.BytesIO(base64.b64decode(b64))) for b64 in image_data
54-
]
43+
image_data = request.input_data if isinstance(request.input_data, list) else [request.input_data]
44+
image = [Image.open(io.BytesIO(base64.b64decode(b64))) for b64 in image_data]
5545
text = None
5646
else:
5747
raise ValueError("Invalid request type")
@@ -65,9 +55,7 @@ def _compute_text_embedding(self, text_features: BatchFeature) -> torch.Tensor:
6555

6656
def _compute_image_embedding(self, image_features: BatchFeature) -> torch.Tensor:
6757
with torch.no_grad():
68-
img_emb = self.model.get_image_features(
69-
**image_features.to(self.device, dtype=TORCH_DTYPE)
70-
)
58+
img_emb = self.model.get_image_features(**image_features.to(self.device, dtype=TORCH_DTYPE))
7159
return img_emb
7260

7361
def predict(self, inputs: dict[str, list | None]) -> torch.Tensor:

0 commit comments

Comments
 (0)