@@ -282,6 +282,7 @@ def __init__(
282282 self .num_layers = model_config .pretrained_config .num_hidden_layers
283283 self ._eh_proj_before_attn = eagle_config .get ("eh_proj_before_attn" ,
284284 False )
285+ self ._norm_before_fc = eagle_config .get ("norm_before_fc" , False )
285286 self ._use_mla = use_mla
286287
287288 if hasattr (config , "target_hidden_size" ):
@@ -303,6 +304,15 @@ def __init__(
303304 dtype = config .torch_dtype ,
304305 quant_config = model_config .get_quant_config (),
305306 )
307+ if self ._norm_before_fc :
308+ self .input_norm = RMSNorm (
309+ hidden_size = self .hidden_size_in *
310+ self .spec_config .num_capture_layers ,
311+ eps = config .rms_norm_eps ,
312+ dtype = config .torch_dtype ,
313+ )
314+ else :
315+ self .input_norm = None
306316
307317 if self .num_layers > 1 :
308318 self .midlayer = nn .ModuleList ([
@@ -552,6 +562,8 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
552562
553563 expected_hidden_size = self .model .hidden_size
554564 if hidden_states .shape [- 1 ] != expected_hidden_size :
565+ if self .model ._norm_before_fc :
566+ hidden_states = self .model .input_norm (hidden_states )
555567 hidden_states = self .model .fc (hidden_states )
556568
557569 return hidden_states
0 commit comments