Skip to content

Commit 8b9ac26

Browse files
committed
support ERNIE-MoE
1 parent 1899886 commit 8b9ac26

File tree

12 files changed

+360
-35
lines changed

12 files changed

+360
-35
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
1313

1414
**What's New:**
1515

16+
* 2025-07-04: ERNIE-MoE
1617
* 2025-06-30: Hunyuan-A13B, ERNIE-Dense
1718
* 2025-06-21: [I can hear](./docs/multimodal.md): Qwen2-Audio
1819
* 2025-06-10: SmolVLM2

convert.py

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ class ModelType(Enum):
197197

198198
Apriel = 0x2400
199199

200+
ERNIE_MoE = 0x2500
201+
200202
BCE_Embedding = 0x10000100
201203
BCE_ReRanker = 0x10000101
202204
BGE_M3 = 0x10000102
@@ -285,6 +287,7 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
285287
tensor = torch.cat((scale.half().view(torch.int8), min_values.half().view(torch.int8), tensor), dim=-1)
286288
return tensor
287289

290+
@torch.jit.script
288291
def qkx2_quants(x: torch.Tensor, nmax, rmin, rdelta, nstep: int, use_mad: bool):
289292
assert x.dim() == 1
290293
N = x.shape[0]
@@ -297,7 +300,7 @@ def qkx2_quants(x: torch.Tensor, nmax, rmin, rdelta, nstep: int, use_mad: bool):
297300
if min_x > 0: min_x = torch.tensor(0)
298301
if min_x == max_x:
299302
L = torch.zeros(N)
300-
return 0.0, -min_x, L
303+
return torch.tensor(0.0), -min_x, L
301304

302305
iscale = nmax / (max_x - min_x)
303306
scale = 1 / iscale
@@ -322,7 +325,7 @@ def qkx2_quants(x: torch.Tensor, nmax, rmin, rdelta, nstep: int, use_mad: bool):
322325
this_scale = (sum_w * sum_xl - sum_x * sum_l)/D
323326
this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D
324327
if this_min > 0:
325-
this_min = 0
328+
this_min = torch.tensor(0)
326329
this_scale = sum_xl / sum_l2
327330

328331
diff = this_scale * l + this_min - x
@@ -335,19 +338,20 @@ def qkx2_quants(x: torch.Tensor, nmax, rmin, rdelta, nstep: int, use_mad: bool):
335338

336339
return scale, -min_x, L
337340

338-
def quantize_q4_k_block(tensor: torch.Tensor) -> torch.CharTensor:
341+
@torch.jit.script
342+
def quantize_q4_k_block(tensor: torch.Tensor, GGML_QK_K: int) -> torch.CharTensor:
339343
assert tensor.shape == (GGML_QK_K, )
340344
tensor = tensor.view(-1, 32)
341345

342-
subblocks = [qkx2_quants(tensor[i], 15, -1.0, 0.1, 20, False) for i in range(tensor.shape[0])]
346+
subblocks = [qkx2_quants(tensor[i], torch.tensor(15), torch.tensor(-1.0), torch.tensor(0.1), 20, False) for i in range(tensor.shape[0])]
343347
scale = torch.stack([x[0] for x in subblocks])
344348
min_x = torch.stack([x[1] for x in subblocks])
345349

346350
max_scale = torch.max(scale)
347351
max_min = torch.max(min_x)
348352

349-
inv_scale = 63.0 / max_scale if max_scale > 0 else 0.0
350-
inv_min = 64.0 / max_min if max_min > 0 else 0.0
353+
inv_scale = torch.tensor(63.0) / max_scale if max_scale > 0 else torch.tensor(0.0)
354+
inv_min = torch.tensor(64.0) / max_min if max_min > 0 else torch.tensor(0.0)
351355

352356
ls = (inv_scale * scale).round().clamp(max=63)
353357
lm = (inv_min * min_x).round().clamp(max=63)
@@ -380,11 +384,12 @@ def quantize_q4_k_block(tensor: torch.Tensor) -> torch.CharTensor:
380384

