Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/10391.adaptive_trace_model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add multiprocessing support to adaptive trace model step. Processing time reduced by a factor of 7 for one NIRSpec IFU test case, fitting all 30 slices with 10 cores.
11 changes: 11 additions & 0 deletions docs/jwst/adaptive_trace_model/arguments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,14 @@ The ``adaptive_trace_model`` step has the following step-specific arguments:
``spline_full``, ``spline_used``, ``linear_interp``, and ``spline_residual``,
respectively. The linear interpolation and residual flux files are saved only
if oversampling was performed.

``--maximum_cores``: The number of available cores that will be
used for multi-processing in this step. The default value is 'none', which does not use
multi-processing. The other options are either an integer, 'quarter', 'half', or 'all'.
Note that these fractions refer to the total available cores and on most CPUs these include
physical and virtual cores. Note that only the spline fitting portion of the code will
use multiprocessing. The oversampling portion is not affected by this flag. Thus the
speed up will be less than linear with the number of cores, since only a portion of
the code is affected. The relative fractional speedup will be largest for cases in
which the ``fit_threshold`` and ``slope_limit`` parameter values, and/or the ``psf_optimal`` flag,
result in a larger fraction of the image having the spline fit performed.
6 changes: 5 additions & 1 deletion docs/jwst/adaptive_trace_model/main.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,17 @@ Slope values higher than a threshold value (step parameter ``slope_limit``) indi
a compact source region. The trace model will be evaluated for these regions, with some
padding for nearby pixels; it will not be evaluated in other regions.

If no oversampling is desired (i.e. the ``oversample`` parameter is set to 1.0), then the trace
If no oversampling is desired (i.e., the ``oversample`` parameter is set to 1.0), then the trace
model is evaluated at every input pixel in a compact source region to create a wavelength-dependent
spatial profile. This image is stored in the output datamodel, in the ``trace_model`` attribute.
Regions for which a spline model could not be computed, or which did not meet the compact source
criteria, are set to NaN in the image. The step then returns without further changes to the input
datamodel. The rest of the algorithm description, below, applies only to oversampling.

The trace modeling portion of the code can optionally use multiprocessing to improve
runtime by modeling the spectral regions (i.e., each slice or slit) independently in
separate processes.

Comment thread
melanieclarke marked this conversation as resolved.
Oversample the Flux
^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/jwst/user_documentation/multiprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ the processing of a particular dataset running computationally-intensive steps:
* :ref:`jump <jump_step>` (jump detection)
* :ref:`ramp_fitting <ramp_fitting_step>`
* :ref:`wfss_contam <wfss_contam_step>` (WFSS contamination correction)
* :ref:`adaptive_trace_model <adaptive_trace_model_step>`

Unlike :ref:`multiproc_multiple-obs`, this usage is compatible with running
the pipeline within Jupyter Notebook/Lab.
Expand Down
2 changes: 2 additions & 0 deletions jwst/adaptive_trace_model/adaptive_trace_model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class AdaptiveTraceModelStep(Step):
save_intermediate_results = boolean(default=False) # Save the full spline model and residuals.
skip = boolean(default=True) # By default, skip the step.
output_use_model = boolean(default=True) # Use input filenames in the output models
maximum_cores = string(default='none') # cores for multiprocessing. Can be an integer, 'half', 'quarter', or 'all'
""" # noqa: E501

def process(self, input_data):
Expand Down Expand Up @@ -99,6 +100,7 @@ def process(self, input_data):
oversample_factor=self.oversample,
psf_optimal=self.psf_optimal,
return_intermediate_models=self.save_intermediate_results,
maximum_cores=self.maximum_cores,
)

model.meta.cal_step.adaptive_trace_model = "COMPLETE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,18 @@ def test_adaptive_trace_model_step_oversample(miri_mrs_model):
"dataset",
["miri_mrs_model_with_source", "nirspec_ifu_model_with_source", "nirspec_ifu_slice_wcs"],
)
def test_adaptive_trace_model_step_oversample_with_source(request, dataset):
@pytest.mark.parametrize("cores", ["none", "2"])
def test_adaptive_trace_model_step_oversample_with_source(request, dataset, cores):
model = request.getfixturevalue(dataset)

fit_threshold = 10.0
slope_limit = 0.05
result = AdaptiveTraceModelStep.call(
model, oversample=2, slope_limit=slope_limit, fit_threshold=fit_threshold
model,
oversample=2,
slope_limit=slope_limit,
fit_threshold=fit_threshold,
maximum_cores=cores,
Comment thread
melanieclarke marked this conversation as resolved.
)
assert result.meta.cal_step.adaptive_trace_model == "COMPLETE"

Expand Down Expand Up @@ -247,9 +252,16 @@ def test_adaptive_trace_model_unsupported_model(caplog):


@pytest.mark.slow
def test_adaptive_trace_model_step_psf_optimal(caplog, miri_mrs_model_with_source):
@pytest.mark.parametrize("cores", ["none", "2"])
def test_adaptive_trace_model_step_psf_optimal(caplog, miri_mrs_model_with_source, cores):
"""
Set psf_optimal to ensure all slices are fit and process in parallel.

