Skip to content

Commit c989cc9

Browse files
committed
Refactor, use recursion to avoid repeating cases
1 parent d6b34a8 commit c989cc9

File tree

2 files changed

+33
-43
lines changed

2 files changed

+33
-43
lines changed

cpmpy/solvers/pindakaas.py

+32-42
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def transform(self, cpm_expr):
233233
cpm_cons = only_bv_reifies(cpm_cons)
234234
cpm_cons = only_implies(cpm_cons)
235235
cpm_cons = linearize_constraint(
236-
cpm_cons, supported=frozenset({"sum", "wsum", "and", "or", "bv"})
236+
cpm_cons, supported=frozenset({"sum", "wsum", "and", "or"})
237237
)
238238
return cpm_cons
239239

@@ -265,48 +265,36 @@ def __add__(self, cpm_expr_orig):
265265
# transform and post the constraints
266266
try:
267267
for cpm_expr in self.transform(cpm_expr_orig):
268-
if isinstance(cpm_expr, BoolVal):
269-
# base case: Boolean value
270-
if cpm_expr.args[0] is False:
271-
self.pkd_solver.add_clause([])
272-
273-
elif isinstance(cpm_expr, _BoolVarImpl):
274-
# base case, just var or ~var
275-
self.pkd_solver.add_clause([self.solver_var(cpm_expr)])
268+
self._add(cpm_expr)
269+
except pkd.Unsatisfiable:
270+
self.unsatisfiable = True
276271

277-
elif cpm_expr.name == "or":
278-
self.pkd_solver.add_clause(self.solver_vars(cpm_expr.args))
272+
return self
279273

280-
elif cpm_expr.name == "->":
281-
a0, a1 = cpm_expr.args
282-
self._add_bool_linear(a1, conditions=[~a0])
274+
def _add(self, cpm_expr, conditions=[]):
275+
import pindakaas as pkd
283276

284-
elif isinstance(cpm_expr, Comparison):
285-
self._add_bool_linear(cpm_expr)
277+
if isinstance(cpm_expr, BoolVal):
278+
# base case: Boolean value
279+
if cpm_expr.args[0] is False:
280+
self.pkd_solver.add_clause(conditions)
286281

287-
else:
288-
raise NotSupportedError(
289-
f"{self.name}: Unsupported constraint {cpm_expr}"
290-
)
291-
except pkd.Unsatisfiable:
292-
self.unsatisfiable = True
282+
elif isinstance(cpm_expr, _BoolVarImpl): # (implied) literal
283+
self.pkd_solver.add_clause(conditions + [self.solver_var(cpm_expr)])
293284

294-
return self
285+
elif cpm_expr.name == "or": # (implied) clause
286+
self.pkd_solver.add_clause(conditions + self.solver_vars(cpm_expr.args))
295287

296-
""" Unpack implied literal, clause, sum, or weighted sum """
288+
elif cpm_expr.name == "->": # implication
289+
a0, a1 = cpm_expr.args
290+
self._add(a1, conditions=conditions + [~self.solver_var(a0)])
297291

298-
def _add_bool_linear(self, cpm_expr, conditions=[]):
299-
import pindakaas as pkd
292+
elif isinstance(cpm_expr, Comparison): # Bool linear
293+
literals = None
294+
coefficients = None
295+
comparator = None
296+
k = None
300297

301-
literals = None
302-
coefficients = None
303-
comparator = None
304-
k = None
305-
if isinstance(cpm_expr, _BoolVarImpl):
306-
literals = [cpm_expr]
307-
elif isinstance(cpm_expr, Operator) and cpm_expr.name == "or":
308-
literals = cpm_expr.args
309-
elif isinstance(cpm_expr, Comparison):
310298
lhs, k = cpm_expr.args
311299
if lhs.name == "sum":
312300
literals = lhs.args
@@ -325,10 +313,12 @@ def _add_bool_linear(self, cpm_expr, conditions=[]):
325313
else:
326314
raise ValueError(f"Unsupported comparator: {cpm_expr.name}")
327315

328-
self.pkd_solver.add_linear(
329-
self.solver_vars(literals),
330-
coefficients=coefficients,
331-
comparator=comparator,
332-
k=k,
333-
conditions=self.solver_vars(conditions),
334-
)
316+
self.pkd_solver.add_linear(
317+
self.solver_vars(literals),
318+
coefficients=coefficients,
319+
comparator=comparator,
320+
k=k,
321+
conditions=conditions,
322+
)
323+
else:
324+
raise NotSupportedError(f"{self.name}: Unsupported constraint {cpm_expr}")

0 commit comments

Comments
 (0)