4141warnings .filterwarnings ("ignore" , message = "divide by zero encountered in log" )
4242
4343MAXIT = 10000 # maximum number of iterations in self-cal minimizer
44+ NHIST = 50 # number of steps to store for hessian approx
45+ MAXLS = 40 # maximum number of line search steps in BFGS-B
46+ STOP = 1e-6 # convergence criterion
4447
4548###################################################################################################
4649# Self-Calibration
@@ -52,7 +55,8 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
5255 ttype = 'direct' , fft_pad_factor = 2 , caltable = False ,
5356 debias = True , apply_dterms = False ,
5457 copy_closure_tables = True ,
55- processes = - 1 , show_solution = False , msgtype = 'bar' ):
58+ processes = - 1 , show_solution = False , msgtype = 'bar' ,
59+ use_grad = False ):
5660 """Self-calibrate a dataset to an image.
5761
5862 Args:
@@ -83,12 +87,14 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
8387 apply_dterms (bool): if True, apply dterms (in obs.tarr) to clean data before calibrating
8488 show_solution (bool): if True, display the solution as it is calculated
8589 msgtype (str): type of progress message to be printed, default is 'bar'
86-
90+ use_grad (bool): if True, use gradients in minimizer
91+
8792 Returns:
8893 (Obsdata): the calibrated observation, if caltable==False
8994 (Caltable): the derived calibration table, if caltable==True
9095 """
91-
96+ if use_grad and (method == 'phase' or method == 'amp' ):
97+ raise Exception ("errfunc_grad in self_cal only works with method=='both'!" )
9298 if pol not in ['I' , 'Q' , 'U' , 'V' , 'RR' , 'LL' ]:
9399 raise Exception ("Can only self-calibrate to I, Q, U, V, RR, or LL images!" )
94100 if pol in ['I' , 'Q' , 'U' , 'V' ]:
@@ -148,7 +154,8 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
148154 obs .polrep , pol ,
149155 method , minimizer_method ,
150156 show_solution , pad_amp , gain_tol ,
151- caltable , debias , msgtype
157+ debias , caltable , msgtype ,
158+ use_grad
152159 ] for i in range (len (scans ))]))
153160
154161 else : # run on a single core
@@ -157,8 +164,10 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
157164 scans_cal [i ] = self_cal_scan (scans [i ], im , V_scan = V_scans [i ], sites = sites ,
158165 polrep = obs .polrep , pol = pol ,
159166 method = method , minimizer_method = minimizer_method ,
160- show_solution = show_solution , debias = debias ,
161- pad_amp = pad_amp , gain_tol = gain_tol , caltable = caltable )
167+ show_solution = show_solution ,
168+ pad_amp = pad_amp , gain_tol = gain_tol ,
169+ debias = debias , caltable = caltable ,
170+ use_grad = use_grad )
162171
163172 tstop = time .time ()
164173 print ("\n self_cal time: %f s" % (tstop - tstart ))
@@ -201,7 +210,8 @@ def self_cal(obs, im, sites=[], method="both", pol='I', minimizer_method='BFGS',
201210
202211def self_cal_scan (scan , im , V_scan = [], sites = [], polrep = 'stokes' , pol = 'I' , method = "both" ,
203212 minimizer_method = 'BFGS' , show_solution = False ,
204- pad_amp = 0. , gain_tol = .2 , debias = True , caltable = False ):
213+ pad_amp = 0. , gain_tol = .2 , debias = True , caltable = False ,
214+ use_grad = False ):
205215 """Self-calibrate a scan to an image.
206216
207217 Args:
@@ -224,12 +234,16 @@ def self_cal_scan(scan, im, V_scan=[], sites=[], polrep='stokes', pol='I', metho
224234 debias (bool): If True, debias the amplitudes
225235 caltable (bool): if True, returns a Caltable instead of an Obsdata
226236 show_solution (bool): if True, display the solution as it is calculated
227-
237+ use_grad (bool): if True, use gradients in minimizer
238+
228239 Returns:
229240 (Obsdata): the calibrated observation, if caltable==False
230241 (Caltable): the derived calibration table, if caltable==True
231242 """
232-
243+
244+ if use_grad and (method == 'phase' or method == 'amp' ):
245+ raise Exception ("errfunc_grad in self_cal only works with method=='both'!" )
246+
233247 if len (sites ) == 0 :
234248 print ("No stations specified in self cal: defaulting to calibrating all !" )
235249 sites = list (set (scan ['t1' ]).union (set (scan ['t2' ])))
@@ -286,46 +300,25 @@ def self_cal_scan(scan, im, V_scan=[], sites=[], polrep='stokes', pol='I', metho
286300
287301 # error function
288302 def errfunc (gpar ):
289- # all the forward site gains (complex)
290- g = gpar .astype (np .float64 ).view (dtype = np .complex128 )
303+ return errfunc_full (gpar , vis , V_scan , sigma_inv , gain_tol , sites , g1_keys , g2_keys , method )
291304
292- if method == "phase" :
293- g = g / np .abs (g )
294- if method == "amp" :
295- g = np .abs (np .real (g ))
296-
297- # append the default values to g for missing gains
298- g = np .append (g , 1. )
299- g1 = g [g1_keys ]
300- g2 = g [g2_keys ]
301-
302- # build site specific tolerance parameters
303- tol0 = np .array ([gain_tol .get (s , gain_tol ['default' ])[0 ] for s in sites ])
304- tol1 = np .array ([gain_tol .get (s , gain_tol ['default' ])[1 ] for s in sites ])
305-
306- if method == 'amp' :
307- verr = np .abs (vis ) - g1 * g2 .conj () * np .abs (V_scan )
308- else :
309- verr = vis - g1 * g2 .conj () * V_scan
310-
311- nan_mask = [not np .isnan (v ) for v in verr ]
312- verr = verr [nan_mask ]
313-
314- # goodness-of-fit for gains
315- chisq = np .sum ((verr .real * sigma_inv [nan_mask ])** 2 ) + \
316- np .sum ((verr .imag * sigma_inv [nan_mask ])** 2 )
317-
318- # prior on the gains
319- # don't count the last (default missing site) gain dummy value
320- chisq_g = np .sum (np .log (np .abs (g [:- 1 ]))** 2 /
321- ((np .abs (g [:- 1 ]) > 1 ) * tol0 + (np .abs (g [:- 1 ]) <= 1 ) * tol1 )** 2 )
322-
323- return chisq + chisq_g
305+ def errfunc_grad (gpar ):
306+ return errfunc_grad_full (gpar , vis , V_scan , sigma_inv , gain_tol , sites , g1_keys , g2_keys , method )
324307
325308 # use gradient descent to find the gains
326- optdict = {'maxiter' : MAXIT } # minimizer params
327- res = opt .minimize (errfunc , gpar_guess , method = minimizer_method , options = optdict )
328-
309+ # minimizer params
310+ if minimizer_method == 'L-BFGS-B' :
311+ optdict = {'maxiter' : MAXIT ,
312+ 'ftol' : STOP , 'gtol' : STOP ,
313+ 'maxcor' : NHIST , 'maxls' : MAXLS }
314+ else :
315+ optdict = {'maxiter' : MAXIT }
316+
317+ if use_grad :
318+ res = opt .minimize (errfunc , gpar_guess , method = minimizer_method , options = optdict , jac = errfunc_grad )
319+ else :
320+ res = opt .minimize (errfunc , gpar_guess , method = minimizer_method , options = optdict )
321+
329322 # save the solution
330323 g_fit = res .x .view (np .complex128 )
331324
@@ -397,7 +390,7 @@ def get_selfcal_scan_cal(args):
397390
398391
399392def get_selfcal_scan_cal2 (i , n , scan , im , V_scan , sites , polrep , pol , method , minimizer_method ,
400- show_solution , pad_amp , gain_tol , caltable , debias , msgtype ):
393+ show_solution , pad_amp , gain_tol , debias , caltable , msgtype , use_grad ):
401394 if n > 1 :
402395 global counter
403396 counter .increment ()
@@ -406,4 +399,140 @@ def get_selfcal_scan_cal2(i, n, scan, im, V_scan, sites, polrep, pol, method, mi
406399 return self_cal_scan (scan , im , V_scan = V_scan , sites = sites , polrep = polrep , pol = pol ,
407400 method = method , minimizer_method = minimizer_method ,
408401 show_solution = show_solution ,
409- pad_amp = pad_amp , gain_tol = gain_tol , caltable = caltable , debias = debias )
402+ pad_amp = pad_amp , gain_tol = gain_tol , debias = debias , caltable = caltable ,
403+ use_grad = use_grad )
404+
405+ # error function
406+ def errfunc_full (gpar , vis , v_scan , sigma_inv , gain_tol , sites , g1_keys , g2_keys , method ):
407+ # all the forward site gains (complex)
408+ g = gpar .astype (np .float64 ).view (dtype = np .complex128 )
409+
410+ if method == "phase" :
411+ g = g / np .abs (g )
412+ if method == "amp" :
413+ g = np .abs (np .real (g ))
414+
415+ # append the default values to g for missing gains
416+ g = np .append (g , 1. )
417+ g1 = g [g1_keys ]
418+ g2 = g [g2_keys ]
419+
420+ # build site specific tolerance parameters
421+ tol0 = np .array ([gain_tol .get (s , gain_tol ['default' ])[0 ] for s in sites ])
422+ tol1 = np .array ([gain_tol .get (s , gain_tol ['default' ])[1 ] for s in sites ])
423+
424+ if method == 'amp' :
425+ verr = np .abs (vis ) - g1 * g2 .conj () * np .abs (v_scan )
426+ else :
427+ verr = vis - g1 * g2 .conj () * v_scan
428+
429+ nan_mask = [not np .isnan (v ) for v in verr ]
430+ verr = verr [nan_mask ]
431+
432+ # goodness-of-fit for gains
433+ chisq = np .sum ((verr .real * sigma_inv [nan_mask ])** 2 ) + \
434+ np .sum ((verr .imag * sigma_inv [nan_mask ])** 2 )
435+
436+ # prior on the gains
437+ # don't count the last (default missing site) gain dummy value
438+ tolsq = ((np .abs (g [:- 1 ]) > 1 ) * tol0 + (np .abs (g [:- 1 ]) <= 1 ) * tol1 )** 2
439+ chisq_g = np .sum (np .log (np .abs (g [:- 1 ]))** 2 / tolsq )
440+
441+ # total chi^2
442+ chisqtot = chisq + chisq_g
443+ return chisqtot
444+
445+ def errfunc_grad_full (gpar , vis , v_scan , sigma_inv , gain_tol , sites , g1_keys , g2_keys , method ):
446+ # does not work for method=='phase' or method=='amp'
447+ if method == 'phase' or method == 'amp' :
448+ raise Exception ("errfunc_grad in self_cal only works with method=='both'!" )
449+
450+ # all the forward site gains (complex)
451+ g = gpar .astype (np .float64 ).view (dtype = np .complex128 )
452+ gr = np .real (g )
453+ gi = np .imag (g )
454+
455+ # build site specific tolerance parameters
456+ tol0 = np .array ([gain_tol .get (s , gain_tol ['default' ])[0 ] for s in sites ])
457+ tol1 = np .array ([gain_tol .get (s , gain_tol ['default' ])[1 ] for s in sites ])
458+
459+ # append the default values to g for missing gains
460+ g = np .append (g , 1. )
461+ g1 = g [g1_keys ]
462+ g2 = g [g2_keys ]
463+
464+ g1r = np .real (g1 )
465+ g1i = np .imag (g1 )
466+ g2r = np .real (g2 )
467+ g2i = np .imag (g2 )
468+
469+ v_scan_sq = v_scan * v_scan .conj ()
470+ g1sq = g1 * (g1 .conj ())
471+ g2sq = g2 * (g2 .conj ())
472+
473+ ###################################
474+ # data term chi^2 derivitive
475+ ###################################
476+
477+ # chi^2 term gradients
478+ dchisq_dg1r = (- g2 .conj ()* vis .conj ()* v_scan - g2 * vis * v_scan .conj () + 2 * g1r * g2sq * v_scan_sq )
479+ dchisq_dg1i = (- 1j * g2 .conj ()* vis .conj ()* v_scan + 1j * g2 * vis * v_scan .conj () + 2 * g1i * g2sq * v_scan_sq )
480+
481+ dchisq_dg2r = (- g1 * vis .conj ()* v_scan - g1 .conj ()* vis * v_scan .conj () + 2 * g2r * g1sq * v_scan_sq )
482+ dchisq_dg2i = (1j * g1 * vis .conj ()* v_scan - 1j * g1 .conj ()* vis * v_scan .conj () + 2 * g2i * g1sq * v_scan_sq )
483+
484+
485+ dchisq_dg1r *= ((sigma_inv )** 2 )
486+ dchisq_dg1i *= ((sigma_inv )** 2 )
487+ dchisq_dg2r *= ((sigma_inv )** 2 )
488+ dchisq_dg2i *= ((sigma_inv )** 2 )
489+
490+ # same masking function as in errfunc
491+ # preserve length of dchisq arrays
492+ verr = vis - g1 * g2 .conj () * v_scan
493+ nan_mask = np .isnan (verr )
494+
495+ dchisq_dg1r [nan_mask ] = 0
496+ dchisq_dg1i [nan_mask ] = 0
497+ dchisq_dg2r [nan_mask ] = 0
498+ dchisq_dg2i [nan_mask ] = 0
499+
500+ # derivitives of real and imaginary gains
501+ dchisq_dgr = np .zeros (len (gpar )// 2 ) #len(gpar) must be even
502+ dchisq_dgi = np .zeros (len (gpar )// 2 )
503+
504+ # TODO faster than a for loop?
505+ for i in range (len (gpar )// 2 ):
506+ g1idx = np .argwhere (np .array (g1_keys )== i )
507+ g2idx = np .argwhere (np .array (g2_keys )== i )
508+
509+ dchisq_dgr [i ] = np .sum (dchisq_dg1r [g1idx ]) + np .sum (dchisq_dg2r [g2idx ])
510+ dchisq_dgi [i ] = np .sum (dchisq_dg1i [g1idx ]) + np .sum (dchisq_dg2i [g2idx ])
511+
512+ ###################################
513+ # prior term chi^2 derivitive
514+ ###################################
515+
516+ # NOTE this derivitive doesn't account for possible sharp change in tol at g=1
517+ gsq = np .abs (g [:- 1 ])** 2 # don't count default missing site dummy value
518+ tolsq = ((np .abs (g [:- 1 ]) > 1 ) * tol0 + (np .abs (g [:- 1 ]) <= 1 ) * tol1 )** 2
519+
520+ dchisqg_dgr = gr * np .log (gsq )/ gsq / tolsq
521+ dchisqg_dgi = gi * np .log (gsq )/ gsq / tolsq
522+
523+ # total derivative
524+ dchisqtot_dgr = dchisq_dgr + dchisqg_dgr
525+ dchisqtot_dgi = dchisq_dgi + dchisqg_dgi
526+
527+ # interleave final derivs
528+ dchisqtot_dgpar = np .zeros (len (gpar ))
529+ dchisqtot_dgpar [0 ::2 ] = dchisqtot_dgr
530+ dchisqtot_dgpar [1 ::2 ] = dchisqtot_dgi
531+
532+ # any imaginary parts??? should all be real
533+ dchisqtot_dgpar = np .real (dchisqtot_dgpar )
534+
535+ return dchisqtot_dgpar
536+
537+
538+
0 commit comments