diff --git a/monad-bayes.cabal b/monad-bayes.cabal index e74b1fe9..fcd047b9 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -94,10 +94,14 @@ library default-extensions: BlockArguments + DerivingVia FlexibleContexts + GeneralizedNewtypeDeriving ImportQualifiedPost + KindSignatures LambdaCase OverloadedStrings + StandaloneDeriving TupleSections if flag(dev) diff --git a/src/Control/Monad/Bayes/Class.hs b/src/Control/Monad/Bayes/Class.hs index b11f3f86..d1d46e1d 100644 --- a/src/Control/Monad/Bayes/Class.hs +++ b/src/Control/Monad/Bayes/Class.hs @@ -71,6 +71,7 @@ module Control.Monad.Bayes.Class Measure, Kernel, Log (ln, Exp), + MonadMeasureTrans (..), ) where @@ -82,9 +83,11 @@ import Control.Monad.Identity (IdentityT) import Control.Monad.List (ListT) import Control.Monad.Reader (ReaderT) import Control.Monad.State (StateT) +import Control.Monad.Trans (MonadTrans) import Control.Monad.Writer (WriterT) import Data.Histogram qualified as H import Data.Histogram.Fill qualified as H +import Data.Kind (Type) import Data.Matrix ( Matrix, cholDecomp, @@ -342,68 +345,82 @@ histogramToList = H.asList ---------------------------------------------------------------------------- -- Instances that lift probabilistic effects to standard tranformers. -instance MonadDistribution m => MonadDistribution (IdentityT m) where - random = lift random - bernoulli = lift . bernoulli +deriving via (MonadMeasureTrans IdentityT m) instance MonadDistribution m => MonadDistribution (IdentityT m) -instance MonadFactor m => MonadFactor (IdentityT m) where - score = lift . score +deriving via (MonadMeasureTrans IdentityT m) instance MonadFactor m => MonadFactor (IdentityT m) instance MonadMeasure m => MonadMeasure (IdentityT m) -instance MonadDistribution m => MonadDistribution (ExceptT e m) where - random = lift random - uniformD = lift . uniformD +deriving via (MonadMeasureTrans (ExceptT e) m) instance MonadDistribution m => MonadDistribution (ExceptT e m) -instance MonadFactor m => MonadFactor (ExceptT e m) where - score = lift . score +deriving via (MonadMeasureTrans (ExceptT e) m) instance MonadFactor m => MonadFactor (ExceptT e m) instance MonadMeasure m => MonadMeasure (ExceptT e m) -instance MonadDistribution m => MonadDistribution (ReaderT r m) where - random = lift random - bernoulli = lift . bernoulli +deriving via (MonadMeasureTrans (ReaderT r) m) instance MonadDistribution m => MonadDistribution (ReaderT r m) -instance MonadFactor m => MonadFactor (ReaderT r m) where - score = lift . score +deriving via (MonadMeasureTrans (ReaderT r) m) instance MonadFactor m => MonadFactor (ReaderT r m) instance MonadMeasure m => MonadMeasure (ReaderT r m) -instance (Monoid w, MonadDistribution m) => MonadDistribution (WriterT w m) where - random = lift random - bernoulli = lift . bernoulli - categorical = lift . categorical +deriving via (MonadMeasureTrans (WriterT w) m) instance (Monoid w, MonadDistribution m) => MonadDistribution (WriterT w m) -instance (Monoid w, MonadFactor m) => MonadFactor (WriterT w m) where - score = lift . score +deriving via (MonadMeasureTrans (WriterT w) m) instance (Monoid w, MonadFactor m) => MonadFactor (WriterT w m) instance (Monoid w, MonadMeasure m) => MonadMeasure (WriterT w m) -instance MonadDistribution m => MonadDistribution (StateT s m) where - random = lift random - bernoulli = lift . bernoulli - categorical = lift . categorical - uniformD = lift . uniformD +deriving via (MonadMeasureTrans (StateT s) m) instance MonadDistribution m => MonadDistribution (StateT s m) -instance MonadFactor m => MonadFactor (StateT s m) where - score = lift . score +deriving via (MonadMeasureTrans (StateT s) m) instance MonadFactor m => MonadFactor (StateT s m) instance MonadMeasure m => MonadMeasure (StateT s m) -instance MonadDistribution m => MonadDistribution (ListT m) where - random = lift random - bernoulli = lift . bernoulli - categorical = lift . categorical +deriving via (MonadMeasureTrans ListT m) instance MonadDistribution m => MonadDistribution (ListT m) -instance MonadFactor m => MonadFactor (ListT m) where - score = lift . score +deriving via (MonadMeasureTrans ListT m) instance MonadFactor m => MonadFactor (ListT m) instance MonadMeasure m => MonadMeasure (ListT m) -instance MonadDistribution m => MonadDistribution (ContT r m) where +deriving via (MonadMeasureTrans (ContT r) m) instance MonadDistribution m => MonadDistribution (ContT r m) + +deriving via (MonadMeasureTrans (ContT r) m) instance MonadFactor m => MonadFactor (ContT r m) + +instance MonadMeasure m => MonadMeasure (ContT r m) + +-- * Utility for deriving MonadDistribution, MonadFactor and MonadMeasure + +-- | Newtype to derive 'MonadDistribution', 'MonadFactor' and 'MonadMeasure' automatically for monad transformers. +-- +-- The typical usage is with the `StandaloneDeriving` and `DerivingVia` extensions. +-- For example, to derive all instances for the 'IdentityT' transformer, one writes: +-- +-- @ +-- deriving via (MonadMeasureTrans IdentityT m) instance MonadDistribution m => MonadDistribution (IdentityT m) +-- deriving via (MonadMeasureTrans IdentityT m) instance MonadFactor m => MonadFactor (IdentityT m) +-- instance MonadMeasure m => MonadMeasure (IdentityT m) +-- @ +-- (The final 'MonadMeasure' could also be derived `via`, but this isn't necessary because it doesn't contain any methods.) +newtype MonadMeasureTrans (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a = MonadMeasureTrans {getMonadMeasureTrans :: t m a} + deriving (Functor, Applicative, Monad) + +instance MonadTrans t => MonadTrans (MonadMeasureTrans t) where + lift = MonadMeasureTrans . lift + +instance (MonadTrans t, MonadDistribution m, Monad (t m)) => MonadDistribution (MonadMeasureTrans t m) where random = lift random + uniform = (lift .) . uniform + normal = (lift .) . normal + gamma = (lift .) . gamma + beta = (lift .) . beta + bernoulli = lift . bernoulli + categorical = lift . categorical + logCategorical = lift . logCategorical + uniformD = lift . uniformD + geometric = lift . geometric + poisson = lift . poisson + dirichlet = lift . dirichlet -instance MonadFactor m => MonadFactor (ContT r m) where +instance (MonadFactor m, MonadTrans t, Monad (t m)) => MonadFactor (MonadMeasureTrans t m) where score = lift . score -instance MonadMeasure m => MonadMeasure (ContT r m) +instance (MonadDistribution m, MonadFactor m, MonadTrans t, Monad (t m)) => MonadMeasure (MonadMeasureTrans t m) diff --git a/src/Control/Monad/Bayes/Inference/SMC2.hs b/src/Control/Monad/Bayes/Inference/SMC2.hs index 3b8a3787..c5608dda 100644 --- a/src/Control/Monad/Bayes/Inference/SMC2.hs +++ b/src/Control/Monad/Bayes/Inference/SMC2.hs @@ -20,7 +20,7 @@ module Control.Monad.Bayes.Inference.SMC2 where import Control.Monad.Bayes.Class - ( MonadDistribution (random), + ( MonadDistribution, MonadFactor (..), MonadMeasure, ) @@ -35,7 +35,7 @@ import Numeric.Log (Log) -- | Helper monad transformer for preprocessing the model for 'smc2'. newtype SMC2 m a = SMC2 (Sequential (Traced (Population m)) a) - deriving newtype (Functor, Applicative, Monad) + deriving newtype (Functor, Applicative, Monad, MonadDistribution, MonadFactor) setup :: SMC2 m a -> Sequential (Traced (Population m)) a setup (SMC2 m) = m @@ -43,12 +43,6 @@ setup (SMC2 m) = m instance MonadTrans SMC2 where lift = SMC2 . lift . lift . lift -instance MonadDistribution m => MonadDistribution (SMC2 m) where - random = lift random - -instance Monad m => MonadFactor (SMC2 m) where - score = SMC2 . score - instance MonadDistribution m => MonadMeasure (SMC2 m) -- | Sequential Monte Carlo squared. diff --git a/src/Control/Monad/Bayes/Sequential/Coroutine.hs b/src/Control/Monad/Bayes/Sequential/Coroutine.hs index d15c38f8..81a7b621 100644 --- a/src/Control/Monad/Bayes/Sequential/Coroutine.hs +++ b/src/Control/Monad/Bayes/Sequential/Coroutine.hs @@ -26,9 +26,10 @@ module Control.Monad.Bayes.Sequential.Coroutine where import Control.Monad.Bayes.Class - ( MonadDistribution (bernoulli, categorical, random), + ( MonadDistribution, MonadFactor (..), MonadMeasure, + MonadMeasureTrans (..), ) import Control.Monad.Coroutine ( Coroutine (..), @@ -54,10 +55,7 @@ newtype Sequential m a = Sequential {runSequential :: Coroutine (Await ()) m a} extract :: Await () a -> a extract (Await f) = f () -instance MonadDistribution m => MonadDistribution (Sequential m) where - random = lift random - bernoulli = lift . bernoulli - categorical = lift . categorical +deriving via (MonadMeasureTrans Sequential m) instance MonadDistribution m => MonadDistribution (Sequential m) -- | Execution is 'suspend'ed after each 'score'. instance MonadFactor m => MonadFactor (Sequential m) where