1818
1919#define MAX (a, b ) ((a) > (b) ? (a) : (b))
2020
21+ constexpr size_t k_max_extra_alloc = 1024 *1024 *64 ;
22+
2123// backend buffer type
2224
2325const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft) {
@@ -1162,6 +1164,8 @@ struct ggml_backend_sched {
11621164 char * context_buffer;
11631165 size_t context_buffer_size;
11641166
1167+ std::array<ggml_backend_buffer_t , GGML_SCHED_MAX_BACKENDS> input_memory_bufs = {{ nullptr }};
1168+
11651169 uint32_t op_offload[(GGML_OP_COUNT + 31 )/32 ];
11661170
11671171 bool only_active_experts;
@@ -2093,14 +2097,62 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
20932097 return ggml_backend_sched_compute_splits_sm_graph (sched);
20942098 }
20952099
2100+ std::array<bool , GGML_SCHED_MAX_BACKENDS> needs_sync{{true }};
2101+ std::array<bool , GGML_SCHED_MAX_BACKENDS> own_cpy{{false }};
2102+
2103+ if (sched->split_mode_graph ) {
2104+ std::vector<size_t > input_size (sched->n_backends , 0 );
2105+ for (int i = 0 ; i < sched->n_splits ; i++) {
2106+ auto split = &sched->splits [i];
2107+ int split_backend_id = split->backend_id ;
2108+ for (int j = 0 ; j < split->n_inputs ; ++j) {
2109+ auto nbytes = ggml_nbytes (split->inputs [j]);
2110+ nbytes = 256 *((nbytes + 255 )/256 );
2111+ input_size[split_backend_id] += nbytes;
2112+ }
2113+ }
2114+ for (int backend_id = 0 ; backend_id < sched->n_backends ; ++backend_id) {
2115+ if (!input_size[backend_id]) continue ; // this backend has no inputs, so no need to worry about it.
2116+ if (input_size[backend_id] <= k_max_extra_alloc) {
2117+ if (sched->input_memory_bufs [backend_id] && sched->input_memory_bufs [backend_id]->size < input_size[backend_id]) {
2118+ ggml_backend_buffer_free (sched->input_memory_bufs [backend_id]);
2119+ sched->input_memory_bufs [backend_id] = nullptr ;
2120+ }
2121+ if (!sched->input_memory_bufs [backend_id]) {
2122+ sched->input_memory_bufs [backend_id] = ggml_backend_alloc_buffer (sched->backends [backend_id], input_size[backend_id]);
2123+ }
2124+ auto ptr = (char *)ggml_backend_buffer_get_base (sched->input_memory_bufs [backend_id]);
2125+ for (int i = 0 ; i < sched->n_splits ; ++i) {
2126+ auto split = &sched->splits [i];
2127+ if (split->backend_id != backend_id) continue ;
2128+ for (int j = 0 ; j < split->n_inputs ; ++j) {
2129+ auto input_cpy = tensor_copy (split->inputs [j], backend_id, sched->cur_copy );
2130+ for (int k = 0 ; k < split->graph .n_nodes ; ++k) {
2131+ auto node = split->graph .nodes [k];
2132+ for (int l = 0 ; l < GGML_MAX_SRC; ++l) {
2133+ if (node->src [l] && node->src [l]->data == input_cpy->data ) node->src [l]->data = ptr;
2134+ }
2135+ }
2136+ input_cpy->data = ptr;
2137+ auto nbytes = ggml_nbytes (split->inputs [j]);
2138+ nbytes = 256 *((nbytes + 255 )/256 );
2139+ ptr += nbytes;
2140+ }
2141+ }
2142+ needs_sync[backend_id] = false ;
2143+ own_cpy[backend_id] = true ;
2144+ }
2145+ }
2146+ // printf("=== Input memory per backend:\n");
2147+ // for (int i = 0; i < sched->n_backends; ++i) printf(" %d: %.2f MiB\n", i, input_size[i]/1024./1024.);
2148+ }
2149+
20962150 struct ggml_backend_sched_split * splits = sched->splits ;
20972151
20982152 std::vector<int32_t > ids;
20992153 std::vector<uint32_t > unique_ids;
21002154 ggml_tensor * last_ids_tensor = nullptr ;
21012155
2102- std::array<bool , GGML_SCHED_MAX_BACKENDS> needs_sync{{true }};
2103-
21042156 for (int i = 0 ; i < sched->n_splits ; i++) {
21052157#if IK_PRINT_TIMING
21062158 int64_t tim1 = ggml_time_us ();
@@ -2112,7 +2164,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
21122164 // copy the input tensors to the split backend
21132165 ggml_backend_sched_copy_inputs (sched, split, needs_sync, ids, unique_ids, last_ids_tensor);
21142166
2115- if (split->n_inputs > 0 ) {
2167+ if (split->n_inputs > 0 && !own_cpy[split_backend_id] ) {
21162168 needs_sync[split_backend_id] = true ;
21172169 }
21182170 if (!sched->callback_eval ) {
@@ -2240,6 +2292,11 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
22402292 ggml_backend_event_free (sched->events [b][c]);
22412293 }
22422294 }
2295+ for (int i = 0 ; i < sched->n_backends ; ++i) {
2296+ if (sched->input_memory_bufs [i]) {
2297+ ggml_backend_buffer_free (sched->input_memory_bufs [i]);
2298+ }
2299+ }
22432300 ggml_gallocr_free (sched->galloc );
22442301 ggml_free (sched->ctx );
22452302 ggml_hash_set_free (&sched->hash_set );
0 commit comments