Skip to content

Commit 2b20260

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 2746fc1 + 0a36cea commit 2b20260

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/llama-load-tensors.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,12 +2952,13 @@ bool create_tensors_helper::create_tensors() {
29522952
throw std::runtime_error("unknown architecture");
29532953
}
29542954
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) {
2955+
const int n_layer = model.layers.size() - model.hparams.nextn_predict_layers;
29552956
printf("================================ max_gpu = %d\n", model.max_gpu);
29562957
std::vector<size_t> mem_used(model.splits.size(), 0);
29572958
const auto & hparams = model.hparams;
29582959
int gqa_ratio = hparams.n_head() / hparams.n_head_kv();
29592960
auto cur_splits = model.splits;
2960-
int adjust_step = std::max(1, int(model.layers.size() / (2*model.splits.size())));
2961+
int adjust_step = std::max(1, int(n_layer / (2*model.splits.size())));
29612962
if (model.max_gpu > 1 && model.max_gpu < int(cur_splits.size())) {
29622963
bool equal_split = true;
29632964
for (int i = 0; i < int(cur_splits.size()); ++i) {
@@ -2969,13 +2970,13 @@ bool create_tensors_helper::create_tensors() {
29692970
if (equal_split) {
29702971
if (cur_splits.size() % model.max_gpu == 0) {
29712972
int nadj = cur_splits.size()/model.max_gpu;
2972-
adjust_step = (model.layers.size() + nadj - 1) / nadj;
2973+
adjust_step = (n_layer + nadj - 1) / nadj;
29732974
} else {
2974-
adjust_step = (model.layers.size() + cur_splits.size() - 1)/cur_splits.size();
2975+
adjust_step = (n_layer + cur_splits.size() - 1)/cur_splits.size();
29752976
}
29762977
}
29772978
}
2978-
for (int il = 0; il < int(model.layers.size()); ++il) {
2979+
for (int il = 0; il < n_layer; ++il) {
29792980
if (ggml_backend_buft_is_host(model.buft_layer[il].buft_matrix)) {
29802981
LLAMA_LOG_INFO("%s: not splitting layer %d because buffer type is host\n", __func__, il);
29812982
continue;

0 commit comments

Comments
 (0)