Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions africanus/rime/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,40 @@

from africanus.constants import minus_two_pi_over_c
from africanus.util.docs import DocstringTemplate
from africanus.util.numba import generated_jit
from africanus.util.numba import generated_jit, njit, is_numba_type_none
from africanus.util.type_inference import infer_complex_dtype


def _out_factory(out_present):
if out_present:
def impl(out, shape, dtype):
# TODO(sjperkins) Check the dtype too?
if out.shape != shape:
raise ValueError("out.shape does not match expected shape")

return out
else:
def impl(out, shape, dtype):
return np.zeros(shape, dtype)

return njit(nogil=True, cache=True)(impl)


@generated_jit(nopython=True, nogil=True, cache=True)
def phase_delay(lm, uvw, frequency):
def phase_delay(lm, uvw, frequency, out=None):
have_out = not is_numba_type_none(out)

# Bake constants in with the correct type
one = lm.dtype(1.0)
neg_two_pi_over_c = lm.dtype(minus_two_pi_over_c)

out_dtype = infer_complex_dtype(lm, uvw, frequency)

create_output = _out_factory(have_out)

@wraps(phase_delay)
def _phase_delay_impl(lm, uvw, frequency):
def _phase_delay_impl(lm, uvw, frequency, out=None):
shape = (lm.shape[0], uvw.shape[0], frequency.shape[0])
complex_phase = np.zeros(shape, dtype=out_dtype)
complex_phase = create_output(out, shape, out_dtype)

# For each source
for source in range(lm.shape[0]):
Expand Down Expand Up @@ -87,6 +105,9 @@ def _phase_delay_impl(lm, uvw, frequency):
U, V and W components in the last dimension.
frequency : $(array_type)
frequencies of shape :code:`(chan,)`
out : $(array_type), optional
Array holding the output results. Should have the
same shape as the returned `complex_phase`.

Returns
-------
Expand Down
10 changes: 10 additions & 0 deletions africanus/rime/tests/test_rime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Tests for `codex-africanus` package."""

import numpy as np
from numpy.testing import assert_array_equal

import pytest

Expand Down Expand Up @@ -45,6 +46,15 @@ def test_phase_delay():
phase = minus_two_pi_over_c*(u*l + v*m + w*n)*freq
assert np.all(np.exp(1j*phase) == complex_phase[lm_i, uvw_i, freq_i])

# Test that we can supply an out parameter
out = np.zeros_like(complex_phase)
complex_phase_2 = phase_delay(lm, uvw, frequency, out=out)

# Result matches first version
assert_array_equal(complex_phase, complex_phase_2)
# Check that the result is in the original variable we passed in
assert out is complex_phase_2


def test_feed_rotation():
import numpy as np
Expand Down