Skip to content

Fix err dist none #909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/909.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix ligthcurve breaking when err_dist=None and changed default err_dist to None to match docs
10 changes: 5 additions & 5 deletions docs/simulator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Simulate Method
Stingray provides multiple ways to simulate a light curve. However, all these methods follow a common recipe::

>>> sim = simulator.Simulator(N=1024, mean=0.5, dt=0.125, rms=1.0)
>>> lc = sim.simulate(2)
>>> lc = sim.simulate(2) # doctest: +IGNORE_WARNINGS

Using Power-Law Spectrum
------------------------
Expand Down Expand Up @@ -154,7 +154,7 @@ Channel Simulation
The `simulator` class provides the functionality to simulate light curves independently for each channel. This is useful, for example, when dealing with energy dependent impulse responses where we can create a di↵erent simulation channel for each energy range. The module provides options to count, retrieve and delete channels.::

>>> sim = simulator.Simulator(N=1024, mean=0.5, dt=0.125, rms=1.0)
>>> sim.simulate_channel('3.5 - 4.5', 2)
>>> sim.simulate_channel('3.5 - 4.5', 2) # doctest: +IGNORE_WARNINGS
>>> sim.count_channels()
1
>>> lc = sim.get_channel('3.5 - 4.5')
Expand All @@ -164,9 +164,9 @@ Alternatively, assume that we have light curves in the simulated energy channels

>>> sim.count_channels()
0
>>> sim.simulate_channel('3.5 - 4.5', 2)
>>> sim.simulate_channel('4.5 - 5.5', 2)
>>> sim.simulate_channel('5.5 - 6.5', 2)
>>> sim.simulate_channel('3.5 - 4.5', 2) # doctest: +IGNORE_WARNINGS
>>> sim.simulate_channel('4.5 - 5.5', 2) # doctest: +IGNORE_WARNINGS
>>> sim.simulate_channel('5.5 - 6.5', 2) # doctest: +IGNORE_WARNINGS
>>> chans = sim.get_channels(['3.5 - 4.5','4.5 - 5.5','5.5 - 6.5'])
>>> sim.delete_channels(['3.5 - 4.5','4.5 - 5.5','5.5 - 6.5'])

Expand Down
15 changes: 9 additions & 6 deletions stingray/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

__all__ = ["Lightcurve"]

valid_statistics = ["poisson", "gauss", None]
valid_statistics = ["poisson", "gauss", "none"]

logger = setup_logger()

Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
err=None,
input_counts=True,
gti=None,
err_dist="poisson",
err_dist=None,
bg_counts=None,
bg_ratio=None,
frac_exp=None,
Expand Down Expand Up @@ -275,6 +275,11 @@ def __init__(
if not skip_checks:
time, counts, err = self.initial_optional_checks(time, counts, err, gti=gti)

if err_dist is None and input_counts:
err_dist = "poisson"
elif err_dist is None:
err_dist = "none"

if err_dist.lower() not in valid_statistics:
# err_dist set can be increased with other statistics
raise StingrayError(
Expand All @@ -283,10 +288,8 @@ def __init__(
)
elif not err_dist.lower() == "poisson":
simon(
"Stingray only uses poisson err_dist at the moment. "
"All analysis in the light curve will assume Poisson "
"errors. "
"Sorry for the inconvenience."
"Beware! Stingray only supports poisson err_dist at the moment in many methods"
", and 'gauss' in a few more. "
)

self._time = time
Expand Down
2 changes: 1 addition & 1 deletion stingray/sampledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def sample_data():
counts = data[0 : len(data), 1]

# Return class:`Lightcurve` object
return lightcurve.Lightcurve(dates, counts, dt=dt, skip_checks=True)
return lightcurve.Lightcurve(dates, counts, dt=dt, skip_checks=True, err_dist="poisson")
161 changes: 98 additions & 63 deletions stingray/simulator/tests/test_simulator.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions stingray/tests/test_crosscorrelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def setup_class(cls):
freq = 1 / 50
flux1 = 0.5 + 0.5 * np.sin(2 * np.pi * freq * times)
flux2 = 0.5 + 0.5 * np.sin(2 * np.pi * freq * (times - 20))

cls.lc1 = Lightcurve(times, flux1, dt=dt, err_dist="gauss", gti=gti, skip_checks=True)
cls.lc2 = Lightcurve(times, flux2, dt=dt, err_dist="gauss", gti=gti, skip_checks=True)
with pytest.warns(UserWarning, match="Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."):
cls.lc1 = Lightcurve(times, flux1, dt=dt, err_dist="gauss", gti=gti, skip_checks=True)
cls.lc2 = Lightcurve(times, flux2, dt=dt, err_dist="gauss", gti=gti, skip_checks=True)

def test_crosscorr(self):
cr = CrossCorrelation(self.lc1, self.lc2)
Expand Down
22 changes: 12 additions & 10 deletions stingray/tests/test_crossspectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,15 @@ def setup_class(self):
counts1 = np.random.poisson(10000, size=time.shape[0])
counts1_norm = counts1 / 13.4
counts1_norm_err = np.std(counts1) / 13.4
self.lc1_norm = Lightcurve(
time,
counts1_norm,
gti=[[tstart, self.tseg]],
dt=dt,
err_dist="gauss",
err=np.zeros_like(counts1_norm) + counts1_norm_err,
)
with pytest.warns(UserWarning, match="Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."):
self.lc1_norm = Lightcurve(
time,
counts1_norm,
gti=[[tstart, self.tseg]],
dt=dt,
err_dist="gauss",
err=np.zeros_like(counts1_norm) + counts1_norm_err,
)
self.lc1 = Lightcurve(time, counts1, gti=[[tstart, self.tseg]], dt=dt)
self.rate1 = np.mean(counts1) / dt # mean count rate (counts/sec) of light curve 1

Expand Down Expand Up @@ -1217,7 +1218,8 @@ def test_rebin_log_returns_complex_values_and_errors(self):
def test_timelag(self):
dt = 0.1
simulator = Simulator(dt, 10000, rms=0.2, mean=1000)
test_lc1 = simulator.simulate(2)
with pytest.warns(UserWarning, match="Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."):
test_lc1 = simulator.simulate(2)
test_lc1.counts -= np.min(test_lc1.counts)

with pytest.warns(UserWarning):
Expand Down Expand Up @@ -1417,7 +1419,7 @@ def setup_class(cls):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)

lc = Lightcurve(timestamps, signal, err_dist="poisson", dt=dt, gti=[[0, 100]])
lc = Lightcurve(timestamps, signal, dt=dt, gti=[[0, 100]])

cls.lc = lc

Expand Down
20 changes: 11 additions & 9 deletions stingray/tests/test_lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,15 @@ def test_irregular_time_warning(self):
)