Also test with multiprocessing, since it is meaningful in this case, where
multiple slices may be fit simultaneously.
"""
model = miri_mrs_model_with_source
result = AdaptiveTraceModelStep.call(model, oversample=2, psf_optimal=True)
result = AdaptiveTraceModelStep.call(model, oversample=2, psf_optimal=True, maximum_cores=cores)
assert result.meta.cal_step.adaptive_trace_model == "COMPLETE"
assert "Ignoring fit threshold and slope limit" in caplog.text

Expand Down
150 changes: 115 additions & 35 deletions jwst/adaptive_trace_model/trace_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import functools
import logging
import multiprocessing
import warnings
from multiprocessing import cpu_count

import gwcs
import numpy as np
from astropy.modeling.models import Identity, Scale, Shift
from astropy.stats import sigma_clipped_stats as scs
from astropy.utils.exceptions import AstropyUserWarning
from scipy.signal import find_peaks
from stcal.multiprocessing import compute_num_cores
from stdatamodels.jwst import datamodels
from stdatamodels.jwst.datamodels import dqflags

Expand Down Expand Up @@ -567,9 +571,11 @@ def linear_oversample(
return os_data


def fit_all_regions(flux, alpha, region_map, signal_threshold, **fit_kwargs):
def _fit_one_region(flux, alpha, region_map, signal_threshold, region_number, **fit_kwargs):
"""
Fit a trace model to all regions in the flux image.
Fit a trace model to a single region in the flux image.

Called from fit_all_regions, optionally parallelized via multiprocessing.

Parameters
----------
Expand All @@ -584,56 +590,119 @@ def fit_all_regions(flux, alpha, region_map, signal_threshold, **fit_kwargs):
Threshold values for each valid region in the region map. If
the median peak value across columns in the region is below this
threshold, a fit will not be attempted for that region.
region_number : int
Index number for the single region to be fit in this invocation.
**fit_kwargs
Keyword arguments to pass to the fitting routine (see `fit_2d_spline_trace`).

Returns
-------
spline_models : dict
Keys are region numbers, values are dicts containing a spline model,
scale, and bounds for each column index in the region. If a spline model
could not be fit, the column index number is not present.
splines : dict
Dict containing a spline model, scale, and bounds for each column index in the region.
If a spline model could not be fit, the column index number is not present.
"""
# Arrays to reset with NaNs for each slice
data_slice = np.full_like(flux, np.nan)
alpha_slice = np.full_like(flux, np.nan)

spline_models = {}
slice_numbers = np.unique(region_map[region_map > 0])
# Copy the relevant data for this slice into the holding arrays
indx = region_map == region_number
data_slice[indx] = flux[indx]
alpha_slice[indx] = alpha[indx]

for slnum in slice_numbers:
log.info("Fitting slice %s", slnum)
# A running sum in a given detector column (used for normalization)
runsum = np.nansum(data_slice, axis=0)

# Reset holding arrays to NaN
data_slice[:] = np.nan
alpha_slice[:] = np.nan
# Collapse the slice along Y to get max in each column
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
collapse = np.nanmax(data_slice, axis=0)

# Copy the relevant data for this slice into the holding arrays
indx = region_map == slnum
data_slice[indx] = flux[indx]
alpha_slice[indx] = alpha[indx]
# Median column max across all columns
medcmax = np.nanmedian(collapse)

# A running sum in a given detector column (used for normalization)
runsum = np.nansum(data_slice, axis=0)
# Is medcmax over threshold? If so, do bspline for this slice.
dospline = False
if medcmax > signal_threshold[region_number]:
dospline = True

# Collapse the slice along Y to get max in each column
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
collapse = np.nanmax(data_slice, axis=0)
if dospline:
splines = fit_2d_spline_trace(data_slice, alpha_slice, fit_scale=runsum, **fit_kwargs)
else:
splines = {}

# Median column max across all columns
medcmax = np.nanmedian(collapse)
return splines

# Is medcmax over threshold? If so, do bspline for this slice.
dospline = False
if medcmax > signal_threshold[slnum]:
dospline = True

if dospline:
splines = fit_2d_spline_trace(data_slice, alpha_slice, fit_scale=runsum, **fit_kwargs)
else:
splines = {}
spline_models[slnum] = splines
def fit_all_regions(flux, alpha, region_map, signal_threshold, maximum_cores="none", **fit_kwargs):
"""
Fit a trace model to all regions in the flux image.

