Skip to content

Commit b7098a2

Browse files
[None][feat] Eagle: Norm before FC (NVIDIA#12561)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
1 parent db1c637 commit b7098a2

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def __init__(
282282
self.num_layers = model_config.pretrained_config.num_hidden_layers
283283
self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn",
284284
False)
285+
self._norm_before_fc = eagle_config.get("norm_before_fc", False)
285286
self._use_mla = use_mla
286287

287288
if hasattr(config, "target_hidden_size"):
@@ -303,6 +304,15 @@ def __init__(
303304
dtype=config.torch_dtype,
304305
quant_config=model_config.get_quant_config(),
305306
)
307+
if self._norm_before_fc:
308+
self.input_norm = RMSNorm(
309+
hidden_size=self.hidden_size_in *
310+
self.spec_config.num_capture_layers,
311+
eps=config.rms_norm_eps,
312+
dtype=config.torch_dtype,
313+
)
314+
else:
315+
self.input_norm = None
306316

307317
if self.num_layers > 1:
308318
self.midlayer = nn.ModuleList([
@@ -552,6 +562,8 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
552562

553563
expected_hidden_size = self.model.hidden_size
554564
if hidden_states.shape[-1] != expected_hidden_size:
565+
if self.model._norm_before_fc:
566+
hidden_states = self.model.input_norm(hidden_states)
555567
hidden_states = self.model.fc(hidden_states)
556568

557569
return hidden_states

0 commit comments

Comments
 (0)