Skip to content

Commit 1a61996

Browse files
committed
int2bool transformation
1 parent 6ed4537 commit 1a61996

File tree

3 files changed

+315
-16
lines changed

3 files changed

+315
-16
lines changed

cpmpy/solvers/pysat.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@
5151
Module details
5252
==============
5353
"""
54+
from threading import Timer
55+
5456
from .solver_interface import SolverInterface, SolverStatus, ExitStatus
5557
from ..exceptions import NotSupportedError
5658
from ..expressions.core import Comparison, Operator, BoolVal
57-
from ..expressions.variables import _BoolVarImpl, NegBoolView, boolvar
59+
from ..expressions.variables import _BoolVarImpl, _IntVarImpl, NegBoolView, boolvar
5860
from ..expressions.globalconstraints import DirectConstraint
5961
from ..transformations.linearize import canonical_comparison, only_positive_coefficients
6062
from ..expressions.utils import is_int, flatlist
@@ -65,6 +67,7 @@
6567
from ..transformations.linearize import linearize_constraint
6668
from ..transformations.normalize import toplevel_list, simplify_boolean
6769
from ..transformations.reification import only_implies, only_bv_reifies, reify_rewrite
70+
from ..transformations.int2bool import int2bool, int2bool_make
6871

6972

7073
class CPM_pysat(SolverInterface):
@@ -161,6 +164,7 @@ def __init__(self, cpm_model=None, subsolver=None):
161164
# initialise the native solver object
162165
self.pysat_vpool = IDPool()
163166
self.pysat_solver = Solver(use_timer=True, name=subsolver)
167+
self.ivarmap = dict() # for the integer to boolean encoders
164168

165169
# initialise everything else and post the constraints/objective
166170
super().__init__(name="pysat:"+subsolver, cpm_model=cpm_model)
@@ -188,22 +192,25 @@ def solve(self, time_limit=None, assumptions=None):
188192
Note: the PySAT interface is statefull, so you can incrementally call solve() with assumptions and it will reuse learned clauses
189193
"""
190194

191-
# ensure all vars are known to solver
192-
self.solver_vars(list(self.user_vars))
195+
# ensure all Boolean vars are known to solver
196+
for v in list(self.user_vars): # can change during iteration
197+
if isinstance(v, _BoolVarImpl):
198+
self.solver_vars(v)
199+
else: # intvar
200+
if not v in self.ivarmap:
201+
self += int2bool_make(self.ivarmap, v)
193202

194203
if assumptions is None:
195204
pysat_assum_vars = [] # default if no assumptions
196205
else:
197206
pysat_assum_vars = self.solver_vars(assumptions)
198207
self.assumption_vars = assumptions
199208

200-
import time
201209
# set time limit
202210
if time_limit is not None:
203211
if time_limit <= 0:
204212
raise ValueError("Time limit must be positive")
205213

206-
from threading import Timer
207214
t = Timer(time_limit, lambda s: s.interrupt(), [self.pysat_solver])
208215
t.start()
209216
my_status = self.pysat_solver.solve_limited(assumptions=pysat_assum_vars, expect_interrupt=True)
@@ -237,13 +244,21 @@ def solve(self, time_limit=None, assumptions=None):
237244
sol = frozenset(self.pysat_solver.get_model()) # to speed up lookup
238245
# fill in variable values
239246
for cpm_var in self.user_vars:
240-
lit = self.solver_var(cpm_var)
241-
if lit in sol:
242-
cpm_var._value = True
243-
elif -lit in sol:
244-
cpm_var._value = False
245-
else: # not specified, dummy val
246-
cpm_var._value = True
247+
if isinstance(cpm_var, _BoolVarImpl):
248+
lit = self.solver_var(cpm_var)
249+
if lit in sol:
250+
cpm_var._value = True
251+
else: # -lit in sol (=False) or not specified (=False)
252+
cpm_var._value = False
253+
elif isinstance(cpm_var, _IntVarImpl):
254+
assert cpm_var.name in self.ivarmap, "Integer variable %s not found in ivarenc" % cpm_var.name
255+
varenc = self.ivarmap[cpm_var.name]
256+
lits = self.solver_vars(varenc.vars())
257+
# default value=False
258+
vals = [lit in sol for lit in lits]
259+
cpm_var._value = varenc.decode(vals)
260+
else:
261+
raise NotImplementedError(f"CPM_pysat: variable {cpm_var} not supported")
247262

248263
else: # clear values of variables
249264
for cpm_var in self.user_vars:
@@ -302,6 +317,7 @@ def transform(self, cpm_expr):
302317
cpm_cons = only_bv_reifies(cpm_cons)
303318
cpm_cons = only_implies(cpm_cons)
304319
cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum", "and", "or"})) # the core of the MIP-linearization
320+
cpm_cons = int2bool(cpm_cons, ivarmap=self.ivarmap)
305321
cpm_cons = only_positive_coefficients(cpm_cons)
306322
return cpm_cons
307323

cpmpy/transformations/int2bool.py

+283
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
"""
2+
Convert integer linear constraints to pseudo-boolean constraints
3+
"""
4+
5+
from typing import List
6+
import cpmpy as cp
7+
from abc import ABC, abstractmethod
8+
from ..expressions.variables import _BoolVarImpl, _IntVarImpl
9+
from ..expressions.core import Comparison, Operator
10+
from ..transformations.get_variables import get_variables
11+
from ..expressions.core import Expression
12+
13+
def int2bool(cpm_lst: List[Expression], ivarmap=None, encoding="auto"):
14+
"""
15+
Convert integer linear constraints to pseudo-boolean constraints
16+
"""
17+
assert encoding in ("auto", "direct"), "Only auto or direct encoding is supported"
18+
if ivarmap is None:
19+
ivarmap = dict()
20+
21+
cpm_out = []
22+
for expr in cpm_lst:
23+
vs = get_variables(expr)
24+
# skip all Boolean expressions
25+
if all(isinstance(v, _BoolVarImpl) for v in vs):
26+
cpm_out.append(expr)
27+
continue
28+
29+
# check if all variables are in the ivarmap
30+
for v in vs:
31+
if type(v) is _IntVarImpl and v.name not in ivarmap:
32+
cons = int2bool_make(ivarmap, v, encoding, cpm_out)
33+
cpm_out.extend(cons)
34+
35+
# we also need to support b -> subexpr
36+
# where subexpr's transformation is identical to non-reified expr
37+
# we do this with a special flag
38+
is_halfreif = False
39+
if expr.name == "->":
40+
is_halfreif = True
41+
b = expr.args[0] # PAY ATTENTION: we will overwrite expr by the rhs of the ->
42+
expr = expr.args[1]
43+
44+
# now replace intvars with their encoding
45+
if isinstance(expr, Comparison):
46+
# special case: lhs is a single intvar
47+
lhs,rhs = expr.args
48+
if type(lhs) is _IntVarImpl:
49+
cons = ivarmap[lhs.name].encode_comparison(expr.name, rhs)
50+
if is_halfreif:
51+
cpm_out.extend([b.implies(c) for c in cons])
52+
else:
53+
cpm_out.extend(cons)
54+
elif lhs.name == "wsum":
55+
# if its a wsum, insert encoding of terms
56+
newweights = []
57+
newvars = []
58+
for w,v in zip(*lhs.args):
59+
if type(v) is _IntVarImpl:
60+
# get list of weights/vars to add
61+
ws,vs = ivarmap[v.name].encode_term(w)
62+
newweights.extend(ws)
63+
newvars.extend(vs)
64+
else:
65+
newweights.append(w)
66+
newvars.append(v)
67+
# make the new comparison over the new wsum
68+
expr = Comparison(expr.name, Operator("wsum", (newweights, newvars)), rhs)
69+
if is_halfreif:
70+
cpm_out.append(b.implies(expr))
71+
else:
72+
cpm_out.append(expr)
73+
elif lhs.name == "sum":
74+
if len(lhs.args) == 1:
75+
assert type(lhs.args[0]) is _IntVarImpl, "Expected single intvar in sum"
76+
cons = ivarmap[lhs.args[0].name].encode_comparison(expr.name, rhs)
77+
if is_halfreif:
78+
cpm_out.extend([b.implies(c) for c in cons])
79+
else:
80+
cpm_out.extend(cons)
81+
else:
82+
# need to translate to wsum and insert encoding of terms
83+
newweights = []
84+
newvars = []
85+
for v in lhs.args:
86+
if type(v) is _IntVarImpl:
87+
ws,vs = ivarmap[v.name].encode_term()
88+
newweights.extend(ws)
89+
newvars.extend(vs)
90+
else:
91+
newweights.append(1)
92+
newvars.append(v)
93+
# make the new comparison over the new wsum
94+
expr = Comparison(expr.name, Operator("wsum", (newweights, newvars)), rhs)
95+
if is_halfreif:
96+
cpm_out.append(b.implies(expr))
97+
else:
98+
cpm_out.append(expr)
99+
else:
100+
raise NotImplementedError(f"int2bool: comparison with lhs {lhs} not (yet?) supported")
101+
else:
102+
raise NotImplementedError(f"int2bool: non-comparison {expr} not (yet?) supported")
103+
104+
return cpm_out
105+
106+
def int2bool_wsum(expr: Expression, ivarmap, encoding="auto"):
107+
"""
108+
Convert a weighted sum to a pseudo-boolean constraint
109+
110+
Accepts only bool/int/sum/wsum expressions
111+
112+
Returns (newexpr, newcons)
113+
"""
114+
vs = get_variables(expr)
115+
# skip all Boolean expressions
116+
if all(isinstance(v, _BoolVarImpl) for v in vs):
117+
return expr, []
118+
119+
# check if all variables are in the ivarmap, add constraints if not
120+
newcons = []
121+
for v in vs:
122+
if type(v) is _IntVarImpl and v.name not in ivarmap:
123+
cons = int2bool_make(ivarmap, v, encoding)
124+
newcons.extend(cons)
125+
126+
if isinstance(expr, _IntVarImpl):
127+
ws,vs = ivarmap[expr.name].encode_term()
128+
return Operator("wsum", (ws, vs)), newcons
129+
130+
# rest: sum or wsum
131+
if expr.name == "sum":
132+
w = [1]*len(expr.args)
133+
v = expr.args
134+
elif expr.name == "wsum":
135+
w,v = expr.args
136+
else:
137+
raise NotImplementedError(f"int2bool_wsum: non-sum/wsum expression {expr} not supported")
138+
139+
new_w, new_v = [], []
140+
for wi,vi in zip(w,v):
141+
if type(vi) is _IntVarImpl:
142+
# get list of weights/vars to add
143+
ws,vs = ivarmap[vi.name].encode_term(wi)
144+
new_w.extend(ws)
145+
new_v.extend(vs)
146+
else:
147+
new_w.append(wi)
148+
new_v.append(vi)
149+
150+
return Operator("wsum", (new_w, new_v)), newcons
151+
152+
153+
def int2bool_make(ivarmap, v, encoding="auto", expr=None):
154+
"""
155+
Make the encoding for an integer variable
156+
"""
157+
# for now, the only encoding is 'direct', so we dont inspect 'expr' at all
158+
enc = IntVarEncDirect(v)
159+
ivarmap[v.name] = enc
160+
return enc.encode_self()
161+
162+
class IntVarEnc(ABC):
163+
"""
164+
Abstract base class for integer variable encodings.
165+
"""
166+
def __init__(self, varname):
167+
self.varname = varname
168+
169+
@abstractmethod
170+
def vars(self):
171+
"""
172+
Return the Boolean variables in the encoding.
173+
"""
174+
pass
175+
176+
def decode(self, vals):
177+
"""
178+
Decode the Boolean values to the integer value.
179+
"""
180+
pass
181+
182+
@abstractmethod
183+
def encode_self(self):
184+
"""
185+
Return consistency constraints for the encoding.
186+
187+
Returns:
188+
List[Expression]: a list of constraints
189+
"""
190+
pass
191+
192+
@abstractmethod
193+
def encode_comparison(self, op, rhs):
194+
"""
195+
Encode a comparison over the variable: self <op> rhs
196+
197+
Args:
198+
op: The comparison operator ("==", "!=", "<", "<=", ">", ">=")
199+
rhs: The right-hand side value to compare against
200+
201+
Returns:
202+
List[Expression]: a list of constraints
203+
"""
204+
pass
205+
206+
@abstractmethod
207+
def encode_term(self, w=1):
208+
"""
209+
Encode w*self as a weighted sum of Boolean variables
210+
211+
Args:
212+
w: The weight to multiply the variable by
213+
214+
Returns:
215+
tuple: (weights, variables) where weights is a list of weights and
216+
variables is a list of Boolean variables
217+
"""
218+
pass
219+
220+
class IntVarEncDirect(IntVarEnc):
221+
"""
222+
Direct (or sparse or one-hot) encoding of an integer variable.
223+
224+
Uses a Boolean 'equality' variable for each value in the domain.
225+
"""
226+
def __init__(self, v):
227+
super().__init__(v.name)
228+
self.offset = v.lb
229+
n = v.ub+1-v.lb # number of Boolean variables
230+
self.bvars = cp.boolvar(shape=n, name=f"EncDir({v.name})")
231+
232+
def vars(self):
233+
return self.bvars
234+
235+
def decode(self, vals):
236+
"""
237+
Decode the Boolean values to the integer value.
238+
"""
239+
assert sum(vals) == 1, f"Expected exactly one True value in vals: {vals}"
240+
return sum(i for i,v in enumerate(vals) if v) + self.offset
241+
242+
def encode_self(self):
243+
"""
244+
Return consistency constraints
245+
246+
Variable x has exactly one value from domain,
247+
so only one of the Boolean variables can be True
248+
"""
249+
return [cp.sum(self.bvars) == 1]
250+
251+
def encode_comparison(self, op, rhs):
252+
"""
253+
Encode a comparison over the variable: self <op> rhs
254+
"""
255+
if op == "==":
256+
# one yes, hence also rest no
257+
return [b if i==(rhs-self.offset) else ~b for i,b in enumerate(self.bvars)]
258+
elif op == "!=":
259+
return [~self.bvars[rhs - self.offset]]
260+
elif op == "<":
261+
# all higher-or-equal values are False
262+
return list(~self.bvars[rhs-self.offset:])
263+
elif op == "<=":
264+
# all higher values are False
265+
return list(~self.bvars[rhs-self.offset+1:])
266+
elif op == ">":
267+
# all lower values are False
268+
return list(~self.bvars[:rhs-self.offset+1])
269+
elif op == ">=":
270+
# all lower-or-equal values are False
271+
return list(~self.bvars[:rhs-self.offset])
272+
else:
273+
raise NotImplementedError(f"int2bool: comparison with op {op} unknown")
274+
275+
def encode_term(self, w=1):
276+
"""
277+
Rewrite term w*self to terms [w1, w2 ,...]*[bv1, bv2, ...]
278+
"""
279+
o = self.offset
280+
return [w*(o+i) for i in range(len(self.bvars))], self.bvars
281+
282+
# TODO: class IntVarEncOrder(IntVarEnc)
283+
# TODO: class IntVarEncLog(IntVarEnc)

0 commit comments

Comments
 (0)