diff --git a/src/optimum/rbln/diffusers/models/controlnet.py b/src/optimum/rbln/diffusers/models/controlnet.py index d71ab1da2..07d03e398 100644 --- a/src/optimum/rbln/diffusers/models/controlnet.py +++ b/src/optimum/rbln/diffusers/models/controlnet.py @@ -219,6 +219,21 @@ def forward( return_dict: bool = True, **kwargs, ): + """ + Forward pass for the RBLN-optimized ControlNetModel. + + Args: + sample (torch.FloatTensor): The noisy input tensor. + timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input. + encoder_hidden_states (torch.Tensor): The encoder hidden states. + controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`. + conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs. + added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet. + return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple + + Returns: + (Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple) + """ sample_batch_size = sample.size()[0] compiled_batch_size = self.compiled_batch_size if sample_batch_size != compiled_batch_size and ( diff --git a/src/optimum/rbln/diffusers/models/transformers/prior_transformer.py b/src/optimum/rbln/diffusers/models/transformers/prior_transformer.py index 4a33164f8..371531140 100644 --- a/src/optimum/rbln/diffusers/models/transformers/prior_transformer.py +++ b/src/optimum/rbln/diffusers/models/transformers/prior_transformer.py @@ -128,13 +128,27 @@ def post_process_latents(self, prior_latents): def forward( self, - hidden_states, + hidden_states: torch.Tensor, timestep: Union[torch.Tensor, float, int], proj_embedding: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): + """ + Forward pass for the RBLN-optimized PriorTransformer. + + Args: + hidden_states (torch.Tensor): The currently predicted image embeddings. + timestep (Union[torch.Tensor, float, int]): Current denoising step. + proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on. + encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on. + attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings. + return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple. + + Returns: + (Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple]) + """ # Convert timestep(long) and attention_mask(bool) to float return super().forward( hidden_states, diff --git a/src/optimum/rbln/diffusers/models/transformers/transformer_cosmos.py b/src/optimum/rbln/diffusers/models/transformers/transformer_cosmos.py index 97dcc69a9..11acfa3b8 100644 --- a/src/optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +++ b/src/optimum/rbln/diffusers/models/transformers/transformer_cosmos.py @@ -303,6 +303,21 @@ def forward( padding_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): + """ + Forward pass for the RBLN-optimized CosmosTransformer3DModel. + + Args: + hidden_states (torch.Tensor): The currently predicted image embeddings. + timestep (torch.Tensor): Current denoising step. + encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + fps: (Optional[int]): Frames per second for the video being generated. + condition_mask (Optional[torch.Tensor]): Tensor of condition mask. + padding_mask (Optional[torch.Tensor]): Tensor of padding mask. + return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple. + + Returns: + (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple]) + """ ( hidden_states, temb, diff --git a/src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py b/src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py index 5c98fdb0c..305727144 100644 --- a/src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py +++ b/src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py @@ -161,6 +161,19 @@ def forward( return_dict: bool = True, **kwargs, ): + """ + Forward pass for the RBLN-optimized SD3Transformer2DModel. + + Args: + hidden_states (torch.FloatTensor): The currently predicted image embeddings. + encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions. + timestep (torch.LongTensor): Current denoising step. + return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple. + + Returns: + (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple]) + """ sample_batch_size = hidden_states.size()[0] compiled_batch_size = self.compiled_batch_size if sample_batch_size != compiled_batch_size and ( diff --git a/src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py b/src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py index e288abd50..681a71cdc 100644 --- a/src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py +++ b/src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py @@ -349,6 +349,22 @@ def forward( return_dict: bool = True, **kwargs, ) -> Union[UNet2DConditionOutput, Tuple]: + """ + Forward pass for the RBLN-optimized UNet2DConditionModel. + + Args: + sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input. + encoder_hidden_states (torch.Tensor): The encoder hidden states. + added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that + if specified are added to the embeddings that are passed along to the UNet blocks. + down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block. + return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + (Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple) + """ sample_batch_size = sample.size()[0] compiled_batch_size = self.compiled_batch_size if sample_batch_size != compiled_batch_size and ( diff --git a/src/optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py b/src/optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py index dc33db7a5..e71f35522 100644 --- a/src/optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/src/optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py @@ -96,6 +96,26 @@ def forward( guess_mode: bool = False, return_dict: bool = True, ): + """ + Forward pass for the RBLN-optimized MultiControlNetModel. + + This method processes multiple ControlNet models in sequence, applying each one to the input sample + with its corresponding conditioning image and scale factor. The outputs from all ControlNets are + merged by addition to produce the final control signals. + + Args: + sample (torch.FloatTensor): The noisy input tensor. + timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input. + encoder_hidden_states (torch.Tensor): The encoder hidden states from the text encoder. + controlnet_cond (List[torch.Tensor]): A list of conditional input tensors, one for each ControlNet model. + conditioning_scale (List[float]): A list of scale factors for each ControlNet output. Each scale + controls the strength of the corresponding ControlNet's influence on the generation. + return_dict (bool): Whether or not to return a dictionary instead of a plain tuple. Currently, + this method always returns a tuple regardless of this parameter. + + Returns: + (Tuple[List[torch.Tensor], torch.Tensor]) + """ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): down_samples, mid_sample = controlnet( sample=sample.contiguous(),