Skip to content

Commit b188fa8

Browse files
author
Alex McKenna
committed
WIP - concurrent normalization
1 parent 92adc32 commit b188fa8

File tree

17 files changed

+189
-111
lines changed

17 files changed

+189
-111
lines changed

benchmark/benchmark-normalization.hs

+6-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import Clash.Netlist.Types (TopEntityT(topId))
1515

1616
import Criterion.Main
1717

18+
import qualified Control.Concurrent.MVar as MVar
1819
import qualified Control.Concurrent.Supply as Supply
1920
import Control.DeepSeq (NFData(..), rwhnf)
2021
import Data.List (isPrefixOf, partition)
@@ -42,7 +43,7 @@ main = do
4243
benchFile :: [FilePath] -> FilePath -> Benchmark
4344
benchFile idirs src =
4445
env (setupEnv idirs src) $
45-
\ ~(clashEnv, clashDesign, supplyN) -> do
46+
\ ~(clashEnv, clashDesign, supplyN, lock) -> do
4647
bench ("normalization of " ++ src)
4748
(nfIO
4849
(normalizeEntity
@@ -51,18 +52,20 @@ benchFile idirs src =
5152
(ghcTypeToHWType (opt_intWidth (envOpts clashEnv)))
5253
ghcEvaluator
5354
evaluator
55+
lock
5456
(fmap topId (designEntities clashDesign))
5557
supplyN
5658
(topId (head (designEntities clashDesign)))))
5759

5860
setupEnv
5961
:: [FilePath]
6062
-> FilePath
61-
-> IO (ClashEnv, ClashDesign, Supply.Supply)
63+
-> IO (ClashEnv, ClashDesign, Supply.Supply, MVar.MVar ())
6264
setupEnv idirs src = do
6365
(clashEnv, clashDesign) <- runInputStage idirs src
6466
supplyN <- Supply.newSupply
65-
return (clashEnv, clashDesign ,supplyN)
67+
lock <- MVar.newMVar ()
68+
return (clashEnv, clashDesign ,supplyN, lock)
6669

6770
instance NFData Supply.Supply where
6871
rnf = rwhnf

benchmark/common/BenchmarkCommon.hs

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import Clash.GHC.Evaluator
1515
import Clash.GHC.GenerateBindings
1616
import Clash.GHC.NetlistTypes
1717

18+
import qualified Control.Concurrent.MVar as MVar
1819
import qualified Control.Concurrent.Supply as Supply
1920

2021
defaultTests :: [FilePath]
@@ -57,6 +58,7 @@ runNormalisationStage
5758
-> IO (ClashEnv, ClashDesign, Id)
5859
runNormalisationStage idirs src = do
5960
supplyN <- Supply.newSupply
61+
lock <- MVar.newMVar ()
6062
(env, design) <- runInputStage idirs src
6163
let topEntityNames = fmap topId (designEntities design)
6264
let topEntity = head topEntityNames
@@ -65,5 +67,6 @@ runNormalisationStage idirs src = do
6567
(ghcTypeToHWType (opt_intWidth (opts idirs)))
6668
ghcEvaluator
6769
evaluator
70+
lock
6871
topEntityNames supplyN topEntity
6972
return (env, design{designBindings=transformedBindings},topEntity)

benchmark/profiling/run/profile-normalization-run.hs

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Clash.GHC.PartialEval
77
import Clash.GHC.Evaluator
88
import Clash.GHC.NetlistTypes (ghcTypeToHWType)
99

10+
import qualified Control.Concurrent.MVar as MVar
1011
import qualified Control.Concurrent.Supply as Supply
1112
import Control.DeepSeq (deepseq)
1213
import Data.Binary (decode)
@@ -32,6 +33,7 @@ main = do
3233
benchFile :: [FilePath] -> FilePath -> IO ()
3334
benchFile idirs src = do
3435
supplyN <- Supply.newSupply
36+
lock <- MVar.newMVar ()
3537
(bindingsMap,tcm,tupTcm,primMap,reprs,topEntityNames,topEntity) <- setupEnv src
3638
putStrLn $ "Doing normalization of " ++ src
3739

@@ -47,6 +49,7 @@ benchFile idirs src = do
4749
(ghcTypeToHWType (opt_intWidth (envOpts clashEnv)))
4850
ghcEvaluator
4951
evaluator
52+
lock
5053
topEntityNames supplyN topEntity
5154
res `deepseq` putStrLn ".. done\n"
5255

