Skip to content

swin v2 adding padding to shifted window attention breaks the algorithm #2438

Open
@alita-moore

Description

@alita-moore

Here:

pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
_, Hp, Wp, _ = shifted_x.shape
the shifted window attention is applying padding after shifting the values.
def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
assumes that the window has been rolled by the shift_size. However, the applied padding would offset these values such that the generated mask does not contain the shifted values. Meaning, patches are being included in attention calculation when they should not be.

consider the following x

x = [[1,2,3],[4,5,6],[7,8,9]]

after shifting the windows you get

x = [[2,3,1],[5,6,4],[8,9,7]]

if the window size is 2 then it would apply padding like so

x = [[2,3,1, 0],[5,6,4, 0],[8,9,7,0], [0,0,0,0]]

because the shifted window attention mask is calculated from x at this point the calculated attn mask would only mask out the added padding tokens not the shifted values. In this particular example, the shifted values do not attend to each other inappropriately, but in the case of a larger grid (e.g. 3x3) you would see cases where tokens such as 7 might attend to a token such as 1.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions