33import numpy as np
44from abc import ABC , abstractmethod
55from copy import deepcopy
6+ from dataclasses import replace
67from pyuvdata import UVBeam
8+ from pyuvdata .analytic_beam import AnalyticBeam
9+ from pyuvdata .beam_interface import BeamInterface
710from pyuvdata .utils .pol import polstr2num
811from typing import Any , Literal
912
1013
11- def prepare_beam_unpolarized (uvbeam , use_pol : Literal ["xy" , "xx" , "yx" , "yy" ] = "xx" ):
12- """Given a UVBeam or AnalyticBeam, prepare it for an un-polarized simulation."""
13- if uvbeam .beam_type == "power" and getattr (uvbeam , "Npols" , 1 ) == 1 :
14- return uvbeam
14+ def prepare_beam_unpolarized (
15+ beam : BeamInterface ,
16+ use_feed : Literal ["x" , "y" ] = "x" ,
17+ allow_beam_mutation : bool = False ,
18+ ) -> BeamInterface :
19+ """Given a BeamInterface, prepare it for an un-polarized simulation."""
20+ if beam .beam_type == "power" and beam .Npols == 1 :
21+ return beam
1522
16- uvbeam_ = uvbeam .copy () if isinstance (uvbeam , UVBeam ) else deepcopy (uvbeam )
17-
18- if uvbeam_ .beam_type == "efield" :
19- if isinstance (uvbeam , UVBeam ):
20- uvbeam_ .efield_to_power (calc_cross_pols = False )
21- else :
22- uvbeam_ .efield_to_power ()
23-
24- if getattr (uvbeam_ , "Npols" , 1 ) > 1 :
25- pol = polstr2num (use_pol )
23+ if beam .beam_type == "efield" :
24+ beam = beam .as_power_beam (
25+ include_cross_pols = False , allow_beam_mutation = allow_beam_mutation
26+ )
2627
27- if pol not in uvbeam_ .polarization_array :
28- raise ValueError (
29- f"You want to use { use_pol } pol, but it does not exist in the UVBeam"
30- )
31- uvbeam_ .select (polarizations = [pol ])
28+ if beam .Npols > 1 :
29+ beam = beam .with_feeds ([use_feed ])
3230
33- return uvbeam_
31+ return beam
3432
3533
3634def _wrangle_beams (
3735 beam_idx : np .ndarray | None ,
38- beam_list : list [UVBeam ],
36+ beam_list : list [BeamInterface | UVBeam | AnalyticBeam ],
3937 polarized : bool ,
4038 nant : int ,
4139 freq : float ,
42- ) -> tuple [list [UVBeam ], int , np .ndarray ]:
40+ ) -> tuple [list [BeamInterface ], int , np .ndarray ]:
4341 """Perform all the operations and checks on the input beams.
4442
4543 Checks that the beam indices match the number of antennas, pre-interpolates to the
@@ -61,6 +59,10 @@ def _wrangle_beams(
6159 """
6260 # Get the number of unique beams
6361 nbeam = len (beam_list )
62+ beam_list = [BeamInterface (beam ) for beam in beam_list ]
63+
64+ if not polarized :
65+ beam_list = [prepare_beam_unpolarized (beam ) for beam in beam_list ]
6466
6567 # Check the beam indices
6668 if beam_idx is None and nbeam not in (1 , nant ):
@@ -78,8 +80,12 @@ def _wrangle_beams(
7880 # make sure we interpolate to the right frequency first.
7981 beam_list = [
8082 (
81- bm .interp (freq_array = np .array ([freq ]), new_object = True , run_check = False )
82- if isinstance (bm , UVBeam )
83+ bm .clone (
84+ beam = bm .beam .interp (
85+ freq_array = np .array ([freq ]), new_object = True , run_check = False
86+ )
87+ )
88+ if bm ._isuvbeam
8389 else bm
8490 )
8591 for bm in beam_list
@@ -89,19 +95,6 @@ def _wrangle_beams(
8995 if any (b .beam_type != "efield" for b in beam_list ):
9096 raise ValueError ("beam type must be efield if using polarized=True" )
9197
92- # The following applies if we're not polarized
93- elif any (
94- (
95- b .beam_type != "power"
96- or getattr (b , "Npols" , 1 ) > 1
97- or b .polarization_array [0 ] not in [- 5 , - 6 ]
98- )
99- for b in beam_list
100- ):
101- raise ValueError (
102- "beam type must be power and have only one pol (either xx or yy) if polarized=False"
103- )
104-
10598 return beam_list , nbeam , beam_idx
10699
107100
@@ -132,7 +125,7 @@ class BeamInterpolator(ABC):
132125
133126 def __init__ (
134127 self ,
135- beam_list : list [UVBeam ],
128+ beam_list : list [BeamInterface ],
136129 beam_idx : np .ndarray ,
137130 polarized : bool ,
138131 nant : int ,
@@ -142,6 +135,7 @@ def __init__(
142135 precision : int = 1 ,
143136 ):
144137 self .polarized = polarized
138+
145139 self .beam_list , self .nbeam , self .beam_idx = _wrangle_beams (
146140 beam_idx = beam_idx ,
147141 beam_list = beam_list ,
0 commit comments