381385
return r
382386

383-
def quantize_q4_k(tensor: torch.Tensor) -> torch.CharTensor:
387+
@torch.jit.script
388+
def quantize_q4_k(tensor: torch.Tensor, GGML_QK_K: int) -> torch.CharTensor:
384389
# equivalent to dequantize_row_q4_K in ggml-quants.c
385390
assert tensor.shape[tensor.ndim - 1] % GGML_QK_K == 0
386391
tensor = tensor.view(-1, GGML_QK_K)
387-
blocks = [quantize_q4_k_block(tensor[i]) for i in range(tensor.shape[0])]
392+
blocks = [quantize_q4_k_block(tensor[i], GGML_QK_K) for i in range(tensor.shape[0])]
388393
tensor = torch.cat(blocks, dim=-1)
389394
return tensor
390395

@@ -411,7 +416,7 @@ def dump_tensor(f, name: str, tensor: torch.Tensor, ggml_type: GGMLType):
411416
elif ggml_type == GGMLType.Q4_1:
412417
tensor = quantize_q4_1(tensor)
413418
elif ggml_type == GGMLType.Q4_K:
414-
tensor = quantize_q4_k(tensor)
419+
tensor = quantize_q4_k(tensor, GGML_QK_K)
415420
else:
416421
raise NotImplementedError(f"Cannot dump tensor of dtype {tensor.dtype}")
417422
except Exception as e:
@@ -6364,6 +6369,81 @@ def get_weight_names(config):
63646369

63656370
return weight_names
63666371

6372+
class ERNIEMoEConverter(BaseConverter):
6373+
MODEL_TYPE = ModelType.ERNIE_MoE
6374+
6375+
@staticmethod
6376+
def dump_config(f, config, ggml_type):
6377+
assert not config.use_bias
6378+
assert len(config.moe_capacity) == 3
6379+
if config.rope_scaling is not None:
6380+
assert config.rope_scaling == 1.0, 'rope_scaling must equal to 1.0'
6381+
6382+
dump_llama_like_config(f, config, ggml_type)
6383+
config_values = [
6384+
config.num_key_value_heads,
6385+
1 if config.tie_word_embeddings else 0,
6386+
config.moe_num_experts,
6387+
config.moe_num_shared_experts,
6388+
config.moe_layer_start_index,
6389+
config.moe_intermediate_size,
6390+
config.moe_capacity[0],
6391+
config.moe_capacity[1],
6392+
config.moe_capacity[2],
6393+
config.moe_k,
6394+
config.moe_layer_interval,
6395+
1 if config.moe_use_aux_free else 0,
6396+
]
6397+
f.write(struct.pack("i" * len(config_values), *config_values))
6398+
f.write(struct.pack("<f", config.rope_theta))
6399+
6400+
@staticmethod
6401+
def get_weight_names(config):
6402+
weight_names = ["model.embed_tokens.weight"]
6403+
for i in range(config.num_hidden_layers):
6404+
weight_names += [
6405+
f"model.layers.{i}.input_layernorm.weight",
6406+
f"model.layers.{i}.post_attention_layernorm.weight",
6407+
f"model.layers.{i}.self_attn.k_proj.weight",
6408+
f"model.layers.{i}.self_attn.o_proj.weight",
6409+
f"model.layers.{i}.self_attn.q_proj.weight",
6410+
f"model.layers.{i}.self_attn.v_proj.weight",
6411+
]
6412+
6413+
if (i >= config.moe_layer_start_index) and ((i + 1) % config.moe_layer_interval == 0):
6414+
weight_names += [
6415+
f"model.layers.{i}.mlp.gate.weight",
6416+
f"model.layers.{i}.mlp.shared_experts.gate_proj.weight",
6417+
f"model.layers.{i}.mlp.shared_experts.up_proj.weight",
6418+
f"model.layers.{i}.mlp.shared_experts.down_proj.weight",
6419+
]
6420+
if config.moe_use_aux_free:
6421+
weight_names += [
6422+
f"model.layers.{i}.mlp.moe_statics.e_score_correction_bias",
6423+
]
6424+
for j in range(config.moe_num_experts):
6425+
weight_names += [
6426+
f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight",
6427+
f"model.layers.{i}.mlp.experts.{j}.up_proj.weight",
6428+
f"model.layers.{i}.mlp.experts.{j}.down_proj.weight",
6429+
]
6430+
else:
6431+
weight_names += [
6432+
f"model.layers.{i}.mlp.down_proj.weight",
6433+
f"model.layers.{i}.mlp.gate_proj.weight",
6434+
f"model.layers.{i}.mlp.up_proj.weight",
6435+
]
6436+
6437+
weight_names += [
6438+
"model.norm.weight",
6439+
]
6440+
6441+
if not config.tie_word_embeddings:
6442+
weight_names += [
6443+
"lm_head.weight"
6444+
]
6445+
return weight_names
6446+
63676447
class KimiVLConverter(BaseConverter):
63686448
MODEL_TYPE = ModelType.KimiVL
63696449

