Skip to content

Commit 368a7e9

Browse files
committed
Implemented multimodal chat endpoint with image, textand audio as input and audio, text as output
2 parents 05eb65b + d622df3 commit 368a7e9

10 files changed

Lines changed: 286 additions & 65 deletions

app/api/v1/multimodal.py

Lines changed: 107 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
from typing import Optional
23

34
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
45
from sqlalchemy.ext.asyncio import AsyncSession
@@ -8,9 +9,11 @@
89
from app.crud.message import create_message
910
from app.crud.session import create_chat_session, get_chat_session
1011
from app.db.session import get_async_session
12+
from app.models import VoiceStyle
1113
from app.models.attachment import MediaType
1214
from app.models.message import RoleEnum
1315
from app.services import (
16+
AudioOutput,
1417
UploadToS3,
1518
extract_text_from_s3_image,
1619
generate_response,
@@ -20,18 +23,33 @@
2023
router = APIRouter(prefix="/multimodal", tags=["Multimodal"])
2124

2225

23-
@router.post("/media")
24-
async def upload_media(
25-
file: UploadFile = File(...),
26-
session_id: uuid.UUID | None = Form(None),
27-
prompt: str | None = Form(None),
26+
@router.post("/chat")
27+
async def multimodal_chat(
28+
file: Optional[UploadFile] = File(None, description="Optional image/audio file."),
29+
session_id: Optional[uuid.UUID] = Form(None),
30+
prompt: Optional[str] = Form(None, description="User text input or question."),
31+
audio_output: bool = Form(False, description="Return response as audio if True."),
32+
voice_style: VoiceStyle = Form(
33+
VoiceStyle.alloy,
34+
description="""
35+
Choose the output voice style:\n
36+
- alloy: Versatile and neutral-sounding voice.\n
37+
- echo: Warm and resonant voice.\n
38+
- fable: Clear and articulate voice.\n
39+
- onyx: Deep and commanding voice.\n
40+
- nova: Bright and energetic voice.\n
41+
- shimmer: Smooth and calming voice.\n
42+
""",
43+
),
2844
db: AsyncSession = Depends(get_async_session),
2945
current_user=Depends(get_current_user),
3046
):
3147
"""
32-
Accepts image and audio
48+
Handles multimodal chat:
49+
- Accepts optional text (`prompt`) or media file (image/audio)
50+
- Supports text or audio output response
51+
- Returns assistant response and optional audio file URL
3352
"""
34-
# Accept images and audio formats
3553
SUPPORTED_TYPES = {
3654
"image": ["image/jpeg", "image/png", "image/webp"],
3755
"audio": [
@@ -43,71 +61,107 @@ async def upload_media(
4361
"audio/ogg",
4462
],
4563
}
64+
4665
all_types = SUPPORTED_TYPES["image"] + SUPPORTED_TYPES["audio"]
4766

48-
if file.content_type not in all_types:
49-
raise HTTPException(status_code=400, detail="Unsupported file type")
67+
# Check if both file and prompt are empty
68+
if not file and not prompt:
69+
raise HTTPException(status_code=400, detail="Either file or prompt is required")
5070

51-
# If no session provided, create one
71+
# Session handling
5272
if not session_id:
5373
session = await create_chat_session(db, current_user.id, title="Media Session")
5474
else:
5575
session = await get_chat_session(db, session_id)
5676
if not session or session.user_id != current_user.id:
5777
raise HTTPException(status_code=403, detail="Invalid session")
5878

59-
# Read file
60-
file_bytes = await file.read()
61-
62-
# Upload to S3
63-
s3_obj = UploadToS3()
64-
file_url = s3_obj.upload_file_to_s3(file_bytes, file.filename, file.content_type)
65-
66-
# Determine type and handle processing
67-
if file.content_type in SUPPORTED_TYPES["image"]:
68-
media_type = MediaType.image
69-
text_content = extract_text_from_s3_image(file_url) # OCR for images
70-
content_summary = f"User uploaded an image: The content of image is: \n{text_content}. Prompt:\n {prompt}"
71-
72-
elif file.content_type in SUPPORTED_TYPES["audio"]:
73-
media_type = MediaType.audio
74-
# Transcribe the audio file using AWS Transcribe
75-
job_name = "audio_transcribe"
76-
text_content = transcribe_file(job_name=job_name, s3_uri=file_url)
77-
78-
content_summary = f"User uploaded an audio file: The transcribe of audio is: \n{text_content}. In whatever language the transcribe is convert it into english and then reply only in English, Unless user explicitly asks for the specific language. \nPrompt:\n {prompt}"
79+
# Step 1: Determine content
80+
content_summary = ""
81+
file_url = None
82+
media_type = None
83+
84+
if file:
85+
if file.content_type not in all_types:
86+
raise HTTPException(status_code=400, detail="Unsupported file type")
87+
88+
file_bytes = await file.read()
89+
s3_obj = UploadToS3()
90+
file_url = s3_obj.upload_file_to_s3(
91+
file_bytes, file.filename, file.content_type
92+
)
93+
94+
# Image Processing
95+
if file.content_type in SUPPORTED_TYPES["image"]:
96+
media_type = MediaType.image
97+
ocr_text = extract_text_from_s3_image(file_url)
98+
content_summary = (
99+
f"User uploaded an image. Extracted text: {ocr_text}. Prompt: {prompt}"
100+
)
101+
102+
# Audio Processing
103+
elif file.content_type in SUPPORTED_TYPES["audio"]:
104+
media_type = MediaType.audio
105+
text_content = transcribe_file(job_name="audio_transcribe", s3_uri=file_url)
106+
content_summary = (
107+
f"User uploaded an audio file. Transcription: {text_content}. "
108+
f"Convert it to English unless explicitly asked otherwise. Prompt: {prompt}"
109+
)
110+
111+
# Save attachment
112+
user_msg = await create_message(
113+
db, session.id, RoleEnum.user, prompt or f"Uploaded a {media_type.value}"
114+
)
115+
await create_attachment(
116+
db,
117+
session.id,
118+
user_msg.id,
119+
file_url,
120+
media_type,
121+
{"filename": file.filename},
122+
)
79123
else:
80-
raise HTTPException(status_code=400, detail="Unsupported media type")
81-
82-
# Save user message with attachment
83-
user_msg = await create_message(
84-
db, session.id, RoleEnum.user, prompt or f"Uploaded a {media_type.value}"
85-
)
86-
await create_attachment(
87-
db,
88-
session.id,
89-
user_msg.id,
90-
file_url,
91-
media_type,
92-
{"filename": file.filename},
93-
)
94-
95-
history = [
96-
{
97-
"role": "user",
98-
"content": content_summary,
99-
}
100-
]
124+
# Text-only chat
125+
content_summary = f"User says: {prompt}"
101126

127+
# Step 2: Generate Assistant Response
128+
history = [{"role": "user", "content": content_summary}]
102129
assistant_content = await generate_response(history)
103130

104131
assistant_msg = await create_message(
105132
db, session.id, RoleEnum.assistant, assistant_content
106133
)
107134

108-
return {
135+
response_payload = {
109136
"assistant_message": assistant_msg.content,
110-
"file_url": file_url,
111137
"session_id": str(session.id),
112138
"message_id": str(assistant_msg.id),
113139
}
140+
141+
# Step 3: Audio Output
142+
if audio_output:
143+
# Convert text to audio and upload to S3
144+
audio_output_service = AudioOutput()
145+
audio_s3_url = await audio_output_service.convert_text_into_audio(
146+
assistant_content=assistant_content,
147+
voice_style=voice_style.value,
148+
)
149+
150+
response_payload["audio_output_url"] = audio_s3_url
151+
152+
# Save assistant audio as an attachment in DB
153+
await create_attachment(
154+
db=db,
155+
session_id=session.id,
156+
message_id=assistant_msg.id,
157+
url=file_url,
158+
media_type=media_type,
159+
metadata_={"voice_style": voice_style.value},
160+
audio_url=audio_s3_url,
161+
)
162+
163+
# Add media file link if uploaded
164+
if file_url:
165+
response_payload["uploaded_file_url"] = file_url
166+
167+
return response_payload

app/crud/attachments.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
1+
from app.models.attachment import Attachment
12
from sqlalchemy.ext.asyncio import AsyncSession
2-
from app.models.attachment import Attachment, MediaType
33

4-
async def create_attachment(db: AsyncSession, session_id, message_id, url: str, media_type: MediaType, metadata: dict = None) -> Attachment:
5-
attach = Attachment(session_id=session_id, message_id=message_id, url=url, media_type=media_type, metadata=metadata)
6-
db.add(attach)
4+
5+
async def create_attachment(
6+
db: AsyncSession,
7+
session_id,
8+
message_id,
9+
url,
10+
media_type,
11+
metadata_=None,
12+
audio_url=None,
13+
):
14+
attachment = Attachment(
15+
session_id=session_id,
16+
message_id=message_id,
17+
url=url,
18+
media_type=media_type,
19+
metadata_=metadata_,
20+
audio_url=audio_url,
21+
)
22+
db.add(attachment)
723
await db.commit()
8-
await db.refresh(attach)
9-
return attach
24+
await db.refresh(attachment)
25+
return attachment

app/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .chat_session import ChatSession
22
from .message import Message, RoleEnum
33
from .user import User
4-
from .attachment import Attachment
4+
from .attachment import Attachment
5+
from .voice_styles import VoiceStyle

app/models/attachment.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,19 @@ class Attachment(Base):
2323
nullable=True,
2424
)
2525
message_id = Column(
26-
UUID(as_uuid=True), ForeignKey("messages.id", ondelete="CASCADE"), nullable=True
26+
UUID(as_uuid=True),
27+
ForeignKey("messages.id", ondelete="CASCADE"),
28+
nullable=True,
2729
)
2830

29-
url = Column(String, nullable=False)
30-
media_type = Column(Enum(MediaType), nullable=False)
31+
url = Column(String, nullable=True)
32+
media_type = Column(Enum(MediaType), nullable=True)
3133
metadata_ = Column(JSONB, nullable=True)
3234

35+
# To store generated assistant audio responses
36+
audio_url = Column(String, nullable=True)
37+
3338
created_at = Column(DateTime, default=datetime.datetime.utcnow)
3439

3540
session = relationship("ChatSession", back_populates="attachments")
3641
message = relationship("Message", back_populates="attachments")
37-

app/models/voice_styles.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from enum import Enum
2+
3+
4+
class VoiceStyle(str, Enum):
5+
alloy = "alloy"
6+
verse = "verse"
7+
fable = "fable"
8+
onyx = "onyx"
9+
nova = "nova"
10+
shimmer = "shimmer"

app/services/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .llm_client import generate_response
22
from .textract import extract_text_from_s3_image
33
from .s3_storage import UploadToS3
4-
from .transcribe import transcribe_file
4+
from .transcribe import transcribe_file
5+
from .audio_output import AudioOutput

app/services/audio_output.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from openai import AsyncOpenAI
2+
import tempfile
3+
from app.core.config import settings
4+
import uuid
5+
from app.services import UploadToS3
6+
7+
8+
class AudioOutput:
9+
def __init__(self):
10+
self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
11+
self.s3_obj = UploadToS3()
12+
13+
async def convert_text_into_audio(self, voice_style: str, assistant_content: str):
14+
"""
15+
Converts text into audio using OpenAI's TTS model and uploads it to S3.
16+
"""
17+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
18+
audio_resp = await self.client.audio.speech.create(
19+
model="gpt-4o-mini-tts",
20+
voice=voice_style,
21+
input=assistant_content,
22+
)
23+
audio_resp.stream_to_file(temp_audio.name)
24+
25+
# Upload generated audio to S3
26+
with open(temp_audio.name, "rb") as audio_file:
27+
audio_url = self.s3_obj.upload_file_to_s3(
28+
audio_file.read(), f"{uuid.uuid4()}.mp3", "audio/mpeg"
29+
)
30+
31+
return audio_url
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Make media_type nullable
2+
3+
Revision ID: 0adeb6494274
4+
Revises: f8e5c60d6d7c
5+
Create Date: 2025-10-05 17:50:45.936151
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
from sqlalchemy.dialects import postgresql
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = '0adeb6494274'
16+
down_revision: Union[str, Sequence[str], None] = 'f8e5c60d6d7c'
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
"""Upgrade schema."""
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.alter_column('attachments', 'media_type',
25+
existing_type=postgresql.ENUM('image', 'audio', name='mediatype'),
26+
nullable=True)
27+
# ### end Alembic commands ###
28+
29+
30+
def downgrade() -> None:
31+
"""Downgrade schema."""
32+
# ### commands auto generated by Alembic - please adjust! ###
33+
op.alter_column('attachments', 'media_type',
34+
existing_type=postgresql.ENUM('image', 'audio', name='mediatype'),
35+
nullable=False)
36+
# ### end Alembic commands ###

0 commit comments

Comments
 (0)