33"""
44
55import dask
6+ import time
67
78import numba as nb
9+ from numba .typed import List
810import numpy as np
11+ import gc
12+ #from memory_profiler import profile
913
10- from scipy .stats import median_abs_deviation
1114
15+ @nb .njit (nogil = True , cache = True , fastmath = True )
16+ def _merge_vis_lists (ps_vis_accum , vis_accum , npixu , npixv ):
1217
13- #@dask.delayed
14- def compute_uv_histogram (input_params ):
18+ print ("start _merge_vis_lists" )
19+
20+ for uu in range (npixu ):
21+ for vv in range (npixv ):
22+ if len (vis_accum [uu ][vv ]) > 0 :
23+ ps_vis_accum [uu ][vv ].extend (vis_accum [uu ][vv ])
24+
25+ print ("end _merge_vis_lists " )
26+
27+ return ps_vis_accum
28+
29+
30+ #@profile
31+ def accumulate_uv_points (input_params ):
1532 """
1633 Read in the input XDS and calculate a histogram per pixel in the UV plane.
1734 """
18- print ('Processing task with id: ' , input_params ['task_id' ])
35+ gc .collect ()
36+
1937 from xradio .vis .load_processing_set import load_processing_set
2038
2139 uvrange = np .asarray (sorted (input_params ['uvrange' ]))
2240 uvcell = input_params ['uvcell' ]
2341 nhistbin = input_params ['nhistbin' ]
2442 npixu , npixv = input_params ['npixels' ]
2543
26- accum_uv_hist_med = np . zeros (( npixu , npixv ))
27- accum_uv_hist_std = np . zeros (( npixu , npixv ))
28- vis_accum = np . empty (( npixu , npixv ), dtype = object )
44+ ps_vis_accum = List ([ List ([ List . empty_list ( nb . f8 ) for y in range ( npixv )]) for z in range ( npixu )] )
45+
46+ print ( input_params [ "data_selection" ]. items () )
2947
3048 for ms_v4_name , slice_description in input_params ["data_selection" ].items ():
3149 if input_params ["input_data" ] is None :
32- ps = load_processing_set (
33- ps_name = input_params ["input_data_store" ],
50+ ps = load_processing_set (input_params ["input_data_store" ],
3451 sel_parms = {ms_v4_name : slice_description },
3552 )
3653 else :
@@ -39,8 +56,8 @@ def compute_uv_histogram(input_params):
3956 ms_xds = ps .get (0 )
4057
4158 #ref_freq = float(ms_xds.frequency.attrs['reference_frequency']['data'])
42- ref_freq = np .mean (ms_xds .frequency ).values
43-
59+ ref_freq = np .mean (ms_xds .frequency ).values
60+
4461 min_baseline = ms_xds .baseline_id .min ().data
4562 max_baseline = ms_xds .baseline_id .max ().data
4663
@@ -52,70 +69,112 @@ def compute_uv_histogram(input_params):
5269 flag = ms_xds .FLAG .data .astype (bool )
5370 freq = ms_xds .frequency .data
5471
55- # Flip flags and replace NaNs with zeros, so they flag the corresponding visibilities
56- #flag = np.nan_to_num(~flag).astype(bool)
72+ print (uvw .shape , vis .shape , flag .shape , freq .shape )
73+
74+ def getsize (arr ):
75+ return round (arr .nbytes / 1024 / 1024 , 2 )
76+
77+ print ("UVW, Vis, flag, freq in MB" , getsize (uvw ), getsize (vis ), getsize (flag ), getsize (freq ))
78+
79+ import sys , psutil
80+ process = psutil .Process ()
81+
82+ print ("Size of PS, ms_xds in bytes " , sys .getsizeof (ps ), sys .getsizeof (ms_xds ))
83+ print ("Total process memory in MB " , process .memory_info ().rss / 1024 / 1024 )
84+
85+ del ps , ms_xds
86+ gc .collect ()
87+ print ("Total process memory in MB after GC" , process .memory_info ().rss / 1024 / 1024 )
88+
5789 vis = np .nan_to_num (vis )
58- # Flag visibilities
59- # vis = np.asarray(vis*~flag)
90+ # Apply previously computed flags
91+ vis = np .asarray (vis * ~ flag )
6092
6193 uvw = np .nan_to_num (uvw )
94+ t1 = time .time ()
6295 uv_scaled = scale_uv_freq (np .asarray (uvw ), np .asarray (freq ), ref_freq )
96+ t2 = time .time ()
97+ print (f"scale_uv_freq time { t2 - t1 } s" )
6398
64- # Create a histogram per UV pixel - some might be entirely zeros, with no data.
99+ npt = uv_scaled .reshape ([- 1 ,2 ]).shape [0 ]
100+
101+ # Create a list of visibilities per UV pixel - some might be entirely zeros, with no data.
65102 # Manually verified that the reshape works for a handful of random indices
66- uv_histogram (vis_accum , uv_scaled .reshape ([- 1 ,2 ]), vis .reshape ([- 1 ,2 ]), uvrange , uvcell , npixu , npixv )
103+ t1 = time .time ()
104+ vis_accum = vis_per_uv_pixel (uv_scaled .reshape ([- 1 ,2 ]), vis .reshape ([- 1 ,2 ]), uvrange , uvcell , npixu , npixv , npt )
105+ t2 = time .time ()
106+ print (f"vis_per_uv_pixel time { t2 - t1 } s" )
67107
108+ t1 = time .time ()
109+ ps_vis_accum = _merge_vis_lists (ps_vis_accum , vis_accum , npixu , npixv )
110+ t2 = time .time ()
111+ print (f"_merge_vis_lists time { t2 - t1 } s" )
68112
69- return vis_accum
113+ #ps_vis_accum = np.asarray(ps_vis_accum)
114+ t1 = time .time ()
115+ uv_med_grid , uv_std_grid , uv_npt_grid = calc_uv_stats (ps_vis_accum , npixu , npixv )
116+ t2 = time .time ()
117+ print (f"calc_uv_stats time { t2 - t1 } s" )
70118
119+ return uv_med_grid , uv_std_grid , uv_npt_grid
71120
72- #@nb.jit(nopython=True, nogil=True, cache=True)
73- def merge_uv_grids (results , input_parms ):
121+
122+
123+ @nb .njit (cache = True , nogil = True , fastmath = True )
124+ def mad_std (data ):
74125 """
75- Given the list of results from compute_uv_histogram, merge the UV grids together to compute stats
126+ Calculate the median absolute deviation of the data. Cannot use a "built-in" function like
127+ astropy.stats.mad_std because we want to call this function inside numba.jit
76128
77129 Inputs:
78- results : np.array(list) - All points falling within a UV cell
79- npixu : Number of pixels in U
80- npixv : Number of pixels in V
81- nhistbin : Number of histogram bins
130+ data : np.array - Data
82131
83132 Returns:
84- accum_uv_hist_med : np.array - Median of the histogram per pixel
85- accum_uv_hist_std : np.array - Standard deviation of the histogram per pixel
133+ mad_std : float - Median absolute deviation
86134 """
87135
88- npixu = input_parms ['npixu' ]
89- npixv = input_parms ['npixv' ]
90- nhistbin = input_parms ['nhistbin' ]
136+ median = np .median (data )
137+ mad = np .median (np .abs (data - median ))
138+ std = 1.4826 * mad
139+
140+ return std
141+
142+ return np .median (np .abs (data - np .median (data )))
143+
91144
92- nchunk = len (results )
93- accum_uv_hist_med = np .zeros ((npixu , npixv ))
94- accum_uv_hist_std = np .zeros ((npixu , npixv ))
145+ @nb .njit (cache = True , nogil = True , fastmath = True )
146+ def calc_uv_stats (ps_vis_accum , npixu , npixv ):
147+ """
148+ Calculate the median and standard deviation of the visibilities per UV pixel.
149+
150+ Inputs:
151+ ps_vis_accum : np.array - Visibilities
152+ npixu : int - Number of pixels in U
153+ npixv : int - Number of pixels in V
154+
155+ Returns:
156+ uv_med_grid : np.array - Median of the histogram per pixel
157+ uv_std_grid : np.array - Standard deviation of the histogram per pixel
158+ """
95159
96- #for nn in range(nchunk):
97- # print(results[0][nn].shape)
98- # for uu in range(npixu):
99- # for vv in range(npixv):
100- # if len(results[0][nn][uu,vv]) > 0:
101- # print(f"nn {nn} uu {uu} vv {vv}")
102- # print(len(results[0][nn][uu,vv]))
103- # input()
160+ print ("start calc_uv_stats" )
104161
162+ uv_med_grid = np .zeros ((npixu , npixv ))
163+ uv_std_grid = np .zeros ((npixu , npixv ))
164+ uv_npt_grid = np .zeros ((npixu , npixv ), dtype = np .int64 )
105165
106166 for uu in range (npixu ):
107167 for vv in range (npixv ):
108- concat_list = []
109- for nn in range (nchunk ):
110- if len (results [nn ][uu ,vv ]) > 0 :
111- concat_list .extend (results [nn ][uu ,vv ])
112-
113- if len (concat_list ) > 0 :
114- concat_list = np .nan_to_num (concat_list )
115- accum_uv_hist_med [uu , vv ] = np .median (concat_list [concat_list != 0 ])
116- accum_uv_hist_std [uu , vv ] = median_abs_deviation (concat_list [concat_list != 0 ])
168+ if len (ps_vis_accum [uu ][vv ]) > 0 :
169+ uv_med_grid [uu , vv ] = np .median (np .asarray (ps_vis_accum [uu ][vv ]))
170+ uv_std_grid [uu , vv ] = mad_std (np .asarray (ps_vis_accum [uu ][vv ]))
171+ uv_npt_grid [uu , vv ] = len (ps_vis_accum [uu ][vv ])
172+ else :
173+ uv_med_grid [uu , vv ] = 0
174+ uv_std_grid [uu , vv ] = 0
117175
118- return [accum_uv_hist_med , accum_uv_hist_std ]
176+ print ("end calc_uv_stats" )
177+ return uv_med_grid , uv_std_grid , uv_npt_grid
119178
120179
121180
@@ -136,37 +195,146 @@ def hermitian_conjugate(uv, vis):
136195
137196
138197#@profile
139- # @nb.jit(nopython=True, nogil=True, cache=True)
140- def uv_histogram ( vis_hist , uv , vis , uvrange , uvcell , npixu , npixv ):
198+ @nb .jit (nopython = True , nogil = True , cache = True , fastmath = True )
199+ def vis_per_uv_pixel ( uv , vis , uvrange , uvcell , npixu , npixv , npt ):
141200 """
142- Generate a histogram per UV pixel, given the input UV coordinates & visibilities.
201+ Accumulate list of visibilities per UV pixel
143202 """
144203
145204 uvrange = sorted (uvrange )
146205 uv , vis = hermitian_conjugate (np .asarray (uv ), np .asarray (vis ))
206+ nptuv = np .zeros ((npixu , npixv ))
147207
148- #vis_hist = np.zeros((npixu, npixv), dtype=object)
208+ # numba hack : Create an empty typed list
209+ vis_accum = List ([List ([List ([float (x ) for x in range (0 )]) for y in range (npixv )]) for z in range (npixu )])
149210
150211 stokesI = np .abs ((vis [...,0 ] + vis [...,- 1 ])/ 2. )
151212
152- # Initialize empty lists
153- for uu in range ( npixu ):
154- for vv in range ( npixv ) :
155- vis_hist [ uu ][ vv ] = []
213+ idx = 0
214+ for didx , dat in enumerate ( stokesI ):
215+ if dat == 0 :
216+ continue
156217
157- for idx , dat in enumerate (stokesI ):
158- uvdist = np .sqrt (uv [idx , 0 ]** 2 + uv [idx , 1 ]** 2 )
218+ uvdist = np .sqrt (uv [didx , 0 ]** 2 + uv [didx , 1 ]** 2 )
159219
160220 if uvdist > uvrange [1 ] or uvdist < uvrange [0 ]:
161221 continue
162222
163- ubin = int ((uv [idx , 0 ] + uvrange [1 ])// uvcell )
164- vbin = int (uv [idx , 1 ]// uvcell )
223+ ubin = int ((uv [didx , 0 ] + uvrange [1 ])// uvcell )
224+ vbin = int (uv [didx , 1 ]// uvcell )
225+
226+ vis_accum [ubin ][vbin ].append (dat )
227+ nptuv [ubin , vbin ] += 1
228+
229+ print ("Number of points appended " , np .sum (nptuv ))
230+ print ("Approx size " , np .sum (nptuv )* 8 / 1024 / 1024 )
231+ print ("end vis_per_uv_pixel" )
232+ return vis_accum
233+
234+
235+
236+ @nb .njit (nogil = True , cache = True , fastmath = True )
237+ def _accum_means (mean1 , npt1 , mean2 , npt2 ):
238+ """
239+ Calculate the aggregate mean given two input mean values.
240+
241+ Inputs:
242+ mean1 : float - Mean 1
243+ npt1 : int - Number of points for mean 1
244+ mean2 : float - Mean 2
245+ npt2 : int - Number of points for mean 2
246+
247+ Returns:
248+ mean : float - Aggregate mean
249+ npt : int - Number of points
250+ """
251+
252+ npt = npt1 + npt2
253+ if npt == 0 :
254+ return 0 , 0
255+
256+ mean = (mean1 * npt1 + mean2 * npt2 )/ npt
257+
258+ return mean , npt
259+
260+
261+ @nb .njit (nogil = True , cache = True , fastmath = True )
262+ def _accum_std (mean1 , std1 , npt1 , mean2 , std2 , npt2 ):
263+ """
264+ Calculate the aggregate standard deviation given two input standard deviation values.
265+
266+ Inputs:
267+ mean1 : float - Mean 1
268+ std1 : float - Standard deviation 1
269+ npt1 : int - Number of points for mean 1
270+ mean2 : float - Mean 2
271+ std2 : float - Standard deviation 2
272+ npt2 : int - Number of points for mean 2
273+
274+ Returns:
275+ std : float - Aggregate standard deviation
276+ npt : int - Number of points
277+ """
278+
279+ npt = npt1 + npt2
280+ if npt < 2 :
281+ return 0 , 0
282+
283+ var1 = ((npt1 - 1 )* std1 ** 2 + (npt2 - 1 )* std2 ** 2 )/ (npt1 + npt2 - 1 )
284+ var2 = ((npt1 * npt2 ) * (mean1 - mean2 )** 2 )/ ((npt1 + npt2 )* (npt1 + npt2 - 1 ))
285+
286+ std = np .sqrt (var1 + var2 )
287+
288+ return std , npt
289+
290+
291+ #@nb.jit(nopython=True, nogil=True, cache=True)
292+ def merge_uv_grids (graph , input_params ):
293+ """
294+ Given the list of results from accumulate_uv_points, merge the UV grids together to compute stats
295+
296+ Inputs:
297+ graph : list(np.array, np.array) - Each element contains the median, std dev and npt UV grid
298+ npixu : Number of pixels in U
299+ npixv : Number of pixels in V
300+ nhistbin : Number of histogram bins
301+
302+ Returns:
303+ accum_uvmed : np.array - Median of the histogram per pixel
304+ accum_uvstd : np.array - Standard deviation of the histogram per pixel
305+ """
306+
307+ npixu = input_params ['npixels' ][0 ]
308+ npixv = input_params ['npixels' ][1 ]
309+ nhistbin = input_params ['nhistbin' ]
310+
311+ # Graph is a tuple of 2 nested elements : (median, std, npt) from each node
312+ # of the input DAG merge_uv_grids should accumulate these grids onto a
313+ # single one, and return it.
314+
315+ nchunk = len (graph )
316+
317+ accum_uv_med = np .zeros ((npixu , npixv ))
318+ accum_uv_std = np .zeros ((npixu , npixv ))
319+ uvnpt = np .zeros ((npixu , npixv ), dtype = np .int64 )
320+
321+ med0 = graph [0 ][0 ]
322+ std0 = graph [0 ][1 ]
323+ npt0 = graph [0 ][2 ]
324+
325+ med1 = graph [1 ][0 ]
326+ std1 = graph [1 ][1 ]
327+ npt1 = graph [1 ][2 ]
328+
329+ for uu in range (npixu ):
330+ for vv in range (npixv ):
331+ accum_uv_med [uu ,vv ], uvnpt [uu ,vv ] = _accum_means (med0 [uu ,vv ], npt0 [uu ,vv ], med1 [uu ,vv ], npt1 [uu ,vv ])
332+ accum_uv_std [uu ,vv ], __ = _accum_std (med0 [uu ,vv ], std0 [uu ,vv ], npt0 [uu ,vv ], med1 [uu ,vv ], std1 [uu ,vv ], npt1 [uu ,vv ])
333+
165334
166- if dat != 0 :
167- vis_hist [ ubin , vbin ]. append ( dat )
335+ np . savez ( 'accum_uv_med.npz' , accum_uv_med , accum_uv_std , uvnpt )
336+ return accum_uv_med , accum_uv_std , uvnpt
168337
169- #return vis_hist
170338
171339
172340@nb .jit (nopython = True , nogil = True , cache = True )
@@ -183,8 +351,8 @@ def scale_uv_freq(uvw, frequency, ref_freq):
183351
184352 for ffidx , ff in enumerate (frequency ):
185353 delta_nu = (ff - ref_freq )/ ref_freq
186- uv_scaled [:,:,ffidx ,0 ] = uvw [:,:,0 ] * (1 + delta_nu / ff )
187- uv_scaled [:,:,ffidx ,1 ] = uvw [:,:,1 ] * (1 + delta_nu / ff )
354+ uv_scaled [:,:,ffidx ,0 ] = uvw [:,:,0 ] * (1 + delta_nu )
355+ uv_scaled [:,:,ffidx ,1 ] = uvw [:,:,1 ] * (1 + delta_nu )
188356
189357 return uv_scaled
190358
0 commit comments