Skip to content

Commit 6665122

Browse files
committed
parallelization with joblib for AgeModel.trace2ages
also make sure ls_kwargs works for None in AgeModel init
1 parent dd951de commit 6665122

File tree

3 files changed

+60
-35
lines changed

3 files changed

+60
-35
lines changed

docs/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ numpy
88
pymc
99
scipy
1010
tqdm
11-
matplotlib
11+
matplotlib
12+
joblib

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ install_requires =
5555
scipy
5656
tqdm
5757
matplotlib
58+
joblib
5859

5960
[options.packages.find]
6061
where = src

src/stratage/stratage.py

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,16 @@
33
import warnings
44

55
import numpy as np
6-
76
from numba import njit
7+
from joblib import Parallel, delayed
8+
import arviz as az
9+
from scipy.optimize import minimize_scalar, lsq_linear
10+
from tqdm.auto import tqdm
811

912
import pymc as pm
1013
import pytensor.tensor as pt
1114
from pytensor.graph import Apply, Op
1215

13-
import arviz as az
14-
15-
from scipy.optimize import minimize_scalar, lsq_linear
16-
17-
from tqdm import tqdm
18-
1916
from .geochron import Geochron
2017

2118
from stratage import __version__
@@ -425,9 +422,9 @@ def model_ls(units, geochron,
425422

426423
# set defaults for bounds
427424
if sed_rate_bounds is None:
428-
sed_rate_bounds = [1e-1, 1e2]
425+
sed_rate_bounds = [1e-2, 1e2]
429426
if hiatus_bounds is None:
430-
hiatus_bounds = [1e-1, 1e3]
427+
hiatus_bounds = [0, np.inf]
431428

432429
# bounds on model parameters
433430
lower_bounds = np.zeros(n_units+n_contacts)
@@ -473,7 +470,7 @@ def __init__(self, units, geochron, sed_rates_prior, hiatuses_prior, ls_kwargs=N
473470
geochron (geochron.Geochron): Geochron object containing geochron constraints.
474471
sed_rates_prior (function): Prior distribution for sedimentation rates. Must be valid as dist argument to pymc.CustomDist(dist=dist). Signature is sed_rate_prior(size=size).
475472
hiatuses_prior (function): Prior distribution for hiatuses. Must be valid as dist argument to pymc.CustomDist(dist=dist). Signature is hiatus_prior(size=size).
476-
ls_kwargs (dict, optional): Keyword arguments for model_ls. Defaults to None.
473+
ls_kwargs (dict, optional): Keyword arguments for model_ls. Defaults to None. If None, empty dictionary is passed to model_ls.
477474
"""
478475
# assign attributes
479476
self.units = units
@@ -490,6 +487,8 @@ def __init__(self, units, geochron, sed_rates_prior, hiatuses_prior, ls_kwargs=N
490487
# trim the section to the top and bottom of the geochron constraints
491488
self.units_trim = trim_units(self.units, self.geochron.h)
492489
# create least squares model as initial guess
490+
if ls_kwargs is None:
491+
ls_kwargs = {}
493492
self.sed_rates_ls, self.hiatuses_ls = model_ls(self.units, self.geochron, **ls_kwargs)
494493
# create time increment log-like function
495494
loglike_op = loglike_gen(self.geochron, self.units_trim)
@@ -574,13 +573,38 @@ def sample(self, draws=1000, **kwargs):
574573
trace = pm.sample(draws=draws, **kwargs)
575574
return trace
576575

577-
def trace2ages(self, trace, h=None, n_posterior=None):
576+
@staticmethod
577+
def fit_absolute_age(ii, sed_rates_post, hiatuses_post, units_trim, geochron, h):
578+
"""Fit the age model for the ii-th sample of the posterior.
579+
580+
Static method to work with joblib.Parallel.
581+
582+
Args:
583+
ii (int): Index of the posterior sample.
584+
sed_rates_post (ndarray): Sedimentation rates for each unit, ii-th posterior sample.
585+
hiatuses_post (ndarray): Hiatuses between units, ii-th posterior sample.
586+
units_trim (ndarray): Trimmed unit heights after adjusting for the top and bottom units.
587+
geochron (Geochron): Geochron object containing geochron constraints.
588+
h (arraylike): Heights at which to evaluate the age model.
589+
590+
Returns:
591+
ndarray: Ages at the given height(s).
592+
"""
593+
# Logic for fitting the age model for the ii-th sample
594+
cur_times = fit_floating_model(sed_rates_post,
595+
hiatuses_post,
596+
units_trim,
597+
geochron)
598+
return age(cur_times, units_trim, h)
599+
600+
def trace2ages(self, trace, h, n_posterior=None, n_jobs=1):
578601
"""Transform MCMC trace to age models.
579602
580603
Args:
581604
trace (arviz.InferenceData): ArviZ InferenceData object containing the MCMC trace.
582-
h (arraylike, optional): Heights at which to evaluate the age model. Defaults to None. If None, only times arrays are returned.
605+
h (arraylike): Heights at which to evaluate the age model.
583606
n_posterior (int, optional): Number of posterior samples. Defaults to None.
607+
n_jobs (int, optional): Number of parallel jobs. Defaults to 1. If 1, no parallelization is used. Uses joblib.Parallel for parallelization.
584608
585609
Returns:
586610
list: List of age models; each element is a nx2 array of unit bottom and top times for n units.
@@ -593,27 +617,26 @@ def trace2ages(self, trace, h=None, n_posterior=None):
593617
n_posterior = np.min([10000, n_chain*n_draws])
594618
posterior_params = az.extract(trace, num_samples=n_posterior)
595619
# get posterior samples
596-
sed_rates_post = posterior_params.sed_rates.to_numpy().squeeze().T
597-
hiatuses_post = posterior_params.hiatuses.to_numpy().squeeze().T
598-
# get times
599-
times_post = []
620+
sed_rates_post = posterior_params.sed_rates.to_numpy()
621+
hiatuses_post = posterior_params.hiatuses.to_numpy()
622+
# get ages at heights h
623+
t_post = []
624+
600625
# iterate over posterior samples to generate times
601-
for ii in tqdm(range(n_posterior),
602-
desc='Anchoring floating age models'):
603-
# fit floating model
604-
cur_time = fit_floating_model(sed_rates_post[ii],
605-
hiatuses_post[ii],
606-
self.units_trim,
607-
self.geochron)
608-
times_post.append(cur_time)
609-
# if no heights provided, return times arrays only
610-
if h is None:
611-
return times_post
612-
# create age-depth models for heights
626+
if n_jobs == 1:
627+
for ii in tqdm(range(n_posterior),
628+
desc='Anchoring floating age models'):
629+
t_post.append(self.fit_absolute_age(ii, sed_rates_post[:, ii],
630+
hiatuses_post[:, ii],
631+
self.units_trim, self.geochron, h))
613632
else:
614-
t_posterior = np.zeros((n_posterior, len(h)))
615-
for ii in tqdm(range(n_posterior),
616-
desc='Interpolating heights to ages'):
617-
t_posterior[ii, :] = age(times_post[ii],
618-
self.units_trim, h)
619-
return times_post, t_posterior
633+
t_post = Parallel(n_jobs=n_jobs)(delayed(self.fit_absolute_age)(ii,
634+
sed_rates_post[:, ii],
635+
hiatuses_post[:, ii],
636+
self.units_trim,
637+
self.geochron,
638+
h) \
639+
for ii in tqdm(range(n_posterior),
640+
desc='Anchoring floating age models'))
641+
642+
return t_post

0 commit comments

Comments
 (0)