3
3
from torch import LongTensor , nn , FloatTensor , BoolTensor
4
4
torch .set_grad_enabled (False )
5
5
6
- from .dalle_bart_encoder_torch import GLUTorch , AttentionTorch
6
+ from .dalle_bart_encoder import GLU , AttentionBase
7
7
8
8
9
- class DecoderCrossAttentionTorch ( AttentionTorch ):
9
+ class DecoderCrossAttention ( AttentionBase ):
10
10
def forward (
11
11
self ,
12
12
decoder_state : FloatTensor ,
@@ -19,7 +19,7 @@ def forward(
19
19
return super ().forward (keys , values , queries , attention_mask )
20
20
21
21
22
- class DecoderSelfAttentionTorch ( AttentionTorch ):
22
+ class DecoderSelfAttention ( AttentionBase ):
23
23
def forward (
24
24
self ,
25
25
decoder_state : FloatTensor ,
@@ -42,7 +42,7 @@ def forward(
42
42
return decoder_state , attention_state
43
43
44
44
45
- class DecoderLayerTorch (nn .Module ):
45
+ class DecoderLayer (nn .Module ):
46
46
def __init__ (
47
47
self ,
48
48
image_token_count : int ,
@@ -53,12 +53,12 @@ def __init__(
53
53
super ().__init__ ()
54
54
self .image_token_count = image_token_count
55
55
self .pre_self_attn_layer_norm = nn .LayerNorm (embed_count )
56
- self .self_attn = DecoderSelfAttentionTorch (head_count , embed_count )
56
+ self .self_attn = DecoderSelfAttention (head_count , embed_count )
57
57
self .self_attn_layer_norm = nn .LayerNorm (embed_count )
58
58
self .pre_encoder_attn_layer_norm = nn .LayerNorm (embed_count )
59
- self .encoder_attn = DecoderCrossAttentionTorch (head_count , embed_count )
59
+ self .encoder_attn = DecoderCrossAttention (head_count , embed_count )
60
60
self .encoder_attn_layer_norm = nn .LayerNorm (embed_count )
61
- self .glu = GLUTorch (embed_count , glu_embed_count )
61
+ self .glu = GLU (embed_count , glu_embed_count )
62
62
63
63
self .token_indices = torch .arange (self .image_token_count )
64
64
if torch .cuda .is_available ():
@@ -106,7 +106,7 @@ def forward(
106
106
return decoder_state , attention_state
107
107
108
108
109
- class DalleBartDecoderTorch (nn .Module ):
109
+ class DalleBartDecoder (nn .Module ):
110
110
def __init__ (
111
111
self ,
112
112
image_vocab_count : int ,
@@ -126,8 +126,8 @@ def __init__(
126
126
self .image_token_count = image_token_count
127
127
self .embed_tokens = nn .Embedding (image_vocab_count + 1 , embed_count )
128
128
self .embed_positions = nn .Embedding (image_token_count , embed_count )
129
- self .layers : List [DecoderLayerTorch ] = nn .ModuleList ([
130
- DecoderLayerTorch (
129
+ self .layers : List [DecoderLayer ] = nn .ModuleList ([
130
+ DecoderLayer (
131
131
image_token_count ,
132
132
attention_head_count ,
133
133
embed_count ,
0 commit comments