@@ -7516,6 +7596,8 @@ def main():
75167596
QWen3Converter.convert(config, model_files, vocab, ggml_type, args.save_path)
75177597
elif arch == 'Ernie4_5_ForCausalLM':
75187598
ERNIEDenseConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
7599+
elif arch == 'Ernie4_5_MoeForCausalLM':
7600+
ERNIEMoEConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
75197601
elif arch == 'deepseek-r1-distill-qwen3':
75207602
QWen3Converter.MODEL_TYPE = ModelType.DeepSeek_R1_Distill_QWen3
75217603
QWen3Converter.convert(config, model_files, vocab, ggml_type, args.save_path)

docs/models.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@
5858

5959
Two optimization modes are defined: speed (default) and memory. See `BaseMLAttention`.
6060

61-
* ERNIE (`Ernie4_5_ForCausalLM`)
62-
* [x] [0.3B](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT/tree/c163aa422d265f995b024d1322d91c4e3cb52ec8)
61+
* ERNIE (`Ernie4_5_ForCausalLM`, `Ernie4_5_MoeForCausalLM`)
62+
* [x] [0.3B](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT/tree/c163aa422d265f995b024d1322d91c4e3cb52ec8), [A3B](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT/tree/b24b8917f5379129992dad46c279683c7b845c96)
6363

