Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions src/pst/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,6 @@ def interpolate(self, wavelength=None):
self.wavelength= wavelength
return self.response

def _check_spectra(self, spectra, default_unit=u.Lsun / u.angstrom / u.cm**2):
if spectra is not None and not isinstance(spectra, u.Quantity):
return spectra * default_unit
else:
return spectra

def get_photons(self, spectra, spectra_err=None, mask_nan=True):
r"""Compute the photon flux from an input spectra.

Expand Down Expand Up @@ -325,8 +319,9 @@ def get_photons(self, spectra, spectra_err=None, mask_nan=True):
photon_flux_err : :class:``astropy.units.Quantity``
Filter photon flux associated error.
"""
spectra = self._check_spectra(spectra)
spectra_err = self._check_spectra(spectra_err)
spectra = utils.check_unit(spectra, default_unit=u.Lsun / u.angstrom / u.cm**2,
equivalence=u.spectral_density, wav=self.wavelength)

if mask_nan:
mask = np.isfinite(spectra)
photon_flux = np.trapz(
Expand All @@ -340,6 +335,10 @@ def get_photons(self, spectra, spectra_err=None, mask_nan=True):
x=self.wavelength)

if spectra_err is not None:

spectra_err = utils.check_unit(spectra_err,
default_unit=u.Lsun / u.angstrom / u.cm**2,
equivalence=u.spectral_density, wav=self.wavelength)
if mask_nan:
mask = mask & np.isfinite(spectra_err)
else:
Expand Down Expand Up @@ -455,8 +454,9 @@ def get_flambda_vegamag(self, spectra, spectra_err=None, mask_nan=True):
--------
:func:`get_photons`
"""
spectra = self._check_spectra(spectra)
spectra_err = self._check_spectra(spectra_err)
spectra = spectra = utils.check_unit(spectra, default_unit=u.Lsun / u.angstrom / u.cm**2,
equivalence=u.spectral_density, wav=self.wavelength)

if mask_nan:
mask = np.isfinite(spectra)
else:
Expand All @@ -466,6 +466,9 @@ def get_flambda_vegamag(self, spectra, spectra_err=None, mask_nan=True):
) / np.trapz(self.response[mask] * self.wavelength[mask], x=self.wavelength[mask])

if spectra_err is not None:

spectra_err = utils.check_unit(spectra_err, default_unit=u.Lsun / u.angstrom / u.cm**2,
equivalence=u.spectral_density, wav=self.wavelength)
if mask_nan:
mask = mask & np.isfinite(spectra_err)
else:
Expand Down
15 changes: 10 additions & 5 deletions src/pst/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def gaussian1d_conv(f, sigma, deltax):
f_convolved[pixel] = np.sum(f * g)
return f_convolved

def check_unit(quantity, default_unit=None):
def check_unit(quantity, default_unit=None, equivalence=None, **equiv_kwargs):
"""Check the units of an input quantity.

Parameters
Expand All @@ -81,12 +81,17 @@ def check_unit(quantity, default_unit=None):
isq = isinstance(quantity, u.Quantity)
if isq and default_unit is not None:
if not quantity.unit.is_equivalent(default_unit):
raise u.UnitTypeError(
"Input quantity does not have the appropriate units")
if equivalence is not None:
return quantity.to(default_unit,
equivalencies=equivalence(**equiv_kwargs))
else:
raise u.UnitTypeError(
"Input quantity does not have the appropriate units")
else:
return quantity
return quantity.to(default_unit)

elif not isq and default_unit is not None:
return quantity * default_unit
return quantity << default_unit
elif not isq and default_unit is None:
raise ValueError("Input value must be a astropy.units.Quantity")
else:
Expand Down
20 changes: 16 additions & 4 deletions tests/test_observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ def setUpClass(self):
print("Setting SSP model")
self.dummy_wavelength = np.geomspace(100, 1e5, 3000) * u.angstrom
# Monocromatic SED
self.dummy_spectra = np.ones(self.dummy_wavelength.size
self.dummy_flam = np.ones(self.dummy_wavelength.size
) * constants.c / self.dummy_wavelength**2 * 3631 * u.Jy
self.dummy_fnu = np.ones(self.dummy_wavelength.size
) * 3631 * u.Jy

def test_default_dir(self):
self.assertTrue(
Expand Down Expand Up @@ -52,20 +54,30 @@ def test_filter(self):
0.32484839450189695),
"Unexpected effective transmission value")

# Interpolate filter to input wavelength array
filter.interpolate(self.dummy_wavelength)
flux, _ = filter.get_fnu(self.dummy_spectra)
# Use flam
flux, _ = filter.get_fnu(self.dummy_flam)
self.assertTrue(np.isclose(flux, 3631.0 * u.Jy),
f"Unexpected integrated flux value: {flux}")

mag, _ = filter.get_ab(self.dummy_spectra)
mag, _ = filter.get_ab(self.dummy_flam)
self.assertTrue(np.isclose(mag, 0.0, atol=1e-4),
f"Unexpected magnitude value: {mag}")
# Use fnu
flux, _ = filter.get_fnu(self.dummy_fnu)
self.assertTrue(np.isclose(flux, 3631.0 * u.Jy),
f"Unexpected integrated flux value: {flux}")

mag, _ = filter.get_ab(self.dummy_fnu)
self.assertTrue(np.isclose(mag, 0.0, atol=1e-4),
f"Unexpected magnitude value: {mag}")

fig = filter.plot(show=False)

def test_equivalent_width(self):
eqwidth = observables.EquivalentWidth.from_name("lick_ha")
ew, ew_err = eqwidth.compute_ew(self.dummy_wavelength, self.dummy_spectra)
ew, ew_err = eqwidth.compute_ew(self.dummy_wavelength, self.dummy_flam)
self.assertTrue(np.isfinite(ew), "Unexpected EW value")


Expand Down