@@ -306,6 +306,9 @@ def _run_nufft(
306306def _compute_basis_visibilities (
307307 beam_evaluations : list ,
308308 flux_here : np .ndarray ,
309+ ant1_idxs : np .ndarray ,
310+ ant2_idxs : np .ndarray ,
311+ beam_coeffs : np .ndarray ,
309312 freqidx : int ,
310313 topo : np .ndarray ,
311314 uvw : np .ndarray ,
@@ -353,6 +356,10 @@ def _compute_basis_visibilities(
353356 flux_here : np.ndarray
354357 Source flux, shape ``(nsrc, nfreqs)`` for unpolarized sky or
355358 ``(nsrc, nfreqs, nfeeds, nfeeds)`` for polarized sky.
359+ ant1_idxs, ant2_idxs : np.ndarray
360+ Antenna indices for each baseline, each shape ``(nbls,)``.
361+ beam_coeffs : np.ndarray
362+ Coefficients for each baseline, shape ``(nbls, nbasis)``.
356363 freqidx : int
357364 Frequency index into flux_here.
358365 topo : np.ndarray
@@ -391,7 +398,9 @@ def _compute_basis_visibilities(
391398 Basis visibility tensor, shape ``(nbasis, nbasis, nbls, nfeeds, nfeeds)``.
392399 """
393400 nbasis = len (beam_evaluations )
394- vis_basis = np .zeros ((nbasis , nbasis , nbls , nfeeds , nfeeds ), dtype = complex_dtype )
401+
402+ # Output accumulator — only (nbls, nfeeds, nfeeds) instead of (K, K, nbls, nfeeds, nfeeds)
403+ vis_out = np .zeros ((nbls , nfeeds , nfeeds ), dtype = complex_dtype )
395404
396405 # No baseline flipping in the basis path — we run all baselines at once.
397406 flipped = np .zeros (nbls , dtype = bool )
@@ -404,6 +413,10 @@ def _compute_basis_visibilities(
404413 else :
405414 _apparent_buf = np .empty (nsim_sources , dtype = complex_dtype )
406415
416+ # Gather coefficients once, outside the loop
417+ ant1_c = beam_coeffs [ant1_idxs , :] # (nbls, K)
418+ ant2_c = beam_coeffs [ant2_idxs , :] # (nbls, K)
419+
407420 for k in range (nbasis ):
408421 for l in range (nbasis ):
409422 # Reuse the same optimized kernels as the standard path.
@@ -423,7 +436,7 @@ def _compute_basis_visibilities(
423436 if phi_kl is None : # pragma: no cover
424437 continue
425438
426- vis_basis [ k , l ] = _run_nufft (
439+ vis_kl = _run_nufft (
427440 apparent_coherency = phi_kl ,
428441 topo = topo ,
429442 uvw = uvw ,
@@ -441,7 +454,10 @@ def _compute_basis_visibilities(
441454 nfeeds = nfeeds ,
442455 )
443456
444- return vis_basis
457+ weight = ant1_c [:, k ] * ant2_c [:, l ] # (nbls,)
458+ vis_out += weight [:, None , None ] * vis_kl
459+
460+ return vis_out
445461
446462
447463def _postprocess_basis_visibilities (
@@ -975,6 +991,9 @@ def _evaluate_vis_chunk(
975991 vis_basis = _compute_basis_visibilities (
976992 beam_evaluations = beam_evaluations ,
977993 flux_here = flux ,
994+ ant1_idxs = ant1_idxs ,
995+ ant2_idxs = ant2_idxs ,
996+ beam_coeffs = beam_coeffs ,
978997 freqidx = freqidx ,
979998 topo = topo ,
980999 uvw = uvw if not use_type1 else None ,
@@ -995,16 +1014,7 @@ def _evaluate_vis_chunk(
9951014 polarized_sky_model = polarized_sky_model ,
9961015 )
9971016
998- # Contract (K, K, nbls, nfeeds, nfeeds) with per-antenna
999- # coefficients to produce (nbls, nfeeds, nfeeds)
1000- vis_contracted = _postprocess_basis_visibilities (
1001- vis_basis = vis_basis ,
1002- beam_coeffs = beam_coeffs ,
1003- ant1_idxs = ant1_idxs ,
1004- ant2_idxs = ant2_idxs ,
1005- )
1006-
1007- vis [time_index , :, :, :, freqidx ] += vis_contracted
1017+ vis [time_index , :, :, :, freqidx ] += vis_basis
10081018
10091019 # -------------------------------------------------------
10101020 # Standard beam-pair path
0 commit comments