@@ -51,14 +51,21 @@ def __init__(self, client_module, inference=True):
51
51
try :
52
52
from megatron .model .transformer import ParallelTransformerLayer
53
53
MegatronLayerPolicy ._orig_layer_class = ParallelTransformerLayer
54
+ MegatronLayerPolicy .version = 1
54
55
except ImportError :
55
56
MegatronLayerPolicy ._orig_layer_class = None
56
57
57
58
def get_hidden_heads (self ):
58
- return self .client_module .attention .query_key_value .weight .shape [1 ], \
59
- self .client_module .attention .num_attention_heads , \
60
- self .client_module .input_layernorm .eps , \
61
- DEFAULT_INTERMEDIATE_SIZE
59
+ if MegatronLayerPolicy .version == 0 :
60
+ return self .client_module .attention .query_key_value .weight .shape [1 ], \
61
+ self .client_module .attention .num_attention_heads , \
62
+ self .client_module .input_layernorm .eps , \
63
+ DEFAULT_INTERMEDIATE_SIZE
64
+ else :
65
+ return self .client_module .self_attention .query_key_value .weight .shape [1 ], \
66
+ self .client_module .self_attention .num_attention_heads , \
67
+ self .client_module .input_layernorm .eps , \
68
+ DEFAULT_INTERMEDIATE_SIZE
62
69
63
70
def attention (self , enable_training = False ):
64
71
if self .inference :
0 commit comments