Skip to content

Commit 6cbf666

Browse files
Dino ChenRezaYazdaniAminabadi
Dino Chen
andauthored
fix MegatronLayerPolicy to be compatible with the newest ParallelTransformerLayer (#4236)
Co-authored-by: Reza Yazdani <[email protected]>
1 parent 5dbc531 commit 6cbf666

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

deepspeed/module_inject/containers/megatron_gpt.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,21 @@ def __init__(self, client_module, inference=True):
5151
try:
5252
from megatron.model.transformer import ParallelTransformerLayer
5353
MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer
54+
MegatronLayerPolicy.version = 1
5455
except ImportError:
5556
MegatronLayerPolicy._orig_layer_class = None
5657

5758
def get_hidden_heads(self):
58-
return self.client_module.attention.query_key_value.weight.shape[1], \
59-
self.client_module.attention.num_attention_heads, \
60-
self.client_module.input_layernorm.eps, \
61-
DEFAULT_INTERMEDIATE_SIZE
59+
if MegatronLayerPolicy.version == 0:
60+
return self.client_module.attention.query_key_value.weight.shape[1], \
61+
self.client_module.attention.num_attention_heads, \
62+
self.client_module.input_layernorm.eps, \
63+
DEFAULT_INTERMEDIATE_SIZE
64+
else:
65+
return self.client_module.self_attention.query_key_value.weight.shape[1], \
66+
self.client_module.self_attention.num_attention_heads, \
67+
self.client_module.input_layernorm.eps, \
68+
DEFAULT_INTERMEDIATE_SIZE
6269

6370
def attention(self, enable_training=False):
6471
if self.inference:

0 commit comments

Comments
 (0)