Skip to content

Commit 793c23e

Browse files
cmikeh2mrwyattiijeffra
authored
Explicitly check for OPT activation function (#3278)
Co-authored-by: Michael Wyatt <[email protected]> Co-authored-by: Jeff Rasley <[email protected]>
1 parent 145c3a7 commit 793c23e

File tree

1 file changed

+13
-5
lines changed
  • deepspeed/module_inject/containers

1 file changed

+13
-5
lines changed

deepspeed/module_inject/containers/opt.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,26 @@ class HFOPTLayerPolicy(TransformerPolicy):
7272
_orig_layer_class = None
7373

7474
def __init__(self, client_module, inference=True, use_load_prefix=True):
75-
super().__init__(inference,
76-
linear_layer=True,
77-
mlp_act_func_type=ActivationFuncType.ReLU,
78-
pre_attn_norm=True,
79-
use_load_prefix=use_load_prefix)
75+
super().__init__(inference, linear_layer=True, pre_attn_norm=True, use_load_prefix=use_load_prefix)
8076
self.client_module = client_module
8177
try:
8278
import transformers
8379
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
8480
except:
8581
HFOPTLayerPolicy._orig_layer_class = None
8682

83+
if hasattr(TransformerPolicy, "hf_model_config") and hasattr(TransformerPolicy.hf_model_config,
84+
"activation_function"):
85+
if TransformerPolicy.hf_model_config.activation_function == "relu":
86+
self.mlp_act_func_type == ActivationFuncType.ReLU
87+
elif TransformerPolicy.hf_model_config.activation_function in ["gelu", "gelu_new"]:
88+
self.mlp_act_func_type == ActivationFuncType.GELU
89+
else:
90+
raise ValueError("Unsupported activation function: {}".format(
91+
TransformerPolicy.hf_model_config.activation_function))
92+
else:
93+
self.mlp_act_func_type == ActivationFuncType.ReLU # default
94+
8795
def get_hidden_heads(self):
8896
return self.client_module.self_attn.embed_dim, \
8997
self.client_module.self_attn.num_heads, \

0 commit comments

Comments
 (0)