-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransformer_lens_final_ln.patch
More file actions
23 lines (20 loc) · 1.01 KB
/
transformer_lens_final_ln.patch
File metadata and controls
23 lines (20 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py
index 8ee2e74f..dce09b51 100644
--- a/transformer_lens/HookedTransformer.py
+++ b/transformer_lens/HookedTransformer.py
@@ -179,3 +179,5 @@ class HookedTransformer(HookedRootModule):
- if self.cfg.normalization_type == "RMS":
+ if not self.cfg.final_ln:
+ self.ln_final = nn.Identity()
+ elif self.cfg.normalization_type == "RMS":
self.ln_final = RMSNorm(self.cfg)
diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py
index 6906de38..40318588 100644
--- a/transformer_lens/HookedTransformerConfig.py
+++ b/transformer_lens/HookedTransformerConfig.py
@@ -181,2 +181,3 @@ class HookedTransformerConfig:
set.
+ final_ln (bool): Whether to apply normalization before final unembed layer. Defaults to True.
@@ -245,2 +246,3 @@ class HookedTransformerConfig:
output_logits_soft_cap: float = -1.0
+ final_ln: bool = True