File tree Expand file tree Collapse file tree 3 files changed +57
-3
lines changed
Expand file tree Collapse file tree 3 files changed +57
-3
lines changed Original file line number Diff line number Diff line change @@ -25,3 +25,9 @@ class Tags(BaseModel):
2525
2626class TagsResponse (BaseModel ):
2727 tags : List [Tags ]
28+
29+
30+ class SaveTagsRequest (BaseModel ):
31+ category : str
32+ image : Image
33+ tags : List [Tags ]
Original file line number Diff line number Diff line change 11from fastapi import APIRouter
22
3- from tagger .api .schema .tags import TagsRequest , TagsResponse
4- from tagger .core .tags import generate_tags
3+ from tagger .api .schema .tags import SaveTagsRequest , TagsRequest , TagsResponse
4+ from tagger .config .models import VISION_EMBEDDING_MODEL
5+ from tagger .core .tags import (
6+ download_image_url ,
7+ generate_tags ,
8+ save_tag_embedding ,
9+ resize_image ,
10+ )
511
612router = APIRouter (prefix = "/tags" )
713
814
915@router .post ("/" , response_model = TagsResponse )
1016async def create_tags (tag : TagsRequest ):
1117 return generate_tags (tag )
18+
19+
20+ @router .post ("/save" , response_model = TagsResponse )
21+ async def save_tags (tag : SaveTagsRequest ):
22+ """
23+ Save generated tags for an image to the database.
24+ """
25+
26+ # Download image from url
27+ base64_image = resize_image (download_image_url (tag .image .url ))
28+
29+ # Generate image embedding
30+ image_embedding_value = VISION_EMBEDDING_MODEL .image_embedding ([base64_image ])[0 ]
31+
32+ # Save image embedding + tags to database
33+ save_tag_embedding (
34+ category = tag .category ,
35+ image_url = tag .image .url ,
36+ image_embeddings = image_embedding_value ,
37+ coordinates = tag .image .coordinates ,
38+ tags = tag .tags ,
39+ )
Original file line number Diff line number Diff line change 1313import requests
1414import boto3
1515
16- from tagger .api .schema .tags import Tags , TagsRequest , TagsResponse
16+ from tagger .api .schema .tags import Coordinates , Tags , TagsRequest , TagsResponse
1717from tagger .config .models import JSON_OUTPUT_MODEL , VISION_EMBEDDING_MODEL , VISION_MODEL
1818from tagger .config .db import TAGGING_DB_ENGINE
1919from tagger .config .storage import S3_CLIENT
@@ -218,6 +218,26 @@ def get_similar_images(
218218 ]
219219
220220
221+ def save_tag_embedding (
222+ category : str ,
223+ image_url : str ,
224+ image_embeddings : List [float ],
225+ coordinates : Coordinates ,
226+ tags : List [Tags ],
227+ ):
228+ with Session (TAGGING_DB_ENGINE ) as session :
229+ tag_embedding = TagEmbedding (
230+ id = None ,
231+ category = category ,
232+ image_url = image_url ,
233+ image_embeddings = image_embeddings ,
234+ coordinates = f"POINT({ coordinates .lon } { coordinates .lat } )" ,
235+ tags = {tag .key : tag .value for tag in tags },
236+ )
237+ session .add (tag_embedding )
238+ session .commit ()
239+
240+
221241def download_image_s3 (image_s3_url : str ) -> BytesIO :
222242 # Parse S3 URL
223243 parsed_url = urlparse (image_s3_url )
You can’t perform that action at this time.
0 commit comments