@@ -27,6 +27,8 @@ module Control.Monad.Bayes.Population
2727 resampleSystematic ,
2828 stratified ,
2929 resampleStratified ,
30+ onlyBelowEffectiveSampleSize ,
31+ effectiveSampleSize ,
3032 extractEvidence ,
3133 pushEvidence ,
3234 proper ,
@@ -244,6 +246,43 @@ resampleMultinomial ::
244246 PopulationT m a
245247resampleMultinomial = resampleGeneric multinomial
246248
249+ -- ** Effective sample size
250+
251+ -- | Only use the given resampler when the effective sample size is below a certain threshold.
252+ --
253+ -- See 'withEffectiveSampleSize'.
254+ onlyBelowEffectiveSampleSize ::
255+ (MonadDistribution m ) =>
256+ -- | The threshold under which the effective sample size must fall before the resampler is used.
257+ -- For example, this may be half of the number of particles.
258+ Double ->
259+ -- | The resampler to user under the threshold
260+ (forall n . (MonadDistribution n ) => PopulationT n a -> PopulationT n a ) ->
261+ -- | The new resampler
262+ (PopulationT m a -> PopulationT m a )
263+ onlyBelowEffectiveSampleSize threshold resampler pop = fromWeightedList $ do
264+ (as, ess) <- withEffectiveSampleSize pop
265+ if ess < threshold then runPopulationT $ resampler $ fromWeightedList $ pure as else return as
266+
267+ -- | Compute the effective sample size of a population from the weights.
268+ --
269+ -- See https://en.wikipedia.org/wiki/Design_effect#Effective_sample_size
270+ effectiveSampleSize :: (Functor m ) => PopulationT m a -> m Double
271+ effectiveSampleSize = fmap snd . withEffectiveSampleSize
272+
273+ -- | Compute the effective sample size alongside the samples themselves.
274+ --
275+ -- The advantage over 'effectiveSampleSize' is that the samples need not be created a second time.
276+ withEffectiveSampleSize :: (Functor m ) => PopulationT m a -> m ([(a , Log Double )], Double )
277+ withEffectiveSampleSize = fmap (\ as -> (as, effectiveSampleSizeKish $ (exp . ln . snd ) <$> as)) . runPopulationT
278+ where
279+ effectiveSampleSizeKish :: [Double ] -> Double
280+ effectiveSampleSizeKish weights = square (Data.List. sum weights) / Data.List. sum (square <$> weights)
281+ square :: Double -> Double
282+ square x = x * x
283+
284+ -- ** Utility functions
285+
247286-- | Separate the sum of weights into the 'WeightedT' transformer.
248287-- Weights are normalized after this operation.
249288extractEvidence ::
0 commit comments