Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions vllm_omni/diffusion/models/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,51 +573,53 @@ def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: boo
def forward(
self,
req: OmniDiffusionRequest,
prompt: str | list[str] | None = None,
prompt_2: str | list[str] | None = None,
negative_prompt: str | list[str] | None = None,
negative_prompt_2: str | list[str] | None = None,
true_cfg_scale: float = 1.0,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 28,
sigmas: list[float] | None = None,
guidance_scale: float = 3.5,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.FloatTensor | None = None,
prompt_embeds: torch.FloatTensor | None = None,
pooled_prompt_embeds: torch.FloatTensor | None = None,
negative_prompt_embeds: torch.FloatTensor | None = None,
negative_pooled_prompt_embeds: torch.FloatTensor | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
joint_attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
):
"""Forward pass for flux."""
# TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "")
# TODO: May be some data formatting operations on the API side. Hack for now.
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts]

# For negative prompt, make it None if ALL are None---making it falsy and skipping CFG
# If only some of them are not None, only set those to empty strings---because we cannot skip CFG anyway.
if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts):
negative_prompt = None
elif req.prompts:
negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts]

req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts]
if any(p is not None for p in req_prompt_embeds):
# If at list one prompt is provided as an embedding,
# Then assume that the user wants to provide embeddings for all prompts, and enter this if block
# If the user in fact provides mixed input format, req_prompt_embeds will have some None's
# And `torch.stack` automatically raises an exception for us
prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError

req_negative_prompt_embeds = [
p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts
]
if any(p is not None for p in req_negative_prompt_embeds):
negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError

height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor
width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
sigmas = req.sampling_params.sigmas or sigmas
guidance_scale = (
req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale
)
generator = req.sampling_params.generator or generator
true_cfg_scale = req.sampling_params.true_cfg_scale or 1.0
num_inference_steps = req.sampling_params.num_inference_steps or 28
sigmas = req.sampling_params.sigmas
max_sequence_length = req.sampling_params.max_sequence_length or 512
guidance_scale = req.sampling_params.guidance_scale or 3.5
generator = req.sampling_params.generator
num_images_per_prompt = (
req.sampling_params.num_outputs_per_prompt
if req.sampling_params.num_outputs_per_prompt > 0
else num_images_per_prompt
req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1
)
latents = req.sampling_params.latents

# 1. Check inputs. Raise error if not correct
self.check_inputs(
Expand Down
96 changes: 13 additions & 83 deletions vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,72 +648,17 @@ def interrupt(self):
def forward(
self,
req: OmniDiffusionRequest,
image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
prompt: str | list[str] | None = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
sigmas: list[float] | None = None,
guidance_scale: float | None = 4.0,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
text_encoder_out_layers: tuple[int, ...] = (9, 18, 27),
) -> DiffusionOutput:
r"""
Function invoked when calling the pipeline for generation.

Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list of these):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models,
`guidance_scale` is ignored.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline.
If not provided, will be generated from "".
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
Expand All @@ -732,7 +677,6 @@ def forward(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
text_encoder_out_layers (`Tuple[int]`):
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.

Expand All @@ -751,46 +695,32 @@ def forward(
first_prompt = req.prompts[0]
prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "")

prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds")
negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds")

if (
raw_image := None
if isinstance(first_prompt, str)
else first_prompt.get("multi_modal_data", {}).get("image")
) is None:
pass # use image from param list
image = None
elif isinstance(raw_image, list):
image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image]
else:
image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image)

height = req.sampling_params.height or height
width = req.sampling_params.width or width
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
sigmas = req.sampling_params.sigmas or sigmas
guidance_scale = (
req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale
)
generator = req.sampling_params.generator or generator
height = req.sampling_params.height
width = req.sampling_params.width
num_inference_steps = req.sampling_params.num_inference_steps or 50
sigmas = req.sampling_params.sigmas
guidance_scale = req.sampling_params.guidance_scale or 4.0
generator = req.sampling_params.generator
num_images_per_prompt = (
req.sampling_params.num_outputs_per_prompt
if req.sampling_params.num_outputs_per_prompt > 0
else num_images_per_prompt
req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1
)
max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length
max_sequence_length = req.sampling_params.max_sequence_length or 512
text_encoder_out_layers = req.sampling_params.extra_args.get("text_encoder_out_layers", text_encoder_out_layers)

req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts]
if any(p is not None for p in req_prompt_embeds):
# If at list one prompt is provided as an embedding,
# Then assume that the user wants to provide embeddings for all prompts, and enter this if block
# If the user in fact provides mixed input format, req_prompt_embeds will have some None's
# And `torch.stack` automatically raises an exception for us
prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError

req_negative_prompt_embeds = [
p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts
]
if any(p is not None for p in req_negative_prompt_embeds):
negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError
latents = req.sampling_params.latents

# 1. Check inputs. Raise error if not correct
self.check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,18 +473,6 @@ def cfg_normalize_function(self, noise_pred, comb_pred, cfg_renorm_min=0.0):
def forward(
self,
req: OmniDiffusionRequest,
prompt: str | list[str] | None = None,
negative_prompt: str | list[str] | None = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
sigmas: list[float] | None = None,
guidance_scale: float = 4.5,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.FloatTensor | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
joint_attention_kwargs: dict[str, Any] | None = None,
Expand All @@ -494,29 +482,12 @@ def forward(
) -> DiffusionOutput:
# TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "")
# TODO: May be some data formatting operations on the API side. Hack for now.
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts]
if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts):
negative_prompt = None
elif req.prompts:
negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts]

height = req.sampling_params.height or height or self.default_sample_size * self.vae_scale_factor
width = req.sampling_params.width or width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
sigmas = req.sampling_params.sigmas or sigmas
generator = req.sampling_params.generator or generator
guidance_scale = (
req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale
)
num_images_per_prompt = (
req.sampling_params.num_outputs_per_prompt
if req.sampling_params.num_outputs_per_prompt is not None
else num_images_per_prompt
)
enable_prompt_rewrite = req.sampling_params.extra_args.get("enable_prompt_rewrite", enable_prompt_rewrite)
enable_cfg_renorm = req.sampling_params.extra_args.get("enable_cfg_renorm", enable_cfg_renorm)
cfg_renorm_min = req.sampling_params.extra_args.get("cfg_renorm_min", cfg_renorm_min)

req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts]
if any(p is not None for p in req_prompt_embeds):
# If at list one prompt is provided as an embedding,
Expand All @@ -531,6 +502,20 @@ def forward(
if any(p is not None for p in req_negative_prompt_embeds):
negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError

height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor
width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = req.sampling_params.num_inference_steps or 50
sigmas = req.sampling_params.sigmas
generator = req.sampling_params.generator
guidance_scale = req.sampling_params.guidance_scale or 4.5
num_images_per_prompt = (
req.sampling_params.num_outputs_per_prompt if req.sampling_params.num_outputs_per_prompt > 0 else 1
)
latents = req.sampling_params.latents
enable_prompt_rewrite = req.sampling_params.extra_args.get("enable_prompt_rewrite", enable_prompt_rewrite)
enable_cfg_renorm = req.sampling_params.extra_args.get("enable_cfg_renorm", enable_cfg_renorm)
cfg_renorm_min = req.sampling_params.extra_args.get("cfg_renorm_min", cfg_renorm_min)

self.check_inputs(
prompt,
height,
Expand Down
Loading
Loading