Skip to content

Commit 3dd2005

Browse files
aittalamclaude
andauthored
Reduce cuda dylibs size (#963)
* Reduce CUDA library size with opt-in build flags Reimplements the size-reduction ideas from PR #921 on top of llamacpp-7b8443a, and adds the upstream guards the original PR omitted so --no-iq-quants actually works at link time. New flags on llamafile/cuda.sh (all off by default): --minimize-size umbrella: enables the four flags below --minimal-archs virtual PTX for sm_75/sm_90, real SASS for 80/86/89 --no-iq-quants drop 8 mmq-instance-iq*.cu + define GGML_CUDA_NO_IQ_QUANTS --strip strip --strip-unneeded after link --compress --compress-mode=size (requires CUDA >= 12.8) build-functions.sh: collect_gpu_sources now categorizes template instances and includes only the 3 default fattn-vec quant combos (f16-f16, q4_0-q4_0, q8_0-q8_0), matching upstream CMake's default. Optional NO_IQ_QUANTS argument excludes mmq-instance-iq*.cu. New patches in llama.cpp.patches/patches/ wrap IQ-quant dispatch sites in #ifndef GGML_CUDA_NO_IQ_QUANTS so --no-iq-quants links cleanly without the excluded template instantiations: mmq.cu / mmq.cuh / mmvq.cu / convert.cu / cpy.cu Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Mirror cuda.sh size-reduction flags in cuda.bat and cuda_parallel.bat Adds the same option surface to both Windows build scripts: --minimize-size umbrella: enables the four flags below --minimal-archs virtual PTX for sm_75/sm_90, real SASS for 80/86/89 --no-iq-quants drop 8 mmq-instance-iq*.cu + define GGML_CUDA_NO_IQ_QUANTS --strip no-op on Windows (debug info lives in a separate .pdb); accepted for parity with cuda.sh --compress --compress-mode=size (requires CUDA >= 12.8) --fa-all-quants compile all fattn-vec quant combos and define GGML_CUDA_FA_ALL_QUANTS Source collection is restructured into the same 5 categories the bash script uses, defaulting to the 3 common fattn-vec combos (f16-f16, q4_0-q4_0, q8_0-q8_0) and gating mmq-instance-iq* under --no-iq-quants. CUDA version parsing now captures both major and minor for the >=12.8 compress check, with a safe 0 fallback when nvcc's output can't be parsed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent bec4d28 commit 3dd2005

9 files changed

Lines changed: 590 additions & 62 deletions
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
2+
--- a/llama.cpp/ggml/src/ggml-cuda/convert.cu
3+
+++ b/llama.cpp/ggml/src/ggml-cuda/convert.cu
4+
@@ -736,6 +736,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
5+
return dequantize_row_q5_K_cuda;
6+
case GGML_TYPE_Q6_K:
7+
return dequantize_row_q6_K_cuda;
8+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
9+
case GGML_TYPE_IQ2_XXS:
10+
return dequantize_row_iq2_xxs_cuda;
11+
case GGML_TYPE_IQ2_XS:
12+
@@ -754,6 +755,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
13+
return dequantize_row_iq4_xs_cuda;
14+
case GGML_TYPE_IQ3_S:
15+
return dequantize_row_iq3_s_cuda;
16+
+#endif // GGML_CUDA_NO_IQ_QUANTS
17+
case GGML_TYPE_MXFP4:
18+
return dequantize_row_mxfp4_cuda;
19+
case GGML_TYPE_NVFP4:
20+
@@ -791,6 +793,7 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
21+
return dequantize_row_q5_K_cuda;
22+
case GGML_TYPE_Q6_K:
23+
return dequantize_row_q6_K_cuda;
24+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
25+
case GGML_TYPE_IQ2_XXS:
26+
return dequantize_row_iq2_xxs_cuda;
27+
case GGML_TYPE_IQ2_XS:
28+
@@ -809,6 +812,7 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
29+
return dequantize_row_iq4_xs_cuda;
30+
case GGML_TYPE_IQ3_S:
31+
return dequantize_row_iq3_s_cuda;
32+
+#endif // GGML_CUDA_NO_IQ_QUANTS
33+
case GGML_TYPE_MXFP4:
34+
return dequantize_row_mxfp4_cuda;
35+
case GGML_TYPE_NVFP4:
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
2+
--- a/llama.cpp/ggml/src/ggml-cuda/cpy.cu
3+
+++ b/llama.cpp/ggml/src/ggml-cuda/cpy.cu
4+
@@ -360,6 +360,7 @@ static void ggml_cpy_q5_1_f32_cuda(
5+
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
6+
}
7+
8+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
9+
static void ggml_cpy_f32_iq4_nl_cuda(
10+
const char * cx, char * cdst, const int64_t ne,
11+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
12+
@@ -371,6 +372,7 @@ static void ggml_cpy_f32_iq4_nl_cuda(
13+
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
14+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
15+
}
16+
+#endif // GGML_CUDA_NO_IQ_QUANTS
17+
18+
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
19+
const int64_t ne = ggml_nelements(src0);
20+
@@ -465,9 +467,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
21+
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
22+
ggml_cpy_q5_0_f32_cuda
23+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
24+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
25+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
26+
ggml_cpy_f32_iq4_nl_cuda
27+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
28+
+#endif // GGML_CUDA_NO_IQ_QUANTS
29+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
30+
ggml_cpy_f32_q5_1_cuda
31+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
2+
--- a/llama.cpp/ggml/src/ggml-cuda/mmq.cu
3+
+++ b/llama.cpp/ggml/src/ggml-cuda/mmq.cu
4+
@@ -44,6 +44,7 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
5+
case GGML_TYPE_Q6_K:
6+
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
7+
break;
8+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
9+
case GGML_TYPE_IQ2_XXS:
10+
mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
11+
break;
12+
@@ -68,6 +69,7 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
13+
case GGML_TYPE_IQ4_NL:
14+
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
15+
break;
16+
+#endif // GGML_CUDA_NO_IQ_QUANTS
17+
default:
18+
GGML_ABORT("fatal error");
19+
break;
20+
@@ -286,6 +288,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
21+
case GGML_TYPE_Q4_K:
22+
case GGML_TYPE_Q5_K:
23+
case GGML_TYPE_Q6_K:
24+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
25+
case GGML_TYPE_IQ2_XXS:
26+
case GGML_TYPE_IQ2_XS:
27+
case GGML_TYPE_IQ2_S:
28+
@@ -294,6 +297,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
29+
case GGML_TYPE_IQ1_S:
30+
case GGML_TYPE_IQ4_XS:
31+
case GGML_TYPE_IQ4_NL:
32+
+#endif // GGML_CUDA_NO_IQ_QUANTS
33+
mmq_supported = true;
34+
break;
35+
default:
36+
@@ -356,9 +360,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
37+
return ne11 <= 128;
38+
case GGML_TYPE_Q6_K:
39+
return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256);
40+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
41+
case GGML_TYPE_IQ2_XS:
42+
case GGML_TYPE_IQ2_S:
43+
return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128;
44+
+#endif // GGML_CUDA_NO_IQ_QUANTS
45+
default:
46+
return true;
47+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
2+
--- a/llama.cpp/ggml/src/ggml-cuda/mmq.cuh
3+
+++ b/llama.cpp/ggml/src/ggml-cuda/mmq.cuh
4+
@@ -4088,6 +4088,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
5+
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
6+
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
7+
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
8+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
9+
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
10+
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
11+
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
12+
@@ -4096,6 +4097,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
13+
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
14+
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
15+
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
16+
+#endif // GGML_CUDA_NO_IQ_QUANTS
17+
18+
// -------------------------------------------------------------------------------------------------------------------------
19+
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
2+
--- a/llama.cpp/ggml/src/ggml-cuda/mmvq.cu
3+
+++ b/llama.cpp/ggml/src/ggml-cuda/mmvq.cu
4+
@@ -22,6 +22,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
5+
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
6+
case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1;
7+
case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1;
8+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
9+
case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
10+
case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1;
11+
case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1;
12+
@@ -31,6 +32,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
13+
case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1;
14+
case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1;
15+
case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1;
16+
+#endif // GGML_CUDA_NO_IQ_QUANTS
17+
default: return nullptr;
18+
}
19+
}
20+
@@ -50,6 +52,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
21+
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
22+
case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ;
23+
case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ;
24+
+#ifndef GGML_CUDA_NO_IQ_QUANTS
25+
case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
26+
case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ;
27+
case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ;
28+
@@ -57,6 +60,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
29+
case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ;
30+
case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ;
31+
case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ;
32+
+#endif // GGML_CUDA_NO_IQ_QUANTS
33+
default: return 1;
34+
}
35+
}

