Skip to content

Commit 2dee2b5

Browse files
committed
support megrez2
1 parent 76de927 commit 2dee2b5

File tree

10 files changed

+546
-64
lines changed

10 files changed

+546
-64
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ LittleAcademia[<a href="https://github.com/foldl/little-academia" style="text-
3333

3434
**What's New:**
3535

36+
* 2025-10-31: Megrez2-3x7B-A3B
3637
* 2025-10-25: LLaDA2.0-mini
3738
* 2025-10-14: Nanonets-OCR2
3839
* 2025-10-13: dots.ocr

convert.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ class ModelType(Enum):
221221
BailingMoE2 = 0x2E00
222222
LlaDA2 = 0x2E01
223223

224+
MegrezMoE = 0x2F00
225+
224226
BCE_Embedding = 0x10000100
225227
BCE_ReRanker = 0x10000101
226228
BGE_M3 = 0x10000102
@@ -8099,6 +8101,80 @@ def get_weight_names(config):
80998101

81008102
return weight_names
81018103

8104+
class MegrezMoEConverter(BaseConverter):
8105+
MODEL_TYPE = ModelType.MegrezMoE
8106+
8107+
@classmethod
8108+
def pp(cls, config, name: str, tensor):
8109+
return DeepSeekV1Converter.pp(config, name, tensor)
8110+
8111+
@staticmethod
8112+
def dump_config(f, config, ggml_type):
8113+
assert config.hidden_act == 'silu', "hidden_act must be silu"
8114+
assert config.attention_bias == False, "attention_bias must be False"
8115+
assert config.ep_size == 1, "ep_size must be 1"
8116+
assert config.rope_scaling is None
8117+
assert config.scoring_func == 'sigmoid', "scoring_func must be 'sigmoid'"
8118+
assert config.topk_method == 'noaux_tc', "topk_method must be 'noaux_tc'"
8119+
assert config.n_routed_experts is not None, "n_routed_experts must not be null"
8120+
assert config.pre_gate
8121+
8122+
config.scoring_func = 'softmax'
8123+
DeepSeekV1Converter.dump_config(f, config, ggml_type)
8124+
8125+
config_values = [
8126+
config.experts_shared_frequency,
8127+
config.n_group,
8128+
config.topk_group,
8129+
config.routed_scaling_factor,
8130+
]
8131+
f.write(struct.pack("<iiif", *config_values))
8132+
8133+
@staticmethod
8134+
def get_weight_names(config):
8135+
weight_names = ["model.embed_tokens.weight",
8136+
"model.norm.weight",
8137+
"lm_head.weight"]
8138+
for i in range(config.num_hidden_layers):
8139+
8140+
weight_names += [
8141+
f"model.layers.{i}.self_attn.k_proj.weight",
8142+
f"model.layers.{i}.self_attn.q_proj.weight",
8143+
f"model.layers.{i}.self_attn.v_proj.weight",
8144+
f"model.layers.{i}.self_attn.o_proj.weight",
8145+
]
8146+
8147+
if (config.n_routed_experts is not None
8148+
and (i >= config.first_k_dense_replace)
8149+
and (i % config.moe_layer_freq == 0)):
8150+
weight_names += [
8151+
f"model.layers.{i}.mlp.gate.e_score_correction_bias",
8152+
f"model.layers.{i}.mlp.gate.weight",
8153+
f"model.layers.{i}.mlp.shared_experts.gate_proj.weight",
8154+
f"model.layers.{i}.mlp.shared_experts.up_proj.weight",
8155+
f"model.layers.{i}.mlp.shared_experts.down_proj.weight",
8156+
]
8157+
if (i - config.first_k_dense_replace) % config.experts_shared_frequency == 0:
8158+
for j in range(config.n_routed_experts):
8159+
weight_names += [
8160+
f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight",
8161+
f"model.layers.{i}.mlp.experts.{j}.up_proj.weight",
8162+
f"model.layers.{i}.mlp.experts.{j}.down_proj.weight",
8163+
]
8164+
else:
8165+
weight_names += [
8166+
f"model.layers.{i}.mlp.gate_proj.weight",
8167+
f"model.layers.{i}.mlp.up_proj.weight",
8168+
f"model.layers.{i}.mlp.down_proj.weight",
8169+
]
8170+
8171+
weight_names += [
8172+
f"model.layers.{i}.input_layernorm.weight",
8173+
f"model.layers.{i}.post_attention_layernorm.weight",
8174+
]
8175+
8176+
return weight_names
8177+
81028178
def convert_grok_1_base(args, vocab, ggml_type):
81038179
def ffn_size(emb_size, widening_factor):
81048180
_ffn_size = int(widening_factor * emb_size) * 2 // 3
@@ -8719,6 +8795,8 @@ def main():
87198795
JanusConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
87208796
elif arch.endswith('DotsOCRForCausalLM'):
87218797
DotsOCRConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
8798+
elif arch.endswith('MegrezMoeForCausalLM'):
8799+
MegrezMoEConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
87228800
elif arch == 'deepseek-r1-distill-qwen3':
87238801
QWen3Converter.MODEL_TYPE = ModelType.DeepSeek_R1_Distill_QWen3
87248802
QWen3Converter.convert(config, model_files, vocab, ggml_type, args.save_path)

docs/models.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@
172172

173173
For other models that using `LlamaForCausalLM` architecture, for example, [aiXcoder-7B](https://huggingface.co/aiXcoder/aixcoder-7b-base), try `-a Yi`.
174174

175+
* Megrez (`MegrezMoeForCausalLM`)
176+
* [x] (3x7B-A3B)[https://huggingface.co/Infinigence/Megrez2-3x7B-A3B/tree/3ffc3b7c0ffc0f0b27d71fba2a97dcc14c797bb4]
177+
175178
* MiniCPM (`MiniCPMForCausalLM`, `MiniCPM3ForCausalLM`)
176179
* [x] [DPO-2B](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp16), [SFT-2B](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16),
177180
[SFT-1B](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16)🔥

models/bailing.cpp

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -119,52 +119,31 @@ namespace chatllm::bailing::moe2
119119
const int NUM_EXPERTS = 256;
120120
const int EXPERTS_PER_TOK = 8;
121121

122-
class BailingSparseMoE : public BaseSparseMLP
122+
class BailingSparseMoE : public GenericGroupedSparseMoE
123123
{
124124
public:
125-
BailingSparseMoE(InitContext *ctx, int hidden_size, int intermediate_size, int num_experts = NUM_EXPERTS, int experts_per_tok = EXPERTS_PER_TOK)
126-
: BaseSparseMLP(ctx, hidden_size, intermediate_size, num_experts, experts_per_tok, ActFunc::SILU, true),
127-
n_group(-1), topk_group(-1)
125+
BailingSparseMoE(InitContext *ctx, int hidden_size, int intermediate_size, int num_experts = NUM_EXPERTS, int experts_per_tok = EXPERTS_PER_TOK):
126+
GenericGroupedSparseMoE(ctx, hidden_size, num_experts, experts_per_tok, true, false, false, false),
127+
experts(ctx, hidden_size, intermediate_size, num_experts, experts_per_tok, ActFunc::SILU, false)
128128
{
129-
score_func = ScoreFunc::Sigmoid;
130-
always_scaling = true;
129+
set_experts(&experts);
131130
}
132-
protected:
133-
ggml::tensor *select_experts(ComputeContext *ctx, ggml::tensor *corrected_score) override;
134131

132+
int64_t get_param_num(bool effective_only) const override
133+
{
134+
int64_t r = GenericSparseMLP::get_param_num(effective_only);
135+
r += experts.get_param_num(effective_only);
136+
return r;
137+
}
138+
void load(const std::string &path, TensorLoader *loader) override
139+
{
140+
GenericSparseMLP::load(path, loader);
141+
experts.load(path + "experts.", loader);
142+
}
135143
public:
136-
int n_group;
137-
int topk_group;
144+
MultiMLP experts;
138145
};
139146

140-
ggml::tensor *BailingSparseMoE::select_experts(ComputeContext *ctx, ggml::tensor *corrected_score)
141-
{
142-
const int n_expert = num_local_experts;
143-
const int experts_per_group = n_expert / n_group;
144-
CHATLLM_CHECK(ggml::get_dim(corrected_score, 2) == 1);
145-
146-
ggml::tensor * selected_experts = nullptr;
147-
148-
ggml::tensor *grouped_scores = ggml::reshape_4d(ctx, corrected_score, experts_per_group, num_experts_per_tok,
149-
ggml::get_dim(corrected_score, 1), ggml::get_dim(corrected_score, 2));
150-
selected_experts = ggml::top_k(ctx, grouped_scores, topk_group);
151-
152-
ggml::tensor *selected_experts_i64 = ggml::cast_int_to_i64(ctx, selected_experts);
153-
154-
CHATLLM_CHECK(ggml::get_dim(grouped_scores, 3) == 1);
155-
grouped_scores = ggml::reshape_4d(ctx, grouped_scores, 1, ggml::get_dim(grouped_scores, 0), ggml::get_dim(grouped_scores, 1), ggml::get_dim(grouped_scores, 2));
156-
ggml::tensor *selected_group_scores = ggml::scale(ctx, grouped_scores, 0.0f);
157-
grouped_scores = ggml::get_rows(ctx, grouped_scores, selected_experts);
158-
selected_group_scores = ggml::set_rows(ctx, selected_group_scores, selected_experts_i64, grouped_scores);
159-
160-
selected_group_scores = ggml::reshape_3d(ctx, selected_group_scores,
161-
ggml::get_dim(corrected_score, 0), ggml::get_dim(corrected_score, 1), ggml::get_dim(corrected_score, 2));
162-
163-
selected_experts = ggml::top_k(ctx, selected_group_scores, num_experts_per_tok);
164-
165-
return selected_experts;
166-
}
167-
168147
class AttnParams
169148
{
170149
public:

models/gpt.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ Reasoning: medium
164164
norm_topk_prob = false;
165165
}
166166
public:
167-
ggml::tensor *forward(ComputeContext *ctx, ggml::tensor *hidden_states)
167+
using Block::forward;
168+
ggml::tensor *forward(ComputeContext *ctx, ggml::tensor *hidden_states) override
168169
{
169170
const int64_t qlen = hidden_states->ne[1];
170171
const int n_expert = num_local_experts;

0 commit comments

Comments
 (0)