|
31 | 31 | from djl_python.inputs import Input |
32 | 32 | from djl_python.outputs import Output |
33 | 33 | from djl_python.encode_decode import decode |
34 | | -from djl_python.async_utils import handle_streaming_response, create_non_stream_output, _extract_lora_adapter |
| 34 | +from djl_python.async_utils import handle_streaming_response, create_non_stream_output, create_stream_chunk_output, _extract_lora_adapter |
35 | 35 | from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError |
36 | 36 | from djl_python.custom_handler_service import CustomHandlerService |
37 | 37 | from djl_python.rolling_batch.rolling_batch_vllm_utils import create_lora_request, get_lora_request |
@@ -162,6 +162,14 @@ async def initialize(self, properties: dict): |
162 | 162 | self.session_manager: SessionManager = SessionManager(properties) |
163 | 163 | self.initialized = True |
164 | 164 |
|
| 165 | + def _get_custom_formatter(self, adapter_name: Optional[str] = None) -> bool: |
| 166 | + """Check if a custom output formatter exists for the adapter or base model.""" |
| 167 | + if adapter_name: |
| 168 | + adapter_formatter = self.get_adapter_formatter_handler(adapter_name) |
| 169 | + if adapter_formatter and adapter_formatter.output_formatter: |
| 170 | + return True |
| 171 | + return self.output_formatter is not None |
| 172 | + |
165 | 173 | def preprocess_request(self, inputs: Input) -> ProcessedRequest: |
166 | 174 | batch = inputs.get_batches() |
167 | 175 | assert len(batch) == 1, "only one request per batch allowed" |
@@ -255,50 +263,67 @@ async def check_health(self): |
255 | 263 | logger.fatal("vLLM engine is dead, terminating process") |
256 | 264 | kill_process_tree(os.getpid()) |
257 | 265 |
|
258 | | - async def inference( |
259 | | - self, |
260 | | - inputs: Input) -> Union[Output, AsyncGenerator[Output, None]]: |
| 266 | + async def inference(self, inputs: Input) -> Union[Output, AsyncGenerator[Output, None]]: |
261 | 267 | await self.check_health() |
262 | 268 | try: |
263 | 269 | processed_request = self.preprocess_request(inputs) |
264 | 270 | except CustomFormatterError as e: |
265 | 271 | logger.exception("Custom formatter failed") |
266 | | - output = create_non_stream_output( |
| 272 | + return create_non_stream_output( |
267 | 273 | "", error=f"Custom formatter failed: {str(e)}", code=424) |
268 | | - return output |
269 | 274 | except Exception as e: |
270 | 275 | logger.exception("Input parsing failed") |
271 | | - output = create_non_stream_output( |
| 276 | + return create_non_stream_output( |
272 | 277 | "", error=f"Input parsing failed: {str(e)}", code=424) |
273 | | - return output |
274 | 278 |
|
275 | 279 | # vLLM will extract the adapter from the request object via _maybe_get_adapters() |
276 | 280 | response = await processed_request.inference_invoker( |
277 | 281 | processed_request.vllm_request) |
278 | 282 |
|
| 283 | + # Check if custom formatter exists (applies to both streaming and non-streaming) |
| 284 | + custom_formatter = self._get_custom_formatter(processed_request.adapter_name) |
| 285 | + |
279 | 286 | if isinstance(response, types.AsyncGeneratorType): |
280 | | - # Apply streaming output formatter (adapter-specific or base model) |
281 | | - response = self.apply_output_formatter_streaming_raw( |
| 287 | + return self._handle_streaming_response(response, processed_request, custom_formatter) |
| 288 | + |
| 289 | + # Non-streaming response |
| 290 | + if custom_formatter: |
| 291 | + formatted_response = self.apply_output_formatter( |
282 | 292 | response, adapter_name=processed_request.adapter_name) |
| 293 | + # If custom formatter returns a Pydantic model, serialize it |
| 294 | + if hasattr(formatted_response, 'model_dump_json'): |
| 295 | + formatted_response = formatted_response.model_dump_json() |
| 296 | + elif hasattr(formatted_response, 'model_dump'): |
| 297 | + formatted_response = formatted_response.model_dump() |
| 298 | + return create_non_stream_output(formatted_response) |
| 299 | + |
| 300 | + # LMI formatter for non-streaming |
| 301 | + return processed_request.non_stream_output_formatter( |
| 302 | + response, |
| 303 | + request=processed_request.vllm_request, |
| 304 | + tokenizer=self.tokenizer, |
| 305 | + ) |
283 | 306 |
|
284 | | - return handle_streaming_response( |
| 307 | + async def _handle_streaming_response(self, response, processed_request, custom_formatter): |
| 308 | + """Handle streaming responses as an async generator""" |
| 309 | + if custom_formatter: |
| 310 | + # Custom formatter: apply to each chunk and yield directly |
| 311 | + async for chunk in response: |
| 312 | + formatted_chunk = self.apply_output_formatter( |
| 313 | + chunk, adapter_name=processed_request.adapter_name) |
| 314 | + yield create_stream_chunk_output(formatted_chunk, last_chunk=False) |
| 315 | + yield create_stream_chunk_output("", last_chunk=True) |
| 316 | + else: |
| 317 | + # LMI formatter for streaming |
| 318 | + async for output in handle_streaming_response( |
285 | 319 | response, |
286 | 320 | processed_request.stream_output_formatter, |
287 | 321 | request=processed_request.vllm_request, |
288 | 322 | accumulate_chunks=processed_request.accumulate_chunks, |
289 | 323 | include_prompt=processed_request.include_prompt, |
290 | 324 | tokenizer=self.tokenizer, |
291 | | - ) |
292 | | - |
293 | | - # Apply output formatter (adapter-specific or base model) |
294 | | - response = self.apply_output_formatter( |
295 | | - response, adapter_name=processed_request.adapter_name) |
296 | | - |
297 | | - return processed_request.non_stream_output_formatter( |
298 | | - response, |
299 | | - request=processed_request.vllm_request, |
300 | | - tokenizer=self.tokenizer, |
301 | | - ) |
| 325 | + ): |
| 326 | + yield output |
302 | 327 |
|
303 | 328 | async def add_lora(self, lora_name: str, lora_alias: str, lora_path: str): |
304 | 329 | logging.info(f"Adding LoRA {lora_name} from {lora_path}") |
|
0 commit comments