Skip to content

Commit 39233d1

Browse files
committed
inference fixes
1 parent 31a2d16 commit 39233d1

5 files changed

Lines changed: 70 additions & 45 deletions

File tree

inference/server/oasst_inference_server/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def custom_json_deserializer(s):
3333
return inference.WorkParameters.parse_obj(d)
3434
case "WorkerConfig":
3535
return inference.WorkerConfig.parse_obj(d)
36-
case "MessageRequest":
37-
return chat_schema.MessageRequest.parse_obj(d)
36+
case "CreateMessageRequest":
37+
return chat_schema.CreateMessageRequest.parse_obj(d)
3838
case "WorkRequest":
3939
return inference.WorkRequest.parse_obj(d)
4040
case "WorkResponsePacket":

inference/server/oasst_inference_server/routes/chats.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import fastapi
22
from fastapi import Depends
33
from loguru import logger
4-
from oasst_inference_server import auth, deps, queueing
4+
from oasst_inference_server import auth, deps, models, queueing
55
from oasst_inference_server.schemas import chat as chat_schema
66
from oasst_inference_server.user_chat_repository import UserChatRepository
77
from oasst_shared.schemas import inference
@@ -48,22 +48,21 @@ async def get_chat(
4848
@router.post("/{chat_id}/messages")
4949
async def create_message(
5050
chat_id: str,
51-
message_request: chat_schema.MessageRequest,
52-
fastapi_request: fastapi.Request,
51+
request: chat_schema.CreateMessageRequest,
5352
user_id: str = Depends(auth.get_current_user_id),
54-
) -> EventSourceResponse:
53+
) -> chat_schema.CreateMessageResponse:
5554
"""Allows the client to stream the results of a request."""
5655

