Skip to content

Commit e463bbd

Browse files
AesSedaiCISC
andauthored
model: Add Kimi-K2.5 support (ggml-org#19170)
* Move dequant_model to after the text_config merge Add new kimi-k2.5 keys to mtmd convert Update V_MMPROJ tensor mapping for new mm_projector.proj keys Update V_M_IMP_NORM for new mm_projector.pre_norm key * Fix a couple of oversights * Add image support for Kimi-K2.5 * Revert changes to KimiVLForConditionalGeneration * Fix an assert crash * Fix permute swapping w / h on accident * Kimi-K2.5: Use merged QKV for vision * Kimi-K2.5: pre-convert vision QK to use build_rope_2d * Kimi-K2.5: support non-interleaved rope for vision * Kimi-K2.5: fix min / max pixel * Kimi-K2.5: remove v/o permutes, unnecessary * Kimi-K2.5: update permute name to match * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <[email protected]> * Kimi-K2.5: replace build_rope_2d ggml_cont with ggml_view_3d pointers --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 53de59f commit e463bbd

File tree

8 files changed

+316
-13
lines changed

8 files changed

+316
-13
lines changed

convert_hf_to_gguf.py

Lines changed: 115 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
160160
self.ftype = gguf.LlamaFileType.MOSTLY_F16
161161
logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16")
162162

163-
self.dequant_model()
164-
165163
# Configure GGUF Writer
166164
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
167165
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@@ -527,6 +525,8 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
527525
return ()
528526

529527
def prepare_tensors(self):
528+
self.dequant_model()
529+
530530
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
531531
if self.tensor_map.mapping:
532532
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
@@ -1815,7 +1815,7 @@ class MmprojModel(ModelBase):
18151815
preprocessor_config: dict[str, Any]
18161816
global_config: dict[str, Any]
18171817

1818-
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
1818+
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"]
18191819

18201820
has_vision_encoder: bool = True # by default
18211821
has_audio_encoder: bool = False
@@ -1870,7 +1870,15 @@ def __init__(self, *args, **kwargs):
18701870
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
18711871
if preprocessor_config_path.is_file():
18721872
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
1873-
self.preprocessor_config = json.load(f)
1873+
cfg = json.load(f)
1874+
# move media_proc_cfg to root level for compat
1875+
if "media_proc_cfg" in cfg:
1876+
cfg = {
1877+
**cfg,
1878+
**cfg["media_proc_cfg"],
1879+
}
1880+
# merge configs
1881+
self.preprocessor_config = {**self.preprocessor_config, **cfg}
18741882

18751883
# prefer processor_config.json if possible
18761884
processor_config_path = self.dir_model / "processor_config.json"
@@ -1919,10 +1927,10 @@ def set_gguf_parameters(self):
19191927
self.image_size = self.find_vparam(["image_size"])
19201928
self.gguf_writer.add_vision_image_size(self.image_size)
19211929
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
1922-
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
1923-
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
1930+
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"]))
1931+
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"]))
19241932
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
1925-
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
1933+
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"]))
19261934

19271935
# preprocessor config
19281936
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
@@ -7695,6 +7703,7 @@ def prepare_tensors(self):
76957703
"DeepseekV2ForCausalLM",
76967704
"DeepseekV3ForCausalLM",
76977705
"KimiVLForConditionalGeneration",
7706+
"KimiK25ForConditionalGeneration",
76987707
"YoutuForCausalLM",
76997708
"YoutuVLForConditionalGeneration",
77007709
)
@@ -7813,8 +7822,8 @@ def set_gguf_parameters(self):
78137822
_experts: list[dict[str, Tensor]] | None = None
78147823

78157824
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
7816-
# skip vision tensors and remove "language_model." for Kimi-VL
7817-
if "vision_tower" in name or "multi_modal_projector" in name:
7825+
# skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5
7826+
if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name:
78187827
return
78197828
if name.startswith("siglip2.") or name.startswith("merger."):
78207829
return
@@ -11176,6 +11185,103 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
1117611185
yield from super().modify_tensors(data_torch, name, bid)
1117711186

1117811187

