@@ -146,6 +146,7 @@ def forward(
146146 mask (torch.Tensor, optional): Attention mask of floating points in the range
147147 `[-inf, 0)` with shape of `(nW, ws, ws)`, where `nW` is the number of windows,
148148 and `ws` is the window size (i.e. total tokens inside the window).
149+ rollout_step (int, optional): Roll-out step. Defaults to `0`.
149150
150151 Returns:
151152 torch.Tensor: Output of shape `(nW*B, N, C)`.
@@ -198,8 +199,8 @@ def window_partition_3d(x: torch.Tensor, ws: tuple[int, int, int]) -> torch.Tens
198199 """Partition into windows.
199200
200201 Args:
201- x: (torch.Tensor): Input tensor of shape `(B, C, H, W, D)`.
202- ws: (tuple[int, int, int]): A 3D window size `(Wc, Wh, Ww)`.
202+ x (torch.Tensor): Input tensor of shape `(B, C, H, W, D)`.
203+ ws (tuple[int, int, int]): A 3D window size `(Wc, Wh, Ww)`.
203204
204205 Returns:
205206 torch.Tensor: Partitioning of shape `(num_windows*B, Wc, Wh, Ww, D)`.
@@ -318,7 +319,8 @@ def compute_3d_shifted_window_mask(
318319 H (int): Height of the image.
319320 W (int): Width of the image.
320321 ws (tuple[int, int, int]): Window sizes of the form `(Wc, Wh, Ww)`.
321- ss (tuple[int, int, int]): Shift sizes of the form `(Sc, Sh, Sw)`
322+ ss (tuple[int, int, int]): Shift sizes of the form `(Sc, Sh, Sw)`.
323+ device (torch.device): Device of the mask.
322324 dtype (torch.dtype, optional): Data type of the mask. Defaults to `torch.bfloat16`.
323325 warped (bool): If `True`,assume that the left and right sides of the image are connected.
324326 Defaults to `True`.
@@ -768,7 +770,8 @@ def __init__(
768770 lora_mode : LoRAMode = "single" ,
769771 use_lora : bool = False ,
770772 ) -> None :
771- """
773+ """Initialise.
774+
772775 Args:
773776 embed_dim (int): Patch embedding dimension. Default to `96`.
774777 encoder_depths (tuple[int, ...]): Number of blocks in each encoder layer. Defaults to
0 commit comments