@@ -77,7 +77,8 @@ def __init__(
77
77
dim ,
78
78
heads = 8 ,
79
79
dim_head = 64 ,
80
- dropout = 0.
80
+ dropout = 0. ,
81
+ max_pos_emb = 512
81
82
):
82
83
super ().__init__ ()
83
84
inner_dim = dim_head * heads
@@ -87,17 +88,28 @@ def __init__(
87
88
self .to_kv = nn .Linear (dim , inner_dim * 2 , bias = False )
88
89
self .to_out = nn .Linear (inner_dim , dim )
89
90
91
+ self .max_pos_emb = max_pos_emb
92
+ self .rel_pos_emb = nn .Embedding (2 * max_pos_emb + 1 , dim_head )
93
+
90
94
self .dropout = nn .Dropout (dropout )
91
95
92
96
def forward (self , x , context = None , mask = None , context_mask = None ):
93
- device , h , has_context = x .device , self .heads , exists (context )
97
+ n , device , h , max_pos_emb , has_context = x .shape [ - 2 ], x . device , self .heads , self . max_pos_emb , exists (context )
94
98
context = default (context , x )
95
99
96
100
q , k , v = (self .to_q (x ), * self .to_kv (context ).chunk (2 , dim = - 1 ))
97
101
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), (q , k , v ))
98
102
99
103
dots = einsum ('b h i d, b h j d -> b h i j' , q , k ) * self .scale
100
104
105
+ # shaw's relative positional embedding
106
+ seq = torch .arange (n , device = device )
107
+ dist = seq [:, None ] - seq [None , :]
108
+ dist = dist .clip (- max_pos_emb , max_pos_emb ) + max_pos_emb
109
+ rel_pos_emb = self .rel_pos_emb (dist ).to (q )
110
+ pos_attn = einsum ('b h n d, n r d -> b h n r' , q , rel_pos_emb ) * self .scale
111
+ dots = dots + pos_attn
112
+
101
113
if exists (mask ) or exists (context_mask ):
102
114
mask = default (mask , lambda : torch .ones (* x .shape [:2 ], device = device ))
103
115
context_mask = default (context_mask , mask ) if not has_context else default (context_mask , lambda : torch .ones (* context .shape [:2 ], device = device ))
0 commit comments