11188+
@ModelBase.register("KimiK25ForConditionalGeneration")
11189+
class KimiK25Model(MmprojModel):
11190+
"""Kimi-K2.5 with MoonViT3d vision encoder"""
11191+
11192+
def __init__(self, *args, **kwargs):
11193+
super().__init__(*args, **kwargs)
11194+
11195+
assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config"
11196+
11197+
self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2]))
11198+
self.patch_size = self.hparams_vision.get("patch_size", 14)
11199+
11200+
# Set image_size for compatibility with base class
11201+
# Use position embedding dimensions as image_size reference
11202+
pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64)
11203+
self.hparams_vision["image_size"] = pos_emb_h * self.patch_size
11204+
11205+
def set_gguf_parameters(self):
11206+
# Base class MmprojModel.set_gguf_parameters() already writes:
11207+
# - vision_block_count, vision_head_count, vision_embedding_length
11208+
# - vision_feed_forward_length, vision_patch_size, image_mean, image_std
11209+
# via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config
11210+
super().set_gguf_parameters()
11211+
assert self.hparams_vision is not None
11212+
11213+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25)
11214+
11215+
# Position embedding parameters (for interpolation)
11216+
self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64))
11217+
self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64))
11218+
self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4))
11219+
11220+
# Projector parameters
11221+
self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu")
11222+
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5))
11223+
self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0])
11224+
11225+
# Image size limits
11226+
# Note: in_patch_limit is for images, in_patch_limit_each_frame is for video (not supported yet)
11227+
in_patch_limit = self.preprocessor_config.get("in_patch_limit", 16384)
11228+
min_patches = 8 # reasonable minimum
11229+
pixels_per_patch = self.patch_size ** 2
11230+
self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch)
11231+
self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch)
11232+
11233+
@staticmethod
11234+
def permute(weights: Tensor, n_head: int) -> Tensor:
11235+
out_dim, in_dim = weights.shape
11236+
head_dim = out_dim // n_head
11237+
w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim)
11238+
w = w.permute(0, 2, 1, 3, 4)
11239+
return w.reshape(out_dim, in_dim)
11240+
11241+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
11242+
# Only process vision and projector tensors
11243+
is_vision = any(x in name for x in ["vision_tower", "mm_projector"])
11244+
11245+
if not is_vision:
11246+
return
11247+
11248+
assert self.hparams_vision is not None
11249+
n_head = self.hparams_vision.get("num_attention_heads", 16)
11250+
11251+
# Permute Q/K weights/biases from interleaved to split RoPE format
11252+
# This allows using build_rope_2d at runtime without post-permutation.
11253+
if "wqkv" in name:
11254+
out_dim = data_torch.shape[0]
11255+
qkv_dim = out_dim // 3
11256+
head_dim = qkv_dim // n_head
11257+
11258+
if "weight" in name:
11259+
wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2 * qkv_dim, :], data_torch[2 * qkv_dim:, :]
11260+
wq = self.permute(wq, n_head)
11261+
wk = self.permute(wk, n_head)
11262+
data_torch = torch.cat([wq, wk, wv], dim=0)
11263+
elif "bias" in name:
11264+
bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2 * qkv_dim], data_torch[2 * qkv_dim:]
11265+
bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
11266+
bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
11267+
data_torch = torch.cat([bq, bk, bv], dim=0)
11268+
11269+
# Temporal embeddings: (T, 1, C) → (T, C)
11270+
if "pos_emb.time_weight" in name:
11271+
T, _, C = data_torch.shape
11272+
data_torch = data_torch.reshape(T, C)
11273+
11274+
# PatchMergerMLP tensor name mapping
11275+
# proj.0.weight → proj.linear_1.weight
11276+
# proj.2.weight → proj.linear_2.weight
11277+
if "mm_projector.proj.0." in name:
11278+
name = name.replace(".proj.0.", ".proj.linear_1.")
11279+
elif "mm_projector.proj.2." in name:
11280+
name = name.replace(".proj.2.", ".proj.linear_2.")
11281+
11282+
yield from super().modify_tensors(data_torch, name, bid)
11283+
11284+
1117911285
@ModelBase.register("CogVLMForCausalLM")
1118011286
class CogVLMVisionModel(MmprojModel):
1118111287

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3766,6 +3766,7 @@ class VisionProjectorType:
37663766
VOXTRAL = "voxtral"
37673767
LFM2 = "lfm2"
37683768
KIMIVL = "kimivl"
3769+
KIMIK25 = "kimik25"
37693770
LIGHTONOCR = "lightonocr"
37703771
COGVLM = "cogvlm"
37713772
JANUS_PRO = "janus_pro"

gguf-py/gguf/tensor_mapping.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,7 @@ class TensorNameMap:
13031303

