Skip to content

Commit 07abfcf

Browse files
committed
rescale values in linear attention to mitigate overflows in fp16 setting
1 parent 2e35a99 commit 07abfcf

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

dalle2_pytorch/dalle2_pytorch.py

+1
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,7 @@ def forward(self, fmap):
15031503
k = k.softmax(dim = -2)
15041504

15051505
q = q * self.scale
1506+
v = v / (x * y)
15061507

15071508
context = einsum('b n d, b n e -> b d e', k, v)
15081509
out = einsum('b n d, b d e -> b n e', q, context)

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.1.0'
1+
__version__ = '1.2.0'

0 commit comments

Comments
 (0)