Skip to content

Conversation

@ikawrakow
Copy link
Owner

@ikawrakow ikawrakow commented Dec 27, 2025

This PR adds support for Mimo-V2-Flash (https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash), and closes #1076.

Unlike the mainline PR 18328, which does not support flash attention (FA), FA is supported here.

Split mode "graph" is not supported for now. It turns out my splitting logic for the attention tensors only works when the K- and V attention head size is the same, which is not true for Mimo-V2. So, this will have to be a follow up PR. Also, I did not add support for HF->GGUF conversion, so mainline will need to be used for that.

Another limitation of this PR is that quantized KV cache cannot be used on CUDA(we get NaNs). It works fine on the CPU, so will need to investigate why quantized KV cache fails on CUDA. Fixed with latest commit.

The other caveat is that the large saving in KV cache size that could be possible due to the aggressive SWA used by Mimo-V2 is not realized, so here mainline has advantage.

On the other hand, because mainline does not support FA for Mimo-V2, I was still able to go to a much larger context than with mainline. I downloaded the IQ2_XXS quantization from Bartowski. I picked that one so that I can use full GPU offload on the 4x3090 system. With mainline the best I could do before OOM was a context of 8192 with u-batch size of 1024. With ik_llama.cpp I can go up to a context of 32k tokens using u-batch size of 2048. Correspondingly performance here is quite a bit better than over there (see sweep bench results below).

CPU-only performance is quite decent: I get 115 t/s for PP-2048 and 21.8 t/s for TG-128 on a Ryzen-3995WX CPU.

ik_llama.cpp, Mimo-V2-Flash, IQ2_XXS, 4x3090

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 128 0 2.722 752.41 1.427 89.68
2048 128 2048 2.662 769.44 1.443 88.70
2048 128 4096 2.698 759.18 1.459 87.75
2048 128 6144 2.720 752.83 1.471 87.04
2048 128 8192 2.740 747.45 1.482 86.36
2048 128 10240 2.771 739.05 1.498 85.45
2048 128 12288 2.784 735.62 1.522 84.11
2048 128 14336 2.814 727.84 1.533 83.52
2048 128 16384 2.828 724.26 1.534 83.44
2048 128 18432 2.857 716.92 1.545 82.87
2048 128 20480 2.881 710.90 1.551 82.50
2048 128 22528 2.891 708.32 1.562 81.95
2048 128 24576 2.916 702.26 1.586 80.72
2048 128 26624 2.948 694.81 1.592 80.38
2048 128 28672 2.970 689.65 1.608 79.59
2048 128 30720 2.996 683.61 1.610 79.51

llama.cpp, Mimo-V2-Flash, IQ2_XXS, 4x3090

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 128 0 1.736 589.87 1.613 79.35
1024 128 1024 1.611 635.74 1.856 68.96
1024 128 2048 1.640 624.32 1.884 67.94
1024 128 3072 1.651 620.09 1.922 66.59
1024 128 4096 1.687 607.12 1.961 65.26
1024 128 5120 1.710 598.66 1.997 64.11
1024 128 6144 1.732 591.33 2.037 62.83
1024 128 7168 1.760 581.83 2.074 61.72

@Ph0rk0z
Copy link

Ph0rk0z commented Dec 27, 2025

