Skip to content

Commit 31e8249

Browse files
ngxsonandrewmd5
andauthored
mtmd: support "frame merge" for qwen-vl-based models (#21858)
* feat: add video support for Qwen3.5 * various clean up * revise the design * fix llava-uhd case * nits * nits 2 --------- Co-authored-by: andrewmd5 <1297077+andrewmd5@users.noreply.github.com>
1 parent 6b80c74 commit 31e8249

10 files changed

Lines changed: 197 additions & 74 deletions

File tree

tools/mtmd/clip-graph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ struct clip_graph {
3737
float kq_scale; // TODO: maybe move this to hparams
3838
const clip_flash_attn_type flash_attn_type;
3939

40+
// TODO [QWEN_VIDEO]: improve this in the future
41+
int n_batch = 1;
42+
4043
ggml_context_ptr ctx0_ptr;
4144
ggml_context * ctx0;
4245
ggml_cgraph * gf;

tools/mtmd/clip-impl.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,6 @@ struct clip_image_u8 {
480480
buf[idx + 2] = rgb[2];
481481
}
482482

483-
size_t n_pixels() const {
484-
return (size_t) nx * (size_t) ny;
485-
}
486-
487483
size_t n_elements() const {
488484
return n_pixels() * 3;
489485
}
@@ -492,10 +488,16 @@ struct clip_image_u8 {
492488
std::vector<uint8_t> buf;
493489
int nx = 0;
494490
int ny = 0;
491+
492+
size_t n_pixels() const {
493+
return (size_t) nx * (size_t) ny;
494+
}
495495
};
496496

497497
// For images, buf.size() == nx*ny*3
498498
// Memory layout: RGBRGBRGB...
499+
// For seq, buf.size() == nx*ny*3*nt
500+
// Memory layout: RGBRGB...RGBRGB... (nt times)
499501
// For audio, only one channel is used, buf.size() == nx*ny
500502
// nx will be n_frames and ny will be n_mel
501503
struct clip_image_f32 {
@@ -544,10 +546,6 @@ struct clip_image_f32 {
544546
}
545547
}
546548

547-
size_t n_pixels() const {
548-
return (size_t) nx_ * (size_t) ny_;
549-
}
550-
551549
size_t n_elements() const {
552550
return n_pixels() * 3;
553551
}
@@ -580,6 +578,10 @@ struct clip_image_f32 {
580578
std::vector<float> buf;
581579
int nx_ = 0;
582580
int ny_ = 0;
581+
582+
size_t n_pixels() const {
583+
return (size_t) nx_ * (size_t) ny_;
584+
}
583585
};
584586

585587
//
@@ -627,6 +629,7 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
627629
va_end(args);
628630
}
629631

632+
#define LOG_TRC(...) clip_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
630633
#define LOG_DBG(...) clip_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
631634
#define LOG_INF(...) clip_log_internal(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
632635
#define LOG_WRN(...) clip_log_internal(GGML_LOG_LEVEL_WARN, __VA_ARGS__)

tools/mtmd/clip.cpp

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ ggml_tensor * clip_graph::build_inp() {
527527
}
528528

529529
ggml_tensor * clip_graph::build_inp_raw(int channels) {
530-
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx(), img.ny(), channels);
530+
ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, img.nx(), img.ny(), channels, n_batch);
531531
ggml_set_name(inp_raw, "inp_raw");
532532
ggml_set_input(inp_raw);
533533
return inp_raw;
@@ -848,8 +848,6 @@ ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale
848848
}
849849

850850
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
851-
GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported");
852-
853851
const clip_image_f32 & img = *imgs.entries[0];
854852
std::unique_ptr<clip_graph> builder;
855853

@@ -1009,6 +1007,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
10091007
GGML_ABORT("missing cgraph builder");
10101008
}
10111009

1010+
// TODO [QWEN_VIDEO]: improve this in the future
1011+
builder->n_batch = imgs.entries.size();
1012+
10121013
return builder->build();
10131014
}
10141015

@@ -3479,12 +3480,15 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
34793480

