-
Notifications
You must be signed in to change notification settings - Fork 38
Description
Describe the bug
I had an issue with sampling in starry on multiple cores when the orbit is variable. This had me running everything more slowly on one core for a while. The error traceback I got was:
Error traceback
--------------------------------------------------------------------------- EOFError Traceback (most recent call last) Input In [24], in () 1 with model: 2 #trace = pmx.sample( ----> 3 trace = pm.sample( 4 tune=250, 5 draws=500, 6 start=map_soln, 7 chains=4, 8 cores=4, 9 target_accept=0.9, 10 )File ~/miniconda3/envs/ABATE/lib/python3.9/site-packages/pymc3/sampling.py:559, in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
557 _print_step_hierarchy(step)
558 try:
--> 559 trace = _mp_sample(**sample_args, **parallel_args)
560 except pickle.PickleError:
561 _log.warning("Could not pickle model, sampling singlethreaded.")
File ~/miniconda3/envs/ABATE/lib/python3.9/site-packages/pymc3/sampling.py:1477, in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
1475 try:
1476 with sampler:
-> 1477 for draw in sampler:
1478 trace = traces[draw.chain - chain]
1479 if trace.supports_sampler_stats and draw.stats is not None:
File ~/miniconda3/envs/ABATE/lib/python3.9/site-packages/pymc3/parallel_sampling.py:479, in ParallelSampler.iter(self)
476 self._progress.update(self._total_draws)
478 while self._active:
--> 479 draw = ProcessAdapter.recv_draw(self._active)
480 proc, is_last, draw, tuning, stats, warns = draw
481 self._total_draws += 1
File ~/miniconda3/envs/ABATE/lib/python3.9/site-packages/pymc3/parallel_sampling.py:351, in ProcessAdapter.recv_draw(processes, timeout)
349 idxs = {id(proc._msg_pipe): proc for proc in processes}
350 proc = idxs[id(ready[0])]
--> 351 msg = ready[0].recv()
353 if msg[0] == "error":
354 warns, old_error = msg[1:]
File ~/miniconda3/envs/ABATE/lib/python3.9/multiprocessing/connection.py:255, in _ConnectionBase.recv(self)
253 self._check_closed()
254 self._check_readable()
--> 255 buf = self._recv_bytes()
256 return _ForkingPickler.loads(buf.getbuffer())
File ~/miniconda3/envs/ABATE/lib/python3.9/multiprocessing/connection.py:419, in Connection._recv_bytes(self, maxsize)
418 def _recv_bytes(self, maxsize=None):
--> 419 buf = self._recv(4)
420 size, = struct.unpack("!i", buf.getvalue())
421 if size == -1:
File ~/miniconda3/envs/ABATE/lib/python3.9/multiprocessing/connection.py:388, in Connection._recv(self, size, read)
386 if n == 0:
387 if remaining == size:
--> 388 raise EOFError
389 else:
390 raise OSError("got end of file during message")
EOFError:
To Reproduce
Minimal-ish example adapted from the "Hot jupiter phase curve example"
import starry
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import pymc3_ext as pmx
import exoplanet
starry.config.quiet = True
np.random.seed(1)
# In[7]:
A = starry.Primary(starry.Map(ydeg=0, udeg=2, amp=1.0), m=1.0, r=1.0, prot=1.0)
A.map[1] = 0.4
A.map[2] = 0.2
# In[8]:
# These are the parameters we're going to try to infer
log_amp_true = -3.0
offset_true = 30.0
b = starry.Secondary(
starry.Map(ydeg=1, udeg=0, amp=10 ** log_amp_true, inc=90.0, obl=0.0),
m=0.0,
r=0.1,
inc=90.0,
prot=1.0,
porb=1.0,
)
b.map[1, 0] = 0.5
b.theta0 = 180.0 + offset_true
# In[9]:
sys = starry.System(A, b)
# In[10]:
t = np.linspace(-0.3, 1.3, 1000)
flux_true = sys.flux(t).eval()
ferr = 1e-4
flux = flux_true + ferr * np.random.randn(len(t))
plt.figure(figsize=(12, 5))
plt.plot(t, flux, "k.", alpha=0.3, ms=3)
plt.plot(t, flux_true)
plt.xlabel("Time [days]", fontsize=24)
plt.ylabel("Flux [normalized]", fontsize=24);
# In[21]:
with pm.Model() as model:
# These are the variables we're solving for;
# here we're placing wide Gaussian priors on them.
#offset = pm.Normal("offset", 0.0, 50.0, testval=0.11)
offset=offset_true
log_amp = pm.Normal("log_amp", -4.0, 2.0, testval=-3.91)
porb = pm.Normal('porb',mu=1.0,sigma=0.02)
#porb = 1.0
# Instantiate the star; all its parameters are assumed
# to be known exactly
A = starry.Primary(
starry.Map(ydeg=0, udeg=2, amp=1.0, inc=90.0, obl=0.0), m=1.0, r=1.0, prot=1.0
)
A.map[1] = 0.4
A.map[2] = 0.2
# Instantiate the planet. Everything is fixed except for
# its luminosity and the hot spot offset.
b = starry.Secondary(
starry.Map(ydeg=1, udeg=0, amp=10 ** log_amp, inc=90.0, obl=0.0),
m=0.0,
r=0.1,
prot=1.0,
porb=porb,
)
b.map[1, 0] = 0.5
b.theta0 = 180.0 + offset
# Instantiate the system as before
sys = starry.System(A, b)
# Our model for the flux
flux_model = pm.Deterministic("flux_model", sys.flux(t))
# This is how we tell `pymc3` about our observations;
# we are assuming they are ampally distributed about
# the true model. This line effectively defines our
# likelihood function.
pm.Normal("obs", flux_model, sd=ferr, observed=flux)
# In[22]:
with model:
map_soln = pmx.optimize()
# In[23]:
plt.figure(figsize=(12, 5))
plt.plot(t, flux, "k.", alpha=0.3, ms=3)
plt.plot(t, map_soln["flux_model"])
plt.xlabel("Time [days]", fontsize=24)
plt.ylabel("Flux [normalized]", fontsize=24);
# In[24]:
with model:
trace = pm.sample(
tune=250,
draws=500,
start=map_soln,
chains=4,
cores=4,
target_accept=0.9,
)
Expected behavior
Should sample the posterior and calculate a trace object. Instead, I get the error.
Your setup (please complete the following information):
- Version of starry: 1.2.0
- Operating system: Mac OS X 12.0.1
- Python version & installation method (pip, conda, etc.): Python 3.9.12, pip 21.2.4, conda 4.10.1, M1pro chip
Additional context
This happened to me in a fairly specific set of circumstances:
- when the number of cores > 1
- when the orbital parameters are variable
However, this is the set of circumstances in which I primarily usestarry.
I have found a solution/workaround and wanted to share it in case anyone else gets EOFError.