Skip to content
Merged
29 changes: 16 additions & 13 deletions src/pst/SSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from astropy import units as u
from astropy import units

from pst.utils import check_unit
from pst.utils import check_unit, flux_conserving_interpolation

class SSPBase(object):
"""Base class that represents a model of Simple Stellar Populations.
Expand Down Expand Up @@ -223,32 +223,35 @@ def cut_wavelength(self, wl_min=None, wl_max=None, verbose=True):
if verbose:
print('[SSP] Models cut between {} {}'.format(wl_min, wl_max))

def interpolate_sed(self, new_wl_edges, verbose=True):
def interpolate_sed(self, new_wl, verbose=True, log=False, **interp_kwargs):
"""Flux-conserving interpolation.

Parameters
----------
- new_wl_edges: bin edges of the new interpolated points.
- new_wl: bin centers of the new interpolated points.
"""
if not isinstance(new_wl_edges, units.Quantity):
new_wl_edges *= self.wavelength.unit
if not isinstance(new_wl, units.Quantity):
new_wl = new_wl << self.wavelength.unit

new_wl = (new_wl_edges[1:] + new_wl_edges[:-1]) / 2
dwl = np.diff(new_wl_edges)
ori_dwl = np.hstack((np.diff(self.wavelength),
self.wavelength[-1] - self.wavelength[-2]))
if verbose:
print('[SSP] Interpolating SSP SEDs')

if log:
target_wl = np.log(new_wl.to_value(self.wavelength.unit))
ref_wl = np.log(self.wavelength.value)
else:
target_wl = new_wl
ref_wl = self.wavelength

new_l_lambda = np.empty(
shape=(self.metallicities.size, self.ages.size,
new_wl.size), dtype=np.float32) * self.L_lambda.unit

for i in range(self.L_lambda.shape[0]):
for j in range(self.L_lambda.shape[1]):
f = np.interp(new_wl_edges, self.wavelength,
np.cumsum(self.L_lambda[i, j] * ori_dwl))
new_flux = np.diff(f) / dwl
new_l_lambda[i, j] = new_flux
new_l_lambda[i, j] = flux_conserving_interpolation(
target_wl, ref_wl, self.L_lambda[i, j],
**interp_kwargs)

self.L_lambda = new_l_lambda
self.wavelength = new_wl
Expand Down
6 changes: 3 additions & 3 deletions src/pst/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def interpolate_ssp_masses(self, ssp: SSPBase, t_obs: u.Gyr,
Stellar masses corresponding to each SSP age and metallicity, in units
of solar masses.
"""

# define age bins from 0 to t_obs
age_bins = np.hstack(
[0 << u.yr, np.sqrt(ssp.ages[1:] * ssp.ages[:-1]), 1e12 << u.yr])
Expand Down Expand Up @@ -645,11 +644,12 @@ def stellar_mass_formed(self, times: u.Gyr):
"""
interpolator = interpolate.PchipInterpolator(
self.table_t, self.table_mass)
integral = interpolator(times) << self.table_mass.unit
integral = interpolator(times.to_value(self.table_t.unit)
) << self.table_mass.unit
integral[times > self.table_t[-1]] = self.table_mass[-1]
integral[times < self.table_t[0]] = 0
return integral

@u.quantity_input
def ism_metallicity(self, times: u.Gyr):
"""Evaluate the integral of the SFR over a given set of times.
Expand Down
Loading