llamafile/build-functions.sh

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,64 @@ setup_build_dir() {
9898
mkdir -p "$build_dir"
9999
}
100100

101-
# Collect CUDA/HIP source files
101+
# Collect CUDA/HIP source files with selective template inclusion
102102
# Sets: CUDA_SOURCES, NUM_SOURCES
103-
# Args: $1 = GGML_CUDA_DIR, $2 = extra sources (optional, e.g., tinyblas.cu path)
103+
# Args: $1 = GGML_CUDA_DIR
104+
# $2 = caller-supplied sources prepended to the list (e.g., tinyblas.cu
105+
# for the default TinyBLAS build; empty for the --cublas build)
106+
# $3 = NO_IQ_QUANTS (optional, "1" to exclude IQ quant MMQ templates)
107+
# $4 = FA_ALL_QUANTS (optional, "1" to include all fattn-vec quant combos
108+
# instead of the 3 default ones; mirrors upstream's GGML_CUDA_FA_ALL_QUANTS)
104109
collect_gpu_sources() {
105110
local ggml_cuda_dir="$1"
106-
local extra_sources="$2"
111+
local caller_sources="$2"
112+
local no_iq_quants="${3:-0}"
113+
local fa_all_quants="${4:-0}"
107114

108-
CUDA_SOURCES="$extra_sources"
115+
CUDA_SOURCES="$caller_sources"
109116

110-
for f in "$ggml_cuda_dir"/*.cu "$ggml_cuda_dir/template-instances"/*.cu; do
117+
# 1. Main CUDA sources (always included)
118+
for f in "$ggml_cuda_dir"/*.cu; do
119+
[ -f "$f" ] && CUDA_SOURCES="$CUDA_SOURCES $f"
120+
done
121+
122+
local ti_dir="$ggml_cuda_dir/template-instances"
123+
124+
# 2. fattn-mma and fattn-tile instances (always included)
125+
for f in "$ti_dir"/fattn-mma-*.cu "$ti_dir"/fattn-tile-*.cu; do
126+
[ -f "$f" ] && CUDA_SOURCES="$CUDA_SOURCES $f"
127+
done
128+
129+
# 3. fattn-vec: default to the 4 common quant combos (f16-f16, q4_0-q4_0,
130+
# q8_0-q8_0, bf16-bf16), matching upstream CMake. With FA_ALL_QUANTS=1
131+
# include all fattn-vec instances (mirrors upstream's
132+
# GGML_CUDA_FA_ALL_QUANTS opt-in).
133+
if [ "$fa_all_quants" = "1" ]; then
134+
for f in "$ti_dir"/fattn-vec-instance-*.cu; do
135+
[ -f "$f" ] && CUDA_SOURCES="$CUDA_SOURCES $f"
136+
done
137+
else
138+
for f in "$ti_dir"/fattn-vec-instance-f16-f16.cu \
139+
"$ti_dir"/fattn-vec-instance-q4_0-q4_0.cu \
140+
"$ti_dir"/fattn-vec-instance-q8_0-q8_0.cu \
141+
"$ti_dir"/fattn-vec-instance-bf16-bf16.cu; do
142+
[ -f "$f" ] && CUDA_SOURCES="$CUDA_SOURCES $f"
143+
done
144+
fi
145+
146+
# 4. mmf instances (always included)
147+
for f in "$ti_dir"/mmf-*.cu; do
148+
[ -f "$f" ] && CUDA_SOURCES="$CUDA_SOURCES $f"
149+
done
150+
151+
# 5. mmq instances: include all, but optionally exclude IQ quant templates
152+
for f in "$ti_dir"/mmq-*.cu; do
111153
if [ -f "$f" ]; then
154+
if [ "$no_iq_quants" = "1" ]; then
155+
case "$(basename "$f")" in
156+
mmq-instance-iq*) continue ;;
157+
esac
158+
fi
112159
CUDA_SOURCES="$CUDA_SOURCES $f"
113160
fi
114161
done

0 commit comments

Comments
 (0)