Skip to content

Commit 766be56

Browse files
committed
[CP-SAT] add float value to the CpSolver and solution callback classes
1 parent dc267cd commit 766be56

File tree

5 files changed

+142
-16
lines changed

5 files changed

+142
-16
lines changed

ortools/sat/python/cp_model.py

+46
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,35 @@ def values(self, variables: _IndexOrSeries) -> pd.Series:
25212521
index=_get_index(variables),
25222522
)
25232523

2524+
def float_value(self, expression: LinearExprT) -> float:
2525+
"""Returns the value of a linear expression after solve."""
2526+
return self._checked_response.float_value(expression)
2527+
2528+
def float_values(self, expressions: _IndexOrSeries) -> pd.Series:
2529+
"""Returns the float values of the input linear expressions.
2530+
2531+
If `expressions` is a `pd.Index`, then the output will be indexed by the
2532+
variables. If `variables` is a `pd.Series` indexed by the underlying
2533+
dimensions, then the output will be indexed by the same underlying
2534+
dimensions.
2535+
2536+
Args:
2537+
expressions (Union[pd.Index, pd.Series]): The set of expressions from
2538+
which to get the values.
2539+
2540+
Returns:
2541+
pd.Series: The values of all variables in the set.
2542+
2543+
Raises:
2544+
RuntimeError: if solve() has not been called.
2545+
"""
2546+
if self.__response_wrapper is None:
2547+
raise RuntimeError("solve() has not been called.")
2548+
return pd.Series(
2549+
data=[self.__response_wrapper.float_value(expr) for expr in expressions],
2550+
index=_get_index(expressions),
2551+
)
2552+
25242553
def boolean_value(self, literal: LiteralT) -> bool:
25252554
"""Returns the boolean value of a literal after solve."""
25262555
return self._checked_response.boolean_value(literal)
@@ -2796,6 +2825,23 @@ def value(self, expression: LinearExprT) -> int:
27962825
raise RuntimeError("solve() has not been called.")
27972826
return self.Value(expression)
27982827

2828+
def float_value(self, expression: LinearExprT) -> float:
2829+
"""Evaluates an linear expression in the current solution.
2830+
2831+
Args:
2832+
expression: a linear expression of the model.
2833+
2834+
Returns:
2835+
An integer value equal to the evaluation of the linear expression
2836+
against the current solution.
2837+
2838+
Raises:
2839+
RuntimeError: if 'expression' is not a LinearExpr.
2840+
"""
2841+
if not self.has_response():
2842+
raise RuntimeError("solve() has not been called.")
2843+
return self.FloatValue(expression)
2844+
27992845
def has_response(self) -> bool:
28002846
return self.HasResponse()
28012847

ortools/sat/python/cp_model_helper.cc

