@@ -72,7 +72,7 @@ def forward(self, x, mask = None):
72
72
return out
73
73
74
74
class SparseConvCausalAttention (nn .Module ):
75
- def __init__ (self , dim , seq_len , image_size = 32 , kernel_size = 5 , heads = 8 , dim_head = 64 , dropout = 0. , ** kwargs ):
75
+ def __init__ (self , dim , seq_len , image_size = 32 , kernel_size = 5 , dilation = 0 , heads = 8 , dim_head = 64 , dropout = 0. , ** kwargs ):
76
76
super ().__init__ ()
77
77
assert kernel_size % 2 == 1 , 'kernel size must be odd'
78
78
@@ -81,6 +81,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, heads = 8, di
81
81
self .scale = dim_head ** - 0.5
82
82
self .image_size = image_size
83
83
self .kernel_size = kernel_size
84
+ self .dilation = dilation
84
85
85
86
self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
86
87
@@ -90,7 +91,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, heads = 8, di
90
91
)
91
92
92
93
def forward (self , x , mask = None ):
93
- b , n , _ , h , img_size , kernel_size , device = * x .shape , self .heads , self .image_size , self .kernel_size , x .device
94
+ b , n , _ , h , img_size , kernel_size , dilation , device = * x .shape , self .heads , self .image_size , self .kernel_size , self . dilation , x .device
94
95
qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
95
96
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
96
97
@@ -114,8 +115,11 @@ def forward(self, x, mask = None):
114
115
115
116
# image attention
116
117
118
+ effective_kernel_size = (kernel_size - 1 ) * dilation + 1
119
+ padding = effective_kernel_size // 2
120
+
117
121
k_img , v_img = map (lambda t : rearrange (t , 'b (h w) c -> b c h w' , h = img_size ), (k_img , v_img ))
118
- k_img , v_img = map (lambda t : F .unfold (t , kernel_size , padding = ( kernel_size // 2 ) ), (k_img , v_img ))
122
+ k_img , v_img = map (lambda t : F .unfold (t , kernel_size , padding = padding , dilation = dilation ), (k_img , v_img ))
119
123
k_img , v_img = map (lambda t : rearrange (t , 'b (j d) i -> b i j d' , j = kernel_size ** 2 ), (k_img , v_img ))
120
124
121
125
k_text , v_text = map (lambda t : repeat (t , 'b j d -> b i j d' , i = img_seq_len ), (k_text , v_text ))
@@ -132,8 +136,8 @@ def forward(self, x, mask = None):
132
136
i , j = dots_image .shape [- 2 :]
133
137
img_seq = torch .arange (img_seq_len , device = device )
134
138
k_img_indices = rearrange (img_seq .float (), '(h w) -> () () h w' , h = img_size )
135
- k_img_indices = F .pad (k_img_indices , (kernel_size // 2 ,) * 4 , value = img_seq_len )
136
- k_img_indices = F .unfold (k_img_indices , kernel_size )
139
+ k_img_indices = F .pad (k_img_indices , (padding ,) * 4 , value = img_seq_len ) # padding set to be max, so it is never attended to
140
+ k_img_indices = F .unfold (k_img_indices , kernel_size , dilation = dilation )
137
141
k_img_indices = rearrange (k_img_indices , 'b j i -> b i j' )
138
142
139
143
# mask image attention
0 commit comments