Skip to content

Wrong implementation of Spiking_Self_Attention #14

@frostylight

Description

@frostylight

It seems that the last lif (proj_lif) in Spiking_Self_Attention get input with shape $[T\times B, C, N]$ rather than $[T, B, C, N]$.
The reshape operators should be executed after proj_bn and before proj_lif.

The Accuracy of trained HST-10-384 is 78.732% when batch_size is set to $32$ (default in imagenet/test.py).
When fixing this problem by setting batch_size to 1, the performance drops to 74.202% (-4.530%).
When fixing this problem by replacing it with following code, the performance drops to 74.196% (-4.536%).

x = self.proj_lif(self.proj_bn(self.proj_conv(x)).reshape(T, B, C, W, H))

x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x)
x = x.flatten(0, 1)
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, W, H)

QKFormer/cifar10/model.py

Lines 115 to 118 in 43f0adf

x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x)
x = x.flatten(0,1)
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T,B,C,W,H)

QKFormer/cifar100/model.py

Lines 115 to 118 in 43f0adf

x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x)
x = x.flatten(0,1)
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T,B,C,W,H)

x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x)
x = x.flatten(0, 1)
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, W, H)

x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
x = self.attn_lif(x)
x = x.flatten(0,1)
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T,B,C,W,H)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions