Skip to content

Commit ed9a0c3

Browse files
committed
update checks to pass when beam_coefs is provided
1 parent bf9ea68 commit ed9a0c3

5 files changed

Lines changed: 63 additions & 142 deletions

File tree

src/fftvis/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from .cpu.cpu_simulate import CPUSimulationEngine
1111
from .wrapper import create_simulation_engine, simulate_vis
1212

13+
# Import beam basis decomposition utility
14+
from .core.beam_basis import compute_beam_basis
15+
1316
# Import utility modules
1417
from . import utils, logutils
1518

@@ -23,4 +26,6 @@
2326
"CPUSimulationEngine",
2427
"create_simulation_engine",
2528
"simulate_vis",
29+
# Beam basis decomposition
30+
"compute_beam_basis",
2631
]

src/fftvis/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
from .beams import BeamEvaluator
44
from .simulate import SimulationEngine, default_accuracy_dict
5+
from .beam_basis import compute_beam_basis

src/fftvis/core/beam_basis.py

Lines changed: 15 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -14,140 +14,56 @@
1414

1515

1616
def compute_beam_basis(
17-
beam_list: list[UVBeam],
17+
beam_list,
1818
freq: float,
1919
threshold: float = 1e-12,
20-
polarized: bool = False,
21-
) -> tuple[list[UVBeam], np.ndarray]:
22-
"""Decompose a set of UVBeams into an SVD eigenbeam basis.
23-
24-
Each beam is first interpolated onto a common frequency grid to keep memory
25-
usage predictable, then stacked into a matrix and SVD'd. The basis is
26-
truncated at the first component whose normalised singular value drops below
27-
``threshold``.
28-
29-
The decomposition satisfies::
30-
31-
beam_i(az, za, freq) ≈ sum_k coeffs[i, k] * eigenbeams[k](az, za, freq)
32-
33-
so ``eigenbeams`` can be passed directly as ``beam_list`` to
34-
``CPUSimulationEngine.simulate`` and ``coeffs`` as ``beam_coeffs``.
35-
36-
Parameters
37-
----------
38-
beam_list : list of UVBeam
39-
Input beams to decompose. All beams must share the same spatial grid
40-
(``Nax1``, ``Nax2``) and ``beam_type`` (``efield`` or ``power``).
41-
freq : float
42-
Frequency in Hz at which every beam is evaluated before the SVD.
43-
threshold : float, optional
44-
Normalised singular value cutoff. Basis components ``k`` with
45-
``s[k] / s[0] < threshold`` are discarded. Default ``1e-3``.
46-
polarized : bool, optional
47-
If ``True``, the full Jones matrix (all ``Naxes_vec`` and ``Nfeeds``
48-
components) is included in the SVD. If ``False``, only the first
49-
component ``data_array[0, 0, 0, ...]`` is used, matching the scalar
50-
beam path in the simulator. Default ``False``.
51-
52-
Returns
53-
-------
54-
eigenbeams : list of UVBeam
55-
The ``K`` truncated eigenbeam objects. Each has the same metadata and
56-
spatial/frequency grid as the interpolated input beams.
57-
coeffs : np.ndarray
58-
Per-antenna coefficients of shape ``(N_beams, K)``. These are the
59-
values to pass as ``beam_coeffs`` to the simulator.
60-
61-
Notes
62-
-----
63-
The function allocates two large intermediate arrays simultaneously: the
64-
full stacked beam matrix ``B`` of shape ``(N_beams, N_flat)`` and the ``Vh``
65-
factor of the same shape. Peak memory is therefore roughly
66-
``2 * N_beams * N_flat * itemsize``. Choosing a sparse ``freq_grid``
67-
directly reduces ``N_flat``.
68-
69-
Only ``efield`` beams have been tested; ``power`` beams should work but
70-
the resulting eigenbeams will have real-valued ``data_array`` only if all
71-
inputs are real.
20+
):
21+
"""
7222
"""
7323
if len(beam_list) == 0:
7424
raise ValueError("beam_list must contain at least one beam.")
7525

7626
n_beams = len(beam_list)
7727
freq_grid = np.atleast_1d(freq)
7828

79-
# ------------------------------------------------------------------
80-
# Step 1: Interpolate every beam to the common frequency grid.
81-
# ------------------------------------------------------------------
82-
logger.info(
83-
f"Interpolating {n_beams} beams to {len(freq_grid)} frequencies "
84-
f"({freq_grid[0]/1e6:.1f}{freq_grid[-1]/1e6:.1f} MHz)."
85-
)
8629
interp_beams = []
8730
for idx, beam in enumerate(beam_list):
8831
interp_beams.append(
8932
beam.interp(freq_array=freq_grid, new_object=True)
9033
)
91-
logger.debug(f" Interpolated beam {idx + 1}/{n_beams}.")
9234

93-
# ------------------------------------------------------------------
94-
# Step 2: Extract the relevant slice of data_array and flatten.
95-
#
96-
# data_array shape: (Naxes_vec, 1, Nfeeds, Nfreqs, Nax1, Nax2)
97-
#
98-
# Polarized → use axes_vec and feeds: data_array[:, 0, :, ...]
99-
# flat shape per beam: Naxes_vec * Nfeeds * Nfreqs * Nax1 * Nax2
100-
# Unpolarized → match evaluate_beam scalar path: data_array[0, 0, 0, ...]
101-
# flat shape per beam: Nfreqs * Nax1 * Nax2
102-
# ------------------------------------------------------------------
10335
ref = interp_beams[0]
10436

105-
if polarized:
106-
# Shape of the slice we care about: (Naxes_vec, Nfeeds, Nfreqs, Nax1, Nax2)
107-
slice_shape = ref.data_array[:, 0, :, :, :, :].shape
108-
flat_vecs = np.array(
109-
[b.data_array[:, 0, :, :, :, :].ravel() for b in interp_beams]
110-
) # (N_beams, Naxes_vec * Nfeeds * Nfreqs * Nax1 * Nax2)
111-
else:
112-
# Shape: (Nfreqs, Nax1, Nax2)
113-
slice_shape = ref.data_array[0, 0, 0, :, :, :].shape
114-
flat_vecs = np.array(
115-
[b.data_array[0, 0, 0, :, :, :].ravel() for b in interp_beams]
116-
) # (N_beams, Nfreqs * Nax1 * Nax2)
117-
118-
logger.info(
119-
f"Beam matrix shape: {flat_vecs.shape} "
120-
f"({flat_vecs.nbytes / 1024**2:.1f} MB)."
121-
)
37+
# Shape of the slice we care about:
38+
# (Naxes_vec, Nfeeds, Nfreqs, Nax1) for healpix beams
39+
# (Naxes_vec, Nfeeds, Nfreqs, Nax1, Nax2) for gridded beams
40+
slice_shape = ref.data_array[:, :, 0].shape
41+
flat_vecs = np.array(
42+
[b.data_array[:, :, 0].ravel() for b in interp_beams]
43+
) # (N_beams, Naxes_vec * Nfeeds * Nfreqs * Nax1 * Nax2)
12244

12345
# ------------------------------------------------------------------
12446
# Step 3: SVD.
12547
# B = U @ diag(s) @ Vh → beam_i ≈ sum_k (U[i,k]*s[k]) * Vh[k]
12648
# ------------------------------------------------------------------
127-
logger.info("Computing SVD...")
12849
U, s, Vh = np.linalg.svd(flat_vecs, full_matrices=False)
12950

13051
# ------------------------------------------------------------------
13152
# Step 4: Truncate at normalised singular value threshold.
13253
# ------------------------------------------------------------------
13354
s_norm = s / s[0]
13455
K = int(np.sum(s_norm >= threshold))
135-
logger.info(
136-
f"Retaining {K}/{len(s)} basis components "
137-
f"(threshold={threshold}, "
138-
f"min retained s_norm={s_norm[K-1]:.3e})."
139-
)
14056

14157
U_k = U[:, :K] # (N_beams, K)
14258
s_k = s[:K] # (K,)
14359
Vh_k = Vh[:K, :] # (K, N_flat)
14460

14561
# ------------------------------------------------------------------
14662
# Step 5: Compute per-antenna coefficients.
147-
# coeffs[i, k] = U[i, k] * s[k] so that flat_vecs ≈ coeffs @ Vh_k
63+
# beam_coefs[i, k] = U[i, k] * s[k] so that flat_vecs ≈ beam_coefs @ Vh_k
14864
# ------------------------------------------------------------------
149-
coeffs = U_k * s_k[None, :] # (N_beams, K)
150-
65+
beam_coefs = U_k * s_k[None, :] # (N_beams, K)
66+
15167
# ------------------------------------------------------------------
15268
# Step 6: Build eigenbeam UVBeam objects by copying reference metadata
15369
# and replacing data_array with the reshaped Vh rows.
@@ -156,16 +72,7 @@ def compute_beam_basis(
15672
for k in range(K):
15773
eb = ref.copy()
15874
eigenbeam_slice = Vh_k[k].reshape(slice_shape)
159-
160-
if polarized:
161-
# Restore the dropped size-1 axis: (Naxes_vec, 1, Nfeeds, Nfreqs, Nax1, Nax2)
162-
eb.data_array = eigenbeam_slice[:, np.newaxis, :, :, :, :]
163-
else:
164-
# Restore to full data_array shape; fill non-primary components with zeros.
165-
eb.data_array = np.zeros_like(ref.data_array)
166-
eb.data_array[0, 0, 0, :, :, :] = eigenbeam_slice
167-
75+
eb.data_array = eigenbeam_slice[:, :, np.newaxis]
16876
eigenbeams.append(eb)
16977

170-
logger.info(f"Basis computation complete: {K} eigenbeams.")
171-
return eigenbeams, coeffs
78+
return eigenbeams, beam_coefs

src/fftvis/cpu/cpu_simulate.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def _compute_basis_visibilities(
308308
flux_here: np.ndarray,
309309
ant1_idxs: np.ndarray,
310310
ant2_idxs: np.ndarray,
311-
beam_coeffs: np.ndarray,
311+
beam_coefs: np.ndarray,
312312
freqidx: int,
313313
topo: np.ndarray,
314314
uvw: np.ndarray,
@@ -358,7 +358,7 @@ def _compute_basis_visibilities(
358358
``(nsrc, nfreqs, nfeeds, nfeeds)`` for polarized sky.
359359
ant1_idxs, ant2_idxs : np.ndarray
360360
Antenna indices for each baseline, each shape ``(nbls,)``.
361-
beam_coeffs : np.ndarray
361+
beam_coefs : np.ndarray
362362
Coefficients for each baseline, shape ``(nbls, nbasis)``.
363363
freqidx : int
364364
Frequency index into flux_here.
@@ -416,8 +416,8 @@ def _compute_basis_visibilities(
416416
# Gather coefficients once, outside the loop.
417417
# The measurement equation is V_ij = A_i^H C A_j, so the left (ant1)
418418
# coefficients are conjugated and the right (ant2) are not.
419-
ant1_c = beam_coeffs[ant1_idxs, :].conj() # C_ik^* (nbls, K)
420-
ant2_c = beam_coeffs[ant2_idxs, :] # C_jl (nbls, K)
419+
ant1_c = beam_coefs[ant1_idxs, :, freqidx].conj() # C_ik^* (nbls, K)
420+
ant2_c = beam_coefs[ant2_idxs, :, freqidx] # C_jl (nbls, K)
421421

422422
# Only iterate over the upper triangle (k <= l) and use the conjugate
423423
# symmetry V_tilde[l, k] = V_tilde[k, l]^* to handle the lower triangle
@@ -499,7 +499,7 @@ def _evaluate_vis_chunk_remote(
499499
type1_n_modes: int = None,
500500
trace_mem: bool = False,
501501
nchunks: int = 1,
502-
beam_coeffs: np.ndarray = None,
502+
beam_coefs: np.ndarray = None,
503503
):
504504
"""Ray-compatible remote version of _evaluate_vis_chunk."""
505505
engine = CPUSimulationEngine() # pragma: no cover
@@ -529,7 +529,7 @@ def _evaluate_vis_chunk_remote(
529529
type1_n_modes=type1_n_modes,
530530
trace_mem=trace_mem,
531531
nchunks=nchunks,
532-
beam_coeffs=beam_coeffs,
532+
beam_coefs=beam_coefs,
533533
)
534534

535535

@@ -567,18 +567,18 @@ def simulate(
567567
enable_memory_monitor: bool = False,
568568
nchunks: int = 1,
569569
source_buffer=1.0,
570-
beam_coeffs: np.ndarray = None,
570+
beam_coefs: np.ndarray = None,
571571
) -> np.ndarray:
572572
"""
573573
Simulate visibilities using CPU implementation.
574574
575575
Parameters
576576
----------
577-
beam_coeffs : np.ndarray, optional
577+
beam_coefs : np.ndarray, optional
578578
Per-antenna SVD coefficients of shape (N_ant, K). When provided,
579579
beam_list is interpreted as K basis beams rather than per-antenna
580580
beams. Visibilities are computed over all K^2 basis pairs and then
581-
contracted with beam_coeffs in post-processing.
581+
contracted with beam_coefs in post-processing.
582582
583583
See base class for all other parameter descriptions.
584584
"""
@@ -743,8 +743,8 @@ def simulate(
743743
freqs = ray.put(freqs)
744744
beam_list = ray.put(beam_list)
745745
coord_mgr = ray.put(coord_mgr)
746-
if beam_coeffs is not None:
747-
beam_coeffs = ray.put(beam_coeffs)
746+
if beam_coefs is not None:
747+
beam_coefs = ray.put(beam_coefs)
748748
if trace_mem:
749749
os.system("ray memory --units MB > after-puts.txt")
750750

@@ -803,7 +803,7 @@ def simulate(
803803
type1_n_modes=n_modes if is_gridded else None,
804804
trace_mem=(nprocesses > 1 or force_use_ray) and trace_mem,
805805
nchunks=nchunks,
806-
beam_coeffs=beam_coeffs,
806+
beam_coefs=beam_coefs,
807807
)
808808
)
809809
if trace_mem:
@@ -856,14 +856,14 @@ def _evaluate_vis_chunk(
856856
type1_n_modes: int = None,
857857
trace_mem: bool = False,
858858
nchunks: int = 1,
859-
beam_coeffs: np.ndarray = None,
859+
beam_coefs: np.ndarray = None,
860860
) -> np.ndarray:
861861
"""
862862
Evaluate a chunk of visibility data using CPU.
863863
864864
Parameters
865865
----------
866-
beam_coeffs : np.ndarray, optional
866+
beam_coefs : np.ndarray, optional
867867
Per-antenna SVD coefficients, shape (N_ant, K). When provided,
868868
beam_list contains the K basis beams and the standard beam-pair
869869
loop is replaced by the basis visibility path.
@@ -888,11 +888,11 @@ def _evaluate_vis_chunk(
888888

889889
coord_mgr.setup()
890890

891-
use_basis = beam_coeffs is not None
891+
use_basis = beam_coefs is not None
892892

893893
if use_basis:
894894
# Pre-compute the per-baseline antenna index arrays once.
895-
# These are used in post-processing to gather the right rows of beam_coeffs.
895+
# These are used in post-processing to gather the right rows of beam_coefs.
896896
ant1_idxs = np.array([antnums.index(bl[0]) for bl in baselines])
897897
ant2_idxs = np.array([antnums.index(bl[1]) for bl in baselines])
898898
else:
@@ -972,7 +972,7 @@ def _evaluate_vis_chunk(
972972
flux_here=flux,
973973
ant1_idxs=ant1_idxs,
974974
ant2_idxs=ant2_idxs,
975-
beam_coeffs=beam_coeffs,
975+
beam_coefs=beam_coefs,
976976
freqidx=freqidx,
977977
topo=topo,
978978
uvw=uvw if not use_type1 else None,

0 commit comments

Comments
 (0)