@@ -72,18 +72,26 @@ class HFOPTLayerPolicy(TransformerPolicy):
72
72
_orig_layer_class = None
73
73
74
74
def __init__ (self , client_module , inference = True , use_load_prefix = True ):
75
- super ().__init__ (inference ,
76
- linear_layer = True ,
77
- mlp_act_func_type = ActivationFuncType .ReLU ,
78
- pre_attn_norm = True ,
79
- use_load_prefix = use_load_prefix )
75
+ super ().__init__ (inference , linear_layer = True , pre_attn_norm = True , use_load_prefix = use_load_prefix )
80
76
self .client_module = client_module
81
77
try :
82
78
import transformers
83
79
HFOPTLayerPolicy ._orig_layer_class = transformers .models .opt .modeling_opt .OPTDecoderLayer
84
80
except :
85
81
HFOPTLayerPolicy ._orig_layer_class = None
86
82
83
+ if hasattr (TransformerPolicy , "hf_model_config" ) and hasattr (TransformerPolicy .hf_model_config ,
84
+ "activation_function" ):
85
+ if TransformerPolicy .hf_model_config .activation_function == "relu" :
86
+ self .mlp_act_func_type == ActivationFuncType .ReLU
87
+ elif TransformerPolicy .hf_model_config .activation_function in ["gelu" , "gelu_new" ]:
88
+ self .mlp_act_func_type == ActivationFuncType .GELU
89
+ else :
90
+ raise ValueError ("Unsupported activation function: {}" .format (
91
+ TransformerPolicy .hf_model_config .activation_function ))
92
+ else :
93
+ self .mlp_act_func_type == ActivationFuncType .ReLU # default
94
+
87
95
def get_hidden_heads (self ):
88
96
return self .client_module .self_attn .embed_dim , \
89
97
self .client_module .self_attn .num_heads , \
0 commit comments