34803481
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
34813482
const clip_image_f32_batch & imgs = *imgs_c_ptr;
3482-
int batch_size = imgs.entries.size();
3483+
int n_batch_cur = imgs.entries.size();
3484+
3485+
// maximum supported batch size, usually == 2 for qwen-vl-based models
3486+
int n_batch_max = clip_model_n_batch_max(ctx);
34833487

34843488
// TODO @ngxson : implement batch size > 1 as a loop
34853489
// we don't need true batching support because the cgraph will gonna be big anyway
3486-
if (batch_size != 1) {
3487-
return false; // only support batch size of 1
3490+
if (n_batch_cur > n_batch_max) {
3491+
return false;
34883492
}
34893493

34903494
// if buffers are not allocated, we need to do a warmup run to allocate them
@@ -3555,18 +3559,20 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
35553559
// └─────┘ │
35563560
// ──────┘ x B
35573561

3558-
for (size_t i = 0; i < imgs.entries.size(); i++) {
3559-
const int nx = imgs.entries[i]->nx();
3560-
const int ny = imgs.entries[i]->ny();
3561-
const int n = nx * ny;
3562+
// IMPORTANT: [QWEN_VIDEO] the batch dim is currently used for temporal dim in Qwen-VL models
3563+
// All entries must have the same spatial size (enforced by can_batch_with() during merging)
3564+
{
3565+
const int nx = imgs.entries[0]->nx();
3566+
const int ny = imgs.entries[0]->ny();
3567+
const int n = nx * ny;
35623568

3563-
for (int b = 0; b < batch_size; b++) {
3569+
for (int b = 0; b < n_batch_cur; b++) {
35643570
const auto & buf = imgs.entries[b]->get_ro_buf();
35653571
float * batch_entry = inp_raw.data() + b * (3*n);
35663572
for (int y = 0; y < ny; y++) {
35673573
for (int x = 0; x < nx; x++) {
3568-
size_t base_src = 3*(y * nx + x); // idx of the first channel
3569-
size_t base_dst = y * nx + x; // idx of the first channel
3574+
size_t base_src = 3*(y * nx + x);
3575+
size_t base_dst = y * nx + x;
35703576
batch_entry[ base_dst] = buf[base_src ];
35713577
batch_entry[1*n + base_dst] = buf[base_src + 1];
35723578
batch_entry[2*n + base_dst] = buf[base_src + 2];
@@ -4549,6 +4555,17 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
45494555
return ctx->model.modality == CLIP_MODALITY_AUDIO;
45504556
}
45514557

4558+
int clip_model_n_batch_max(const struct clip_ctx * ctx) {
4559+
switch (ctx->proj_type()) {
4560+
case PROJECTOR_TYPE_QWEN2VL:
4561+
case PROJECTOR_TYPE_QWEN25VL:
4562+
case PROJECTOR_TYPE_QWEN3VL:
4563+
return 2;
4564+
default:
4565+
return 1;
4566+
}
4567+
}
4568+
45524569
//
45534570
// API used internally with mtmd
45544571
//

tools/mtmd/clip.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ struct clip_image_size {
2020
bool operator==(const clip_image_size & other) const {
2121
return width == other.width && height == other.height;
2222
}
23+
bool operator!=(const clip_image_size & other) const {
24+
return !(*this == other);
25+
}
26+
int area() const {
27+
return width * height;
28+
}
2329
};
2430

2531
struct clip_image_f32;
@@ -101,6 +107,8 @@ bool clip_is_llava(const struct clip_ctx * ctx);
101107
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
102108
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
103109

110+
int clip_model_n_batch_max(const struct clip_ctx * ctx);
111+
104112
std::map<ggml_backend_dev_t, size_t> clip_get_mem_usage(const struct clip_ctx * ctx);
105113

106114
struct clip_cap {

tools/mtmd/models/models.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@ struct clip_graph_pixtral : clip_graph {
3131
struct clip_graph_qwen2vl : clip_graph {
3232
clip_graph_qwen2vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
3333
ggml_cgraph * build() override;
34+
ggml_tensor * build_inp_with_temporal_merge();
3435
};
3536

36-
struct clip_graph_qwen3vl : clip_graph {
37-
clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
37+
struct clip_graph_qwen3vl : clip_graph_qwen2vl {
38+
clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph_qwen2vl(ctx, img) {}
3839
ggml_cgraph * build() override;
3940
};
4041

tools/mtmd/models/qwen2vl.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
11
#include "models.h"
22

3+
ggml_tensor * clip_graph_qwen2vl::build_inp_with_temporal_merge() {
4+
ggml_tensor * inp_raw = build_inp_raw();
5+
6+
GGML_ASSERT(img.nx() % (patch_size * 2) == 0);
7+
GGML_ASSERT(img.ny() % (patch_size * 2) == 0);
8+
9+
const size_t nb1 = ggml_row_size(inp_raw->type, img.nx());
10+
const size_t nb2 = ggml_row_size(inp_raw->type, img.nx() * img.ny());
11+
12+
if (n_batch == 1) {
13+
// still image input
14+
return ggml_add(ctx0,
15+
ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1),
16+
ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1));
17+
} else if (n_batch == 2) {
18+
// 2 frames input (video input)
19+
ggml_tensor * inp_0 = ggml_view_3d(ctx0, inp_raw,
20+
img.nx(), img.ny(), 3, nb1, nb2, 0);
21+
ggml_tensor * inp_1 = ggml_view_3d(ctx0, inp_raw,
22+
img.nx(), img.ny(), 3, nb1, nb2,
23+
nb2 * 3); // move to the second frame
24+
return ggml_add(ctx0,
25+
ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_0, patch_size, patch_size, 0, 0, 1, 1),
26+
ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_1, patch_size, patch_size, 0, 0, 1, 1));
27+
} else {
28+
GGML_ASSERT(false && "n_batch > 2 is not supported");
29+
}
30+
}
31+
332
ggml_cgraph * clip_graph_qwen2vl::build() {
433
GGML_ASSERT(model.patch_bias == nullptr);
534
GGML_ASSERT(model.class_embedding == nullptr);
@@ -16,17 +45,10 @@ ggml_cgraph * clip_graph_qwen2vl::build() {
1645

1746
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
1847

19-
ggml_tensor * inp_raw = build_inp_raw();
20-
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
21-
22-
GGML_ASSERT(img.nx() % (patch_size * 2) == 0);
23-
GGML_ASSERT(img.ny() % (patch_size * 2) == 0);
48+
ggml_tensor * inp = build_inp_with_temporal_merge();
2449

2550
// second conv dimension
2651
{
27-
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
28-
inp = ggml_add(ctx0, inp, inp_1);
29-
3052
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
3153
inp = ggml_cont_4d(
3254
ctx0, inp,

tools/mtmd/models/qwen3vl.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,10 @@ ggml_cgraph * clip_graph_qwen3vl::build() {
1313

1414
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
1515

16-
ggml_tensor * inp_raw = build_inp_raw();
17-
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
16+
ggml_tensor * inp = build_inp_with_temporal_merge();
1817

19-
GGML_ASSERT(img.nx() % (patch_size * 2) == 0);
20-
GGML_ASSERT(img.ny() % (patch_size * 2) == 0);
21-
22-
// second conv dimension
18+
// spatial merge
2319
{
24-
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
25-
inp = ggml_add(ctx0, inp, inp_1);
26-
2720
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
2821
inp = ggml_cont_4d(
2922
ctx0, inp,

tools/mtmd/mtmd-image.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ bool mtmd_image_preprocessor_deepseekocr::preprocess(const clip_image_u8 & img,
11161116
static constexpr int native_resolutions[] = { 1024 /* base */, 1280 /* large */ };
11171117
// TODO: support 512 (tiny) and 640 (small) once we have eval data for them
11181118

1119-
const int64_t orig_area = static_cast<int64_t>(img.n_pixels());
1119+
const int64_t orig_area = static_cast<int64_t>(img.get_size().area());
11201120

11211121
size_t mode_i = 0;
11221122
int64_t min_diff = std::numeric_limits<int64_t>::max();

0 commit comments

Comments
 (0)