88from astropy .io import fits
99from expecto import get_spectrum
1010from scipy .interpolate import RegularGridInterpolator
11+ import numba
1112
1213from .fast_starspot import doppler_shift
1314from .gaussian import ip_convolution
1415from .units import U , has_unit , kms , maybe_quantity_input , unit_arange
1516from .utils import c , read_rdb , sqrt2pi
16-
17+ from numba import float64 , int32
18+ from numba .experimental import jitclass
1719
1820def set_object_attributes (object , attrs ):
1921 """
@@ -618,9 +620,6 @@ def __repr__(self):
618620 pars = f"depth={ self ._depth :.3f} , fwhm={ self ._fwhm :.3f} "
619621 return f"gaussianCCF({ pars } )"
620622
621- from numba import float64 , int32
622- from numba .experimental import jitclass
623-
624623spectrum_attr_spec = [("wave" , float64 [:]), ("flux_arr" , float64 [:]), ("n" , int32 )]
625624
626625
@@ -641,34 +640,63 @@ def flux(self, μ: float = 0.0):
641640 ("wave" , float64 [:]),
642641 ("flux2d" , float64 [:, :]),
643642 ("n" , int32 ),
644- # ('interp', RegularGridInterpolator),
643+ ( "μ" , float64 [:])
645644]
646645
646+ @numba .njit
647+ def _find_interval (x , grid ):
648+ # Return i such that grid[i] <= x <= grid[i+1], clamped to [0, len-2]
649+ n = grid .size
650+ if x <= grid [0 ]:
651+ return 0
652+ if x >= grid [n - 2 ]:
653+ return n - 2
654+ lo = 0
655+ hi = n - 1
656+ while lo <= hi :
657+ mid = (lo + hi ) >> 1
658+ if grid [mid ] <= x :
659+ lo = mid + 1
660+ else :
661+ hi = mid - 1
662+ i = hi
663+ if i < 0 :
664+ i = 0
665+ if i > n - 2 :
666+ i = n - 2
667+ return i
668+
669+ @numba .njit
670+ def _blend_rows (mu , mu_grid , flux_mu_wave ):
671+ # flux_mu_wave shape: (M_mu, N_wave) == flux2d.T
672+ # linear interpolation along μ, returning one spectrum of length N_wave
673+ if mu <= mu_grid [0 ]:
674+ return flux_mu_wave [0 ].copy ()
675+ if mu >= mu_grid [- 1 ]:
676+ return flux_mu_wave [- 1 ].copy ()
677+ i = _find_interval (mu , mu_grid )
678+ mu0 = mu_grid [i ]
679+ mu1 = mu_grid [i + 1 ]
680+ t = (mu - mu0 ) / (mu1 - mu0 )
681+ return (1.0 - t ) * flux_mu_wave [i ] + t * flux_mu_wave [i + 1 ]
647682
648683@jitclass (spectrum2d_attr_spec )
649684class SpectrumNumbaInterpolated :
650- def __init__ (self , wave , flux2d ):
685+ def __init__ (self , wave , flux2d , μ ):
651686 self .n = wave .size
652687 self .wave = wave
653688 self .flux2d = flux2d
654-
655- # def interpolate(self):
656- # with objmode():
657- # return interp
689+ self .μ = μ
658690
659691 def flux (self , mu = 0.0 ):
660- # don't do extrapolations, use the μ=0.2 spectrum
661- if mu < 0.2 :
692+ # identical boundary policy as before (no extrapolation)
693+ if mu < np . min ( self . μ ) :
662694 return self .flux2d [:, 0 ]
663- # do interpolation
664- μ = np .array (
665- [0.2 , 0.3 , 0.35 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 0.95 , 0.97 , 0.98 , 0.99 , 1.0 ]
666- )
667- xi = np .column_stack ((np .full (self .wave .size , mu ), self .wave ))
668- with objmode (y = "float64[:]" ):
669- interp = RegularGridInterpolator ((μ , self .wave ), self .flux2d .T )
670- y = interp (xi )
671- return y
695+ # elif mu > np.max(self.μ):
696+ # return self.flux2d[:, -1]
697+ # interpolate along μ only; wavelength samples are native, so no λ interpolation required
698+ # flux2d is (N_wave, M_mu); transpose once on-the-fly for μ-major access
699+ return _blend_rows (mu , self .μ , self .flux2d .T )
672700
673701
674702class Spectrum :
@@ -919,7 +947,7 @@ def __call__(self, mu: float):
919947
920948 def to_numba (self ):
921949 return SpectrumNumbaInterpolated (
922- self .wave .astype (float ), self .flux .astype (float )
950+ self .wave .astype (float ), self .flux .astype (float ), self . μ . astype ( float )
923951 )
924952
925953class solarIAGatlas (Spectrum ):
@@ -982,7 +1010,7 @@ def __call__(self, mu: float):
9821010
9831011 def to_numba (self ):
9841012 return SpectrumNumbaInterpolated (
985- self .wave .astype (float ), self .flux .astype (float )
1013+ self .wave .astype (float ), self .flux .astype (float ), self . μ . astype ( float )
9861014 )
9871015
9881016
0 commit comments