From 9147aab7aa4dc4a28e7461750b0ae4bb3ef6d26f Mon Sep 17 00:00:00 2001 From: "T. E. Pickering" Date: Tue, 31 Mar 2026 15:31:19 -0700 Subject: [PATCH 1/3] fix: reduce memory in flatfield by evaluating tilts only at slit pixels Replace full-frame meshgrid tilt evaluation with per-slit-pixel evaluation using PypeItFit.eval directly. For spectrographs with many slits (e.g., fiber-fed IFUs with hundreds of fibers), the previous approach allocated a full-frame tilts array per slit, causing excessive memory usage. Co-Authored-By: Claude Opus 4.6 --- pypeit/flatfield.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pypeit/flatfield.py b/pypeit/flatfield.py index 68b194d890..44a4ab1189 100644 --- a/pypeit/flatfield.py +++ b/pypeit/flatfield.py @@ -989,14 +989,28 @@ def fit(self, spat_illum_only=False, doqa=True, debug=False): # Collapse the slit spatially and fit the spectral function # TODO: Put this stuff in a self.spectral_fit method? - # Create the tilts image for this slit + # Create the tilts for pixels in this slit only (not full image) if self.slitless: tilts = np.tile(np.arange(rawflat.shape[0]) / rawflat.shape[0], (rawflat.shape[1], 1)).T + spec_coo = tilts * (nspec-1) else: # TODO -- JFH Confirm the sign of this shift is correct! _flexure = 0. if self.wavetilts.spat_flexure is None else self.wavetilts.spat_flexure - tilts = tracewave.fit2tilts(rawflat.shape, self.wavetilts['coeffs'][:,:,slit_idx], - self.wavetilts['func2d'], spat_shift=-1*_flexure) + # Evaluate tilts only at slit pixels to save memory + _coeff = self.wavetilts['coeffs'][:,:,slit_idx] + _spec, _spat = np.where(onslit_padded) + _pypeitFit = fitting.PypeItFit(fitc=_coeff, minx=0.0, maxx=1.0, + minx2=0.0, maxx2=1.0, + func=self.wavetilts['func2d']) + _xnspecmin1 = float(nspec - 1) + _xnspatmin1 = float(rawflat.shape[1] - 1) + _tilts_slit = _pypeitFit.eval(_spec / _xnspecmin1, + x2=(_spat + _flexure) / _xnspatmin1) + _tilts_slit = np.fmax(np.fmin(_tilts_slit, 1.2), -0.2) + # Build a full-frame tilts image placeholder with only slit pixels filled + tilts = np.zeros(rawflat.shape, dtype=float) + tilts[onslit_padded] = _tilts_slit + del _tilts_slit, _spec, _spat # Convert the tilt image to an image with the spectral pixel index spec_coo = tilts * (nspec-1) From a16ca6bc22d859458887ba171649e9ec96289873 Mon Sep 17 00:00:00 2001 From: "T. E. Pickering" Date: Tue, 31 Mar 2026 18:19:29 -0700 Subject: [PATCH 2/3] refactor: move slit-masked tilt evaluation into fit2tilts Add an optional `slit_mask` parameter to `tracewave.fit2tilts` so that tilt evaluation at only the relevant slit pixels is handled inside the function rather than being inlined at each call site. This addresses review feedback on the memory optimization: the logic now lives in the canonical location and both callers in `flatfield.py` and `wavetilts.py` benefit from reduced memory usage. Co-Authored-By: Claude Opus 4.6 --- pypeit/core/tracewave.py | 48 +++++++++++++++++++++++++++++----------- pypeit/flatfield.py | 23 +++++-------------- pypeit/wavetilts.py | 6 +++-- 3 files changed, 45 insertions(+), 32 deletions(-) diff --git a/pypeit/core/tracewave.py b/pypeit/core/tracewave.py index 85c2a6ccc0..2de3cedf2b 100644 --- a/pypeit/core/tracewave.py +++ b/pypeit/core/tracewave.py @@ -858,9 +858,15 @@ def fit_tilts(trc_tilt_dict, thismask, slit_cen, spat_order=3, spec_order=4, max # log.info("RMS/FWHM: {}".format(rms_real/fwhm)) -def fit2tilts(shape, coeff2, func2d, spat_shift=None): +def fit2tilts(shape, coeff2, func2d, spat_shift=None, slit_mask=None): """ - Evaluate the wavelength tilt model over the full image. + Evaluate the wavelength tilt model. + + When ``slit_mask`` is not provided, the model is evaluated over a full + meshgrid spanning the image. When ``slit_mask`` is provided, the model + is evaluated only at the ``True`` pixels in the mask, which avoids + allocating full-frame meshgrid arrays and significantly reduces memory + usage for spectrographs with many slits (e.g., fiber-fed IFUs). Parameters ---------- @@ -871,9 +877,13 @@ def fit2tilts(shape, coeff2, func2d, spat_shift=None): func2d : str the 2d function used to fit the tilts spat_shift : float, optional - Spatial shift to be added to image pixels before evaluation + Spatial shift to be added to image pixels before evaluation. If you are accounting for flexure, then you probably wish to input -1*flexure_shift into this parameter. + slit_mask : `numpy.ndarray`_, bool, optional + Boolean mask with the same shape as the image. If provided, + tilts are evaluated only where ``slit_mask`` is ``True`` and + the returned image is zero elsewhere. Returns ------- @@ -882,22 +892,34 @@ def fit2tilts(shape, coeff2, func2d, spat_shift=None): image. This output is used in the pipeline. """ - # Init _spat_shift = 0. if spat_shift is None else spat_shift - # Compute the tilts image nspec, nspat = shape xnspecmin1 = float(nspec - 1) xnspatmin1 = float(nspat - 1) - spec_vec = np.arange(nspec) - spat_vec = np.arange(nspat) - _spat_shift - spat_img, spec_img = np.meshgrid(spat_vec, spec_vec) - # + pypeitFit = fitting.PypeItFit(fitc=coeff2, minx=0.0, maxx=1.0, minx2=0.0, maxx2=1.0, func=func2d) - tilts = pypeitFit.eval(spec_img / xnspecmin1, x2=spat_img / xnspatmin1) - # Added this to ensure that tilts are never crazy values due to extrapolation of fits which can break - # wavelength solution fitting - return np.fmax(np.fmin(tilts, 1.2), -0.2) + + if slit_mask is not None: + # Evaluate only at slit pixels to save memory + spec_pix, spat_pix = np.where(slit_mask) + tilts_vals = pypeitFit.eval(spec_pix / xnspecmin1, + x2=(spat_pix - _spat_shift) / xnspatmin1) + tilts_vals = np.fmax(np.fmin(tilts_vals, 1.2), -0.2) + tilts = np.zeros(shape, dtype=float) + tilts[slit_mask] = tilts_vals + del tilts_vals, spec_pix, spat_pix + else: + # Full-frame meshgrid evaluation + spec_vec = np.arange(nspec) + spat_vec = np.arange(nspat) - _spat_shift + spat_img, spec_img = np.meshgrid(spat_vec, spec_vec) + tilts = pypeitFit.eval(spec_img / xnspecmin1, x2=spat_img / xnspatmin1) + # Added this to ensure that tilts are never crazy values due to + # extrapolation of fits which can break wavelength solution fitting + tilts = np.fmax(np.fmin(tilts, 1.2), -0.2) + + return tilts # This method needs to match the name in pypeit.core.qa.set_qa_filename() diff --git a/pypeit/flatfield.py b/pypeit/flatfield.py index 44a4ab1189..7d36973b53 100644 --- a/pypeit/flatfield.py +++ b/pypeit/flatfield.py @@ -989,28 +989,17 @@ def fit(self, spat_illum_only=False, doqa=True, debug=False): # Collapse the slit spatially and fit the spectral function # TODO: Put this stuff in a self.spectral_fit method? - # Create the tilts for pixels in this slit only (not full image) + # Create the tilts image for this slit if self.slitless: tilts = np.tile(np.arange(rawflat.shape[0]) / rawflat.shape[0], (rawflat.shape[1], 1)).T - spec_coo = tilts * (nspec-1) else: # TODO -- JFH Confirm the sign of this shift is correct! _flexure = 0. if self.wavetilts.spat_flexure is None else self.wavetilts.spat_flexure - # Evaluate tilts only at slit pixels to save memory - _coeff = self.wavetilts['coeffs'][:,:,slit_idx] - _spec, _spat = np.where(onslit_padded) - _pypeitFit = fitting.PypeItFit(fitc=_coeff, minx=0.0, maxx=1.0, - minx2=0.0, maxx2=1.0, - func=self.wavetilts['func2d']) - _xnspecmin1 = float(nspec - 1) - _xnspatmin1 = float(rawflat.shape[1] - 1) - _tilts_slit = _pypeitFit.eval(_spec / _xnspecmin1, - x2=(_spat + _flexure) / _xnspatmin1) - _tilts_slit = np.fmax(np.fmin(_tilts_slit, 1.2), -0.2) - # Build a full-frame tilts image placeholder with only slit pixels filled - tilts = np.zeros(rawflat.shape, dtype=float) - tilts[onslit_padded] = _tilts_slit - del _tilts_slit, _spec, _spat + tilts = tracewave.fit2tilts(rawflat.shape, + self.wavetilts['coeffs'][:,:,slit_idx], + self.wavetilts['func2d'], + spat_shift=-1*_flexure, + slit_mask=onslit_padded) # Convert the tilt image to an image with the spectral pixel index spec_coo = tilts * (nspec-1) diff --git a/pypeit/wavetilts.py b/pypeit/wavetilts.py index 02a318dffc..b2b849e2d1 100644 --- a/pypeit/wavetilts.py +++ b/pypeit/wavetilts.py @@ -151,9 +151,11 @@ def fit2tiltimg(self, slitmask, flexure=None): slit_idx = self.spatid_to_zero(slit_spat) # Calculate coeff_out = self.coeffs[:self.spec_order[slit_idx]+1,:self.spat_order[slit_idx]+1,slit_idx] - _tilts = tracewave.fit2tilts(final_tilts.shape, coeff_out, self.func2d, spat_shift=-1*_flexure) - # Fill thismask_science = slitmask == slit_spat + _tilts = tracewave.fit2tilts(final_tilts.shape, coeff_out, self.func2d, + spat_shift=-1*_flexure, + slit_mask=thismask_science) + # Fill final_tilts[thismask_science] = _tilts[thismask_science] # Return return final_tilts From 88da475b07a385b65ea82ec2c79e207fb1c95f43 Mon Sep 17 00:00:00 2001 From: "T. E. Pickering" Date: Thu, 2 Apr 2026 14:24:32 -0700 Subject: [PATCH 3/3] apply suggested refactor --- pypeit/core/tracewave.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/pypeit/core/tracewave.py b/pypeit/core/tracewave.py index 2de3cedf2b..91d4a7615b 100644 --- a/pypeit/core/tracewave.py +++ b/pypeit/core/tracewave.py @@ -900,24 +900,18 @@ def fit2tilts(shape, coeff2, func2d, spat_shift=None, slit_mask=None): pypeitFit = fitting.PypeItFit(fitc=coeff2, minx=0.0, maxx=1.0, minx2=0.0, maxx2=1.0, func=func2d) - if slit_mask is not None: - # Evaluate only at slit pixels to save memory - spec_pix, spat_pix = np.where(slit_mask) - tilts_vals = pypeitFit.eval(spec_pix / xnspecmin1, - x2=(spat_pix - _spat_shift) / xnspatmin1) - tilts_vals = np.fmax(np.fmin(tilts_vals, 1.2), -0.2) - tilts = np.zeros(shape, dtype=float) - tilts[slit_mask] = tilts_vals - del tilts_vals, spec_pix, spat_pix + if slit_mask is None: + spat_pix, spec_pix = map( + lambda x : x.ravel(), np.meshgrid(np.arange(nspat), np.arange(nspec)) + ) else: - # Full-frame meshgrid evaluation - spec_vec = np.arange(nspec) - spat_vec = np.arange(nspat) - _spat_shift - spat_img, spec_img = np.meshgrid(spat_vec, spec_vec) - tilts = pypeitFit.eval(spec_img / xnspecmin1, x2=spat_img / xnspatmin1) - # Added this to ensure that tilts are never crazy values due to - # extrapolation of fits which can break wavelength solution fitting - tilts = np.fmax(np.fmin(tilts, 1.2), -0.2) + spec_pix, spat_pix = np.where(slit_mask) + + tilts_vals = pypeitFit.eval(spec_pix / xnspecmin1, x2=(spat_pix - _spat_shift) / xnspatmin1) + tilts_vals = np.fmax(np.fmin(tilts_vals, 1.2), -0.2) + tilts = np.zeros(shape, dtype=float) + tilts[(spec_pix,spat_pix)] = tilts_vals + del tilts_vals, spec_pix, spat_pix return tilts