Diagnostic a pathological behavior: solution? #28
Replies: 4 comments 5 replies
-
|
Dear @Joshuaalbert, |
Beta Was this translation helpful? Give feedback.
-
|
Hello @Joshuaalbert, Here is my Nympyro code def model():
# Cosmological params (test13 JEC)
Omega_c = numpyro.sample('Omega_c', dist.Uniform(0.1, 0.4))
sigma8 = numpyro.sample('sigma8', dist.Uniform(0.5, 1.2))
Omega_b = numpyro.sample('Omega_b', dist.Uniform(0.01, 0.09))
h = numpyro.sample('h', dist.Uniform(0.4, 1.0))
n_s = numpyro.sample('n_s', dist.Uniform(0.5, 1.25))
w0 = numpyro.sample('w0', dist.Uniform(-2.0, -0.0001))
# Astrophysical params
A = numpyro.sample('A', dist.Uniform(0., 2.5))
eta = numpyro.sample('eta', dist.Uniform(0., 6.))
# parameters for systematics
m = [numpyro.sample('m%d'%i, dist.Normal(0.012, 0.023))
for i in range(1,5)]
dz1 = numpyro.sample('dz1', dist.Normal(0.001, 0.016))
dz2 = numpyro.sample('dz2', dist.Normal(-0.019, 0.013))
dz3 = numpyro.sample('dz3', dist.Normal(0.009, 0.011))
dz4 = numpyro.sample('dz4', dist.Normal(-0.018, 0.022))
# Now that params are defined, here is the forward model
cosmo = FiducialCosmo(Omega_c=Omega_c, sigma8=sigma8, Omega_b=Omega_b,
h=h, n_s=n_s, w0=w0)
signal = model_fn(get_params_vec(cosmo, m, [dz1, dz2, dz3, dz4], [A, eta]))
# And here we define the likelihood
numpyro.sample('cl_wl', dist.MultivariateNormal(signal, C), obs=cl_obs)Here is my tentative to setup a JaxNS code: #@jax.jit
#code from https://github.com/google/jax/issues/2314
def multi_gauss_logpdf(x, mean, cov):
""" Calculate the probability density of a
sample from the multivariate normal. """
D = mean.shape[0]
(sign, logdet) = np.linalg.slogdet(cov)
p1 = D*np.log(2*np.pi) + logdet
p2 = (x-mean).T @ np.linalg.inv(cov) @ (x-mean)
return -1./2 * (p1 + p2)
def solve(cov,cl_obs):
def log_lik(Omega_c, sigma8, Omega_b,h, n_s, w0, A, eta,
m1, m2,m3, m4,
dz1, dz2, dz3, dz4,
**kwargs):
cosmo = FiducialCosmo(Omega_c=Omega_c, sigma8=sigma8, Omega_b=Omega_b,
h=h, n_s=n_s, w0=w0)
signal = model_fn(get_params_vec(cosmo, [m1, m2,m3, m4], [dz1, dz2, dz3, dz4], [A, eta]))
### Comment faire numpyro.sample('cl_wl', dist.MultivariateNormal(signal, cov), obs='cl_obs')
return multi_gauss_logpdf(cl_obs,signal,cov)
prior_chain = PriorChain(UniformPrior('Omega_c',0.1, 0.4),
UniformPrior('sigma8',0.5, 1.2),
UniformPrior('Omega_b',0.01, 0.09),
UniformPrior('h',0.4, 1.0),
UniformPrior('n_s',0.5, 1.25),
UniformPrior('w0',-2.0, -0.0001),
UniformPrior('A',0.,2.5),
UniformPrior('eta',0.,6.),
NormalPrior('m1',0.012, 0.023),
NormalPrior('m2',0.012, 0.023),
NormalPrior('m3',0.012, 0.023),
NormalPrior('m4',0.012, 0.023),
NormalPrior('dz1',0.001, 0.016),
NormalPrior('dz2',-0.019, 0.013),
NormalPrior('dz3',0.009, 0.011),
NormalPrior('dz4',-0.018, 0.022)
)
print('num_live_points:',prior_chain.U_ndims*500)
ns = jaxns.nested_sampling.NestedSampler(log_lik, prior_chain,
num_live_points=prior_chain.U_ndims*500)
print('Go...')
results = ns(jax.random.PRNGKey(32564))
return results
## Go.
results = solve(C,data)Now I try to run this code on GPU... 2021-10-20 15:11:15.189842: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:461] Allocator (GPU_0_bfc) ran out of memory trying to allocate 933.17GiB (rounded to 1001989427968)requested by op The stack trace below excludes JAX-internal frames. The above exception was the direct cause of the following exception: Traceback (most recent call last): |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
-
|
@jecampagne does this example require any data? If not, I could add it to the examples section. |
Beta Was this translation helpful? Give feedback.


Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Dear experts,
I have posted this message to
Numpyroforum, and they invite me to ask the question here as theirNestedSamplerclass is a wrapper toJaxNS(see here) . Also I must confess that I 'm not an experienced user of Nested SAmpler so my question may be not well formulated.So, I have a
modelwhich I sample with theNUTSsampler. Schematically, I do the following (sorry this is in the Numpyro language)numpyro.samplestatements to define priors (which are for this exemple all Gaussian distributions) and likelihoodfix_cond_model = numpyro.handlers.condition(model, <parmeters defult values>to generate somedatathanks toThen, I proceed to the MCMC run et finally get the samples:
So far so good, now I wander if I can use the
NestedSampler? I have triedFrom the function calls point of view, it seems ok from
Numpyrodevelopers, but the sampling of the variables is clearly pathologic,And here are the arviz kde plots

I certainly miss something. Any idea are welcome. Thanks
Beta Was this translation helpful? Give feedback.
All reactions