clash-ghc/clash-ghc.cabal

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ flag use-ghc-paths
6868
executable clash
6969
Main-Is: src-ghc/Batch.hs
7070
Build-Depends: base, clash-ghc
71-
GHC-Options: -Wall -Wcompat
71+
GHC-Options: -Wall -Wcompat -threaded -rtsopts
7272
if flag(dynamic)
7373
GHC-Options: -dynamic
7474
extra-libraries: pthread
@@ -77,7 +77,7 @@ executable clash
7777
executable clashi
7878
Main-Is: src-ghc/Interactive.hs
7979
Build-Depends: base, clash-ghc
80-
GHC-Options: -Wall -Wcompat
80+
GHC-Options: -Wall -Wcompat -threaded -rtsopts
8181
if flag(dynamic)
8282
GHC-Options: -dynamic
8383
extra-libraries: pthread

clash-lib/clash-lib.cabal

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ Library
130130
aeson-pretty >= 0.8 && < 0.9,
131131
ansi-terminal >= 0.8.0.0 && < 0.12,
132132
array,
133-
async >= 2.2.0 && < 2.3,
134133
attoparsec >= 0.10.4.0 && < 0.15,
135134
base >= 4.11 && < 5,
136135
base16-bytestring >= 0.1.1 && < 1.1,
@@ -156,6 +155,7 @@ Library
156155
interpolate >= 0.2.0 && < 1.0,
157156
lens >= 4.10 && < 5.1.0,
158157
-- TODO bounds
158+
lifted-async,
159159
lifted-base,
160160
monad-control,
161161
mtl >= 2.1.2 && < 2.3,

clash-lib/src/Clash/Debug.hs

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
module Clash.Debug
44
( debugIsOn
55
, traceIf
6+
, traceWhen
67
, module Debug.Trace
78
) where
89

@@ -19,4 +20,9 @@ debugIsOn = False
1920
traceIf :: Bool -> String -> a -> a
2021
traceIf True msg = trace msg
2122
traceIf False _ = id
23+
24+
traceWhen :: Monad m => Bool -> String -> m ()
25+
traceWhen True = traceM
26+
traceWhen False = const (pure ())
27+
2228
{-# INLINE traceIf #-}

clash-lib/src/Clash/Driver.hs

+6-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
module Clash.Driver where
2323

2424
import Control.Concurrent (MVar, modifyMVar, modifyMVar_, newMVar, withMVar)
25-
import Control.Concurrent.Async (mapConcurrently_)
25+
import Control.Concurrent.Async.Lifted (mapConcurrently_)
2626
import qualified Control.Concurrent.Supply as Supply
2727
import Control.DeepSeq
2828
import Control.Exception (throw)
@@ -442,7 +442,7 @@ generateHDL env design hdlState typeTrans peEval eval mainTopEntity startTime =
442442
-- 2. Normalize topEntity
443443
supplyN <- Supply.newSupply
444444
transformedBindings <- normalizeEntity env bindingsMap typeTrans peEval
445-
eval topEntityNames supplyN topEntity
445+
eval ioLockV topEntityNames supplyN topEntity
446446

447447
normTime <- transformedBindings `deepseq` Clock.getCurrentTime
448448
let prepNormDiff = reportTimeDiff normTime prevTime
@@ -1062,21 +1062,23 @@ normalizeEntity
10621062
-- ^ Hardcoded evaluator for partial evaluation
10631063
-> WHNF.Evaluator
10641064
-- ^ Hardcoded evaluator for WHNF (old evaluator)
1065+
-> MVar ()
1066+
-- ^ Synchroniztion for stdout
10651067
-> [Id]
10661068
-- ^ TopEntities
10671069
-> Supply.Supply
10681070
-- ^ Unique supply
10691071
-> Id
10701072
-- ^ root of the hierarchy
10711073
-> IO BindingMap
1072-
normalizeEntity env bindingsMap typeTrans peEval eval topEntities supply tm = transformedBindings
1074+
normalizeEntity env bindingsMap typeTrans peEval eval lock topEntities supply tm = transformedBindings
10731075
where
10741076
doNorm = do norm <- normalize [tm]
10751077
let normChecked = checkNonRecursive norm
10761078
cleaned <- cleanupGraph tm normChecked
10771079
return cleaned
10781080
transformedBindings = runNormalization env supply bindingsMap
1079-
typeTrans peEval eval emptyVarEnv
1081+
typeTrans peEval eval emptyVarEnv lock
10801082
topEntities doNorm
10811083

10821084
-- | topologically sort the top entities

clash-lib/src/Clash/Normalize.hs

+44-30
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,21 @@
1010
-}
1111

