@@ -233,7 +233,7 @@ def transform(self, cpm_expr):
233
233
cpm_cons = only_bv_reifies (cpm_cons )
234
234
cpm_cons = only_implies (cpm_cons )
235
235
cpm_cons = linearize_constraint (
236
- cpm_cons , supported = frozenset ({"sum" , "wsum" , "and" , "or" , "bv" })
236
+ cpm_cons , supported = frozenset ({"sum" , "wsum" , "and" , "or" })
237
237
)
238
238
return cpm_cons
239
239
@@ -265,48 +265,36 @@ def __add__(self, cpm_expr_orig):
265
265
# transform and post the constraints
266
266
try :
267
267
for cpm_expr in self .transform (cpm_expr_orig ):
268
- if isinstance (cpm_expr , BoolVal ):
269
- # base case: Boolean value
270
- if cpm_expr .args [0 ] is False :
271
- self .pkd_solver .add_clause ([])
272
-
273
- elif isinstance (cpm_expr , _BoolVarImpl ):
274
- # base case, just var or ~var
275
- self .pkd_solver .add_clause ([self .solver_var (cpm_expr )])
268
+ self ._add (cpm_expr )
269
+ except pkd .Unsatisfiable :
270
+ self .unsatisfiable = True
276
271
277
- elif cpm_expr .name == "or" :
278
- self .pkd_solver .add_clause (self .solver_vars (cpm_expr .args ))
272
+ return self
279
273
280
- elif cpm_expr .name == "->" :
281
- a0 , a1 = cpm_expr .args
282
- self ._add_bool_linear (a1 , conditions = [~ a0 ])
274
+ def _add (self , cpm_expr , conditions = []):
275
+ import pindakaas as pkd
283
276
284
- elif isinstance (cpm_expr , Comparison ):
285
- self ._add_bool_linear (cpm_expr )
277
+ if isinstance (cpm_expr , BoolVal ):
278
+ # base case: Boolean value
279
+ if cpm_expr .args [0 ] is False :
280
+ self .pkd_solver .add_clause (conditions )
286
281
287
- else :
288
- raise NotSupportedError (
289
- f"{ self .name } : Unsupported constraint { cpm_expr } "
290
- )
291
- except pkd .Unsatisfiable :
292
- self .unsatisfiable = True
282
+ elif isinstance (cpm_expr , _BoolVarImpl ): # (implied) literal
283
+ self .pkd_solver .add_clause (conditions + [self .solver_var (cpm_expr )])
293
284
294
- return self
285
+ elif cpm_expr .name == "or" : # (implied) clause
286
+ self .pkd_solver .add_clause (conditions + self .solver_vars (cpm_expr .args ))
295
287
296
- """ Unpack implied literal, clause, sum, or weighted sum """
288
+ elif cpm_expr .name == "->" : # implication
289
+ a0 , a1 = cpm_expr .args
290
+ self ._add (a1 , conditions = conditions + [~ self .solver_var (a0 )])
297
291
298
- def _add_bool_linear (self , cpm_expr , conditions = []):
299
- import pindakaas as pkd
292
+ elif isinstance (cpm_expr , Comparison ): # Bool linear
293
+ literals = None
294
+ coefficients = None
295
+ comparator = None
296
+ k = None
300
297
301
- literals = None
302
- coefficients = None
303
- comparator = None
304
- k = None
305
- if isinstance (cpm_expr , _BoolVarImpl ):
306
- literals = [cpm_expr ]
307
- elif isinstance (cpm_expr , Operator ) and cpm_expr .name == "or" :
308
- literals = cpm_expr .args
309
- elif isinstance (cpm_expr , Comparison ):
310
298
lhs , k = cpm_expr .args
311
299
if lhs .name == "sum" :
312
300
literals = lhs .args
@@ -325,10 +313,12 @@ def _add_bool_linear(self, cpm_expr, conditions=[]):
325
313
else :
326
314
raise ValueError (f"Unsupported comparator: { cpm_expr .name } " )
327
315
328
- self .pkd_solver .add_linear (
329
- self .solver_vars (literals ),
330
- coefficients = coefficients ,
331
- comparator = comparator ,
332
- k = k ,
333
- conditions = self .solver_vars (conditions ),
334
- )
316
+ self .pkd_solver .add_linear (
317
+ self .solver_vars (literals ),
318
+ coefficients = coefficients ,
319
+ comparator = comparator ,
320
+ k = k ,
321
+ conditions = conditions ,
322
+ )
323
+ else :
324
+ raise NotSupportedError (f"{ self .name } : Unsupported constraint { cpm_expr } " )
0 commit comments