1
+ """
2
+ Convert integer linear constraints to pseudo-boolean constraints
3
+ """
4
+
5
+ from typing import List
6
+ import cpmpy as cp
7
+ from abc import ABC , abstractmethod
8
+ from ..expressions .variables import _BoolVarImpl , _IntVarImpl
9
+ from ..expressions .core import Comparison , Operator
10
+ from ..transformations .get_variables import get_variables
11
+ from ..expressions .core import Expression
12
+
13
+ def int2bool (cpm_lst : List [Expression ], ivarmap = None , encoding = "auto" ):
14
+ """
15
+ Convert integer linear constraints to pseudo-boolean constraints
16
+ """
17
+ assert encoding in ("auto" , "direct" ), "Only auto or direct encoding is supported"
18
+ if ivarmap is None :
19
+ ivarmap = dict ()
20
+
21
+ cpm_out = []
22
+ for expr in cpm_lst :
23
+ vs = get_variables (expr )
24
+ # skip all Boolean expressions
25
+ if all (isinstance (v , _BoolVarImpl ) for v in vs ):
26
+ cpm_out .append (expr )
27
+ continue
28
+
29
+ # check if all variables are in the ivarmap
30
+ for v in vs :
31
+ if type (v ) is _IntVarImpl and v .name not in ivarmap :
32
+ cons = int2bool_make (ivarmap , v , encoding , cpm_out )
33
+ cpm_out .extend (cons )
34
+
35
+ # we also need to support b -> subexpr
36
+ # where subexpr's transformation is identical to non-reified expr
37
+ # we do this with a special flag
38
+ is_halfreif = False
39
+ if expr .name == "->" :
40
+ is_halfreif = True
41
+ b = expr .args [0 ] # PAY ATTENTION: we will overwrite expr by the rhs of the ->
42
+ expr = expr .args [1 ]
43
+
44
+ # now replace intvars with their encoding
45
+ if isinstance (expr , Comparison ):
46
+ # special case: lhs is a single intvar
47
+ lhs ,rhs = expr .args
48
+ if type (lhs ) is _IntVarImpl :
49
+ cons = ivarmap [lhs .name ].encode_comparison (expr .name , rhs )
50
+ if is_halfreif :
51
+ cpm_out .extend ([b .implies (c ) for c in cons ])
52
+ else :
53
+ cpm_out .extend (cons )
54
+ elif lhs .name == "wsum" :
55
+ # if its a wsum, insert encoding of terms
56
+ newweights = []
57
+ newvars = []
58
+ for w ,v in zip (* lhs .args ):
59
+ if type (v ) is _IntVarImpl :
60
+ # get list of weights/vars to add
61
+ ws ,vs = ivarmap [v .name ].encode_term (w )
62
+ newweights .extend (ws )
63
+ newvars .extend (vs )
64
+ else :
65
+ newweights .append (w )
66
+ newvars .append (v )
67
+ # make the new comparison over the new wsum
68
+ expr = Comparison (expr .name , Operator ("wsum" , (newweights , newvars )), rhs )
69
+ if is_halfreif :
70
+ cpm_out .append (b .implies (expr ))
71
+ else :
72
+ cpm_out .append (expr )
73
+ elif lhs .name == "sum" :
74
+ if len (lhs .args ) == 1 :
75
+ assert type (lhs .args [0 ]) is _IntVarImpl , "Expected single intvar in sum"
76
+ cons = ivarmap [lhs .args [0 ].name ].encode_comparison (expr .name , rhs )
77
+ if is_halfreif :
78
+ cpm_out .extend ([b .implies (c ) for c in cons ])
79
+ else :
80
+ cpm_out .extend (cons )
81
+ else :
82
+ # need to translate to wsum and insert encoding of terms
83
+ newweights = []
84
+ newvars = []
85
+ for v in lhs .args :
86
+ if type (v ) is _IntVarImpl :
87
+ ws ,vs = ivarmap [v .name ].encode_term ()
88
+ newweights .extend (ws )
89
+ newvars .extend (vs )
90
+ else :
91
+ newweights .append (1 )
92
+ newvars .append (v )
93
+ # make the new comparison over the new wsum
94
+ expr = Comparison (expr .name , Operator ("wsum" , (newweights , newvars )), rhs )
95
+ if is_halfreif :
96
+ cpm_out .append (b .implies (expr ))
97
+ else :
98
+ cpm_out .append (expr )
99
+ else :
100
+ raise NotImplementedError (f"int2bool: comparison with lhs { lhs } not (yet?) supported" )
101
+ else :
102
+ raise NotImplementedError (f"int2bool: non-comparison { expr } not (yet?) supported" )
103
+
104
+ return cpm_out
105
+
106
+ def int2bool_wsum (expr : Expression , ivarmap , encoding = "auto" ):
107
+ """
108
+ Convert a weighted sum to a pseudo-boolean constraint
109
+
110
+ Accepts only bool/int/sum/wsum expressions
111
+
112
+ Returns (newexpr, newcons)
113
+ """
114
+ vs = get_variables (expr )
115
+ # skip all Boolean expressions
116
+ if all (isinstance (v , _BoolVarImpl ) for v in vs ):
117
+ return expr , []
118
+
119
+ # check if all variables are in the ivarmap, add constraints if not
120
+ newcons = []
121
+ for v in vs :
122
+ if type (v ) is _IntVarImpl and v .name not in ivarmap :
123
+ cons = int2bool_make (ivarmap , v , encoding )
124
+ newcons .extend (cons )
125
+
126
+ if isinstance (expr , _IntVarImpl ):
127
+ ws ,vs = ivarmap [expr .name ].encode_term ()
128
+ return Operator ("wsum" , (ws , vs )), newcons
129
+
130
+ # rest: sum or wsum
131
+ if expr .name == "sum" :
132
+ w = [1 ]* len (expr .args )
133
+ v = expr .args
134
+ elif expr .name == "wsum" :
135
+ w ,v = expr .args
136
+ else :
137
+ raise NotImplementedError (f"int2bool_wsum: non-sum/wsum expression { expr } not supported" )
138
+
139
+ new_w , new_v = [], []
140
+ for wi ,vi in zip (w ,v ):
141
+ if type (vi ) is _IntVarImpl :
142
+ # get list of weights/vars to add
143
+ ws ,vs = ivarmap [vi .name ].encode_term (wi )
144
+ new_w .extend (ws )
145
+ new_v .extend (vs )
146
+ else :
147
+ new_w .append (wi )
148
+ new_v .append (vi )
149
+
150
+ return Operator ("wsum" , (new_w , new_v )), newcons
151
+
152
+
153
+ def int2bool_make (ivarmap , v , encoding = "auto" , expr = None ):
154
+ """
155
+ Make the encoding for an integer variable
156
+ """
157
+ # for now, the only encoding is 'direct', so we dont inspect 'expr' at all
158
+ enc = IntVarEncDirect (v )
159
+ ivarmap [v .name ] = enc
160
+ return enc .encode_self ()
161
+
162
+ class IntVarEnc (ABC ):
163
+ """
164
+ Abstract base class for integer variable encodings.
165
+ """
166
+ def __init__ (self , varname ):
167
+ self .varname = varname
168
+
169
+ @abstractmethod
170
+ def vars (self ):
171
+ """
172
+ Return the Boolean variables in the encoding.
173
+ """
174
+ pass
175
+
176
+ def decode (self , vals ):
177
+ """
178
+ Decode the Boolean values to the integer value.
179
+ """
180
+ pass
181
+
182
+ @abstractmethod
183
+ def encode_self (self ):
184
+ """
185
+ Return consistency constraints for the encoding.
186
+
187
+ Returns:
188
+ List[Expression]: a list of constraints
189
+ """
190
+ pass
191
+
192
+ @abstractmethod
193
+ def encode_comparison (self , op , rhs ):
194
+ """
195
+ Encode a comparison over the variable: self <op> rhs
196
+
197
+ Args:
198
+ op: The comparison operator ("==", "!=", "<", "<=", ">", ">=")
199
+ rhs: The right-hand side value to compare against
200
+
201
+ Returns:
202
+ List[Expression]: a list of constraints
203
+ """
204
+ pass
205
+
206
+ @abstractmethod
207
+ def encode_term (self , w = 1 ):
208
+ """
209
+ Encode w*self as a weighted sum of Boolean variables
210
+
211
+ Args:
212
+ w: The weight to multiply the variable by
213
+
214
+ Returns:
215
+ tuple: (weights, variables) where weights is a list of weights and
216
+ variables is a list of Boolean variables
217
+ """
218
+ pass
219
+
220
+ class IntVarEncDirect (IntVarEnc ):
221
+ """
222
+ Direct (or sparse or one-hot) encoding of an integer variable.
223
+
224
+ Uses a Boolean 'equality' variable for each value in the domain.
225
+ """
226
+ def __init__ (self , v ):
227
+ super ().__init__ (v .name )
228
+ self .offset = v .lb
229
+ n = v .ub + 1 - v .lb # number of Boolean variables
230
+ self .bvars = cp .boolvar (shape = n , name = f"EncDir({ v .name } )" )
231
+
232
+ def vars (self ):
233
+ return self .bvars
234
+
235
+ def decode (self , vals ):
236
+ """
237
+ Decode the Boolean values to the integer value.
238
+ """
239
+ assert sum (vals ) == 1 , f"Expected exactly one True value in vals: { vals } "
240
+ return sum (i for i ,v in enumerate (vals ) if v ) + self .offset
241
+
242
+ def encode_self (self ):
243
+ """
244
+ Return consistency constraints
245
+
246
+ Variable x has exactly one value from domain,
247
+ so only one of the Boolean variables can be True
248
+ """
249
+ return [cp .sum (self .bvars ) == 1 ]
250
+
251
+ def encode_comparison (self , op , rhs ):
252
+ """
253
+ Encode a comparison over the variable: self <op> rhs
254
+ """
255
+ if op == "==" :
256
+ # one yes, hence also rest no
257
+ return [b if i == (rhs - self .offset ) else ~ b for i ,b in enumerate (self .bvars )]
258
+ elif op == "!=" :
259
+ return [~ self .bvars [rhs - self .offset ]]
260
+ elif op == "<" :
261
+ # all higher-or-equal values are False
262
+ return list (~ self .bvars [rhs - self .offset :])
263
+ elif op == "<=" :
264
+ # all higher values are False
265
+ return list (~ self .bvars [rhs - self .offset + 1 :])
266
+ elif op == ">" :
267
+ # all lower values are False
268
+ return list (~ self .bvars [:rhs - self .offset + 1 ])
269
+ elif op == ">=" :
270
+ # all lower-or-equal values are False
271
+ return list (~ self .bvars [:rhs - self .offset ])
272
+ else :
273
+ raise NotImplementedError (f"int2bool: comparison with op { op } unknown" )
274
+
275
+ def encode_term (self , w = 1 ):
276
+ """
277
+ Rewrite term w*self to terms [w1, w2 ,...]*[bv1, bv2, ...]
278
+ """
279
+ o = self .offset
280
+ return [w * (o + i ) for i in range (len (self .bvars ))], self .bvars
281
+
282
+ # TODO: class IntVarEncOrder(IntVarEnc)
283
+ # TODO: class IntVarEncLog(IntVarEnc)
0 commit comments