|
10 | 10 | import json |
11 | 11 | import os |
12 | 12 | from collections.abc import Callable, Iterable |
13 | | -from typing import Any |
| 13 | +from typing import Any, cast |
14 | 14 |
|
15 | 15 | import numpy as np |
16 | 16 | import PIL.Image |
@@ -571,14 +571,44 @@ def forward( |
571 | 571 | max_sequence_length: int = 512, |
572 | 572 | sigmas: list[float] | None = None, |
573 | 573 | ) -> DiffusionOutput: |
574 | | - prompt = req.prompt if req.prompt is not None else prompt |
575 | | - image = req.pil_image if req.pil_image is not None else image |
576 | | - height = req.height or height |
577 | | - width = req.width or width |
578 | | - num_inference_steps = req.num_inference_steps or num_inference_steps |
579 | | - guidance_scale = req.guidance_scale or guidance_scale |
580 | | - generator = req.generator or generator |
581 | | - latents = req.latents or latents |
| 574 | + # Handle multiple prompts - only take the first one, similar to Flux2KleinPipeline |
| 575 | + if len(req.prompts) > 1: |
| 576 | + logger.warning( |
| 577 | + "This model only supports a single prompt, not a batched request.", |
| 578 | + "Taking only the first prompt for now.", |
| 579 | + ) |
| 580 | + first_prompt = req.prompts[0] if req.prompts else None |
| 581 | + prompt = ( |
| 582 | + first_prompt |
| 583 | + if isinstance(first_prompt, str) |
| 584 | + else (first_prompt.get("prompt") or "") |
| 585 | + if first_prompt |
| 586 | + else prompt |
| 587 | + ) |
| 588 | + |
| 589 | + # Handle image from prompt data |
| 590 | + if ( |
| 591 | + raw_image := None |
| 592 | + if isinstance(first_prompt, str) |
| 593 | + else first_prompt.get("multi_modal_data", {}).get("image") |
| 594 | + if first_prompt |
| 595 | + else None |
| 596 | + ) is None: |
| 597 | + pass # use image from param list |
| 598 | + elif isinstance(raw_image, list): |
| 599 | + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] |
| 600 | + else: |
| 601 | + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) |
| 602 | + height = req.sampling_params.height or height |
| 603 | + width = req.sampling_params.width or width |
| 604 | + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps |
| 605 | + guidance_scale = req.sampling_params.guidance_scale or guidance_scale |
| 606 | + generator = req.sampling_params.generator or generator |
| 607 | + latents = ( |
| 608 | + req.sampling_params.extra_args.get("latents") |
| 609 | + if req.sampling_params.extra_args.get("latents") is not None |
| 610 | + else latents |
| 611 | + ) |
582 | 612 | height = height or self.default_sample_size * self.vae_scale_factor |
583 | 613 | width = width or self.default_sample_size * self.vae_scale_factor |
584 | 614 |
|
@@ -645,7 +675,7 @@ def forward( |
645 | 675 | image_height = image_height // multiple_of * multiple_of |
646 | 676 | image = self.image_processor.resize(image, image_height, image_width) |
647 | 677 | image = self.image_processor.preprocess(image, image_height, image_width) |
648 | | - if req.height is None and req.width is None: |
| 678 | + if height is None and width is None: |
649 | 679 | height = image_height |
650 | 680 | width = image_width |
651 | 681 |
|
|
0 commit comments