Skip to content

Commit 09d137b

Browse files
committed
upd
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent abb1848 commit 09d137b

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

vllm_omni/diffusion/models/flux_kontext/pipeline_flux_kontext.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import json
1111
import os
1212
from collections.abc import Callable, Iterable
13-
from typing import Any
13+
from typing import Any, cast
1414

1515
import numpy as np
1616
import PIL.Image
@@ -571,14 +571,44 @@ def forward(
571571
max_sequence_length: int = 512,
572572
sigmas: list[float] | None = None,
573573
) -> DiffusionOutput:
574-
prompt = req.prompt if req.prompt is not None else prompt
575-
image = req.pil_image if req.pil_image is not None else image
576-
height = req.height or height
577-
width = req.width or width
578-
num_inference_steps = req.num_inference_steps or num_inference_steps
579-
guidance_scale = req.guidance_scale or guidance_scale
580-
generator = req.generator or generator
581-
latents = req.latents or latents
574+
# Handle multiple prompts - only take the first one, similar to Flux2KleinPipeline
575+
if len(req.prompts) > 1:
576+
logger.warning(
577+
"This model only supports a single prompt, not a batched request.",
578+
"Taking only the first prompt for now.",
579+
)
580+
first_prompt = req.prompts[0] if req.prompts else None
581+
prompt = (
582+
first_prompt
583+
if isinstance(first_prompt, str)
584+
else (first_prompt.get("prompt") or "")
585+
if first_prompt
586+
else prompt
587+
)
588+
589+
# Handle image from prompt data
590+
if (
591+
raw_image := None
592+
if isinstance(first_prompt, str)
593+
else first_prompt.get("multi_modal_data", {}).get("image")
594+
if first_prompt
595+
else None
596+
) is None:
597+
pass # use image from param list
598+
elif isinstance(raw_image, list):
599+
image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image]
600+
else:
601+
image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image)
602+
height = req.sampling_params.height or height
603+
width = req.sampling_params.width or width
604+
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
605+
guidance_scale = req.sampling_params.guidance_scale or guidance_scale
606+
generator = req.sampling_params.generator or generator
607+
latents = (
608+
req.sampling_params.extra_args.get("latents")
609+
if req.sampling_params.extra_args.get("latents") is not None
610+
else latents
611+
)
582612
height = height or self.default_sample_size * self.vae_scale_factor
583613
width = width or self.default_sample_size * self.vae_scale_factor
584614

@@ -645,7 +675,7 @@ def forward(
645675
image_height = image_height // multiple_of * multiple_of
646676
image = self.image_processor.resize(image, image_height, image_width)
647677
image = self.image_processor.preprocess(image, image_height, image_width)
648-
if req.height is None and req.width is None:
678+
if height is None and width is None:
649679
height = image_height
650680
width = image_width
651681

0 commit comments

Comments
 (0)