@@ -9,8 +9,9 @@ import Prelude hiding (LT, GT)
99import GHC.Natural
1010import GHC.IO.Handle (Handle , hFlush , hSetBuffering , BufferMode (.. ))
1111import Control.Concurrent.Chan (Chan , newChan , writeChan , readChan )
12- import Control.Concurrent (forkIO , killThread )
13- import Control.Concurrent.STM (writeTChan , newTChan , TChan , tryReadTChan , atomically )
12+ import Control.Concurrent (forkIO , killThread , ThreadId , myThreadId )
13+ import Control.Concurrent.STM (writeTChan , newTChan , TChan , tryReadTChan , atomically , readTVar , readTVarIO , modifyTVar' , check )
14+ import Control.Concurrent.STM.TVar (TVar , newTVarIO )
1415import Control.Monad
1516import Control.Monad.State.Strict
1617import Control.Monad.IO.Unlift
@@ -61,7 +62,10 @@ data SolverInstance = SolverInstance
6162 }
6263
6364-- | A channel representing a group of solvers
64- newtype SolverGroup = SolverGroup (Chan Task )
65+ data SolverGroup = SolverGroup
66+ { taskQueue :: Chan Task
67+ , shouldAbort :: TVar Bool
68+ }
6569
6670data MultiSol = MultiSol
6771 { maxSols :: Int
@@ -92,13 +96,13 @@ supersetAny :: Set Prop -> [Set Prop] -> Bool
9296supersetAny a bs = any (`isSubsetOf` a) bs
9397
9498checkMulti :: SolverGroup -> Err SMT2 -> MultiSol -> IO (Maybe [W256 ])
95- checkMulti ( SolverGroup taskq) smt2 multiSol = do
99+ checkMulti sg smt2 multiSol = do
96100 if isLeft smt2 then pure Nothing
97101 else do
98102 -- prepare result channel
99103 resChan <- newChan
100104 -- send task to solver group
101- writeChan taskq (TaskMulti (MultiData (getNonError smt2) multiSol resChan))
105+ writeChan sg . taskQueue (TaskMulti (MultiData (getNonError smt2) multiSol resChan))
102106 -- collect result
103107 readChan resChan
104108
@@ -115,13 +119,13 @@ checkSatWithProps sg props = do
115119
116120-- When props is Nothing, the cache will not be filled or used
117121checkSat :: SolverGroup -> Maybe [Prop ] -> Err SMT2 -> IO SMTResult
118- checkSat ( SolverGroup taskq) props smt2 = do
122+ checkSat sg props smt2 = do
119123 if isLeft smt2 then pure $ Error $ getError smt2
120124 else do
121125 -- prepare result channel
122126 resChan <- newChan
123127 -- send task to solver group
124- writeChan taskq (TaskSingle (SingleData (getNonError smt2) props resChan))
128+ writeChan sg . taskQueue (TaskSingle (SingleData (getNonError smt2) props resChan))
125129 -- collect result
126130 readChan resChan
127131
@@ -139,44 +143,93 @@ withSolvers solver count threads timeout cont = do
139143 taskq <- liftIO newChan
140144 cacheq <- liftIO . atomically $ newTChan
141145 availableInstances <- liftIO newChan
146+ shouldAbort <- liftIO $ newTVarIO False
147+ runningThreads <- liftIO $ newTVarIO ([] :: [(ThreadId , Task )])
142148 liftIO $ forM_ instances (writeChan availableInstances)
143- orchestrate' <- toIO $ orchestrate taskq cacheq availableInstances [] 0
149+
150+ -- Spawn orchestration thread
151+ orchestrate' <- toIO $ orchestrate taskq cacheq availableInstances shouldAbort runningThreads [] 0
144152 orchestrateId <- liftIO $ forkIO orchestrate'
145153
154+ -- Spawn watcher thread that kills solver threads when abort flag is set
155+ abortWatcher' <- toIO $ abortWatcher shouldAbort runningThreads
156+ abortWatcherId <- liftIO $ forkIO abortWatcher'
157+
146158 -- run continuation with task queue
147- res <- cont (SolverGroup taskq)
159+ res <- cont (SolverGroup taskq shouldAbort )
148160
149161 -- cleanup and return results
150162 liftIO $ mapM_ (stopSolver) instances
151163 liftIO $ killThread orchestrateId
164+ liftIO $ killThread abortWatcherId
152165 pure res
153166 where
154- orchestrate :: App m => Chan Task -> TChan CacheEntry -> Chan SolverInstance -> [Set Prop ] -> Int -> m b
155- orchestrate taskq cacheq avail knownUnsat fileCounter = do
167+ -- Watcher thread that blocks until abort flag is set, then kills all solver threads
168+ abortWatcher :: App m => TVar Bool -> TVar [(ThreadId , Task )] -> m ()
169+ abortWatcher abortFlag runningThreads = do
156170 conf <- readConfig
157- mx <- liftIO . atomically $ tryReadTChan cacheq
158- case mx of
159- Just (CacheEntry props) -> do
160- let knownUnsat' = (fromList props): knownUnsat
161- when conf. debug $ liftIO $ putStrLn " adding UNSAT cache"
162- orchestrate taskq cacheq avail knownUnsat' fileCounter
163- Nothing -> do
164- task <- liftIO $ readChan taskq
171+ when conf. earlyAbort $ do
172+ -- Block until abort flag becomes True
173+ liftIO . atomically $ do
174+ abort <- readTVar abortFlag
175+ check abort
176+ -- Kill all running solver threads immediately and write errors to their result channels
177+ tidTasks <- liftIO $ readTVarIO runningThreads
178+ liftIO $ forM_ tidTasks $ \ (tid, task) -> do
179+ killThread tid
180+ -- Write error result to the task's result channel
165181 case task of
166- TaskSingle (SingleData _ props r) | isJust props && supersetAny (fromList (fromJust props)) knownUnsat -> do
167- liftIO $ writeChan r Qed
168- when conf. debug $ liftIO $ putStrLn " Qed found via cache!"
169- orchestrate taskq cacheq avail knownUnsat fileCounter
170- _ -> do
171- inst <- liftIO $ readChan avail
172- runTask' <- case task of
173- TaskSingle (SingleData smt2 props r) -> toIO $ getOneSol smt2 props r cacheq inst avail fileCounter
174- TaskMulti (MultiData smt2 multiSol r) -> toIO $ getMultiSol smt2 multiSol r inst avail fileCounter
175- _ <- liftIO $ forkIO runTask'
176- orchestrate taskq cacheq avail knownUnsat (fileCounter + 1 )
177-
178- getMultiSol :: forall m . (MonadIO m , ReadConfig m ) => SMT2 -> MultiSol -> (Chan (Maybe [W256 ])) -> SolverInstance -> Chan SolverInstance -> Int -> m ()
179- getMultiSol smt2@ (SMT2 cmds cexvars _) multiSol r inst availableInstances fileCounter = do
182+ TaskSingle (SingleData _ _ r) -> writeChan r (Error " Aborted due to early abort" )
183+ TaskMulti (MultiData _ _ r) -> writeChan r Nothing
184+ liftIO . atomically $ modifyTVar' runningThreads (const [] )
185+ when conf. debug $ liftIO $ putStrLn $ " [Abort Watcher] Killed " <> show (length tidTasks) <> " running solver thread(s) due to early abort"
186+
187+ orchestrate :: App m => Chan Task -> TChan CacheEntry -> Chan SolverInstance -> TVar Bool -> TVar [(ThreadId , Task )] -> [Set Prop ] -> Int -> m b
188+ orchestrate taskq cacheq avail abortFlag runningThreads knownUnsat fileCounter = do
189+ conf <- readConfig
190+ -- Check if we should abort early
191+ abort <- liftIO $ readTVarIO abortFlag
192+ if (conf. earlyAbort && abort) then do
193+ -- Drain pending tasks and return errors
194+ drainTasks taskq
195+ else do
196+ mx <- liftIO . atomically $ tryReadTChan cacheq
197+ case mx of
198+ Just (CacheEntry props) -> do
199+ let knownUnsat' = (fromList props): knownUnsat
200+ when conf. debug $ liftIO $ putStrLn " adding UNSAT cache"
201+ orchestrate taskq cacheq avail abortFlag runningThreads knownUnsat' fileCounter
202+ Nothing -> do
203+ task <- liftIO $ readChan taskq
204+ case task of
205+ TaskSingle (SingleData _ props r) | isJust props && supersetAny (fromList (fromJust props)) knownUnsat -> do
206+ liftIO $ writeChan r Qed
207+ when conf. debug $ liftIO $ putStrLn " Qed found via cache!"
208+ orchestrate taskq cacheq avail abortFlag runningThreads knownUnsat fileCounter
209+ _ -> do
210+ inst <- liftIO $ readChan avail
211+ runTask' <- case task of
212+ TaskSingle (SingleData smt2 props r) -> toIO $ getOneSol smt2 props r cacheq inst avail runningThreads fileCounter
213+ TaskMulti (MultiData smt2 multiSol r) -> toIO $ getMultiSol smt2 multiSol r inst avail runningThreads fileCounter
214+ tid <- liftIO $ forkIO runTask'
215+ liftIO . atomically $ modifyTVar' runningThreads ((tid, task): )
216+ orchestrate taskq cacheq avail abortFlag runningThreads knownUnsat (fileCounter + 1 )
217+
218+ -- Drain any pending tasks from the queue and return error results
219+ drainTasks :: App m => Chan Task -> m b
220+ drainTasks taskq = do
221+ conf <- readConfig
222+ when conf. debug $ liftIO $ putStrLn " [Orchestrate] Draining pending tasks due to early abort"
223+ forever $ do
224+ task <- liftIO $ readChan taskq
225+ case task of
226+ TaskSingle (SingleData _ _ r) -> liftIO $ writeChan r (Error " Aborted due to early abort" )
227+ TaskMulti (MultiData _ _ r) -> liftIO $ writeChan r Nothing
228+
229+ getMultiSol :: forall m . (MonadIO m , ReadConfig m ) => SMT2 -> MultiSol -> (Chan (Maybe [W256 ])) -> SolverInstance -> Chan SolverInstance -> TVar [(ThreadId , Task )] -> Int -> m ()
230+ getMultiSol smt2@ (SMT2 cmds cexvars _) multiSol r inst availableInstances runningThreads fileCounter = do
231+ tid <- liftIO myThreadId
232+ let cleanup = liftIO . atomically $ modifyTVar' runningThreads (filter (\ (t, _) -> t /= tid))
180233 conf <- readConfig
181234 when conf. dumpQueries $ liftIO $ writeSMT2File smt2 " ." (show fileCounter)
182235 -- reset solver and send all lines of provided script
@@ -195,6 +248,8 @@ getMultiSol smt2@(SMT2 cmds cexvars _) multiSol r inst availableInstances fileCo
195248 subRun [] smt2 sat
196249 -- put the instance back in the list of available instances
197250 liftIO $ writeChan availableInstances inst
251+ -- remove thread from running threads list
252+ cleanup
198253 where
199254 maskFromBytesCount k
200255 | k <= 32 = (2 ^ (8 * k) - 1 )
@@ -243,9 +298,10 @@ getMultiSol smt2@(SMT2 cmds cexvars _) multiSol r inst availableInstances fileCo
243298 when conf. debug $ putStrLn $ " Unable to write SMT to solver: " <> (T. unpack err)
244299 writeChan r Nothing
245300
246- getOneSol :: (MonadIO m , ReadConfig m ) => SMT2 -> Maybe [Prop ] -> Chan SMTResult -> TChan CacheEntry -> SolverInstance -> Chan SolverInstance -> Int -> m ()
247- getOneSol smt2@ (SMT2 cmds cexvars _) props r cacheq inst availableInstances fileCounter = do
301+ getOneSol :: (MonadIO m , ReadConfig m ) => SMT2 -> Maybe [Prop ] -> Chan SMTResult -> TChan CacheEntry -> SolverInstance -> Chan SolverInstance -> TVar [( ThreadId , Task )] -> Int -> m ()
302+ getOneSol smt2@ (SMT2 cmds cexvars _) props r cacheq inst availableInstances runningThreads fileCounter = do
248303 conf <- readConfig
304+ tid <- liftIO myThreadId
249305 liftIO $ do
250306 when (conf. dumpQueries) $ writeSMT2File smt2 " ." (show fileCounter)
251307 -- reset solver and send all lines of provided script
@@ -275,6 +331,8 @@ getOneSol smt2@(SMT2 cmds cexvars _) props r cacheq inst availableInstances file
275331
276332 -- put the instance back in the list of available instances
277333 writeChan availableInstances inst
334+ -- remove thread from running threads list
335+ liftIO . atomically $ modifyTVar' runningThreads (filter (\ (t, _) -> t /= tid))
278336
279337dumpUnsolved :: SMT2 -> Int -> Maybe FilePath -> IO ()
280338dumpUnsolved fullSmt fileCounter dump = do
0 commit comments