Skip to content

Commit 6169ad6

Browse files
committed
gridflag : Gtidding still flaky
Generating the UV grid is still flaky, it seems to work with the "multithreaded" dask scheduler, but not with Cluster()
1 parent 7428d52 commit 6169ad6

2 files changed

Lines changed: 259 additions & 86 deletions

File tree

src/astroviper/_domain/_flagging/_gridflag_histogram.py

Lines changed: 236 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,51 @@
33
"""
44

55
import dask
6+
import time
67

78
import numba as nb
9+
from numba.typed import List
810
import numpy as np
11+
import gc
12+
#from memory_profiler import profile
913

10-
from scipy.stats import median_abs_deviation
1114

15+
@nb.njit(nogil=True, cache=True, fastmath=True)
16+
def _merge_vis_lists(ps_vis_accum, vis_accum, npixu, npixv):
1217

13-
#@dask.delayed
14-
def compute_uv_histogram(input_params):
18+
print("start _merge_vis_lists")
19+
20+
for uu in range(npixu):
21+
for vv in range(npixv):
22+
if len(vis_accum[uu][vv]) > 0:
23+
ps_vis_accum[uu][vv].extend(vis_accum[uu][vv])
24+
25+
print("end _merge_vis_lists ")
26+
27+
return ps_vis_accum
28+
29+
30+
#@profile
31+
def accumulate_uv_points(input_params):
1532
"""
1633
Read in the input XDS and calculate a histogram per pixel in the UV plane.
1734
"""
18-
print('Processing task with id: ', input_params['task_id'])
35+
gc.collect()
36+
1937
from xradio.vis.load_processing_set import load_processing_set
2038

2139
uvrange = np.asarray(sorted(input_params['uvrange']))
2240
uvcell = input_params['uvcell']
2341
nhistbin = input_params['nhistbin']
2442
npixu, npixv = input_params['npixels']
2543

26-
accum_uv_hist_med = np.zeros((npixu, npixv))
27-
accum_uv_hist_std = np.zeros((npixu, npixv))
28-
vis_accum = np.empty((npixu, npixv), dtype=object)
44+
ps_vis_accum = List([List([List.empty_list(nb.f8) for y in range(npixv)]) for z in range(npixu)])
45+
46+
print(input_params["data_selection"].items())
2947

3048
for ms_v4_name, slice_description in input_params["data_selection"].items():
3149
if input_params["input_data"] is None:
32-
ps = load_processing_set(
33-
ps_name=input_params["input_data_store"],
50+
ps = load_processing_set(input_params["input_data_store"],
3451
sel_parms={ms_v4_name: slice_description},
3552
)
3653
else:
@@ -39,8 +56,8 @@ def compute_uv_histogram(input_params):
3956
ms_xds = ps.get(0)
4057

4158
#ref_freq = float(ms_xds.frequency.attrs['reference_frequency']['data'])
42-
ref_freq = np.mean(ms_xds.frequency).values
43-
59+
ref_freq = np.mean(ms_xds.frequency).values
60+
4461
min_baseline = ms_xds.baseline_id.min().data
4562
max_baseline = ms_xds.baseline_id.max().data
4663

@@ -52,70 +69,112 @@ def compute_uv_histogram(input_params):
5269
flag = ms_xds.FLAG.data.astype(bool)
5370
freq = ms_xds.frequency.data
5471

55-
# Flip flags and replace NaNs with zeros, so they flag the corresponding visibilities
56-
#flag = np.nan_to_num(~flag).astype(bool)
72+
print(uvw.shape, vis.shape, flag.shape, freq.shape)
73+
74+
def getsize(arr):
75+
return round(arr.nbytes/1024/1024, 2)
76+
77+
print("UVW, Vis, flag, freq in MB", getsize(uvw), getsize(vis), getsize(flag), getsize(freq))
78+
79+
import sys, psutil
80+
process = psutil.Process()
81+
82+
print("Size of PS, ms_xds in bytes ", sys.getsizeof(ps), sys.getsizeof(ms_xds))
83+
print("Total process memory in MB ", process.memory_info().rss/1024/1024)
84+
85+
del ps, ms_xds
86+
gc.collect()
87+
print("Total process memory in MB after GC", process.memory_info().rss/1024/1024)
88+
5789
vis = np.nan_to_num(vis)
58-
# Flag visibilities
59-
#vis = np.asarray(vis*~flag)
90+
# Apply previously computed flags
91+
vis = np.asarray(vis*~flag)
6092

6193
uvw = np.nan_to_num(uvw)
94+
t1 = time.time()
6295
uv_scaled = scale_uv_freq(np.asarray(uvw), np.asarray(freq), ref_freq)
96+
t2 = time.time()
97+
print(f"scale_uv_freq time {t2-t1} s")
6398

64-
# Create a histogram per UV pixel - some might be entirely zeros, with no data.
99+
npt = uv_scaled.reshape([-1,2]).shape[0]
100+
101+
# Create a list of visibilities per UV pixel - some might be entirely zeros, with no data.
65102
# Manually verified that the reshape works for a handful of random indices
66-
uv_histogram(vis_accum, uv_scaled.reshape([-1,2]), vis.reshape([-1,2]), uvrange, uvcell, npixu, npixv)
103+
t1 = time.time()
104+
vis_accum = vis_per_uv_pixel(uv_scaled.reshape([-1,2]), vis.reshape([-1,2]), uvrange, uvcell, npixu, npixv, npt)
105+
t2 = time.time()
106+
print(f"vis_per_uv_pixel time {t2-t1} s")
67107

108+
t1 = time.time()
109+
ps_vis_accum = _merge_vis_lists(ps_vis_accum, vis_accum, npixu, npixv)
110+
t2 = time.time()
111+
print(f"_merge_vis_lists time {t2-t1} s")
68112

69-
return vis_accum
113+
#ps_vis_accum = np.asarray(ps_vis_accum)
114+
t1 = time.time()
115+
uv_med_grid, uv_std_grid, uv_npt_grid = calc_uv_stats(ps_vis_accum, npixu, npixv)
116+
t2 = time.time()
117+
print(f"calc_uv_stats time {t2-t1} s")
70118

119+
return uv_med_grid, uv_std_grid, uv_npt_grid
71120

72-
#@nb.jit(nopython=True, nogil=True, cache=True)
73-
def merge_uv_grids(results, input_parms):
121+
122+
123+
@nb.njit(cache=True, nogil=True, fastmath=True)
124+
def mad_std(data):
74125
"""
75-
Given the list of results from compute_uv_histogram, merge the UV grids together to compute stats
126+
Calculate the median absolute deviation of the data. Cannot use a "built-in" function like
127+
astropy.stats.mad_std because we want to call this function inside numba.jit
76128
77129
Inputs:
78-
results : np.array(list) - All points falling within a UV cell
79-
npixu : Number of pixels in U
80-
npixv : Number of pixels in V
81-
nhistbin : Number of histogram bins
130+
data : np.array - Data
82131
83132
Returns:
84-
accum_uv_hist_med : np.array - Median of the histogram per pixel
85-
accum_uv_hist_std : np.array - Standard deviation of the histogram per pixel
133+
mad_std : float - Median absolute deviation
86134
"""
87135

88-
npixu = input_parms['npixu']
89-
npixv = input_parms['npixv']
90-
nhistbin = input_parms['nhistbin']
136+
median = np.median(data)
137+
mad = np.median(np.abs(data - median))
138+
std = 1.4826 * mad
139+
140+
return std
141+
142+
return np.median(np.abs(data - np.median(data)))
143+
91144

92-
nchunk = len(results)
93-
accum_uv_hist_med = np.zeros((npixu, npixv))
94-
accum_uv_hist_std = np.zeros((npixu, npixv))
145+
@nb.njit(cache=True, nogil=True, fastmath=True)
146+
def calc_uv_stats(ps_vis_accum, npixu, npixv):
147+
"""
148+
Calculate the median and standard deviation of the visibilities per UV pixel.
149+
150+
Inputs:
151+
ps_vis_accum : np.array - Visibilities
152+
npixu : int - Number of pixels in U
153+
npixv : int - Number of pixels in V
154+
155+
Returns:
156+
uv_med_grid : np.array - Median of the histogram per pixel
157+
uv_std_grid : np.array - Standard deviation of the histogram per pixel
158+
"""
95159

96-
#for nn in range(nchunk):
97-
# print(results[0][nn].shape)
98-
# for uu in range(npixu):
99-
# for vv in range(npixv):
100-
# if len(results[0][nn][uu,vv]) > 0:
101-
# print(f"nn {nn} uu {uu} vv {vv}")
102-
# print(len(results[0][nn][uu,vv]))
103-
# input()
160+
print("start calc_uv_stats")
104161

162+
uv_med_grid = np.zeros((npixu, npixv))
163+
uv_std_grid = np.zeros((npixu, npixv))
164+
uv_npt_grid = np.zeros((npixu, npixv), dtype=np.int64)
105165

106166
for uu in range(npixu):
107167
for vv in range(npixv):
108-
concat_list = []
109-
for nn in range(nchunk):
110-
if len(results[nn][uu,vv]) > 0:
111-
concat_list.extend(results[nn][uu,vv])
112-
113-
if len(concat_list) > 0:
114-
concat_list = np.nan_to_num(concat_list)
115-
accum_uv_hist_med[uu, vv] = np.median(concat_list[concat_list != 0])
116-
accum_uv_hist_std[uu, vv] = median_abs_deviation(concat_list[concat_list != 0])
168+
if len(ps_vis_accum[uu][vv]) > 0:
169+
uv_med_grid[uu, vv] = np.median(np.asarray(ps_vis_accum[uu][vv]))
170+
uv_std_grid[uu, vv] = mad_std(np.asarray(ps_vis_accum[uu][vv]))
171+
uv_npt_grid[uu, vv] = len(ps_vis_accum[uu][vv])
172+
else:
173+
uv_med_grid[uu, vv] = 0
174+
uv_std_grid[uu, vv] = 0
117175

118-
return [accum_uv_hist_med, accum_uv_hist_std]
176+
print("end calc_uv_stats")
177+
return uv_med_grid, uv_std_grid, uv_npt_grid
119178

120179

121180

@@ -136,37 +195,146 @@ def hermitian_conjugate(uv, vis):
136195

137196

138197
#@profile
139-
#@nb.jit(nopython=True, nogil=True, cache=True)
140-
def uv_histogram(vis_hist,uv, vis, uvrange, uvcell, npixu, npixv):
198+
@nb.jit(nopython=True, nogil=True, cache=True, fastmath=True)
199+
def vis_per_uv_pixel(uv, vis, uvrange, uvcell, npixu, npixv, npt):
141200
"""
142-
Generate a histogram per UV pixel, given the input UV coordinates & visibilities.
201+
Accumulate list of visibilities per UV pixel
143202
"""
144203

145204
uvrange = sorted(uvrange)
146205
uv, vis = hermitian_conjugate(np.asarray(uv), np.asarray(vis))
206+
nptuv = np.zeros((npixu, npixv))
147207

148-
#vis_hist = np.zeros((npixu, npixv), dtype=object)
208+
# numba hack : Create an empty typed list
209+
vis_accum = List([List([List([float(x) for x in range(0)]) for y in range(npixv)]) for z in range(npixu)])
149210

150211
stokesI = np.abs((vis[...,0] + vis[...,-1])/2.)
151212

152-
# Initialize empty lists
153-
for uu in range(npixu):
154-
for vv in range(npixv):
155-
vis_hist[uu][vv] = []
213+
idx = 0
214+
for didx, dat in enumerate(stokesI):
215+
if dat == 0:
216+
continue
156217

157-
for idx, dat in enumerate(stokesI):
158-
uvdist = np.sqrt(uv[idx, 0]**2 + uv[idx, 1]**2)
218+
uvdist = np.sqrt(uv[didx, 0]**2 + uv[didx, 1]**2)
159219

160220
if uvdist > uvrange[1] or uvdist < uvrange[0]:
161221
continue
162222

163-
ubin = int((uv[idx, 0] + uvrange[1])//uvcell)
164-
vbin = int(uv[idx, 1]//uvcell)
223+
ubin = int((uv[didx, 0] + uvrange[1])//uvcell)
224+
vbin = int(uv[didx, 1]//uvcell)
225+
226+
vis_accum[ubin][vbin].append(dat)
227+
nptuv[ubin, vbin] += 1
228+
229+
print("Number of points appended ", np.sum(nptuv))
230+
print("Approx size ", np.sum(nptuv)*8/1024/1024)
231+
print("end vis_per_uv_pixel")
232+
return vis_accum
233+
234+
235+
236+
@nb.njit(nogil=True, cache=True, fastmath=True)
237+
def _accum_means(mean1, npt1, mean2, npt2):
238+
"""
239+
Calculate the aggregate mean given two input mean values.
240+
241+
Inputs:
242+
mean1 : float - Mean 1
243+
npt1 : int - Number of points for mean 1
244+
mean2 : float - Mean 2
245+
npt2 : int - Number of points for mean 2
246+
247+
Returns:
248+
mean : float - Aggregate mean
249+
npt : int - Number of points
250+
"""
251+
252+
npt = npt1 + npt2
253+
if npt == 0:
254+
return 0, 0
255+
256+
mean = (mean1*npt1 + mean2*npt2)/npt
257+
258+
return mean, npt
259+
260+
261+
@nb.njit(nogil=True, cache=True, fastmath=True)
262+
def _accum_std(mean1, std1, npt1, mean2, std2, npt2):
263+
"""
264+
Calculate the aggregate standard deviation given two input standard deviation values.
265+
266+
Inputs:
267+
mean1 : float - Mean 1
268+
std1 : float - Standard deviation 1
269+
npt1 : int - Number of points for mean 1
270+
mean2 : float - Mean 2
271+
std2 : float - Standard deviation 2
272+
npt2 : int - Number of points for mean 2
273+
274+
Returns:
275+
std : float - Aggregate standard deviation
276+
npt : int - Number of points
277+
"""
278+
279+
npt = npt1 + npt2
280+
if npt < 2:
281+
return 0, 0
282+
283+
var1 = ((npt1 - 1)*std1**2 + (npt2-1)*std2**2)/(npt1 + npt2 - 1)
284+
var2 = ((npt1*npt2) * (mean1 - mean2)**2)/((npt1 + npt2)*(npt1 + npt2 - 1))
285+
286+
std = np.sqrt(var1 + var2)
287+
288+
return std, npt
289+
290+
291+
#@nb.jit(nopython=True, nogil=True, cache=True)
292+
def merge_uv_grids(graph, input_params):
293+
"""
294+
Given the list of results from accumulate_uv_points, merge the UV grids together to compute stats
295+
296+
Inputs:
297+
graph : list(np.array, np.array) - Each element contains the median, std dev and npt UV grid
298+
npixu : Number of pixels in U
299+
npixv : Number of pixels in V
300+
nhistbin : Number of histogram bins
301+
302+
Returns:
303+
accum_uvmed : np.array - Median of the histogram per pixel
304+
accum_uvstd : np.array - Standard deviation of the histogram per pixel
305+
"""
306+
307+
npixu = input_params['npixels'][0]
308+
npixv = input_params['npixels'][1]
309+
nhistbin = input_params['nhistbin']
310+
311+
# Graph is a tuple of 2 nested elements : (median, std, npt) from each node
312+
# of the input DAG merge_uv_grids should accumulate these grids onto a
313+
# single one, and return it.
314+
315+
nchunk = len(graph)
316+
317+
accum_uv_med = np.zeros((npixu, npixv))
318+
accum_uv_std = np.zeros((npixu, npixv))
319+
uvnpt = np.zeros((npixu, npixv), dtype=np.int64)
320+
321+
med0 = graph[0][0]
322+
std0 = graph[0][1]
323+
npt0 = graph[0][2]
324+
325+
med1 = graph[1][0]
326+
std1 = graph[1][1]
327+
npt1 = graph[1][2]
328+
329+
for uu in range(npixu):
330+
for vv in range(npixv):
331+
accum_uv_med[uu,vv], uvnpt[uu,vv] = _accum_means(med0[uu,vv], npt0[uu,vv], med1[uu,vv], npt1[uu,vv])
332+
accum_uv_std[uu,vv], __ = _accum_std(med0[uu,vv], std0[uu,vv], npt0[uu,vv], med1[uu,vv], std1[uu,vv], npt1[uu,vv])
333+
165334

166-
if dat != 0:
167-
vis_hist[ubin,vbin].append(dat)
335+
np.savez('accum_uv_med.npz', accum_uv_med, accum_uv_std, uvnpt)
336+
return accum_uv_med, accum_uv_std, uvnpt
168337

169-
#return vis_hist
170338

171339

172340
@nb.jit(nopython=True, nogil=True, cache=True)
@@ -183,8 +351,8 @@ def scale_uv_freq(uvw, frequency, ref_freq):
183351

184352
for ffidx, ff in enumerate(frequency):
185353
delta_nu = (ff - ref_freq)/ref_freq
186-
uv_scaled[:,:,ffidx,0] = uvw[:,:,0] * (1 + delta_nu/ff)
187-
uv_scaled[:,:,ffidx,1] = uvw[:,:,1] * (1 + delta_nu/ff)
354+
uv_scaled[:,:,ffidx,0] = uvw[:,:,0] * (1 + delta_nu)
355+
uv_scaled[:,:,ffidx,1] = uvw[:,:,1] * (1 + delta_nu)
188356

189357
return uv_scaled
190358

0 commit comments

Comments
 (0)