Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/optimum/rbln/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,21 @@ def forward(
return_dict: bool = True,
**kwargs,
):
"""
Forward pass for the RBLN-optimized ControlNetModel.

Args:
sample (torch.FloatTensor): The noisy input tensor.
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
encoder_hidden_states (torch.Tensor): The encoder hidden states.
controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
Comment on lines +225 to +232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it okay if there is no mention of kwargs?


Returns:
(Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
"""
sample_batch_size = sample.size()[0]
compiled_batch_size = self.compiled_batch_size
if sample_batch_size != compiled_batch_size and (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,27 @@ def post_process_latents(self, prior_latents):

def forward(
self,
hidden_states,
hidden_states: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
Forward pass for the RBLN-optimized PriorTransformer.

Args:
hidden_states (torch.Tensor): The currently predicted image embeddings.
timestep (Union[torch.Tensor, float, int]): Current denoising step.
proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.

Returns:
(Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
"""
# Convert timestep(long) and attention_mask(bool) to float
return super().forward(
hidden_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,21 @@ def forward(
padding_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
Forward pass for the RBLN-optimized CosmosTransformer3DModel.

Args:
hidden_states (torch.Tensor): The currently predicted image embeddings.
timestep (torch.Tensor): Current denoising step.
encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
fps: (Optional[int]): Frames per second for the video being generated.
condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.

Returns:
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
"""
(
hidden_states,
temb,
Expand Down
13 changes: 13 additions & 0 deletions src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ def forward(
return_dict: bool = True,
**kwargs,
):
"""
Forward pass for the RBLN-optimized SD3Transformer2DModel.
Args:
hidden_states (torch.FloatTensor): The currently predicted image embeddings.
encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
timestep (torch.LongTensor): Current denoising step.
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
Returns:
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
"""
sample_batch_size = hidden_states.size()[0]
compiled_batch_size = self.compiled_batch_size
if sample_batch_size != compiled_batch_size and (
Expand Down
16 changes: 16 additions & 0 deletions src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,22 @@ def forward(
return_dict: bool = True,
**kwargs,
) -> Union[UNet2DConditionOutput, Tuple]:
"""
Forward pass for the RBLN-optimized UNet2DConditionModel.

Args:
sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
encoder_hidden_states (torch.Tensor): The encoder hidden states.
added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
if specified are added to the embeddings that are passed along to the UNet blocks.
down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.

Returns:
(Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
"""
sample_batch_size = sample.size()[0]
compiled_batch_size = self.compiled_batch_size
if sample_batch_size != compiled_batch_size and (
Expand Down