@@ -413,13 +413,18 @@ def _compute_basis_visibilities(
413413 else :
414414 _apparent_buf = np .empty (nsim_sources , dtype = complex_dtype )
415415
416- # Gather coefficients once, outside the loop
417- ant1_c = beam_coeffs [ant1_idxs , :] # (nbls, K)
418- ant2_c = beam_coeffs [ant2_idxs , :] # (nbls, K)
419-
416+ # Gather coefficients once, outside the loop.
417+ # The measurement equation is V_ij = A_i^H C A_j, so the left (ant1)
418+ # coefficients are conjugated and the right (ant2) are not.
419+ ant1_c = beam_coeffs [ant1_idxs , :].conj () # C_ik^* (nbls, K)
420+ ant2_c = beam_coeffs [ant2_idxs , :] # C_jl (nbls, K)
421+
422+ # Only iterate over the upper triangle (k <= l) and use the conjugate
423+ # symmetry V_tilde[l, k] = V_tilde[k, l]^* to handle the lower triangle
424+ # without an extra NUFFT. This halves the number of NUFFTs from K^2 to
425+ # K*(K+1)/2 at no cost to accuracy.
420426 for k in range (nbasis ):
421- for l in range (nbasis ):
422- # Reuse the same optimized kernels as the standard path.
427+ for l in range (k , nbasis ):
423428 phi_kl = _compute_apparent_coherency (
424429 beam_evaluations = beam_evaluations ,
425430 bi = k ,
@@ -452,45 +457,19 @@ def _compute_basis_visibilities(
452457 n_threads = n_threads ,
453458 upsample_factor = upsample_factor ,
454459 nfeeds = nfeeds ,
455- )
456-
457- weight = ant1_c [:, k ] * ant2_c [:, l ] # (nbls,)
458- vis_out += weight [:, None , None ] * vis_kl
460+ ) # (nbls, nfeeds, nfeeds)
459461
460- return vis_out
462+ # (k, l) contribution: weight = C1[b,k] * C2[b,l]^*
463+ w_kl = ant1_c [:, k ] * ant2_c [:, l ] # (nbls,)
464+ vis_out += w_kl [:, None , None ] * vis_kl
461465
466+ if l != k :
467+ # (l, k) contribution: V_tilde[l,k] = V_tilde[k,l]^*
468+ # but the weights are different since ant1 != ant2 in general
469+ w_lk = ant1_c [:, l ] * ant2_c [:, k ] # (nbls,)
470+ vis_out += w_lk [:, None , None ] * vis_kl .conj ()
462471
463- def _postprocess_basis_visibilities (
464- vis_basis : np .ndarray ,
465- beam_coeffs : np .ndarray ,
466- ant1_idxs : np .ndarray ,
467- ant2_idxs : np .ndarray ,
468- ) -> np .ndarray :
469- """Contract basis visibilities with per-antenna SVD coefficients.
470-
471- For each baseline b connecting antennas i and j:
472- V_b = sum_{k,l} C[i,k] * C[j,l]^* * V_tilde[k, l, b]
473-
474- Parameters
475- ----------
476- vis_basis : np.ndarray
477- Basis visibility tensor, shape (nbasis, nbasis, nbls, nfeeds, nfeeds).
478- beam_coeffs : np.ndarray
479- Per-antenna SVD coefficients, shape (N_ant, nbasis).
480- ant1_idxs, ant2_idxs : np.ndarray
481- Antenna index per baseline, each shape (nbls,).
482-
483- Returns
484- -------
485- np.ndarray
486- Visibilities shaped (nbls, nfeeds, nfeeds).
487- """
488- # Gather coefficients for each baseline's antenna pair
489- C1 = beam_coeffs [ant1_idxs , :] # (nbls, nbasis)
490- C2 = beam_coeffs [ant2_idxs , :].conj () # (nbls, nbasis)
491-
492- # V_b,p,q = sum_{k,l} C1[b,k] * C2[b,l] * vis_basis[k,l,b,p,q]
493- return np .einsum ('bk,bl,klbpq->bpq' , C1 , C2 , vis_basis )
472+ return vis_out
494473
495474
496475@ray .remote
0 commit comments