diff --git a/.ci/bindist/linux/debian/focal/buildinfo.json b/.ci/bindist/linux/debian/focal/buildinfo.json index 70e7a6c340..139e702dc6 100644 --- a/.ci/bindist/linux/debian/focal/buildinfo.json +++ b/.ci/bindist/linux/debian/focal/buildinfo.json @@ -26,6 +26,16 @@ "src": {"type": "hackage", "version": "0.7.0.12"}, "cabal_debian_options": ["--disable-tests"] }, + { + "name": "atomic-primops", + "src": {"type": "hackage", "version": "0.8.4"}, + "cabal_debian_options": ["--disable-tests"] + }, + { + "name": "lockfree-queue", + "src": {"type": "hackage", "version": "0.2.3.1"}, + "cabal_debian_options": ["--disable-tests"] + }, { "name": "ghc-tcplugins-extra", "src": {"type": "hackage"} diff --git a/benchmark/benchmark-normalization.hs b/benchmark/benchmark-normalization.hs index cc8070dbb2..62d7d9b9ef 100644 --- a/benchmark/benchmark-normalization.hs +++ b/benchmark/benchmark-normalization.hs @@ -15,6 +15,7 @@ import Clash.Netlist.Types (TopEntityT(topId)) import Criterion.Main +import qualified Control.Concurrent.MVar as MVar import qualified Control.Concurrent.Supply as Supply import Control.DeepSeq (NFData(..), rwhnf) import Data.List (isPrefixOf, partition) @@ -42,7 +43,7 @@ main = do benchFile :: [FilePath] -> FilePath -> Benchmark benchFile idirs src = env (setupEnv idirs src) $ - \ ~(clashEnv, clashDesign, supplyN) -> do + \ ~(clashEnv, clashDesign, supplyN, lock) -> do bench ("normalization of " ++ src) (nfIO (normalizeEntity @@ -51,6 +52,7 @@ benchFile idirs src = (ghcTypeToHWType (opt_intWidth (envOpts clashEnv))) ghcEvaluator evaluator + lock (fmap topId (designEntities clashDesign)) supplyN (topId (head (designEntities clashDesign))))) @@ -58,11 +60,12 @@ benchFile idirs src = setupEnv :: [FilePath] -> FilePath - -> IO (ClashEnv, ClashDesign, Supply.Supply) + -> IO (ClashEnv, ClashDesign, Supply.Supply, MVar.MVar ()) setupEnv idirs src = do (clashEnv, clashDesign) <- runInputStage idirs src supplyN <- Supply.newSupply - return (clashEnv, clashDesign ,supplyN) + lock <- MVar.newMVar () + return (clashEnv, clashDesign ,supplyN, lock) instance NFData Supply.Supply where rnf = rwhnf diff --git a/benchmark/common/BenchmarkCommon.hs b/benchmark/common/BenchmarkCommon.hs index c3390ffa25..cbb93f7b7c 100644 --- a/benchmark/common/BenchmarkCommon.hs +++ b/benchmark/common/BenchmarkCommon.hs @@ -15,6 +15,7 @@ import Clash.GHC.Evaluator import Clash.GHC.GenerateBindings import Clash.GHC.NetlistTypes +import qualified Control.Concurrent.MVar as MVar import qualified Control.Concurrent.Supply as Supply defaultTests :: [FilePath] @@ -57,6 +58,7 @@ runNormalisationStage -> IO (ClashEnv, ClashDesign, Id) runNormalisationStage idirs src = do supplyN <- Supply.newSupply + lock <- MVar.newMVar () (env, design) <- runInputStage idirs src let topEntityNames = fmap topId (designEntities design) let topEntity = head topEntityNames @@ -65,5 +67,6 @@ runNormalisationStage idirs src = do (ghcTypeToHWType (opt_intWidth (opts idirs))) ghcEvaluator evaluator + lock topEntityNames supplyN topEntity return (env, design{designBindings=transformedBindings},topEntity) diff --git a/benchmark/profiling/run/profile-normalization-run.hs b/benchmark/profiling/run/profile-normalization-run.hs index f02c00619b..c88f538e1f 100644 --- a/benchmark/profiling/run/profile-normalization-run.hs +++ b/benchmark/profiling/run/profile-normalization-run.hs @@ -7,6 +7,7 @@ import Clash.GHC.PartialEval import Clash.GHC.Evaluator import Clash.GHC.NetlistTypes (ghcTypeToHWType) +import qualified Control.Concurrent.MVar as MVar import qualified Control.Concurrent.Supply as Supply import Control.DeepSeq (deepseq) import Data.Binary (decode) @@ -32,6 +33,7 @@ main = do benchFile :: [FilePath] -> FilePath -> IO () benchFile idirs src = do supplyN <- Supply.newSupply + lock <- MVar.newMVar () (bindingsMap,tcm,tupTcm,primMap,reprs,topEntityNames,topEntity) <- setupEnv src putStrLn $ "Doing normalization of " ++ src @@ -47,6 +49,7 @@ benchFile idirs src = do (ghcTypeToHWType (opt_intWidth (envOpts clashEnv))) ghcEvaluator evaluator + lock topEntityNames supplyN topEntity res `deepseq` putStrLn ".. done\n" diff --git a/changelog/2022-03-21T11_09_46-05_00_concurrent_normalization b/changelog/2022-03-21T11_09_46-05_00_concurrent_normalization new file mode 100644 index 0000000000..09dbe95604 --- /dev/null +++ b/changelog/2022-03-21T11_09_46-05_00_concurrent_normalization @@ -0,0 +1 @@ +CHANGED: Add concurrent normalization flag [#2074](https://github.com/clash-lang/clash-compiler/pull/2074) diff --git a/clash-ghc/clash-ghc.cabal b/clash-ghc/clash-ghc.cabal index 9de908032c..95c2f23583 100644 --- a/clash-ghc/clash-ghc.cabal +++ b/clash-ghc/clash-ghc.cabal @@ -77,7 +77,7 @@ executable clash executable clashi Main-Is: src-ghc/Interactive.hs Build-Depends: base, clash-ghc - GHC-Options: -Wall -Wcompat -rtsopts -with-rtsopts=-A128m + GHC-Options: -Wall -Wcompat -threaded -rtsopts -with-rtsopts=-A128m if flag(dynamic) GHC-Options: -dynamic extra-libraries: pthread diff --git a/clash-ghc/src-ghc/Clash/GHC/ClashFlags.hs b/clash-ghc/src-ghc/Clash/GHC/ClashFlags.hs index 22a12b3cb3..109e04ddfd 100644 --- a/clash-ghc/src-ghc/Clash/GHC/ClashFlags.hs +++ b/clash-ghc/src-ghc/Clash/GHC/ClashFlags.hs @@ -89,6 +89,7 @@ flagsClash r = [ , defFlag "fclash-inline-workfree-limit" $ IntSuffix (liftEwM . setInlineWFLimit r) , defFlag "fclash-edalize" $ NoArg (liftEwM (setEdalize r)) , defFlag "fclash-no-render-enums" $ NoArg (liftEwM (setNoRenderEnums r)) + , defFlag "fclash-concurrent-normalization" $ NoArg (liftEwM (setConcurrentNormalization r)) ] -- | Print deprecated flag warning @@ -313,6 +314,9 @@ setAggressiveXOptBB r = modifyIORef r (\c -> c { opt_aggressiveXOptBB = True }) setEdalize :: IORef ClashOpts -> IO () setEdalize r = modifyIORef r (\c -> c { opt_edalize = True }) +setConcurrentNormalization :: IORef ClashOpts -> IO () +setConcurrentNormalization r = modifyIORef r (\c -> c { opt_concurrentNormalization = True }) + setRewriteHistoryFile :: IORef ClashOpts -> String -> IO () setRewriteHistoryFile r arg = do let fileNm = case drop (length "-fclash-debug-history=") arg of diff --git a/clash-lib/clash-lib.cabal b/clash-lib/clash-lib.cabal index 558a6674f5..da728f30e9 100644 --- a/clash-lib/clash-lib.cabal +++ b/clash-lib/clash-lib.cabal @@ -130,7 +130,6 @@ Library aeson-pretty >= 0.8 && < 0.9, ansi-terminal >= 0.8.0.0 && < 0.12, array, - async >= 2.2.0 && < 2.3, attoparsec >= 0.10.4.0 && < 0.15, base >= 4.11 && < 5, base16-bytestring >= 0.1.1 && < 1.1, @@ -155,6 +154,10 @@ Library hint >= 0.7 && < 0.10, interpolate >= 0.2.0 && < 1.0, lens >= 4.10 && < 5.1.0, + lifted-async >=0.10 && <0.11, + lifted-base >=0.2 && <0.3, + lockfree-queue >=0.2 && <0.3, + monad-control >=1.0 && <1.1, mtl >= 2.1.2 && < 2.3, ordered-containers >= 0.2 && < 0.3, prettyprinter >= 1.2.0.1 && < 1.8, @@ -166,6 +169,7 @@ Library text >= 1.2.2 && < 2.1, time >= 1.4.0.1 && < 1.14, transformers >= 0.5.2.0 && < 0.7, + transformers-base, trifecta >= 1.7.1.1 && < 2.2, vector >= 0.11 && < 1.0, vector-binary-instances >= 0.2.3.5 && < 0.3, diff --git a/clash-lib/src/Clash/Core/PartialEval/Monad.hs b/clash-lib/src/Clash/Core/PartialEval/Monad.hs index 6191c324fa..0752ffd561 100644 --- a/clash-lib/src/Clash/Core/PartialEval/Monad.hs +++ b/clash-lib/src/Clash/Core/PartialEval/Monad.hs @@ -1,5 +1,5 @@ {-| -Copyright : (C) 2020-2021, QBayLogic B.V. +Copyright : (C) 2020-2022, QBayLogic B.V. License : BSD2 (see the file LICENSE) Maintainer : QBayLogic B.V. @@ -84,7 +84,7 @@ import Clash.Core.Util (mkUniqSystemId, mkUniqSystemTyVar) import Clash.Core.Var (Id, TyVar, Var) import Clash.Core.VarEnv import Clash.Driver.Types (Binding(..)) -import Clash.Rewrite.WorkFree (isWorkFree) +import Clash.Rewrite.WorkFree (isWorkFreePure) {- NOTE [RWS monad] @@ -311,7 +311,11 @@ workFreeValue :: Value -> Eval Bool workFreeValue = \case VNeutral _ -> pure False VThunk x _ -> do - bindings <- fmap (fmap asTerm) . genvBindings <$> getGlobalEnv - isWorkFree workFreeCache bindings x + env <- getGlobalEnv + let bindings = fmap (fmap asTerm) (genvBindings env) + let (cache, wf) = isWorkFreePure (genvWorkCache env) bindings x + + modifyGlobalEnv (\genv -> genv { genvWorkCache = cache }) + pure wf _ -> pure True diff --git a/clash-lib/src/Clash/Core/PartialEval/NormalForm.hs b/clash-lib/src/Clash/Core/PartialEval/NormalForm.hs index d201aed0b3..5613a5e76a 100644 --- a/clash-lib/src/Clash/Core/PartialEval/NormalForm.hs +++ b/clash-lib/src/Clash/Core/PartialEval/NormalForm.hs @@ -1,6 +1,5 @@ {-| -Copyright : (C) 2020-2021, QBayLogic B.V., - 2022 , Google Inc. +Copyright : (C) 2020-2022, QBayLogic B.V. License : BSD2 (see the file LICENSE) Maintainer : QBayLogic B.V. @@ -29,11 +28,9 @@ module Clash.Core.PartialEval.NormalForm , Normal(..) , LocalEnv(..) , GlobalEnv(..) - , workFreeCache ) where import Control.Concurrent.Supply (Supply) -import Control.Lens (Lens', lens) import Data.IntMap.Strict (IntMap) import Data.Map.Strict (Map) @@ -201,6 +198,3 @@ data GlobalEnv = GlobalEnv -- ^ Cache for the results of isWorkFree. This is required to use -- Clash.Rewrite.WorkFree.isWorkFree. } - -workFreeCache :: Lens' GlobalEnv (VarEnv Bool) -workFreeCache = lens genvWorkCache (\env x -> env { genvWorkCache = x }) diff --git a/clash-lib/src/Clash/Core/VarEnv.hs b/clash-lib/src/Clash/Core/VarEnv.hs index 2e66c5ccea..37cbcc0fb6 100644 --- a/clash-lib/src/Clash/Core/VarEnv.hs +++ b/clash-lib/src/Clash/Core/VarEnv.hs @@ -38,12 +38,15 @@ module Clash.Core.VarEnv -- ** Conversions -- *** Lists , eltsVarEnv + , toListVarEnv + , listToVarEnv -- * Sets of variables , VarSet -- ** Construction , emptyVarSet , unitVarSet -- ** Modification + , extendVarSet , delVarSetByKey , unionVarSet , differenceVarSet @@ -260,6 +263,15 @@ eltsVarEnv -> [a] eltsVarEnv = eltsUniqMap +toListVarEnv :: VarEnv a -> [(Unique, a)] +toListVarEnv = toListUniqMap + +listToVarEnv + :: Uniquable a + => [(a, b)] + -> VarEnv b +listToVarEnv = listToUniqMap + -- | Does the variable exist in the environment elemVarEnv :: Var a diff --git a/clash-lib/src/Clash/Debug.hs b/clash-lib/src/Clash/Debug.hs index 8211d87dd1..8876b03eed 100644 --- a/clash-lib/src/Clash/Debug.hs +++ b/clash-lib/src/Clash/Debug.hs @@ -3,6 +3,7 @@ module Clash.Debug ( debugIsOn , traceIf + , traceWhen , module Debug.Trace ) where @@ -19,4 +20,9 @@ debugIsOn = False traceIf :: Bool -> String -> a -> a traceIf True msg = trace msg traceIf False _ = id + +traceWhen :: Monad m => Bool -> String -> m () +traceWhen True = traceM +traceWhen False = const (pure ()) + {-# INLINE traceIf #-} diff --git a/clash-lib/src/Clash/Driver.hs b/clash-lib/src/Clash/Driver.hs index 7cf662f0ff..fb01e4647c 100644 --- a/clash-lib/src/Clash/Driver.hs +++ b/clash-lib/src/Clash/Driver.hs @@ -22,7 +22,7 @@ module Clash.Driver where import Control.Concurrent (MVar, modifyMVar, modifyMVar_, newMVar, withMVar) -import Control.Concurrent.Async (mapConcurrently_) +import Control.Concurrent.Async.Lifted (mapConcurrently_) import qualified Control.Concurrent.Supply as Supply import Control.DeepSeq import Control.Exception (throw) @@ -443,7 +443,7 @@ generateHDL env design hdlState typeTrans peEval eval mainTopEntity startTime = -- 2. Normalize topEntity supplyN <- Supply.newSupply transformedBindings <- normalizeEntity env bindingsMap typeTrans peEval - eval topEntityNames supplyN topEntity + eval ioLockV topEntityNames supplyN topEntity normTime <- transformedBindings `deepseq` Clock.getCurrentTime let prepNormDiff = reportTimeDiff normTime prevTime @@ -1063,6 +1063,8 @@ normalizeEntity -- ^ Hardcoded evaluator for partial evaluation -> WHNF.Evaluator -- ^ Hardcoded evaluator for WHNF (old evaluator) + -> MVar () + -- ^ Synchronization for stdout -> [Id] -- ^ TopEntities -> Supply.Supply @@ -1070,14 +1072,14 @@ normalizeEntity -> Id -- ^ root of the hierarchy -> IO BindingMap -normalizeEntity env bindingsMap typeTrans peEval eval topEntities supply tm = transformedBindings +normalizeEntity env bindingsMap typeTrans peEval eval lock topEntities supply tm = transformedBindings where doNorm = do norm <- normalize [tm] let normChecked = checkNonRecursive norm cleaned <- cleanupGraph tm normChecked return cleaned transformedBindings = runNormalization env supply bindingsMap - typeTrans peEval eval emptyVarEnv + typeTrans peEval eval emptyVarEnv lock topEntities doNorm -- | topologically sort the top entities diff --git a/clash-lib/src/Clash/Driver/Types.hs b/clash-lib/src/Clash/Driver/Types.hs index f537dd3364..e7e840525f 100644 --- a/clash-lib/src/Clash/Driver/Types.hs +++ b/clash-lib/src/Clash/Driver/Types.hs @@ -392,6 +392,8 @@ data ClashOpts = ClashOpts , opt_renderEnums :: Bool -- ^ Render sum types with all zero-width fields as enums where supported, as -- opposed to rendering them as bitvectors. + , opt_concurrentNormalization :: Bool + -- ^ Toggle concurrent normalization (usually slower, faster on large designs) } deriving (Show) @@ -424,6 +426,7 @@ instance NFData ClashOpts where opt_inlineWFCacheLimit o `deepseq` opt_edalize o `deepseq` opt_renderEnums o `deepseq` + opt_concurrentNormalization o `deepseq` () instance Eq ClashOpts where @@ -454,7 +457,8 @@ instance Eq ClashOpts where opt_aggressiveXOptBB s0 == opt_aggressiveXOptBB s1 && opt_inlineWFCacheLimit s0 == opt_inlineWFCacheLimit s1 && opt_edalize s0 == opt_edalize s1 && - opt_renderEnums s0 == opt_renderEnums s1 + opt_renderEnums s0 == opt_renderEnums s1 && + opt_concurrentNormalization s0 == opt_concurrentNormalization s1 where eqOverridingBool :: OverridingBool -> OverridingBool -> Bool @@ -492,7 +496,8 @@ instance Hashable ClashOpts where opt_aggressiveXOptBB `hashWithSalt` opt_inlineWFCacheLimit `hashWithSalt` opt_edalize `hashWithSalt` - opt_renderEnums + opt_renderEnums `hashWithSalt` + opt_concurrentNormalization where hashOverridingBool :: Int -> OverridingBool -> Int hashOverridingBool s1 Auto = hashWithSalt s1 (0 :: Int) @@ -501,36 +506,36 @@ instance Hashable ClashOpts where infixl 0 `hashOverridingBool` defClashOpts :: ClashOpts -defClashOpts - = ClashOpts - { opt_werror = False - , opt_inlineLimit = 20 - , opt_specLimit = 20 - , opt_inlineFunctionLimit = 15 - , opt_inlineConstantLimit = 0 - , opt_evaluatorFuelLimit = 20 - , opt_debug = debugNone - , opt_cachehdl = True - , opt_clear = False - , opt_primWarn = True - , opt_color = Auto - , opt_intWidth = WORD_SIZE_IN_BITS - , opt_hdlDir = Nothing - , opt_hdlSyn = Other - , opt_errorExtra = False - , opt_importPaths = [] - , opt_componentPrefix = Nothing - , opt_newInlineStrat = True - , opt_escapedIds = True - , opt_lowerCaseBasicIds = PreserveCase - , opt_ultra = False - , opt_forceUndefined = Nothing - , opt_checkIDir = True - , opt_aggressiveXOpt = False - , opt_aggressiveXOptBB = False - , opt_inlineWFCacheLimit = 10 -- TODO: find "optimal" value - , opt_edalize = False - , opt_renderEnums = True +defClashOpts = ClashOpts + { opt_werror = False + , opt_inlineLimit = 20 + , opt_specLimit = 20 + , opt_inlineFunctionLimit = 15 + , opt_inlineConstantLimit = 0 + , opt_evaluatorFuelLimit = 20 + , opt_debug = debugNone + , opt_cachehdl = True + , opt_clear = False + , opt_primWarn = True + , opt_color = Auto + , opt_intWidth = WORD_SIZE_IN_BITS + , opt_hdlDir = Nothing + , opt_hdlSyn = Other + , opt_errorExtra = False + , opt_importPaths = [] + , opt_componentPrefix = Nothing + , opt_newInlineStrat = True + , opt_escapedIds = True + , opt_lowerCaseBasicIds = PreserveCase + , opt_ultra = False + , opt_forceUndefined = Nothing + , opt_checkIDir = True + , opt_aggressiveXOpt = False + , opt_aggressiveXOptBB = False + , opt_inlineWFCacheLimit = 10 -- TODO: find "optimal" value + , opt_edalize = False + , opt_renderEnums = True + , opt_concurrentNormalization = False } -- | Synopsys Design Constraint (SDC) information for a component. diff --git a/clash-lib/src/Clash/Normalize.hs b/clash-lib/src/Clash/Normalize.hs index 70d7e504af..ae21ff8deb 100644 --- a/clash-lib/src/Clash/Normalize.hs +++ b/clash-lib/src/Clash/Normalize.hs @@ -10,26 +10,34 @@ -} {-# LANGUAGE CPP #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} module Clash.Normalize where -import Control.Concurrent.Supply (Supply) +import qualified Control.Concurrent.Async.Lifted as Async +import Control.Concurrent.MVar.Lifted (MVar) +import qualified Control.Concurrent.MVar.Lifted as MVar +import Control.Concurrent.Supply (Supply, splitSupply) import Control.Exception (throw) import qualified Control.Lens as Lens import Control.Monad (when) +import qualified Control.Monad.IO.Class as Monad (liftIO) import Control.Monad.State.Strict (State) +import Data.Bifunctor (second) import Data.Default (def) import Data.Either (lefts,partitionEithers) -import qualified Data.IntMap as IntMap +import Data.Foldable (traverse_) +import qualified Data.HashMap.Strict as HashMap import Data.List (intersect, mapAccumL) import qualified Data.Map as Map import qualified Data.Maybe as Maybe import qualified Data.Set as Set import qualified Data.Set.Lens as Lens +import qualified Data.Concurrent.Queue.MichaelScott as MS #if MIN_VERSION_prettyprinter(1,7,0) import Prettyprinter (vcat) @@ -61,9 +69,10 @@ import Clash.Core.TyCon (TyConMap) import Clash.Core.Type (isPolyTy) import Clash.Core.Var (Id, varName, varType) import Clash.Core.VarEnv - (VarEnv, elemVarSet, eltsVarEnv, emptyInScopeSet, emptyVarEnv, - extendVarEnv, lookupVarEnv, mapVarEnv, mapMaybeVarEnv, - mkVarEnv, mkVarSet, notElemVarEnv, notElemVarSet, nullVarEnv, unionVarEnv) + (VarEnv, VarSet, elemVarSet, eltsVarEnv, emptyInScopeSet, emptyVarEnv, emptyVarSet, + extendVarEnv, extendVarSet, lookupVarEnv, mapMaybeVarEnv, + mkVarEnv, mkVarSet, notElemVarEnv, notElemVarSet, nullVarEnv, + listToVarEnv, toListVarEnv) import Clash.Debug (traceIf) import Clash.Driver.Types (BindingMap, Binding(..), DebugOpts(..), ClashEnv(..)) @@ -77,8 +86,8 @@ import Clash.Normalize.Types import Clash.Normalize.Util import Clash.Rewrite.Combinators ((>->),(!->),repeatR,topdownR) import Clash.Rewrite.Types - (RewriteEnv (..), RewriteState (..), bindings, debugOpts, extra, - tcCache, topEntities, newInlineStrategy) + (RewriteEnv (..), RewriteState (..), bindings, debugOpts, extra, uniqSupply, + tcCache, topEntities, newInlineStrategy, ioLock) import Clash.Rewrite.Util (apply, isUntranslatableType, runRewriteSession) import Clash.Util @@ -88,10 +97,8 @@ import Data.Binary (encode) import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL -import System.IO.Unsafe (unsafePerformIO) import Clash.Rewrite.Types (RewriteStep(..)) - -- | Run a NormalizeSession in a given environment runNormalization :: ClashEnv @@ -108,53 +115,85 @@ runNormalization -- ^ Hardcoded evaluator for WHNF (old evaluator) -> VarEnv Bool -- ^ Map telling whether a components is part of a recursive group + -> MVar () + -- ^ Synchronization on stdout -> [Id] -- ^ topEntities -> NormalizeSession a -- ^ NormalizeSession to run -> IO a -runNormalization env supply globals typeTrans peEval eval rcsMap topEnts = - runRewriteSession rwEnv rwState - where - -- TODO The RewriteEnv should just take ClashOpts. - rwEnv = RewriteEnv - env - typeTrans - peEval - eval - (mkVarSet topEnts) - - rwState = RewriteState - 0 - mempty -- transformCounters Map - globals - supply - (error $ $(curLoc) ++ "Report as bug: no curFun",noSrcSpan) - 0 - (IntMap.empty, 0) - emptyVarEnv - normState - - normState = NormalizeState - emptyVarEnv - Map.empty - emptyVarEnv - emptyVarEnv - Map.empty - rcsMap - -normalize - :: [Id] - -> NormalizeSession BindingMap -normalize [] = return emptyVarEnv -normalize top = do - (new,topNormalized) <- unzip <$> mapM normalize' top - newNormalized <- normalize (concat new) - return (unionVarEnv (mkVarEnv topNormalized) newNormalized) - -normalize' :: Id -> NormalizeSession ([Id], (Id, Binding Term)) -normalize' nm = do - exprM <- lookupVarEnv nm <$> Lens.use bindings +runNormalization env supply globals typeTrans peEval eval rcsMap lock entities session = do + normState <- NormalizeState + <$> MVar.newMVar emptyVarEnv + <*> MVar.newMVar Map.empty + <*> MVar.newMVar emptyVarEnv + <*> MVar.newMVar emptyVarEnv + <*> MVar.newMVar Map.empty + <*> MVar.newMVar rcsMap + + rwState <- RewriteState + <$> MVar.newMVar mempty + <*> MVar.newMVar globals + <*> pure supply + <*> MVar.newMVar HashMap.empty + <*> MVar.newMVar 0 + <*> MVar.newMVar (mempty, 0) + <*> MVar.newMVar emptyVarEnv + <*> pure lock + <*> pure normState + + runRewriteSession rwEnv rwState session + where + rwEnv = RewriteEnv + { _clashEnv = env + , _typeTranslator = typeTrans + , _peEvaluator = peEval + , _evaluator = eval + , _topEntities = mkVarSet entities + } + +supplies :: Int -> Supply -> [Supply] +supplies 0 _ = [] +supplies n s = let (s0', s1') = splitSupply s in s0' : supplies (n-1) s1' + +normalize :: [Id] -> NormalizeSession BindingMap +normalize tops = do + q <- Monad.liftIO MS.newQ + traverse_ (Monad.liftIO . MS.pushL q) tops + binds <- MVar.newMVar (emptyVarSet, []) + uniq0 <- Lens.use uniqSupply + let ss = supplies (length tops) uniq0 + -- one thread per top-level binding + Async.mapConcurrently_ (normalizeStep q binds) ss + mkVarEnv . snd <$> MVar.readMVar binds + +normalizeStep + :: MS.LinkedQueue Id + -> MVar (VarSet, [(Id, Binding Term)]) + -> Supply + -> NormalizeSession () +normalizeStep q binds s = do + uniqSupply Lens..= s + res <- Monad.liftIO $ MS.tryPopR q + case res of + Just id' -> do + (bound, pairs) <- MVar.takeMVar binds + if not (id' `elemVarSet` bound) + then do + -- mark that we are attempting to normalize id' + MVar.putMVar binds (bound `extendVarSet` id', pairs) + pair <- normalize' id' q + MVar.modifyMVar_ binds (pure . second (pair:)) + else + MVar.putMVar binds (bound, pairs) + nextS <- Lens.use uniqSupply + normalizeStep q binds nextS + Nothing -> pure () + +normalize' :: Id -> MS.LinkedQueue Id -> NormalizeSession (Id, Binding Term) +normalize' nm q = do + bndrsV <- Lens.use bindings + exprM <- MVar.withMVar bndrsV (pure . lookupVarEnv nm) let nmS = showPpr (varName nm) case exprM of Just (Binding nm' sp inl pr tm r) -> do @@ -192,10 +231,18 @@ normalize' nm = do , ") remains recursive after normalization:\n" , showPpr (bindingTerm tmNorm) ]) (return ()) - prevNorm <- mapVarEnv bindingId <$> Lens.use (extra.normalized) - let toNormalize = filter (`notElemVarSet` topEnts) - $ filter (`notElemVarEnv` (extendVarEnv nm nm prevNorm)) usedBndrs - return (toNormalize,(nm,tmNorm)) + + normV <- Lens.use (extra.normalized) + + toNormalize <- + MVar.withMVar normV $ \norm -> do + prevNorm <- listToVarEnv <$> traverse (\(k, v) -> (k,) . bindingId <$> MVar.readMVar v) (toListVarEnv norm) + let toNormalize = filter (`notElemVarSet` topEnts) + $ filter (`notElemVarEnv` extendVarEnv nm nm prevNorm) usedBndrs + in pure toNormalize + + traverse_ (Monad.liftIO . MS.pushL q) toNormalize + pure (nm, tmNorm) else do -- Throw an error for unrepresentable topEntities and functions @@ -217,7 +264,7 @@ normalize' nm = do , showPpr (coreTypeOf nm') , ") has a non-representable return type." , " Not normalising:\n", showPpr tm] ) - (return ([],(nm,(Binding nm' sp inl pr tm r)))) + (return (nm,(Binding nm' sp inl pr tm r))) Nothing -> error $ $(curLoc) ++ "Expr belonging to bndr: " ++ nmS ++ " not found" @@ -349,18 +396,22 @@ flattenCallTree (CBranch (nm,(Binding nm' sp inl pr tm r)) used) = do -- NB: When -fclash-debug-history is on, emit binary data holding the recorded rewrite steps opts <- Lens.view debugOpts let rewriteHistFile = dbg_historyFile opts - when (Maybe.isJust rewriteHistFile) $ - let !_ = unsafePerformIO - $ BS.appendFile (Maybe.fromJust rewriteHistFile) - $ BL.toStrict - $ encode RewriteStep - { t_ctx = [] - , t_name = "INLINE" - , t_bndrS = showPpr (varName nm') - , t_before = tm - , t_after = tm1 - } - in pure () + + when (Maybe.isJust rewriteHistFile) $ do + lock <- Lens.use ioLock + + MVar.withMVar lock $ \() -> + Monad.liftIO + . BS.appendFile (Maybe.fromJust rewriteHistFile) + . BL.toStrict + $ encode RewriteStep + { t_ctx = [] + , t_name = "INLINE" + , t_bndrS = showPpr (varName nm') + , t_before = tm + , t_after = tm1 + } + rewriteExpr ("flattenExpr",flatten) (showPpr nm, tm1) (nm', sp) let allUsed = newUsed ++ concat il_used -- inline all components when the resulting expression after flattening diff --git a/clash-lib/src/Clash/Normalize/PrimitiveReductions.hs b/clash-lib/src/Clash/Normalize/PrimitiveReductions.hs index 0d401a2455..d626a04738 100644 --- a/clash-lib/src/Clash/Normalize/PrimitiveReductions.hs +++ b/clash-lib/src/Clash/Normalize/PrimitiveReductions.hs @@ -1,7 +1,7 @@ {-| Copyright : (C) 2015-2016, University of Twente, 2016 , Myrtle Software Ltd, - 2021 , QBayLogic B.V. + 2021-2022, QBayLogic B.V. License : BSD2 (see the file LICENSE) Maintainer : QBayLogic B.V. @@ -428,7 +428,6 @@ reduceIterateI (TransformContext is0 ctx) n aTy vTy f0 a = do tcm <- Lens.view tcCache f1 <- constantPropagation (TransformContext is0 (AppArg Nothing:ctx)) f0 - -- Generate uniq ids for element assignments. uniqs0 <- Lens.use uniqSupply let is1 = extendInScopeSetList is0 (collectTermIds f1) @@ -919,28 +918,29 @@ reduceUnconcat inScope unconcatPrimInfo n m aTy sm arg = do (uniqs1,(vars,headsAndTails)) = second (second concat . unzip) (extractElems uniqs0 inScope consCon aTy 'U' (n*m) arg) + -- Build a vector out of the first m elements - mvec = mkVec nilCon consCon aTy m (take (fromInteger m) vars) - -- Get the vector representing the next ((n-1)*m) elements - -- N.B. `extractElems (xs :: Vec 2 a)` creates: - -- x0 = head xs - -- xs0 = tail xs - -- x1 = head xs0 - -- xs1 = tail xs0 - (lbs,head -> nextVec) = splitAt ((2*fromInteger m)-1) headsAndTails - -- recursively call unconcat - nextUnconcat = mkApps (Prim unconcatPrimInfo) - [ Right (LitTy (NumTy (n-1))) - , Right (LitTy (NumTy m)) - , Right aTy - , Left (Literal (NaturalLiteral (n-1))) - , Left sm - , Left (snd nextVec) - ] - -- let (mvec,nextVec) = splitAt sm arg - -- in Cons mvec (unconcat sm nextVec) - lBody = mkVecCons consCon innerVecTy n mvec nextUnconcat - lb = Letrec lbs lBody + let mvec = mkVec nilCon consCon aTy m (take (fromInteger m) vars) + -- Get the vector representing the next ((n-1)*m) elements + -- N.B. `extractElems (xs :: Vec 2 a)` creates: + -- x0 = head xs + -- xs0 = tail xs + -- x1 = head xs0 + -- xs1 = tail xs0 + (lbs,head -> nextVec) = splitAt ((2*fromInteger m)-1) headsAndTails + -- recursively call unconcat + nextUnconcat = mkApps (Prim unconcatPrimInfo) + [ Right (LitTy (NumTy (n-1))) + , Right (LitTy (NumTy m)) + , Right aTy + , Left (Literal (NaturalLiteral (n-1))) + , Left sm + , Left (snd nextVec) + ] + -- let (mvec,nextVec) = splitAt sm arg + -- in Cons mvec (unconcat sm nextVec) + lBody = mkVecCons consCon innerVecTy n mvec nextUnconcat + lb = Letrec lbs lBody uniqSupply Lens..= uniqs1 changed lb @@ -1078,7 +1078,6 @@ reduceReplace_int is0 n aTy vTy v i newA = do (Just iTc) = lookupUniqMap iTcNm tcm [iDc] = tyConDataCons iTc - -- Get elements from vector (uniqs1,(vars,elems)) = second (second concat . unzip) $ extractElems uniqs0 diff --git a/clash-lib/src/Clash/Normalize/Transformations/Case.hs b/clash-lib/src/Clash/Normalize/Transformations/Case.hs index 9b7562a724..a51042125d 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/Case.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/Case.hs @@ -26,6 +26,7 @@ module Clash.Normalize.Transformations.Case , elimExistentials ) where +import qualified Control.Concurrent.MVar.Lifted as MVar import qualified Control.Lens as Lens import Control.Monad.State.Strict (evalState) import Data.Bifunctor (second) @@ -75,7 +76,7 @@ import Clash.Normalize.Types (NormRewrite, NormalizeSession) import Clash.Rewrite.Combinators ((>-!)) import Clash.Rewrite.Types ( TransformContext(..), bindings, customReprs, debugOpts, tcCache - , typeTranslator, workFreeBinders) + , typeTranslator, workFreeBinders, ioLock) import Clash.Rewrite.Util (changed, isFromInt, whnfRW) import Clash.Rewrite.WorkFree import Clash.Util (curLoc) @@ -237,8 +238,9 @@ caseCon' ctx@(TransformContext is0 _) e@(Case subj ty alts) = do -- based on the fact on whether the argument has the potential to make -- the circuit larger than needed if we were to duplicate that argument. newBinder (isN0, substN) (x, arg) = do - bndrs <- Lens.use bindings - isWorkFree workFreeBinders bndrs arg >>= \case + bindingsV <- Lens.use bindings + wf <- MVar.withMVar bindingsV (\bndrs -> isWorkFree workFreeBinders bndrs arg) + case wf of True -> pure ((isN0, (x, arg):substN), Nothing) False -> let @@ -316,14 +318,16 @@ caseCon' ctx@(TransformContext is0 _) e@(Case subj ty alts) = do -> caseCon ctx1 (Case (Literal (IntegerLiteral 0)) ty alts) _ -> do opts <- Lens.view debugOpts + ioLockV <- Lens.use ioLock -- When invariants are being checked, report missing evaluation -- rules for the primitive evaluator. - traceIf (dbg_invariants opts && isConstant subj) - ("Unmatchable constant as case subject: " ++ showPpr subj ++ - "\nWHNF is: " ++ showPpr subj1) - -- Otherwise check whether the entire case-expression has a - -- single alternative, and pick that one. - (caseOneAlt e) + MVar.withMVar ioLockV $ \() -> + traceIf (dbg_invariants opts && isConstant subj) + ("Unmatchable constant as case subject: " ++ showPpr subj ++ + "\nWHNF is: " ++ showPpr subj1) + -- Otherwise check whether the entire case-expression has a + -- single alternative, and pick that one. + (caseOneAlt e) -- The subject is a variable (Var v, [], _) | isNum0 (coreTypeOf v) -> diff --git a/clash-lib/src/Clash/Normalize/Transformations/Cast.hs b/clash-lib/src/Clash/Normalize/Transformations/Cast.hs index 222a5a91fe..7efccbdf68 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/Cast.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/Cast.hs @@ -9,9 +9,13 @@ module Clash.Normalize.Transformations.Cast , splitCastWork ) where +import Control.Concurrent.Lifted (myThreadId) +import qualified Control.Concurrent.MVar.Lifted as MVar import Control.Exception (throw) import qualified Control.Lens as Lens +import qualified Control.Monad as Monad (when) import Control.Monad.Writer (listen) +import qualified Data.HashMap.Strict as HashMap import qualified Data.Monoid as Monoid (Any(..)) import GHC.Stack (HasCallStack) @@ -22,11 +26,11 @@ import Clash.Core.TermInfo (isCast) import Clash.Core.Type (normalizeType) import Clash.Core.Var (isGlobalId, varName) import Clash.Core.VarEnv (InScopeSet) -import Clash.Debug (trace) +import Clash.Debug (traceM) import Clash.Normalize.Transformations.Specialize (specialize) import Clash.Normalize.Types (NormRewrite, NormalizeSession) import Clash.Rewrite.Types - (TransformContext(..), bindings, curFun, tcCache, workFreeBinders) + (TransformContext(..), bindings, curFun, tcCache, workFreeBinders, ioLock) import Clash.Rewrite.Util (changed, mkDerivedName, mkTmBinderFor) import Clash.Rewrite.WorkFree (isWorkFree) import Clash.Util (ClashException(..), curLoc) @@ -58,18 +62,23 @@ argCastSpec ctx e@(App f (stripTicks -> Cast e' _ _)) -- We can only push casts into global binders , (Var g, _) <- collectArgs f , isGlobalId g = do - bndrs <- Lens.use bindings - isWorkFree workFreeBinders bndrs e' >>= \case - True -> go - False -> warn go + bndrsV <- Lens.use bindings + wf <- MVar.withMVar bndrsV (\bndrs -> isWorkFree workFreeBinders bndrs e') + + ioLockV <- Lens.use ioLock + + Monad.when (not wf) $ + MVar.withMVar ioLockV $ \() -> traceM warn + + specialize ctx e where - go = specialize ctx e - warn = trace (unwords + warn = unwords [ "WARNING:", $(curLoc), "specializing a function on a non work-free" , "cast. Generated HDL implementation might contain duplicate work." , "Please report this as a bug.", "\n\nExpression where this occured:" , "\n\n" ++ showPpr e - ]) + ] + argCastSpec _ e = return e {-# SCC argCastSpec #-} @@ -96,7 +105,9 @@ elimCastCast _ c@(Cast (stripTicks -> Cast e tyA tyB) tyB' tyC) = do if ntyB == ntyB' && ntyA == ntyC then changed e else throwError where throwError = do - (nm,sp) <- Lens.use curFun + curFunsV <- Lens.use curFun + thread <- myThreadId + Just (nm,sp) <- MVar.withMVar curFunsV (pure . HashMap.lookup thread) throw (ClashException sp ($(curLoc) ++ showPpr nm ++ ": Found 2 nested casts whose types don't line up:\n" ++ showPpr c) diff --git a/clash-lib/src/Clash/Normalize/Transformations/DEC.hs b/clash-lib/src/Clash/Normalize/Transformations/DEC.hs index 3a2fa5beae..ecaa4ad163 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/DEC.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/DEC.hs @@ -1,6 +1,6 @@ {-| Copyright : (C) 2015-2016, University of Twente, - 2021, QBayLogic B.V. + 2021-2022, QBayLogic B.V. License : BSD2 (see the file LICENSE) Maintainer : QBayLogic B.V. @@ -34,6 +34,7 @@ module Clash.Normalize.Transformations.DEC ( disjointExpressionConsolidation ) where +import qualified Control.Concurrent.MVar.Lifted as MVar import Control.Concurrent.Supply (splitSupply) import Control.Lens ((^.), _1) import qualified Control.Lens as Lens @@ -292,13 +293,19 @@ collectGlobals' is0 substitution seen (Case scrut ty alts) _eIsConstant = do collectGlobals' is0 substitution seen e@(collectArgsTicks -> (fun, args@(_:_), ticks)) eIsconstant | not eIsconstant = do tcm <- Lens.view tcCache - bndrs <- Lens.use bindings + bndrsV <- Lens.use bindings evaluate <- Lens.view evaluator ids <- Lens.use uniqSupply let (ids1,ids2) = splitSupply ids uniqSupply Lens..= ids2 - gh <- Lens.use globalHeap - let eval = (Lens.view Lens._3) . whnf' evaluate bndrs tcm gh ids1 is0 False + + ghV <- Lens.use globalHeap + + eval <- + MVar.withMVar bndrsV $ \bndrs -> + MVar.withMVar ghV $ \gh -> + pure $ (Lens.view Lens._3) . whnf' evaluate bndrs tcm gh ids1 is0 False + let eTy = inferCoreTypeOf tcm e untran <- isUntranslatableType False eTy case untran of diff --git a/clash-lib/src/Clash/Normalize/Transformations/Inline.hs b/clash-lib/src/Clash/Normalize/Transformations/Inline.hs index c37dca4b73..8388471d7a 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/Inline.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/Inline.hs @@ -31,6 +31,8 @@ module Clash.Normalize.Transformations.Inline , inlineWorkFree ) where +import Control.Concurrent.Lifted (myThreadId) +import qualified Control.Concurrent.MVar.Lifted as MVar import qualified Control.Lens as Lens import qualified Control.Monad as Monad import Control.Monad.Trans.Maybe (MaybeT(..)) @@ -38,6 +40,7 @@ import Control.Monad.Writer ((>=>),lift,listen) import Data.Default (Default(..)) import Data.Either (lefts) import qualified Data.HashMap.Lazy as HashMap +import qualified Data.HashMap.Strict as HashMapS import qualified Data.List as List import qualified Data.Maybe as Maybe import qualified Data.Monoid as Monoid (Any(..)) @@ -76,7 +79,7 @@ import Clash.Core.VarEnv , eltsVarEnv, emptyVarEnv, extendInScopeSetList, extendVarEnv , foldlWithUniqueVarEnv', lookupVarEnv, lookupVarEnvDirectly, mkVarEnv , notElemVarSet, unionVarEnv, unionVarEnvWith, unitVarSet) -import Clash.Debug (trace) +import Clash.Debug (traceM) import Clash.Driver.Types (Binding(..)) import Clash.Netlist.Util (representableType) import Clash.Primitives.Types @@ -85,7 +88,7 @@ import Clash.Rewrite.Combinators (allR) import Clash.Rewrite.Types ( TransformContext(..), bindings, curFun, customReprs, tcCache, topEntities , typeTranslator, inlineConstantLimit, inlineFunctionLimit, inlineLimit - , inlineWFCacheLimit, primitives) + , inlineWFCacheLimit, primitives, ioLock) import Clash.Rewrite.Util ( changed, inlineBinders, inlineOrLiftBinders, isJoinPointIn , isUntranslatable, isUntranslatableType, isVoidWrapper, zoomExtra) @@ -398,32 +401,40 @@ collapseRHSNoops _ (Letrec binds body) = do runCollapseNoop orig = runMaybeT (collapseNoop orig) >>= Maybe.maybe (return orig) changed - collapseNoop (iD,term) = do + collapseNoop :: (Id, Term) -> MaybeT NormalizeSession (Id, Term) + collapseNoop (iD, term) = do (Prim info,args) <- return $ collectArgs term identity <- getIdentity info $ lefts args collapsed <- collapseToIdentity iD identity return (iD,collapsed) + collapseToIdentity :: Id -> Term -> MaybeT NormalizeSession Term collapseToIdentity iD identity = do tcm <- Lens.view tcCache let aTy = inferCoreTypeOf tcm identity bTy = coreTypeOf iD return $ primUCo `TyApp` aTy `TyApp` bTy `App` identity + getIdentity :: PrimInfo -> [Term] -> MaybeT NormalizeSession Term getIdentity primInfo termArgs = do WorkIdentity idIdx noopIdxs <- return $ primWorkInfo primInfo mapM_ (getTermArg termArgs >=> isNoop >=> Monad.guard) noopIdxs getTermArg termArgs idIdx + getTermArg :: [Term] -> Int -> MaybeT NormalizeSession Term getTermArg args i = do Monad.guard $ i <= length args - 1 return $ args !! i + isNoop :: Term -> MaybeT NormalizeSession Bool isNoop (Var i) = do - binding <- MaybeT $ lookupVarEnv i <$> Lens.use bindings - isRecursive <- lift $ isRecursiveBndr $ bindingId binding + bindingsV <- Lens.use bindings + binding <- MVar.withMVar bindingsV (MaybeT . pure . lookupVarEnv i) + isRecursive <- lift $ isRecursiveBndr (bindingId binding) + Monad.guard $ not isRecursive isNoop $ bindingTerm binding + isNoop (Prim PrimInfo{primWorkInfo=WorkIdentity _ []}) = return True isNoop (Lam x e) = isNoopApp x (collectArgs e) isNoop _ = return False @@ -498,7 +509,9 @@ inlineNonRepWorker e@(Case scrut altsTy alts) | (Var f, args,ticks) <- collectArgsTicks scrut , isGlobalId f = do - (cf,_) <- Lens.use curFun + curFunsV <- Lens.use curFun + thread <- myThreadId + Just (cf,_) <- MVar.withMVar curFunsV (pure . HashMapS.lookup thread) isInlined <- zoomExtra (alreadyInlined f cf) limit <- Lens.view inlineLimit tcm <- Lens.view tcCache @@ -511,7 +524,8 @@ inlineNonRepWorker e@(Case scrut altsTy alts) overLimit = notClassTy && (Maybe.fromMaybe 0 isInlined) > limit - bodyMaybe <- lookupVarEnv f <$> Lens.use bindings + bindingsV <- Lens.use bindings + bodyMaybe <- MVar.withMVar bindingsV (pure . lookupVarEnv f) nonRepScrut <- not <$> (representableType <$> Lens.view typeTranslator <*> Lens.view customReprs <*> pure False @@ -519,19 +533,24 @@ inlineNonRepWorker e@(Case scrut altsTy alts) <*> pure scrutTy) case (nonRepScrut, bodyMaybe) of (True, Just b) -> do - if overLimit then - trace ($(curLoc) ++ [I.i| - InlineNonRep: #{showPpr (varName f)} already inlined - #{limit} times in: #{showPpr (varName cf)}. The type of the subject - is: + if overLimit then do + ioLockV <- Lens.use ioLock + + MVar.withMVar ioLockV $ \() -> + traceM ($(curLoc) ++ [I.i| + InlineNonRep: #{showPpr (varName f)} already inlined + #{limit} times in: #{showPpr (varName cf)}. The type of the subject + is: - #{showPpr' def{displayTypes=True\} scrutTy} + #{showPpr' def{displayTypes=True\} scrutTy} - Function #{showPpr (varName cf)} will not reach a normal form and - compilation might fail. + Function #{showPpr (varName cf)} will not reach a normal form and + compilation might fail. - Run with '-fclash-inline-limit=N' to increase the inline limit to N. - |]) (return e) + Run with '-fclash-inline-limit=N' to increase the inline limit to N. + |]) + + return e else do Monad.when notClassTy (zoomExtra (addNewInline f cf)) @@ -596,9 +615,11 @@ inlineSmall _ e@(collectArgsTicks -> (Var f,args,ticks)) = do if untranslatable || f `elemVarSet` topEnts || lv then return e else do - bndrs <- Lens.use bindings sizeLimit <- Lens.view inlineFunctionLimit - case lookupVarEnv f bndrs of + bndrsV <- Lens.use bindings + mBind <- MVar.withMVar bndrsV (pure . lookupVarEnv f) + + case mBind of -- Don't inline recursive expressions Just b -> do isRecBndr <- isRecursiveBndr f @@ -631,8 +652,9 @@ inlineWorkFree _ e@(collectArgsTicks -> (Var f,args@(_:_),ticks)) if untranslatable || isSignal || argsHaveWork || lv || isTopEnt then return e else do - bndrs <- Lens.use bindings - case lookupVarEnv f bndrs of + bndrsV <- Lens.use bindings + bndr <- MVar.withMVar bndrsV (pure . lookupVarEnv f) + case bndr of -- Don't inline recursive expressions Just b -> do isRecBndr <- isRecursiveBndr f @@ -664,8 +686,9 @@ inlineWorkFree _ e@(Var f) = do let gv = isGlobalId f if closed && f `notElemVarSet` topEnts && not untranslatable && not isSignal && gv then do - bndrs <- Lens.use bindings - case lookupVarEnv f bndrs of + bndrsV <- Lens.use bindings + bndr <- MVar.withMVar bndrsV (pure . lookupVarEnv f) + case bndr of -- Don't inline recursive expressions Just top -> do isRecBndr <- isRecursiveBndr f @@ -682,7 +705,7 @@ inlineWorkFree _ e@(Var f) = do b <- normalizeTopLvlBndr False f top changed (bindingTerm b) _ -> return e - else return e + else return e inlineWorkFree _ e = return e {-# SCC inlineWorkFree #-} diff --git a/clash-lib/src/Clash/Normalize/Transformations/Letrec.hs b/clash-lib/src/Clash/Normalize/Transformations/Letrec.hs index 4b10c4e32b..25feba27aa 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/Letrec.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/Letrec.hs @@ -23,6 +23,8 @@ module Clash.Normalize.Transformations.Letrec , topLet ) where +import Control.Concurrent.Lifted (myThreadId) +import qualified Control.Concurrent.MVar.Lifted as MVar import qualified Control.Lens as Lens import qualified Control.Monad as Monad import Control.Monad.Trans.Except (runExcept) @@ -30,6 +32,7 @@ import Control.Monad.Writer (listen) import Data.Bifunctor (second) import qualified Data.Either as Either import qualified Data.HashMap.Lazy as HashMap +import qualified Data.HashMap.Strict as HashMapS import Data.List ((\\)) import qualified Data.List as List import qualified Data.List.Extra as List @@ -193,11 +196,15 @@ flattenLet (TransformContext is0 _) (Letrec binds body) = do emptyVarEnv (`unitVarEnv` (1 :: Int)) body (is2,binds1) <- second concat <$> List.mapAccumLM go is1 binds - bndrs <- Lens.use bindings + e1WorkFree <- case binds1 of - [(_,e1)] -> isWorkFree workFreeBinders bndrs e1 + [(_,e1)] -> do + bndrsV <- Lens.use bindings + MVar.withMVar bndrsV (\bndrs ->isWorkFree workFreeBinders bndrs e1) + _ -> pure (error "flattenLet: unreachable") + case binds1 of -- inline binders into the body when there's only a single binder, and only -- if that binder doesn't perform any work or is only used once in the body @@ -206,7 +213,7 @@ flattenLet (TransformContext is0 _) (Letrec binds body) = do -- Except when the binder is recursive! then return (Letrec binds1 body) else let subst = extendIdSubst (mkSubst is2) id1 e1 - in changed (substTm "flattenLet" subst body) + in changed (substTm "flattenLet" subst body) _ -> return (Letrec binds1 body) where go :: InScopeSet -> LetBinding -> NormalizeSession (InScopeSet,[LetBinding]) @@ -231,11 +238,15 @@ flattenLet (TransformContext is0 _) (Letrec binds body) = do emptyVarEnv (`unitVarEnv` (1 :: Int)) body2 (srcTicks,nmTicks) = partitionTicks ticks - bndrs <- Lens.use bindings + e2WorkFree <- case binds2 of - [(_,e2)] -> isWorkFree workFreeBinders bndrs e2 + [(_,e2)] -> do + bndrsV <- Lens.use bindings + MVar.withMVar bndrsV (\bndrs ->isWorkFree workFreeBinders bndrs e2) + _ -> pure (error "flattenLet: unreachable") + -- Distribute the name ticks of the let-expression over all the bindings (isN1,) . map (second (`mkTicks` nmTicks)) <$> case binds2 of -- inline binders into the body when there's only a single binder, and @@ -264,7 +275,9 @@ flattenLet _ e = return e -- found in the body of the top-level let-expression. recToLetRec :: HasCallStack => NormRewrite recToLetRec (TransformContext is0 []) e = do - (fn,_) <- Lens.use curFun + curFunsV <- Lens.use curFun + thread <- myThreadId + Just (fn,_) <- MVar.withMVar curFunsV (pure . HashMapS.lookup thread) tcm <- Lens.view tcCache case splitNormalized tcm e of Right (args,bndrs,res) -> do diff --git a/clash-lib/src/Clash/Normalize/Transformations/Specialize.hs b/clash-lib/src/Clash/Normalize/Transformations/Specialize.hs index d4465290d4..3dc8abc46b 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/Specialize.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/Specialize.hs @@ -28,17 +28,18 @@ module Clash.Normalize.Transformations.Specialize ) where import Control.Arrow ((***), (&&&)) -import Control.DeepSeq (deepseq) +import Control.Concurrent.Lifted (myThreadId) +import qualified Control.Concurrent.MVar.Lifted as MVar +import Control.DeepSeq (force) import Control.Exception (throw) -import Control.Lens ((%=)) import qualified Control.Lens as Lens import qualified Control.Monad as Monad -import Control.Monad.Extra (orM) import qualified Control.Monad.Writer as Writer (listen) import Data.Bifunctor (bimap) import Data.Coerce (coerce) import qualified Data.Either as Either import Data.Functor.Const (Const(..)) +import qualified Data.HashMap.Strict as HashMap import qualified Data.Map.Strict as Map import qualified Data.Monoid as Monoid (getAny) import qualified Data.Set.Ordered as OSet @@ -80,13 +81,13 @@ import Clash.Core.Var (Var(..), Id, TyVar, mkTyVar) import Clash.Core.VarEnv ( InScopeSet, extendInScopeSet, extendInScopeSetList, lookupVarEnv , mkInScopeSet, mkVarSet, unionInScope, elemVarSet) -import Clash.Debug (traceIf, traceM) +import Clash.Debug (traceM, traceWhen) import Clash.Driver.Types (Binding(..), TransformationInfo(..), hasTransformationInfo) import Clash.Netlist.Util (representableType) import Clash.Rewrite.Combinators (topdownR) import Clash.Rewrite.Types ( TransformContext(..), bindings, censor, curFun, customReprs, extra, tcCache - , typeTranslator, workFreeBinders, debugOpts, topEntities, specializationLimit) + , typeTranslator, workFreeBinders, debugOpts, topEntities, specializationLimit, ioLock) import Clash.Rewrite.Util ( mkBinderFor, mkDerivedName, mkFunction, mkTmBinderFor, setChanged, changed , normalizeTermTypes, normalizeId) @@ -203,8 +204,10 @@ appProp ctx@(TransformContext is _) = \case go is0 (Lam v e) (Left arg:args) ticks = do setChanged - bndrs <- Lens.use bindings - orM [pure (isVar arg), isWorkFree workFreeBinders bndrs arg] >>= \case + bndrsV <- Lens.use bindings + wf <- MVar.withMVar bndrsV (\bndrs -> isWorkFree workFreeBinders bndrs arg) + + case isVar arg || wf of True -> let subst = extendIdSubst (mkSubst is0) v arg in (`mkTicks` ticks) <$> go is0 (substTm "appProp.AppLam" subst e) args [] @@ -266,10 +269,12 @@ appProp ctx@(TransformContext is _) = \case goCaseArg isA0 ty0 ls0 (Left arg:args0) = do tcm <- Lens.view tcCache - bndrs <- Lens.use bindings let argTy = inferCoreTypeOf tcm arg ty1 = applyFunTy tcm ty0 argTy - orM [pure (isVar arg), isWorkFree workFreeBinders bndrs arg] >>= \case + bndrsV <- Lens.use bindings + wf <- MVar.withMVar bndrsV (\bndrs -> isWorkFree workFreeBinders bndrs arg) + + case isVar arg || wf of True -> do (ty2,ls1,args1) <- goCaseArg isA0 ty1 ls0 args0 return (ty2,ls1,Left arg:args1) @@ -355,24 +360,29 @@ specialize' specialize' (TransformContext is0 _) e (Var f, args, ticks) specArgIn = do opts <- Lens.view debugOpts tcm <- Lens.view tcCache + ioLockV <- Lens.use ioLock -- Don't specialize TopEntities topEnts <- Lens.view topEntities if f `elemVarSet` topEnts - then do + then case specArgIn of Left _ -> do - traceM ("Not specializing TopEntity: " ++ showPpr (varName f)) + MVar.withMVar ioLockV $ \() -> + traceM ("Not specializing TopEntity: " ++ showPpr (varName f)) + return e - Right tyArg -> - traceIf (hasTransformationInfo AppliedTerm opts) ("Dropping type application on TopEntity: " ++ showPpr (varName f) ++ "\ntype:\n" ++ showPpr tyArg) $ + Right tyArg -> do + MVar.withMVar ioLockV $ \() -> + traceWhen (hasTransformationInfo AppliedTerm opts) + ("Dropping type application on TopEntity: " ++ showPpr (varName f) ++ "\ntype:\n" ++ showPpr tyArg) -- TopEntities aren't allowed to be semantically polymorphic. -- But using type equality constraints they may be syntactically polymorphic. -- > topEntity :: forall dom . (dom ~ "System") => Signal dom Bool -> Signal dom Bool -- The TyLam's in the body will have been removed by 'Clash.Normalize.Util.substWithTyEq'. -- So we drop the TyApp ("specializing" on it) and change the varType to match. let newVarTy = piResultTy tcm (coreTypeOf f) tyArg - in changed (mkApps (mkTicks (Var f{varType = newVarTy}) ticks) args) + changed (mkApps (mkTicks (Var f{varType = newVarTy}) ticks) args) else do -- NondecreasingIndentation let specArg = bimap (normalizeTermTypes tcm) (normalizeType tcm) specArgIn @@ -387,74 +397,81 @@ specialize' (TransformContext is0 _) e (Var f, args, ticks) specArgIn = do specAbs :: Either Term Type specAbs = either (Left . stripAllTicks . (`mkAbstraction` specBndrs)) (Right . id) specArg -- Determine if 'f' has already been specialized on (a type-normalized) 'specArg' - specM <- Map.lookup (f,argLen,specAbs) <$> Lens.use (extra.specialisationCache) - case specM of - -- Use previously specialized function - Just f' -> - traceIf (hasTransformationInfo AppliedTerm opts) - ("Using previous specialization of " ++ showPpr (varName f) ++ " on " ++ - (either showPpr showPpr) specAbs ++ ": " ++ showPpr (varName f')) $ - changed $ mkApps (mkTicks (Var f') ticks) (args ++ specVars) - -- Create new specialized function - Nothing -> do - -- Determine if we can specialize f - bodyMaybe <- fmap (lookupUniqMap (varName f)) $ Lens.use bindings - case bodyMaybe of - Just (Binding _ sp inl _ bodyTm _) -> do - -- Determine if we see a sequence of specializations on a growing argument - specHistM <- lookupUniqMap f <$> Lens.use (extra.specialisationHistory) - specLim <- Lens.view specializationLimit - if maybe False (> specLim) specHistM - then throw (ClashException - sp - (unlines [ "Hit specialization limit " ++ show specLim ++ " on function `" ++ showPpr (varName f) ++ "'.\n" - , "The function `" ++ showPpr f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n" - , "Body of `" ++ showPpr f ++ "':\n" ++ showPpr bodyTm ++ "\n" - , "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showPpr showPpr) specArg - , "Run with '-fclash-spec-limit=N' to increase the specialization limit to N." - ]) - Nothing) - else do - let existingNames = collectBndrsMinusApps bodyTm - newNames = [ mkUnsafeInternalName ("pTS" `Text.append` Text.pack (show n)) n - | n <- [(0::Int)..] - ] - -- Make new binders for existing arguments - (boundArgs,argVars) <- fmap (unzip . map (either (Left &&& Left . Var) (Right &&& Right . VarTy))) $ - Monad.zipWithM - (mkBinderFor is0 tcm) - (existingNames ++ newNames) - args - -- Determine name the resulting specialized function, and the - -- form of the specialized-on argument - (fId,inl',specArg') <- case specArg of - Left a@(collectArgsTicks -> (Var g,gArgs,_gTicks)) -> if isPolyFun tcm a - then do - -- In case we are specialising on an argument that is a - -- global function then we use that function's name as the - -- name of the specialized higher-order function. - -- Additionally, we will return the body of the global - -- function, instead of a variable reference to the - -- global function. - -- - -- This will turn things like @mealy g k@ into a new - -- binding @g'@ where both the body of @mealy@ and @g@ - -- are inlined, meaning the state-transition-function - -- and the memory element will be in a single function. - gTmM <- fmap (lookupUniqMap (varName g)) $ Lens.use bindings - return (g,maybe inl bindingSpec gTmM, maybe specArg (Left . (`mkApps` gArgs) . bindingTerm) gTmM) - else return (f,inl,specArg) - _ -> return (f,inl,specArg) - -- Create specialized functions - let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg'])) (boundArgs ++ specBndrs) - newf <- mkFunction (varName fId) sp inl' newBody - -- Remember specialization - (extra.specialisationHistory) %= extendUniqMapWith f 1 (+) - (extra.specialisationCache) %= Map.insert (f,argLen,specAbs) newf - -- use specialized function - let newExpr = mkApps (mkTicks (Var newf) ticks) (args ++ specVars) - newf `deepseq` changed newExpr - Nothing -> return e + specCacheV <- Lens.use (extra.specialisationCache) + + MVar.modifyMVar specCacheV $ \specCache -> + case Map.lookup (f, argLen, specAbs) specCache of + -- Use previously specialized function + Just f' -> do + MVar.withMVar ioLockV $ \() -> + traceWhen (hasTransformationInfo AppliedTerm opts) + ("Using previous specialization of " ++ showPpr (varName f) ++ " on " ++ + (either showPpr showPpr) specAbs ++ ": " ++ showPpr (varName f')) + + changed (specCache, mkApps (mkTicks (Var f') ticks) (args ++ specVars)) + -- Create new specialized function + Nothing -> do + -- Determine if we can specialize f + bndrsV <- Lens.use bindings + bodyMaybe <- MVar.withMVar bndrsV (pure . lookupUniqMap (varName f)) + case bodyMaybe of + Just (Binding _ sp inl _ bodyTm _) -> do + -- Determine if we see a sequence of specializations on a growing argument + specHistMV <- Lens.use (extra.specialisationHistory) + specHist <- MVar.takeMVar specHistMV + let specHistM = lookupUniqMap f specHist + specLim <- Lens.view specializationLimit + if maybe False (> specLim) specHistM + then throw (ClashException + sp + (unlines [ "Hit specialization limit " ++ show specLim ++ " on function `" ++ showPpr (varName f) ++ "'.\n" + , "The function `" ++ showPpr f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n" + , "Body of `" ++ showPpr f ++ "':\n" ++ showPpr bodyTm ++ "\n" + , "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showPpr showPpr) specArg + , "Run with '-fclash-spec-limit=N' to increase the specialization limit to N." + ]) + Nothing) + else do + let existingNames = collectBndrsMinusApps bodyTm + newNames = [ mkUnsafeInternalName ("pTS" `Text.append` Text.pack (show n)) n + | n <- [(0::Int)..] + ] + -- Make new binders for existing arguments + (boundArgs,argVars) <- fmap (unzip . map (either (Left &&& Left . Var) (Right &&& Right . VarTy))) $ + Monad.zipWithM + (mkBinderFor is0 tcm) + (existingNames ++ newNames) + args + -- Determine name the resulting specialized function, and the + -- form of the specialized-on argument + (fId,inl',specArg') <- case specArg of + Left a@(collectArgsTicks -> (Var g,gArgs,_gTicks)) -> if isPolyFun tcm a + then do + -- In case we are specialising on an argument that is a + -- global function then we use that function's name as the + -- name of the specialized higher-order function. + -- Additionally, we will return the body of the global + -- function, instead of a variable reference to the + -- global function. + -- + -- This will turn things like @mealy g k@ into a new + -- binding @g'@ where both the body of @mealy@ and @g@ + -- are inlined, meaning the state-transition-function + -- and the memory element will be in a single function. + gTmM <- MVar.withMVar bndrsV (pure . lookupUniqMap (varName g)) + return (g,maybe inl bindingSpec gTmM, maybe specArg (Left . (`mkApps` gArgs) . bindingTerm) gTmM) + else return (f,inl,specArg) + _ -> return (f,inl,specArg) + -- Create specialized functions + let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg'])) (boundArgs ++ specBndrs) + newf <- force <$> mkFunction (varName fId) sp inl' newBody + -- Remember specialization + MVar.putMVar specHistMV (extendUniqMapWith f 1 (+) specHist) + -- use specialized function + let newCache = Map.insert (f, argLen, specAbs) newf specCache + let newExpr = mkApps (mkTicks (Var newf) ticks) (args ++ specVars) + changed (newCache, newExpr) + Nothing -> return (specCache, e) where collectBndrsMinusApps :: Term -> [Name a] collectBndrsMinusApps = reverse . go [] @@ -477,10 +494,14 @@ specialize' _ctx _ (appE,args,ticks) (Left specArg) = do newBody = mkAbstraction specArg specBndrs -- See if there's an existing binder that's alpha-equivalent to the -- specialized function - existing <- filterUniqMap ((`aeqTerm` newBody) . bindingTerm) <$> Lens.use bindings + bndrsV <- Lens.use bindings + existing <- MVar.withMVar bndrsV $ \bndrs -> + pure $ filterUniqMap ((`aeqTerm` newBody) . bindingTerm) bndrs -- Create a new function if an alpha-equivalent binder doesn't exist newf <- case eltsUniqMap existing of - [] -> do (cf,sp) <- Lens.use curFun + [] -> do curFunsV <- Lens.use curFun + thread <- myThreadId + Just (cf,sp) <- MVar.withMVar curFunsV (pure . HashMap.lookup thread) mkFunction (appendToName (varName cf) "_specF") sp NoUserInline newBody (b:_) -> return (bindingId b) -- Create specialized argument @@ -571,7 +592,8 @@ nonRepSpec ctx e@(App e1 e2) inlineInternalSpecialisationArgument app | (Var f,fArgs,ticks) <- collectArgsTicks app = do - fTmM <- lookupVarEnv f <$> Lens.use bindings + bndrsV <- Lens.use bindings + fTmM <- MVar.withMVar bndrsV (pure . lookupVarEnv f) case fTmM of Just b | nameSort (varName (bindingId b)) == Internal diff --git a/clash-lib/src/Clash/Normalize/Types.hs b/clash-lib/src/Clash/Normalize/Types.hs index ea25497559..839702a546 100644 --- a/clash-lib/src/Clash/Normalize/Types.hs +++ b/clash-lib/src/Clash/Normalize/Types.hs @@ -12,8 +12,9 @@ module Clash.Normalize.Types where +import Control.Concurrent.MVar (MVar) import qualified Control.Lens as Lens -import Control.Monad.State.Strict (State) +import Control.Monad.State.Strict (StateT) import Data.Map (Map) import Data.Set (Set) import Data.Text (Text) @@ -22,31 +23,31 @@ import Clash.Core.Term (Term) import Clash.Core.Type (Type) import Clash.Core.Var (Id) import Clash.Core.VarEnv (VarEnv) -import Clash.Driver.Types (BindingMap) +import Clash.Driver.Types (Binding) import Clash.Rewrite.Types (Rewrite, RewriteMonad) -- | State of the 'NormalizeMonad' data NormalizeState = NormalizeState - { _normalized :: BindingMap + { _normalized :: MVar (VarEnv (MVar (Binding Term))) -- ^ Global binders - , _specialisationCache :: Map (Id,Int,Either Term Type) Id + , _specialisationCache :: MVar (Map (Id,Int,Either Term Type) Id) -- ^ Cache of previously specialized functions: -- -- * Key: (name of the original function, argument position, specialized term/type) -- -- * Elem: (name of specialized function,type of specialized function) - , _specialisationHistory :: VarEnv Int + , _specialisationHistory :: MVar (VarEnv Int) -- ^ Cache of how many times a function was specialized - , _inlineHistory :: VarEnv (VarEnv Int) + , _inlineHistory :: MVar (VarEnv (VarEnv Int)) -- ^ Cache of function where inlining took place: -- -- * Key: function where inlining took place -- -- * Elem: (functions which were inlined, number of times inlined) - , _primitiveArgs :: Map Text (Set Int) + , _primitiveArgs :: MVar (Map Text (Set Int)) -- ^ Cache for looking up constantness of blackbox arguments - , _recursiveComponents :: VarEnv Bool + , _recursiveComponents :: MVar (VarEnv Bool) -- ^ Map telling whether a components is recursively defined. -- -- NB: there are only no mutually-recursive component, only self-recursive @@ -56,7 +57,7 @@ data NormalizeState Lens.makeLenses ''NormalizeState -- | State monad that stores specialisation and inlining information -type NormalizeMonad = State NormalizeState +type NormalizeMonad = StateT NormalizeState IO -- | RewriteSession with extra Normalisation information type NormalizeSession = RewriteMonad NormalizeState diff --git a/clash-lib/src/Clash/Normalize/Util.hs b/clash-lib/src/Clash/Normalize/Util.hs index 2d5109d5f2..d1b7f5a9a8 100644 --- a/clash-lib/src/Clash/Normalize/Util.hs +++ b/clash-lib/src/Clash/Normalize/Util.hs @@ -10,7 +10,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE TemplateHaskellQuotes #-} +{-# LANGUAGE TemplateHaskell #-} module Clash.Normalize.Util ( ConstantSpecInfo(..) @@ -33,13 +33,16 @@ module Clash.Normalize.Util ) where -import Control.Lens ((&),(+~),(%=),(.=)) +import Control.Concurrent.Lifted (myThreadId) +import qualified Control.Concurrent.MVar.Lifted as MVar +import Control.Lens ((&),(+~)) import qualified Control.Lens as Lens import Data.Bifunctor (bimap) import Data.Either (lefts,rights) import qualified Data.List as List import qualified Data.List.Extra as List import qualified Data.Map as Map +import Data.Maybe (fromMaybe) import qualified Data.HashMap.Strict as HashMapS import qualified Data.HashSet as HashSet import Data.Text (Text) @@ -74,7 +77,7 @@ import Clash.Core.Var (Id, TyVar, Var (..), isGlobalId) import Clash.Core.VarEnv (VarEnv, emptyInScopeSet, emptyVarEnv, extendVarEnv, extendVarEnvWith, lookupVarEnv, unionVarEnvWith, unitVarEnv, extendInScopeSetList, mkInScopeSet, mkVarSet) -import Clash.Debug (traceIf) +import Clash.Debug (traceWhen) import Clash.Driver.Types (BindingMap, Binding(..), TransformationInfo(FinalTerm), hasTransformationInfo) import Clash.Normalize.Primitives (removedArg) @@ -83,11 +86,11 @@ import Clash.Normalize.Types import Clash.Primitives.Util (constantArgs) import Clash.Rewrite.Types (RewriteMonad, TransformContext(..), bindings, curFun, debugOpts, extra, - tcCache, primitives) + tcCache, primitives, ioLock) import Clash.Rewrite.Util (runRewrite, mkTmBinderFor, mkDerivedName) import Clash.Unique -import Clash.Util (SrcSpan, makeCachedU) +import Clash.Util (SrcSpan, curLoc, noSrcSpan) -- | Determine if argument should reduce to a constant given a primitive and -- an argument number. Caches results. @@ -102,23 +105,26 @@ isConstantArg -- blackbox. isConstantArg "Clash.Explicit.SimIO.mealyIO" i = pure (i == 2 || i == 3) isConstantArg nm i = do - argMap <- Lens.use (extra.primitiveArgs) - case Map.lookup nm argMap of - Nothing -> do - -- Constant args not yet calculated, or primitive does not exist - prims <- Lens.view primitives - case extractPrim =<< HashMapS.lookup nm prims of - Nothing -> - -- Primitive does not exist: - pure False - Just p -> do - -- Calculate constant arguments: - let m = constantArgs nm p - (extra.primitiveArgs) Lens.%= Map.insert nm m - pure (i `elem` m) - Just m -> - -- Cached version found - pure (i `elem` m) + argMapV <- Lens.use (extra.primitiveArgs) + + MVar.modifyMVar argMapV $ \argMap -> + case Map.lookup nm argMap of + Nothing -> do + prims <- Lens.view primitives + -- Constant args not yet calculated, or primitive does not exist + case extractPrim =<< HashMapS.lookup nm prims of + Nothing -> + -- Primitive does not exist: + pure (argMap, False) + + Just p -> + -- Calculate constant arguments: + let m = constantArgs nm p + in pure (Map.insert nm m argMap, i `elem` m) + + Just m -> + -- Cached version found + pure (argMap, i `elem` m) -- | Given a list of transformation contexts, determine if any of the contexts -- indicates that the current arg is to be reduced to a constant / literal. @@ -139,10 +145,12 @@ alreadyInlined -- ^ Function in which we want to perform the inlining -> NormalizeMonad (Maybe Int) alreadyInlined f cf = do - inlinedHM <- Lens.use inlineHistory - case lookupVarEnv cf inlinedHM of - Nothing -> return Nothing - Just inlined' -> return (lookupVarEnv f inlined') + inlinedHMV <- Lens.use inlineHistory + + MVar.withMVar inlinedHMV $ \inlinedHM -> + case lookupVarEnv cf inlinedHM of + Nothing -> return Nothing + Just inlined' -> return (lookupVarEnv f inlined') -- | Record a new inlining in the `inlineHistory` addNewInline @@ -151,11 +159,11 @@ addNewInline -> Id -- ^ Function in which we're inlining it -> NormalizeMonad () -addNewInline f cf = - inlineHistory %= extendVarEnvWith - cf - (unitVarEnv f 1) - (\_ hm -> extendVarEnvWith f 1 (+) hm) +addNewInline f cf = do + inlineHistV <- Lens.use inlineHistory + + MVar.modifyMVar_ inlineHistV $ + pure . extendVarEnvWith cf (unitVarEnv f 1) (\_ hm -> extendVarEnvWith f 1 (+) hm) -- | Test whether a given term represents a non-recursive global variable isNonRecursiveGlobalVar @@ -172,20 +180,23 @@ isRecursiveBndr :: Id -> NormalizeSession Bool isRecursiveBndr f = do - cg <- Lens.use (extra.recursiveComponents) - case lookupVarEnv f cg of - Just isR -> return isR - Nothing -> do - fBodyM <- lookupVarEnv f <$> Lens.use bindings - case fBodyM of - Nothing -> return False - Just b -> do - -- There are no global mutually-recursive functions, only self-recursive - -- ones, so checking whether 'f' is part of the free variables of the - -- body of 'f' is sufficient. - let isR = f `globalIdOccursIn` bindingTerm b - (extra.recursiveComponents) %= extendVarEnv f isR - return isR + cgV <- Lens.use (extra.recursiveComponents) + + MVar.modifyMVar cgV $ \cg -> + case lookupVarEnv f cg of + Just isR -> pure (cg, isR) + Nothing -> do + bindingsV <- Lens.use bindings + mBind <- MVar.withMVar bindingsV (pure . lookupVarEnv f) + + case mBind of + Nothing -> pure (cg, False) + Just b -> + -- There are no global mutually-recursive functions, only self-recursive + -- ones, so checking whether 'f' is part of the free variables of the + -- body of 'f' is sufficient. + let isR = f `globalIdOccursIn` bindingTerm b + in pure (extendVarEnv f isR cg, isR) data ConstantSpecInfo = ConstantSpecInfo @@ -323,7 +334,9 @@ constantSpecInfo ctx e = do pure (constantCsr e) (var@(Var f), args, ticks) -> do - (curF, _) <- Lens.use curFun + curFunsV <- Lens.use curFun + thread <- myThreadId + Just (curF, _) <- MVar.withMVar curFunsV (pure . HashMapS.lookup thread) isNonRecGlobVar <- isNonRecursiveGlobalVar e if isNonRecGlobVar && f /= curF then do csr <- mergeCsrs ctx ticks e (mkApps var) args @@ -408,22 +421,44 @@ normalizeTopLvlBndr -> Id -> Binding Term -> NormalizeSession (Binding Term) -normalizeTopLvlBndr isTop nm (Binding nm' sp inl pr tm _) = makeCachedU nm (extra.normalized) $ do - tcm <- Lens.view tcCache - let nmS = showPpr (varName nm) - -- We deshadow the term because sometimes GHC gives us - -- code where a local binder has the same unique as a - -- global binder, sometimes causing the inliner to go - -- into a loop. Deshadowing freshens all the bindings - -- to avoid this. - let tm1 = deShadowTerm emptyInScopeSet tm - tm2 = if isTop then substWithTyEq tm1 else tm1 - old <- Lens.use curFun - tm3 <- rewriteExpr ("normalization",normalization) (nmS,tm2) (nm',sp) - curFun .= old - let ty' = inferCoreTypeOf tcm tm3 - let r' = nm' `globalIdOccursIn` tm3 - return (Binding nm'{varType = ty'} sp inl pr tm3 r') +normalizeTopLvlBndr isTop nm (Binding nm' sp inl pr tm _) = do + normalizedV <- Lens.use (extra.normalized) + + -- TODO This was a call to makeCachedU, but since there was no variation + -- for MVar, I unrolled everything. Maybe there should be MVar versions of + -- the makeCachedX functions needed in normalization. + + cache <- MVar.takeMVar normalizedV + case lookupVarEnv nm cache of + Just vMVar -> do + MVar.putMVar normalizedV cache + MVar.readMVar vMVar + Nothing -> do + tmp <- MVar.newEmptyMVar + MVar.putMVar normalizedV (extendVarEnv nm tmp cache) + + tcm <- Lens.view tcCache + let nmS = showPpr (varName nm) + -- We deshadow the term because sometimes GHC gives us + -- code where a local binder has the same unique as a + -- global binder, sometimes causing the inliner to go + -- into a loop. Deshadowing freshens all the bindings + -- to avoid this. + let tm1 = deShadowTerm emptyInScopeSet tm + tm2 = if isTop then substWithTyEq tm1 else tm1 + -- TODO Should tm3 be done async / added to the job queue when it's made? + curFunsV <- Lens.use curFun + thread <- myThreadId + old <- MVar.withMVar curFunsV (pure . HashMapS.lookup thread) + tm3 <- rewriteExpr ("normalization",normalization) (nmS,tm2) (nm',sp) + MVar.modifyMVar_ curFunsV $ + pure . HashMapS.insert thread (fromMaybe (error $ $(curLoc) ++ "Report as bug: no curFun", noSrcSpan) old) + let ty' = inferCoreTypeOf tcm tm3 + let r' = nm' `globalIdOccursIn` tm3 + let value = Binding nm'{varType = ty'} sp inl pr tm3 r' + + MVar.putMVar tmp value + pure value -- | Turn type equality constraints into substitutions and apply them. -- @@ -494,17 +529,23 @@ rewriteExpr :: (String,NormRewrite) -- ^ Transformation to apply -> (Id, SrcSpan) -- ^ Renew current function being rewritten -> NormalizeSession Term rewriteExpr (nrwS,nrw) (bndrS,expr) (nm, sp) = do - curFun .= (nm, sp) + curFunsV <- Lens.use curFun + thread <- myThreadId + MVar.modifyMVar_ curFunsV (pure . HashMapS.insert thread (nm, sp)) opts <- Lens.view debugOpts - let before = showPpr expr - let expr' = traceIf (hasTransformationInfo FinalTerm opts) - (bndrS ++ " before " ++ nrwS ++ ":\n\n" ++ before ++ "\n") - expr - rewritten <- runRewrite nrwS emptyInScopeSet nrw expr' - let after = showPpr rewritten - traceIf (hasTransformationInfo FinalTerm opts) - (bndrS ++ " after " ++ nrwS ++ ":\n\n" ++ after ++ "\n") $ - return rewritten + ioLockV <- Lens.use ioLock + + MVar.withMVar ioLockV $ \() -> + traceWhen (hasTransformationInfo FinalTerm opts) + (bndrS ++ " before " ++ nrwS ++ ":\n\n" ++ showPpr expr ++ "\n") + + rewritten <- runRewrite nrwS emptyInScopeSet nrw expr + + MVar.withMVar ioLockV $ \() -> + traceWhen (hasTransformationInfo FinalTerm opts) + (bndrS ++ " after " ++ nrwS ++ ":\n\n" ++ showPpr rewritten ++ "\n") + + return rewritten -- | A tick to prefix an inlined expression with it's original name. -- For example, given diff --git a/clash-lib/src/Clash/Rewrite/Types.hs b/clash-lib/src/Clash/Rewrite/Types.hs index da4ddfd28d..7221bc64c6 100644 --- a/clash-lib/src/Clash/Rewrite/Types.hs +++ b/clash-lib/src/Clash/Rewrite/Types.hs @@ -17,22 +17,40 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} + +#if MIN_VERSION_transformers(0,5,6) +{-# LANGUAGE UndecidableInstances #-} + +{-# OPTIONS_GHC -Wno-orphans #-} +#endif module Clash.Rewrite.Types where +import Control.Applicative (Alternative) +import Control.Concurrent (MVar, ThreadId) import Control.Concurrent.Supply (Supply, freshId) import Control.DeepSeq (NFData) import Control.Lens (Lens', use, (.=)) import qualified Control.Lens as Lens +import Control.Monad.Base +#if !MIN_VERSION_base(4,13,0) +import Control.Monad.Fail (MonadFail) +#endif import Control.Monad.Fix (MonadFix) +import Control.Monad.IO.Class (MonadIO) import Control.Monad.State.Strict (State) #if MIN_VERSION_transformers(0,5,6) import Control.Monad.Reader (MonadReader (..)) import Control.Monad.State (MonadState (..)) +import Control.Monad.Trans.Control + ( ComposeSt, MonadBaseControl(..), MonadTransControl(..) + , defaultLiftBaseWith, defaultRestoreM) import Control.Monad.Trans.RWS.CPS (RWST) import qualified Control.Monad.Trans.RWS.CPS as RWS import Control.Monad.Writer (MonadWriter (..)) #else +import Control.Monad.Trans.Control (MonadBaseControl(..)) import Control.Monad.Trans.RWS.Strict (RWST) import qualified Control.Monad.Trans.RWS.Strict as RWS #endif @@ -74,27 +92,40 @@ data RewriteStep -- ^ Term after `apply` } deriving (Show, Generic, NFData, Binary) +{- +Note [strictness in RewriteState] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Prior to concurrent normalization, the _bindings and _nameCounter +all had strictness marked in the fields. However, since they are now MVar, it +is not the field itself that needs to be strict but the contents of the MVar. +When these are updated in rewriting, it is necessary to use `seq` or bang +patterns to ensure that they are always forced to WHNF. + +Since the transform count was replaced in it's entirity with the map of +counters, operations on the map are always forced completely with `deepseq`. +This prevents thunks being built up on map updates, since counting the number +of transformations applied is common when debugging. +-} + -- | State of a rewriting session data RewriteState extra = RewriteState - -- TODO Given we now keep transformCounters, this should just be 'fold' - -- over that map, otherwise the two counts could fall out of sync. - { _transformCounter :: {-# UNPACK #-} !Word - -- ^ Total number of applied transformations - , _transformCounters :: HashMap Text Word + { _transformCounters :: MVar (HashMap Text Word) -- ^ Map that tracks how many times each transformation is applied - , _bindings :: !BindingMap + , _bindings :: MVar BindingMap -- ^ Global binders , _uniqSupply :: !Supply -- ^ Supply of unique numbers - , _curFun :: (Id,SrcSpan) -- Initially set to undefined: no strictness annotation - -- ^ Function which is currently normalized - , _nameCounter :: {-# UNPACK #-} !Int + , _curFun :: MVar (HashMap ThreadId (Id,SrcSpan)) + -- ^ Function which is currently normalized for each thread + , _nameCounter :: MVar Int -- ^ Used for 'Fresh' - , _globalHeap :: PrimHeap + , _globalHeap :: MVar PrimHeap -- ^ Used as a heap for compile-time evaluation of primitives that live in I/O - , _workFreeBinders :: VarEnv Bool + , _workFreeBinders :: MVar (VarEnv Bool) -- ^ Map telling whether a binder's definition is work-free + , _ioLock :: MVar () + -- ^ Synchronization for logging to stdout , _extra :: !extra -- ^ Additional state } @@ -169,10 +200,15 @@ normalizeUltra = clashEnv . Lens.to (opt_ultra . envOpts) newtype RewriteMonad extra a = R { unR :: RWST RewriteEnv Any (RewriteState extra) IO a } deriving newtype - ( Applicative + ( Alternative + , Applicative , Functor , Monad + , MonadBase IO + , MonadBaseControl IO + , MonadFail , MonadFix + , MonadIO ) -- | Run the computation in the RewriteMonad @@ -214,6 +250,35 @@ instance MonadReader RewriteEnv (RewriteMonad extra) where {-# INLINE reader #-} #endif +#if MIN_VERSION_transformers(0,5,6) && !MIN_VERSION_transformers_base(0,4,6) +instance (Monoid w, MonadBase b m) => MonadBase b (RWST r w s m) where + liftBase = liftBaseDefault + {-# INLINE liftBase #-} +#endif + +#if MIN_VERSION_transformers(0,5,6) +-- For Control.Monad.Trans.RWS.Strict these are already defined, however +-- the CPS version of RWS is now included in `monad-control` yet. + +instance (Monoid w) => MonadTransControl (RWST r w s) where + type StT (RWST r w s) a = (a, s, w) + + liftWith f = RWS.rwsT $ \r s -> + fmap (\x -> (x, s, mempty)) (f (\t -> RWS.runRWST t r s)) + {-# INLINE liftWith #-} + + restoreT m = RWS.rwsT $ \_ _ -> m + {-# INLINE restoreT #-} + +instance (Monoid w, MonadBaseControl b m) => MonadBaseControl b (RWST r w s m) where + type StM (RWST r w s m) a = ComposeSt (RWST r w s) m a + + liftBaseWith = defaultLiftBaseWith + {-# INLINE liftBaseWith #-} + restoreM = defaultRestoreM + {-# INLINE restoreM #-} +#endif + instance MonadUnique (RewriteMonad extra) where getUniqueM = do sup <- use uniqSupply @@ -239,7 +304,7 @@ type Rewrite extra = Transform (RewriteMonad extra) -- Moved into Clash.Rewrite.WorkFree {-# SPECIALIZE isWorkFree - :: Lens' (RewriteState extra) (VarEnv Bool) + :: Lens' (RewriteState extra) (MVar (VarEnv Bool)) -> BindingMap -> Term -> RewriteMonad extra Bool diff --git a/clash-lib/src/Clash/Rewrite/Util.hs b/clash-lib/src/Clash/Rewrite/Util.hs index 93630450a7..d9635cdf52 100644 --- a/clash-lib/src/Clash/Rewrite/Util.hs +++ b/clash-lib/src/Clash/Rewrite/Util.hs @@ -24,12 +24,15 @@ module Clash.Rewrite.Util , module Clash.Rewrite.WorkFree ) where +import Control.Concurrent.Lifted (myThreadId) +import qualified Control.Concurrent.MVar.Lifted as MVar import Control.Concurrent.Supply (splitSupply) import Control.DeepSeq import Control.Exception (throw) -import Control.Lens ((%=), (+=), (^.)) +import Control.Lens ((^.)) import qualified Control.Lens as Lens import qualified Control.Monad as Monad +import qualified Control.Monad.IO.Class as Monad import qualified Control.Monad.State.Strict as State #if MIN_VERSION_transformers(0,5,6) import qualified Control.Monad.Trans.RWS.CPS as RWS @@ -51,7 +54,6 @@ import qualified Data.Set as Set import qualified Data.Set.Lens as Lens import Data.Text (Text) import qualified Data.Text as Text -import System.IO.Unsafe (unsafePerformIO) import Data.Binary (encode) import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL @@ -94,10 +96,10 @@ import Clash.Util.Eq (fastEqBy) import qualified Clash.Util.Interpolate as I -- | Lift an action working in the '_extra' state to the 'RewriteMonad' -zoomExtra :: State.State extra a -> RewriteMonad extra a -zoomExtra m = R . RWS.rwsT $ \_ s -> - let (a, st') = State.runState m (_extra s) - in pure (a, s { _extra = st' }, mempty) +zoomExtra :: State.StateT extra IO a -> RewriteMonad extra a +zoomExtra m = R . RWS.rwsT $ \_ s -> do + (a, st') <- State.runStateT m (_extra s) + pure (a, s { _extra = st' }, mempty) -- | Some transformations might erroneously introduce shadowing. For example, -- a transformation might result in: @@ -147,30 +149,50 @@ apply -> Rewrite extra apply = \s rewrite ctx expr0 -> do opts <- Lens.view debugOpts - traceIf (hasDebugInfo TryName s opts) ("Trying: " <> s) (pure ()) + ioLockV <- Lens.use ioLock + + MVar.withMVar ioLockV $ \() -> + traceWhen (hasDebugInfo TryName s opts) ("Trying: " <> s) (!expr1,anyChanged) <- Writer.listen (rewrite ctx expr0) let hasChanged = Monoid.getAny anyChanged - Monad.when hasChanged (transformCounter += 1) + + Monad.when hasChanged $ do + countersV <- Lens.use transformCounters + MVar.modifyMVar_ countersV (pure . force . HashMap.insertWith (const succ) (Text.pack s) 1) -- NB: When -fclash-debug-history is on, emit binary data holding the recorded rewrite steps let rewriteHistFile = dbg_historyFile opts Monad.when (isJust rewriteHistFile && hasChanged) $ do - (curBndr, _) <- Lens.use curFun - let !_ = unsafePerformIO - $ BS.appendFile (fromJust rewriteHistFile) - $ BL.toStrict - $ encode RewriteStep - { t_ctx = tfContext ctx - , t_name = s - , t_bndrS = showPpr (varName curBndr) - , t_before = expr0 - , t_after = expr1 - } - return () + thread <- myThreadId + curFunsV <- Lens.use curFun + + MVar.withMVar curFunsV $ \curFuns -> + case fst <$> HashMap.lookup thread curFuns of + Just curBndr -> + -- TODO Although we're locking access to the history file, entries + -- may still be written to it interleaved by entity. I'm not sure if + -- clash-term can handle this correctly... + MVar.withMVar ioLockV $ \() -> + Monad.liftIO + . BS.appendFile (fromJust rewriteHistFile) + . BL.toStrict + $ encode RewriteStep + { t_ctx = tfContext ctx + , t_name = s + , t_bndrS = showPpr (varName curBndr) + , t_before = expr0 + , t_after = expr1 + } + + Nothing -> + error "apply: Normalizing from an unknown thread" if isDebugging opts - then applyDebug s expr0 hasChanged expr1 + then do + countersV <- Lens.use transformCounters + nTrans <- sum <$> MVar.readMVar countersV + applyDebug s expr0 hasChanged expr1 nTrans else return expr1 {-# INLINE apply #-} @@ -183,26 +205,26 @@ applyDebug -- ^ Whether the rewrite indicated change -> Term -- ^ New expression + -> Word -> RewriteMonad extra Term -applyDebug name exprOld hasChanged exprNew = do - nTrans <- Lens.use transformCounter +applyDebug name exprOld hasChanged exprNew nTrans = do opts <- Lens.view debugOpts let from = fromMaybe 0 (dbg_transformationsFrom opts) let limit = fromMaybe maxBound (dbg_transformationsLimit opts) - if | nTrans - from > limit -> + if | nTrans - from > limit -> do error "-fclash-debug-transformations-limit exceeded" - | nTrans <= from -> + | nTrans <= from -> do pure exprNew | otherwise -> - go opts + go (pred nTrans) opts where - go opts = traceIf (hasDebugInfo TryTerm name opts) ("Tried: " ++ name ++ " on:\n" ++ before) $ do - nTrans <- pred <$> Lens.use transformCounter + go nTrans' opts = do + ioLockV <- Lens.use ioLock - Monad.when (dbg_countTransformations opts && hasChanged) $ do - transformCounters %= HashMap.insertWith (const succ) (Text.pack name) 1 + MVar.withMVar ioLockV $ \() -> + traceWhen (hasDebugInfo TryTerm name opts) ("Tried: " ++ name ++ " on:\n" ++ before) Monad.when (dbg_invariants opts && hasChanged) $ do tcm <- Lens.view tcCache @@ -253,12 +275,14 @@ applyDebug name exprOld hasChanged exprNew = do error $ $(curLoc) ++ "Expression changed without notice(" ++ name ++ "): before" ++ before ++ "\nafter:\n" ++ after - traceIf (hasDebugInfo AppliedName name opts && hasChanged) (name <> " {" <> show nTrans <> "}") $ - traceIf (hasDebugInfo AppliedTerm name opts && hasChanged) ("Changes when applying rewrite to:\n" - ++ before ++ "\nResult:\n" ++ after ++ "\n") $ - traceIf (hasDebugInfo TryTerm name opts && not hasChanged) ("No changes when applying rewrite " - ++ name ++ " to:\n" ++ after ++ "\n") $ - return exprNew + MVar.withMVar ioLockV $ \() -> do + traceWhen (hasDebugInfo AppliedName name opts && hasChanged) (name <> " {" <> show nTrans' <> "}") + traceWhen (hasDebugInfo AppliedTerm name opts && hasChanged) + ("Changes when applying rewrite to:\n" ++ before ++ "\nResult:\n" ++ after ++ "\n") + traceWhen (hasDebugInfo TryTerm name opts && not hasChanged) + ("No changes when applying rewrite " ++ name ++ " to:\n" ++ after ++ "\n") + + return exprNew where before = showPpr exprOld after = showPpr exprNew @@ -282,11 +306,14 @@ runRewriteSession :: RewriteEnv -> IO a runRewriteSession r s m = do (a, s', _) <- runR m r s - traceIf (dbg_countTransformations (opt_debug (envOpts (_clashEnv r)))) - ("Clash: Transformations:\n" ++ Text.unpack (showCounters (s' ^. transformCounters))) $ - traceIf (None < dbg_transformationInfo (opt_debug (envOpts (_clashEnv r)))) - ("Clash: Applied " ++ show (s' ^. transformCounter) ++ " transformations") - pure a + MVar.withMVar (s' ^. transformCounters) $ \counters -> do + MVar.withMVar (s' ^. ioLock) $ \() -> do + traceWhen (dbg_countTransformations (opt_debug (envOpts (_clashEnv r)))) + ("Clash: Transformations:\n" ++ Text.unpack (showCounters counters)) + traceWhen (None < dbg_transformationInfo (opt_debug (envOpts (_clashEnv r)))) + ("Clash: Applied " ++ show (sum counters) ++ " transformations") + + pure a where showCounters = Text.unlines @@ -492,7 +519,9 @@ liftAndSubsituteBinders inScope toLift toKeep body = do (substTmEnv subst) } subst2 = extendIdSubst subst1 x e2 if x `elemFreeVars` e2 then do - (_,sp) <- Lens.use curFun + curFunsV <- Lens.use curFun + thread <- myThreadId + Just (_,sp) <- MVar.withMVar curFunsV (pure . HashMap.lookup thread) throw (ClashException sp [I.i| Internal error: inlineOrLiftBInders failed on: @@ -571,8 +600,11 @@ liftBinding (var@Id {varName = idName} ,e) = do -- Make a new global ID tcm <- Lens.view tcCache let newBodyTy = inferCoreTypeOf tcm $ mkTyLams (mkLams e boundFVs) boundFTVs - (cf,sp) <- Lens.use curFun - binders <- Lens.use bindings + curFunsV <- Lens.use curFun + thread <- myThreadId + Just (cf,sp) <- MVar.withMVar curFunsV (pure . HashMap.lookup thread) + bindersV <- Lens.use bindings + binders <- MVar.takeMVar bindersV newBodyNm <- cloneNameWithBindingMap binders @@ -594,30 +626,31 @@ liftBinding (var@Id {varName = idName} ,e) = do newBody = mkTyLams (mkLams e' boundFVs) boundFTVs -- Check if an alpha-equivalent global binder already exists - aeqExisting <- (eltsUniqMap . filterUniqMap ((`aeqTerm` newBody) . bindingTerm)) <$> Lens.use bindings + let aeqExisting = eltsUniqMap $ filterUniqMap ((`aeqTerm` newBody) . bindingTerm) binders case aeqExisting of -- If it doesn't, create a new binder [] -> do -- Add the created function to the list of global bindings let r = newBodyId `globalIdOccursIn` newBody - bindings %= extendUniqMap newBodyNm - -- We mark this function as internal so that - -- it can be inlined at the very end of - -- the normalisation pipeline as part of the - -- flattening pass. We don't inline - -- right away because we are lifting this - -- function at this moment for a reason! - -- (termination, CSE and DEC oppertunities, - -- ,etc.) - (Binding newBodyId sp NoUserInline IsFun newBody r) - -- Return the new binder + MVar.putMVar bindersV $ + extendUniqMap + newBodyNm + -- We mark this function as internal so that it can be inlined + -- at the very end of the normalisation pipeline as part of the + -- flattening pass. We don't inline right away because we are + -- lifting this function at this moment for a reason! + -- (termination, CSE and DEC oppertunities, etc.) + (Binding newBodyId sp NoUserInline IsFun newBody r) + binders + return (var, newExpr) -- If it does, use the existing binder - (b:_) -> + (b:_) -> do let newExpr' = mkTmApps (mkTyApps (Var $ bindingId b) (map VarTy boundFTVs)) (map Var boundFVs) - in return (var, newExpr') + MVar.putMVar bindersV binders + return (var, newExpr') liftBinding _ = error $ $(curLoc) ++ "liftBinding: invalid core, expr bound to tyvar" @@ -634,23 +667,16 @@ mkFunction mkFunction bndrNm sp inl body = do tcm <- Lens.view tcCache let bodyTy = inferCoreTypeOf tcm body - binders <- Lens.use bindings - bodyNm <- cloneNameWithBindingMap binders bndrNm - addGlobalBind bodyNm bodyTy sp inl body - return (mkGlobalId bodyTy bodyNm) + bindersV <- Lens.use bindings --- | Add a function to the set of global binders -addGlobalBind - :: TmName - -> Type - -> SrcSpan - -> InlineSpec - -> Term - -> RewriteMonad extra () -addGlobalBind vNm ty sp inl body = do - let vId = mkGlobalId ty vNm - r = vId `globalIdOccursIn` body - (ty,body) `deepseq` bindings %= extendUniqMap vNm (Binding vId sp inl IsFun body r) + MVar.modifyMVar bindersV $ \binders -> do + bodyNm <- cloneNameWithBindingMap binders bndrNm + let vId = mkGlobalId bodyTy bodyNm + r = vId `globalIdOccursIn` body + bind = Binding vId sp inl IsFun body r + binders' = extendUniqMap vId bind binders + + bodyTy `deepseq` body `deepseq` binders' `seq` pure (binders', vId) -- | Create a new name out of the given name, but with another unique. Resulting -- unique is guaranteed to not be in the given InScopeSet. @@ -717,77 +743,83 @@ normalizeId _ tyvar = tyvar -- | Evaluate an expression to weak-head normal form (WHNF), and apply a -- transformation on the expression in WHNF. whnfRW - :: Bool + :: forall extra + . Bool -- ^ Whether the expression we're reducing to WHNF is the subject of a -- case expression. -> TransformContext -> Term -> Rewrite extra -> RewriteMonad extra Term -whnfRW isSubj ctx@(TransformContext is0 _) e rw = do +whnfRW isSubj (TransformContext is0 hist) e0 rw = do tcm <- Lens.view tcCache - bndrs <- Lens.use bindings eval <- Lens.view evaluator + + bndrsV <- Lens.use bindings ids <- Lens.use uniqSupply + ghV <- Lens.use globalHeap + + bndrs <- MVar.takeMVar bndrsV + gh <- MVar.takeMVar ghV + let (ids1,ids2) = splitSupply ids uniqSupply Lens..= ids2 - gh <- Lens.use globalHeap - case whnf' eval bndrs tcm gh ids1 is0 isSubj e of + case whnf' eval bndrs tcm gh ids1 is0 isSubj e0 of (!gh1,ph,v) -> do - globalHeap Lens..= gh1 - bindPureHeap tcm ph rw ctx v -{-# SCC whnfRW #-} - --- | Binds variables on the PureHeap over the result of the rewrite --- --- To prevent unnecessary rewrites only do this when rewrite changed something. -bindPureHeap - :: TyConMap - -> PureHeap - -> Rewrite extra - -> Rewrite extra -bindPureHeap tcm heap rw ctx0@(TransformContext is0 hist) e = do - (e1, Monoid.getAny -> hasChanged) <- Writer.listen $ rw ctx e - if hasChanged && not (null bndrs) then do - -- The evaluator results are post-processed with two operations: - -- - -- 1. Inline work free binders. We've seen cases in the wild† where the - -- evaluator (or rather, 'bindPureHeap') would let-bind work-free - -- binders that were crucial for eliminating case constructs. If these - -- case constructs were used in a self-referential (but terminating) - -- manner, Clash would get stuck in an infinite loop. The proper - -- solution would be to use 'isWorkFree', instead of 'isWorkFreeIsh', - -- in 'bindConstantVar' such that these work free constructs would get - -- inlined again. However, this incurs a great performance penalty so - -- we opt to prevent the evaluator from introducing this situation in - -- the first place. - -- - -- I'd like to stress that this is not a proper solution though, as GHC - -- might produce a similar situation. We plan on properly solving this - -- by eliminating the current lift/bind/eval strategy, instead replacing - -- it by a partial evaluator‡. - -- - -- 2. Remove any unused let-bindings. Similar to (1), we risk Clash getting - -- stuck in an infinite loop if we don't remove unused (eliminated by - -- evaluation!) binders. - -- - -- † https://github.com/clash-lang/clash-compiler/pull/1354#issuecomment-635430374 - -- ‡ https://www.microsoft.com/en-us/research/wp-content/uploads/2016/07/supercomp-by-eval.pdf - bs <- Lens.use bindings - inlineBinders (inlineTest bs) ctx0 (Letrec bndrs e1) >>= \case - e2@(Let bnders1 e3) -> - pure (fromMaybe e2 (removeUnusedBinders bnders1 e3)) - e2 -> - pure e2 - else - return e1 - where - heapIds = map fst bndrs + let result = bindPureHeap tcm bndrs ph v + MVar.putMVar bndrsV bndrs + MVar.putMVar ghV gh1 + result + where + -- | Binds variables on the PureHeap over the result of the rewrite + -- To prevent unnecessary rewrites only do this when rewrite changed something. + bindPureHeap + :: TyConMap + -> BindingMap + -> PureHeap + -> Term + -> RewriteMonad extra Term + bindPureHeap tcm bs heap e1 = do + (e2, Monoid.getAny -> hasChanged) <- Writer.listen $ rw ctx e1 + if hasChanged && not (null letBndrs) then do + -- The evaluator results are post-processed with two operations: + -- + -- 1. Inline work free binders. We've seen cases in the wild† where the + -- evaluator (or rather, 'bindPureHeap') would let-bind work-free + -- binders that were crucial for eliminating case constructs. If these + -- case constructs were used in a self-referential (but terminating) + -- manner, Clash would get stuck in an infinite loop. The proper + -- solution would be to use 'isWorkFree', instead of 'isWorkFreeIsh', + -- in 'bindConstantVar' such that these work free constructs would get + -- inlined again. However, this incurs a great performance penalty so + -- we opt to prevent the evaluator from introducing this situation in + -- the first place. + -- + -- I'd like to stress that this is not a proper solution though, as GHC + -- might produce a similar situation. We plan on properly solving this + -- by eliminating the current lift/bind/eval strategy, instead replacing + -- it by a partial evaluator‡. + -- + -- 2. Remove any unused let-bindings. Similar to (1), we risk Clash getting + -- stuck in an infinite loop if we don't remove unused (eliminated by + -- evaluation!) binders. + -- + -- † https://github.com/clash-lang/clash-compiler/pull/1354#issuecomment-635430374 + -- ‡ https://www.microsoft.com/en-us/research/wp-content/uploads/2016/07/supercomp-by-eval.pdf + inlineBinders inlineTest ctx (Letrec letBndrs e2) >>= \case + e3@(Let bnders1 e4) -> + pure (fromMaybe e3 (removeUnusedBinders bnders1 e4)) + e3 -> + pure e3 + else + return e2 + where + heapIds = map fst letBndrs is1 = extendInScopeSetList is0 heapIds ctx = TransformContext is1 (LetBody heapIds : hist) - bndrs = map toLetBinding $ toListUniqMap heap + letBndrs = map toLetBinding $ toListUniqMap heap toLetBinding :: (Unique,Term) -> LetBinding toLetBinding (uniq,term) = (nm, term) @@ -795,7 +827,8 @@ bindPureHeap tcm heap rw ctx0@(TransformContext is0 hist) e = do ty = inferCoreTypeOf tcm term nm = mkLocalId ty (mkUnsafeSystemName "x" uniq) -- See [Note: Name re-creation] - inlineTest bs _ (_, stripTicks -> e_) = isWorkFree workFreeBinders bs e_ + inlineTest _ (_, stripTicks -> e_) = isWorkFree workFreeBinders bs e_ +{-# SCC whnfRW #-} -- | Remove unused binders in given let-binding. Returns /Nothing/ if no unused -- binders were found. diff --git a/clash-lib/src/Clash/Rewrite/WorkFree.hs b/clash-lib/src/Clash/Rewrite/WorkFree.hs index 576fdd6782..e368f92a0d 100644 --- a/clash-lib/src/Clash/Rewrite/WorkFree.hs +++ b/clash-lib/src/Clash/Rewrite/WorkFree.hs @@ -1,5 +1,5 @@ {-| -Copyright : (C) 2020-2021, QBayLogic B.V. +Copyright : (C) 2020-2022, QBayLogic B.V. License : BSD2 (see the file LICENSE) Maintainer : QBayLogic B.V. @@ -8,21 +8,25 @@ evaluation to check whether it is possible to perform changes without duplicating work in the result, e.g. inlining. -} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TemplateHaskellQuotes #-} module Clash.Rewrite.WorkFree ( isWorkFree + , isWorkFreePure , isWorkFreeClockOrResetOrEnable , isWorkFreeIsh , isConstant , isConstantNotClockReset ) where -import Control.Lens (Lens') -import Control.Monad.Extra (allM, andM, eitherM) +import Control.Concurrent.MVar (MVar) +import qualified Control.Concurrent.MVar.Lifted as MVar +import Control.Lens as Lens (Lens', use) import Control.Monad.State.Class (MonadState) +import Control.Monad.Trans.Control (MonadBaseControl) import qualified Data.Text.Extra as Text import GHC.Stack (HasCallStack) @@ -35,41 +39,59 @@ import Clash.Core.TyCon (TyConMap) import Clash.Core.Type (isPolyFunTy) import Clash.Core.Util import Clash.Core.Var (Id, isLocalId) -import Clash.Core.VarEnv (VarEnv, lookupVarEnv) +import Clash.Core.VarEnv (VarEnv, extendVarEnv, lookupVarEnv, unionVarEnv) import Clash.Driver.Types (BindingMap, Binding(..)) import Clash.Normalize.Primitives (removedArg) -import Clash.Util (makeCachedU) + +-- TODO I think isWorkFree only needs to exist within the rewriting monad, and +-- this extra polymorphism is probably unnecessary. Needs checking. -- Alex + +{-# INLINABLE isWorkFree #-} +isWorkFree + :: (HasCallStack, MonadState s m, MonadBaseControl IO m) + => Lens' s (MVar (VarEnv Bool)) + -> BindingMap + -> Term + -> m Bool +isWorkFree cacheL bndrs bndr = do + lock <- Lens.use cacheL + MVar.modifyMVar lock (\cache -> pure (isWorkFreePure cache bndrs bndr)) -- | Determines whether a global binder is work free. Errors if binder does -- not exist. isWorkFreeBinder - :: (HasCallStack, MonadState s m) - => Lens' s (VarEnv Bool) + :: HasCallStack + => VarEnv Bool -> BindingMap -> Id - -> m Bool + -> (VarEnv Bool, Bool) isWorkFreeBinder cache bndrs bndr = - makeCachedU bndr cache $ - case lookupVarEnv bndr bndrs of - Nothing -> error ("isWorkFreeBinder: couldn't find binder: " ++ showPpr bndr) - Just (bindingTerm -> t) -> - if bndr `globalIdOccursIn` t - then pure False - else isWorkFree cache bndrs t + case lookupVarEnv bndr cache of + Just value -> + (cache, value) -{-# INLINABLE isWorkFree #-} + Nothing -> + case lookupVarEnv bndr bndrs of + Nothing -> + error ("isWorkFreeBinder: couldn't find binder: " ++ showPpr bndr) + + Just (bindingTerm -> t) -> + if bndr `globalIdOccursIn` t + then (extendVarEnv bndr False cache, False) + else isWorkFreePure cache bndrs t + +{-# INLINABLE isWorkFreePure #-} -- | Determine whether a term does any work, i.e. adds to the size of the -- circuit. This function requires a cache (specified as a lens) to store the -- result for querying work info of global binders. -- -isWorkFree - :: forall s m - . (HasCallStack, MonadState s m) - => Lens' s (VarEnv Bool) +isWorkFreePure + :: HasCallStack + => VarEnv Bool -> BindingMap -> Term - -> m Bool -isWorkFree cache bndrs = go True + -> (VarEnv Bool, Bool) +isWorkFreePure cache bndrs = go True where -- If we are in the outermost level of a term (i.e. not checking a subterm) -- then a term is work free if it simply refers to a local variable. This @@ -79,7 +101,7 @@ isWorkFree cache bndrs = go True -- -- as being work free, as the term bound to f may introduce work. -- - go :: HasCallStack => Bool -> Term -> m Bool + go :: HasCallStack => Bool -> Term -> (VarEnv Bool, Bool) go isOutermost (collectArgs -> (fun, args)) = case fun of Var i @@ -91,38 +113,79 @@ isWorkFree cache bndrs = go True -- would need to be changed to know the FVs of global binders first. -- | isPolyFunTy (coreTypeOf i) -> - pure (isLocalId i && isOutermost && null args) + (cache, isLocalId i && isOutermost && null args) | isLocalId i -> - pure True + (cache, True) | otherwise -> - andM [isWorkFreeBinder cache bndrs i, allM goArg args] + let (cache', wf) = isWorkFreeBinder cache bndrs i + (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv cache' caches, and (wf : wfs)) + + Data _ -> + let (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv mempty caches, and wfs) + + Literal _ -> + (cache, True) - Data _ -> allM goArg args - Literal _ -> pure True Prim pr -> case primWorkInfo pr of -- We can ignore arguments because the primitive outputs a constant -- regardless of their values. - WorkConstant -> pure True - WorkNever -> allM goArg args - WorkIdentity _ _ -> allM goArg args - WorkVariable -> pure (all isConstantArg args) - WorkAlways -> pure False - - Lam _ e -> andM [go False e, allM goArg args] - TyLam _ e -> andM [go False e, allM goArg args] - Let (NonRec _ x) e -> andM [go False e, go False x, allM goArg args] - Let (Rec bs) e -> andM [go False e, allM (go False . snd) bs, allM goArg args] - Case s _ [(_, a)] -> andM [go False s, go False a, allM goArg args] - Case e _ _ -> andM [go False e, allM goArg args] - Cast e _ _ -> andM [go False e, allM goArg args] + WorkConstant -> (cache, True) + WorkNever -> + let (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv mempty caches, and wfs) + WorkIdentity _ _ -> + let (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv mempty caches, and wfs) + WorkVariable -> (cache, all isConstantArg args) + WorkAlways -> (cache, False) + + Lam _ e -> + let (cache', wf) = go False e + (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv cache' caches, and (wf : wfs)) + + TyLam _ e -> + let (cache', wf) = go False e + (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv cache' caches, and (wf : wfs)) + + Let (NonRec _ x) e -> + let (cacheE, wfE) = go False e + (cacheX, wfX) = go False x + (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv cacheE (cacheX : caches), and (wfE : wfX : wfs)) + + Let (Rec bs) e -> + let (cacheE, wfE) = go False e + (cacheBs, wfBs) = unzip (fmap (go False . snd) bs) + (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv cacheE (cacheBs <> caches), and (wfE : (wfBs <> wfs))) + + Case s _ [(_, a)] -> + let (cacheS, wfS) = go False s + (cacheA, wfA) = go False a + (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv cacheS (cacheA : caches), and (wfS : wfA : wfs)) + + Case e _ _ -> + let (cache', wf) = go False e + (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv cache' caches, and (wf : wfs)) + + Cast e _ _ -> + let (cache', wf) = go False e + (caches, wfs) = unzip (fmap goArg args) + in (foldr unionVarEnv cache' caches, and (wf : wfs)) -- (Ty)App's and Ticks are removed by collectArgs Tick _ _ -> error "isWorkFree: unexpected Tick" App {} -> error "isWorkFree: unexpected App" TyApp {} -> error "isWorkFree: unexpected TyApp" - goArg e = eitherM (go False) (pure . const True) (pure e) + goArg e = either (go False) (const (cache, True)) e isConstantArg = either isConstant (const True) -- | Determine if a term represents a constant diff --git a/clash-lib/tests/Test/Clash/Rewrite.hs b/clash-lib/tests/Test/Clash/Rewrite.hs index c202e119e1..8717614861 100644 --- a/clash-lib/tests/Test/Clash/Rewrite.hs +++ b/clash-lib/tests/Test/Clash/Rewrite.hs @@ -30,11 +30,11 @@ import Clash.Unique (emptyUniqMap) import qualified Clash.Util.Interpolate as I import Control.Applicative ((<|>)) +import Control.Concurrent.MVar (newMVar) import Control.Concurrent.Supply (newSupply) import Data.Default import Language.Haskell.Exts.Syntax import Language.Haskell.Exts.Parser (parseExp, fromParseResult) -import System.IO.Unsafe (unsafePerformIO) import Text.Read (readMaybe) import GHC.Stack (HasCallStack) @@ -73,28 +73,26 @@ instance Default RewriteEnv where , _topEntities=emptyVarSet } -instance Default extra => Default (RewriteState extra) where - def = RewriteState - { _transformCounter=0 - , _transformCounters=mempty - , _bindings=emptyVarEnv - , _uniqSupply=unsafePerformIO newSupply - , _curFun=error "_curFun: NYI" - , _nameCounter=2 - , _workFreeBinders=emptyVarEnv - , _globalHeap=error "_globalHeap: NYI" - , _extra=def - } - -instance Default NormalizeState where - def = NormalizeState - { _normalized=emptyVarEnv - , _specialisationCache=Map.empty - , _specialisationHistory=emptyVarEnv - , _inlineHistory=emptyVarEnv - , _primitiveArgs=Map.empty - , _recursiveComponents=emptyVarEnv - } +defRewriteState :: IO (RewriteState NormalizeState) +defRewriteState = do + normState <- NormalizeState + <$> newMVar emptyVarEnv + <*> newMVar Map.empty + <*> newMVar emptyVarEnv + <*> newMVar emptyVarEnv + <*> newMVar Map.empty + <*> newMVar emptyVarEnv + + RewriteState + <$> newMVar mempty + <*> newMVar emptyVarEnv + <*> newSupply + <*> pure (error "_curFun: NYI") + <*> newMVar 2 + <*> newMVar (error "_globalHeap: NYI") + <*> newMVar emptyVarEnv + <*> newMVar () + <*> pure normState instance Default InScopeSet where def = emptyInScopeSet @@ -124,8 +122,10 @@ runSingleTransformation rwEnv rwState is trans term = do -- include a type translator, evaluator, current function, or global heap. Maps, -- like the primitive and tycon map, are also empty. If the transformation under -- test needs these definitions, you should add them manually. -runSingleTransformationDef :: Default extra => Rewrite extra -> C.Term -> IO C.Term -runSingleTransformationDef = runSingleTransformation def def def +runSingleTransformationDef :: Rewrite NormalizeState -> C.Term -> IO C.Term +runSingleTransformationDef rewrite term = do + st <- defRewriteState + runSingleTransformation def st def rewrite term parseType :: Show l => Type l -> C.Type diff --git a/docs/developing-hardware/flags.rst b/docs/developing-hardware/flags.rst index 488f5a56a4..34bff811b7 100644 --- a/docs/developing-hardware/flags.rst +++ b/docs/developing-hardware/flags.rst @@ -271,6 +271,11 @@ Clash Compiler Flags .. _`Edalize`: https://github.com/olofk/edalize +-fclash-concurrent-normaliztation + Toggle concurrent normalization. Faster for large designs. + + **Default:** False + -main-is When using one of ``--vhdl``, ``--verilog``, or ``--systemverilog``, this flag refers to synthesis target. For example, running Clash with diff --git a/tests/src/Test/Tasty/Clash/NetlistTest.hs b/tests/src/Test/Tasty/Clash/NetlistTest.hs index 9cc0a81a35..4432da1eea 100644 --- a/tests/src/Test/Tasty/Clash/NetlistTest.hs +++ b/tests/src/Test/Tasty/Clash/NetlistTest.hs @@ -38,6 +38,7 @@ import Clash.GHC.NetlistTypes import Clash.Netlist import Clash.Netlist.Types hiding (backend, hdlDir) +import qualified Control.Concurrent.MVar as MVar import qualified Control.Concurrent.Supply as Supply import Control.DeepSeq (force) import Data.Maybe @@ -78,12 +79,14 @@ runToNetlistStage target f src = do tes2 = mkVarEnv (P.zip (P.map topId (designEntities design)) (designEntities design)) supplyN <- Supply.newSupply + lock <- MVar.newMVar () transformedBindings <- normalizeEntity env (designBindings design) (ghcTypeToHWType (opt_intWidth opts)) ghcEvaluator evaluator + lock teNames supplyN te fmap (\(_,x,_) -> force (P.map snd (OMap.assocs x))) $