13041304
MODEL_TENSOR.V_MMPROJ: (
13051305
"multi_modal_projector.linear_{bid}",
1306+
"mm_projector.proj.linear_{bid}", # Kimi-K2.5
13061307
"visual.merger.mlp.{bid}", # qwen2vl
13071308
"merger.mlp.{bid}",
13081309
),
@@ -1364,6 +1365,7 @@ class TensorNameMap:
13641365
MODEL_TENSOR.V_ENC_ATTN_QKV: (
13651366
"visual.blocks.{bid}.attn.qkv", # qwen3vl
13661367
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
1368+
"vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5
13671369
),
13681370

13691371
MODEL_TENSOR.V_ENC_ATTN_Q: (
@@ -1538,6 +1540,7 @@ class TensorNameMap:
15381540
"multi_modal_projector.norm",
15391541
"multi_modal_projector.layer_norm",
15401542
"multi_modal_projector.pre_norm",
1543+
"mm_projector.pre_norm", # Kimi-K2.5
15411544
"pre_mm_projector_norm",
15421545
"model.vision.linear_proj.norm1", # cogvlm
15431546
"merger.ln_q",

tools/mtmd/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_library(mtmd
1919
models/glm4v.cpp
2020
models/internvl.cpp
2121
models/kimivl.cpp
22+
models/kimik25.cpp
2223
models/llama4.cpp
2324
models/llava.cpp
2425
models/minicpmv.cpp

tools/mtmd/clip-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ enum projector_type {
235235
PROJECTOR_TYPE_LFM2A,
236236
PROJECTOR_TYPE_GLM4V,
237237
PROJECTOR_TYPE_YOUTUVL,
238+
PROJECTOR_TYPE_KIMIK25,
238239
PROJECTOR_TYPE_UNKNOWN,
239240
};
240241

@@ -268,6 +269,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
268269
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
269270
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
270271
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
272+
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
271273
};
272274

273275
static projector_type clip_projector_type_from_string(const std::string & str) {

tools/mtmd/clip.cpp

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -673,8 +673,8 @@ ggml_tensor * clip_graph::build_rope_2d(
673673
{
674674
first = ggml_view_3d(ctx0, cur,
675675
n_dim/2, n_head, n_pos,
676-
ggml_row_size(cur->type, n_dim),
677-
ggml_row_size(cur->type, n_dim*n_head),
676+
cur->nb[1],
677+
cur->nb[2],
678678
0);
679679
first = ggml_rope_ext(
680680
ctx0,
@@ -692,8 +692,8 @@ ggml_tensor * clip_graph::build_rope_2d(
692692
{
693693
second = ggml_view_3d(ctx0, cur,
694694
n_dim/2, n_head, n_pos,
695-
ggml_row_size(cur->type, n_dim),
696-
ggml_row_size(cur->type, n_dim*n_head),
695+
cur->nb[1],
696+
cur->nb[2],
697697
n_dim/2 * ggml_element_size(cur));
698698
second = ggml_rope_ext(
699699
ctx0,
@@ -826,6 +826,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
826826
{
827827
builder = std::make_unique<clip_graph_kimivl>(ctx, img);
828828
} break;
829+
case PROJECTOR_TYPE_KIMIK25:
830+
{
831+
builder = std::make_unique<clip_graph_kimik25>(ctx, img);
832+
} break;
829833
case PROJECTOR_TYPE_COGVLM:
830834
{
831835
builder = std::make_unique<clip_graph_cogvlm>(ctx, img);
@@ -1139,6 +1143,22 @@ struct clip_model_loader {
11391143
hparams.set_limit_image_tokens(8, 1024);
11401144
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
11411145
} break;
1146+
case PROJECTOR_TYPE_KIMIK25:
1147+
{
1148+
hparams.rope_theta = 10000.0f;
1149+
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
1150+
1151+
int min_pixels = 0, max_pixels = 0;
1152+
get_u32(KEY_IMAGE_MIN_PIXELS, min_pixels, false);
1153+
get_u32(KEY_IMAGE_MAX_PIXELS, max_pixels, false);
1154+
if (min_pixels > 0 && max_pixels > 0) {
1155+
hparams.image_min_pixels = min_pixels;
1156+
hparams.image_max_pixels = max_pixels;
1157+
hparams.warmup_image_size = static_cast<int>(std::sqrt(max_pixels));
1158+
} else {
1159+
hparams.set_limit_image_tokens(2, 4096);
1160+
}
1161+
} break;
11421162
case PROJECTOR_TYPE_GEMMA3:
11431163
{
11441164
// default value (used by all model sizes in gemma 3 family)
@@ -1668,6 +1688,7 @@ struct clip_model_loader {
16681688
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
16691689
} break;
16701690
case PROJECTOR_TYPE_KIMIVL:
1691+
case PROJECTOR_TYPE_KIMIK25:
16711692
{
16721693
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
16731694
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@@ -3165,6 +3186,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
31653186
res_imgs->entries.push_back(std::move(res));
31663187
} break;
31673188

3189+
case PROJECTOR_TYPE_KIMIK25:
3190+
{
3191+
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
3192+
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
3193+
original_size,
3194+
params.patch_size * params.n_merge,
3195+
params.image_min_pixels,
3196+
params.image_max_pixels);
3197+
const std::array<uint8_t, 3> pad_color = {0, 0, 0};
3198+
3199+
clip_image_u8 resized_img;
3200+
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color);
3201+
clip_image_f32_ptr res(clip_image_f32_init());
3202+
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
3203+
res_imgs->entries.push_back(std::move(res));
3204+
} break;
3205+
31683206
case PROJECTOR_TYPE_MLP:
31693207
case PROJECTOR_TYPE_MLP_NORM:
31703208
case PROJECTOR_TYPE_LDP:
@@ -3373,6 +3411,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
33733411
} break;
33743412
case PROJECTOR_TYPE_LFM2:
33753413
case PROJECTOR_TYPE_KIMIVL:
3414+
case PROJECTOR_TYPE_KIMIK25:
33763415
{
33773416
// dynamic size
33783417
int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
@@ -3714,6 +3753,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
37143753
} break;
37153754
case PROJECTOR_TYPE_PIXTRAL:
37163755
case PROJECTOR_TYPE_KIMIVL:
3756+
case PROJECTOR_TYPE_KIMIK25:
37173757
case PROJECTOR_TYPE_LIGHTONOCR:
37183758
{
37193759
// set the 2D positions
@@ -3850,6 +3890,47 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
38503890
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
38513891
}
38523892

3893+
// Debug: dump final embeddings if MTMD_DEBUG_EMBEDDINGS is set
3894+
if (std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr) {
3895+
const int64_t n_embd = embeddings->ne[0];
3896+
const int64_t n_tokens = embeddings->ne[1];
3897+
std::vector<float> emb_data(n_embd * n_tokens);
3898+
ggml_backend_tensor_get(embeddings, emb_data.data(), 0, ggml_nbytes(embeddings));
3899+
3900+
LOG_INF("\n=== MTMD_DEBUG_EMBEDDINGS ===\n");
3901+
LOG_INF("Shape: [%lld, %lld]\n", (long long)n_embd, (long long)n_tokens);
3902+
3903+
// Print first few values of first token
3904+
LOG_INF("Token 0 (first 16 values): ");
3905+
for (int i = 0; i < std::min((int64_t)16, n_embd); i++) {
3906+
LOG_INF("%.6f ", emb_data[i]);
3907+
}
3908+
LOG_INF("\n");
3909+
3910+
// Print last few values of first token
3911+
if (n_embd > 16) {
3912+
LOG_INF("Token 0 (last 16 values): ");
3913+
for (int64_t i = n_embd - 16; i < n_embd; i++) {
3914+
LOG_INF("%.6f ", emb_data[i]);
3915+
}
3916+
LOG_INF("\n");
3917+
}
3918+
3919+
// Compute and print statistics
3920+
float sum = 0.0f, sum_sq = 0.0f, min_val = emb_data[0], max_val = emb_data[0];
3921+
for (size_t i = 0; i < emb_data.size(); i++) {
3922+
sum += emb_data[i];
3923+
sum_sq += emb_data[i] * emb_data[i];
3924+
min_val = std::min(min_val, emb_data[i]);
3925+
max_val = std::max(max_val, emb_data[i]);
3926+
}
3927+
float mean = sum / emb_data.size();
3928+
float variance = (sum_sq / emb_data.size()) - (mean * mean);
3929+
LOG_INF("Stats: mean=%.6f, std=%.6f, min=%.6f, max=%.6f, sum=%.6f\n",
3930+
mean, sqrtf(variance), min_val, max_val, sum);
3931+
LOG_INF("=== END MTMD_DEBUG_EMBEDDINGS ===\n\n");
3932+
}
3933+
38533934
return true;
38543935
}
38553936

@@ -3896,6 +3977,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
38963977
return ctx->model.mm_2_w->ne[1];
38973978
case PROJECTOR_TYPE_LFM2:
38983979
case PROJECTOR_TYPE_KIMIVL:
3980+
case PROJECTOR_TYPE_KIMIK25:
38993981
return ctx->model.mm_2_w->ne[1];
39003982
case PROJECTOR_TYPE_COGVLM:
39013983
return ctx->model.mm_4h_to_h_w->ne[1];

0 commit comments

Comments
 (0)