Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions inference_perf/apis/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ChatMessage(BaseModel):
class ChatCompletionAPIData(InferenceAPIData):
messages: List[ChatMessage]
max_tokens: int = 0
model_response: str = "" # Store the assistant response for multi-turn chat

def get_api_type(self) -> APIType:
return APIType.Chat
Expand Down Expand Up @@ -81,6 +82,7 @@ async def process_response(self, response: ClientResponse, config: APIConfig, to
prompt_text = "".join([msg.content for msg in self.messages if msg.content])
prompt_len = tokenizer.count_tokens(prompt_text)
output_len = tokenizer.count_tokens(output_text)
self.model_response = output_text # Store response for multi-turn chat
return InferenceInfo(
input_tokens=prompt_len,
output_tokens=output_len,
Expand All @@ -93,6 +95,7 @@ async def process_response(self, response: ClientResponse, config: APIConfig, to
if len(choices) == 0:
return InferenceInfo(input_tokens=prompt_len)
output_text = "".join([choice.get("message", {}).get("content", "") for choice in choices])
self.model_response = output_text # Store response for multi-turn chat
output_len = tokenizer.count_tokens(output_text)
return InferenceInfo(
input_tokens=prompt_len,
Expand Down
11 changes: 9 additions & 2 deletions inference_perf/apis/user_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,21 @@ def update_inference_info(self, inference_info: InferenceInfo) -> None:
inference_info.extra_info["user_session"] = self.user_session.user_session_id
inference_info.extra_info["chat_round"] = self.user_session._current_round

async def process_response(self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer) -> InferenceInfo:
async def process_response(
self, response: ClientResponse, config: APIConfig, tokenizer: CustomTokenizer, lora_adapter: Optional[str] = None
) -> InferenceInfo:
inference_info = await super().process_response(response, config, tokenizer)
self.update_inference_info(inference_info)
self.user_session.update_context(self.prompt + " " + self.model_response)
return inference_info

async def process_failure(
self, response: Optional[ClientResponse], config: APIConfig, tokenizer: CustomTokenizer, exception: Exception
self,
response: Optional[ClientResponse],
config: APIConfig,
tokenizer: CustomTokenizer,
exception: Exception,
lora_adapter: Optional[str] = None,
) -> Optional[InferenceInfo]:
# no response returned, use context from the last round
inference_info = InferenceInfo()
Expand Down
35 changes: 25 additions & 10 deletions inference_perf/datagen/shared_prefix_datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from inference_perf.apis.base import InferenceAPIData, LazyLoadInferenceAPIData
from inference_perf.apis.completion import CompletionAPIData
from inference_perf.apis.user_session import LocalUserSession, UserSessionCompletionAPIData
from inference_perf.apis.chat import ChatCompletionAPIData, ChatMessage
from inference_perf.config import APIConfig, APIType, DataConfig
from inference_perf.utils.custom_tokenizer import CustomTokenizer
from .base import DataGenerator, LazyLoadDataMixin
Expand Down Expand Up @@ -46,12 +47,13 @@ def __init__(self, api_config: APIConfig, config: DataConfig, tokenizer: Optiona
self.output_len: int = self.shared_prefix.output_len
self.enable_multi_turn_chat: bool = self.shared_prefix.enable_multi_turn_chat

self.prompts: List[str] = []
self.prompts: List[str] = [] # For completion API
self.prompt_pairs: List[tuple[str, str]] = [] # (shared_prefix, question) pairs for chat API
self.user_sessions: List[LocalUserSession] = []
self._generate_prompts()

def get_supported_apis(self) -> List[APIType]:
return [APIType.Completion]
return [APIType.Completion, APIType.Chat]

def is_io_distribution_supported(self) -> bool:
return True
Expand All @@ -73,6 +75,10 @@ def load_lazy_data(self, data: LazyLoadInferenceAPIData) -> InferenceAPIData:
user_session=self.user_sessions[user_id],
target_round=round,
)
elif self.api_config.type == APIType.Chat:
shared_prefix, question = self.prompt_pairs[i]
messages = [ChatMessage(role="system", content=shared_prefix), ChatMessage(role="user", content=question)]
return ChatCompletionAPIData(messages=messages, max_tokens=self.output_len)
else:
return CompletionAPIData(prompt=self.prompts[i], max_tokens=self.output_len)

Expand All @@ -82,9 +88,18 @@ def get_data(self) -> Generator[InferenceAPIData, None, None]:

i = 0
while True:
prefered_worker_id = i % self.num_groups if self.enable_multi_turn_chat else -1
yield LazyLoadInferenceAPIData(data_index=i, prefered_worker_id=prefered_worker_id)
i += 1
if self.enable_multi_turn_chat:
prefered_worker_id = i % self.num_groups
yield LazyLoadInferenceAPIData(data_index=i, prefered_worker_id=prefered_worker_id)
i += 1
elif self.api_config.type == APIType.Chat:
shared_prefix, question = self.prompt_pairs[i]
messages = [ChatMessage(role="system", content=shared_prefix), ChatMessage(role="user", content=question)]
yield ChatCompletionAPIData(messages=messages, max_tokens=self.output_len)
i = (i + 1) % len(self.prompts)
else:
yield CompletionAPIData(prompt=self.prompts[i], max_tokens=self.output_len)
i = (i + 1) % len(self.prompts)

def _generate_random_token_ids(self, length: int) -> List[int]:
"""Generates a list of random token IDs of a specified length."""
Expand All @@ -111,6 +126,11 @@ def _generate_prompts(self) -> None:
question_token_ids = self._generate_random_token_ids(self.question_len)
question_text = hf_tokenizer.decode(question_token_ids, skip_special_tokens=True)

# Combine shared prefix and question
full_prompt_text = shared_prefix_text + " " + question_text
self.prompts.append(full_prompt_text)
self.prompt_pairs.append((shared_prefix_text, question_text))

if self.enable_multi_turn_chat:
# multi turn chat, create user to keep conversation
self.user_sessions.append(
Expand All @@ -119,11 +139,6 @@ def _generate_prompts(self) -> None:
context=shared_prefix_text,
)
)
else:
# Single turn chat, Combine shared prefix and question
question_text = shared_prefix_text + " " + question_text

self.prompts.append(question_text)

# Shuffle the generated prompts to ensure randomness if served sequentially by different workers
random.shuffle(self.prompts)
239 changes: 239 additions & 0 deletions tests/datagen/test_shared_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import pytest
from unittest.mock import Mock

from inference_perf.datagen.shared_prefix_datagen import SharedPrefixDataGenerator
from inference_perf.apis.completion import CompletionAPIData
from inference_perf.apis.chat import ChatCompletionAPIData
from inference_perf.apis.user_session import UserSessionCompletionAPIData
from inference_perf.apis.base import LazyLoadInferenceAPIData
from inference_perf.config import APIConfig, APIType, DataConfig


def create_mock_tokenizer() -> Mock:
"""Create a mock tokenizer for testing."""
mock_tokenizer = Mock()
mock_hf_tokenizer = Mock()
mock_hf_tokenizer.vocab_size = 32000
mock_hf_tokenizer.decode = Mock(side_effect=lambda ids, **kwargs: f"text_{len(ids)}")
mock_tokenizer.get_tokenizer.return_value = mock_hf_tokenizer
return mock_tokenizer


def create_api_config(api_type: APIType) -> APIConfig:
"""Create an APIConfig for testing."""
return APIConfig(type=api_type)


def create_data_config(
num_groups: int = 2,
num_prompts_per_group: int = 3,
system_prompt_len: int = 10,
question_len: int = 5,
output_len: int = 20,
enable_multi_turn_chat: bool = False,
) -> DataConfig:
"""Create a DataConfig with shared_prefix settings for testing."""
config = DataConfig()
config.shared_prefix = Mock()
config.shared_prefix.num_groups = num_groups
config.shared_prefix.num_prompts_per_group = num_prompts_per_group
config.shared_prefix.system_prompt_len = system_prompt_len
config.shared_prefix.question_len = question_len
config.shared_prefix.output_len = output_len
config.shared_prefix.enable_multi_turn_chat = enable_multi_turn_chat
return config


class TestSharedPrefixDataGeneratorBasic:
"""Basic tests for SharedPrefixDataGenerator."""

def test_get_supported_apis(self) -> None:
"""Test that both Completion and Chat APIs are supported."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config()
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)

supported = generator.get_supported_apis()
assert APIType.Completion in supported
assert APIType.Chat in supported

def test_prompts_count(self) -> None:
"""Test that correct number of prompts are generated."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(num_groups=2, num_prompts_per_group=3)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)

assert len(generator.prompts) == 6 # 2 groups * 3 prompts
assert len(generator.prompt_pairs) == 6

def test_prompt_pairs_structure(self) -> None:
"""Test that prompt_pairs contain (shared_prefix, question) tuples."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config()
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)

for shared_prefix, question in generator.prompt_pairs:
assert isinstance(shared_prefix, str)
assert isinstance(question, str)


class TestSharedPrefixCompletionAPI:
"""Tests for Completion API support."""

def test_load_lazy_data_returns_completion_api_data(self) -> None:
"""Test that load_lazy_data returns CompletionAPIData for Completion API."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)

assert isinstance(result, CompletionAPIData)
assert result.max_tokens == generator.output_len

@pytest.mark.asyncio
async def test_completion_api_to_payload(self) -> None:
"""Test that CompletionAPIData generates correct payload."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(output_len=50, enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)
result = generator.load_lazy_data(lazy_data)

payload = await result.to_payload("test-model", 100, False, True)

assert payload["model"] == "test-model"
assert "prompt" in payload
assert payload["max_tokens"] == 50
assert payload["stream"] is True

def test_get_data_yields_completion_api_data(self) -> None:
"""Test that get_data yields CompletionAPIData for Completion API."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
data_gen = generator.get_data()

first_item = next(data_gen)
assert isinstance(first_item, CompletionAPIData)


class TestSharedPrefixChatAPI:
"""Tests for Chat API support."""

def test_load_lazy_data_returns_chat_api_data(self) -> None:
"""Test that load_lazy_data returns ChatCompletionAPIData for Chat API."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)

assert isinstance(result, ChatCompletionAPIData)
assert result.max_tokens == generator.output_len

def test_chat_api_messages_structure(self) -> None:
"""Test that Chat API messages have system and user roles."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)

assert isinstance(result, ChatCompletionAPIData)
assert len(result.messages) == 2
assert result.messages[0].role == "system"
assert result.messages[1].role == "user"

def test_get_data_yields_chat_api_data(self) -> None:
"""Test that get_data yields ChatCompletionAPIData for Chat API."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
data_gen = generator.get_data()

first_item = next(data_gen)
assert isinstance(first_item, ChatCompletionAPIData)
assert len(first_item.messages) == 2


class TestSharedPrefixMultiTurn:
"""Tests for multi-turn chat support."""

def test_multi_turn_creates_user_sessions(self) -> None:
"""Test that multi-turn mode creates user sessions."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(num_groups=2, num_prompts_per_group=3, enable_multi_turn_chat=True)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)

assert len(generator.user_sessions) == 6 # 2 groups * 3 prompts

def test_multi_turn_load_lazy_data_returns_user_session_data(self) -> None:
"""Test that multi-turn load_lazy_data returns UserSessionCompletionAPIData."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(enable_multi_turn_chat=True)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)

assert isinstance(result, UserSessionCompletionAPIData)

def test_multi_turn_get_data_yields_lazy_load_data(self) -> None:
"""Test that multi-turn get_data yields LazyLoadInferenceAPIData."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(enable_multi_turn_chat=True)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
data_gen = generator.get_data()

first_item = next(data_gen)
assert isinstance(first_item, LazyLoadInferenceAPIData)


class TestSharedPrefixValidation:
"""Tests for validation and error handling."""

def test_requires_tokenizer(self) -> None:
"""Test that tokenizer is required."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config()

with pytest.raises((ValueError, AttributeError)):
SharedPrefixDataGenerator(api_config, data_config, None)

def test_requires_shared_prefix_config(self) -> None:
"""Test that shared_prefix config is required."""
api_config = create_api_config(APIType.Completion)
data_config = DataConfig()
data_config.shared_prefix = None
tokenizer = create_mock_tokenizer()

with pytest.raises(ValueError, match="Shared Prefix config is required"):
SharedPrefixDataGenerator(api_config, data_config, tokenizer)