|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import xradio |
| 3 | +import xarray |
| 4 | +import copy |
| 5 | +from typing import List |
| 6 | +from astropy.coordinates import SpectralCoord, EarthLocation, SkyCoord |
| 7 | +from astropy.time import Time |
| 8 | +import astropy.units as u |
| 9 | +import numpy as np |
| 10 | +from scipy.interpolate import interp1d |
| 11 | +import warnings |
| 12 | + |
| 13 | +warnings.filterwarnings( |
| 14 | + "ignore", |
| 15 | + category=UserWarning, |
| 16 | + module="astropy.coordinates.spectral_coordinate", |
| 17 | +) |
| 18 | +np.set_printoptions(precision=12) |
| 19 | + |
| 20 | + |
| 21 | +def _ms_spectral_frame_conversion( |
| 22 | + ms: xarray.core.datatree.DataTree, |
| 23 | + freqrange: List[float] = [], |
| 24 | + outframe: str = "LSRK", |
| 25 | +) -> xarray.core.datatree.DataTree: |
| 26 | + """ |
| 27 | +
|
| 28 | +
|
| 29 | + Parameters |
| 30 | + ---------- |
| 31 | + ms_xds : xarray.core.datatree.DataTree |
| 32 | + DESCRIPTION. This is the input ms v4 xds that is to be transformed to |
| 33 | + a new spectral frame |
| 34 | + freqrange : List(float, float) |
| 35 | + DESCRIPTION. Selection of subset of channels falling in the |
| 36 | + freqrange to transform |
| 37 | +
|
| 38 | + Returns |
| 39 | + ------- |
| 40 | + ms v4 xds |
| 41 | +
|
| 42 | + """ |
| 43 | + |
| 44 | + # outms = ms.copy(inherit=False, deep=True) |
| 45 | + outms = copy.deepcopy(ms) |
| 46 | + # need to do selection over frequency here |
| 47 | + pc = ms.ms.get_field_and_source_xds().FIELD_PHASE_CENTER |
| 48 | + pcdir = [d * u.Unit(e) for d, e in zip(pc.data[0], pc.units)] |
| 49 | + phcen = SkyCoord(pcdir[0], pcdir[1], frame=pc.frame) |
| 50 | + locATt = _get_all_itrs_loc(ms) |
| 51 | + obsfreq = ms.frequency.data * u.Unit(ms.frequency.attrs["units"][0]) |
| 52 | + newFreqInFrame = _outframe_freq(outms, outframe) |
| 53 | + ##outMSFreq = outms.frequency.assign_coords(frequency=newFreqInFrame) buggy |
| 54 | + outMSFreq = xarray.DataArray( |
| 55 | + data=newFreqInFrame, |
| 56 | + dims=outms.frequency.dims, |
| 57 | + coords=outms.frequency.coords, |
| 58 | + attrs=outms.frequency.attrs, |
| 59 | + ) |
| 60 | + outMSFreq.attrs["observer"] = outframe |
| 61 | + # cannot assign the values or array directly |
| 62 | + # for k in range(len(newFreqInFrame)): |
| 63 | + # outMSFreq.values[k] = newFreqInFrame[k] # wopuld rather assign by coordinate value |
| 64 | + outms.frequency = outMSFreq |
| 65 | + print(obsfreq.value - outms.frequency.data, newFreqInFrame, obsfreq.value) |
| 66 | + # for aTLoc in locATt: |
| 67 | + _interpolate_data_weight_from_TOPO(outms, obsfreq) |
| 68 | + |
| 69 | + return outms |
| 70 | + |
| 71 | + |
| 72 | +def _outframe_freq(ms: xarray.core.datatree.DataTree, outframe: str = "lsrk"): |
| 73 | + """ |
| 74 | + Function to get a set of uniformly spaced frequencies that covers the range of frequencies |
| 75 | + in the ms over time in the outframe requested |
| 76 | + Parameters |
| 77 | + ---------- |
| 78 | + ms : xarray.core.datatree.DataTree |
| 79 | + DESCRIPTION. |
| 80 | + outframe : str, optional |
| 81 | + For now accepts lsrk, lsrd, bary, galacto(centric). |
| 82 | + The default is "lsrk". |
| 83 | + Returns |
| 84 | + ------- |
| 85 | + outfreqs : numpy array |
| 86 | + The nchannels in the ms frequencies in the outframe requested |
| 87 | +
|
| 88 | + """ |
| 89 | + # loc_itrs = obs.get_itrs(obstime=t) |
| 90 | + obsframe = ms["frequency"].attrs["observer"] |
| 91 | + if "TOPO" not in obsframe: |
| 92 | + raise Exception("This function works with TOPO only for now") |
| 93 | + obsfreq = ms["frequency"].data * u.Unit(ms.frequency.attrs["units"][0]) |
| 94 | + # print("_outframe_freq originfreq", obsfreq) |
| 95 | + nchan = len(obsfreq) |
| 96 | + |
| 97 | + locATt = _get_all_itrs_loc(ms) |
| 98 | + # There should be only one field name in each ms |
| 99 | + # should we test for uniqueness ? |
| 100 | + fldname = ms.field_name.data[0] |
| 101 | + phcen = _get_phase_center(ms, fldname) |
| 102 | + maxFreq = -1 |
| 103 | + minFreq = 1e12 |
| 104 | + for a_loc in locATt: |
| 105 | + frameSpec = SpectralCoord( |
| 106 | + obsfreq, observer=a_loc, target=phcen |
| 107 | + ).with_observer_stationary_relative_to(_frame_from_str(outframe)) |
| 108 | + freqATt = frameSpec.quantity.value |
| 109 | + minFreq = min(minFreq, np.min(freqATt)) |
| 110 | + maxFreq = max(maxFreq, np.max(freqATt)) |
| 111 | + outfreqs = np.zeros([nchan]) |
| 112 | + |
| 113 | + if nchan > 1: |
| 114 | + width = maxFreq - minFreq |
| 115 | + outfreqs = np.vectorize(lambda k: k * width + minFreq)( |
| 116 | + np.arange(nchan) |
| 117 | + ) |
| 118 | + else: |
| 119 | + outfreqs = np.array([minFreq]) |
| 120 | + # print("minfreq ", minFreq, " maxfreq ", maxFreq) |
| 121 | + # print("difference obsfreq ...new frame freq ", obsfreq.value - outfreqs) |
| 122 | + return outfreqs |
| 123 | + |
| 124 | + |
| 125 | +def _get_phase_center( |
| 126 | + ms: xarray.core.datatree.DataTree, fieldname: str |
| 127 | +) -> SkyCoord: |
| 128 | + pc = ms.ms.get_field_and_source_xds().FIELD_PHASE_CENTER.sel( |
| 129 | + field_name=fieldname |
| 130 | + ) |
| 131 | + pcdir = [d * u.Unit(e) for d, e in zip(pc.data, pc.attrs["units"])] |
| 132 | + phcen = SkyCoord(pcdir[0], pcdir[1], frame=pc.attrs["frame"]) |
| 133 | + return phcen |
| 134 | + |
| 135 | + |
| 136 | +def _frame_from_str(framestr: str): |
| 137 | + """ |
| 138 | + Tries to interprete string frame a la casa definition and return the |
| 139 | + appropriate astropy frame |
| 140 | + """ |
| 141 | + fr = "" |
| 142 | + if framestr.lower() == "lsrk": |
| 143 | + from astropy.coordinates import LSRK |
| 144 | + |
| 145 | + fr = LSRK |
| 146 | + elif framestr.lower() == "lsrd": |
| 147 | + from astropy.coordinates import LSRD |
| 148 | + |
| 149 | + fr = LSRD |
| 150 | + elif "galacto" in framestr.lower(): |
| 151 | + from astropy.coordinates import Galactocentric |
| 152 | + |
| 153 | + fr = Galactocentric |
| 154 | + elif "bary" in framestr.lower(): |
| 155 | + from astropy.coordinates import BarycentricTrueEcliptic |
| 156 | + |
| 157 | + fr = BarycentricTrueEcliptic |
| 158 | + else: |
| 159 | + raise (f"Don't know the frame {framestr}") |
| 160 | + return fr |
| 161 | + |
| 162 | + |
| 163 | +def _get_all_itrs_loc(ms: xarray.core.datatree.DataTree): |
| 164 | + """This returns an astropy itrs location at all unique times |
| 165 | + in the ms |
| 166 | + """ |
| 167 | + |
| 168 | + obsstr = ms["antenna_xds"].attrs["overall_telescope_name"] |
| 169 | + if obsstr == "EVLA": |
| 170 | + obsstr = "VLA" |
| 171 | + elif obsstr == "ATA": |
| 172 | + obsstr = "ALMA" |
| 173 | + |
| 174 | + obs = EarthLocation.of_site(obsstr) |
| 175 | + obs_t = ms["time"].data * u.Unit(ms["time"].attrs["units"][0]) |
| 176 | + t = Time( |
| 177 | + obs_t, |
| 178 | + format=ms["time"].attrs["format"], |
| 179 | + scale=ms["time"].attrs["scale"], |
| 180 | + ) |
| 181 | + locATt = obs.get_itrs(obstime=t) |
| 182 | + return locATt |
| 183 | + |
| 184 | + |
| 185 | +def _interpolate_data_weight_from_TOPO( |
| 186 | + ms: xarray.core.datatree.DataTree, origfreq: u.quantity.Quantity |
| 187 | +): |
| 188 | + """ |
| 189 | + interpolate the visibility (and weights) for every time stamp in the frame of the ms |
| 190 | + The outframe is assumed to have already been assigned to the |
| 191 | + ms.frequency attributes |
| 192 | +
|
| 193 | + """ |
| 194 | + infreq = origfreq.to(u.Hz).value |
| 195 | + interpfreq = ms.frequency.data |
| 196 | + locATt = _get_all_itrs_loc(ms) |
| 197 | + fldname = ms.field_name.data[0] |
| 198 | + phcen = _get_phase_center(ms, fldname) |
| 199 | + outframe = ms.frequency.attrs["observer"] |
| 200 | + for a_loc in locATt: |
| 201 | + frameSpec = SpectralCoord( |
| 202 | + origfreq, observer=a_loc, target=phcen |
| 203 | + ).with_observer_stationary_relative_to(_frame_from_str(outframe)) |
| 204 | + freqATt = frameSpec.quantity.to(u.Hz).value |
| 205 | + elvis = ms.VISIBILITY[a_loc.obstime.value] |
| 206 | + elwgt = ms.WEIGHT[a_loc.obstime.value] |
| 207 | + elflg = ms.FLAG[a_loc.obstime.value] |
| 208 | + elwgt = elwgt * np.logical_not(elflg) |
| 209 | + _interp_channels(elvis, elwgt, freqATt, interpfreq) |
| 210 | + |
| 211 | + |
| 212 | +def _interp_channels(data, weights, datafreq, interpfreq): |
| 213 | + wgtdata = data.data * weights.data |
| 214 | + for b in range(data.shape[0]): |
| 215 | + for p in range(data.shape[2]): |
| 216 | + fintd = interp1d( |
| 217 | + datafreq, |
| 218 | + wgtdata[b, :, p], |
| 219 | + kind="linear", |
| 220 | + fill_value="extrapolate", |
| 221 | + ) |
| 222 | + wgtdata[b, :, p] = fintd(interpfreq) |
| 223 | + fintw = interp1d( |
| 224 | + datafreq, |
| 225 | + weights[b, :, p], |
| 226 | + kind="linear", |
| 227 | + fill_value="extrapolate", |
| 228 | + ) |
| 229 | + weights[b, :, p] = fintw(interpfreq) |
| 230 | + data[weights != 0] = wgtdata[weights != 0] / weights[weights != 0] |
| 231 | + data[weights == 0] = 0.0 |
0 commit comments