6464
* EXAONE (`ExaoneForCausalLM`)
6565
* [x] v3.5: [Instruct-2.4B](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct), [Instruct-7.8B](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct), [Instruct-32B](https://huggingface.co/LGAI-EXAONE/EXAONE-3.5-32B-Instruct)

models/ernie.cpp

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace chatllm::ernie::dense
1313

1414
static ChatHistoryEncoder _chat_encoder;
1515

16-
Tokenizer::Tokenizer(const Config &config)
16+
Tokenizer::Tokenizer(const BaseConfig &config)
1717
: chatllm::llama::v2::Tokenizer(config, &_chat_encoder)
1818
{}
1919

@@ -65,4 +65,135 @@ namespace chatllm::ernie::dense
6565
attention.freq_base = config.rope_theta;
6666
}
6767
}
68+
}
69+
70+
namespace chatllm::ernie::moe
71+
{
72+
template <class ErnieMoEMLP> class ErnieMoEBlock : public LMBlock1<RMSNorm, LlamaSelfAttention, RMSNorm, ErnieMoEMLP>
73+
{
74+
public:
75+
ErnieMoEBlock(InitContext *ctx, int hidden_size, int num_attention_heads, int intermediate_size,
76+
int mlp_intermediate_size1, int mlp_intermediate_size2,
77+
int num_kv_heads, int head_dim, int max_length)
78+
: LMBlock1<RMSNorm, LlamaSelfAttention, RMSNorm, ErnieMoEMLP>(ctx, hidden_size, num_attention_heads, intermediate_size, mlp_intermediate_size1, mlp_intermediate_size2,
79+
num_kv_heads, head_dim, max_length)
80+
{}
81+
};
82+
83+
template <int NUM_EXPERTS, int EXPERTS_PER_TOK> class ErnieSparseMoE : public BaseSparseMLP
84+
{
85+
public:
86+
ErnieSparseMoE(InitContext *ctx, int hidden_size, int intermediate_size)
87+
: BaseSparseMLP(ctx, hidden_size, intermediate_size, NUM_EXPERTS, EXPERTS_PER_TOK, ActFunc::SILU, false)
88+
{
89+
}
90+
};
91+
92+
template <const int NUM_EXPERTS, const int EXPERTS_PER_TOK, const int EFFECTIVE_EXPERTS_PER_TOK> class GenericConditionalGeneration : public BaseModelForConditionalGeneration
93+
{
94+
public:
95+
typedef CombinedMLP<ErnieSparseMoE<NUM_EXPERTS, EXPERTS_PER_TOK>, SiLUMLP> ErnieMoEMLP;
96+
typedef ErnieMoEBlock<ErnieMoEMLP> MoEBlock;
97+
typedef BaseModelForConditionalGeneration Base;
98+
typedef HeterogeneousModel ModelClass;
99+
public:
100+
GenericConditionalGeneration() = default;
101+
102+
GenericConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config)
103+
: BaseModelForConditionalGeneration(MODEL_TYPE_ERNIE_MOE, config, runtime_config, 4096 * 4),
104+
config(config)
105+
{
106+
const size_t tensor_ovhd = ggml_tensor_overhead();
107+
const size_t moe_layers = get_moe_layer_num();
108+
const size_t dense_layers = config.num_hidden_layers - moe_layers;
109+
const size_t num_tensors = 2 + dense_layers * (12) + moe_layers * (16 + 0) + (config.tie_word_embeddings ? 0 : 1);
110+
const size_t ctx_size = num_tensors * tensor_ovhd;
111+
w_ctx_.gctx = GGMLContext({.mem_size = ctx_size, .mem_buffer = nullptr, .no_alloc = true});
112+
w_ctx_.dtype = config.dtype;
113+
114+
if (config.use_correction_bias)
115+
ggml::log(GGML_LOG_LEVEL_WARN, "use_correction_bias is ignored, see https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT/blob/main/modeling_ernie4_5_moe.py#L369");
116+
117+
auto create_layer = [&](InitContext *ctx, int layer_index) -> Block * {
118+
if (is_layer_moe(layer_index))
119+
{
120+
auto layer = new MoEBlock(ctx, config.hidden_size, config.num_attention_heads, config.intermediate_size,
121+
config.moe_intermediate_size, config.moe_intermediate_size * config.moe_num_shared_experts,
122+
config.num_key_value_heads, config.hidden_size / config.num_attention_heads,
123+
config.max_length);
124+
layer->attention.freq_base = config.rope_theta;
125+
layer->mlp.mlp1.norm_topk_prob = true;
126+
return layer;
127+
}
128+
else
129+
{
130+
auto layer = new LlamaBlock(ctx, config.hidden_size, config.num_attention_heads, config.intermediate_size,
131+
config.num_key_value_heads, config.max_length);
132+
layer->attention.freq_base = config.rope_theta;
133+
return layer;
134+
}
135+
};
136+
137+
auto transformer = new ModelClass(&w_ctx_, config.num_hidden_layers, config.hidden_size,
138+
create_embedding<Embedding>(&w_ctx_, config),
139+
create_final_norm<RMSNorm>(&w_ctx_, config),
140+
config.tie_word_embeddings ? nullptr : create_lm_head(&w_ctx_, config, false), create_layer);
141+
Base::transformer = transformer;
142+
143+
w_ctx_.check_used_mem_size(true);
144+
}
145+
146+
protected:
147+
int get_moe_layer_num(void)
148+
{
149+
int r = 0;
150+
for (int i = 0; i < config.num_hidden_layers; i++)
151+
if (is_layer_moe(i)) r++;
152+
return r;
153+
}
154+
155+
bool is_layer_moe(int i)
156+
{
157+
if (i < config.moe_layer_start_index) return false;
158+
return (i % config.moe_layer_interval) == 0;
159+
}
160+
public:
161+
Config config;
162+
};
163+
164+
namespace experts_64
165+
{
166+
const int NUM_EXPERTS = 64;
167+
const int EXPERTS_PER_TOK = 6;
168+
169+
// make it easy to test with different number of experts.
170+
const int EFFECTIVE_EXPERTS_PER_TOK = EXPERTS_PER_TOK;
171+
172+
typedef GenericConditionalGeneration<NUM_EXPERTS, EXPERTS_PER_TOK, EFFECTIVE_EXPERTS_PER_TOK> ConditionalGeneration;
173+
}
174+
175+
ConditionalGeneration::ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config)
176+
{
177+
switch (config.moe_num_experts)
178+
{
179+
case experts_64::NUM_EXPERTS:
180+
set_proxy_model(new experts_64::ConditionalGeneration(config, runtime_config));
181+
break;
182+
default:
183+
CHATLLM_CHECK(false) << "unsupported MoE param: num_experts = " << config.moe_num_experts;
184+
break;
185+
}
186+
}
187+
188+
void ConditionalGeneration::load(ModelLoader &loader)
189+
{
190+
loader.add_tensor_name_translations({
191+
{".mlp2.", ".shared_experts."},
192+
{".mlp1.gate.", ".gate."},
193+
{".mlp1.experts.", ".experts."},
194+
{".mlp1.gate_score_correction_bias", ".moe_statics.e_score_correction_bias"}
195+
});
196+
197+
ModelProxy::load(loader);
198+
}
68199
}

