Skip to content

Commit e74b9eb

Browse files
committed
update docstrings for diffusers models
1 parent 60ae60e commit e74b9eb

File tree

5 files changed

+74
-1
lines changed

5 files changed

+74
-1
lines changed

src/optimum/rbln/diffusers/models/controlnet.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,21 @@ def forward(
219219
return_dict: bool = True,
220220
**kwargs,
221221
):
222+
"""
223+
Forward pass for the RBLN-optimized ControlNetModel.
224+
225+
Args:
226+
sample (torch.FloatTensor): The noisy input tensor.
227+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
228+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
229+
controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
230+
conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
231+
added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
232+
return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
233+
234+
Returns:
235+
(Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
236+
"""
222237
sample_batch_size = sample.size()[0]
223238
compiled_batch_size = self.compiled_batch_size
224239
if sample_batch_size != compiled_batch_size and (

src/optimum/rbln/diffusers/models/transformers/prior_transformer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,27 @@ def post_process_latents(self, prior_latents):
128128

129129
def forward(
130130
self,
131-
hidden_states,
131+
hidden_states: torch.Tensor,
132132
timestep: Union[torch.Tensor, float, int],
133133
proj_embedding: torch.Tensor,
134134
encoder_hidden_states: Optional[torch.Tensor] = None,
135135
attention_mask: Optional[torch.Tensor] = None,
136136
return_dict: bool = True,
137137
):
138+
"""
139+
Forward pass for the RBLN-optimized PriorTransformer.
140+
141+
Args:
142+
hidden_states (torch.Tensor): The currently predicted image embeddings.
143+
timestep (Union[torch.Tensor, float, int]): Current denoising step.
144+
proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
145+
encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
146+
attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
147+
return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
148+
149+
Returns:
150+
(Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
151+
"""
138152
# Convert timestep(long) and attention_mask(bool) to float
139153
return super().forward(
140154
hidden_states,

src/optimum/rbln/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,21 @@ def forward(
303303
padding_mask: Optional[torch.Tensor] = None,
304304
return_dict: bool = True,
305305
):
306+
"""
307+
Forward pass for the RBLN-optimized CosmosTransformer3DModel.
308+
309+
Args:
310+
hidden_states (torch.Tensor): The currently predicted image embeddings.
311+
timestep (torch.Tensor): Current denoising step.
312+
encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
313+
fps: (Optional[int]): Frames per second for the video being generated.
314+
condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
315+
padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
316+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
317+
318+
Returns:
319+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
320+
"""
306321
(
307322
hidden_states,
308323
temb,

src/optimum/rbln/diffusers/models/transformers/transformer_sd3.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,19 @@ def forward(
161161
return_dict: bool = True,
162162
**kwargs,
163163
):
164+
"""
165+
Forward pass for the RBLN-optimized SD3Transformer2DModel.
166+
167+
Args:
168+
hidden_states (torch.FloatTensor): The currently predicted image embeddings.
169+
encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
170+
pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
171+
timestep (torch.LongTensor): Current denoising step.
172+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
173+
174+
Returns:
175+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
176+
"""
164177
sample_batch_size = hidden_states.size()[0]
165178
compiled_batch_size = self.compiled_batch_size
166179
if sample_batch_size != compiled_batch_size and (

src/optimum/rbln/diffusers/models/unets/unet_2d_condition.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,22 @@ def forward(
349349
return_dict: bool = True,
350350
**kwargs,
351351
) -> Union[UNet2DConditionOutput, Tuple]:
352+
"""
353+
Forward pass for the RBLN-optimized UNet2DConditionModel.
354+
355+
Args:
356+
sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
357+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
358+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
359+
added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
360+
if specified are added to the embeddings that are passed along to the UNet blocks.
361+
down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
362+
mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
363+
return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
364+
365+
Returns:
366+
(Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
367+
"""
352368
sample_batch_size = sample.size()[0]
353369
compiled_batch_size = self.compiled_batch_size
354370
if sample_batch_size != compiled_batch_size and (

0 commit comments

Comments
 (0)