with pytest.warns(UserWarning, match=warn_str):
_ = Lightcurve(times, counts, err_dist="poisson")
_ = Lightcurve(times, counts)

def test_unrecognize_err_dist_warning(self):
"""
Check if a non-poisson error_dist throws the correct warning.
"""
times = [1, 2, 3, 4, 5]
counts = [2, 2, 2, 2, 2]
warn_str = "SIMON says: Stingray only uses poisson err_dist at " "the moment"
warn_str = "SIMON says: Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."

with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always")
Expand Down Expand Up @@ -554,7 +554,7 @@ def test_input_countrate(self):
mean_counts = 2.0
times = np.arange(0 + dt / 2, 5 - dt / 2, dt)
countrate = np.zeros_like(times) + mean_counts
lc = Lightcurve(times, countrate, input_counts=False)
lc = Lightcurve(times, countrate, input_counts=False, err_dist="poisson")
assert np.allclose(lc.counts, np.zeros_like(countrate) + mean_counts * dt)

def test_meanrate(self):
Expand Down Expand Up @@ -726,7 +726,7 @@ def test_slicing(self):
dt=self.dt,
gti=self.gti,
err=self.counts / 10,
err_dist="gauss",
err_dist="poisson",
)
assert np.allclose(lc[1:3].counts, np.array([4, 6]))
assert np.allclose(lc[:2].counts, np.array([2, 4]))
Expand Down Expand Up @@ -1326,14 +1326,14 @@ def test_split_lc_by_gtis_minpoints(self):
def test_shift(self):
times = [1, 2, 3, 4, 5, 6, 7, 8]
counts = [1, 1, 1, 1, 2, 3, 3, 2]
lc = Lightcurve(times, counts, input_counts=True)
lc = Lightcurve(times, counts, input_counts=True, err_dist="poisson")
lc2 = lc.shift(1)
assert np.allclose(lc2.time - 1, times)
lc2 = lc.shift(-1)
assert np.allclose(lc2.time + 1, times)
assert np.allclose(lc2.counts, lc.counts)
assert np.allclose(lc2.countrate, lc.countrate)
lc = Lightcurve(times, counts, input_counts=False)
lc = Lightcurve(times, counts, input_counts=False, err_dist="poisson")
lc2 = lc.shift(1)
assert np.allclose(lc2.counts, lc.counts)
assert np.allclose(lc2.countrate, lc.countrate)
Expand All @@ -1350,6 +1350,7 @@ def test_table_roundtrip(self):
mission="BUBU",
instr="BABA",
mjdref=53467.0,
err_dist="poisson"
)

ts = lc.to_astropy_table()
Expand Down Expand Up @@ -1378,6 +1379,7 @@ def test_table_roundtrip_ctrate(self):
instr="BABA",
mjdref=53467.0,
input_counts=False,
err_dist="poisson"
)

