Skip to content

Commit b4e069a

Browse files
committed
Implemented basic knapsack cut
1 parent 52f7407 commit b4e069a

8 files changed

Lines changed: 717 additions & 1 deletion

File tree

src/configuration/OptionParser.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ void OptionParser::initialize()
136136
"prove-unsat",
137137
boost::program_options::bool_switch( &( ( *_boolOptions )[Options::PRODUCE_PROOFS] ) )
138138
->default_value( ( *_boolOptions )[Options::PRODUCE_PROOFS] ),
139-
"Produce proofs of UNSAT and check them" )
139+
"Produce proofs of UNSAT and check them" )(
140+
"knapsack-cuts",
141+
boost::program_options::bool_switch( &( ( *_boolOptions )[Options::KNAPSACK_CUTS] ) )
142+
->default_value( ( *_boolOptions )[Options::KNAPSACK_CUTS] ),
143+
"Learn knapsack cuts at UNSAT leaves to prune subsumed subproblems" )
140144
#ifdef ENABLE_GUROBI
141145
#endif // ENABLE_GUROBI
142146
;

src/configuration/Options.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ void Options::initializeDefaultValues()
5555
_boolOptions[DEBUG_ASSIGNMENT] = false;
5656
_boolOptions[PRODUCE_PROOFS] = false;
5757
_boolOptions[DO_NOT_MERGE_CONSECUTIVE_WEIGHTED_SUM_LAYERS] = false;
58+
_boolOptions[KNAPSACK_CUTS] = false;
5859

5960
/*
6061
Int options

src/configuration/Options.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ class Options
8383
// logically-consecutive weighted sum layers into a single
8484
// weighted sum layer, to reduce the number of variables
8585
DO_NOT_MERGE_CONSECUTIVE_WEIGHTED_SUM_LAYERS,
86+
87+
// Enable knapsack cuts: at each UNSAT leaf, learn a linear cut
88+
// group over upstream ReLU phase indicators that prunes any
89+
// subproblem in which the group is implied.
90+
KNAPSACK_CUTS,
8691
};
8792

8893
enum IntOptions {

src/engine/Engine.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Engine::Engine()
5858
, _lastIterationWithProgress( 0 )
5959
, _symbolicBoundTighteningType( Options::get()->getSymbolicBoundTighteningType() )
6060
, _solveWithMILP( Options::get()->getBool( Options::SOLVE_WITH_MILP ) )
61+
, _useKnapsackCuts( Options::get()->getBool( Options::KNAPSACK_CUTS ) )
6162
, _lpSolverType( Options::get()->getLPSolverType() )
6263
, _gurobi( nullptr )
6364
, _milpEncoder( nullptr )
@@ -200,6 +201,9 @@ bool Engine::solve( double timeoutInSeconds )
200201
// Before encoding, make sure all valid constraints are applied.
201202
applyAllValidConstraintCaseSplits();
202203

204+
if ( _useKnapsackCuts )
205+
_knapsackCutManager.initialize( _networkLevelReasoner, _plConstraints, &_boundManager );
206+
203207
if ( _solveWithMILP )
204208
return solveWithMILPEncoding( timeoutInSeconds );
205209

@@ -243,6 +247,8 @@ bool Engine::solve( double timeoutInSeconds )
243247
_statistics.print();
244248
}
245249

250+
if ( _useKnapsackCuts )
251+
_knapsackCutManager.printSummary();
246252
_exitCode = Engine::TIMEOUT;
247253
_statistics.timeout();
248254
return false;
@@ -257,6 +263,8 @@ bool Engine::solve( double timeoutInSeconds )
257263
_statistics.print();
258264
}
259265

266+
if ( _useKnapsackCuts )
267+
_knapsackCutManager.printSummary();
260268
_exitCode = Engine::QUIT_REQUESTED;
261269
return false;
262270
}
@@ -294,6 +302,9 @@ bool Engine::solve( double timeoutInSeconds )
294302
performBoundTighteningAfterCaseSplit();
295303
informLPSolverOfBounds();
296304
splitJustPerformed = false;
305+
306+
if ( _useKnapsackCuts && _knapsackCutManager.checkPruning() )
307+
throw InfeasibleQueryException();
297308
}
298309

299310
if ( _searchTreeHandler.needToSplit() )
@@ -342,6 +353,8 @@ bool Engine::solve( double timeoutInSeconds )
342353
ASSERT( _UNSATCertificateCurrentPointer );
343354
( **_UNSATCertificateCurrentPointer ).setSATSolutionFlag();
344355
}
356+
if ( _useKnapsackCuts )
357+
_knapsackCutManager.printSummary();
345358
_exitCode = Engine::SAT;
346359
return true;
347360
}
@@ -356,6 +369,8 @@ bool Engine::solve( double timeoutInSeconds )
356369
printf( "\nEngine::solve: at leaf node but solving inconclusive\n" );
357370
_statistics.print();
358371
}
372+
if ( _useKnapsackCuts )
373+
_knapsackCutManager.printSummary();
359374
_exitCode = Engine::UNKNOWN;
360375
return false;
361376
}
@@ -407,6 +422,9 @@ bool Engine::solve( double timeoutInSeconds )
407422
if ( _produceUNSATProofs )
408423
explainSimplexFailure();
409424

425+
if ( _useKnapsackCuts )
426+
_knapsackCutManager.collectFromCurrentLeaf();
427+
410428
if ( !_searchTreeHandler.popSplit() )
411429
{
412430
mainLoopEnd = TimeUtils::sampleMicro();
@@ -417,6 +435,8 @@ bool Engine::solve( double timeoutInSeconds )
417435
printf( "\nEngine::solve: unsat query\n" );
418436
_statistics.print();
419437
}
438+
if ( _useKnapsackCuts )
439+
_knapsackCutManager.printSummary();
420440
_exitCode = Engine::UNSAT;
421441
return false;
422442
}

src/engine/Engine.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "IEngine.h"
3333
#include "IQuery.h"
3434
#include "JsonWriter.h"
35+
#include "KnapsackCutManager.h"
3536
#include "LPSolverType.h"
3637
#include "LinearExpression.h"
3738
#include "MILPEncoder.h"
@@ -502,6 +503,13 @@ class Engine
502503
*/
503504
bool _solveWithMILP;
504505

