Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,204 changes: 1,665 additions & 1,539 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
{platform="linux", version=">=1.10"}, # We can use any torch version on Linux (e.g colab)
]
tqdm=">=4.64.1"
transformers=">=4.37.2"
transformers={ git = "https://github.com/huggingface/transformers.git" }
typing-extensions="*"
wandb=">=0.13.5"

Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
self.set_tokenizer(
AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
add_bos_token=True,
# add_bos_token=True,
trust_remote_code=self.cfg.trust_remote_code,
use_fast=use_fast,
token=huggingface_token,
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/components/mlps/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):

self.num_experts: int = self.cfg.num_experts
self.experts_per_token: int = self.cfg.experts_per_token
# self.norm_topk_prob: bool = self.cfg.norm_topk_prob

assert (
self.cfg.experts_per_token <= self.cfg.num_experts
Expand All @@ -88,6 +89,8 @@ def forward(
# both are [batch, pos, experts_per_token]
weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float))
weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1)
# if self.norm_topk_prob:
# weights /= weights.sum(dim=-1, keepdim=True)
weights /= weights.sum(dim=-1, keepdim=True)
expert_indices = self.hook_expert_indices(expert_indices)
weights = weights.to(x.dtype)
Expand Down
34 changes: 34 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
convert_neel_solu_old_weights,
convert_neo_weights,
convert_neox_weights,
convert_olmoe_weights,
convert_opt_weights,
convert_phi3_weights,
convert_phi_weights,
Expand Down Expand Up @@ -225,6 +226,9 @@
"google-t5/t5-base",
"google-t5/t5-large",
"ai-forever/mGPT",
"allenai/OLMoE-1B-7B-0924",
"allenai/OLMoE-1B-7B-0924-SFT",
"allenai/OLMoE-1B-7B-0924-Instruct",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -1329,6 +1333,34 @@ def convert_hf_model_config(model_name: str, **kwargs):
"use_attn_scale": False,
"tie_word_embeddings": hf_config.tie_word_embeddings,
}
elif architecture == "OlmoeForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": hf_config.hidden_size // hf_config.num_attention_heads,
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": hf_config.max_position_embeddings,
"eps": hf_config.rms_norm_eps,
"d_vocab": hf_config.vocab_size,
"act_fn": hf_config.hidden_act,
"num_experts": hf_config.num_experts,
"experts_per_token": hf_config.num_experts_per_tok,
# TODO: implement!
# "router_aux_loss_coef": hf_config.router_aux_loss_coef,
# "router_z_loss_coef": hf_config.router_z_loss_coef,
# "norm_topk_prob": hf_config.norm_topk_prob,
# end
"n_key_value_heads": hf_config.num_key_value_heads,
"rotary_base": hf_config.rope_theta,
"tie_word_embeddings": hf_config.tie_word_embeddings,
"initializer_range": hf_config.initializer_range,
"positional_embedding_type": "rotary",
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
"final_rms": True,
"gated_mlp": True,
"normalization_type": "RMS",
}
else:
raise NotImplementedError(f"{architecture} is not currently supported.")
# All of these models use LayerNorm
Expand Down Expand Up @@ -1714,6 +1746,8 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "OlmoeForCausalLM":
state_dict = convert_olmoe_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .nanogpt import convert_nanogpt_weights
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .olmoe import convert_olmoe_weights
64 changes: 64 additions & 0 deletions transformer_lens/pretrained/weight_conversions/olmoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import einops
import torch

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def convert_olmoe_weights(olmoe, cfg: HookedTransformerConfig):
state_dict = {}

assert cfg.n_key_value_heads is not None
assert cfg.d_mlp is not None
assert cfg.num_experts is not None

state_dict["embed.W_E"] = olmoe.model.embed_tokens.weight

for l in range(cfg.n_layers):
olmoe_layer = olmoe.model.layers[l]
state_dict[f"blocks.{l}.ln1.w"] = olmoe_layer.input_layernorm.weight

W_Q = olmoe.model.layers[l].self_attn.q_proj.weight
W_K = olmoe.model.layers[l].self_attn.k_proj.weight
W_V = olmoe.model.layers[l].self_attn.v_proj.weight
W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
state_dict[f"blocks.{l}.attn._W_K"] = W_K
state_dict[f"blocks.{l}.attn._W_V"] = W_V

state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
)
state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
)

W_O = olmoe_layer.self_attn.o_proj.weight
W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_O"] = W_O

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)

state_dict[f"blocks.{l}.ln2.w"] = olmoe_layer.post_attention_layernorm.weight

state_dict[f"blocks.{l}.mlp.W_gate.weight"] = olmoe_layer.mlp.gate.weight

for e in range(cfg.num_experts):
state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = olmoe_layer.mlp.experts[
e
].up_proj.weight
state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = olmoe_layer.mlp.experts[
e
].gate_proj.weight
state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = olmoe_layer.mlp.experts[
e
].down_proj.weight

state_dict["ln_final.w"] = olmoe.model.norm.weight

state_dict["unembed.W_U"] = olmoe.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)

return state_dict
2 changes: 1 addition & 1 deletion transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ def get_tokenizer_with_bos(tokenizer):
huggingface_token = os.environ.get("HF_TOKEN", None)
tokenizer_with_bos = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
add_bos_token=True,
# add_bos_token=True,
token=huggingface_token,
**init_kwargs,
)
Expand Down