diff --git a/pypeit/core/tracewave.py b/pypeit/core/tracewave.py index 85c2a6ccc..91d4a7615 100644 --- a/pypeit/core/tracewave.py +++ b/pypeit/core/tracewave.py @@ -858,9 +858,15 @@ def fit_tilts(trc_tilt_dict, thismask, slit_cen, spat_order=3, spec_order=4, max # log.info("RMS/FWHM: {}".format(rms_real/fwhm)) -def fit2tilts(shape, coeff2, func2d, spat_shift=None): +def fit2tilts(shape, coeff2, func2d, spat_shift=None, slit_mask=None): """ - Evaluate the wavelength tilt model over the full image. + Evaluate the wavelength tilt model. + + When ``slit_mask`` is not provided, the model is evaluated over a full + meshgrid spanning the image. When ``slit_mask`` is provided, the model + is evaluated only at the ``True`` pixels in the mask, which avoids + allocating full-frame meshgrid arrays and significantly reduces memory + usage for spectrographs with many slits (e.g., fiber-fed IFUs). Parameters ---------- @@ -871,9 +877,13 @@ def fit2tilts(shape, coeff2, func2d, spat_shift=None): func2d : str the 2d function used to fit the tilts spat_shift : float, optional - Spatial shift to be added to image pixels before evaluation + Spatial shift to be added to image pixels before evaluation. If you are accounting for flexure, then you probably wish to input -1*flexure_shift into this parameter. + slit_mask : `numpy.ndarray`_, bool, optional + Boolean mask with the same shape as the image. If provided, + tilts are evaluated only where ``slit_mask`` is ``True`` and + the returned image is zero elsewhere. Returns ------- @@ -882,22 +892,28 @@ def fit2tilts(shape, coeff2, func2d, spat_shift=None): image. This output is used in the pipeline. """ - # Init _spat_shift = 0. if spat_shift is None else spat_shift - # Compute the tilts image nspec, nspat = shape xnspecmin1 = float(nspec - 1) xnspatmin1 = float(nspat - 1) - spec_vec = np.arange(nspec) - spat_vec = np.arange(nspat) - _spat_shift - spat_img, spec_img = np.meshgrid(spat_vec, spec_vec) - # + pypeitFit = fitting.PypeItFit(fitc=coeff2, minx=0.0, maxx=1.0, minx2=0.0, maxx2=1.0, func=func2d) - tilts = pypeitFit.eval(spec_img / xnspecmin1, x2=spat_img / xnspatmin1) - # Added this to ensure that tilts are never crazy values due to extrapolation of fits which can break - # wavelength solution fitting - return np.fmax(np.fmin(tilts, 1.2), -0.2) + + if slit_mask is None: + spat_pix, spec_pix = map( + lambda x : x.ravel(), np.meshgrid(np.arange(nspat), np.arange(nspec)) + ) + else: + spec_pix, spat_pix = np.where(slit_mask) + + tilts_vals = pypeitFit.eval(spec_pix / xnspecmin1, x2=(spat_pix - _spat_shift) / xnspatmin1) + tilts_vals = np.fmax(np.fmin(tilts_vals, 1.2), -0.2) + tilts = np.zeros(shape, dtype=float) + tilts[(spec_pix,spat_pix)] = tilts_vals + del tilts_vals, spec_pix, spat_pix + + return tilts # This method needs to match the name in pypeit.core.qa.set_qa_filename() diff --git a/pypeit/flatfield.py b/pypeit/flatfield.py index 68b194d89..7d36973b5 100644 --- a/pypeit/flatfield.py +++ b/pypeit/flatfield.py @@ -995,8 +995,11 @@ def fit(self, spat_illum_only=False, doqa=True, debug=False): else: # TODO -- JFH Confirm the sign of this shift is correct! _flexure = 0. if self.wavetilts.spat_flexure is None else self.wavetilts.spat_flexure - tilts = tracewave.fit2tilts(rawflat.shape, self.wavetilts['coeffs'][:,:,slit_idx], - self.wavetilts['func2d'], spat_shift=-1*_flexure) + tilts = tracewave.fit2tilts(rawflat.shape, + self.wavetilts['coeffs'][:,:,slit_idx], + self.wavetilts['func2d'], + spat_shift=-1*_flexure, + slit_mask=onslit_padded) # Convert the tilt image to an image with the spectral pixel index spec_coo = tilts * (nspec-1) diff --git a/pypeit/wavetilts.py b/pypeit/wavetilts.py index 02a318dff..b2b849e2d 100644 --- a/pypeit/wavetilts.py +++ b/pypeit/wavetilts.py @@ -151,9 +151,11 @@ def fit2tiltimg(self, slitmask, flexure=None): slit_idx = self.spatid_to_zero(slit_spat) # Calculate coeff_out = self.coeffs[:self.spec_order[slit_idx]+1,:self.spat_order[slit_idx]+1,slit_idx] - _tilts = tracewave.fit2tilts(final_tilts.shape, coeff_out, self.func2d, spat_shift=-1*_flexure) - # Fill thismask_science = slitmask == slit_spat + _tilts = tracewave.fit2tilts(final_tilts.shape, coeff_out, self.func2d, + spat_shift=-1*_flexure, + slit_mask=thismask_science) + # Fill final_tilts[thismask_science] = _tilts[thismask_science] # Return return final_tilts