506+
/*
507+
Knapsack cuts: learned at UNSAT leaves, evaluated at every new
508+
search node to prune subsumed subproblems.
509+
*/
510+
bool _useKnapsackCuts;
511+
KnapsackCutManager _knapsackCutManager;
512+
505513
/*
506514
The solver to solve the LP during the complete search.
507515
*/

src/engine/KnapsackCut.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/********************* */
2+
/*! \file KnapsackCut.h
3+
** \verbatim
4+
** This file is part of the Marabou project.
5+
** Copyright (c) 2017-2024 by the authors listed in the file AUTHORS
6+
** in the top-level source directory) and their institutional affiliations.
7+
** All rights reserved. See the file COPYING in the top-level source
8+
** directory for licensing information.\endverbatim
9+
**
10+
** KnapsackCut implements a topology-aware cutting plane for ReLU networks.
11+
**
12+
** For a ReLU neuron b with pre-activation pre_b = sum_i W_bi * post_i + bias_b,
13+
** the cut encodes: if the weighted combination of upstream phase indicators
14+
** exceeds a threshold, then b's phase is guaranteed.
15+
**
16+
** A cut group (from a single UNSAT leaf) prunes a subproblem when ALL cuts
17+
** in the group are satisfied, meaning all target neurons are in their
18+
** UNSAT-leaf phases.
19+
**
20+
** Only "furthest" fixes are included: neurons with no downstream fixed neurons.
21+
**
22+
** Cut for active phase (pre_b >= 0):
23+
** sum_i a_i * z_i + c >= 0
24+
** where a_i = W_bi * lb_i (if W_bi > 0) or W_bi * ub_i (if W_bi < 0)
25+
** c = bias_b + folded constant contributions
26+
**
27+
** Cut for inactive phase (pre_b <= 0), negated to >= form:
28+
** sum_i a_i * z_i + c >= 0
29+
** where a_i = -W_bi * ub_i (if W_bi > 0) or -W_bi * lb_i (if W_bi < 0)
30+
** c = -bias_b + folded constant contributions
31+
**/
32+
33+
#ifndef __KnapsackCut_h__
34+
#define __KnapsackCut_h__
35+
36+
#include "Map.h"
37+
#include "Vector.h"
38+
39+
struct KnapsackCut
40+
{
41+
// The pre-activation variable (b) of the target ReLU
42+
unsigned targetBVar;
43+
44+
// Phase of the target at the UNSAT leaf (true = active, false = inactive)
45+
bool isActive;
46+
47+
// NLR layer and neuron indices for identification
48+
unsigned reluLayerIdx;
49+
unsigned neuronIdx;
50+
51+
// Coefficients keyed by upstream ReLU f-variable
52+
// coeff[f_var] = contribution when upstream neuron is active (z=1)
53+
Map<unsigned, double> coefficients;
54+
55+
// Constant term (folded bias + non-phase-indicator contributions)
56+
double constant;
57+
};
58+
59+
struct KnapsackCutGroup
60+
{
61+
Vector<KnapsackCut> cuts;
62+
unsigned depth;
63+
};
64+
65+
#endif // __KnapsackCut_h__

0 commit comments

Comments
 (0)