3030from dask .distributed import futures_of , as_completed , wait
3131from dask import delayed
3232
33+
3334class PerturbationPrior (PriorBase ):
3435
3536 def __init__ (self , ref_prior , samples , normalized_weights , perturbation_kernel ,
36- use_logger = False ):
37+ use_logger = False ):
3738
3839 self .name = 'Perturbation Prior'
3940 self .ref_prior = ref_prior
@@ -42,7 +43,7 @@ def __init__(self, ref_prior, samples, normalized_weights, perturbation_kernel,
4243 self .perturbation_kernel = perturbation_kernel
4344 super (PerturbationPrior , self ).__init__ (self .name , use_logger )
4445
45- def draw (self , n = 1 , chunk_size = 1 ):
46+ def draw (self , n = 1 , chunk_size = 1 ):
4647
4748 assert n >= chunk_size
4849
@@ -57,9 +58,9 @@ def draw(self, n = 1, chunk_size = 1):
5758 return generated_samples
5859
5960 @delayed
60- def _weighted_draw_perturb (self ,m ):
61+ def _weighted_draw_perturb (self , m ):
6162 idxs = np .random .choice (self .samples .shape [0 ], m ,
62- p = self .normalized_weights )
63+ p = self .normalized_weights )
6364 s0 = [self .samples [idx ] for idx in idxs ]
6465 s = []
6566 for z in s0 :
@@ -72,6 +73,7 @@ def _weighted_draw_perturb(self,m):
7273
7374 return np .asarray (s )
7475
76+
7577class SMCABC (InferenceBase ):
7678 """
7779 SMC - Approximate Bayesian Computation
@@ -145,19 +147,20 @@ def infer(self, num_samples, batch_size, chunk_size=10, ensemble_size=1, normali
145147
146148 # Generate an initial population from the first epsilon
147149 abc_instance = abc_inference .ABC (self .data , self .sim , prior_function ,
148- epsilon = self .epsilons [0 ],
149- summaries_function = self .summaries_function ,
150- distance_function = self .distance_function ,
151- summaries_divisor = self .summaries_divisor ,
152- use_logger = self .use_logger )
150+ epsilon = self .epsilons [0 ],
151+ summaries_function = self .summaries_function ,
152+ distance_function = self .distance_function ,
153+ summaries_divisor = self .summaries_divisor ,
154+ use_logger = self .use_logger )
153155
154156 print ("Starting epsilon={}" .format (self .epsilons [0 ]))
155- abc_instance .compute_fixed_mean (chunk_size = chunk_size )
156- abc_results = abc_instance .infer (num_samples = t , batch_size = batch_size , chunk_size = chunk_size , normalize = normalize )
157+ abc_instance .compute_fixed_mean (chunk_size = chunk_size )
158+ abc_results = abc_instance .infer (num_samples = t , batch_size = batch_size , chunk_size = chunk_size ,
159+ normalize = normalize )
157160
158161 final_results = abc_results
159162 population = np .vstack (abc_results ['accepted_samples' ])[:t ]
160- normalized_weights = np .ones (t )/ t
163+ normalized_weights = np .ones (t ) / t
161164 d = population .shape [1 ]
162165
163166 # SMC iterations
@@ -175,23 +178,23 @@ def infer(self, num_samples, batch_size, chunk_size=10, ensemble_size=1, normali
175178 try :
176179 # Run ABC on the next epsilon using the proposal prior
177180 abc_instance = abc_inference .ABC (self .data , self .sim , new_prior ,
178- epsilon = eps , summaries_function = self .summaries_function ,
179- distance_function = self .distance_function ,
180- summaries_divisor = self .summaries_divisor ,
181- use_logger = self .use_logger )
182- abc_instance .compute_fixed_mean (chunk_size = chunk_size )
183- abc_results = abc_instance .infer (num_samples = t ,
184- batch_size = batch_size ,
185- chunk_size = chunk_size ,
186- normalize = normalize )
181+ epsilon = eps , summaries_function = self .summaries_function ,
182+ distance_function = self .distance_function ,
183+ summaries_divisor = self .summaries_divisor ,
184+ use_logger = self .use_logger )
185+ abc_instance .compute_fixed_mean (chunk_size = chunk_size )
186+ abc_results = abc_instance .infer (num_samples = t ,
187+ batch_size = batch_size ,
188+ chunk_size = chunk_size ,
189+ normalize = normalize )
187190 new_samples = np .vstack (abc_results ['accepted_samples' ])[:t ]
188191
189192 # Compute importance weights for the new samples
190193 prior_weights = self .prior_function .pdf (new_samples )
191194 kweights = self .perturbation_kernel .pdf (population , new_samples )
192195
193- new_weights = prior_weights / np .sum (kweights * normalized_weights [:, np .newaxis ], axis = 0 )
194- new_weights = new_weights / sum (new_weights )
196+ new_weights = prior_weights / np .sum (kweights * normalized_weights [:, np .newaxis ], axis = 0 )
197+ new_weights = new_weights / sum (new_weights )
195198
196199 population = new_samples
197200 normalized_weights = new_weights
0 commit comments