Skip to content

Commit c31c41f

Browse files
authored
fix: exclude valid openai packets with empty string data (#224)
1 parent a82154e commit c31c41f

File tree

3 files changed

+234
-5
lines changed

3 files changed

+234
-5
lines changed

aiperf/parsers/openai_parsers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _parse_text_response(self, response: TextResponse) -> ResponseData | None:
109109
"""Parse a TextResponse into a ResponseData object."""
110110
raw = response.text
111111
parsed = self._parse_text(raw)
112-
if parsed is None:
112+
if not parsed:
113113
return None
114114

115115
return ResponseData(
@@ -123,7 +123,7 @@ def _parse_sse_response(self, response: SSEMessage) -> ResponseData | None:
123123
"""Parse a SSEMessage into a ResponseData object."""
124124
raw = response.extract_data_content()
125125
parsed = self._parse_sse(raw)
126-
if parsed is None or len(parsed) == 0:
126+
if not parsed:
127127
return None
128128

129129
return ResponseData(
@@ -147,7 +147,7 @@ async def extract_response_data(
147147
results = []
148148
for response in record.responses:
149149
response_data = self._parse_response(response)
150-
if response_data is None:
150+
if not response_data:
151151
continue
152152

153153
if tokenizer is not None:
@@ -180,7 +180,9 @@ def _parse_text(self, raw_text: str) -> Any | None:
180180

181181
for obj_type, extractor in type_to_extractor.items():
182182
if isinstance(obj, obj_type):
183-
return extractor(obj)
183+
content = extractor(obj)
184+
# skip empty content
185+
return content if content else None
184186

185187
raise ValueError(f"Invalid OpenAI object: {raw_text}")
186188

@@ -189,7 +191,7 @@ def _parse_sse(self, raw_sse: list[str]) -> list[Any]:
189191
result = []
190192
for sse in raw_sse:
191193
parsed = self._parse_text(sse)
192-
if parsed is None:
194+
if not parsed:
193195
continue
194196
result.append(parsed)
195197
return result

tests/parsers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
from unittest.mock import MagicMock
6+
7+
import pytest
8+
9+
from aiperf.clients.model_endpoint_info import ModelEndpointInfo
10+
from aiperf.common.models import RequestRecord, ResponseData, SSEMessage, TextResponse
11+
from aiperf.parsers.openai_parsers import OpenAIResponseExtractor
12+
13+
14+
class TestOpenAIResponseExtractor:
15+
"""Test cases for OpenAIResponseExtractor."""
16+
17+
@pytest.fixture
18+
def extractor(self):
19+
"""Create an OpenAIResponseExtractor instance."""
20+
mock_endpoint = MagicMock(spec=ModelEndpointInfo)
21+
return OpenAIResponseExtractor(mock_endpoint)
22+
23+
def chat_completion_json(self, content) -> str:
24+
"""Generate chat completion JSON with specified content and finish reason."""
25+
completion = {
26+
"id": "test",
27+
"object": "chat.completion",
28+
"created": 1700000000,
29+
"model": "test-model",
30+
"choices": [
31+
{
32+
"index": 0,
33+
"message": {"role": "assistant", "content": content},
34+
"finish_reason": "stop",
35+
}
36+
],
37+
}
38+
assert completion["choices"][0]["message"]["content"] == content
39+
return json.dumps(completion)
40+
41+
def chat_completion_chunk_json(self, content, stop=True) -> str:
42+
"""Generate chat completion chunk JSON with specified delta content and finish reason."""
43+
chunk = {
44+
"id": "test",
45+
"object": "chat.completion.chunk",
46+
"created": 1700000000,
47+
"model": "test-model",
48+
"choices": [
49+
{
50+
"index": 0,
51+
"delta": {"content": content},
52+
"finish_reason": "stop" if stop else None,
53+
}
54+
],
55+
}
56+
assert chunk["choices"][0]["delta"]["content"] == content
57+
return json.dumps(chunk)
58+
59+
def create_raw_text_response(self, content, perf_ns=1000000) -> MagicMock:
60+
"""Create a mock TextResponse with specified content."""
61+
text_response = MagicMock(spec=TextResponse)
62+
text_response.text = content
63+
text_response.perf_ns = perf_ns
64+
return text_response
65+
66+
def create_text_response(self, content, perf_ns=1000000) -> MagicMock:
67+
"""Create a mock TextResponse with specified content."""
68+
text_response = MagicMock(spec=TextResponse)
69+
text_response.text = self.chat_completion_json(content)
70+
text_response.perf_ns = perf_ns
71+
return text_response
72+
73+
def create_sse_message(self, chunks, perf_ns=2000000) -> MagicMock:
74+
"""Create a mock SSEMessage with specified chunk contents."""
75+
sse_message = MagicMock(spec=SSEMessage)
76+
if isinstance(chunks, str):
77+
# Single chunk
78+
sse_message.extract_data_content.return_value = [
79+
self.chat_completion_chunk_json(chunks)
80+
]
81+
else:
82+
# Multiple chunks
83+
sse_message.extract_data_content.return_value = [
84+
self.chat_completion_chunk_json(chunk) for chunk in chunks
85+
]
86+
sse_message.perf_ns = perf_ns
87+
return sse_message
88+
89+
def create_request_record(self, *responses) -> MagicMock:
90+
"""Create a mock RequestRecord with specified responses."""
91+
record = MagicMock(spec=RequestRecord)
92+
record.responses = list(responses)
93+
return record
94+
95+
@pytest.mark.parametrize("text", ["[DONE]", "", None])
96+
def test_parse_text_returns_none(self, extractor, text):
97+
"""Test that _parse_text returns None for '[DONE]' marker, empty string, and None."""
98+
result = extractor._parse_text(text)
99+
assert result is None
100+
101+
@pytest.mark.parametrize("content", ["", None])
102+
def test_parse_text_with_empty_content_returns_none(self, extractor, content):
103+
"""Test that valid chat completion with empty/null content returns None."""
104+
chat_completion_json = self.chat_completion_json(content)
105+
106+
result = extractor._parse_text(chat_completion_json)
107+
assert result is None
108+
109+
@pytest.mark.parametrize("content", ["", None])
110+
def test_parse_text_with_empty_chunk_content_returns_none(self, extractor, content):
111+
"""Test that valid chat completion chunk with empty/null delta content returns None."""
112+
chunk_json = self.chat_completion_chunk_json(content)
113+
114+
result = extractor._parse_text(chunk_json)
115+
assert result is None
116+
117+
def test_parse_text_with_valid_content_returns_content(self, extractor):
118+
"""Test that valid chat completion with actual content returns the content."""
119+
test_content = "Hello, how can I help you?"
120+
chat_completion_json = self.chat_completion_json(test_content)
121+
122+
result = extractor._parse_text(chat_completion_json)
123+
assert result == test_content
124+
125+
def test_parse_text_with_valid_chunk_content_returns_content(self, extractor):
126+
"""Test that valid chat completion chunk with actual delta content returns the content."""
127+
test_content = "Stream chunk content"
128+
chunk_json = self.chat_completion_chunk_json(test_content)
129+
130+
result = extractor._parse_text(chunk_json)
131+
assert result == test_content
132+
133+
def test_parse_text_response_with_empty_content_returns_none(self, extractor):
134+
"""Test that TextResponse with empty content is ignored."""
135+
text_response = self.create_raw_text_response("")
136+
137+
result = extractor._parse_text_response(text_response)
138+
assert result is None
139+
140+
def test_parse_text_response_with_valid_content_returns_response_data(
141+
self, extractor
142+
):
143+
"""Test that TextResponse with valid content returns ResponseData."""
144+
test_content = "Valid response"
145+
text_response = self.create_text_response(test_content)
146+
147+
result = extractor._parse_text_response(text_response)
148+
149+
assert result is not None
150+
assert isinstance(result, ResponseData)
151+
assert result.parsed_text == [test_content]
152+
assert result.perf_ns == 1000000
153+
154+
def test_parse_sse_response_with_empty_chunks_returns_none(self, extractor):
155+
"""Test that SSEMessage with empty chunks is ignored."""
156+
sse_message = self.create_sse_message("")
157+
158+
result = extractor._parse_sse_response(sse_message)
159+
assert result is None
160+
161+
def test_parse_sse_response_with_mixed_chunks_filters_empty(self, extractor):
162+
"""Test that SSEMessage filters out empty chunks but keeps valid ones."""
163+
sse_message = self.create_sse_message(["", "Valid chunk"])
164+
165+
result = extractor._parse_sse_response(sse_message)
166+
167+
assert result is not None
168+
assert isinstance(result, ResponseData)
169+
assert result.parsed_text == ["Valid chunk"]
170+
assert result.perf_ns == 2000000
171+
172+
@pytest.mark.asyncio
173+
async def test_extract_response_data_filters_empty_responses(self, extractor):
174+
"""Test that extract_response_data filters out responses with empty content."""
175+
request = self.create_request_record(
176+
self.create_raw_text_response("", perf_ns=1000000), # Raw empty text
177+
self.create_text_response("Valid response", perf_ns=2000000),
178+
)
179+
180+
results = await extractor.extract_response_data(request, None)
181+
182+
# Should only return the valid response, empty one should be filtered out
183+
assert len(results) == 1
184+
assert results[0].parsed_text == ["Valid response"]
185+
assert results[0].perf_ns == 2000000
186+
187+
@pytest.mark.asyncio
188+
async def test_extract_response_data_handles_mixed_response_types(self, extractor):
189+
"""Test that extract_response_data handles mixed TextResponse and SSEMessage types."""
190+
request = self.create_request_record(
191+
self.create_text_response("Text response", perf_ns=1000000),
192+
self.create_sse_message("SSE chunk", perf_ns=2000000),
193+
)
194+
195+
results = await extractor.extract_response_data(request, None)
196+
197+
# Should return both responses
198+
assert len(results) == 2
199+
assert results[0].parsed_text == ["Text response"]
200+
assert results[0].perf_ns == 1000000
201+
assert results[1].parsed_text == ["SSE chunk"]
202+
assert results[1].perf_ns == 2000000
203+
204+
@pytest.mark.asyncio
205+
async def test_extract_response_data_with_complex_sse_filtering(self, extractor):
206+
"""Test extract_response_data with complex SSE message filtering."""
207+
request = self.create_request_record(
208+
self.create_text_response("Valid text response", perf_ns=1000000),
209+
self.create_sse_message(
210+
["", "Valid chunk 1", "", "Valid chunk 2"], perf_ns=3000000
211+
),
212+
self.create_raw_text_response("", perf_ns=4000000), # Should be filtered
213+
)
214+
215+
results = await extractor.extract_response_data(request, None)
216+
217+
# Should return text response + filtered SSE response (empty raw_text filtered out)
218+
assert len(results) == 2
219+
assert results[0].parsed_text == ["Valid text response"]
220+
assert results[0].perf_ns == 1000000
221+
assert results[1].parsed_text == [
222+
"Valid chunk 1",
223+
"Valid chunk 2",
224+
] # Empty chunks filtered
225+
assert results[1].perf_ns == 3000000

0 commit comments

Comments
 (0)