ts = lc.to_astropy_table()
Expand All @@ -1393,7 +1395,7 @@ def test_timeseries_roundtrip(self):
wrong format is provided.
"""
N = len(self.times)
lc = Lightcurve(self.times, self.counts, mission="BUBU", instr="BABA", mjdref=53467.0)
lc = Lightcurve(self.times, self.counts, mission="BUBU", instr="BABA", mjdref=53467.0, err_dist="poisson")

ts = lc.to_astropy_timeseries()
new_lc = lc.from_astropy_timeseries(ts)
Expand All @@ -1413,7 +1415,7 @@ def test_timeseries_roundtrip_ctrate(self):
countrate = np.zeros_like(times) + mean_counts

lc = Lightcurve(
times, countrate, mission="BUBU", instr="BABA", mjdref=53467.0, input_counts=False
times, countrate, mission="BUBU", instr="BABA", mjdref=53467.0, input_counts=False, err_dist="poisson"
)

ts = lc.to_astropy_timeseries()
Expand Down Expand Up @@ -1558,7 +1560,7 @@ def test_apply_gtis_lc_rate(self, inplace):
time = np.arange(1, 10, dt)
countrate = np.zeros_like(time) + 5
# create the lightcurve from countrare
lc_rate = Lightcurve(time, counts=countrate, input_counts=False, gti=[[-0.5, 10.5]])
lc_rate = Lightcurve(time, counts=countrate, input_counts=False, gti=[[-0.5, 10.5]], err_dist="poisson")
lc_rate.gti = [[-0.5, 2.5]]
lc_rate_new = lc_rate.apply_gtis(inplace=inplace)
if inplace:
Expand Down
12 changes: 8 additions & 4 deletions stingray/tests/test_lombscargle.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def test_autofrequency():
class TestLombScargleCrossspectrum:
def setup_class(self):
sim = Simulator(0.0001, 50, 100, 1, random_state=42, tstart=0)
lc1 = sim.simulate(0)
lc2 = sim.simulate(0)
with pytest.warns(UserWarning, match="Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."):
lc1 = sim.simulate(0)
lc2 = sim.simulate(0)
self.rate1 = lc1.countrate
self.rate2 = lc2.countrate
low, high = lc1.time.min(), lc1.time.max()
Expand Down Expand Up @@ -147,7 +148,9 @@ def test_init_with_negative_max_freq(self):
lscs = LombScargleCrossspectrum(self.lc1, self.lc2, max_freq=-1)

def test_make_crossspectrum_diff_lc_counts_shape(self):
lc_ = Simulator(0.0001, 103, 100, 1, random_state=42, tstart=0).simulate(0)
sim = Simulator(0.0001, 103, 100, 1, random_state=42, tstart=0)
with pytest.warns(UserWarning, match="Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."):
lc_ = sim.simulate(0)
with pytest.warns(UserWarning) as record:
lscs = LombScargleCrossspectrum(self.lc1, lc_)
assert np.any(["different statistics" in r.message.args[0] for r in record])
Expand Down Expand Up @@ -246,7 +249,8 @@ def func(time, phase=0):
class TestLombScarglePowerspectrum:
def setup_class(self):
sim = Simulator(0.0001, 100, 100, 1, random_state=42, tstart=0)
lc = sim.simulate(0)
with pytest.warns(UserWarning, match="Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."):
lc = sim.simulate(0)
self.rate = lc.countrate
low, high = lc.time.min(), lc.time.max()
s1 = lc.counts
Expand Down
15 changes: 9 additions & 6 deletions stingray/tests/test_varenergyspectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def setup_class(cls):
from ..simulator import Simulator

simulator = Simulator(0.1, 10000, rms=0.2, mean=200)
test_lc = simulator.simulate(1)
with pytest.warns(UserWarning, match="Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."):
test_lc = simulator.simulate(1)
cls.test_ev1, cls.test_ev2 = EventList(), EventList()
cls.test_ev1.simulate_times(test_lc)
cls.test_ev1.energy = np.random.uniform(0.3, 12, len(cls.test_ev1.time))
Expand Down Expand Up @@ -190,9 +191,10 @@ def setup_class(cls):
flux = data / 40
times = np.arange(data.size) * cls.bin_time
gti = np.asanyarray([[0, data.size * cls.bin_time]])
test_lc = Lightcurve(
times, flux, err_dist="gauss", gti=gti, dt=cls.bin_time, skip_checks=True
)
with pytest.warns(UserWarning, match="Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."):
test_lc = Lightcurve(
times, flux, err_dist="gauss", gti=gti, dt=cls.bin_time, skip_checks=True
)

cls.test_ev1, cls.test_ev2 = EventList(), EventList()
cls.test_ev1.simulate_times(test_lc)
Expand Down Expand Up @@ -551,8 +553,9 @@ def setup_class(cls):
times, flux, rolled_flux = times[good], flux[good], rolled_flux[good]

length = times[-1] - times[0]
test_ref = Lightcurve(times, flux, err_dist="gauss", dt=dt, skip_checks=True)
test_sub = Lightcurve(test_ref.time, rolled_flux, err_dist=test_ref.err_dist, dt=dt)
with pytest.warns(UserWarning, match="Beware! Stingray only supports poisson err_dist at the moment in many methods, and 'gauss' in a few more."):
test_ref = Lightcurve(times, flux, err_dist="gauss", dt=dt, skip_checks=True)
test_sub = Lightcurve(test_ref.time, rolled_flux, err_dist=test_ref.err_dist, dt=dt)
test_ref_ev, test_sub_ev = EventList(), EventList()
test_ref_ev.simulate_times(test_ref)
test_sub_ev.simulate_times(test_sub)
Expand Down