models/ernie.h

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace chatllm::ernie::dense
1818
class Tokenizer : public chatllm::llama::v2::Tokenizer
1919
{
2020
public:
21-
Tokenizer(const Config &config);
21+
Tokenizer(const BaseConfig &config);
2222
};
2323

2424
class ConditionalGeneration : public chatllm::llama::v2::GenericConditionalGeneration<LlamaBlock>
@@ -27,4 +27,32 @@ namespace chatllm::ernie::dense
2727
ConditionalGeneration() = default;
2828
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = ModelType::MODEL_TYPE_ERNIE_DENSE);
2929
};
30-
}
30+
}
31+
32+
namespace chatllm::ernie::moe
33+
{
34+
struct Config : public chatllm::llama::v2::Config
35+
{
36+
int num_key_value_heads;
37+
int tie_word_embeddings;
38+
int moe_num_experts;
39+
int moe_num_shared_experts;
40+
int moe_layer_start_index;
41+
int moe_intermediate_size;
42+
int moe_capacity[3];
43+
int moe_k;
44+
int moe_layer_interval;
45+
int use_correction_bias;
46+
47+
float rope_theta;
48+
};
49+
50+
typedef dense::Tokenizer Tokenizer;
51+
52+
class ConditionalGeneration : public ModelProxy
53+
{
54+
public:
55+
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config);
56+
void load(ModelLoader &loader);
57+
};
58+
}

scripts/models.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2953,6 +2953,15 @@
29532953
"url": "chatllm_quantized_ernie/ernie-4.5-0.3b.bin"
29542954
}
29552955
}
2956+
},
2957+
"a3b": {
2958+
"default": "q4_1",
2959+
"quantized": {
2960+
"q8": {
2961+
"size": 13643262720,
2962+
"url": "chatllm_quantized_ernie/ernie-4.5-21b-a3b-q4_1.bin"
2963+
}
2964+
}
29562965
}
29572966
}
29582967
}

0 commit comments

Comments
 (0)