|
| 1 | +#include <atomic> |
| 2 | +#include <mutex> |
| 3 | +#include <queue> |
1 | 4 | #include <string> |
2 | 5 | #include <vector> |
3 | 6 |
|
@@ -426,29 +429,87 @@ void init_constraint(Rice::Module& m) { |
426 | 429 | "_solve", |
427 | 430 | [](Object self, CpModelBuilder& model, SatParameters& parameters, Object callback) { |
428 | 431 | Model m; |
| 432 | + m.Add(NewSatParameters(parameters)); |
429 | 433 |
|
430 | | - if (!callback.is_nil()) { |
431 | | - // use a single worker since Ruby code cannot be run in a non-Ruby thread |
432 | | - parameters.set_num_search_workers(1); |
| 434 | + std::atomic<bool> done{false}; |
| 435 | + std::queue<CpSolverResponse> queue; |
| 436 | + std::mutex queue_lock; |
| 437 | + Rice::Object ruby_thread; |
| 438 | + std::optional<Rice::Exception> exception; |
433 | 439 |
|
434 | | - m.Add(NewFeasibleSolutionObserver( |
435 | | - [&](const CpSolverResponse& r) { |
436 | | - if (!ruby_native_thread_p()) { |
437 | | - throw std::runtime_error{"Non-Ruby thread"}; |
| 440 | + // TODO release GVL when not calling Ruby |
| 441 | + auto ruby_observer = [&]() { |
| 442 | + try { |
| 443 | + while (true) { |
| 444 | + if (done.load()) { |
| 445 | + std::lock_guard<std::mutex> guard(queue_lock); |
| 446 | + if (queue.empty()) { |
| 447 | + break; |
| 448 | + } |
438 | 449 | } |
439 | 450 |
|
440 | | - callback.call("response=", r); |
441 | | - callback.call("on_solution_callback"); |
| 451 | + while (true) { |
| 452 | + CpSolverResponse r; |
| 453 | + { |
| 454 | + std::lock_guard<std::mutex> guard(queue_lock); |
| 455 | + if (queue.empty()) { |
| 456 | + break; |
| 457 | + } |
| 458 | + r = queue.front(); |
| 459 | + queue.pop(); |
| 460 | + } |
| 461 | + |
| 462 | + callback.call("response=", r); |
| 463 | + callback.call("on_solution_callback"); |
442 | 464 |
|
443 | | - if (callback.attr_get("@stopped")) { |
444 | | - StopSearch(&m); |
| 465 | + if (callback.attr_get("@stopped")) { |
| 466 | + StopSearch(&m); |
| 467 | + return Qnil; |
| 468 | + } |
445 | 469 | } |
| 470 | + |
| 471 | + rb_thread_schedule(); |
| 472 | + } |
| 473 | + } catch (const Rice::Exception& e) { |
| 474 | + exception = e; |
| 475 | + StopSearch(&m); |
| 476 | + } catch (const std::exception& e) { |
| 477 | + exception = Rice::Exception(rb_eRuntimeError, e.what()); |
| 478 | + StopSearch(&m); |
| 479 | + } |
| 480 | + return Qnil; |
| 481 | + }; |
| 482 | + |
| 483 | + auto ruby_wrapper = [](void* arg) { |
| 484 | + return (*static_cast<decltype(ruby_observer)*>(arg))(); |
| 485 | + }; |
| 486 | + |
| 487 | + if (!callback.is_nil()) { |
| 488 | + ruby_thread = rb_thread_create(ruby_wrapper, &ruby_observer); |
| 489 | + |
| 490 | + m.Add(NewFeasibleSolutionObserver( |
| 491 | + [&](const CpSolverResponse& r) { |
| 492 | + std::lock_guard<std::mutex> guard(queue_lock); |
| 493 | + queue.push(r); |
446 | 494 | }) |
447 | 495 | ); |
448 | 496 | } |
449 | 497 |
|
450 | | - m.Add(NewSatParameters(parameters)); |
451 | | - return SolveCpModel(model.Build(), &m); |
| 498 | + CpSolverResponse r; |
| 499 | + Rice::detail::no_gvl([&]() { |
| 500 | + r = SolveCpModel(model.Build(), &m); |
| 501 | + done = true; |
| 502 | + }); |
| 503 | + |
| 504 | + if (!callback.is_nil()) { |
| 505 | + ruby_thread.call("value"); |
| 506 | + } |
| 507 | + |
| 508 | + if (exception.has_value()) { |
| 509 | + throw exception.value(); |
| 510 | + } |
| 511 | + |
| 512 | + return r; |
452 | 513 | }) |
453 | 514 | .define_method( |
454 | 515 | "_solution_integer_value", |
|
0 commit comments