Skip to content

Commit 7d2b45b

Browse files
mtp: support for gemma-4 E2B and E4B assistants (#24282)
* models: update converter to support smaller assistants * models: add masked_embd tensors to gemma4-assist arch * gemma-4: remove temp debug for conversion * gemma-4-mtp: filter out masked_embedding tensors during conversion
1 parent 42a0afd commit 7d2b45b

6 files changed

Lines changed: 34 additions & 0 deletions

File tree

conversion/gemma.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,16 @@ def set_gguf_parameters(self):
789789
class Gemma4AssistantModel(Gemma4Model):
790790
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT
791791

792+
@classmethod
793+
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
794+
name, gen = item
795+
796+
if "masked_embedding" in name:
797+
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
798+
return None
799+
800+
return super().filter_tensors(item)
801+
792802
def set_gguf_parameters(self):
793803
super().set_gguf_parameters()
794804
self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"])

gguf-py/gguf/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,8 @@ class VISION_PROJECTOR_TYPE(IntEnum):
538538
class MODEL_TENSOR(IntEnum):
539539
TOKEN_EMBD = auto()
540540
TOKEN_EMBD_NORM = auto()
541+
MASKED_EMBD_CENTROIDS= auto()
542+
MASKED_EMBD_ORDERING = auto()
541543
TOKEN_TYPES = auto()
542544
POS_EMBD = auto()
543545
OUTPUT = auto()
@@ -1087,6 +1089,8 @@ class MODEL_TENSOR(IntEnum):
10871089
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
10881090
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
10891091
MODEL_TENSOR.TOKEN_TYPES: "token_types",
1092+
MODEL_TENSOR.MASKED_EMBD_CENTROIDS: "masked_embd_centroids",
1093+
MODEL_TENSOR.MASKED_EMBD_ORDERING: "masked_embd_ordering",
10901094
MODEL_TENSOR.POS_EMBD: "position_embd",
10911095
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
10921096
MODEL_TENSOR.OUTPUT: "output",
@@ -2586,6 +2590,8 @@ class MODEL_TENSOR(IntEnum):
25862590
MODEL_ARCH.GEMMA4_ASSISTANT: [
25872591
MODEL_TENSOR.ROPE_FREQS,
25882592
MODEL_TENSOR.TOKEN_EMBD,
2593+
MODEL_TENSOR.MASKED_EMBD_CENTROIDS,
2594+
MODEL_TENSOR.MASKED_EMBD_ORDERING,
25892595
MODEL_TENSOR.OUTPUT_NORM,
25902596
MODEL_TENSOR.NEXTN_PROJ_PRE,
25912597
MODEL_TENSOR.NEXTN_PROJ_POST,

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ class TensorNameMap:
3737
"model.embed", # talkie
3838
),
3939

40+
# Masked embeddings
41+
MODEL_TENSOR.MASKED_EMBD_CENTROIDS: (
42+
"masked_embedding.centroids", # gemma-4 E2B/E4B assistants
43+
),
44+
MODEL_TENSOR.MASKED_EMBD_ORDERING: (
45+
"masked_embedding.token_ordering", # gemma-4 E2B/E4B assistants
46+
),
47+
4048
# Token type embeddings
4149
MODEL_TENSOR.TOKEN_TYPES: (
4250
"embeddings.token_type_embeddings", # bert nomic-bert

src/llama-arch.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
559559
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
560560
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
561561
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
562+
{ LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" },
563+
{ LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" },
562564
};
563565

564566
// declare information about the model weight tensors:
@@ -783,6 +785,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
783785
// latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU
784786
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
785787
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
788+
{LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
789+
{LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
786790
};
787791

788792
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,11 @@ enum llm_tensor {
566566
LLM_TENSOR_NEXTN_HNORM,
567567
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
568568
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
569+
LLM_TENSOR_MASKED_EMBD_CENTROIDS,
570+
LLM_TENSOR_MASKED_EMBD_ORDERING,
569571
};
570572

573+
571574
enum llm_tensor_layer {
572575
LLM_TENSOR_LAYER_INPUT,
573576
LLM_TENSOR_LAYER_REPEATING,

src/models/gemma4-assistant.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) {
3939

4040
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
4141

42+
create_tensor(tn(LLM_TENSOR_MASKED_EMBD_CENTROIDS, "weight"), {}, TENSOR_NOT_REQUIRED);
43+
create_tensor(tn(LLM_TENSOR_MASKED_EMBD_ORDERING), {}, TENSOR_NOT_REQUIRED);
44+
4245
const int64_t n_embd_backbone = hparams.n_embd_inp();
4346
nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0);
4447

0 commit comments

Comments
 (0)