Skip to content

Commit d49b2b9

Browse files
committed
Revise significance and threshold tests
1 parent bf9aa41 commit d49b2b9

1 file changed

Lines changed: 95 additions & 66 deletions

File tree

jwst/adaptive_trace_model/trace_model.py

Lines changed: 95 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def _is_compact_source(
340340
# Bin the alpha coordinates for the high slope locations
341341
avec = np.arange(spline_bkpt) * native_dalpha / 2 - (native_dalpha * spline_bkpt / 4)
342342
with warnings.catch_warnings():
343-
warnings.simplefilter("ignore", RuntimeWarning)
343+
warnings.filterwarnings("ignore", category=RuntimeWarning)
344344
hist, edges = np.histogram(
345345
alpha_ptsource,
346346
bins=spline_bkpt,
@@ -595,34 +595,41 @@ def _significance_test(data_slice, err_slice, alpha_slice):
595595
& (err_slice > 0)
596596
)
597597
if not np.any(valid):
598-
return None, None, None
598+
return None, None, None, None
599599

600600
# Compute some SNR and noise statistics collapsed along wavelength
601-
with warnings.catch_warnings():
602-
warnings.simplefilter("ignore", RuntimeWarning)
603-
snr_slice = data_slice / err_slice
604-
snr_slice[~valid] = np.nan
601+
snr_slice = np.full_like(data_slice, np.nan)
602+
snr_slice[valid] = data_slice[valid] / err_slice[valid]
605603
native_dalpha = np.abs(np.nanmedian(np.diff(alpha_slice, axis=0)))
606604
step = native_dalpha / 2.0
607605

608606
# Bin errors and SNR by alpha values
609607
alpha_test = np.arange(np.nanmin(alpha_slice), np.nanmax(alpha_slice), step)
610-
err_test = np.zeros_like(alpha_test)
611-
snr_test = np.zeros_like(alpha_test)
612-
for kk in range(0, len(alpha_test)):
613-
indx = (alpha_slice >= alpha_test[kk] - step / 2.0) & (
614-
alpha_slice < alpha_test[kk] + step / 2.0
608+
flux_test = np.full_like(alpha_test, np.nan)
609+
err_test = np.full_like(alpha_test, np.nan)
610+
snr_test = np.full_like(alpha_test, np.nan)
611+
for kk in range(len(alpha_test)):
612+
indx = (
613+
(alpha_slice >= alpha_test[kk] - step / 2.0)
614+
& (alpha_slice < alpha_test[kk] + step / 2.0)
615+
& np.isfinite(snr_slice)
615616
)
617+
if not np.any(indx):
618+
continue
619+
flux_test[kk] = np.nanmedian(data_slice[indx])
616620
err_test[kk] = np.nanmedian(err_slice[indx])
617621
snr_test[kk] = np.nanmedian(snr_slice[indx])
618622

619-
return alpha_test, err_test, snr_test
623+
return alpha_test, flux_test, err_test, snr_test
620624

621625

622626
def _trim_edges(data_slice, alpha_slice, alpha_test, err_test, snr_test):
623627
# Bad edges are where SNR is low but ERR is high: set them to NaN
624-
err_mean, _, err_rms = scs(err_test)
625-
bad = (snr_test < 5) & (err_test > err_mean + 5 * err_rms)
628+
valid = np.isfinite(snr_test)
629+
if not np.any(valid):
630+
return
631+
err_mean, _, err_rms = scs(err_test[valid])
632+
bad = (np.abs(snr_test) < 5) & (err_test > err_mean + 5 * err_rms)
626633
if not np.any(bad) or np.all(bad):
627634
return
628635

@@ -640,7 +647,30 @@ def _trim_edges(data_slice, alpha_slice, alpha_test, err_test, snr_test):
640647
data_slice[indx] = np.nan
641648

642649

643-
def _fit_one_region(flux, error, alpha, region_map, signal_threshold, region_number, **fit_kwargs):
650+
def _threshold_test(flux_test, snr_test, region_number, peak_threshold, snr_threshold):
651+
if peak_threshold is None and snr_threshold is None:
652+
return True
653+
peak_over_threshold = (
654+
peak_threshold is not None
655+
and flux_test is not None
656+
and np.nanmax(flux_test) > peak_threshold[region_number]
657+
)
658+
snr_over_threshold = (
659+
snr_threshold is not None and snr_test is not None and np.nanmax(snr_test) > snr_threshold
660+
)
661+
return peak_over_threshold or snr_over_threshold
662+
663+
664+
def _fit_one_region(
665+
flux,
666+
error,
667+
alpha,
668+
region_map,
669+
region_number,
670+
peak_threshold=None,
671+
snr_threshold=None,
672+
**fit_kwargs,
673+
):
644674
"""
645675
Fit a trace model to a single region in the flux image.
646676
@@ -657,12 +687,16 @@ def _fit_one_region(flux, error, alpha, region_map, signal_threshold, region_num
657687
region_map : ndarray of int
658688
Map containing the slice or slit number for valid regions.
659689
Values are >0 for pixels in valid regions, 0 otherwise.
660-
signal_threshold : dict
661-
Threshold values for each valid region in the region map. If
662-
the median peak value across columns in the region is below this
663-
threshold, a fit will not be attempted for that region.
664690
region_number : int
665691
Index number for the single region to be fit in this invocation.
692+
peak_threshold : dict
693+
Flux threshold values for each valid region in the region map. If
694+
the median peak value across columns in the region is below this
695+
threshold, a fit will not be attempted for that region.
696+
snr_threshold : float
697+
Signal-to-noise ratio (SNR) threshold value. If the median SNR value
698+
across columns in the region is below this threshold, a fit will not
699+
be attempted for that region.
666700
**fit_kwargs
667701
Keyword arguments to pass to the fitting routine (see `fit_2d_spline_trace`).
668702
@@ -672,9 +706,6 @@ def _fit_one_region(flux, error, alpha, region_map, signal_threshold, region_num
672706
Dict containing a spline model, scale, and bounds for each column index in the region.
673707
If a spline model could not be fit, the column index number is not present.
674708
"""
675-
# Splines to return
676-
splines = {}
677-
678709
# Arrays to reset with NaNs for each slice
679710
data_slice = np.full_like(flux, np.nan)
680711
err_slice = np.full_like(flux, np.nan)
@@ -687,41 +718,43 @@ def _fit_one_region(flux, error, alpha, region_map, signal_threshold, region_num
687718
alpha_slice[indx] = alpha[indx]
688719

689720
# Estimate SNR for the slit or slice, collapsed along wavelength
690-
alpha_test, err_test, snr_test = _significance_test(data_slice, err_slice, alpha_slice)
721+
alpha_test, flux_test, err_test, snr_test = _significance_test(
722+
data_slice, err_slice, alpha_slice
723+
)
724+
725+
# Is either peak flux or SNR over threshold? If not, stop processing
726+
no_data_msg = "No data over threshold; not fitting splines."
727+
if not _threshold_test(flux_test, snr_test, region_number, peak_threshold, snr_threshold):
728+
log.debug(no_data_msg)
729+
return {}
691730

692731
# Use SNR estimates to trim slit or slice edges
693-
if snr_test is not None:
694-
_trim_edges(data_slice, alpha_slice, alpha_test, err_test, snr_test)
732+
_trim_edges(data_slice, alpha_slice, alpha_test, err_test, snr_test)
733+
734+
# Redo SNR estimate after trimming
735+
_, flux_test, _, snr_test = _significance_test(data_slice, err_slice, alpha_slice)
736+
737+
# Check again for flux over threshold
738+
if not _threshold_test(flux_test, snr_test, region_number, peak_threshold, snr_threshold):
739+
log.debug(no_data_msg)
740+
return {}
695741

696742
# Get a running sum in a given detector column (used for normalization)
697743
negative_nod_threshold = -5.0
698-
if snr_test is not None and np.nanmin(snr_test) < negative_nod_threshold:
744+
if np.nanmin(snr_test) < negative_nod_threshold:
699745
# If significant negative nods present, just sum positive data
700-
log.info("Found significant negative data; summing positive only for normalization.")
746+
log.debug("Found significant negative data; summing positive only for normalization.")
701747
runsum = np.sum(data_slice, where=(data_slice > 0), axis=0)
702748
else:
703749
runsum = np.nansum(data_slice, axis=0)
704750

705-
# Collapse the slice along Y to get max in each column
706-
with warnings.catch_warnings():
707-
warnings.filterwarnings("ignore", category=RuntimeWarning)
708-
collapse = np.nanmax(data_slice, axis=0)
709-
710-
# Median column max across all columns
711-
with warnings.catch_warnings():
712-
warnings.filterwarnings("ignore", category=RuntimeWarning)
713-
medcmax = np.nanmedian(collapse)
714-
715-
# Is medcmax over threshold? If so, do bspline for this slice.
716-
if medcmax > signal_threshold[region_number]:
717-
splines = fit_2d_spline_trace(data_slice, alpha_slice, fit_scale=runsum, **fit_kwargs)
751+
# Fit the splines
752+
splines = fit_2d_spline_trace(data_slice, alpha_slice, fit_scale=runsum, **fit_kwargs)
718753

719754
return splines
720755

721756

722-
def fit_all_regions(
723-
flux, error, alpha, region_map, signal_threshold, maximum_cores="none", **fit_kwargs
724-
):
757+
def fit_all_regions(flux, error, alpha, region_map, maximum_cores="none", **fit_kwargs):
725758
"""
726759
Fit a trace model to all regions in the flux image.
727760
@@ -736,10 +769,6 @@ def fit_all_regions(
736769
region_map : ndarray of int
737770
Map containing the slice or slit number for valid regions.
738771
Values are >0 for pixels in valid regions, 0 otherwise.
739-
signal_threshold : dict
740-
Threshold values for each valid region in the region map. If
741-
the median peak value across columns in the region is below this
742-
threshold, a fit will not be attempted for that region.
743772
maximum_cores : str
744773
Number of cores to use for multiprocessing. If set to 'none' (the default),
745774
then no multiprocessing will be done. The other allowable values are 'quarter',
@@ -765,21 +794,21 @@ def fit_all_regions(
765794
# Call adaptive trace model for the single processor (1 data slice) case
766795
if number_slices == 1:
767796
# Single threaded computation
768-
log.info("Running single-process calculation")
797+
log.debug("Running single-process calculation")
769798
for reg_num in region_numbers:
770799
if len(region_numbers) > 1:
771800
log.info("Fitting slice %s", reg_num)
772801
spline_models[reg_num] = _fit_one_region(
773-
flux, error, alpha, region_map, signal_threshold, reg_num, **fit_kwargs
802+
flux, error, alpha, region_map, reg_num, **fit_kwargs
774803
)
775804
else:
776805
# Parallelized computation
777-
log.info(f"Fitting slices, multiprocessing on {number_slices} cores")
806+
log.info(f"Multiprocessing on {number_slices} cores")
778807

779808
# Use functools.partial to supply all other inputs to _fit_one_region except slice number
780809
# This is needed since pool.starmap doesn't support passing **fit_kwargs
781810
fit_one_region_with_args = functools.partial(
782-
_fit_one_region, flux, error, alpha, region_map, signal_threshold, **fit_kwargs
811+
_fit_one_region, flux, error, alpha, region_map, **fit_kwargs
783812
)
784813

785814
# Run the parallelized calc and collect results
@@ -1650,12 +1679,6 @@ def fit_and_oversample(
16501679
else:
16511680
raise ValueError("Unknown detector")
16521681

1653-
# For single slit images, fit_threshold is not useful: disable it
1654-
region_numbers = np.unique(region_map[region_map > 0])
1655-
if len(region_numbers) < 2 and fit_threshold > 0:
1656-
log.info("Ignoring fit threshold for single slit image")
1657-
fit_threshold = 0
1658-
16591682
# For multiple integrations, fit the profile to the median image
16601683
if nint > 1:
16611684
# Also check for an input oversample factor:
@@ -1698,16 +1721,19 @@ def fit_and_oversample(
16981721
flux_orig = flux_orig - overall_mean
16991722
overall_mean = 0
17001723

1701-
# Define a per-slice analysis threshold (must be brighter than some level above background)
1724+
# Define a per-slice analysis threshold for IFU
1725+
# (must be brighter than some level above background)
1726+
peak_threshold = None
1727+
region_numbers = np.unique(region_map[region_map > 0])
17021728
if fit_threshold <= 0:
1703-
# In this case, all slices should be fit, so make the threshold
1704-
# lower than any real signal
1705-
signal_threshold = dict.fromkeys(region_numbers, -np.inf)
1729+
# In this case for any mode, all slices should be fit,
1730+
# so set both thresholds to None
1731+
fit_threshold = None
17061732
else:
17071733
if mode == "MIR_MRS":
17081734
# For MIRI MRS we need each channel to have its own threshold, particularly
17091735
# for Ch3/Ch4 since the sky is so much brighter in Ch4
1710-
signal_threshold = dict.fromkeys(region_numbers, np.nan)
1736+
peak_threshold = dict.fromkeys(region_numbers, np.nan)
17111737
for channel in [100, 200, 300, 400]:
17121738
ch_data = (region_map >= channel) & (region_map < channel + 100)
17131739
if not np.any(ch_data):
@@ -1717,20 +1743,23 @@ def fit_and_oversample(
17171743
ch_mean, _, ch_rms = scs(flux_orig[ch_data])
17181744
for reg_num in region_numbers:
17191745
if channel <= reg_num < channel + 100:
1720-
signal_threshold[reg_num] = ch_mean + fit_threshold * ch_rms
1721-
else:
1746+
peak_threshold[reg_num] = ch_mean + fit_threshold * ch_rms
1747+
elif mode == "NRS_IFU":
17221748
# For NIRSpec IFU data, all regions have the same threshold
17231749
threshold = overall_mean + fit_threshold * overall_rms
1724-
signal_threshold = dict.fromkeys(region_numbers, threshold)
1750+
peak_threshold = dict.fromkeys(region_numbers, threshold)
17251751

17261752
# Fit spline models to all regions
17271753
fit_kwargs = _set_fit_kwargs(mode, detector, slit, xsize)
1754+
if peak_threshold is not None:
1755+
fit_kwargs["peak_threshold"] = peak_threshold
1756+
else:
1757+
fit_kwargs["snr_threshold"] = fit_threshold
17281758
spline_models = fit_all_regions(
17291759
flux_orig,
17301760
err_orig,
17311761
alpha_orig,
17321762
region_map,
1733-
signal_threshold,
17341763
maximum_cores=maximum_cores,
17351764
**fit_kwargs,
17361765
)

0 commit comments

Comments
 (0)