@@ -2343,7 +2343,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23432343 static_assert (MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
23442344 if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
23452345 if (ggml_is_quantized (src0->type )) {
2346- if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
2346+ const int mmvq_mmid_max = get_mmvq_mmid_max_batch (src0->type , cc);
2347+ if (ne2 <= mmvq_mmid_max) {
23472348 ggml_cuda_mul_mat_vec_q (ctx, src0, src1, ids, dst);
23482349 return ;
23492350 }
@@ -2946,14 +2947,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
29462947 }
29472948
29482949 // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
2949- if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized (node->src [0 ]->type ) || node->ne [2 ] > MMVQ_MMID_MAX_BATCH_SIZE)) {
2950- // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
2951- // TODO: figure out a way to enable for larger batch sizes, without hurting performance
2952- // ref: https://github.com/ggml-org/llama.cpp/pull/18958
2953- use_cuda_graph = false ;
2950+ if (node->op == GGML_OP_MUL_MAT_ID) {
2951+ const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2952+ const int mmvq_mmid_max = get_mmvq_mmid_max_batch (node->src [0 ]->type , cc);
2953+ if (!ggml_is_quantized (node->src [0 ]->type ) || node->ne [2 ] > mmvq_mmid_max) {
2954+ // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
2955+ // TODO: figure out a way to enable for larger batch sizes, without hurting performance
2956+ // ref: https://github.com/ggml-org/llama.cpp/pull/18958
2957+ use_cuda_graph = false ;
29542958#ifndef NDEBUG
2955- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to unsupported node type\n " , __func__);
2959+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to unsupported node type\n " , __func__);
29562960#endif
2961+ }
29572962 }
29582963
29592964 if (!use_cuda_graph) {
0 commit comments