Skip to content

Commit 8e5f841

Browse files
committed
feat: add save tags endpoint for image URLs
1 parent 89dc84b commit 8e5f841

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

tagger/api/schema/tags.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,9 @@ class Tags(BaseModel):
2525

2626
class TagsResponse(BaseModel):
2727
tags: List[Tags]
28+
29+
30+
class SaveTagsRequest(BaseModel):
31+
category: str
32+
image: Image
33+
tags: List[Tags]

tagger/api/v1/tags.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,39 @@
11
from fastapi import APIRouter
22

3-
from tagger.api.schema.tags import TagsRequest, TagsResponse
4-
from tagger.core.tags import generate_tags
3+
from tagger.api.schema.tags import SaveTagsRequest, TagsRequest, TagsResponse
4+
from tagger.config.models import VISION_EMBEDDING_MODEL
5+
from tagger.core.tags import (
6+
download_image_url,
7+
generate_tags,
8+
save_tag_embedding,
9+
resize_image,
10+
)
511

612
router = APIRouter(prefix="/tags")
713

814

915
@router.post("/", response_model=TagsResponse)
1016
async def create_tags(tag: TagsRequest):
1117
return generate_tags(tag)
18+
19+
20+
@router.post("/save", response_model=TagsResponse)
21+
async def save_tags(tag: SaveTagsRequest):
22+
"""
23+
Save generated tags for an image to the database.
24+
"""
25+
26+
# Download image from url
27+
base64_image = resize_image(download_image_url(tag.image.url))
28+
29+
# Generate image embedding
30+
image_embedding_value = VISION_EMBEDDING_MODEL.image_embedding([base64_image])[0]
31+
32+
# Save image embedding + tags to database
33+
save_tag_embedding(
34+
category=tag.category,
35+
image_url=tag.image.url,
36+
image_embeddings=image_embedding_value,
37+
coordinates=tag.image.coordinates,
38+
tags=tag.tags,
39+
)

tagger/core/tags.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import requests
1414
import boto3
1515

16-
from tagger.api.schema.tags import Tags, TagsRequest, TagsResponse
16+
from tagger.api.schema.tags import Coordinates, Tags, TagsRequest, TagsResponse
1717
from tagger.config.models import JSON_OUTPUT_MODEL, VISION_EMBEDDING_MODEL, VISION_MODEL
1818
from tagger.config.db import TAGGING_DB_ENGINE
1919
from tagger.config.storage import S3_CLIENT
@@ -218,6 +218,26 @@ def get_similar_images(
218218
]
219219

220220

221+
def save_tag_embedding(
222+
category: str,
223+
image_url: str,
224+
image_embeddings: List[float],
225+
coordinates: Coordinates,
226+
tags: List[Tags],
227+
):
228+
with Session(TAGGING_DB_ENGINE) as session:
229+
tag_embedding = TagEmbedding(
230+
id=None,
231+
category=category,
232+
image_url=image_url,
233+
image_embeddings=image_embeddings,
234+
coordinates=f"POINT({coordinates.lon} {coordinates.lat})",
235+
tags={tag.key: tag.value for tag in tags},
236+
)
237+
session.add(tag_embedding)
238+
session.commit()
239+
240+
221241
def download_image_s3(image_s3_url: str) -> BytesIO:
222242
# Parse S3 URL
223243
parsed_url = urlparse(image_s3_url)

0 commit comments

Comments
 (0)