|
| 1 | +# ------------------------------------------------------------------------- |
| 2 | +# Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved. Portions of this file consist of AI generated content |
| 3 | +# Licensed under the MIT License. See License.txt in the project root for |
| 4 | +# license information. |
| 5 | +# -------------------------------------------------------------------------- |
| 6 | +from .base import Model |
| 7 | +import torch.nn as nn |
| 8 | + |
| 9 | + |
| 10 | +class InternLM2Model(Model): |
| 11 | + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): |
| 12 | + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) |
| 13 | + # Export genai_config with type "internlm2" (C++ model_type.h already lists "internlm2" as LLM) |
| 14 | + self.model_type = "InternLM2ForCausalLM" |
| 15 | + |
| 16 | + def load_weights(self, input_path): |
| 17 | + """ |
| 18 | + Load the InternLM2 model and adapt attribute names to match base class expectations. |
| 19 | + InternLM2 uses: |
| 20 | + - attention_norm instead of input_layernorm |
| 21 | + - ffn_norm instead of post_attention_layernorm |
| 22 | + - feed_forward instead of mlp |
| 23 | + - wqkv (combined QKV) instead of separate q_proj, k_proj, v_proj |
| 24 | + - wo instead of o_proj |
| 25 | + """ |
| 26 | + # Load the model using the parent class method |
| 27 | + model = super().load_weights(input_path) |
| 28 | + |
| 29 | + # Get config from the loaded model |
| 30 | + config = model.config |
| 31 | + |
| 32 | + # Adapt each decoder layer to match the expected attribute names |
| 33 | + for layer in model.model.layers: |
| 34 | + # Map attention_norm to input_layernorm |
| 35 | + if hasattr(layer, 'attention_norm') and not hasattr(layer, 'input_layernorm'): |
| 36 | + layer.input_layernorm = layer.attention_norm |
| 37 | + |
| 38 | + # Map ffn_norm to post_attention_layernorm |
| 39 | + if hasattr(layer, 'ffn_norm') and not hasattr(layer, 'post_attention_layernorm'): |
| 40 | + layer.post_attention_layernorm = layer.ffn_norm |
| 41 | + |
| 42 | + # Map feed_forward to mlp |
| 43 | + if hasattr(layer, 'feed_forward') and not hasattr(layer, 'mlp'): |
| 44 | + layer.mlp = layer.feed_forward |
| 45 | + |
| 46 | + # Map attention to self_attn |
| 47 | + if hasattr(layer, 'attention') and not hasattr(layer, 'self_attn'): |
| 48 | + layer.self_attn = layer.attention |
| 49 | + |
| 50 | + # Map MLP projections (w1/w2/w3 to gate_proj/down_proj/up_proj) |
| 51 | + if hasattr(layer.mlp, 'w1') and not hasattr(layer.mlp, 'gate_proj'): |
| 52 | + layer.mlp.gate_proj = layer.mlp.w1 |
| 53 | + if hasattr(layer.mlp, 'w2') and not hasattr(layer.mlp, 'down_proj'): |
| 54 | + layer.mlp.down_proj = layer.mlp.w2 |
| 55 | + if hasattr(layer.mlp, 'w3') and not hasattr(layer.mlp, 'up_proj'): |
| 56 | + layer.mlp.up_proj = layer.mlp.w3 |
| 57 | + |
| 58 | + # Handle the combined wqkv projection in attention |
| 59 | + # InternLM2 uses a grouped/interleaved layout: [Q1, Q2, ..., Qn, K, V] per KV group |
| 60 | + # Layout: [batch, seq, num_kv_heads, (num_q_heads_per_kv_group + 2), head_dim] |
| 61 | + if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'wqkv'): |
| 62 | + attn = layer.self_attn |
| 63 | + wqkv_weight = attn.wqkv.weight # Shape: [(num_heads + 2*num_kv_heads) * head_dim, hidden_size] |
| 64 | + wqkv_bias = attn.wqkv.bias if hasattr(attn.wqkv, 'bias') and attn.wqkv.bias is not None else None |
| 65 | + |
| 66 | + # Calculate dimensions |
| 67 | + num_q_heads = config.num_attention_heads |
| 68 | + num_kv_heads = config.num_key_value_heads |
| 69 | + num_kv_groups = num_q_heads // num_kv_heads # How many Q heads per KV head |
| 70 | + head_dim = config.hidden_size // num_q_heads |
| 71 | + |
| 72 | + q_size = num_q_heads * head_dim |
| 73 | + kv_size = num_kv_heads * head_dim |
| 74 | + |
| 75 | + # InternLM2's wqkv is organized as interleaved groups: |
| 76 | + # For each KV head group: [Q_heads for this group (num_kv_groups heads), K for this group, V for this group] |
| 77 | + # We need to reshape and reorder to standard [all Q | all K | all V] layout |
| 78 | + |
| 79 | + # Reshape to grouped format: [num_kv_heads, (num_kv_groups + 2), head_dim, hidden_size] |
| 80 | + group_size = num_kv_groups + 2 |
| 81 | + wqkv_grouped = wqkv_weight.reshape(num_kv_heads, group_size, head_dim, config.hidden_size) |
| 82 | + |
| 83 | + # Extract Q, K, V from grouped layout |
| 84 | + # Q heads: first num_kv_groups entries in each group |
| 85 | + q_weight = wqkv_grouped[:, :num_kv_groups, :, :].reshape(num_q_heads, head_dim, config.hidden_size) |
| 86 | + q_weight = q_weight.reshape(q_size, config.hidden_size) |
| 87 | + |
| 88 | + # K heads: second to last entry in each group |
| 89 | + k_weight = wqkv_grouped[:, -2, :, :].reshape(kv_size, config.hidden_size) |
| 90 | + |
| 91 | + # V heads: last entry in each group |
| 92 | + v_weight = wqkv_grouped[:, -1, :, :].reshape(kv_size, config.hidden_size) |
| 93 | + |
| 94 | + # Create separate projection layers |
| 95 | + attn.q_proj = nn.Linear(config.hidden_size, q_size, bias=config.bias) |
| 96 | + attn.k_proj = nn.Linear(config.hidden_size, kv_size, bias=config.bias) |
| 97 | + attn.v_proj = nn.Linear(config.hidden_size, kv_size, bias=config.bias) |
| 98 | + |
| 99 | + # Copy weights (ensure proper copy and contiguous memory) |
| 100 | + attn.q_proj.weight.data.copy_(q_weight.contiguous()) |
| 101 | + attn.k_proj.weight.data.copy_(k_weight.contiguous()) |
| 102 | + attn.v_proj.weight.data.copy_(v_weight.contiguous()) |
| 103 | + |
| 104 | + # Handle biases if they exist (same grouped layout) |
| 105 | + if wqkv_bias is not None: |
| 106 | + bias_grouped = wqkv_bias.reshape(num_kv_heads, group_size, head_dim) |
| 107 | + |
| 108 | + q_bias = bias_grouped[:, :num_kv_groups, :].reshape(q_size) |
| 109 | + k_bias = bias_grouped[:, -2, :].reshape(kv_size) |
| 110 | + v_bias = bias_grouped[:, -1, :].reshape(kv_size) |
| 111 | + |
| 112 | + attn.q_proj.bias.data.copy_(q_bias.contiguous()) |
| 113 | + attn.k_proj.bias.data.copy_(k_bias.contiguous()) |
| 114 | + attn.v_proj.bias.data.copy_(v_bias.contiguous()) |
| 115 | + |
| 116 | + # Map wo to o_proj |
| 117 | + if hasattr(attn, 'wo') and not hasattr(attn, 'o_proj'): |
| 118 | + attn.o_proj = attn.wo |
| 119 | + |
| 120 | + return model |
0 commit comments