1212
{-# LANGUAGE CPP #-}
13+
{-# LANGUAGE FlexibleContexts #-}
1314
{-# LANGUAGE OverloadedStrings #-}
1415
{-# LANGUAGE QuasiQuotes #-}
1516
{-# LANGUAGE TemplateHaskell #-}
1617

1718
module Clash.Normalize where
1819

20+
import qualified Control.Concurrent.Async.Lifted as Async
21+
import Control.Concurrent.MVar.Lifted (MVar)
1922
import qualified Control.Concurrent.MVar.Lifted as MVar
2023
import Control.Concurrent.Supply (Supply)
2124
import Control.Exception (throw)
2225
import qualified Control.Lens as Lens
2326
import Control.Monad (when)
27+
import qualified Control.Monad.IO.Class as Monad (liftIO)
2428
import Control.Monad.State.Strict (State)
2529
import Data.Default (def)
2630
import Data.Either (lefts,partitionEithers)
@@ -79,7 +83,7 @@ import Clash.Normalize.Util
7983
import Clash.Rewrite.Combinators ((>->),(!->),repeatR,topdownR)
8084
import Clash.Rewrite.Types
8185
(RewriteEnv (..), RewriteState (..), bindings, debugOpts, extra,
82-
tcCache, topEntities, newInlineStrategy)
86+
tcCache, topEntities, newInlineStrategy, ioLock)
8387
import Clash.Rewrite.Util
8488
(apply, isUntranslatableType, runRewriteSession)
8589
import Clash.Util
@@ -89,9 +93,9 @@ import Data.Binary (encode)
8993
import qualified Data.ByteString as BS
9094
import qualified Data.ByteString.Lazy as BL
9195

92-
import System.IO.Unsafe (unsafePerformIO)
9396
import Clash.Rewrite.Types (RewriteStep(..))
9497

98+
import Clash.Debug -- TODO
9599

96100
-- | Run a NormalizeSession in a given environment
97101
runNormalization
@@ -109,12 +113,14 @@ runNormalization
109113
-- ^ Hardcoded evaluator for WHNF (old evaluator)
110114
-> VarEnv Bool
111115
-- ^ Map telling whether a components is part of a recursive group
116+
-> MVar ()
117+
-- ^ Synchronization on stdout
112118
-> [Id]
113119
-- ^ topEntities
114120
-> NormalizeSession a
115121
-- ^ NormalizeSession to run
116122
-> IO a
117-
runNormalization env supply globals typeTrans peEval eval rcsMap entities session = do
123+
runNormalization env supply globals typeTrans peEval eval rcsMap lock entities session = do
118124
normState <- NormalizeState
119125
<$> MVar.newMVar emptyVarEnv
120126
<*> MVar.newMVar Map.empty
@@ -131,6 +137,7 @@ runNormalization env supply globals typeTrans peEval eval rcsMap entities sessio
131137
<*> MVar.newMVar 0
132138
<*> MVar.newMVar (mempty, 0)
133139
<*> MVar.newMVar emptyVarEnv
140+
<*> pure lock
134141
<*> pure normState
135142

136143
runRewriteSession rwEnv rwState session
@@ -143,20 +150,17 @@ runNormalization env supply globals typeTrans peEval eval rcsMap entities sessio
143150
, _topEntities = mkVarSet entities
144151
}
145152

146-
normalize
147-
:: [Id]
148-
-> NormalizeSession BindingMap
149-
normalize [] = return emptyVarEnv
150-
normalize top = do
151-
(new,topNormalized) <- unzip <$> mapM normalize' top
152-
newNormalized <- normalize (concat new)
153-
return (unionVarEnv (mkVarEnv topNormalized) newNormalized)
153+
normalize :: [Id] -> NormalizeSession BindingMap
154+
normalize tops = do
155+
normBinds <- Async.mapConcurrently normalize' tops
156+
pure (mkVarEnv (concat normBinds))
154157

155-
normalize' :: Id -> NormalizeSession ([Id], (Id, Binding Term))
158+
normalize' :: Id -> NormalizeSession [(Id, Binding Term)]
156159
normalize' nm = do
157160
bndrsV <- Lens.use bindings
158161
exprM <- MVar.withMVar bndrsV (pure . lookupVarEnv nm)
159162
let nmS = showPpr (varName nm)
163+
-- traceM ("normalize: start " <> nmS)
160164
case exprM of
161165
Just (Binding nm' sp inl pr tm r) -> do
162166
tcm <- Lens.view tcCache
@@ -196,11 +200,17 @@ normalize' nm = do
196200

197201
normV <- Lens.use (extra.normalized)
198202

199-
MVar.withMVar normV $ \norm ->
200-
let prevNorm = mapVarEnv bindingId norm
201-
toNormalize = filter (`notElemVarSet` topEnts)
202-
$ filter (`notElemVarEnv` extendVarEnv nm nm prevNorm) usedBndrs
203-
in return (toNormalize,(nm,tmNorm))
203+
toNormalize <-
204+
MVar.withMVar normV $ \norm ->
205+
let prevNorm = mapVarEnv bindingId norm
206+
toNormalize = filter (`notElemVarSet` topEnts)
207+
$ filter (`notElemVarEnv` extendVarEnv nm nm prevNorm) usedBndrs
208+
in pure toNormalize
209+
210+
-- traceM ("normalize: end: " <> nmS)
211+
212+
normChildren <- Async.mapConcurrently normalize' toNormalize
213+
return ((nm, tmNorm) : concat normChildren)
204214
else
205215
do
206216
-- Throw an error for unrepresentable topEntities and functions
@@ -222,7 +232,7 @@ normalize' nm = do
222232
, showPpr (coreTypeOf nm')
223233
, ") has a non-representable return type."
224234
, " Not normalising:\n", showPpr tm] )
225-
(return ([],(nm,(Binding nm' sp inl pr tm r))))
235+
(return [(nm,(Binding nm' sp inl pr tm r))])
226236

227237

228238
Nothing -> error $ $(curLoc) ++ "Expr belonging to bndr: " ++ nmS ++ " not found"
@@ -354,18 +364,22 @@ flattenCallTree (CBranch (nm,(Binding nm' sp inl pr tm r)) used) = do
354364
-- NB: When -fclash-debug-history is on, emit binary data holding the recorded rewrite steps
355365
opts <- Lens.view debugOpts
356366
let rewriteHistFile = dbg_historyFile opts
357-
when (Maybe.isJust rewriteHistFile) $
358-
let !_ = unsafePerformIO
359-
$ BS.appendFile (Maybe.fromJust rewriteHistFile)
360-
$ BL.toStrict
361-
$ encode RewriteStep
362-
{ t_ctx = []
363-
, t_name = "INLINE"
364-
, t_bndrS = showPpr (varName nm')
365-
, t_before = tm
366-
, t_after = tm1
367-
}
368-
in pure ()
367+
368+
when (Maybe.isJust rewriteHistFile) $ do
369+
lock <- Lens.use ioLock
370+
371+
MVar.withMVar lock $ \() ->
372+
Monad.liftIO
373+
. BS.appendFile (Maybe.fromJust rewriteHistFile)
374+
. BL.toStrict
375+
$ encode RewriteStep
376+
{ t_ctx = []
377+
, t_name = "INLINE"
378+
, t_bndrS = showPpr (varName nm')
379+
, t_before = tm
380+
, t_after = tm1
381+
}
382+
369383
rewriteExpr ("flattenExpr",flatten) (showPpr nm, tm1) (nm', sp)
370384
let allUsed = newUsed ++ concat il_used
371385
-- inline all components when the resulting expression after flattening

clash-lib/src/Clash/Normalize/Transformations/Case.hs

+9-7
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ import Clash.Normalize.Types (NormRewrite, NormalizeSession)
7676
import Clash.Rewrite.Combinators ((>-!))
7777
import Clash.Rewrite.Types
7878
( TransformContext(..), bindings, customReprs, debugOpts, tcCache
79-
, typeTranslator, workFreeBinders)
79+
, typeTranslator, workFreeBinders, ioLock)
8080
import Clash.Rewrite.Util (changed, isFromInt, whnfRW)
8181
import Clash.Rewrite.WorkFree
8282
import Clash.Util (curLoc)
@@ -309,14 +309,16 @@ caseCon' ctx@(TransformContext is0 _) e@(Case subj ty alts) = do
309309
-> caseCon ctx1 (Case (Literal (IntegerLiteral 0)) ty alts)
310310
_ -> do
311311
opts <- Lens.view debugOpts
312+
ioLockV <- Lens.use ioLock
312313
-- When invariants are being checked, report missing evaluation
313314
-- rules for the primitive evaluator.
314-
traceIf (dbg_invariants opts && isConstant subj)
315-
("Unmatchable constant as case subject: " ++ showPpr subj ++
316-
"\nWHNF is: " ++ showPpr subj1)
317-
-- Otherwise check whether the entire case-expression has a
318-
-- single alternative, and pick that one.
319-
(caseOneAlt e)
315+
MVar.withMVar ioLockV $ \() ->
316+
traceIf (dbg_invariants opts && isConstant subj)
317+
("Unmatchable constant as case subject: " ++ showPpr subj ++
318+
"\nWHNF is: " ++ showPpr subj1)
319+
-- Otherwise check whether the entire case-expression has a
320+
-- single alternative, and pick that one.
321+
(caseOneAlt e)
320322

321323
-- The subject is a variable
322324
(Var v, [], _) | isNum0 (coreTypeOf v) ->

0 commit comments

Comments
 (0)