+31-4
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,19 @@ class ResponseWrapper {
163163

164164
double UserTime() const { return response_.user_time(); }
165165

166+
double FloatValue(std::shared_ptr<LinearExpr> expr) const {
167+
FloatExprVisitor visitor;
168+
visitor.AddToProcess(expr, 1);
169+
return visitor.Evaluate(response_);
170+
}
171+
172+
double FixedFloatValue(double value) const { return value; }
173+
166174
int64_t Value(std::shared_ptr<LinearExpr> expr) const {
167-
IntExprVisitor visitor;
168175
int64_t value;
169-
if (!visitor.Evaluate(expr, response_, &value)) {
176+
IntExprVisitor visitor;
177+
visitor.AddToProcess(expr, 1);
178+
if (!visitor.Evaluate(response_, &value)) {
170179
ThrowError(PyExc_ValueError,
171180
absl::StrCat("Failed to evaluate linear expression: ",
172181
expr->DebugString()));
@@ -453,9 +462,10 @@ PYBIND11_MODULE(cp_model_helper, m) {
453462
"Value",
454463
[](const SolutionCallback& callback,
455464
std::shared_ptr<LinearExpr> expr) {
456-
IntExprVisitor visitor;
457465
int64_t value;
458-
if (!visitor.Evaluate(expr, callback.Response(), &value)) {
466+
IntExprVisitor visitor;
467+
visitor.AddToProcess(expr, 1);
468+
if (!visitor.Evaluate(callback.Response(), &value)) {
459469
ThrowError(PyExc_ValueError,
460470
absl::StrCat("Failed to evaluate linear expression: ",
461471
expr->DebugString()));
@@ -466,6 +476,21 @@ PYBIND11_MODULE(cp_model_helper, m) {
466476
.def(
467477
"Value", [](const SolutionCallback&, int64_t value) { return value; },
468478
"Returns the value of a linear expression after solve.")
479+
.def(
480+
"FloatValue",
481+
[](const SolutionCallback& callback,
482+
std::shared_ptr<LinearExpr> expr) {
483+
FloatExprVisitor visitor;
484+
visitor.AddToProcess(expr, 1.0);
485+
return visitor.Evaluate(callback.Response());
486+
},
487+
"Returns the value of a floating point linear expression after "
488+
"solve.")
489+
.def(
490+
"FloatValue",
491+
[](const SolutionCallback&, double value) { return value; },
492+
"Returns the value of a floating point linear expression after "
493+
"solve.")
469494
.def(
470495
"BooleanValue",
471496
[](const SolutionCallback& callback, std::shared_ptr<Literal> lit) {
@@ -495,6 +520,8 @@ PYBIND11_MODULE(cp_model_helper, m) {
495520
.def("sufficient_assumptions_for_infeasibility",
496521
&ResponseWrapper::SufficientAssumptionsForInfeasibility)
497522
.def("user_time", &ResponseWrapper::UserTime)
523+
.def("float_value", &ResponseWrapper::FloatValue, py::arg("expr"))
524+
.def("float_value", &ResponseWrapper::FixedFloatValue, py::arg("value"))
498525
.def("value", &ResponseWrapper::Value, py::arg("expr"))
499526
.def("value", &ResponseWrapper::FixedValue, py::arg("value"))
500527
.def("wall_time", &ResponseWrapper::WallTime);

ortools/sat/python/cp_model_test.py

+39
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ def sum(self) -> int:
5757
return self.__sum
5858

5959

60+
class SolutionFloatValue(cp_model.CpSolverSolutionCallback):
61+
"""Record the evaluation of a float expression in the solution."""
62+
63+
def __init__(self, expr: cp_model.LinearExpr) -> None:
64+
cp_model.CpSolverSolutionCallback.__init__(self)
65+
self.__expr: cp_model.LinearExpr = expr
66+
self.__value: float = 0.0
67+
68+
def on_solution_callback(self) -> None:
69+
self.__value = self.float_value(self.__expr)
70+
71+
@property
72+
def value(self) -> float:
73+
return self.__value
74+
75+
6076
class SolutionObjective(cp_model.CpSolverSolutionCallback):
6177
"""Record the objective value of the solution."""
6278

@@ -1515,6 +1531,18 @@ def test_solve_with_solution_callback(self) -> None:
15151531
self.assertEqual(cp_model.OPTIMAL, status)
15161532
self.assertEqual(6, solution_sum.sum)
15171533

1534+
def test_solve_with_float_value_in_callback(self) -> None:
1535+
model = cp_model.CpModel()
1536+
x = model.new_int_var(0, 5, "x")
1537+
y = model.new_int_var(0, 5, "y")
1538+
model.add_linear_constraint(x + y, 6, 6)
1539+
1540+
solver = cp_model.CpSolver()
1541+
solution_float_value = SolutionFloatValue((x + y) * 0.5)
1542+
status = solver.solve(model, solution_float_value)
1543+
self.assertEqual(cp_model.OPTIMAL, status)
1544+
self.assertEqual(3.0, solution_float_value.value)
1545+
15181546
def test_best_bound_callback(self) -> None:
15191547
model = cp_model.CpModel()
15201548
x0 = model.new_bool_var("x0")
@@ -1545,6 +1573,17 @@ def test_value(self) -> None:
15451573
self.assertEqual(solver.value(y), 10)
15461574
self.assertEqual(solver.value(2), 2)
15471575

1576+
def test_float_value(self) -> None:
1577+
model = cp_model.CpModel()
1578+
x = model.new_int_var(0, 10, "x")
1579+
y = model.new_int_var(0, 10, "y")
1580+
model.add(x + 2 * y == 29)
1581+
solver = cp_model.CpSolver()
1582+
status = solver.solve(model)
1583+
self.assertEqual(cp_model.OPTIMAL, status)
1584+
self.assertEqual(solver.float_value(x * 1.5 + 0.25), 13.75)
1585+
self.assertEqual(solver.float_value(2.25), 2.25)
1586+
15481587
def test_boolean_value(self) -> None:
15491588
model = cp_model.CpModel()
15501589
x = model.new_bool_var("x")

ortools/sat/python/linear_expr.cc

+22-8
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,25 @@ void FloatExprVisitor::AddToProcess(std::shared_ptr<LinearExpr> expr,
130130
double coeff) {
131131
to_process_.push_back(std::make_pair(expr, coeff));
132132
}
133+
133134
void FloatExprVisitor::AddConstant(double constant) { offset_ += constant; }
135+
134136
void FloatExprVisitor::AddVarCoeff(std::shared_ptr<BaseIntVar> var,
135137
double coeff) {
136138
canonical_terms_[var] += coeff;
137139
}
138-
double FloatExprVisitor::Process(std::shared_ptr<LinearExpr> expr,
139-
std::vector<std::shared_ptr<BaseIntVar>>* vars,
140-
std::vector<double>* coeffs) {
141-
AddToProcess(expr, 1.0);
140+
141+
void FloatExprVisitor::ProcessAll() {
142142
while (!to_process_.empty()) {
143143
const auto [expr, coeff] = to_process_.back();
144144
to_process_.pop_back();
145145
expr->VisitAsFloat(*this, coeff);
146146
}
147+
}
148+
149+
double FloatExprVisitor::Process(std::vector<std::shared_ptr<BaseIntVar>>* vars,
150+
std::vector<double>* coeffs) {
151+
ProcessAll();
147152

148153
vars->clear();
149154
coeffs->clear();
@@ -156,9 +161,20 @@ double FloatExprVisitor::Process(std::shared_ptr<LinearExpr> expr,
156161
return offset_;
157162
}
158163

164+
double FloatExprVisitor::Evaluate(const CpSolverResponse& solution) {
165+
ProcessAll();
166+
167+
for (const auto& [var, coeff] : canonical_terms_) {
168+
if (coeff == 0) continue;
169+
offset_ += coeff * solution.solution(var->index());
170+
}
171+
return offset_;
172+
}
173+
159174
FlatFloatExpr::FlatFloatExpr(std::shared_ptr<LinearExpr> expr) {
160175
FloatExprVisitor lin;
161-
offset_ = lin.Process(expr, &vars_, &coeffs_);
176+
lin.AddToProcess(expr, 1.0);
177+
offset_ = lin.Process(&vars_, &coeffs_);
162178
}
163179

164180
void FlatFloatExpr::VisitAsFloat(FloatExprVisitor& lin, double c) {
@@ -735,10 +751,8 @@ bool IntExprVisitor::Process(std::vector<std::shared_ptr<BaseIntVar>>* vars,
735751
return true;
736752
}
737753

738-
bool IntExprVisitor::Evaluate(std::shared_ptr<LinearExpr> expr,
739-
const CpSolverResponse& solution,
754+
bool IntExprVisitor::Evaluate(const CpSolverResponse& solution,
740755
int64_t* value) {
741-
AddToProcess(expr, 1);
742756
if (!ProcessAll()) return false;
743757

744758
*value = offset_;

ortools/sat/python/linear_expr.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ class FloatExprVisitor {
160160
void AddToProcess(std::shared_ptr<LinearExpr> expr, double coeff);
161161
void AddConstant(double constant);
162162
void AddVarCoeff(std::shared_ptr<BaseIntVar> var, double coeff);
163-
double Process(std::shared_ptr<LinearExpr> expr,
164-
std::vector<std::shared_ptr<BaseIntVar>>* vars,
163+
void ProcessAll();
164+
double Process(std::vector<std::shared_ptr<BaseIntVar>>* vars,
165165
std::vector<double>* coeffs);
166+
double Evaluate(const CpSolverResponse& solution);
166167

167168
private:
168169
std::vector<std::pair<std::shared_ptr<LinearExpr>, double>> to_process_;
@@ -212,8 +213,7 @@ class IntExprVisitor {
212213
bool ProcessAll();
213214
bool Process(std::vector<std::shared_ptr<BaseIntVar>>* vars,
214215
std::vector<int64_t>* coeffs, int64_t* offset);
215-
bool Evaluate(std::shared_ptr<LinearExpr> expr,
216-
const CpSolverResponse& solution, int64_t* value);
216+
bool Evaluate(const CpSolverResponse& solution, int64_t* value);
217217

218218
private:
219219
std::vector<std::pair<std::shared_ptr<LinearExpr>, int64_t>> to_process_;

0 commit comments

Comments
 (0)