Skip to content

Commit 0f0012a

Browse files
committed
feat: Add Chat API + Multi-turn support
- Implement UserSessionChatAPIData for Chat API with multi-turn chat - Update SharedPrefixDataGenerator to support Chat API + Multi-turn combination - Fix prompt selection to use different questions for each turn - Initialize user session context as list for Chat API compatibility
1 parent 1fe4026 commit 0f0012a

File tree

3 files changed

+141
-24
lines changed

3 files changed

+141
-24
lines changed

inference_perf/apis/chat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ChatMessage(BaseModel):
3131
class ChatCompletionAPIData(InferenceAPIData):
3232
messages: List[ChatMessage]
3333
max_tokens: int = 0
34+
model_response: str = "" # Store the assistant response for multi-turn chat
3435

3536
def get_api_type(self) -> APIType:
3637
return APIType.Chat
@@ -81,6 +82,7 @@ async def process_response(self, response: ClientResponse, config: APIConfig, to
8182
prompt_text = "".join([msg.content for msg in self.messages if msg.content])
8283
prompt_len = tokenizer.count_tokens(prompt_text)
8384
output_len = tokenizer.count_tokens(output_text)
85+
self.model_response = output_text # Store response for multi-turn chat
8486
return InferenceInfo(
8587
input_tokens=prompt_len,
8688
output_tokens=output_len,
@@ -93,6 +95,7 @@ async def process_response(self, response: ClientResponse, config: APIConfig, to
9395
if len(choices) == 0:
9496
return InferenceInfo(input_tokens=prompt_len)
9597
output_text = "".join([choice.get("message", {}).get("content", "") for choice in choices])
98+
self.model_response = output_text # Store response for multi-turn chat
9699
output_len = tokenizer.count_tokens(output_text)
97100
return InferenceInfo(
98101
input_tokens=prompt_len,

inference_perf/apis/user_session.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,24 @@
77
from inference_perf.apis import CompletionAPIData, InferenceInfo
88
from inference_perf.utils.custom_tokenizer import CustomTokenizer
99
from inference_perf.config import APIConfig
10+
from inference_perf.apis.chat import ChatCompletionAPIData, ChatMessage
1011

1112
logger = logging.getLogger(__name__)
1213

1314

1415
class LocalUserSession:
1516
user_session_id: str
16-
context: str
17+
user_session_id: str
18+
context: Any
1719

18-
def __init__(self, user_session_id: str, context: str = ""):
20+
def __init__(self, user_session_id: str, context: Any = ""):
1921
self.user_session_id = user_session_id
2022
self.contexts = context if context else ""
2123
self._current_round = 0
2224
self._in_flight: asyncio.Lock = asyncio.Lock()
2325
self._waiting_rounds: asyncio.Queue[asyncio.Future[bool]] = asyncio.Queue()
2426

25-
async def get_context(self, round: int) -> str:
27+
async def get_context(self, round: int) -> Any:
2628
if not self._waiting_rounds.empty() or self._in_flight.locked():
2729
# entering waiting queue
2830
future: asyncio.Future[bool] = asyncio.Future()
@@ -32,7 +34,7 @@ async def get_context(self, round: int) -> str:
3234
self._current_round += 1
3335
return self.contexts
3436

35-
def update_context(self, response: str) -> None:
37+
def update_context(self, response: Any) -> None:
3638
self.contexts = response
3739

3840
if not self._waiting_rounds.empty():
@@ -76,6 +78,80 @@ async def process_failure(
7678
return inference_info
7779

7880

79-
# TODO: UserSessionChatAPIData need to be implemented
80-
# class UserSessionChatAPIData(ChatCompletionAPIData):
81-
# ...
81+
82+
class UserSessionChatAPIData(ChatCompletionAPIData):
83+
model_config = ConfigDict(arbitrary_types_allowed=True)
84+
user_session: LocalUserSession = Field(exclude=True)
85+
target_round: int
86+
87+
async def to_payload(self, model_name: str, max_tokens: int, ignore_eos: bool, streaming: bool) -> dict[str, Any]:
88+
self._session_context = await self.user_session.get_context(self.target_round)
89+
# Append current messages to the session context (history)
90+
# self.messages contains the new user message for this turn (may include system)
91+
# self._session_context contains the history (system prompt + previous turns)
92+
if isinstance(self._session_context, list):
93+
# History already exists, append only the new user message(s)
94+
# Remove system from current messages if it exists (already in history)
95+
new_messages = [msg for msg in self.messages if msg.role != "system"]
96+
full_messages = self._session_context + new_messages
97+
else:
98+
# First turn: context is not a list yet, use all messages (including system)
99+
full_messages = self.messages
100+
101+
# We temporarily override self.messages to generate payload, then restore?
102+
# Or just construct payload manually.
103+
# ChatCompletionAPIData.to_payload uses self.messages.
104+
# Let's override self.messages for the payload generation, but we need to be careful.
105+
# Better to just construct the payload here similar to ChatCompletionAPIData.to_payload
106+
107+
if self.max_tokens == 0:
108+
self.max_tokens = max_tokens
109+
110+
return {
111+
"model": model_name,
112+
"messages": [{"role": m.role, "content": m.content} for m in full_messages],
113+
"max_tokens": self.max_tokens,
114+
"ignore_eos": ignore_eos,
115+
"stream": streaming,
116+
}
117+
118+
def update_inference_info(self, inference_info: InferenceInfo) -> None:
119+
inference_info.extra_info["user_session"] = self.user_session.user_session_id
120+
inference_info.extra_info["chat_round"] = self.user_session._current_round
121+
122+
async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
123+
inference_info = await super().process_response(response, config, tokenizer)
124+
self.update_inference_info(inference_info)
125+
126+
# Update context with the new turn
127+
# History <- History + User Message + Assistant Response
128+
# self._session_context is the history before this turn
129+
# self.messages is the user message(s) for this turn
130+
# self.model_response is the assistant response text
131+
132+
new_history = []
133+
if isinstance(self._session_context, list):
134+
# History already exists, extend it
135+
new_history.extend(self._session_context)
136+
# Add only new user message(s), excluding system (already in history)
137+
new_messages = [msg for msg in self.messages if msg.role != "system"]
138+
new_history.extend(new_messages)
139+
else:
140+
# First turn: include all messages (system + user)
141+
new_history.extend(self.messages)
142+
143+
# Add assistant response
144+
new_history.append(ChatMessage(role="assistant", content=self.model_response))
145+
146+
self.user_session.update_context(new_history)
147+
return inference_info
148+
149+
async def process_failure(
150+
self, response: Optional[ClientResponse], config: APIConfig, tokenizer: CustomTokenizer, exception: Exception
151+
) -> Optional[InferenceInfo]:
152+
# no response returned, use context from the last round (do not add new messages)
153+
inference_info = InferenceInfo()
154+
self.update_inference_info(inference_info)
155+
self.user_session.update_context(self._session_context)
156+
return inference_info
157+

inference_perf/datagen/shared_prefix_datagen.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from inference_perf.apis.base import InferenceAPIData, LazyLoadInferenceAPIData
66
from inference_perf.apis.completion import CompletionAPIData
7-
from inference_perf.apis.user_session import LocalUserSession, UserSessionCompletionAPIData
7+
from inference_perf.apis.user_session import LocalUserSession, UserSessionCompletionAPIData, UserSessionChatAPIData
88
from inference_perf.apis.chat import ChatCompletionAPIData, ChatMessage
99
from inference_perf.config import APIConfig, APIType, DataConfig
1010
from inference_perf.utils.custom_tokenizer import CustomTokenizer
@@ -65,25 +65,54 @@ def is_prefered_worker_requested(self) -> bool:
6565
return True
6666

6767
def load_lazy_data(self, data: LazyLoadInferenceAPIData) -> InferenceAPIData:
68-
i = data.data_index % len(self.prompts)
6968
if self.enable_multi_turn_chat:
7069
user_id = data.data_index % len(self.user_sessions)
7170
round = data.data_index // len(self.user_sessions)
72-
return UserSessionCompletionAPIData(
73-
prompt=self.prompts[i],
74-
max_tokens=self.output_len,
75-
user_session=self.user_sessions[user_id],
76-
target_round=round,
77-
)
78-
elif self.api_config.type == APIType.Chat:
79-
shared_prefix, question = self.prompt_pairs[i]
80-
messages = [
81-
ChatMessage(role="system", content=shared_prefix),
82-
ChatMessage(role="user", content=question)
83-
]
84-
return ChatCompletionAPIData(messages=messages, max_tokens=self.output_len)
71+
72+
# Each user belongs to a group, and each group has num_prompts_per_group questions
73+
# Calculate which group this user belongs to and which question in the group
74+
group_id = user_id // self.num_prompts_per_group
75+
prompt_in_group = user_id % self.num_prompts_per_group
76+
77+
# For each round, use a different question from the same group
78+
# Cycle through questions in the same group
79+
question_in_group = (prompt_in_group + round) % self.num_prompts_per_group
80+
question_idx = group_id * self.num_prompts_per_group + question_in_group
81+
82+
if self.api_config.type == APIType.Chat:
83+
# Chat API + Multi-turn: Use UserSessionChatAPIData
84+
shared_prefix, question = self.prompt_pairs[question_idx]
85+
messages = [
86+
ChatMessage(role="system", content=shared_prefix),
87+
ChatMessage(role="user", content=question)
88+
]
89+
return UserSessionChatAPIData(
90+
messages=messages,
91+
max_tokens=self.output_len,
92+
user_session=self.user_sessions[user_id],
93+
target_round=round,
94+
)
95+
else:
96+
# Completion API + Multi-turn: Use UserSessionCompletionAPIData
97+
prompt_idx = question_idx
98+
return UserSessionCompletionAPIData(
99+
prompt=self.prompts[prompt_idx],
100+
max_tokens=self.output_len,
101+
user_session=self.user_sessions[user_id],
102+
target_round=round,
103+
)
85104
else:
86-
return CompletionAPIData(prompt=self.prompts[i], max_tokens=self.output_len)
105+
# Single-turn: use data_index directly
106+
i = data.data_index % len(self.prompts)
107+
if self.api_config.type == APIType.Chat:
108+
shared_prefix, question = self.prompt_pairs[i]
109+
messages = [
110+
ChatMessage(role="system", content=shared_prefix),
111+
ChatMessage(role="user", content=question)
112+
]
113+
return ChatCompletionAPIData(messages=messages, max_tokens=self.output_len)
114+
else:
115+
return CompletionAPIData(prompt=self.prompts[i], max_tokens=self.output_len)
87116

88117
def get_request(self, n: int) -> InferenceAPIData:
89118
i = n % len(self.prompts)
@@ -151,10 +180,19 @@ def _generate_prompts(self) -> None:
151180

152181
if self.enable_multi_turn_chat:
153182
# multi turn chat, create user to keep conversation
183+
# For Chat API, context should be a list of messages starting with system prompt
184+
# For Completion API, context is a string
185+
if self.api_config.type == APIType.Chat:
186+
initial_context = [
187+
ChatMessage(role="system", content=shared_prefix_text)
188+
]
189+
else:
190+
initial_context = shared_prefix_text
191+
154192
self.user_sessions.append(
155193
LocalUserSession(
156194
user_session_id=f"user_session_{self.num_prompts_per_group * group_id + prompt_id}",
157-
context=shared_prefix_text,
195+
context=initial_context,
158196
)
159197
)
160198

0 commit comments

Comments
 (0)