Skip to content

Commit 4fb16ec

Browse files
XarbirusCISC
andauthored
model: add Mellum architecture (#23966)
* model: support for Mellum architecture * model: improve mellum.py formatting * model: improve mellum.py formatting once again * deps: downgrade transformers to 4.57.6 (to fix CI) * deps: remove huggingface_hub dependency * deps: remove huggingface_hub from test requirements --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent bfb4308 commit 4fb16ec

20 files changed

Lines changed: 344 additions & 5 deletions

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
143143
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
144144
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
145145
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
146+
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
146147

147148
#### Multimodal
148149

conversion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
"Mamba2ForCausalLM": "mamba",
136136
"MambaForCausalLM": "mamba",
137137
"MambaLMHeadModel": "mamba",
138+
"MellumForCausalLM": "mellum",
138139
"MiMoV2FlashForCausalLM": "mimo",
139140
"MiMoV2ForCausalLM": "mimo",
140141
"MiniCPM3ForCausalLM": "minicpm",

conversion/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,6 +1663,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
16631663
if chkhsh == "789696f5946cc0fc59371f39f6097cafed196b3acded6140432f26bbb1ae1669":
16641664
# ref: https://huggingface.co/ibm-granite/granite-embedding-311m-multilingual-r2
16651665
res = "granite-embed-multi-311m"
1666+
if chkhsh == "9dcf830ee9990cdbf78cc523a5f7bd9ad8f3f9890c2d3581d2785ad10f07049d":
1667+
# ref: https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base
1668+
res = "mellum2"
16661669

16671670
if res is None:
16681671
logger.warning("\n")

conversion/mellum.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
from typing import Iterable, TYPE_CHECKING
4+
5+
import torch
6+
7+
if TYPE_CHECKING:
8+
from torch import Tensor
9+
10+
from .base import ModelBase, TextModel, gguf, logger
11+
12+
13+
@ModelBase.register("MellumForCausalLM")
14+
class MellumModel(TextModel):
15+
model_arch = gguf.MODEL_ARCH.MELLUM
16+
17+
def set_gguf_parameters(self):
18+
super().set_gguf_parameters()
19+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
20+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
21+
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
22+
23+
use_sliding_window = self.hparams.get("use_sliding_window")
24+
sliding_window = self.hparams.get("sliding_window")
25+
if (use_sliding_window is True or use_sliding_window is None) and sliding_window is not None:
26+
self.gguf_writer.add_sliding_window(sliding_window)
27+
logger.info(f"gguf: sliding window = {sliding_window}")
28+
self.gguf_writer.add_sliding_window_pattern([t == "sliding_attention" for t in self.hparams["layer_types"]])
29+
logger.info(f"gguf: sliding window pattern length = {len(self.hparams['layer_types'])}")
30+
31+
_experts: list[dict[str, Tensor]] | None = None
32+
33+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
34+
if name.find("experts") != -1:
35+
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
36+
assert bid is not None
37+
38+
if self._experts is None:
39+
self._experts = [{} for _ in range(self.block_count)]
40+
41+
self._experts[bid][name] = data_torch
42+
43+
if len(self._experts[bid]) >= n_experts * 3:
44+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
45+
datas: list[Tensor] = []
46+
47+
for xid in range(n_experts):
48+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
49+
datas.append(self._experts[bid][ename])
50+
del self._experts[bid][ename]
51+
52+
data_torch = torch.stack(datas, dim=0)
53+
54+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
55+
56+
yield from super().modify_tensors(data_torch, merged_name, bid)
57+
return
58+
else:
59+
return
60+
61+
yield from super().modify_tensors(data_torch, name, bid)

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class TOKENIZER_TYPE(IntEnum):
160160
{"name": "minicpm5", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openbmb/MiniCPM5-1B"},
161161
{"name": "granite-embed-multi-97m", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-97m-multilingual-r2", },
162162
{"name": "granite-embed-multi-311m", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-311m-multilingual-r2", },
163+
{"name": "mellum2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base"},
163164
]
164165

165166
# some models are known to be broken upstream, so we will skip them as exceptions

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ class MODEL_ARCH(IntEnum):
510510
MAINCODER = auto()
511511
KIMI_LINEAR = auto()
512512
TALKIE = auto()
513+
MELLUM = auto()
513514

514515

515516
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -1030,6 +1031,7 @@ class MODEL_TENSOR(IntEnum):
10301031
MODEL_ARCH.MAINCODER: "maincoder",
10311032
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
10321033
MODEL_ARCH.TALKIE: "talkie",
1034+
MODEL_ARCH.MELLUM: "mellum",
10331035
}
10341036

10351037
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -4093,6 +4095,23 @@ class MODEL_TENSOR(IntEnum):
40934095
MODEL_TENSOR.FFN_UP,
40944096
MODEL_TENSOR.LAYER_OUT_SCALE,
40954097
],
4098+
MODEL_ARCH.MELLUM: [
4099+
MODEL_TENSOR.TOKEN_EMBD,
4100+
MODEL_TENSOR.OUTPUT_NORM,
4101+
MODEL_TENSOR.OUTPUT,
4102+
MODEL_TENSOR.ATTN_NORM,
4103+
MODEL_TENSOR.ATTN_Q,
4104+
MODEL_TENSOR.ATTN_Q_NORM,
4105+
MODEL_TENSOR.ATTN_K,
4106+
MODEL_TENSOR.ATTN_K_NORM,
4107+
MODEL_TENSOR.ATTN_V,
4108+
MODEL_TENSOR.ATTN_OUT,
4109+
MODEL_TENSOR.FFN_NORM,
4110+
MODEL_TENSOR.FFN_GATE_INP,
4111+
MODEL_TENSOR.FFN_GATE_EXP,
4112+
MODEL_TENSOR.FFN_DOWN_EXP,
4113+
MODEL_TENSOR.FFN_UP_EXP,
4114+
],
40964115
# TODO
40974116
}
40984117

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ requires-python = '>=3.10,<3.15'
1010
dependencies = [
1111
'numpy (>=1.26.4,<3.0.0)',
1212
'sentencepiece (>=0.1.98,<0.3.0)',
13-
'transformers (==5.5.1)',
13+
'transformers (==4.57.6)',
1414
'protobuf (>=4.21.0,<5.0.0)',
1515
'torch (>=2.6.0,<3.0.0)',
1616
'gguf @ ./gguf-py',
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
numpy~=1.26.4
22
sentencepiece>=0.1.98,<0.3.0
33

4-
transformers==5.5.1
4+
transformers==4.57.6
55

66
gguf>=0.1.0
77
protobuf>=4.21.0,<5.0.0

requirements/requirements-tool_bench.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
aiohttp~=3.9.3
22
pytest~=8.3.3
3-
huggingface_hub>=1.5.0,<2.0
43
matplotlib~=3.10.0
54
numpy~=1.26.4
65
openai~=2.14.0

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
135135
{ LLM_ARCH_MAINCODER, "maincoder" },
136136
{ LLM_ARCH_KIMI_LINEAR, "kimi-linear" },
137137
{ LLM_ARCH_TALKIE, "talkie" },
138+
{ LLM_ARCH_MELLUM, "mellum" },
138139
{ LLM_ARCH_UNKNOWN, "(unknown)" },
139140
};
140141

0 commit comments

Comments
 (0)