5756
async with deps.manual_user_chat_repository(user_id) as ucr:
5857
try:
5958
prompter_message = await ucr.add_prompter_message(
60-
chat_id=chat_id, parent_id=message_request.parent_id, content=message_request.content
59+
chat_id=chat_id, parent_id=request.parent_id, content=request.content
6160
)
6261
assistant_message = await ucr.initiate_assistant_message(
6362
parent_id=prompter_message.id,
64-
work_parameters=message_request.work_parameters,
63+
work_parameters=request.work_parameters,
6564
)
66-
queue = queueing.work_queue(deps.redis_client, message_request.worker_compat_hash)
65+
queue = queueing.work_queue(deps.redis_client, request.worker_compat_hash)
6766
logger.debug(f"Adding {assistant_message.id=} to {queue.queue_id} for {chat_id}")
6867
await queue.enqueue(assistant_message.id)
6968
logger.debug(f"Added {assistant_message.id=} to {queue.queue_id} for {chat_id}")
@@ -73,15 +72,26 @@ async def create_message(
7372
logger.exception("Error adding prompter message")
7473
return fastapi.Response(status_code=500)
7574

76-
async def event_generator(prompter_message: inference.MessageRead, assistant_message: inference.MessageRead):
77-
queue = queueing.message_queue(deps.redis_client, assistant_message.id)
75+
return chat_schema.CreateMessageResponse(
76+
prompter_message=prompter_message_read,
77+
assistant_message=assistant_message_read,
78+
)
79+
80+
81+
@router.get("/{chat_id}/messages/{message_id}/events")
82+
async def message_events(
83+
chat_id: str,
84+
message_id: str,
85+
fastapi_request: fastapi.Request,
86+
ucr: UserChatRepository = Depends(deps.create_user_chat_repository),
87+
) -> EventSourceResponse:
88+
message: models.DbMessage = await ucr.get_message_by_id(chat_id=chat_id, message_id=message_id)
89+
if message.role != "assistant":
90+
raise fastapi.HTTPException(status_code=400, detail="Only assistant messages can be streamed.")
91+
92+
async def event_generator(chat_id: str, message_id: str):
93+
queue = queueing.message_queue(deps.redis_client, message_id=message_id)
7894
try:
79-
yield {
80-
"data": chat_schema.MessageResponseEvent(
81-
prompter_message=prompter_message,
82-
assistant_message=assistant_message,
83-
).json(),
84-
}
8595
while True:
8696
item = await queue.dequeue()
8797
if item is None:
@@ -113,12 +123,7 @@ async def event_generator(prompter_message: inference.MessageRead, assistant_mes
113123
logger.exception(f"Error streaming {chat_id}")
114124
raise
115125

116-
return EventSourceResponse(
117-
event_generator(
118-
prompter_message=prompter_message_read,
119-
assistant_message=assistant_message_read,
120-
)
121-
)
126+
return EventSourceResponse(event_generator(chat_id=chat_id, message_id=message_id))
122127

123128

124129
@router.post("/{chat_id}/messages/{message_id}/votes")

inference/server/oasst_inference_server/schemas/chat.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,7 @@
22
from oasst_shared.schemas import inference
33

44

5-
class CreateChatRequest(pydantic.BaseModel):
6-
pass
7-
8-
9-
class ChatListRead(pydantic.BaseModel):
10-
id: str
11-
12-
13-
class ChatRead(pydantic.BaseModel):
14-
id: str
15-
messages: list[inference.MessageRead]
16-
17-
18-
class ListChatsResponse(pydantic.BaseModel):
19-
chats: list[ChatListRead]
20-
21-
22-
class MessageRequest(pydantic.BaseModel):
5+
class CreateMessageRequest(pydantic.BaseModel):
236
parent_id: str | None = None
247
content: str = pydantic.Field(..., repr=False)
258
work_parameters: inference.WorkParameters = pydantic.Field(default_factory=inference.WorkParameters)
@@ -29,9 +12,9 @@ def worker_compat_hash(self) -> str:
2912
return inference.compat_hash(model_name=self.work_parameters.model_name)
3013

3114

32-
class MessageResponseEvent(pydantic.BaseModel):
15+
class CreateMessageResponse(pydantic.BaseModel):
3316
prompter_message: inference.MessageRead
34-
assistant_message: inference.MessageRead | None
17+
assistant_message: inference.MessageRead
3518

3619

3720
class TokenResponseEvent(pydantic.BaseModel):
@@ -46,3 +29,20 @@ class VoteRequest(pydantic.BaseModel):
4629
class ReportRequest(pydantic.BaseModel):
4730
report_type: inference.ReportType
4831
reason: str
32+
33+
34+
class CreateChatRequest(pydantic.BaseModel):
35+
pass
36+
37+
38+
class ChatListRead(pydantic.BaseModel):
39+
id: str
40+
41+
42+
class ChatRead(pydantic.BaseModel):
43+
id: str
44+
messages: list[inference.MessageRead]
45+
46+
47+
class ListChatsResponse(pydantic.BaseModel):
48+
chats: list[ChatListRead]

inference/server/oasst_inference_server/user_chat_repository.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ async def get_chat_by_id(self, chat_id: str) -> models.DbChat:
3333
chat = (await self.session.exec(query)).one()
3434
return chat
3535

36+
async def get_message_by_id(self, chat_id: str, message_id: str) -> models.DbMessage:
37+
query = (
38+
sqlmodel.select(models.DbMessage)
39+
.where(
40+
models.DbMessage.id == message_id,
41+
models.DbMessage.chat_id == chat_id,
42+
)
43+
.join(models.DbChat)
44+
.where(
45+
models.DbChat.user_id == self.user_id,
46+
)
47+
)
48+
message = (await self.session.exec(query)).one()
49+
return message
50+
3651
async def create_chat(self) -> models.DbChat:
3752
chat = models.DbChat(user_id=self.user_id)
3853
self.session.add(chat)

inference/text-client/__main__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,23 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
4141
"parent_id": parent_id,
4242
"content": message,
4343
},
44+
headers=auth_headers,
45+
)
46+
response.raise_for_status()
47+
message_id = response.json()["assistant_message"]["id"]
48+
49+
response = requests.get(
50+
f"{backend_url}/chats/{chat_id}/messages/{message_id}/events",
4451
stream=True,
4552
headers={
4653
"Accept": "text/event-stream",
4754
**auth_headers,
4855
},
4956
)
5057
response.raise_for_status()
51-
5258
client = sseclient.SSEClient(response)
5359
print("Assistant: ", end="", flush=True)
5460
events = iter(client.events())
55-
message_id = json.loads(next(events).data)["assistant_message"]["id"]
5661
for event in events:
5762
try:
5863
data = json.loads(event.data)

0 commit comments

Comments
 (0)