Skip to content

Commit ac1478f

Browse files
committed
First version of cvelling with for ms v4 xds. TOPO to frame to start with
1 parent 8ff589d commit ac1478f

1 file changed

Lines changed: 231 additions & 0 deletions

File tree

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# -*- coding: utf-8 -*-
2+
import xradio
3+
import xarray
4+
import copy
5+
from typing import List
6+
from astropy.coordinates import SpectralCoord, EarthLocation, SkyCoord
7+
from astropy.time import Time
8+
import astropy.units as u
9+
import numpy as np
10+
from scipy.interpolate import interp1d
11+
import warnings
12+
13+
warnings.filterwarnings(
14+
"ignore",
15+
category=UserWarning,
16+
module="astropy.coordinates.spectral_coordinate",
17+
)
18+
np.set_printoptions(precision=12)
19+
20+
21+
def _ms_spectral_frame_conversion(
22+
ms: xarray.core.datatree.DataTree,
23+
freqrange: List[float] = [],
24+
outframe: str = "LSRK",
25+
) -> xarray.core.datatree.DataTree:
26+
"""
27+
28+
29+
Parameters
30+
----------
31+
ms_xds : xarray.core.datatree.DataTree
32+
DESCRIPTION. This is the input ms v4 xds that is to be transformed to
33+
a new spectral frame
34+
freqrange : List(float, float)
35+
DESCRIPTION. Selection of subset of channels falling in the
36+
freqrange to transform
37+
38+
Returns
39+
-------
40+
ms v4 xds
41+
42+
"""
43+
44+
# outms = ms.copy(inherit=False, deep=True)
45+
outms = copy.deepcopy(ms)
46+
# need to do selection over frequency here
47+
pc = ms.ms.get_field_and_source_xds().FIELD_PHASE_CENTER
48+
pcdir = [d * u.Unit(e) for d, e in zip(pc.data[0], pc.units)]
49+
phcen = SkyCoord(pcdir[0], pcdir[1], frame=pc.frame)
50+
locATt = _get_all_itrs_loc(ms)
51+
obsfreq = ms.frequency.data * u.Unit(ms.frequency.attrs["units"][0])
52+
newFreqInFrame = _outframe_freq(outms, outframe)
53+
##outMSFreq = outms.frequency.assign_coords(frequency=newFreqInFrame) buggy
54+
outMSFreq = xarray.DataArray(
55+
data=newFreqInFrame,
56+
dims=outms.frequency.dims,
57+
coords=outms.frequency.coords,
58+
attrs=outms.frequency.attrs,
59+
)
60+
outMSFreq.attrs["observer"] = outframe
61+
# cannot assign the values or array directly
62+
# for k in range(len(newFreqInFrame)):
63+
# outMSFreq.values[k] = newFreqInFrame[k] # wopuld rather assign by coordinate value
64+
outms.frequency = outMSFreq
65+
print(obsfreq.value - outms.frequency.data, newFreqInFrame, obsfreq.value)
66+
# for aTLoc in locATt:
67+
_interpolate_data_weight_from_TOPO(outms, obsfreq)
68+
69+
return outms
70+
71+
72+
def _outframe_freq(ms: xarray.core.datatree.DataTree, outframe: str = "lsrk"):
73+
"""
74+
Function to get a set of uniformly spaced frequencies that covers the range of frequencies
75+
in the ms over time in the outframe requested
76+
Parameters
77+
----------
78+
ms : xarray.core.datatree.DataTree
79+
DESCRIPTION.
80+
outframe : str, optional
81+
For now accepts lsrk, lsrd, bary, galacto(centric).
82+
The default is "lsrk".
83+
Returns
84+
-------
85+
outfreqs : numpy array
86+
The nchannels in the ms frequencies in the outframe requested
87+
88+
"""
89+
# loc_itrs = obs.get_itrs(obstime=t)
90+
obsframe = ms["frequency"].attrs["observer"]
91+
if "TOPO" not in obsframe:
92+
raise Exception("This function works with TOPO only for now")
93+
obsfreq = ms["frequency"].data * u.Unit(ms.frequency.attrs["units"][0])
94+
# print("_outframe_freq originfreq", obsfreq)
95+
nchan = len(obsfreq)
96+
97+
locATt = _get_all_itrs_loc(ms)
98+
# There should be only one field name in each ms
99+
# should we test for uniqueness ?
100+
fldname = ms.field_name.data[0]
101+
phcen = _get_phase_center(ms, fldname)
102+
maxFreq = -1
103+
minFreq = 1e12
104+
for a_loc in locATt:
105+
frameSpec = SpectralCoord(
106+
obsfreq, observer=a_loc, target=phcen
107+
).with_observer_stationary_relative_to(_frame_from_str(outframe))
108+
freqATt = frameSpec.quantity.value
109+
minFreq = min(minFreq, np.min(freqATt))
110+
maxFreq = max(maxFreq, np.max(freqATt))
111+
outfreqs = np.zeros([nchan])
112+
113+
if nchan > 1:
114+
width = maxFreq - minFreq
115+
outfreqs = np.vectorize(lambda k: k * width + minFreq)(
116+
np.arange(nchan)
117+
)
118+
else:
119+
outfreqs = np.array([minFreq])
120+
# print("minfreq ", minFreq, " maxfreq ", maxFreq)
121+
# print("difference obsfreq ...new frame freq ", obsfreq.value - outfreqs)
122+
return outfreqs
123+
124+
125+
def _get_phase_center(
126+
ms: xarray.core.datatree.DataTree, fieldname: str
127+
) -> SkyCoord:
128+
pc = ms.ms.get_field_and_source_xds().FIELD_PHASE_CENTER.sel(
129+
field_name=fieldname
130+
)
131+
pcdir = [d * u.Unit(e) for d, e in zip(pc.data, pc.attrs["units"])]
132+
phcen = SkyCoord(pcdir[0], pcdir[1], frame=pc.attrs["frame"])
133+
return phcen
134+
135+
136+
def _frame_from_str(framestr: str):
137+
"""
138+
Tries to interprete string frame a la casa definition and return the
139+
appropriate astropy frame
140+
"""
141+
fr = ""
142+
if framestr.lower() == "lsrk":
143+
from astropy.coordinates import LSRK
144+
145+
fr = LSRK
146+
elif framestr.lower() == "lsrd":
147+
from astropy.coordinates import LSRD
148+
149+
fr = LSRD
150+
elif "galacto" in framestr.lower():
151+
from astropy.coordinates import Galactocentric
152+
153+
fr = Galactocentric
154+
elif "bary" in framestr.lower():
155+
from astropy.coordinates import BarycentricTrueEcliptic
156+
157+
fr = BarycentricTrueEcliptic
158+
else:
159+
raise (f"Don't know the frame {framestr}")
160+
return fr
161+
162+
163+
def _get_all_itrs_loc(ms: xarray.core.datatree.DataTree):
164+
"""This returns an astropy itrs location at all unique times
165+
in the ms
166+
"""
167+
168+
obsstr = ms["antenna_xds"].attrs["overall_telescope_name"]
169+
if obsstr == "EVLA":
170+
obsstr = "VLA"
171+
elif obsstr == "ATA":
172+
obsstr = "ALMA"
173+
174+
obs = EarthLocation.of_site(obsstr)
175+
obs_t = ms["time"].data * u.Unit(ms["time"].attrs["units"][0])
176+
t = Time(
177+
obs_t,
178+
format=ms["time"].attrs["format"],
179+
scale=ms["time"].attrs["scale"],
180+
)
181+
locATt = obs.get_itrs(obstime=t)
182+
return locATt
183+
184+
185+
def _interpolate_data_weight_from_TOPO(
186+
ms: xarray.core.datatree.DataTree, origfreq: u.quantity.Quantity
187+
):
188+
"""
189+
interpolate the visibility (and weights) for every time stamp in the frame of the ms
190+
The outframe is assumed to have already been assigned to the
191+
ms.frequency attributes
192+
193+
"""
194+
infreq = origfreq.to(u.Hz).value
195+
interpfreq = ms.frequency.data
196+
locATt = _get_all_itrs_loc(ms)
197+
fldname = ms.field_name.data[0]
198+
phcen = _get_phase_center(ms, fldname)
199+
outframe = ms.frequency.attrs["observer"]
200+
for a_loc in locATt:
201+
frameSpec = SpectralCoord(
202+
origfreq, observer=a_loc, target=phcen
203+
).with_observer_stationary_relative_to(_frame_from_str(outframe))
204+
freqATt = frameSpec.quantity.to(u.Hz).value
205+
elvis = ms.VISIBILITY[a_loc.obstime.value]
206+
elwgt = ms.WEIGHT[a_loc.obstime.value]
207+
elflg = ms.FLAG[a_loc.obstime.value]
208+
elwgt = elwgt * np.logical_not(elflg)
209+
_interp_channels(elvis, elwgt, freqATt, interpfreq)
210+
211+
212+
def _interp_channels(data, weights, datafreq, interpfreq):
213+
wgtdata = data.data * weights.data
214+
for b in range(data.shape[0]):
215+
for p in range(data.shape[2]):
216+
fintd = interp1d(
217+
datafreq,
218+
wgtdata[b, :, p],
219+
kind="linear",
220+
fill_value="extrapolate",
221+
)
222+
wgtdata[b, :, p] = fintd(interpfreq)
223+
fintw = interp1d(
224+
datafreq,
225+
weights[b, :, p],
226+
kind="linear",
227+
fill_value="extrapolate",
228+
)
229+
weights[b, :, p] = fintw(interpfreq)
230+
data[weights != 0] = wgtdata[weights != 0] / weights[weights != 0]
231+
data[weights == 0] = 0.0

0 commit comments

Comments
 (0)