Skip to content

Added OLMo(E) v1 #816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ docs/build
docs/source/generated
**.orig
.venv

4,171 changes: 2,291 additions & 1,880 deletions poetry.lock

Large diffs are not rendered by default.

49 changes: 32 additions & 17 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def __init__(
)

self.cfg = HookedTransformerConfig.unwrap(cfg)

if tokenizer is not None:
self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
elif self.cfg.tokenizer_name is not None:
Expand All @@ -155,13 +154,18 @@ def __init__(
if "phi" in self.cfg.tokenizer_name.lower():
use_fast = False
huggingface_token = os.environ.get("HF_TOKEN", "")
add_bos_token = self.cfg.original_architecture not in [
"OlmoForCausalLM",
"OlmoeForCausalLM",
"Olmo2ForCausalLM",
]
self.set_tokenizer(
AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
add_bos_token=True,
trust_remote_code=self.cfg.trust_remote_code,
use_fast=use_fast,
token=huggingface_token if len(huggingface_token) > 0 else None,
add_bos_token=add_bos_token,
),
default_padding_side=default_padding_side,
)
Expand Down Expand Up @@ -726,7 +730,14 @@ def set_tokenizer(
# tokenizers like LlamaTokenizer are different when bos token is automatically/manually
# prepended, and add_bos_token cannot be dynamically controlled after initialization
# (https://github.com/huggingface/transformers/issues/25886).
tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
if self.cfg.original_architecture not in [
"OlmoForCausalLM",
"OlmoeForCausalLM",
"Olmo2ForCausalLM",
]:
tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
else:
tokenizer_with_bos = tokenizer
self.tokenizer = tokenizer_with_bos
assert self.tokenizer is not None # keep mypy happy
self.tokenizer.padding_side = default_padding_side
Expand Down Expand Up @@ -1786,18 +1797,18 @@ def fold_layer_norm(
if not self.cfg.final_rms and fold_biases:
# Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm
# pre unembed.
state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + (
state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None]
state_dict["unembed.b_U"] = state_dict["unembed.b_U"] + (
state_dict["unembed.W_U"] * state_dict["ln_final.b"][:, None]
).sum(dim=-2)
del state_dict[f"ln_final.b"]
del state_dict["ln_final.b"]

state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None]
del state_dict[f"ln_final.w"]
state_dict["unembed.W_U"] = state_dict["unembed.W_U"] * state_dict["ln_final.w"][:, None]
del state_dict["ln_final.w"]

if center_weights:
# Center the weights that read in from the LayerNormPre
state_dict[f"unembed.W_U"] -= einops.reduce(
state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
state_dict["unembed.W_U"] -= einops.reduce(
state_dict["unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
)

return state_dict
Expand All @@ -1809,13 +1820,17 @@ def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
W_out. This is done by subtracting the mean of the weights from the weights themselves. This
is done in-place. See fold_layer_norm for more details.
"""
state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
-1, keepdim=True
)
if self.cfg.positional_embedding_type != "rotary":
state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
"pos_embed.W_pos"
].mean(-1, keepdim=True)
if self.cfg.original_architecture == "Olmo2ForCausalLM":
print("Not centering embedding weights for Olmo2ForCausalLM")
pass # should not because input of attn of 1st layer is not normed
else:
state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
-1, keepdim=True
)
if self.cfg.positional_embedding_type != "rotary":
state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
"pos_embed.W_pos"
].mean(-1, keepdim=True)
for l in range(self.cfg.n_layers):
state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[
f"blocks.{l}.attn.W_O"
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ class HookedTransformerConfig:
NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that
affects the rate of change between low and high-frequency interpolation strategies.
Defaults to 8.0.


norm_topk_prob (bool): Whether to normalize the top-k probabilities in the MoE layer.
"""

n_layers: int
Expand Down Expand Up @@ -262,6 +261,7 @@ class HookedTransformerConfig:
NTK_by_parts_low_freq_factor: float = 1.0
NTK_by_parts_high_freq_factor: float = 4.0
NTK_by_parts_factor: float = 8.0
norm_topk_prob: bool = False

def __post_init__(self):
if self.n_heads == -1:
Expand Down
29 changes: 29 additions & 0 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from jaxtyping import Float, Int
from transformers.utils import is_bitsandbytes_available

from transformer_lens.components.rms_norm import RMSNorm
from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.hook_points import HookPoint
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
Expand Down Expand Up @@ -140,6 +141,11 @@ def __init__(
# will be overwritten by the child T5Attention class
self.has_relative_attention_bias = False

if self.cfg.original_architecture == "OlmoeForCausalLM" or self.cfg.original_architecture == "Olmo2ForCausalLM":
self.q_norm = RMSNorm(self.cfg, self.cfg.d_model)
k_norm_dim = self.cfg.d_model if self.cfg.original_architecture == "Olmo2ForCausalLM" else self.cfg.d_head * self.cfg.n_key_value_heads
self.k_norm = RMSNorm(self.cfg, k_norm_dim)

@property
def OV(self) -> FactoredMatrix:
"""
Expand Down Expand Up @@ -195,6 +201,29 @@ def forward(

q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)

# OLMoE uses QK-norm.
if self.cfg.original_architecture == "OlmoeForCausalLM" or self.cfg.original_architecture == "Olmo2ForCausalLM":
q = einops.rearrange(
self.q_norm(
einops.rearrange(
q,
"batch pos head_index d_head -> batch pos (head_index d_head)",
)
),
"batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
head_index=q.shape[2],
)
k = einops.rearrange(
self.k_norm(
einops.rearrange(
k,
"batch pos head_index d_head -> batch pos (head_index d_head)",
)
),
"batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
head_index=k.shape[2],
)

if past_kv_cache_entry is not None:
# Appends the new keys and values to the cached values, and automatically updates the cache
kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
Expand Down
3 changes: 2 additions & 1 deletion transformer_lens/components/mlps/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def forward(
# both are [batch, pos, experts_per_token]
weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float))
weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
if self.cfg.norm_topk_prob:
weights /= weights.sum(dim=-1, keepdim=True)
expert_indices = self.hook_expert_indices(expert_indices)
weights = weights.to(x.dtype)

Expand Down
48 changes: 32 additions & 16 deletions transformer_lens/components/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,33 +153,49 @@ def forward(
key_input = attn_in
value_input = attn_in

attn_out = (
# hook the residual stream states that are used to calculate the
# queries, keys and values, independently.
# Then take the layer norm of these inputs, and pass these to the attention module.
self.attn(
query_input=self.ln1(query_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
key_input=self.ln1(key_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
value_input=self.ln1(value_input),
past_kv_cache_entry=past_kv_cache_entry,
attention_mask=attention_mask,
)
) # [batch, pos, d_model]
if self.cfg.original_architecture == "Olmo2ForCausalLM":
attn_out = self.attn(
query_input=query_input,
key_input=key_input,
value_input=value_input,
past_kv_cache_entry=past_kv_cache_entry,
attention_mask=attention_mask,
)
else:
attn_out = (
# hook the residual stream states that are used to calculate the
# queries, keys and values, independently.
# Then take the layer norm of these inputs, and pass these to the attention module.
self.attn(
query_input=self.ln1(query_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
key_input=self.ln1(key_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
value_input=self.ln1(value_input),
past_kv_cache_entry=past_kv_cache_entry,
attention_mask=attention_mask,
)
) # [batch, pos, d_model]
if self.cfg.use_normalization_before_and_after:
# If we use LayerNorm both before and after, then apply the second LN after the layer
# and before the hook. We do it before the hook so hook_attn_out captures "that which
# is added to the residual stream"
attn_out = self.ln1_post(attn_out)
attn_out = self.hook_attn_out(attn_out)
if self.cfg.original_architecture == "Olmo2ForCausalLM":
attn_out = self.ln1(attn_out)

if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
mlp_in = (
resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
)
normalized_resid_mid = self.ln2(mlp_in)
mlp_out = self.apply_mlp(normalized_resid_mid)
if self.cfg.original_architecture == "Olmo2ForCausalLM":
mlp_out = self.apply_mlp(mlp_in)
mlp_out = self.ln2(mlp_out)
else:
normalized_resid_mid = self.ln2(mlp_in)
mlp_out = self.apply_mlp(normalized_resid_mid)
resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
elif self.cfg.parallel_attn_mlp:
# Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
Expand Down
102 changes: 102 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AutoConfig,
AutoModelForCausalLM,
BertForPreTraining,
PretrainedConfig,
T5ForConditionalGeneration,
)

Expand All @@ -35,6 +36,9 @@
convert_neel_solu_old_weights,
convert_neo_weights,
convert_neox_weights,
convert_olmo2_weights,
convert_olmo_weights,
convert_olmoe_weights,
convert_opt_weights,
convert_phi3_weights,
convert_phi_weights,
Expand Down Expand Up @@ -251,6 +255,20 @@
"google-t5/t5-base",
"google-t5/t5-large",
"ai-forever/mGPT",
"allenai/OLMo-1B-hf",
"allenai/OLMo-7B-hf",
"allenai/OLMo-7B-0724-hf",
"allenai/OLMo-7B-0724-SFT-hf",
"allenai/OLMo-7B-0724-Instruct-hf",
"allenai/OLMo-7B-0424-hf",
"allenai/OLMo-7B-Twin-2T-hf",
"allenai/OLMo-1B-0724-hf",
"allenai/OLMo-7B-Instruct-hf",
"allenai/OLMo-7B-SFT-hf",
"allenai/OLMoE-1B-7B-0924",
"allenai/OLMoE-1B-7B-0924-SFT",
"allenai/OLMoE-1B-7B-0924-Instruct",
"allenai/OLMo-2-1124-7B"
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -1472,6 +1490,84 @@ def convert_hf_model_config(model_name: str, **kwargs):
"final_rms": True,
"use_normalization_before_and_after": True,
}
elif official_model_name.startswith("allenai/OLMo-1B") and official_model_name.endswith("hf"):
cfg_dict = {
"d_model": 2048,
"d_head": 128,
"n_heads": 16,
"d_mlp": 8192,
"n_layers": 16,
"n_ctx": 2048,
"eps": 1e-05,
"d_vocab": 50304,
"act_fn": "silu",
"initializer_range": 0.02,
"normalization_type": "LN",
"rotary_base": 10000.0,
"attn_types": ["global"] * 16,
"positional_embedding_type": "rotary",
"gated_mlp": True,
}
elif official_model_name.startswith("allenai/OLMo-7B") and official_model_name.endswith("hf"):
cfg_dict = {
"d_model": 4096,
"d_head": 128,
"n_heads": 32,
"d_mlp": 11008,
"n_layers": 32,
"n_ctx": 2048,
"eps": 1e-05,
"d_vocab": 50304,
"act_fn": "silu",
"initializer_range": 0.02,
"normalization_type": "LN",
"rotary_base": 10000.0,
"attn_types": ["global"] * 32,
"positional_embedding_type": "rotary",
"gated_mlp": True,
}
elif official_model_name == "allenai/OLMo-2-1124-7B":
cfg_dict = {
"d_model": 4096,
"d_head": 128,
"n_heads": 32,
"d_mlp": 11008,
"n_layers": 32,
"n_ctx": 4096,
"eps": 1e-06,
"d_vocab": 100352,
"act_fn": "silu",
"initializer_range": 0.02,
"normalization_type": "RMSPre",
"rotary_base": 500000.0,
"attn_types": ["global"] * 32,
"positional_embedding_type": "rotary",
"gated_mlp": True,
}
elif architecture == "OlmoeForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": hf_config.hidden_size // hf_config.num_attention_heads,
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": hf_config.max_position_embeddings,
"eps": hf_config.rms_norm_eps,
"d_vocab": hf_config.vocab_size,
"act_fn": hf_config.hidden_act,
"num_experts": hf_config.num_experts,
"experts_per_token": hf_config.num_experts_per_tok,
"norm_topk_prob": hf_config.norm_topk_prob,
"n_key_value_heads": hf_config.num_key_value_heads,
"rotary_base": hf_config.rope_theta,
"tie_word_embeddings": hf_config.tie_word_embeddings,
"initializer_range": hf_config.initializer_range,
"positional_embedding_type": "rotary",
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
"final_rms": True,
"gated_mlp": True,
"normalization_type": "LN",
}
elif architecture == "T5ForConditionalGeneration":
cfg_dict = {
"d_model": hf_config.d_model,
Expand Down Expand Up @@ -1891,6 +1987,12 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoForCausalLM":
state_dict = convert_olmo_weights(hf_model, cfg)
elif cfg.original_architecture == "Olmo2ForCausalLM":
state_dict = convert_olmo2_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoeForCausalLM":
state_dict = convert_olmoe_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@
from .nanogpt import convert_nanogpt_weights
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .olmo import convert_olmo_weights
from .olmoe import convert_olmoe_weights
from .olmo2 import convert_olmo2_weights
Loading