Skip to content

Commit 295d4a5

Browse files
authored
feat: add customer config endpoints (#230)
* Add customer config endpoints - added custom summary_prompt as a first prop of the config. It is used in case there is no prompt on the summary request itself * fix lint * address review
1 parent 0946827 commit 295d4a5

File tree

11 files changed

+324
-7
lines changed

11 files changed

+324
-7
lines changed

skynet/env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ def tobool(val: str | None):
3333
app_port = int(os.environ.get('SKYNET_PORT', 8000))
3434
listen_ip = os.environ.get('SKYNET_LISTEN_IP', '0.0.0.0')
3535
log_level = os.environ.get('LOG_LEVEL', 'DEBUG').strip().upper()
36-
supported_modules = {'summaries:dispatcher', 'summaries:executor', 'streaming_whisper', 'assistant'}
37-
enabled_modules = set(os.environ.get('ENABLED_MODULES', 'summaries:dispatcher,summaries:executor,assistant').split(','))
36+
supported_modules = {'summaries:dispatcher', 'summaries:executor', 'streaming_whisper', 'assistant', 'customer_configs'}
37+
enabled_modules = set(
38+
os.environ.get('ENABLED_MODULES', 'summaries:dispatcher,summaries:executor,assistant,customer_configs').split(',')
39+
)
3840
modules = supported_modules.intersection(enabled_modules)
3941
file_refresh_interval = int(os.environ.get('FILE_REFRESH_INTERVAL', 30))
4042

skynet/index.html

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
</head>
55
<body>
66
<h1>Skynet</h1>
7+
<ul>
78
<li>
89
<a href="/assistant/docs">Assistant API</a>
910
</li>
11+
<li>
12+
<a href="/customer-configs/docs">Customer configs API</a>
13+
</li>
1014
<li>
1115
<a href="/summaries/docs">Summaries API</a>
1216
</li>

skynet/main.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ async def lifespan(main_app: FastAPI):
4646
main_app.mount('/assistant', rag_app)
4747
await assistant_startup()
4848

49+
if 'customer_configs' in modules:
50+
from skynet.modules.ttt.customer_configs.app import (
51+
app as customer_configs_app,
52+
app_startup as customer_configs_startup,
53+
)
54+
55+
main_app.mount('/customer-configs', customer_configs_app)
56+
await customer_configs_startup()
57+
4958
if 'summaries:dispatcher' in modules:
5059
from skynet.modules.ttt.openai_api.app import app as openai_api_app
5160
from skynet.modules.ttt.summaries.app import app as summaries_app, app_startup as summaries_startup
@@ -74,6 +83,11 @@ async def lifespan(main_app: FastAPI):
7483

7584
await assistant_shutdown()
7685

86+
if 'customer_configs' in modules:
87+
from skynet.modules.ttt.customer_configs.app import app_shutdown as customer_configs_shutdown
88+
89+
await customer_configs_shutdown()
90+
7791
await http_client.close()
7892

7993

skynet/modules/stt/streaming_whisper/connection_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def send(self, connection: MeetingConnection, results: list[utils.Transcri
5959
f'Meeting {connection.meeting_id}: the connection was closed before sending all results: {e}'
6060
)
6161
await self.disconnect(connection, True)
62-
break # stop trying to send results if websocket is disconnected
62+
break # stop trying to send results if websocket is disconnected
6363
except Exception as ex:
6464
log.error(f'Meeting {connection.meeting_id}: exception while sending transcription results {ex}')
6565

skynet/modules/stt/vox/connection_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ async def send(self, connection: MeetingConnection, results: list[TranscriptionR
3232
f'Session {connection.meeting_id}: the connection was closed before sending all results: {e}'
3333
)
3434
await self.disconnect(connection, True)
35-
break # stop trying to send results if the websocket is disconnected
35+
break # stop trying to send results if the websocket is disconnected
3636
except Exception as ex:
3737
log.error(f'Session {connection.meeting_id}: exception while sending transcription results {ex}')
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from fastapi_versionizer.versionizer import Versionizer
2+
3+
from skynet.logs import get_logger
4+
from skynet.modules.ttt.persistence import db
5+
6+
from skynet.utils import create_app
7+
from .v1.router import router as v1_router
8+
9+
log = get_logger(__name__)
10+
11+
app = create_app()
12+
app.include_router(v1_router)
13+
14+
Versionizer(app=app, prefix_format='/v{major}', sort_routes=True).versionize()
15+
16+
17+
async def app_startup():
18+
"""Startup function for Customer Configs module."""
19+
await db.initialize()
20+
log.info('Persistence initialized')
21+
log.info('customer_configs module initialized')
22+
23+
24+
async def app_shutdown():
25+
"""Shutdown function for Customer Configs module."""
26+
await db.close()
27+
log.info('customer_configs shut down')
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import json
2+
from typing import Optional
3+
4+
from skynet.modules.ttt.persistence import db
5+
6+
7+
def get_customerconfig_key(customer_id: str) -> str:
8+
"""Generate database key for customer configuration."""
9+
return f"customer-config:{customer_id}"
10+
11+
12+
async def get_existing_customer_config(customer_id: str) -> Optional[dict]:
13+
"""Get the customer configuration for a customer if it exists."""
14+
key = get_customerconfig_key(customer_id)
15+
config_json = await db.get(key)
16+
17+
if config_json:
18+
return json.loads(config_json)
19+
20+
return None
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel
4+
5+
6+
class CustomerConfigPayload(BaseModel):
7+
summary_prompt: str
8+
9+
10+
class CustomerConfig(BaseModel):
11+
summary_prompt: Optional[str] = None
12+
13+
14+
class CustomerConfigResponse(BaseModel):
15+
success: bool = True
16+
message: str = "Customer configuration updated successfully"
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from fastapi import Depends, HTTPException
2+
from fastapi_versionizer.versionizer import api_version
3+
4+
from skynet.auth.customer_id import CustomerId
5+
from skynet.env import summary_minimum_payload_length
6+
from skynet.logs import get_logger
7+
from skynet.modules.ttt.customer_configs.utils import get_customerconfig_key, get_existing_customer_config
8+
from skynet.modules.ttt.customer_configs.v1.models import CustomerConfig, CustomerConfigPayload, CustomerConfigResponse
9+
from skynet.modules.ttt.persistence import db
10+
from skynet.utils import get_router
11+
12+
router = get_router()
13+
log = get_logger(__name__)
14+
15+
16+
def validate_customer_config_payload(payload: CustomerConfigPayload) -> None:
17+
if not payload.summary_prompt.strip():
18+
raise HTTPException(status_code=422, detail="summary_prompt cannot be empty")
19+
20+
if len(payload.summary_prompt.strip()) < summary_minimum_payload_length:
21+
raise HTTPException(
22+
status_code=422, detail=f"summary_prompt must be at least {summary_minimum_payload_length} characters"
23+
)
24+
25+
26+
@api_version(1)
27+
@router.get('/config')
28+
async def get_customer_config(customer_id=Depends(CustomerId())) -> CustomerConfig:
29+
"""
30+
Get the current customer config.
31+
"""
32+
config = await get_existing_customer_config(customer_id)
33+
34+
if config:
35+
return CustomerConfig(summary_prompt=config.get('summary_prompt'))
36+
37+
raise HTTPException(status_code=404, detail='Customer configuration not found')
38+
39+
40+
@api_version(1)
41+
@router.post('/config', dependencies=[Depends(validate_customer_config_payload)])
42+
async def set_customer_config(
43+
payload: CustomerConfigPayload, customer_id=Depends(CustomerId())
44+
) -> CustomerConfigResponse:
45+
"""
46+
Set the customer config.
47+
"""
48+
# Store in database
49+
key = get_customerconfig_key(customer_id)
50+
config = {'summary_prompt': payload.summary_prompt}
51+
52+
import json
53+
54+
await db.set(key, json.dumps(config))
55+
56+
log.info(f"Updated customer config for customer {customer_id}")
57+
58+
return CustomerConfigResponse()
59+
60+
61+
@api_version(1)
62+
@router.delete('/config')
63+
async def delete_customer_config(customer_id=Depends(CustomerId())) -> CustomerConfigResponse:
64+
"""
65+
Delete the customer config.
66+
"""
67+
config = await get_existing_customer_config(customer_id)
68+
69+
if not config:
70+
raise HTTPException(status_code=404, detail='Customer configuration not found')
71+
72+
key = get_customerconfig_key(customer_id)
73+
await db.delete(key)
74+
75+
log.info(f"Deleted customer config for customer {customer_id}")
76+
77+
return CustomerConfigResponse()

skynet/modules/ttt/processor.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,22 @@ async def assist(model: BaseChatModel, payload: AssistantDocumentPayload, custom
157157
return await rag_chain.ainvoke(input={'question': question})
158158

159159

160-
async def summarize(model: BaseChatModel, payload: DocumentPayload, job_type: JobType) -> str:
160+
async def summarize(model: BaseChatModel, payload: DocumentPayload, job_type: JobType, customer_id: str) -> str:
161161
chain = None
162162
text = payload.text
163163

164-
system_message = payload.prompt or hint_type_to_prompt[job_type][payload.hint]
164+
# Fallback priority: payload.prompt -> customer's summary_prompt -> hint_type_to_prompt[job_type][payload.hint]
165+
system_message = payload.prompt
166+
167+
if not system_message:
168+
from skynet.modules.ttt.customer_configs.utils import get_existing_customer_config
169+
170+
config = await get_existing_customer_config(customer_id)
171+
if config:
172+
system_message = config.get('summary_prompt')
173+
174+
if not system_message:
175+
system_message = hint_type_to_prompt[job_type][payload.hint]
165176

166177
prompt = ChatPromptTemplate(
167178
[
@@ -240,7 +251,7 @@ async def process(job: Job) -> str:
240251
if job_type == JobType.ASSIST:
241252
result = await assist(llm, payload, customer_id)
242253
elif job_type in [JobType.SUMMARY, JobType.ACTION_ITEMS, JobType.TABLE_OF_CONTENTS]:
243-
result = await summarize(llm, payload, job_type)
254+
result = await summarize(llm, payload, job_type, customer_id)
244255
elif job_type == JobType.PROCESS_TEXT:
245256
result = await process_text(llm, payload)
246257
else:

0 commit comments

Comments
 (0)