|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import json |
| 5 | +from unittest.mock import MagicMock |
| 6 | + |
| 7 | +import pytest |
| 8 | + |
| 9 | +from aiperf.clients.model_endpoint_info import ModelEndpointInfo |
| 10 | +from aiperf.common.models import RequestRecord, ResponseData, SSEMessage, TextResponse |
| 11 | +from aiperf.parsers.openai_parsers import OpenAIResponseExtractor |
| 12 | + |
| 13 | + |
| 14 | +class TestOpenAIResponseExtractor: |
| 15 | + """Test cases for OpenAIResponseExtractor.""" |
| 16 | + |
| 17 | + @pytest.fixture |
| 18 | + def extractor(self): |
| 19 | + """Create an OpenAIResponseExtractor instance.""" |
| 20 | + mock_endpoint = MagicMock(spec=ModelEndpointInfo) |
| 21 | + return OpenAIResponseExtractor(mock_endpoint) |
| 22 | + |
| 23 | + def chat_completion_json(self, content) -> str: |
| 24 | + """Generate chat completion JSON with specified content and finish reason.""" |
| 25 | + completion = { |
| 26 | + "id": "test", |
| 27 | + "object": "chat.completion", |
| 28 | + "created": 1700000000, |
| 29 | + "model": "test-model", |
| 30 | + "choices": [ |
| 31 | + { |
| 32 | + "index": 0, |
| 33 | + "message": {"role": "assistant", "content": content}, |
| 34 | + "finish_reason": "stop", |
| 35 | + } |
| 36 | + ], |
| 37 | + } |
| 38 | + assert completion["choices"][0]["message"]["content"] == content |
| 39 | + return json.dumps(completion) |
| 40 | + |
| 41 | + def chat_completion_chunk_json(self, content, stop=True) -> str: |
| 42 | + """Generate chat completion chunk JSON with specified delta content and finish reason.""" |
| 43 | + chunk = { |
| 44 | + "id": "test", |
| 45 | + "object": "chat.completion.chunk", |
| 46 | + "created": 1700000000, |
| 47 | + "model": "test-model", |
| 48 | + "choices": [ |
| 49 | + { |
| 50 | + "index": 0, |
| 51 | + "delta": {"content": content}, |
| 52 | + "finish_reason": "stop" if stop else None, |
| 53 | + } |
| 54 | + ], |
| 55 | + } |
| 56 | + assert chunk["choices"][0]["delta"]["content"] == content |
| 57 | + return json.dumps(chunk) |
| 58 | + |
| 59 | + def create_raw_text_response(self, content, perf_ns=1000000) -> MagicMock: |
| 60 | + """Create a mock TextResponse with specified content.""" |
| 61 | + text_response = MagicMock(spec=TextResponse) |
| 62 | + text_response.text = content |
| 63 | + text_response.perf_ns = perf_ns |
| 64 | + return text_response |
| 65 | + |
| 66 | + def create_text_response(self, content, perf_ns=1000000) -> MagicMock: |
| 67 | + """Create a mock TextResponse with specified content.""" |
| 68 | + text_response = MagicMock(spec=TextResponse) |
| 69 | + text_response.text = self.chat_completion_json(content) |
| 70 | + text_response.perf_ns = perf_ns |
| 71 | + return text_response |
| 72 | + |
| 73 | + def create_sse_message(self, chunks, perf_ns=2000000) -> MagicMock: |
| 74 | + """Create a mock SSEMessage with specified chunk contents.""" |
| 75 | + sse_message = MagicMock(spec=SSEMessage) |
| 76 | + if isinstance(chunks, str): |
| 77 | + # Single chunk |
| 78 | + sse_message.extract_data_content.return_value = [ |
| 79 | + self.chat_completion_chunk_json(chunks) |
| 80 | + ] |
| 81 | + else: |
| 82 | + # Multiple chunks |
| 83 | + sse_message.extract_data_content.return_value = [ |
| 84 | + self.chat_completion_chunk_json(chunk) for chunk in chunks |
| 85 | + ] |
| 86 | + sse_message.perf_ns = perf_ns |
| 87 | + return sse_message |
| 88 | + |
| 89 | + def create_request_record(self, *responses) -> MagicMock: |
| 90 | + """Create a mock RequestRecord with specified responses.""" |
| 91 | + record = MagicMock(spec=RequestRecord) |
| 92 | + record.responses = list(responses) |
| 93 | + return record |
| 94 | + |
| 95 | + @pytest.mark.parametrize("text", ["[DONE]", "", None]) |
| 96 | + def test_parse_text_returns_none(self, extractor, text): |
| 97 | + """Test that _parse_text returns None for '[DONE]' marker, empty string, and None.""" |
| 98 | + result = extractor._parse_text(text) |
| 99 | + assert result is None |
| 100 | + |
| 101 | + @pytest.mark.parametrize("content", ["", None]) |
| 102 | + def test_parse_text_with_empty_content_returns_none(self, extractor, content): |
| 103 | + """Test that valid chat completion with empty/null content returns None.""" |
| 104 | + chat_completion_json = self.chat_completion_json(content) |
| 105 | + |
| 106 | + result = extractor._parse_text(chat_completion_json) |
| 107 | + assert result is None |
| 108 | + |
| 109 | + @pytest.mark.parametrize("content", ["", None]) |
| 110 | + def test_parse_text_with_empty_chunk_content_returns_none(self, extractor, content): |
| 111 | + """Test that valid chat completion chunk with empty/null delta content returns None.""" |
| 112 | + chunk_json = self.chat_completion_chunk_json(content) |
| 113 | + |
| 114 | + result = extractor._parse_text(chunk_json) |
| 115 | + assert result is None |
| 116 | + |
| 117 | + def test_parse_text_with_valid_content_returns_content(self, extractor): |
| 118 | + """Test that valid chat completion with actual content returns the content.""" |
| 119 | + test_content = "Hello, how can I help you?" |
| 120 | + chat_completion_json = self.chat_completion_json(test_content) |
| 121 | + |
| 122 | + result = extractor._parse_text(chat_completion_json) |
| 123 | + assert result == test_content |
| 124 | + |
| 125 | + def test_parse_text_with_valid_chunk_content_returns_content(self, extractor): |
| 126 | + """Test that valid chat completion chunk with actual delta content returns the content.""" |
| 127 | + test_content = "Stream chunk content" |
| 128 | + chunk_json = self.chat_completion_chunk_json(test_content) |
| 129 | + |
| 130 | + result = extractor._parse_text(chunk_json) |
| 131 | + assert result == test_content |
| 132 | + |
| 133 | + def test_parse_text_response_with_empty_content_returns_none(self, extractor): |
| 134 | + """Test that TextResponse with empty content is ignored.""" |
| 135 | + text_response = self.create_raw_text_response("") |
| 136 | + |
| 137 | + result = extractor._parse_text_response(text_response) |
| 138 | + assert result is None |
| 139 | + |
| 140 | + def test_parse_text_response_with_valid_content_returns_response_data( |
| 141 | + self, extractor |
| 142 | + ): |
| 143 | + """Test that TextResponse with valid content returns ResponseData.""" |
| 144 | + test_content = "Valid response" |
| 145 | + text_response = self.create_text_response(test_content) |
| 146 | + |
| 147 | + result = extractor._parse_text_response(text_response) |
| 148 | + |
| 149 | + assert result is not None |
| 150 | + assert isinstance(result, ResponseData) |
| 151 | + assert result.parsed_text == [test_content] |
| 152 | + assert result.perf_ns == 1000000 |
| 153 | + |
| 154 | + def test_parse_sse_response_with_empty_chunks_returns_none(self, extractor): |
| 155 | + """Test that SSEMessage with empty chunks is ignored.""" |
| 156 | + sse_message = self.create_sse_message("") |
| 157 | + |
| 158 | + result = extractor._parse_sse_response(sse_message) |
| 159 | + assert result is None |
| 160 | + |
| 161 | + def test_parse_sse_response_with_mixed_chunks_filters_empty(self, extractor): |
| 162 | + """Test that SSEMessage filters out empty chunks but keeps valid ones.""" |
| 163 | + sse_message = self.create_sse_message(["", "Valid chunk"]) |
| 164 | + |
| 165 | + result = extractor._parse_sse_response(sse_message) |
| 166 | + |
| 167 | + assert result is not None |
| 168 | + assert isinstance(result, ResponseData) |
| 169 | + assert result.parsed_text == ["Valid chunk"] |
| 170 | + assert result.perf_ns == 2000000 |
| 171 | + |
| 172 | + @pytest.mark.asyncio |
| 173 | + async def test_extract_response_data_filters_empty_responses(self, extractor): |
| 174 | + """Test that extract_response_data filters out responses with empty content.""" |
| 175 | + request = self.create_request_record( |
| 176 | + self.create_raw_text_response("", perf_ns=1000000), # Raw empty text |
| 177 | + self.create_text_response("Valid response", perf_ns=2000000), |
| 178 | + ) |
| 179 | + |
| 180 | + results = await extractor.extract_response_data(request, None) |
| 181 | + |
| 182 | + # Should only return the valid response, empty one should be filtered out |
| 183 | + assert len(results) == 1 |
| 184 | + assert results[0].parsed_text == ["Valid response"] |
| 185 | + assert results[0].perf_ns == 2000000 |
| 186 | + |
| 187 | + @pytest.mark.asyncio |
| 188 | + async def test_extract_response_data_handles_mixed_response_types(self, extractor): |
| 189 | + """Test that extract_response_data handles mixed TextResponse and SSEMessage types.""" |
| 190 | + request = self.create_request_record( |
| 191 | + self.create_text_response("Text response", perf_ns=1000000), |
| 192 | + self.create_sse_message("SSE chunk", perf_ns=2000000), |
| 193 | + ) |
| 194 | + |
| 195 | + results = await extractor.extract_response_data(request, None) |
| 196 | + |
| 197 | + # Should return both responses |
| 198 | + assert len(results) == 2 |
| 199 | + assert results[0].parsed_text == ["Text response"] |
| 200 | + assert results[0].perf_ns == 1000000 |
| 201 | + assert results[1].parsed_text == ["SSE chunk"] |
| 202 | + assert results[1].perf_ns == 2000000 |
| 203 | + |
| 204 | + @pytest.mark.asyncio |
| 205 | + async def test_extract_response_data_with_complex_sse_filtering(self, extractor): |
| 206 | + """Test extract_response_data with complex SSE message filtering.""" |
| 207 | + request = self.create_request_record( |
| 208 | + self.create_text_response("Valid text response", perf_ns=1000000), |
| 209 | + self.create_sse_message( |
| 210 | + ["", "Valid chunk 1", "", "Valid chunk 2"], perf_ns=3000000 |
| 211 | + ), |
| 212 | + self.create_raw_text_response("", perf_ns=4000000), # Should be filtered |
| 213 | + ) |
| 214 | + |
| 215 | + results = await extractor.extract_response_data(request, None) |
| 216 | + |
| 217 | + # Should return text response + filtered SSE response (empty raw_text filtered out) |
| 218 | + assert len(results) == 2 |
| 219 | + assert results[0].parsed_text == ["Valid text response"] |
| 220 | + assert results[0].perf_ns == 1000000 |
| 221 | + assert results[1].parsed_text == [ |
| 222 | + "Valid chunk 1", |
| 223 | + "Valid chunk 2", |
| 224 | + ] # Empty chunks filtered |
| 225 | + assert results[1].perf_ns == 3000000 |
0 commit comments