Skip to content

Commit 605b40e

Browse files
ankanezaithottakath
andcommitted
Added support for releasing GVL - resolves #81
Co-authored-by: Zai Thottakath <zaithottakath@gmail.com> Assisted-by: aider (gpt-5.2) <aider@aider.chat>
1 parent 08044d3 commit 605b40e

3 files changed

Lines changed: 87 additions & 13 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.18.0 (unreleased)
22

3+
- Added support for releasing GVL
34
- Added `stop_search` method to `CpSolverSolutionCallback`
45
- Fixed error with `apply_locks` and `set_allowed_vehicles_for_index` methods
56
- Dropped support for Ruby < 3.3

ext/or-tools/constraint.cpp

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#include <atomic>
2+
#include <mutex>
3+
#include <queue>
14
#include <string>
25
#include <vector>
36

@@ -426,29 +429,87 @@ void init_constraint(Rice::Module& m) {
426429
"_solve",
427430
[](Object self, CpModelBuilder& model, SatParameters& parameters, Object callback) {
428431
Model m;
432+
m.Add(NewSatParameters(parameters));
429433

430-
if (!callback.is_nil()) {
431-
// use a single worker since Ruby code cannot be run in a non-Ruby thread
432-
parameters.set_num_search_workers(1);
434+
std::atomic<bool> done{false};
435+
std::queue<CpSolverResponse> queue;
436+
std::mutex queue_lock;
437+
Rice::Object ruby_thread;
438+
std::optional<Rice::Exception> exception;
433439

434-
m.Add(NewFeasibleSolutionObserver(
435-
[&](const CpSolverResponse& r) {
436-
if (!ruby_native_thread_p()) {
437-
throw std::runtime_error{"Non-Ruby thread"};
440+
// TODO release GVL when not calling Ruby
441+
auto ruby_observer = [&]() {
442+
try {
443+
while (true) {
444+
if (done.load()) {
445+
std::lock_guard<std::mutex> guard(queue_lock);
446+
if (queue.empty()) {
447+
break;
448+
}
438449
}
439450

440-
callback.call("response=", r);
441-
callback.call("on_solution_callback");
451+
while (true) {
452+
CpSolverResponse r;
453+
{
454+
std::lock_guard<std::mutex> guard(queue_lock);
455+
if (queue.empty()) {
456+
break;
457+
}
458+
r = queue.front();
459+
queue.pop();
460+
}
461+
462+
callback.call("response=", r);
463+
callback.call("on_solution_callback");
442464

443-
if (callback.attr_get("@stopped")) {
444-
StopSearch(&m);
465+
if (callback.attr_get("@stopped")) {
466+
StopSearch(&m);
467+
return Qnil;
468+
}
445469
}
470+
471+
rb_thread_schedule();
472+
}
473+
} catch (const Rice::Exception& e) {
474+
exception = e;
475+
StopSearch(&m);
476+
} catch (const std::exception& e) {
477+
exception = Rice::Exception(rb_eRuntimeError, e.what());
478+
StopSearch(&m);
479+
}
480+
return Qnil;
481+
};
482+
483+
auto ruby_wrapper = [](void* arg) {
484+
return (*static_cast<decltype(ruby_observer)*>(arg))();
485+
};
486+
487+
if (!callback.is_nil()) {
488+
ruby_thread = rb_thread_create(ruby_wrapper, &ruby_observer);
489+
490+
m.Add(NewFeasibleSolutionObserver(
491+
[&](const CpSolverResponse& r) {
492+
std::lock_guard<std::mutex> guard(queue_lock);
493+
queue.push(r);
446494
})
447495
);
448496
}
449497

450-
m.Add(NewSatParameters(parameters));
451-
return SolveCpModel(model.Build(), &m);
498+
CpSolverResponse r;
499+
Rice::detail::no_gvl([&]() {
500+
r = SolveCpModel(model.Build(), &m);
501+
done = true;
502+
});
503+
504+
if (!callback.is_nil()) {
505+
ruby_thread.call("value");
506+
}
507+
508+
if (exception.has_value()) {
509+
throw exception.value();
510+
}
511+
512+
return r;
452513
})
453514
.define_method(
454515
"_solution_integer_value",

test/constraint_test.rb

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ def on_solution_callback
6161
end
6262
end
6363

64+
class ExceptionCallback < ORTools::CpSolverSolutionCallback
65+
def on_solution_callback
66+
raise "Error!"
67+
end
68+
end
69+
6470
class ConstraintTest < Minitest::Test
6571
# https://developers.google.com/optimization/cp/cp_solver
6672
def test_cp_sat_solver
@@ -174,6 +180,12 @@ def test_cryptoarithmetic
174180
stop_callback = StopSearchCallback.new
175181
status = solver.solve(model, stop_callback)
176182
assert_equal 3, stop_callback.solution_count
183+
184+
exception_callback = ExceptionCallback.new
185+
error = assert_raises(RuntimeError) do
186+
solver.solve(model, exception_callback)
187+
end
188+
assert_equal "Error!", error.message
177189
end
178190

179191
# https://developers.google.com/optimization/cp/queens

0 commit comments

Comments
 (0)