diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py index b90aaa8ca4..54581876fa 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -573,51 +573,53 @@ def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: boo def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, prompt_2: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, negative_prompt_2: str | list[str] | None = None, - true_cfg_scale: float = 1.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 28, - sigmas: list[float] | None = None, - guidance_scale: float = 3.5, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: torch.FloatTensor | None = None, pooled_prompt_embeds: torch.FloatTensor | None = None, - negative_prompt_embeds: torch.FloatTensor | None = None, negative_pooled_prompt_embeds: torch.FloatTensor | None = None, output_type: str | None = "pil", return_dict: bool = True, joint_attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ): """Forward pass for flux.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] + + # For negative prompt, make it None if ALL are None---making it falsy and skipping CFG + # If only some of them are not None, only set those to empty strings---because we cannot skip CFG anyway. if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at least one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 1.0 + num_inference_steps = req.sampling_params.num_inference_steps or 28 + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 512 + guidance_scale = req.sampling_params.guidance_scale or 3.5 + generator = req.sampling_params.generator num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) + latents = req.sampling_params.latents # 1. Check inputs. Raise error if not correct self.check_inputs( diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index e1ef706c3f..eaf9a2b7f8 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -648,72 +648,19 @@ def interrupt(self): def forward( self, req: OmniDiffusionRequest, - image: PIL.Image.Image | list[PIL.Image.Image] | None = None, - prompt: str | list[str] | None = None, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float | None = 4.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, output_type: str | None = "pil", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), ) -> DiffusionOutput: r""" Function invoked when calling the pipeline for generation. Args: - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list of these): - `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both - numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list - or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a - list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, - `guidance_scale` is ignored. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline. - If not provided, will be generated from "". + req (`OmniDiffusionRequest`): + The request object containing the prompts and sampling parameters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -732,7 +679,6 @@ def forward( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. text_encoder_out_layers (`Tuple[int]`): Layer indices to use in the `text_encoder` to derive the final prompt embeddings. @@ -751,46 +697,32 @@ def forward( first_prompt = req.prompts[0] prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + if ( raw_image := None if isinstance(first_prompt, str) else first_prompt.get("multi_modal_data", {}).get("image") ) is None: - pass # use image from param list + image = None elif isinstance(raw_image, list): image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] else: image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - generator = req.sampling_params.generator or generator + height = req.sampling_params.height + width = req.sampling_params.width + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + guidance_scale = req.sampling_params.guidance_scale or 4.0 + generator = req.sampling_params.generator num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + max_sequence_length = req.sampling_params.max_sequence_length or 512 text_encoder_out_layers = req.sampling_params.extra_args.get("text_encoder_out_layers", text_encoder_out_layers) - - req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] - if any(p is not None for p in req_prompt_embeds): - # If at list one prompt is provided as an embedding, - # Then assume that the user wants to provide embeddings for all prompts, and enter this if block - # If the user in fact provides mixed input format, req_prompt_embeds will have some None's - # And `torch.stack` automatically raises an exception for us - prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError - - req_negative_prompt_embeds = [ - p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts - ] - if any(p is not None for p in req_negative_prompt_embeds): - negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + latents = req.sampling_params.latents # 1. Check inputs. Raise error if not correct self.check_inputs( diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 09f409f313..94476fbb3e 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -473,18 +473,6 @@ def cfg_normalize_function(self, noise_pred, comb_pred, cfg_renorm_min=0.0): def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 4.5, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, output_type: str | None = "pil", return_dict: bool = True, joint_attention_kwargs: dict[str, Any] | None = None, @@ -494,32 +482,15 @@ def forward( ) -> DiffusionOutput: # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] - height = req.sampling_params.height or height or self.default_sample_size * self.vae_scale_factor - width = req.sampling_params.width or width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - generator = req.sampling_params.generator or generator - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt is not None - else num_images_per_prompt - ) - enable_prompt_rewrite = req.sampling_params.extra_args.get("enable_prompt_rewrite", enable_prompt_rewrite) - enable_cfg_renorm = req.sampling_params.extra_args.get("enable_cfg_renorm", enable_cfg_renorm) - cfg_renorm_min = req.sampling_params.extra_args.get("cfg_renorm_min", cfg_renorm_min) - req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] if any(p is not None for p in req_prompt_embeds): - # If at list one prompt is provided as an embedding, + # If at least one prompt is provided as an embedding, # Then assume that the user wants to provide embeddings for all prompts, and enter this if block # If the user in fact provides mixed input format, req_prompt_embeds will have some None's # And `torch.stack` automatically raises an exception for us @@ -531,6 +502,20 @@ def forward( if any(p is not None for p in req_negative_prompt_embeds): negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + generator = req.sampling_params.generator + guidance_scale = req.sampling_params.guidance_scale or 4.5 + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 + ) + latents = req.sampling_params.latents + enable_prompt_rewrite = req.sampling_params.extra_args.get("enable_prompt_rewrite", enable_prompt_rewrite) + enable_cfg_renorm = req.sampling_params.extra_args.get("enable_cfg_renorm", enable_cfg_renorm) + cfg_renorm_min = req.sampling_params.extra_args.get("cfg_renorm_min", cfg_renorm_min) + self.check_inputs( prompt, height, diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index 3ba5e34488..2202ae3c1d 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -522,17 +522,6 @@ def check_inputs( def forward( self, req: OmniDiffusionRequest, - image: PIL.Image.Image | torch.Tensor | None = None, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 3.5, - num_images_per_prompt: int | None = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, output_type: str | None = "pil", return_dict: bool = True, joint_attention_kwargs: dict[str, Any] | None = None, @@ -548,21 +537,18 @@ def forward( prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") - negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") # type: ignore # Why it is list[torch.Tensor] in OmniTokenInputs or OmniEmbedsPrompt? Doesn't make sense + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") - sigmas = req.sampling_params.sigmas or sigmas - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas + guidance_scale = req.sampling_params.guidance_scale or 3.5 + num_inference_steps = req.sampling_params.num_inference_steps or 50 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt is not None - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) - generator = req.sampling_params.generator or generator + generator = req.sampling_params.generator height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + latents = req.sampling_params.latents if prompt is not None: batch_size = 1 if isinstance(prompt, str) else len(prompt) @@ -572,11 +558,24 @@ def forward( if not isinstance(first_prompt, str) and "preprocessed_image" in ( additional_information := first_prompt.get("additional_information", {}) ): + # Using preprocessed image prompt_image = additional_information.get("prompt_image") image = additional_information.get("preprocessed_image") calculated_height = additional_information.get("calculated_height", height) calculated_width = additional_information.get("calculated_width", width) else: + # Using original image + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + image = None + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + image_size = image[0].size if isinstance(image, list) else image.size calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 963f1c483b..f17f448bdd 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -527,66 +527,18 @@ def interrupt(self): def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - guidance_scale: float = 5.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - num_images_per_prompt: int | None = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: torch.FloatTensor | None = None, - negative_prompt_embeds: torch.FloatTensor | None = None, output_type: str | None = "pil", return_dict: bool = True, joint_attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 256, ) -> DiffusionOutput: r""" Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `list[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - negative_prompt (`str` or `list[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - not greater than `1`). - guidance_scale (`float`, *optional*, defaults to 1.0): - True classifier-free guidance (guidance scale) is enabled when `guidance_scale` > 1 and - `negative_prompt` is provided. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - sigmas (`list[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `list[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + req (`OmniDiffusionRequest`): + The request object containing the prompts and sampling parameters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -605,7 +557,6 @@ def forward( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. Examples: @@ -616,25 +567,37 @@ def forward( """ # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at least one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale - ) - generator = req.sampling_params.generator or generator + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + guidance_scale = req.sampling_params.guidance_scale or 5.0 + generator = req.sampling_params.generator num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) + latents = req.sampling_params.latents + max_sequence_length = req.sampling_params.max_sequence_length or 256 # Steps: # 1. Check Inputs diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index d85d98b5bf..b96a1e21fc 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -535,48 +535,49 @@ def interrupt(self): def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - true_cfg_scale: float = 4.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 1.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds_mask: torch.Tensor | None = None, output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ) -> DiffusionOutput: # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at least one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 512 + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 4.0 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 1.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) + latents = req.sampling_params.latents # 1. check inputs # 2. encode prompts # 3. prepare latents and timesteps diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 082c9e7545..8d85122d04 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -600,26 +600,11 @@ def interrupt(self): def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - image: PIL.Image.Image | torch.Tensor | None = None, - true_cfg_scale: float = 4.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 1.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds_mask: torch.Tensor | None = None, output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ) -> DiffusionOutput: """Forward pass for image editing.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -639,6 +624,8 @@ def forward( "Qwen official repository recommends to use whitespace string as negative_prompt. " "Note: some distilled variants may not be affected by this." ) + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") # Get preprocessed image from request (pre-processing is done in DiffusionEngine) if not isinstance(first_prompt, str) and "preprocessed_image" in ( @@ -652,10 +639,21 @@ def forward( width = req.sampling_params.width else: # fallback to run pre-processing in pipeline (debug only) + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + image = None + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + image_size = image[0].size if isinstance(image, list) else image.size calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) - height = height or calculated_height - width = width or calculated_width + height = req.sampling_params.height or calculated_height + width = req.sampling_params.width or calculated_width multiple_of = self.vae_scale_factor * 2 width = width // multiple_of * multiple_of @@ -667,18 +665,19 @@ def forward( image = self.image_processor.preprocess(image, calculated_height, calculated_width) image = image.unsqueeze(2) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 512 + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 4.0 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 1.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) + latents = req.sampling_params.latents # 1. check inputs # 2. encode prompts diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 2fdc7003ef..1c478aadce 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -531,26 +531,11 @@ def interrupt(self): def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor | None = None, - true_cfg_scale: float = 4.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 1.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds_mask: torch.Tensor | None = None, output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ) -> DiffusionOutput: """Forward pass for image editing with support for multiple images.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -570,6 +555,8 @@ def forward( "Qwen official repository recommends to use whitespace string as negative_prompt. " "Note: some distilled variants may not be affected by this." ) + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") # Get preprocessed images from request (pre-processing is done in DiffusionEngine) if ( @@ -587,16 +574,21 @@ def forward( width = req.sampling_params.width else: # fallback to run pre-processing in pipeline (debug only) - if image is None: + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: raise ValueError("Image is required for QwenImageEditPlusPipeline") - - if not isinstance(image, list): - image = [image] + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = [PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image)] image_size = image[0].size calculated_width, calculated_height = calculate_dimensions(VAE_IMAGE_SIZE, image_size[0] / image_size[1]) - height = height or calculated_height - width = width or calculated_width + height = req.sampling_params.height or calculated_height + width = req.sampling_params.width or calculated_width multiple_of = self.vae_scale_factor * 2 width = width // multiple_of * multiple_of @@ -618,18 +610,19 @@ def forward( condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 512 + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 4.0 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 1.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) + latents = req.sampling_params.latents # 1. check inputs # 2. encode prompts diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index dbe0bfe4f8..5f90f44de1 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -578,27 +578,10 @@ def interrupt(self): def forward( self, req: OmniDiffusionRequest, - image: PIL.Image.Image | torch.Tensor | None = None, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - true_cfg_scale: float = 4.0, - layers: int | None = 4, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float | None = None, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds_mask: torch.Tensor | None = None, output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, - max_sequence_length: int = 512, - resolution: int = 640, - cfg_normalize: bool = False, - use_en_prompt: bool = False, ) -> DiffusionOutput: """Forward pass for image layered.""" @@ -614,27 +597,26 @@ def forward( first_prompt = req.prompts[0] prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") - - layers = req.sampling_params.layers if req.sampling_params.layers is not None else layers - resolution = req.sampling_params.resolution if req.sampling_params.resolution is not None else resolution - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - cfg_normalize = ( - req.sampling_params.cfg_normalize if req.sampling_params.cfg_normalize is not None else cfg_normalize - ) - use_en_prompt = ( - req.sampling_params.use_en_prompt if req.sampling_params.use_en_prompt is not None else use_en_prompt - ) - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + + layers = req.sampling_params.layers or 4 + resolution = req.sampling_params.resolution or 640 + max_sequence_length = req.sampling_params.max_sequence_length or 512 + cfg_normalize = req.sampling_params.cfg_normalize + use_en_prompt = req.sampling_params.use_en_prompt + num_inference_steps = req.sampling_params.num_inference_steps or 50 + sigmas = req.sampling_params.sigmas + generator = req.sampling_params.generator + true_cfg_scale = req.sampling_params.true_cfg_scale or 4.0 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 1.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) + latents = req.sampling_params.latents if not isinstance(first_prompt, str) and "preprocessed_image" in ( additional_information := first_prompt.get("additional_information", {}) @@ -648,6 +630,17 @@ def forward( width = req.sampling_params.width else: # fallback to run pre-processing in pipeline (debug only) + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + image = None + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + image_size = image[0].size if isinstance(image, list) else image.size assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}" calculated_width, calculated_height = calculate_dimensions( diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index 3668c132f5..661275de08 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -569,43 +569,42 @@ def diffuse( def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] = "", prompt_2: str | list[str] = "", prompt_3: str | list[str] = "", - negative_prompt: str | list[str] = "", negative_prompt_2: str | list[str] = "", negative_prompt_3: str | list[str] = "", - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 28, - sigmas: list[float] | None = None, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, pooled_prompt_embeds: torch.Tensor | None = None, negative_pooled_prompt_embeds: torch.Tensor | None = None, - max_sequence_length: int = 256, ) -> DiffusionOutput: # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt - negative_prompt = [ - "" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts - ] or negative_prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at least one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - generator = req.sampling_params.generator or generator + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 256 + num_inference_steps = req.sampling_params.num_inference_steps or 28 + generator = req.sampling_params.generator num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) + latents = req.sampling_params.latents # 1. check inputs # 2. encode prompts # 3. prepare latents and timesteps diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index c48d68efd6..774c76eaa4 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -350,17 +350,7 @@ def prepare_latents( def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - audio_end_in_s: float | None = None, - audio_start_in_s: float = 0.0, - num_inference_steps: int = 100, - guidance_scale: float = 7.0, num_waveforms_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, output_type: str = "np", ) -> DiffusionOutput: """ @@ -368,17 +358,7 @@ def forward( Args: req: OmniDiffusionRequest containing generation parameters - prompt: Text prompt for audio generation - negative_prompt: Negative prompt for CFG - audio_end_in_s: Audio end time in seconds (max ~47s for stable-audio-open-1.0) - audio_start_in_s: Audio start time in seconds - num_inference_steps: Number of denoising steps - guidance_scale: CFG scale num_waveforms_per_prompt: Number of audio outputs per prompt - generator: Random generator for reproducibility - latents: Pre-generated latents - prompt_embeds: Pre-computed prompt embeddings - negative_prompt_embeds: Pre-computed negative prompt embeddings output_type: Output format ("np", "pt", or "latent") Returns: @@ -387,24 +367,40 @@ def forward( # Extract from request # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): negative_prompt = None elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at least one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + + num_inference_steps = req.sampling_params.num_inference_steps or 100 if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 7.0 - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) + latents = req.sampling_params.latents # Get audio duration from request extra params or defaults - audio_start_in_s = req.sampling_params.extra_args.get("audio_start_in_s", audio_start_in_s) - audio_end_in_s = req.sampling_params.extra_args.get("audio_end_in_s", audio_end_in_s) + audio_start_in_s = req.sampling_params.extra_args.get("audio_start_in_s", 0.0) + audio_end_in_s = req.sampling_params.extra_args.get("audio_end_in_s", None) # Calculate audio length downsample_ratio = self.vae.hop_length diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index b902bc692e..fda0cc0797 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -327,17 +327,7 @@ def current_timestep(self): def forward( self, req: OmniDiffusionRequest, - prompt: str | None = None, - negative_prompt: str | None = None, - height: int = 480, - width: int = 832, - num_inference_steps: int = 40, - guidance_scale: float | tuple[float, float] = 4.0, - frame_num: int = 81, output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, attention_kwargs: dict | None = None, **kwargs, ) -> DiffusionOutput: @@ -348,14 +338,20 @@ def forward( """Please pass in a single prompt object or string, or a single-item list.""", ) if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list - prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") - negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = ( + None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + ) + if prompt is None and prompt_embeds is None: raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num + height = req.sampling_params.height or 480 + width = req.sampling_params.width or 832 + num_frames = req.sampling_params.num_frames or 81 # Ensure dimensions are compatible with VAE and patch size # For expand_timesteps mode, we need latent dims to be even (divisible by patch_size) @@ -363,11 +359,13 @@ def forward( mod_value = self.vae_scale_factor_spatial * patch_size[1] # 16*2=32 for TI2V, 8*2=16 for I2V height = (height // mod_value) * mod_value width = (width // mod_value) * mod_value - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_steps = req.sampling_params.num_inference_steps or 40 # Respect per-request guidance_scale when explicitly provided. if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 4.0 guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] guidance_high = ( @@ -410,8 +408,7 @@ def forward( dtype = self.text_encoder.dtype # Seed / generator - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 1aed9b75de..40dc2b56c7 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -281,18 +281,7 @@ def encode_image( def forward( self, req: OmniDiffusionRequest, - prompt: str | None = None, - negative_prompt: str | None = None, - image: PIL.Image.Image | torch.Tensor | None = None, - height: int = 480, - width: int = 832, - num_inference_steps: int = 40, - guidance_scale: float | tuple[float, float] = 5.0, - frame_num: int = 81, output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, last_image: PIL.Image.Image | torch.Tensor | None = None, attention_kwargs: dict | None = None, @@ -305,39 +294,43 @@ def forward( """Please pass in a single prompt object or string, or a single-item list.""", ) if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list - prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") - negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = ( + None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + ) if prompt is None and prompt_embeds is None: raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") # Get image from request - if image is None: - multi_modal_data = ( - req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None - ) - raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None - if raw_image is None: - raise ValueError("Image is required for I2V generation.") - if isinstance(raw_image, list): - if len(raw_image) > 1: - logger.warning( - """Received a list of image. Only a single image is supported by this model.""" - """Taking only the first image for now.""" - ) - raw_image = raw_image[0] - if isinstance(raw_image, str): - image = PIL.Image.open(raw_image) - else: - image = cast(PIL.Image.Image | torch.Tensor, raw_image) + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if raw_image is None: + raise ValueError("Image is required for I2V generation.") + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_frames = req.sampling_params.num_frames or frame_num - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + height = req.sampling_params.height or 480 + width = req.sampling_params.width or 832 + num_frames = req.sampling_params.num_frames or 81 + num_steps = req.sampling_params.num_inference_steps or 40 # Respect per-request guidance_scale when explicitly provided. if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 5.0 # Handle guidance scales guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] @@ -376,8 +369,7 @@ def forward( dtype = self.transformer.dtype # Generator setup - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index d32b7d697c..efe36b933f 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -220,18 +220,7 @@ def current_timestep(self): def forward( self, req: OmniDiffusionRequest, - prompt: str | None = None, - negative_prompt: str | None = None, - image: PIL.Image.Image | torch.Tensor | None = None, - height: int = 704, - width: int = 1280, - num_inference_steps: int = 40, - guidance_scale: float = 5.0, - frame_num: int = 81, output_type: str | None = "np", - generator: torch.Generator | list[torch.Generator] | None = None, - prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, attention_kwargs: dict | None = None, **kwargs, ) -> DiffusionOutput: @@ -242,40 +231,44 @@ def forward( """Please pass in a single prompt object or string, or a single-item list.""", ) if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list - prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") - negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = ( + None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") + ) if prompt is None and prompt_embeds is None: raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") # Get image from request (optional for TI2V) - if image is None: - multi_modal_data = ( - req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None - ) - raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None - if isinstance(raw_image, list): - if len(raw_image) > 1: - logger.warning( - """Received a list of image. Only a single image is supported by this model.""" - """Taking only the first image for now.""" - ) - raw_image = raw_image[0] - if raw_image is None: - image = None - elif isinstance(raw_image, str): - image = PIL.Image.open(raw_image) - else: - image = cast(PIL.Image.Image | torch.Tensor, raw_image) + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if raw_image is None: + image = None + elif isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) # Default dimensions for TI2V-5B (720P) - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num - num_steps = req.sampling_params.num_inference_steps or num_inference_steps + height = req.sampling_params.height or 704 + width = req.sampling_params.width or 1280 + num_frames = req.sampling_params.num_frames or 81 + num_steps = req.sampling_params.num_inference_steps or 40 # Respect per-request guidance_scale when explicitly provided. if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale + else: + guidance_scale = 5.0 self._guidance_scale = guidance_scale @@ -299,8 +292,7 @@ def forward( dtype = self.transformer.dtype # Generator setup - if generator is None: - generator = req.sampling_params.generator + generator = req.sampling_params.generator if generator is None and req.sampling_params.seed is not None: generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 2c938f39bf..bb666ef08c 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -315,75 +315,24 @@ def interrupt(self): def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - height: int = 1024, - width: int = 1024, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 5.0, cfg_normalization: bool = False, cfg_truncation: float = 1.0, - negative_prompt: str | list[str] | None = None, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.FloatTensor | None = None, - prompt_embeds: list[torch.FloatTensor] | None = None, - negative_prompt_embeds: list[torch.FloatTensor] | None = None, output_type: str | None = "pil", return_dict: bool = True, joint_attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ) -> DiffusionOutput: r""" Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `list[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, *optional*, defaults to 1024): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 1024): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - sigmas (`list[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - guidance_scale (`float`, *optional*, defaults to 5.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + req (`OmniDiffusionRequest`): + The request object containing the prompts and sampling parameters. cfg_normalization (`bool`, *optional*, defaults to False): Whether to apply configuration normalization. cfg_truncation (`float`, *optional*, defaults to 1.0): The truncation value for configuration. - negative_prompt (`str` or `list[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `list[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`list[torch.FloatTensor]`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`list[torch.FloatTensor]`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -403,8 +352,6 @@ def forward( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int`, *optional*, defaults to 512): - Maximum sequence length to use with the `prompt`. Examples: @@ -415,26 +362,40 @@ def forward( """ # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] + + # For negative prompt, make it None if ALL are None---making it falsy and skipping CFG + # If only some of them are not None, only set those to empty strings---because we cannot skip CFG anyway. if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): - negative_prompt = None - elif req.prompts: + negative_prompt: list[str] | None = None + else: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] - height = req.sampling_params.height or height - width = req.sampling_params.width or width - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at least one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError + + height = req.sampling_params.height or 1024 + width = req.sampling_params.width or 1024 + num_inference_steps = req.sampling_params.num_inference_steps or 50 generator = req.sampling_params.generator - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - guidance_scale = ( - req.sampling_params.guidance_scale if req.sampling_params.guidance_rescale is not None else guidance_scale - ) + sigmas = req.sampling_params.sigmas + max_sequence_length = req.sampling_params.max_sequence_length or 512 + guidance_scale = req.sampling_params.guidance_scale if req.sampling_params.guidance_rescale is not None else 5.0 num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt + req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1 ) + latents = req.sampling_params.latents vae_scale = self.vae_scale_factor * 2 if height % vae_scale != 0: