33import warnings
44
55import numpy as np
6-
76from numba import njit
7+ from joblib import Parallel , delayed
8+ import arviz as az
9+ from scipy .optimize import minimize_scalar , lsq_linear
10+ from tqdm .auto import tqdm
811
912import pymc as pm
1013import pytensor .tensor as pt
1114from pytensor .graph import Apply , Op
1215
13- import arviz as az
14-
15- from scipy .optimize import minimize_scalar , lsq_linear
16-
17- from tqdm import tqdm
18-
1916from .geochron import Geochron
2017
2118from stratage import __version__
@@ -425,9 +422,9 @@ def model_ls(units, geochron,
425422
426423 # set defaults for bounds
427424 if sed_rate_bounds is None :
428- sed_rate_bounds = [1e-1 , 1e2 ]
425+ sed_rate_bounds = [1e-2 , 1e2 ]
429426 if hiatus_bounds is None :
430- hiatus_bounds = [1e-1 , 1e3 ]
427+ hiatus_bounds = [0 , np . inf ]
431428
432429 # bounds on model parameters
433430 lower_bounds = np .zeros (n_units + n_contacts )
@@ -473,7 +470,7 @@ def __init__(self, units, geochron, sed_rates_prior, hiatuses_prior, ls_kwargs=N
473470 geochron (geochron.Geochron): Geochron object containing geochron constraints.
474471 sed_rates_prior (function): Prior distribution for sedimentation rates. Must be valid as dist argument to pymc.CustomDist(dist=dist). Signature is sed_rate_prior(size=size).
475472 hiatuses_prior (function): Prior distribution for hiatuses. Must be valid as dist argument to pymc.CustomDist(dist=dist). Signature is hiatus_prior(size=size).
476- ls_kwargs (dict, optional): Keyword arguments for model_ls. Defaults to None.
473+ ls_kwargs (dict, optional): Keyword arguments for model_ls. Defaults to None. If None, empty dictionary is passed to model_ls.
477474 """
478475 # assign attributes
479476 self .units = units
@@ -490,6 +487,8 @@ def __init__(self, units, geochron, sed_rates_prior, hiatuses_prior, ls_kwargs=N
490487 # trim the section to the top and bottom of the geochron constraints
491488 self .units_trim = trim_units (self .units , self .geochron .h )
492489 # create least squares model as initial guess
490+ if ls_kwargs is None :
491+ ls_kwargs = {}
493492 self .sed_rates_ls , self .hiatuses_ls = model_ls (self .units , self .geochron , ** ls_kwargs )
494493 # create time increment log-like function
495494 loglike_op = loglike_gen (self .geochron , self .units_trim )
@@ -574,13 +573,38 @@ def sample(self, draws=1000, **kwargs):
574573 trace = pm .sample (draws = draws , ** kwargs )
575574 return trace
576575
577- def trace2ages (self , trace , h = None , n_posterior = None ):
576+ @staticmethod
577+ def fit_absolute_age (ii , sed_rates_post , hiatuses_post , units_trim , geochron , h ):
578+ """Fit the age model for the ii-th sample of the posterior.
579+
580+ Static method to work with joblib.Parallel.
581+
582+ Args:
583+ ii (int): Index of the posterior sample.
584+ sed_rates_post (ndarray): Sedimentation rates for each unit, ii-th posterior sample.
585+ hiatuses_post (ndarray): Hiatuses between units, ii-th posterior sample.
586+ units_trim (ndarray): Trimmed unit heights after adjusting for the top and bottom units.
587+ geochron (Geochron): Geochron object containing geochron constraints.
588+ h (arraylike): Heights at which to evaluate the age model.
589+
590+ Returns:
591+ ndarray: Ages at the given height(s).
592+ """
593+ # Logic for fitting the age model for the ii-th sample
594+ cur_times = fit_floating_model (sed_rates_post ,
595+ hiatuses_post ,
596+ units_trim ,
597+ geochron )
598+ return age (cur_times , units_trim , h )
599+
600+ def trace2ages (self , trace , h , n_posterior = None , n_jobs = 1 ):
578601 """Transform MCMC trace to age models.
579602
580603 Args:
581604 trace (arviz.InferenceData): ArviZ InferenceData object containing the MCMC trace.
582- h (arraylike, optional ): Heights at which to evaluate the age model. Defaults to None. If None, only times arrays are returned.
605+ h (arraylike): Heights at which to evaluate the age model.
583606 n_posterior (int, optional): Number of posterior samples. Defaults to None.
607+ n_jobs (int, optional): Number of parallel jobs. Defaults to 1. If 1, no parallelization is used. Uses joblib.Parallel for parallelization.
584608
585609 Returns:
586610 list: List of age models; each element is a nx2 array of unit bottom and top times for n units.
@@ -593,27 +617,26 @@ def trace2ages(self, trace, h=None, n_posterior=None):
593617 n_posterior = np .min ([10000 , n_chain * n_draws ])
594618 posterior_params = az .extract (trace , num_samples = n_posterior )
595619 # get posterior samples
596- sed_rates_post = posterior_params .sed_rates .to_numpy ().squeeze ().T
597- hiatuses_post = posterior_params .hiatuses .to_numpy ().squeeze ().T
598- # get times
599- times_post = []
620+ sed_rates_post = posterior_params .sed_rates .to_numpy ()
621+ hiatuses_post = posterior_params .hiatuses .to_numpy ()
622+ # get ages at heights h
623+ t_post = []
624+
600625 # iterate over posterior samples to generate times
601- for ii in tqdm (range (n_posterior ),
602- desc = 'Anchoring floating age models' ):
603- # fit floating model
604- cur_time = fit_floating_model (sed_rates_post [ii ],
605- hiatuses_post [ii ],
606- self .units_trim ,
607- self .geochron )
608- times_post .append (cur_time )
609- # if no heights provided, return times arrays only
610- if h is None :
611- return times_post
612- # create age-depth models for heights
626+ if n_jobs == 1 :
627+ for ii in tqdm (range (n_posterior ),
628+ desc = 'Anchoring floating age models' ):
629+ t_post .append (self .fit_absolute_age (ii , sed_rates_post [:, ii ],
630+ hiatuses_post [:, ii ],
631+ self .units_trim , self .geochron , h ))
613632 else :
614- t_posterior = np .zeros ((n_posterior , len (h )))
615- for ii in tqdm (range (n_posterior ),
616- desc = 'Interpolating heights to ages' ):
617- t_posterior [ii , :] = age (times_post [ii ],
618- self .units_trim , h )
619- return times_post , t_posterior
633+ t_post = Parallel (n_jobs = n_jobs )(delayed (self .fit_absolute_age )(ii ,
634+ sed_rates_post [:, ii ],
635+ hiatuses_post [:, ii ],
636+ self .units_trim ,
637+ self .geochron ,
638+ h ) \
639+ for ii in tqdm (range (n_posterior ),
640+ desc = 'Anchoring floating age models' ))
641+
642+ return t_post
0 commit comments