@@ -153,12 +153,11 @@ def comps2vis(
153153 divide_by_n = False ,
154154 freq_min = - np .inf ,
155155 freq_max = np .inf ,
156- ncorr_out = 4 ,
157- product = 'I' ,
158- poltype = 'linear' ):
156+ product = 'I' ):
159157
160158 # determine output type
161159 complex_type = da .result_type (mds .coefficients .dtype , np .complex64 )
160+ ncorr_out = len (product )
162161
163162 return da .blockwise (_comps2vis , 'rfc' ,
164163 uvw , 'r3' ,
@@ -181,9 +180,7 @@ def comps2vis(
181180 divide_by_n , None ,
182181 freq_min , None ,
183182 freq_max , None ,
184- ncorr_out , None ,
185183 product , None ,
186- poltype , None ,
187184 new_axes = {'c' : ncorr_out },
188185 # it should be getting these from uvw and freq?
189186 adjust_chunks = {'r' : uvw .chunks [0 ]},
@@ -209,9 +206,7 @@ def _comps2vis(
209206 divide_by_n = False ,
210207 freq_min = - np .inf ,
211208 freq_max = np .inf ,
212- ncorr_out = 4 ,
213- product = 'I' ,
214- poltype = 'linear' ):
209+ product = 'I' ):
215210 return _comps2vis_impl (
216211 uvw [0 ],
217212 utime ,
@@ -230,9 +225,7 @@ def _comps2vis(
230225 divide_by_n = divide_by_n ,
231226 freq_min = freq_min ,
232227 freq_max = freq_max ,
233- ncorr_out = ncorr_out ,
234- product = product ,
235- poltype = poltype )
228+ product = product )
236229
237230
238231
@@ -253,13 +246,9 @@ def _comps2vis_impl(uvw,
253246 divide_by_n = False ,
254247 freq_min = - np .inf ,
255248 freq_max = np .inf ,
256- ncorr_out = 4 ,
257- product = 'I' ,
258- poltype = 'linear' ):
249+ product = 'I' ):
259250 # why is this necessary?
260251 resize_thread_pool (nthreads )
261- msg = f"Polarisation product { product } is not compatible with the " \
262- f"number of correlations { ncorr_out } "
263252
264253 # adjust for chunking
265254 # need a copy here if using multiple row chunks
@@ -272,7 +261,8 @@ def _comps2vis_impl(uvw,
272261
273262 nrow = uvw .shape [0 ]
274263 nchan = freq .size
275- vis = np .zeros ((nrow , nchan , ncorr_out ),
264+ nstokes_out = len (product )
265+ vis = np .zeros ((nrow , nchan , nstokes_out ),
276266 dtype = np .result_type (mds .coefficients .dtype , np .complex64 ))
277267 if not ((freq >= freq_min ) & (freq <= freq_max )).any ():
278268 return vis
@@ -308,65 +298,19 @@ def _comps2vis_impl(uvw,
308298 image [Ix , Iy ] = modelf (tout , fout , * comps [:, :]) # too magical?
309299 if np .any (region_mask ):
310300 image = np .where (region_mask , image , 0.0 )
311- vis_stokes = dirty2vis (uvw = uvw ,
312- freq = f ,
313- dirty = image ,
314- pixsize_x = cellx , pixsize_y = celly ,
315- center_x = x0 , center_y = y0 ,
316- flip_u = flip_u ,
317- flip_v = flip_v ,
318- flip_w = flip_w ,
319- epsilon = epsilon ,
320- do_wgridding = do_wgridding ,
321- divide_by_n = divide_by_n ,
322- nthreads = nthreads )
323- if ncorr_out == 1 :
324- vis [indr , indf , 0 ] = vis_stokes
325- elif ncorr_out == 2 :
326- if product .upper () == 'I' :
327- vis [indr , indf , 0 ] = vis_stokes
328- vis [indr , indf , - 1 ] = vis_stokes
329- elif product .upper () == 'Q' :
330- if poltype .lower () == 'linear' :
331- vis [indr , indf , 0 ] = vis_stokes
332- vis [indr , indf , - 1 ] = vis_stokes
333- else :
334- raise ValueError (msg )
335- elif product .upper () == 'V' :
336- if poltype .lower () == 'linear' :
337- raise ValueError (msg )
338- else :
339- vis [indr , indf , 0 ] = vis_stokes
340- vis [indr , indf , - 1 ] = - vis_stokes
341- else :
342- raise ValueError (msg )
343- elif ncorr_out == 4 :
344- if product .upper () == 'I' :
345- vis [indr , indf , 0 ] = vis_stokes
346- vis [indr , indf , - 1 ] = vis_stokes
347- elif product .upper () == 'Q' :
348- if poltype .lower () == 'linear' :
349- vis [indr , indf , 0 ] = vis_stokes
350- vis [indr , indf , - 1 ] = vis_stokes
351- else :
352- vis [indr , indf , 1 ] = vis_stokes
353- vis [indr , indf , 2 ] = vis_stokes
354- elif product .upper () == 'U' :
355- if poltype .lower () == 'linear' :
356- vis [indr , indf , 1 ] = vis_stokes
357- vis [indr , indf , 2 ] = vis_stokes
358- else :
359- vis [indr , indf , 1 ] = 1.0j * vis_stokes
360- vis [indr , indf , 2 ] = - 1.0j * vis_stokes
361- elif product .upper () == 'V' :
362- if poltype .lower () == 'linear' :
363- vis [indr , indf , 1 ] = 1.0j * vis_stokes
364- vis [indr , indf , 2 ] = - 1.0j * vis_stokes
365- else :
366- vis [indr , indf , 0 ] = vis_stokes
367- vis [indr , indf , 1 ] = vis_stokes
368- else :
369- raise ValueError (f"Unknown product { product } " )
301+ for c in range (nstokes_out ):
302+ vis [indr , indf , c ] = dirty2vis (uvw = uvw ,
303+ freq = f ,
304+ dirty = image ,
305+ pixsize_x = cellx , pixsize_y = celly ,
306+ center_x = x0 , center_y = y0 ,
307+ flip_u = flip_u ,
308+ flip_v = flip_v ,
309+ flip_w = flip_w ,
310+ epsilon = epsilon ,
311+ do_wgridding = do_wgridding ,
312+ divide_by_n = divide_by_n ,
313+ nthreads = nthreads )
370314
371315 return vis
372316
0 commit comments