Parameters
----------
flux : ndarray
The flux image to fit.
alpha : ndarray
Alpha coordinates for all flux values.
region_map : ndarray of int
Map containing the slice or slit number for valid regions.
Values are >0 for pixels in valid regions, 0 otherwise.
signal_threshold : dict
Threshold values for each valid region in the region map. If
the median peak value across columns in the region is below this
threshold, a fit will not be attempted for that region.
maximum_cores : str
Number of cores to use for multiprocessing. If set to 'none' (the default),
then no multiprocessing will be done. The other allowable values are 'quarter',
'half', 'all', and string integers. This is the fraction of available or
the explicit number of cores to use for multiprocessing.
**fit_kwargs
Keyword arguments to pass to the fitting routine (see `fit_2d_spline_trace`).

Returns
-------
spline_models : dict
Keys are region numbers, values are dicts containing a spline model,
scale, and bounds for each column index in the region. If a spline model
could not be fit, the column index number is not present.
"""
spline_models = {}
slice_numbers = np.unique(region_map[region_map > 0])

# Determine number of slices to use for multi-processor computations
num_available_cores = cpu_count()
number_slices = compute_num_cores(maximum_cores, len(slice_numbers), num_available_cores)

# Call adaptive trace model for the single processor (1 data slice) case
Comment thread
melanieclarke marked this conversation as resolved.
if number_slices == 1:
# Single threaded computation
log.info("Running single-process calculation")

for slnum in slice_numbers:
log.info("Fitting slice %s", slnum)
spline_models[slnum] = _fit_one_region(
flux, alpha, region_map, signal_threshold, slnum, **fit_kwargs
)
else:
# Parallelized computation
log.info(f"Fitting slices, multiprocessing on {number_slices} cores")

# Use functools.partial to supply all other inputs to _fit_one_region except slice number
# This is needed since pool.starmap doesn't support passing **fit_kwargs
fit_one_region_with_args = functools.partial(
_fit_one_region, flux, alpha, region_map, signal_threshold, **fit_kwargs
)

# Run the parallelized calc and collect results
ctx = multiprocessing.get_context("spawn")
pool = ctx.Pool(processes=number_slices)
try:
pool_results = pool.starmap(fit_one_region_with_args, [(n,) for n in slice_numbers])
finally:
pool.close()
pool.join()
for slnum, result in zip(slice_numbers, pool_results, strict=True):
spline_models[slnum] = result

return spline_models

Expand Down Expand Up @@ -1230,6 +1299,7 @@ def fit_and_oversample(
psf_optimal=False,
oversample_factor=1.0,
return_intermediate_models=False,
maximum_cores="none",
):
"""
Fit a trace model and optionally oversample an IFU datamodel.
Expand Down Expand Up @@ -1257,6 +1327,11 @@ def fit_and_oversample(
If True, additional image models will be returned, containing the full
spline model, the spline model as used for compact sources, the residual
model, and the linearly interpolated data.
maximum_cores : str
Number of cores to use for multiprocessing. If set to 'none' (the default),
then no multiprocessing will be done. The other allowable values are 'quarter',
'half', 'all', and string integers. This is the fraction of available or
the explicit number of cores to use for multiprocessing.

Returns
-------
Expand Down Expand Up @@ -1366,7 +1441,12 @@ def fit_and_oversample(
# Fit spline models to all regions
fit_kwargs = _set_fit_kwargs(detector, xsize)
spline_models = fit_all_regions(
flux_orig, alpha_orig, region_map, signal_threshold, **fit_kwargs
flux_orig,
alpha_orig,
region_map,
signal_threshold,
maximum_cores=maximum_cores,
**fit_kwargs,
)

# If oversampling is not needed, evaluate the spline models to create the
Expand Down
1 change: 1 addition & 0 deletions jwst/regtest/test_nirspec_ifu_trace_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def run_spec2_trace_model(rtdata_module):
"--steps.adaptive_trace_model.skip=false",
"--steps.adaptive_trace_model.save_results=true",
"--steps.adaptive_trace_model.oversample=2.0",
"--steps.adaptive_trace_model.maximum_cores=2",
Comment thread
melanieclarke marked this conversation as resolved.
]
Step.from_cmdline(args)
return rtdata
Expand Down
Loading