diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 9cf16f578..d98040d88 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -264,6 +264,8 @@ class HookedTransformerConfig: NTK_by_parts_high_freq_factor: float = 4.0 NTK_by_parts_factor: float = 8.0 NTK_original_ctx_len: int = 8192 + n_query_heads: Optional[List[int]] = None + d_mlps: Optional[List[int]] = None def __post_init__(self): if self.n_heads == -1: diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..7e281ddbf 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -37,6 +37,7 @@ convert_neel_solu_old_weights, convert_neo_weights, convert_neox_weights, + convert_openelm_weights, convert_opt_weights, convert_phi3_weights, convert_phi_weights, @@ -263,6 +264,14 @@ "google-t5/t5-base", "google-t5/t5-large", "ai-forever/mGPT", + "apple/OpenELM-270M", + "apple/OpenELM-450M", + "apple/OpenELM-1_1B", + "apple/OpenELM-3B", + "apple/OpenELM-270M-Instruct", + "apple/OpenELM-450M-Instruct", + "apple/OpenELM-1_1B-Instruct", + "apple/OpenELM-3B-Instruct", ] """Official model names for models on HuggingFace.""" @@ -1436,6 +1445,53 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "parallel_attn_mlp": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, } + elif architecture == "OpenELMForCausalLM": + + def make_divisible( + v: Union[float, int], + divisor: Optional[int] = 8, + min_value: Optional[Union[float, int]] = None, + ) -> Union[float, int]: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by the divisor + It can be seen at: + https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62 + Args: + v: input value + divisor: default to 8 + min_value: minimum divisor value + Returns: + new_v: new divisible value + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + cfg_dict = { + "d_model": hf_config.model_dim, + "d_head": hf_config.head_dim, + "n_heads": 64, # is this variable too? , + "n_layers": hf_config.num_transformer_layers, + "n_ctx": hf_config.max_context_length, + "eps": 23, # what is going on here?? + "d_vocab": hf_config.vocab_size, + "act_fn": "silu", + "initializer_range": hf_config.initializer_range, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "trust_remote_code": True, + "n_key_value_heads": hf_config.num_kv_heads, + "n_query_heads": hf_config.num_query_heads, + "d_mlps": [ + (2 * int(make_divisible(val * hf_config.model_dim, hf_config.ffn_dim_divisor))) + for val in hf_config.ffn_multipliers + ], + } elif official_model_name.startswith("google/gemma-2b"): # Architecture for Gemma 2b and Gemma 2b Instruct models @@ -1587,7 +1643,10 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): # All of these models use LayerNorm cfg_dict["original_architecture"] = architecture # The name such that AutoTokenizer.from_pretrained works - cfg_dict["tokenizer_name"] = official_model_name + if architecture == "OpenELMForCausalLM": + cfg_dict["tokenizer_name"] = "meta-llama/Llama-2-7b-hf" + else: + cfg_dict["tokenizer_name"] = official_model_name if kwargs.get("trust_remote_code", False): cfg_dict["trust_remote_code"] = True return cfg_dict @@ -1986,6 +2045,8 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) + elif cfg.original_architecture == "OpenELMForCausalLM": + state_dict = convert_openelm_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..3551ac474 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .openelm import convert_openelm_weights diff --git a/transformer_lens/pretrained/weight_conversions/openelm.py b/transformer_lens/pretrained/weight_conversions/openelm.py new file mode 100644 index 000000000..129e9fad4 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/openelm.py @@ -0,0 +1,68 @@ +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_openelm_weights(openelm, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.d_mlp is not None + assert cfg.n_key_value_heads is not None + + state_dict["embed.W_E"] = openelm.transformer.token_embeddings.weight + + for l in range(cfg.n_layers): + WQ = openelm.transformer.layers[l].attn.qkv_proj.weight[ + : (cfg.n_query_heads[l] * cfg.d_head) + ] + WK = openelm.transformer.layers[l].attn.qkv_proj.weight[ + (cfg.n_query_heads[l] * cfg.d_head) : ( + (cfg.n_query_heads[l] + cfg.n_key_value_heads[l]) * cfg.d_head + ) + ] + WV = openelm.transformer.layers[l].attn.qkv_proj.weight[ + -cfg.n_key_value_heads[l] * cfg.d_head : + ] + + WQ = einops.rearrange(WQ, "(n h) m->n m h", n=cfg.n_query_heads[l]) + WK = einops.rearrange(WK, "(n h) m->n m h", n=cfg.n_key_value_heads[l]) + WV = einops.rearrange(WV, "(n h) m->n m h", n=cfg.n_key_value_heads[l]) + + state_dict[f"blocks.{l}.attn.W_Q"] = WQ + state_dict[f"blocks.{l}.attn._W_K"] = WK + state_dict[f"blocks.{l}.attn._W_V"] = WV + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads[l], cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads[l], cfg.d_head, dtype=cfg.dtype + ) + + WO = openelm.transformer.layers[l].attn.out_proj.weight + WO = einops.rearrange(WO, "m (n h)->n h m", n=cfg.n_query_heads[l]) + state_dict[f"blocks.{l}.attn.W_O"] = WO + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.ln2.w"] = openelm.transformer.layers[l].attn_norm.weight + + state_dict[f"blocks.{l}.mlp.W_in"] = ( + openelm.transformer.layers[l].ffn.proj_1.weight[: cfg.d_mlps[l], :].T + ) + state_dict[f"blocks.{l}.mlp.W_gate"] = ( + openelm.transformer.layers[l].ffn.proj_1.weight[cfg.d_mlps[l] :, :].T + ) + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlps[l], dtype=cfg.dtype) + state_dict[f"blocks.{l}.mlp.W_out"] = openelm.transformer.layers[l].ffn.proj_2.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( + openelm.transformer.layers[l].ffn.proj_2.weight.shape[0], dtype=cfg.dtype + ) + + state_dict[f"blocks.{l}.mlp.ln3.w"] = openelm.transformer.layers[l].ffn_norm.weight + + state_dict["ln_final.w"] = openelm.transformer.norm.weight + + state_dict["unembed.W_U"] = openelm.transformer.token_embeddings.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)