What do you get on hybrid? I wanted to grab q3/q4 and was hoping the less active params let it reason at reasonable speeds. Quant is gonna take the rest of the weekend to download :(

Iwan Kawrakow added 2 commits December 27, 2025 17:16
It still does not solve the Mimo-2 quantized cache issue.
@Nexesenex
Copy link
Contributor

If I use -khad for the KV cache:

llama-server -m XiaomiMiMo_MiMo-V2-Flash-IQ4_XS-00001-of-00005.gguf -t 18 -ngl 150 -b 128 -mg 1 -ts 18,18,12 -fa 1 -cuda fusion=1,offload-batch-size=128,mmq-id-size=128 -no-ooae -ot "^blk.([6-9]|1[0-6]|2[0-9]|3[0-0]|3[7-9]|4[0-3])\.ffn_(up|down|gate)_exps\.weight$=CPU" -ot "^blk.(6|30)\.ffn_(down)_exps\.weight$=CPU" -ot "^blk.(5|31)\.ffn_(up|gate)_exps\.weight$=CPU" -mqkv -gr -ger --chat-template chatglm4 --override-kv tokenizer.ggml.eot_token_id=int:151336 --override-kv glm4moe.expert_used_count=int:8 -ser 7,0.3 -c 49152 -ctk q6_0 -ctv q5_0 --context-shift 1 -khad --host 127.0.0.1 --port 8080

Then I get that:

llama_new_context_with_model: KV self size  = 3980.25 MiB, K (q6_0): 2544.75 MiB, V (q5_0): 1435.50 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.58 MiB
ggml\src\ggml.c:6217: GGML_ASSERT(popcount(n) == 1) failed

Otherwise, it works.

@MrHills-2
Copy link

MrHills-2 commented Dec 27, 2025

OK, so, my system:

AMD 7800x3d 8 core
128gb (64x2) ddr5 6000 ram dual channel, bout 80GBps
5090 (pcie5)

Command:

build/bin/llama-server -m models/MiMo-V2-Flash-IQ3_XS.gguf -ot "blk.(?:[0-9]|[1-3][0-9]|[4][0]).ffn.*=CPU" -c 32768 -b 8192 -ub 8192 -ctk q8_0 -ctv q8_0 --threads 7 -ngl 95 -sp -amb 512 --host 0.0.0.0 --port 8080 --webui none --repeat-last-n 2048 -mqkv --jinja

Performance:
Pp @ 29000: 900 t/s
TG @ 29000: 11 t/s

This is slower then minimax m2.1, which with the same settings gives me about 15 t/s. Is MTP working?

Also, the model doesn't think. Which is a problem because without thinking this model is kind of dumb. On silly tavern I have the thinking settings with chat completion to maximum, but it doesn't seem to work.

Edit: OK the model clearly has coherence problems. It's overall quite nonsensical, no matter the context size.

Edit2: Apparently the first layer is dense, so my -ot becomes -ot "blk.(?:[1-9]|[1-3][0-9]|[4][0]).ffn._exps.=CPU".
With this the new performance is:
Pp @ 29000: 1050 t/s
TG @ 29000: 16 t/s
I also noticed that the mtp layers are not even present in the gguf, so I take mtp is not functional.
That said, Xiaomi promises a 3x in performance with mtp, which would mean 48 t/s. That using almost total ffn offload to a 80GBps ram.. Doesn't that mean that you could run this model on a pcie5 ssd at that point? If possible, that would be massive.

@ikawrakow
Copy link
Owner Author

@MrHills-2

Neither ik_llama.cpp nor llama.cpp support MTP at this point. ubergarm tends to leave the MTP layers in when preparing his models for ik_llama.cpp (so in case MTP gets added people don't need to re-download the model), but apparently mainline quant cooks don't do that.

That said, Xiaomi promises a 3x in performance with mtp,

Haven't you learned yet that in the time age of LLMs, everybody shamelessly and massively exaggerates the utility of the thing that they have done?

@ikawrakow
Copy link
Owner Author

Edit: OK the model clearly has coherence problems. It's overall quite nonsensical, no matter the context size.

It looks like there is still an issue with SWA. I'm looking into it.

@ikawrakow ikawrakow marked this pull request as draft December 28, 2025 07:49
@ikawrakow
Copy link
Owner Author

Something is not quite right, so converted to draft.

@ikawrakow
Copy link
Owner Author

OK, PPL is the same as mainline (actually it is slightly lower). Checked for a few context lengths, and it is fine. If I had a bug in the SWA attention mask preparation, one would see it in PPL.

The issue is that when generating after a while the model starts endlessly repeating the same thing again and again. I thought there is an issue with my implementation. But I have now observed the exact same behavior also in mainline. The probability for endless repetition appears to be very sensitive to the temperature.

So, my best guess at this point is that my implementation is fine, but the IQ2_XSS model I downloaded from Bartowski for testing is just too low quality, causing the repetitions observed during TG.

So, I'll remove the draft status. Would appreciate test reports from more users.

@ikawrakow ikawrakow marked this pull request as ready for review December 28, 2025 08:49
@MrHills-2
Copy link

I think there's something off again. The model is def better, but it's still a little incoherent, and it doesn't follow simple prompts consistently, like keeping the answer beneath 100 words in length (it's super yappy).

The weirdest thing is that I'm getting 16t/s at 29000 tokens of context, but only 12t/s at 12000 tokens of context.

Also, we need a way to turn thinking on and off.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Add support for mimo v2 flash

5 participants