diff --git a/clash-prelude/src/Clash/Explicit/Signal/Delayed.hs b/clash-prelude/src/Clash/Explicit/Signal/Delayed.hs index acb70b473e..34443f96cf 100644 --- a/clash-prelude/src/Clash/Explicit/Signal/Delayed.hs +++ b/clash-prelude/src/Clash/Explicit/Signal/Delayed.hs @@ -17,6 +17,8 @@ Maintainer : Christiaan Baaij {-# LANGUAGE Trustworthy #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +{-# OPTIONS_GHC -fconstraint-solver-iterations=10 #-} {-# OPTIONS_HADDOCK show-extensions #-} module Clash.Explicit.Signal.Delayed @@ -42,22 +44,25 @@ module Clash.Explicit.Signal.Delayed ) where -import Prelude ((.), (<$>), (<*>), id, Num(..)) +import Prelude ((.), ($), (<$>), id, Num(..), Maybe(..), fmap) +import Control.Applicative (liftA2) import Data.Coerce (coerce) import Data.Kind (Type) -import Data.Proxy (Proxy (..)) -import Data.Singletons (Apply, TyFun, type (@@)) -import GHC.TypeLits (KnownNat, Nat, type (+), type (^), type (*)) +import Data.Type.Equality ((:~:)(Refl)) +import GHC.TypeLits (sameNat, Div, Mod, KnownNat, Nat, type (+), type (*), type (<=)) +import GHC.TypeLits.Extra (CLog) +import Clash.Magic (clashCompileError) import Clash.Sized.Vector import Clash.Signal.Delayed.Internal (DSignal(..), dfromList, dfromList_lazy, fromSignal, toSignal, unsafeFromSignal, antiDelay, feedback, forward) +import qualified Clash.Signal.Delayed.Bundle as D import Clash.Explicit.Signal (KnownDomain, Clock, Domain, Reset, Signal, Enable, register, delay, bundle, unbundle) -import Clash.Promoted.Nat (SNat (..), snatToInteger) +import Clash.Promoted.Nat (SNat (..), SNatLE (..), compareSNat, snatToInteger) import Clash.XException (NFDataX) {- $setup @@ -230,12 +235,9 @@ delayI -> DSignal dom (n+d) a delayI dflt = delayN (SNat :: SNat d) dflt -data DelayedFold (dom :: Domain) (n :: Nat) (delay :: Nat) (a :: Type) (f :: TyFun Nat Type) :: Type -type instance Apply (DelayedFold dom n delay a) k = DSignal dom (n + (delay*k)) a - -- | Tree fold over a 'Vec' of 'DSignal's with a combinatorial function, -- and delaying @delay@ cycles after each application. --- Values at times 0..(delay*k)-1 are set to a default. +-- Values at times 0..(delay * CLog 2 n)-1 are set to a default. -- -- @ -- countingSignals :: Vec 4 (DSignal dom 0 Int) @@ -248,11 +250,12 @@ type instance Apply (DelayedFold dom n delay a) k = DSignal dom (n + (delay*k)) -- >>> printX $ sampleN 8 (delayedFold d2 (-1) (*) enableGen systemClockGen countingSignals) -- [-1,-1,1,1,0,1,16,81] delayedFold - :: forall dom n delay k a + :: forall dom d delay n a . ( NFDataX a , KnownDomain dom , KnownNat delay - , KnownNat k ) + , KnownNat n + , 1 <= n ) => SNat delay -- ^ Delay applied after each step -> a @@ -261,14 +264,60 @@ delayedFold -- ^ Fold operation to apply -> Enable dom -> Clock dom - -> Vec (2^k) (DSignal dom n a) - -- ^ Vector input of size 2^k - -> DSignal dom (n + (delay * k)) a - -- ^ Output Signal delayed by (delay * k) -delayedFold _ dflt op ena clk = dtfold (Proxy :: Proxy (DelayedFold dom n delay a)) id go - where - go :: SNat l - -> DelayedFold dom n delay a @@ l - -> DelayedFold dom n delay a @@ l - -> DelayedFold dom n delay a @@ (l+1) - go SNat x y = delayI dflt ena clk (op <$> x <*> y) + -> Vec n (DSignal dom d a) + -- ^ Vector input of size @n@ + -> DSignal dom (d + delay * CLog 2 n) a + -- ^ Output Signal delayed by @delay * CLog 2 n@ +delayedFold SNat initial f ena clk inps = case sameNat (SNat @1) (SNat @n) of + Just Refl -> head inps + _ -> case (modProof, strictlyPosDivRu, divMulProof) of + (SNatLE, SNatLE, Just Refl) -> + case sameNat (SNat @(1 + CLog 2 (n `Div` 2 + n `Mod` 2))) (SNat @(CLog 2 n)) of + Just Refl -> delayedFold (SNat @delay) initial f ena clk newLayer + where + newLayer = D.unbundle $ + step @(n `Div` 2) @(n `Mod` 2) @d @delay (SNat @(n `Div` 2)) initial f ena clk (D.bundle inps) + _ -> clashCompileError + "delayedFold0: absurd, report this to the clash-compiler team: https://github.com/clash-lang/clash-compiler/issues" + _ -> clashCompileError + "delayedFold1: absurd, report this to the clash-compiler team: https://github.com/clash-lang/clash-compiler/issues" + where + modProof = compareSNat (SNat @(n `Mod` 2)) (SNat @1) + strictlyPosDivRu = compareSNat (SNat @1) (SNat @(n `Div` 2 + n `Mod` 2)) + divMulProof = sameNat (SNat @n) (SNat @(2 * (n `Div` 2) + n `Mod` 2)) + +-- | A single layer of the pipelined fold +step :: forall (m :: Nat) (p :: Nat) (d :: Nat) (delay :: Nat) (dom :: Domain) (a :: Type). + KnownNat p + => KnownNat delay + => KnownDomain dom + => p <= 1 + => NFDataX a + => SNat m + -> a + -> (a -> a -> a) + -> Enable dom + -> Clock dom + -> DSignal dom d (Vec (2 * m + p) a) + -> DSignal dom (d + delay) (Vec (m + p) a) +step SNat initial f ena clk inps = + let + layerCalc :: DSignal dom d (Vec (2 * m) a) -> DSignal dom d (Vec m a) + layerCalc = fmap (map applyF . unconcatI) + + applyF :: Vec 2 a -> a + applyF (a `Cons` b `Cons` _) = f a b + in + case (sameNat (SNat @p) (SNat @0), sameNat (SNat @p) (SNat @1)) of + -- Size of the input vector is even + (Just Refl, Nothing) -> + delayI (repeat initial) ena clk (layerCalc inps) + -- Size of the input vector is odd + (Nothing, Just Refl) -> + delayI (repeat initial) ena clk $ + liftA2 + (++) + (singleton . head <$> inps) + (layerCalc (tail <$> inps)) + _ -> clashCompileError + "delayedFold step: absurd, report this to the clash-compiler team: https://github.com/clash-lang/clash-compiler/issues" diff --git a/clash-prelude/src/Clash/Signal/Delayed.hs b/clash-prelude/src/Clash/Signal/Delayed.hs index fcbc919d07..ae2142a291 100644 --- a/clash-prelude/src/Clash/Signal/Delayed.hs +++ b/clash-prelude/src/Clash/Signal/Delayed.hs @@ -41,7 +41,8 @@ module Clash.Signal.Delayed where import GHC.TypeLits - (KnownNat, type (^), type (+), type (*)) + (KnownNat, type (+), type (*), type (<=)) +import GHC.TypeLits.Extra (CLog) import Clash.Signal.Delayed.Internal (DSignal(..), dfromList, dfromList_lazy, fromSignal, toSignal, @@ -192,7 +193,7 @@ delayI dflt = hideClock (hideEnable (E.delayI dflt)) -- | Tree fold over a 'Vec' of 'DSignal's with a combinatorial function, -- and delaying @delay@ cycles after each application. --- Values at times 0..(delay*k)-1 are set to a default. +-- Values at times 0..(delay * CLog 2 n)-1 are set to a default. -- -- @ -- countingSignals :: Vec 4 (DSignal dom 0 Int) @@ -205,20 +206,21 @@ delayI dflt = hideClock (hideEnable (E.delayI dflt)) -- >>> printX $ sampleN @System 8 (toSignal (delayedFold d2 (-1) (*) countingSignals)) -- [-1,-1,1,1,0,1,16,81] delayedFold - :: forall dom n delay k a + :: forall dom d delay n a . ( HiddenClock dom , HiddenEnable dom , NFDataX a , KnownNat delay - , KnownNat k ) + , KnownNat n + , 1 <= n) => SNat delay -- ^ Delay applied after each step -> a -- ^ Initial value -> (a -> a -> a) -- ^ Fold operation to apply - -> Vec (2^k) (DSignal dom n a) + -> Vec n (DSignal dom d a) -- ^ Vector input of size 2^k - -> DSignal dom (n + (delay * k)) a + -> DSignal dom (d + (delay * CLog 2 n)) a -- ^ Output Signal delayed by (delay * k) delayedFold d dflt f = hideClock (hideEnable (E.delayedFold d dflt f)) diff --git a/clash-prelude/src/Clash/Signal/Delayed/Bundle.hs b/clash-prelude/src/Clash/Signal/Delayed/Bundle.hs index ba7613fb8d..e54692cef7 100644 --- a/clash-prelude/src/Clash/Signal/Delayed/Bundle.hs +++ b/clash-prelude/src/Clash/Signal/Delayed/Bundle.hs @@ -25,7 +25,7 @@ import GHC.TypeLits (KnownNat) import Prelude hiding (head, map, tail) import Clash.Signal.Internal (Domain) -import Clash.Signal.Delayed (DSignal, toSignal, unsafeFromSignal) +import Clash.Signal.Delayed.Internal (DSignal, toSignal, unsafeFromSignal) import qualified Clash.Signal.Bundle as B import Clash.Sized.BitVector (Bit, BitVector)