@@ -9,8 +9,9 @@ import Prelude hiding (LT, GT)
99import GHC.Natural
1010import GHC.IO.Handle (Handle , hFlush , hSetBuffering , BufferMode (.. ))
1111import Control.Concurrent.Chan (Chan , newChan , writeChan , readChan )
12- import Control.Concurrent (forkIO , killThread )
13- import Control.Concurrent.STM (writeTChan , newTChan , TChan , tryReadTChan , atomically )
12+ import Control.Concurrent (forkIO , killThread , ThreadId , myThreadId )
13+ import Control.Concurrent.STM (writeTChan , newTChan , TChan , tryReadTChan , atomically , readTVar , readTVarIO , modifyTVar' , check )
14+ import Control.Concurrent.STM.TVar (TVar , newTVarIO )
1415import Control.Monad
1516import Control.Monad.State.Strict
1617import Control.Monad.IO.Unlift
@@ -61,7 +62,10 @@ data SolverInstance = SolverInstance
6162 }
6263
6364-- | A channel representing a group of solvers
64- newtype SolverGroup = SolverGroup (Chan Task )
65+ data SolverGroup = SolverGroup
66+ { taskQueue :: Chan Task
67+ , shouldAbort :: TVar Bool
68+ }
6569
6670data MultiSol = MultiSol
6771 { maxSols :: Int
@@ -92,13 +96,13 @@ supersetAny :: Set Prop -> [Set Prop] -> Bool
9296supersetAny a bs = any (`isSubsetOf` a) bs
9397
9498checkMulti :: SolverGroup -> Err SMT2 -> MultiSol -> IO (Maybe [W256 ])
95- checkMulti ( SolverGroup taskq) smt2 multiSol = do
99+ checkMulti sg smt2 multiSol = do
96100 if isLeft smt2 then pure Nothing
97101 else do
98102 -- prepare result channel
99103 resChan <- newChan
100104 -- send task to solver group
101- writeChan taskq (TaskMulti (MultiData (getNonError smt2) multiSol resChan))
105+ writeChan sg . taskQueue (TaskMulti (MultiData (getNonError smt2) multiSol resChan))
102106 -- collect result
103107 readChan resChan
104108
@@ -115,13 +119,13 @@ checkSatWithProps sg props = do
115119
116120-- When props is Nothing, the cache will not be filled or used
117121checkSat :: SolverGroup -> Maybe [Prop ] -> Err SMT2 -> IO SMTResult
118- checkSat ( SolverGroup taskq) props smt2 = do
122+ checkSat sg props smt2 = do
119123 if isLeft smt2 then pure $ Error $ getError smt2
120124 else do
121125 -- prepare result channel
122126 resChan <- newChan
123127 -- send task to solver group
124- writeChan taskq (TaskSingle (SingleData (getNonError smt2) props resChan))
128+ writeChan sg . taskQueue (TaskSingle (SingleData (getNonError smt2) props resChan))
125129 -- collect result
126130 readChan resChan
127131
@@ -139,44 +143,91 @@ withSolvers solver count threads timeout cont = do
139143 taskq <- liftIO newChan
140144 cacheq <- liftIO . atomically $ newTChan
141145 availableInstances <- liftIO newChan
146+ shouldAbort <- liftIO $ newTVarIO False
147+ runningThreads <- liftIO $ newTVarIO ([] :: [(ThreadId , Task )])
142148 liftIO $ forM_ instances (writeChan availableInstances)
143- orchestrate' <- toIO $ orchestrate taskq cacheq availableInstances [] 0
149+
150+ -- Spawn orchestration thread
151+ orchestrate' <- toIO $ orchestrate taskq cacheq availableInstances shouldAbort runningThreads [] 0
144152 orchestrateId <- liftIO $ forkIO orchestrate'
145153
154+ -- Spawn watcher thread that kills solver threads when abort flag is set
155+ abortWatcher' <- toIO $ abortWatcher shouldAbort runningThreads
156+ abortWatcherId <- liftIO $ forkIO abortWatcher'
157+
146158 -- run continuation with task queue
147- res <- cont (SolverGroup taskq)
159+ res <- cont (SolverGroup taskq shouldAbort )
148160
149161 -- cleanup and return results
150162 liftIO $ mapM_ (stopSolver) instances
151163 liftIO $ killThread orchestrateId
164+ liftIO $ killThread abortWatcherId
152165 pure res
153166 where
154- orchestrate :: App m => Chan Task -> TChan CacheEntry -> Chan SolverInstance -> [Set Prop ] -> Int -> m b
155- orchestrate taskq cacheq avail knownUnsat fileCounter = do
167+ -- Watcher thread that blocks until abort flag is set, then kills all solver threads
168+ abortWatcher :: App m => TVar Bool -> TVar [(ThreadId , Task )] -> m ()
169+ abortWatcher abortFlag runningThreads = do
156170 conf <- readConfig
157- mx <- liftIO . atomically $ tryReadTChan cacheq
158- case mx of
159- Just (CacheEntry props) -> do
160- let knownUnsat' = (fromList props): knownUnsat
161- when conf. debug $ liftIO $ putStrLn " adding UNSAT cache"
162- orchestrate taskq cacheq avail knownUnsat' fileCounter
163- Nothing -> do
164- task <- liftIO $ readChan taskq
171+ when conf. earlyAbort $ do
172+ -- Block until abort flag becomes True
173+ liftIO . atomically $ do
174+ abort <- readTVar abortFlag
175+ check abort
176+ -- Kill all running solver threads immediately and write errors to their result channels
177+ tidTasks <- liftIO $ readTVarIO runningThreads
178+ liftIO $ forM_ tidTasks $ \ (tid, task) -> do
179+ killThread tid
180+ -- Write error result to the task's result channel
165181 case task of
166- TaskSingle (SingleData _ props r) | isJust props && supersetAny (fromList (fromJust props)) knownUnsat -> do
167- liftIO $ writeChan r Qed
168- when conf. debug $ liftIO $ putStrLn " Qed found via cache!"
169- orchestrate taskq cacheq avail knownUnsat fileCounter
170- _ -> do
171- inst <- liftIO $ readChan avail
172- runTask' <- case task of
173- TaskSingle (SingleData smt2 props r) -> toIO $ getOneSol smt2 props r cacheq inst avail fileCounter
174- TaskMulti (MultiData smt2 multiSol r) -> toIO $ getMultiSol smt2 multiSol r inst avail fileCounter
175- _ <- liftIO $ forkIO runTask'
176- orchestrate taskq cacheq avail knownUnsat (fileCounter + 1 )
177-
178- getMultiSol :: forall m . (MonadIO m , ReadConfig m ) => SMT2 -> MultiSol -> (Chan (Maybe [W256 ])) -> SolverInstance -> Chan SolverInstance -> Int -> m ()
179- getMultiSol smt2@ (SMT2 cmds cexvars _) multiSol r inst availableInstances fileCounter = do
182+ TaskSingle (SingleData _ _ r) -> writeChan r (Error " Aborted due to early abort" )
183+ TaskMulti (MultiData _ _ r) -> writeChan r Nothing
184+ liftIO . atomically $ modifyTVar' runningThreads (const [] )
185+ when conf. debug $ liftIO $ putStrLn $ " [Abort Watcher] Killed " <> show (length tidTasks) <> " running solver thread(s) due to early abort"
186+
187+ orchestrate :: App m => Chan Task -> TChan CacheEntry -> Chan SolverInstance -> TVar Bool -> TVar [(ThreadId , Task )] -> [Set Prop ] -> Int -> m b
188+ orchestrate taskq cacheq avail abortFlag runningThreads knownUnsat fileCounter = do
189+ conf <- readConfig
190+ -- Check if we should abort early
191+ abort <- liftIO $ readTVarIO abortFlag
192+ if (conf. earlyAbort && abort) then drainTasks taskq
193+ else do
194+ mx <- liftIO . atomically $ tryReadTChan cacheq
195+ case mx of
196+ Just (CacheEntry props) -> do
197+ let knownUnsat' = (fromList props): knownUnsat
198+ when conf. debug $ liftIO $ putStrLn " adding UNSAT cache"
199+ orchestrate taskq cacheq avail abortFlag runningThreads knownUnsat' fileCounter
200+ Nothing -> do
201+ task <- liftIO $ readChan taskq
202+ case task of
203+ TaskSingle (SingleData _ props r) | isJust props && supersetAny (fromList (fromJust props)) knownUnsat -> do
204+ liftIO $ writeChan r Qed
205+ when conf. debug $ liftIO $ putStrLn " Qed found via cache!"
206+ orchestrate taskq cacheq avail abortFlag runningThreads knownUnsat fileCounter
207+ _ -> do
208+ inst <- liftIO $ readChan avail
209+ runTask' <- case task of
210+ TaskSingle (SingleData smt2 props r) -> toIO $ getOneSol smt2 props r cacheq inst avail runningThreads fileCounter
211+ TaskMulti (MultiData smt2 multiSol r) -> toIO $ getMultiSol smt2 multiSol r inst avail runningThreads fileCounter
212+ tid <- liftIO $ forkIO runTask'
213+ liftIO . atomically $ modifyTVar' runningThreads ((tid, task): )
214+ orchestrate taskq cacheq avail abortFlag runningThreads knownUnsat (fileCounter + 1 )
215+
216+ -- Drain any pending tasks from the queue and return error results
217+ drainTasks :: App m => Chan Task -> m b
218+ drainTasks taskq = do
219+ conf <- readConfig
220+ when conf. debug $ liftIO $ putStrLn " [Orchestrate] Draining pending tasks due to early abort"
221+ forever $ do
222+ task <- liftIO $ readChan taskq
223+ case task of
224+ TaskSingle (SingleData _ _ r) -> liftIO $ writeChan r (Error " Aborted due to early abort" )
225+ TaskMulti (MultiData _ _ r) -> liftIO $ writeChan r Nothing
226+
227+ getMultiSol :: forall m . (MonadIO m , ReadConfig m ) => SMT2 -> MultiSol -> (Chan (Maybe [W256 ])) -> SolverInstance -> Chan SolverInstance -> TVar [(ThreadId , Task )] -> Int -> m ()
228+ getMultiSol smt2@ (SMT2 cmds cexvars _) multiSol r inst availableInstances runningThreads fileCounter = do
229+ tid <- liftIO myThreadId
230+ let cleanup = liftIO . atomically $ modifyTVar' runningThreads (filter (\ (t, _) -> t /= tid))
180231 conf <- readConfig
181232 when conf. dumpQueries $ liftIO $ writeSMT2File smt2 " ." (show fileCounter)
182233 -- reset solver and send all lines of provided script
@@ -195,6 +246,8 @@ getMultiSol smt2@(SMT2 cmds cexvars _) multiSol r inst availableInstances fileCo
195246 subRun [] smt2 sat
196247 -- put the instance back in the list of available instances
197248 liftIO $ writeChan availableInstances inst
249+ -- remove thread from running threads list
250+ cleanup
198251 where
199252 maskFromBytesCount k
200253 | k <= 32 = (2 ^ (8 * k) - 1 )
@@ -243,9 +296,10 @@ getMultiSol smt2@(SMT2 cmds cexvars _) multiSol r inst availableInstances fileCo
243296 when conf. debug $ putStrLn $ " Unable to write SMT to solver: " <> (T. unpack err)
244297 writeChan r Nothing
245298
246- getOneSol :: (MonadIO m , ReadConfig m ) => SMT2 -> Maybe [Prop ] -> Chan SMTResult -> TChan CacheEntry -> SolverInstance -> Chan SolverInstance -> Int -> m ()
247- getOneSol smt2@ (SMT2 cmds cexvars _) props r cacheq inst availableInstances fileCounter = do
299+ getOneSol :: (MonadIO m , ReadConfig m ) => SMT2 -> Maybe [Prop ] -> Chan SMTResult -> TChan CacheEntry -> SolverInstance -> Chan SolverInstance -> TVar [( ThreadId , Task )] -> Int -> m ()
300+ getOneSol smt2@ (SMT2 cmds cexvars _) props r cacheq inst availableInstances runningThreads fileCounter = do
248301 conf <- readConfig
302+ tid <- liftIO myThreadId
249303 liftIO $ do
250304 when (conf. dumpQueries) $ writeSMT2File smt2 " ." (show fileCounter)
251305 -- reset solver and send all lines of provided script
@@ -275,6 +329,8 @@ getOneSol smt2@(SMT2 cmds cexvars _) props r cacheq inst availableInstances file
275329
276330 -- put the instance back in the list of available instances
277331 writeChan availableInstances inst
332+ -- remove thread from running threads list
333+ liftIO . atomically $ modifyTVar' runningThreads (filter (\ (t, _) -> t /= tid))
278334
279335dumpUnsolved :: SMT2 -> Int -> Maybe FilePath -> IO ()
280336dumpUnsolved fullSmt fileCounter dump = do
0 commit comments