Skip to content

Initial support for OpenELM models #868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class HookedTransformerConfig:
default_prepend_bos: bool = True
dtype: torch.dtype = torch.float32
tokenizer_prepends_bos: Optional[bool] = None
n_key_value_heads: Optional[int] = None
n_key_value_heads: Optional[List[int]] = None
post_embedding_ln: bool = False
rotary_base: int = 10000
trust_remote_code: bool = False
Expand All @@ -262,6 +262,8 @@ class HookedTransformerConfig:
NTK_by_parts_low_freq_factor: float = 1.0
NTK_by_parts_high_freq_factor: float = 4.0
NTK_by_parts_factor: float = 8.0
n_query_heads: Optional[List[int]] = None
d_mlps: Optional[List[int]] = None

def __post_init__(self):
if self.n_heads == -1:
Expand Down
59 changes: 58 additions & 1 deletion transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
convert_qwen2_weights,
convert_qwen_weights,
convert_t5_weights,
convert_openelm_weights,
)

OFFICIAL_MODEL_NAMES = [
Expand Down Expand Up @@ -250,6 +251,14 @@
"google-t5/t5-base",
"google-t5/t5-large",
"ai-forever/mGPT",
"apple/OpenELM-270M",
"apple/OpenELM-450M",
"apple/OpenELM-1_1B",
"apple/OpenELM-3B",
"apple/OpenELM-270M-Instruct",
"apple/OpenELM-450M-Instruct",
"apple/OpenELM-1_1B-Instruct",
"apple/OpenELM-3B-Instruct",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -1337,6 +1346,49 @@ def convert_hf_model_config(model_name: str, **kwargs):
"parallel_attn_mlp": False,
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
}
elif architecture == "OpenELMForCausalLM":
def make_divisible(
v: Union[float, int],
divisor: Optional[int] = 8,
min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by the divisor
It can be seen at:
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
Args:
v: input value
divisor: default to 8
min_value: minimum divisor value
Returns:
new_v: new divisible value
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v

cfg_dict = {
"d_model": hf_config.model_dim,
"d_head": hf_config.head_dim,
"n_heads": 64,# is this variable too? ,
"n_layers": hf_config.num_transformer_layers,
"n_ctx": hf_config.max_context_length,
"eps": 23, # what is going on here??
"d_vocab": hf_config.vocab_size,
"act_fn": "silu",
"initializer_range": hf_config.initializer_range,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"trust_remote_code": True,
"n_key_value_heads": hf_config.num_kv_heads,
"n_query_heads": hf_config.num_query_heads,
"d_mlps": [(2 * int(make_divisible(val * hf_config.model_dim, hf_config.ffn_dim_divisor))) for val in hf_config.ffn_multipliers],
}

elif official_model_name.startswith("google/gemma-2b"):
# Architecture for Gemma 2b and Gemma 2b Instruct models
Expand Down Expand Up @@ -1488,7 +1540,10 @@ def convert_hf_model_config(model_name: str, **kwargs):
# All of these models use LayerNorm
cfg_dict["original_architecture"] = architecture
# The name such that AutoTokenizer.from_pretrained works
cfg_dict["tokenizer_name"] = official_model_name
if architecture == "OpenELMForCausalLM":
cfg_dict["tokenizer_name"] = "meta-llama/Llama-2-7b-hf"
else:
cfg_dict["tokenizer_name"] = official_model_name
if kwargs.get("trust_remote_code", False):
cfg_dict["trust_remote_code"] = True
return cfg_dict
Expand Down Expand Up @@ -1882,6 +1937,8 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "OpenELMForCausalLM":
state_dict = convert_openelm_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .nanogpt import convert_nanogpt_weights
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .openelm import convert_openelm_weights
53 changes: 53 additions & 0 deletions transformer_lens/pretrained/weight_conversions/openelm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import einops

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig

def convert_openelm_weights(openelm, cfg: HookedTransformerConfig):
state_dict = {}

assert cfg.d_mlp is not None
assert cfg.n_key_value_heads is not None

state_dict["embed.W_E"] = openelm.transformer.token_embeddings.weight

for l in range(cfg.n_layers):
WQ = openelm.transformer.layers[l].attn.qkv_proj.weight[:(cfg.n_query_heads[l] * cfg.d_head)]
WK = openelm.transformer.layers[l].attn.qkv_proj.weight[(cfg.n_query_heads[l] * cfg.d_head) : ((cfg.n_query_heads[l] + cfg.n_key_value_heads[l]) * cfg.d_head)]
WV = openelm.transformer.layers[l].attn.qkv_proj.weight[-cfg.n_key_value_heads[l] * cfg.d_head:]

WQ = einops.rearrange(WQ, "(n h) m->n m h", n=cfg.n_query_heads[l])
WK = einops.rearrange(WK, "(n h) m->n m h", n=cfg.n_key_value_heads[l])
WV = einops.rearrange(WV, "(n h) m->n m h", n=cfg.n_key_value_heads[l])

state_dict[f"blocks.{l}.attn.W_Q"] = WQ
state_dict[f"blocks.{l}.attn._W_K"] = WK
state_dict[f"blocks.{l}.attn._W_V"] = WV

state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
cfg.n_key_value_heads[l], cfg.d_head, dtype=cfg.dtype
)
state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
cfg.n_key_value_heads[l], cfg.d_head, dtype=cfg.dtype
)

WO = openelm.transformer.layers[l].attn.out_proj.weight
WO = einops.rearrange(WO, "m (n h)->n h m", n=cfg.n_query_heads[l])
state_dict[f"blocks.{l}.attn.W_O"] = WO

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
state_dict[f"blocks.{l}.ln2.w"] = openelm.transformer.layers[l].attn_norm.weight

state_dict[f"blocks.{l}.mlp.W_in"] = openelm.transformer.layers[l].ffn.proj_1.weight[:cfg.d_mlps[l], :].T
state_dict[f"blocks.{l}.mlp.W_gate"] = openelm.transformer.layers[l].ffn.proj_1.weight[cfg.d_mlps[l]:, :].T
state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlps[l], dtype=cfg.dtype)
state_dict[f"blocks.{l}.mlp.W_out"] = openelm.transformer.layers[l].ffn.proj_2.weight.T
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(openelm.transformer.layers[l].ffn.proj_2.weight.shape[0], dtype=cfg.dtype)

state_dict[f"blocks.{l}.mlp.ln3.w"] = openelm.transformer.layers[l].ffn_norm.weight

state_dict["ln_final.w"] = openelm.transformer.norm.weight

state_dict["unembed.W_U"] = openelm.transformer.token_embeddings.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
Loading