diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 1e1b6a401f51..b034609ce935 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -170,6 +170,18 @@ def _build_tool_strict_guided_decoding_params(tools, tool_parser_name): by_alias=True, exclude_none=True)) +def _normalize_image_output(image) -> list: + """Normalize image output to a list of individual images. + + Handles single tensors, batched 4D tensors, and lists. + """ + if isinstance(image, list): + return image + if hasattr(image, "dim") and image.dim() == 4: + return [image[i] for i in range(image.shape[0])] + return [image] + + class OpenAIServer: def __init__( @@ -1665,22 +1677,16 @@ async def openai_image_generation(self, request: ImageGenerationRequest, ) # Build response - output_images = output.image - MediaStorage.save_image( - output_images, - self.media_storage_path / f"{image_id}.png", - ) - - if not isinstance(output_images, list): - output_images = [output_images] + output_images = _normalize_image_output(output.image) if request.response_format == "b64_json": data = [ - ImageObject(b64_json=base64.b64encode( - MediaStorage.convert_image_to_bytes(image)).decode( - 'utf-8'), - revised_prompt=request.prompt) - for image in output_images + ImageObject( + b64_json=base64.b64encode( + MediaStorage.convert_image_to_bytes(image)).decode( + 'utf-8'), + revised_prompt=request.prompt, + ) for image in output_images ] response = ImageGenerationResponse( @@ -1690,6 +1696,10 @@ async def openai_image_generation(self, request: ImageGenerationRequest, ) elif request.response_format == "url": + MediaStorage.save_image( + output_images[0], + self.media_storage_path / f"{image_id}.png", + ) # TODO: Support URL mode return self._create_not_supported_error( "URL mode is not supported for image generation") @@ -1732,23 +1742,17 @@ async def openai_image_edit(self, request: ImageEditRequest, ) # Build response - output_images = output.image - MediaStorage.save_image( - output_images, - self.media_storage_path / f"{image_id}.png", - ) - - if not isinstance(output_images, list): - output_images = [output_images] + output_images = _normalize_image_output(output.image) response = ImageGenerationResponse( created=int(time.time()), data=[ - ImageObject(b64_json=base64.b64encode( - MediaStorage.convert_image_to_bytes(image)).decode( - 'utf-8'), - revised_prompt=request.prompt) - for image in output_images + ImageObject( + b64_json=base64.b64encode( + MediaStorage.convert_image_to_bytes(image)).decode( + 'utf-8'), + revised_prompt=request.prompt, + ) for image in output_images ], size=f"{params.width}x{params.height}", )