Skip to content

Commit 3526795

Browse files
authored
fix: add nans parameter to ballet_centroid for handling NaN values (#6)
fix: add nans parameter to ballet_centroid for handling NaN values (#6)
2 parents 5f1a115 + 305998f commit 3526795

5 files changed

Lines changed: 35 additions & 7 deletions

File tree

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@nox.session(python=ALL_PYTHON_VS)
99
def test(session):
10-
session.install(".[test]")
10+
session.install(".[test, jax]")
1111
session.run("pytest", "-n", "auto", *session.posargs)
1212

1313

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ classifiers = [
1414
"Programming Language :: Python",
1515
"Programming Language :: Python :: 3",
1616
]
17-
version = "0.0.10b"
17+
version = "0.0.11b"
1818
dependencies = [
1919
"numpy",
2020
"scikit-image",

src/eloy/centroid.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def photutils_centroid(data, coords, cutout=21, centroid_fun=None):
5757
return centroid_coords
5858

5959

60-
def ballet_centroid(data, coords, cnn):
60+
def ballet_centroid(data, coords, cnn, nans=False):
6161
"""
6262
Compute centroids for sources using a CNN-based model.
6363
@@ -69,11 +69,19 @@ def ballet_centroid(data, coords, cnn):
6969
Array of (x, y) coordinates for sources.
7070
cnn : object
7171
CNN model with a `centroid` method that accepts cutouts.
72+
nans : bool, optional
73+
If True, NaN values in the output will be replaced with the original coordinates.
7274
7375
Returns
7476
-------
7577
np.ndarray
7678
Array of refined centroid coordinates.
7779
"""
7880
cutouts = utils.cutout(data, coords, (15, 15), fill_value=np.median(data))
79-
return coords - 15 / 2 + cnn.centroid(cutouts)
81+
centroids = np.array(coords - 15 / 2 + cnn.centroid(cutouts))
82+
if not nans:
83+
is_nan = np.isnan(centroids).any(axis=1)
84+
centroids[is_nan] = coords[is_nan]
85+
return centroids
86+
else:
87+
return centroids

src/eloy/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from astropy.nddata import Cutout2D
99
import numpy as np
1010
from astropy.nddata import Cutout2D
11+
from astropy.nddata.utils import NoOverlapError
1112

1213

1314
def cutout(data, coords, shape, wcs=None, fill_value=np.nan):
@@ -31,10 +32,14 @@ def cutout(data, coords, shape, wcs=None, fill_value=np.nan):
3132
Array of cutout images.
3233
"""
3334
values = []
35+
dummy = np.zeros(shape, dtype=data.dtype)
3436
for coords in coords:
35-
cutout = Cutout2D(
36-
data, coords, shape, wcs=wcs, fill_value=fill_value, mode="partial"
37-
)
37+
try:
38+
cutout = Cutout2D(
39+
data, coords, shape, wcs=wcs, fill_value=fill_value, mode="partial"
40+
)
41+
except NoOverlapError:
42+
cutout = dummy
3843
values.append(cutout.data)
3944
return np.array(values)
4045

tests/test_centroids.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from eloy import utils, centroid, alignment
22
from photutils.centroids import centroid_com
33
import numpy as np
4+
import pytest
45

56

67
def test_centroid():
@@ -19,3 +20,17 @@ def test_centroid_out():
1920
data = np.random.rand(50, 50)
2021
coords = np.array([[-1, 1], [20, 20]])
2122
centroid.photutils_centroid(data, coords)
23+
24+
25+
def test_ballet_nans():
26+
"""Test ballet centroid with nans"""
27+
pytest.importorskip("jax")
28+
from eloy.centroid import ballet_centroid, Ballet
29+
30+
cnn = Ballet()
31+
32+
im = np.random.rand(20, 20)
33+
coords = np.array([[10.0, 10.0], [-20.0, -20.0]])
34+
35+
assert not np.any(np.isnan(ballet_centroid(im, coords, cnn, nans=False)))
36+
assert np.any(np.isnan(ballet_centroid(im, coords, cnn, nans=True)))

0 commit comments

Comments
 (0)