Skip to content

Commit dc74204

Browse files
Merge pull request #105 from HERA-Team/allow_arbitrary_sky_model_type
Allow for an arbitrary dtype for flux model
2 parents be75a55 + 75ef9e3 commit dc74204

5 files changed

Lines changed: 41 additions & 10 deletions

File tree

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ install_requires =
2424
line-profiler
2525
numpy>=2.0
2626
psutil
27-
pyuvdata>=3.1.2
27+
pyuvdata>=3.2.0
2828
rich
2929
scipy
3030
python_requires = >=3.10

src/matvis/_test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pyuvdata import UVBeam
1212
from pyuvdata.analytic_beam import GaussianBeam
1313
from pyuvdata.beam_interface import BeamInterface
14-
from pyuvdata.telescopes import get_telescope
14+
from pyuvdata.telescopes import Telescope
1515
from pyuvsim import simsetup
1616
from pyuvsim.telescope import BeamList
1717

@@ -34,7 +34,7 @@ def get_standard_sim_params(
3434
first_source_antizenith=False,
3535
):
3636
"""Create some standard random simulation parameters for use in tests."""
37-
hera = get_telescope("hera")
37+
hera = Telescope.from_known_telescopes("hera")
3838
obstime = Time("2018-08-31T04:02:30.11", format="isot", scale="utc")
3939

4040
rng = np.random.default_rng(1)

src/matvis/core/coords.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,15 @@ def __init__(
5353
self.xp = cp if self.gpu else np
5454

5555
self.precision = precision
56-
self.rtype, _ = get_dtypes(precision)
57-
self.flux = self.xp.asarray(flux.astype(self.rtype))
56+
self.rtype, self.ctype = get_dtypes(precision)
57+
58+
# Check if the flux is complex and set the dtype accordingly.
59+
if self.xp.iscomplexobj(flux):
60+
self.sky_model_dtype = self.ctype
61+
else:
62+
self.sky_model_dtype = self.rtype
63+
64+
self.flux = self.xp.asarray(flux.astype(self.sky_model_dtype))
5865

5966
self.nsrc = len(flux)
6067
self.times = times
@@ -82,7 +89,9 @@ def setup(self):
8289
(3, self.nsrc_alloc), self.rtype(0.0), dtype=self.rtype
8390
)
8491
self.flux_above_horizon = self.xp.full(
85-
(self.nsrc_alloc,) + self.flux.shape[1:], self.rtype(0.0), dtype=self.rtype
92+
(self.nsrc_alloc,) + self.flux.shape[1:],
93+
self.sky_model_dtype(0.0),
94+
dtype=self.sky_model_dtype,
8695
)
8796

8897
def select_chunk(self, chunk: int):

tests/test_coordrot.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from astropy import units as un
77
from astropy.coordinates import SkyCoord
88
from astropy.time import Time
9-
from pyuvdata.telescopes import get_telescope
9+
from pyuvdata.telescopes import Telescope
1010

1111
from matvis import HAVE_GPU
1212
from matvis.core.coords import CoordinateRotation
@@ -35,10 +35,32 @@ def get_angles(x, y):
3535
return xp.arccos(ratio)
3636

3737

38+
def test_complex_flux():
39+
"""Test that using a complex flux works appropriately."""
40+
rng = np.random.default_rng(1234)
41+
n = 23
42+
location = Telescope.from_known_telescopes("hera").location
43+
skycoords = SkyCoord(
44+
ra=rng.uniform(0, 2 * np.pi, size=n) * un.rad,
45+
dec=rng.uniform(-np.pi / 2, np.pi / 2, size=n) * un.rad,
46+
frame="icrs",
47+
)
48+
49+
coords = CoordinateRotationAstropy(
50+
flux=rng.normal(100, 2, size=n) + 1j * rng.normal(100, 2, size=n),
51+
times=Time(np.array([2459863.0]), format="jd", scale="utc"),
52+
telescope_loc=location,
53+
skycoords=skycoords,
54+
gpu=False,
55+
precision=2,
56+
)
57+
assert coords.sky_model_dtype == coords.ctype == np.complex128
58+
59+
3860
def get_random_coordrot(n, method, gpu, seed, precision=2, setup: bool = True, **kw):
3961
"""Get a random coordinate rotation object."""
4062
rng = np.random.default_rng(seed)
41-
location = get_telescope("hera").location
63+
location = Telescope.from_known_telescopes("hera").location
4264
skycoords = SkyCoord(
4365
ra=rng.uniform(0, 2 * np.pi, size=n) * un.rad,
4466
dec=rng.uniform(-np.pi / 2, np.pi / 2, size=n) * un.rad,

tests/test_matvis_cpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from astropy.time import Time
77
from pyuvdata.analytic_beam import GaussianBeam
8-
from pyuvdata.telescopes import get_telescope
8+
from pyuvdata.telescopes import Telescope
99

1010
from matvis import simulate_vis
1111

@@ -23,7 +23,7 @@
2323
def test_simulate_vis(polarized):
2424
"""Test basic operation of simple wrapper around matvis, `simulate_vis`."""
2525
# Point source equatorial coords (radians)
26-
hera = get_telescope("hera")
26+
hera = Telescope.from_known_telescopes("hera")
2727
ra = np.linspace(0.0, 2.0 * np.pi, NPTSRC)
2828
dec = np.linspace(-0.5 * np.pi, 0.5 * np.pi, NPTSRC)
2929

0 commit comments

Comments
 (0)