77
88void llama_hparams::set_swa_pattern (uint32_t n_pattern, bool dense_first) {
99 if (dense_first) {
10- for (uint32_t il = 0 ; il < n_layer; ++il) {
10+ for (uint32_t il = 0 ; il < n_layer () ; ++il) {
1111 is_swa_impl[il] = n_pattern == 0 || (il % n_pattern != 0 );
1212 }
1313 } else {
14- for (uint32_t il = 0 ; il < n_layer; ++il) {
14+ for (uint32_t il = 0 ; il < n_layer () ; ++il) {
1515 is_swa_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1 ));
1616 }
1717 }
18+
19+ for (uint32_t il = n_layer (); il < n_layer_all; ++il) {
20+ is_swa_impl[il] = false ;
21+ }
1822}
1923
20- // TODO: implement
21- // void llama_hparams::set_recr_pattern(uint32_t n_pattern, bool dense_first) {
22- // if (dense_first) {
23- // for (uint32_t il = 0; il < n_layer; ++il) {
24- // is_recr_impl[il] = n_pattern == 0 || (il % n_pattern != 0);
25- // }
26- // } else {
27- // for (uint32_t il = 0; il < n_layer; ++il) {
28- // is_recr_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
29- // }
30- // }
31- // }
24+ void llama_hparams::set_recr_pattern (uint32_t n_pattern, bool dense_first) {
25+ if (dense_first) {
26+ for (uint32_t il = 0 ; il < n_layer (); ++il) {
27+ is_recr_impl[il] = n_pattern == 0 || (il % n_pattern != 0 );
28+ }
29+ } else {
30+ for (uint32_t il = 0 ; il < n_layer (); ++il) {
31+ is_recr_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1 ));
32+ }
33+ }
34+
35+ for (uint32_t il = n_layer (); il < n_layer_all; ++il) {
36+ is_recr_impl[il] = false ;
37+ }
38+ }
3239
3340bool llama_hparams::is_swa_any () const {
34- for (uint32_t il = 0 ; il < n_layer ; ++il) {
41+ for (uint32_t il = 0 ; il < n_layer_all ; ++il) {
3542 if (is_swa_impl[il]) {
3643 return true ;
3744 }
@@ -41,23 +48,23 @@ bool llama_hparams::is_swa_any() const {
4148}
4249
4350uint32_t llama_hparams::n_head (uint32_t il) const {
44- if (il < n_layer ) {
51+ if (il < n_layer_all ) {
4552 return n_head_arr[il];
4653 }
4754
4855 GGML_ABORT (" fatal error" );
4956}
5057
5158uint32_t llama_hparams::n_head_kv (uint32_t il) const {
52- if (il < n_layer ) {
59+ if (il < n_layer_all ) {
5360 return n_head_kv_arr[il];
5461 }
5562
5663 GGML_ABORT (" fatal error" );
5764}
5865
5966uint32_t llama_hparams::n_ff (uint32_t il) const {
60- if (il < n_layer ) {
67+ if (il < n_layer_all ) {
6168 return n_ff_arr[il];
6269 }
6370
@@ -76,7 +83,7 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const {
7683}
7784
7885uint32_t llama_hparams::n_rot (uint32_t il) const {
79- if (il < n_layer ) {
86+ if (il < n_layer_all ) {
8087 return is_swa (il) ? n_rot_swa : n_rot_full;
8188 }
8289
@@ -98,15 +105,15 @@ uint32_t llama_hparams::n_embd_out() const {
98105}
99106
100107uint32_t llama_hparams::n_embd_head_k (uint32_t il) const {
101- if (il < n_layer ) {
108+ if (il < n_layer_all ) {
102109 return is_swa (il) ? n_embd_head_k_swa : n_embd_head_k_full;
103110 }
104111
105112 GGML_ABORT (" fatal error" );
106113}
107114
108115uint32_t llama_hparams::n_embd_head_v (uint32_t il) const {
109- if (il < n_layer ) {
116+ if (il < n_layer_all ) {
110117 return is_swa (il) ? n_embd_head_v_swa : n_embd_head_v_full;
111118 }
112119
@@ -127,7 +134,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
127134
128135bool llama_hparams::is_n_embd_k_gqa_variable () const {
129136 const uint32_t val = n_embd_k_gqa ();
130- for (uint32_t il = 0 ; il < n_layer ; ++il) {
137+ for (uint32_t il = 0 ; il < n_layer_all ; ++il) {
131138 if (val != n_embd_k_gqa (il)) {
132139 return true ;
133140 }
@@ -138,7 +145,7 @@ bool llama_hparams::is_n_embd_k_gqa_variable() const {
138145
139146bool llama_hparams::is_n_embd_v_gqa_variable () const {
140147 const uint32_t val = n_embd_v_gqa ();
141- for (uint32_t il = 0 ; il < n_layer ; ++il) {
148+ for (uint32_t il = 0 ; il < n_layer_all ; ++il) {
142149 if (val != n_embd_v_gqa (il)) {
143150 return true ;
144151 }
@@ -149,7 +156,7 @@ bool llama_hparams::is_n_embd_v_gqa_variable() const {
149156
150157uint32_t llama_hparams::n_embd_k_gqa_max () const {
151158 uint32_t val = n_embd_k_gqa ();
152- for (uint32_t il = 0 ; il < n_layer ; ++il) {
159+ for (uint32_t il = 0 ; il < n_layer_all ; ++il) {
153160 val = std::max (val, n_embd_k_gqa (il));
154161 }
155162
@@ -158,7 +165,7 @@ uint32_t llama_hparams::n_embd_k_gqa_max() const {
158165
159166uint32_t llama_hparams::n_embd_v_gqa_max () const {
160167 uint32_t val = n_embd_v_gqa ();
161- for (uint32_t il = 0 ; il < n_layer ; ++il) {
168+ for (uint32_t il = 0 ; il < n_layer_all ; ++il) {
162169 val = std::max (val, n_embd_v_gqa (il));
163170 }
164171
@@ -207,23 +214,23 @@ uint32_t llama_hparams::n_embd_s() const {
207214}
208215
209216bool llama_hparams::is_recr (uint32_t il) const {
210- if (il < n_layer ) {
217+ if (il < n_layer_all ) {
211218 return is_recr_impl[il];
212219 }
213220
214- GGML_ABORT (" %s: il (%u) out of bounds (n_layer : %u)\n " , __func__, il, n_layer );
221+ GGML_ABORT (" %s: il (%u) out of bounds (n_layer_all : %u)\n " , __func__, il, n_layer_all );
215222}
216223
217224uint32_t llama_hparams::n_pos_per_embd () const {
218225 return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1 ;
219226}
220227
221228bool llama_hparams::is_swa (uint32_t il) const {
222- if (il < n_layer ) {
229+ if (il < n_layer_all ) {
223230 return is_swa_impl[il];
224231 }
225232
226- GGML_ABORT (" fatal error " );
233+ GGML_ABORT (" %s: il (%u) out of bounds (n_layer_all: %u) \n " , __func__, il, n_layer_all );
227234}
228235
229236bool llama_hparams::is_mla () const {
@@ -242,12 +249,6 @@ uint32_t llama_hparams::n_embd_head_v_mla() const {
242249}
243250
244251bool llama_hparams::has_kv (uint32_t il) const {
245- if (kv_only_nextn) {
246- // MTP head: only the trailing nextn_predict_layers blocks own a KV cache;
247- // the leading trunk blocks are not executed in this graph.
248- return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers);
249- }
250-
251252 if (n_layer_kv_from_start >= 0 ) {
252253 if (il < (uint32_t ) n_layer_kv_from_start) {
253254 return true ;
@@ -260,16 +261,8 @@ bool llama_hparams::has_kv(uint32_t il) const {
260261 return true ;
261262}
262263
263- uint32_t llama_hparams::n_layer_kv () const {
264- uint32_t res = 0 ;
265-
266- for (uint32_t il = 0 ; il < n_layer; ++il) {
267- if (has_kv (il)) {
268- res++;
269- }
270- }
271-
272- return res;
264+ uint32_t llama_hparams::n_layer () const {
265+ return n_layer_all - n_layer_nextn;
273266}
274267
275268bool llama_hparams::use_mrope () const {
0 commit comments