Skip to content

Commit 60f9558

Browse files
committed
Generalization of spec_mu
1 parent d6bbe81 commit 60f9558

2 files changed

Lines changed: 53 additions & 24 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
SOAP/__pycache__
2-
docs/notebooks/*.npz
2+
docs/notebooks/*.npz
3+
data/IAGatlas

SOAP/classes.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from astropy.io import fits
99
from expecto import get_spectrum
1010
from scipy.interpolate import RegularGridInterpolator
11+
import numba
1112

1213
from .fast_starspot import doppler_shift
1314
from .gaussian import ip_convolution
1415
from .units import U, has_unit, kms, maybe_quantity_input, unit_arange
1516
from .utils import c, read_rdb, sqrt2pi
16-
17+
from numba import float64, int32
18+
from numba.experimental import jitclass
1719

1820
def set_object_attributes(object, attrs):
1921
"""
@@ -618,9 +620,6 @@ def __repr__(self):
618620
pars = f"depth={self._depth:.3f}, fwhm={self._fwhm:.3f}"
619621
return f"gaussianCCF({pars})"
620622

621-
from numba import float64, int32
622-
from numba.experimental import jitclass
623-
624623
spectrum_attr_spec = [("wave", float64[:]), ("flux_arr", float64[:]), ("n", int32)]
625624

626625

@@ -641,34 +640,63 @@ def flux(self, μ: float = 0.0):
641640
("wave", float64[:]),
642641
("flux2d", float64[:, :]),
643642
("n", int32),
644-
# ('interp', RegularGridInterpolator),
643+
("μ", float64[:])
645644
]
646645

646+
@numba.njit
647+
def _find_interval(x, grid):
648+
# Return i such that grid[i] <= x <= grid[i+1], clamped to [0, len-2]
649+
n = grid.size
650+
if x <= grid[0]:
651+
return 0
652+
if x >= grid[n-2]:
653+
return n - 2
654+
lo = 0
655+
hi = n - 1
656+
while lo <= hi:
657+
mid = (lo + hi) >> 1
658+
if grid[mid] <= x:
659+
lo = mid + 1
660+
else:
661+
hi = mid - 1
662+
i = hi
663+
if i < 0:
664+
i = 0
665+
if i > n - 2:
666+
i = n - 2
667+
return i
668+
669+
@numba.njit
670+
def _blend_rows(mu, mu_grid, flux_mu_wave):
671+
# flux_mu_wave shape: (M_mu, N_wave) == flux2d.T
672+
# linear interpolation along μ, returning one spectrum of length N_wave
673+
if mu <= mu_grid[0]:
674+
return flux_mu_wave[0].copy()
675+
if mu >= mu_grid[-1]:
676+
return flux_mu_wave[-1].copy()
677+
i = _find_interval(mu, mu_grid)
678+
mu0 = mu_grid[i]
679+
mu1 = mu_grid[i+1]
680+
t = (mu - mu0) / (mu1 - mu0)
681+
return (1.0 - t) * flux_mu_wave[i] + t * flux_mu_wave[i+1]
647682

648683
@jitclass(spectrum2d_attr_spec)
649684
class SpectrumNumbaInterpolated:
650-
def __init__(self, wave, flux2d):
685+
def __init__(self, wave, flux2d, μ):
651686
self.n = wave.size
652687
self.wave = wave
653688
self.flux2d = flux2d
654-
655-
# def interpolate(self):
656-
# with objmode():
657-
# return interp
689+
self.μ = μ
658690

659691
def flux(self, mu=0.0):
660-
# don't do extrapolations, use the μ=0.2 spectrum
661-
if mu < 0.2:
692+
# identical boundary policy as before (no extrapolation)
693+
if mu < np.min(self.μ):
662694
return self.flux2d[:, 0]
663-
# do interpolation
664-
μ = np.array(
665-
[0.2, 0.3, 0.35, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.97, 0.98, 0.99, 1.0]
666-
)
667-
xi = np.column_stack((np.full(self.wave.size, mu), self.wave))
668-
with objmode(y="float64[:]"):
669-
interp = RegularGridInterpolator((μ, self.wave), self.flux2d.T)
670-
y = interp(xi)
671-
return y
695+
# elif mu > np.max(self.μ):
696+
# return self.flux2d[:, -1]
697+
# interpolate along μ only; wavelength samples are native, so no λ interpolation required
698+
# flux2d is (N_wave, M_mu); transpose once on-the-fly for μ-major access
699+
return _blend_rows(mu, self.μ, self.flux2d.T)
672700

673701

674702
class Spectrum:
@@ -919,7 +947,7 @@ def __call__(self, mu: float):
919947

920948
def to_numba(self):
921949
return SpectrumNumbaInterpolated(
922-
self.wave.astype(float), self.flux.astype(float)
950+
self.wave.astype(float), self.flux.astype(float),self.μ.astype(float)
923951
)
924952

925953
class solarIAGatlas(Spectrum):
@@ -982,7 +1010,7 @@ def __call__(self, mu: float):
9821010

9831011
def to_numba(self):
9841012
return SpectrumNumbaInterpolated(
985-
self.wave.astype(float), self.flux.astype(float)
1013+
self.wave.astype(float), self.flux.astype(float),self.μ.astype(float)
9861014
)
9871015

9881016

0 commit comments

Comments
 (0)