Skip to content

Commit dcb3995

Browse files
committed
fix cancel requests
1 parent dde622e commit dcb3995

3 files changed

Lines changed: 49 additions & 44 deletions

File tree

patches/llama.cpp.patch

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
2-
index c580ec12..cac11586 100644
2+
index c580ec12..84cc9584 100644
33
--- a/examples/server/server.cpp
44
+++ b/examples/server/server.cpp
55
@@ -1552,30 +1552,29 @@ struct server_queue {
@@ -66,7 +66,7 @@ index c580ec12..cac11586 100644
6666
queue_tasks.pop_front();
6767
lock.unlock();
6868

69-
@@ -1878,16 +1877,24 @@ struct server_context {
69+
@@ -1878,14 +1877,20 @@ struct server_context {
7070
~server_context() {
7171
// Clear any sampling context
7272
for (server_slot & slot : slots) {
@@ -91,14 +91,9 @@ index c580ec12..cac11586 100644
9191
+ slot.spec = nullptr;
9292
+ }
9393

94-
- llama_batch_free(slot.batch_spec);
95-
+ if (slot.ctx_dft) {
96-
+ llama_batch_free(slot.batch_spec);
97-
+ }
94+
llama_batch_free(slot.batch_spec);
9895
}
99-
100-
llama_batch_free(batch);
101-
@@ -2005,7 +2012,7 @@ struct server_context {
96+
@@ -2005,7 +2010,7 @@ struct server_context {
10297

10398
slot.reset();
10499

@@ -107,7 +102,7 @@ index c580ec12..cac11586 100644
107102
}
108103

109104
default_generation_settings_for_props = slots[0].to_json();
110-
@@ -2106,7 +2113,7 @@ struct server_context {
105+
@@ -2106,7 +2111,7 @@ struct server_context {
111106
return true;
112107
}
113108

@@ -116,7 +111,7 @@ index c580ec12..cac11586 100644
116111
slot.reset();
117112
slot.id_task = task.id;
118113
slot.index = task.index;
119-
@@ -2114,10 +2121,10 @@ struct server_context {
114+
@@ -2114,10 +2119,10 @@ struct server_context {
120115
slot.params = std::move(task.params);
121116
slot.prompt_tokens = std::move(task.prompt_tokens);
122117

@@ -129,7 +124,16 @@ index c580ec12..cac11586 100644
129124
}
130125

131126
bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
132-
@@ -2548,10 +2555,10 @@ struct server_context {
127+
@@ -2214,7 +2219,7 @@ struct server_context {
128+
}
129+
130+
slot.add_token(result);
131+
- if (slot.params.stream) {
132+
+ if (slot.params.stream && slot.stop != STOP_TYPE_LIMIT) {
133+
send_partial_response(slot, result);
134+
}
135+
}
136+
@@ -2548,10 +2553,10 @@ struct server_context {
133137
server_task task(SERVER_TASK_TYPE_CANCEL);
134138
task.id_target = id_task;
135139
queue_results.remove_waiting_task_id(id_task);
@@ -142,7 +146,7 @@ index c580ec12..cac11586 100644
142146
}
143147

144148
// receive the results from task(s)
145-
@@ -2638,7 +2645,7 @@ struct server_context {
149+
@@ -2638,7 +2643,7 @@ struct server_context {
146150
// Functions to process the task
147151
//
148152

@@ -151,7 +155,7 @@ index c580ec12..cac11586 100644
151155
switch (task.type) {
152156
case SERVER_TASK_TYPE_COMPLETION:
153157
case SERVER_TASK_TYPE_INFILL:
154-
@@ -2652,17 +2659,17 @@ struct server_context {
158+
@@ -2652,17 +2657,17 @@ struct server_context {
155159
if (slot == nullptr) {
156160
// if no slot is available, we defer this task for processing later
157161
SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
@@ -172,7 +176,7 @@ index c580ec12..cac11586 100644
172176
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
173177
break;
174178
}
175-
@@ -2741,7 +2748,7 @@ struct server_context {
179+
@@ -2741,7 +2746,7 @@ struct server_context {
176180
if (slot->is_processing()) {
177181
// if requested slot is unavailable, we defer this task for processing later
178182
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
@@ -181,7 +185,7 @@ index c580ec12..cac11586 100644
181185
break;
182186
}
183187

184-
@@ -2777,7 +2784,7 @@ struct server_context {
188+
@@ -2777,7 +2782,7 @@ struct server_context {
185189
if (slot->is_processing()) {
186190
// if requested slot is unavailable, we defer this task for processing later
187191
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
@@ -190,7 +194,7 @@ index c580ec12..cac11586 100644
190194
break;
191195
}
192196

193-
@@ -2820,7 +2827,7 @@ struct server_context {
197+
@@ -2820,7 +2825,7 @@ struct server_context {
194198
if (slot->is_processing()) {
195199
// if requested slot is unavailable, we defer this task for processing later
196200
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
@@ -199,7 +203,7 @@ index c580ec12..cac11586 100644
199203
break;
200204
}
201205

202-
@@ -2872,7 +2879,7 @@ struct server_context {
206+
@@ -2872,7 +2877,7 @@ struct server_context {
203207

204208
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
205209
task.id = queue_tasks.get_new_id();
@@ -208,7 +212,7 @@ index c580ec12..cac11586 100644
208212
}
209213

210214
// apply context-shift if needed
211-
@@ -3441,7 +3448,7 @@ inline void signal_handler(int signal) {
215+
@@ -3441,7 +3446,7 @@ inline void signal_handler(int signal) {
212216
shutdown_handler(signal);
213217
}
214218

@@ -217,7 +221,7 @@ index c580ec12..cac11586 100644
217221
// own arguments required by this example
218222
common_params params;
219223

220-
@@ -3634,17 +3641,14 @@ int main(int argc, char ** argv) {
224+
@@ -3634,17 +3639,14 @@ int main(int argc, char ** argv) {
221225
}
222226

223227
// request slots data using task queue
@@ -241,7 +245,7 @@ index c580ec12..cac11586 100644
241245

242246
if (result->is_error()) {
243247
res_error(res, result->to_json());
244-
@@ -3673,17 +3677,16 @@ int main(int argc, char ** argv) {
248+
@@ -3673,17 +3675,16 @@ int main(int argc, char ** argv) {
245249
}
246250

247251
// request slots data using task queue
@@ -267,7 +271,7 @@ index c580ec12..cac11586 100644
267271

268272
if (result->is_error()) {
269273
res_error(res, result->to_json());
270-
@@ -3780,20 +3783,17 @@ int main(int argc, char ** argv) {
274+
@@ -3780,20 +3781,17 @@ int main(int argc, char ** argv) {
271275
}
272276
std::string filepath = params.slot_save_path + filename;
273277

@@ -297,7 +301,7 @@ index c580ec12..cac11586 100644
297301

298302
if (result->is_error()) {
299303
res_error(res, result->to_json());
300-
@@ -3812,20 +3812,17 @@ int main(int argc, char ** argv) {
304+
@@ -3812,20 +3810,17 @@ int main(int argc, char ** argv) {
301305
}
302306
std::string filepath = params.slot_save_path + filename;
303307

@@ -327,7 +331,7 @@ index c580ec12..cac11586 100644
327331

328332
if (result->is_error()) {
329333
res_error(res, result->to_json());
330-
@@ -3837,18 +3834,15 @@ int main(int argc, char ** argv) {
334+
@@ -3837,18 +3832,15 @@ int main(int argc, char ** argv) {
331335
};
332336

333337
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
@@ -353,7 +357,7 @@ index c580ec12..cac11586 100644
353357

354358
if (result->is_error()) {
355359
res_error(res, result->to_json());
356-
@@ -3952,10 +3946,9 @@ int main(int argc, char ** argv) {
360+
@@ -3952,10 +3944,9 @@ int main(int argc, char ** argv) {
357361
}
358362

359363
auto completion_id = gen_chatcmplid();
@@ -366,7 +370,7 @@ index c580ec12..cac11586 100644
366370
const auto & prompt = data.at("prompt");
367371
// TODO: this log can become very long, put it behind a flag or think about a more compact format
368372
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
369-
@@ -3970,9 +3963,9 @@ int main(int argc, char ** argv) {
373+
@@ -3970,9 +3961,9 @@ int main(int argc, char ** argv) {
370374

371375
task.prompt_tokens = std::move(tokenized_prompts[i]);
372376
task.params = server_task::params_from_json_cmpl(
@@ -379,7 +383,7 @@ index c580ec12..cac11586 100644
379383
task.id_selected_slot = json_value(data, "id_slot", -1);
380384

381385
// OAI-compat
382-
@@ -3980,18 +3973,18 @@ int main(int argc, char ** argv) {
386+
@@ -3980,18 +3971,18 @@ int main(int argc, char ** argv) {
383387
task.params.oaicompat_cmpl_id = completion_id;
384388
// oaicompat_model is already populated by params_from_json_cmpl
385389

@@ -403,15 +407,15 @@ index c580ec12..cac11586 100644
403407

404408
if (!stream) {
405409
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
406-
@@ -4283,7 +4276,6 @@ int main(int argc, char ** argv) {
410+
@@ -4283,7 +4274,6 @@ int main(int argc, char ** argv) {
407411
// create and queue the task
408412
json responses = json::array();
409413
bool error = false;
410414
- std::unordered_set<int> task_ids;
411415
{
412416
std::vector<server_task> tasks;
413417
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
414-
@@ -4296,26 +4288,27 @@ int main(int argc, char ** argv) {
418+
@@ -4296,26 +4286,27 @@ int main(int argc, char ** argv) {
415419
// OAI-compat
416420
task.params.oaicompat = oaicompat;
417421

@@ -437,7 +441,8 @@ index c580ec12..cac11586 100644
437441
- }, req.is_connection_closed);
438442
+ // get the result
439443
+ std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
440-
+
444+
445+
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
441446
+ ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
442447
+ for (auto & res : results) {
443448
+ GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
@@ -447,22 +452,21 @@ index c580ec12..cac11586 100644
447452
+ res_error(res, error_data);
448453
+ error = true;
449454
+ }, req.is_connection_closed);
450-
451-
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
455+
+
452456
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
453457
+ }
454458

455459
if (error) {
456460
return;
457-
@@ -4382,7 +4375,6 @@ int main(int argc, char ** argv) {
461+
@@ -4382,7 +4373,6 @@ int main(int argc, char ** argv) {
458462
// create and queue the task
459463
json responses = json::array();
460464
bool error = false;
461465
- std::unordered_set<int> task_ids;
462466
{
463467
std::vector<server_task> tasks;
464468
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
465-
@@ -4392,23 +4384,25 @@ int main(int argc, char ** argv) {
469+
@@ -4392,23 +4382,25 @@ int main(int argc, char ** argv) {
466470
task.id = ctx_server.queue_tasks.get_new_id();
467471
task.index = i;
468472
task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
@@ -501,7 +505,7 @@ index c580ec12..cac11586 100644
501505

502506
if (error) {
503507
return;
504-
@@ -4445,19 +4439,14 @@ int main(int argc, char ** argv) {
508+
@@ -4445,19 +4437,14 @@ int main(int argc, char ** argv) {
505509
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
506510
return;
507511
}
@@ -528,7 +532,7 @@ index c580ec12..cac11586 100644
528532

529533
if (result->is_error()) {
530534
res_error(res, result->to_json());
531-
@@ -4601,8 +4590,8 @@ int main(int argc, char ** argv) {
535+
@@ -4601,8 +4588,8 @@ int main(int argc, char ** argv) {
532536
common_chat_templates_source(ctx_server.chat_templates.get()),
533537
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
534538

@@ -540,7 +544,7 @@ index c580ec12..cac11586 100644
540544

541545
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
542546
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
543-
index aba2f27f..f3ed4434 100644
547+
index b497959f..ccc33566 100644
544548
--- a/examples/server/utils.hpp
545549
+++ b/examples/server/utils.hpp
546550
@@ -26,20 +26,20 @@
@@ -622,10 +626,10 @@ index 43d9fc4f..0e8fa1db 100644
622626

623627
add_library(ggml-base
624628
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
625-
index c0bdb9e1..bcc25530 100644
629+
index eac0b422..d96727d3 100644
626630
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
627631
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
628-
@@ -72,7 +72,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
632+
@@ -90,7 +90,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
629633
if (err_ != vk::Result::eSuccess) { \
630634
fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \
631635
#err, to_string(err_).c_str(), __FILE__, __LINE__); \

undreamai.cpp

100644100755
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,13 @@ void handle_error(httplib::Response & res, const json error_data){
197197
res.status = 500;
198198
}
199199

200-
void LLM::release_slot(server_slot slot)
200+
void LLM::release_slot(server_slot& slot)
201201
{
202202
if (slot.task_type == SERVER_TASK_TYPE_COMPLETION)
203203
{
204-
slot.params.stream = false;
205204
slot.i_batch = -1;
206-
slot.params.n_predict = 1;
205+
slot.params.n_predict = 0;
206+
slot.stop = STOP_TYPE_LIMIT;
207207
}
208208
else {
209209
slot.release();
@@ -422,7 +422,8 @@ void LLM::stop_service(){
422422
LOG_INFO("shutting down tasks", {});
423423

424424
// hack completion slots to think task is completed
425-
for (server_slot & slot : ctx_server.slots) {
425+
for (server_slot& slot : ctx_server.slots)
426+
{
426427
release_slot(slot);
427428
}
428429
LOG_INFO("Wait until tasks are finished", {});

undreamai.h

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class LLM {
8282
std::function<bool()> is_connection_closed = always_true
8383
);
8484
bool middleware_validate_api_key(const httplib::Request & req, httplib::Response & res);
85-
void release_slot(server_slot slot);
85+
void release_slot(server_slot& slot);
8686
};
8787

8888
#ifdef _WIN32

0 commit comments

Comments
 (0)