Skip to content

Commit 17e8e65

Browse files
ntessorepre-commit-ci[bot]paddyroddyconnoraird
authored
gh-910: consistent definition of displacement (#911)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick J. Roddy <patrickjamesroddy@gmail.com> Co-authored-by: connoraird <c.aird@ucl.ac.uk>
1 parent 89a5c97 commit 17e8e65

File tree

3 files changed

+151
-35
lines changed

3 files changed

+151
-35
lines changed

glass/points.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def displace(
510510

511511
d = xp.atan2(sa * sg, st * ca - ct * sa * cg)
512512

513-
return lon - d / math.pi * 180, tp / math.pi * 180
513+
return lon + d / math.pi * 180, tp / math.pi * 180
514514

515515

516516
def displacement(
@@ -528,9 +528,13 @@ def displacement(
528528
529529
Parameters
530530
----------
531-
from_lon, from_lat
531+
from_lon
532532
Points before displacement.
533-
to_lon, to_lat
533+
from_lat
534+
Points before displacement.
535+
to_lon
536+
Points after displacement.
537+
to_lat
534538
Points after displacement.
535539
536540
Returns
@@ -550,16 +554,14 @@ def displacement(
550554
use_compat=False,
551555
)
552556

553-
a = (90.0 - to_lat) / 180 * math.pi
554-
b = (90.0 - from_lat) / 180 * math.pi
555-
g = (from_lon - to_lon) / 180 * math.pi
557+
a = uxpx.radians(from_lat)
558+
b = uxpx.radians(to_lat)
559+
g = uxpx.radians(to_lon - from_lon)
556560

557561
sa, ca = xp.sin(a), xp.cos(a)
558562
sb, cb = xp.sin(b), xp.cos(b)
559563
sg, cg = xp.sin(g), xp.cos(g)
560564

561-
r = xp.atan2(xp.hypot(sa * cb - ca * sb * cg, sb * sg), ca * cb + sa * sb * cg)
562-
x = sb * ca - cb * sa * cg
563-
y = sa * sg
564-
z = xp.hypot(x, y)
565-
return r * (x / z + 1j * y / z)
565+
r = xp.atan2(xp.hypot(cb * sg, ca * sb - sa * cb * cg), sa * sb + ca * cb * cg)
566+
x = xp.atan2(cb * sg, ca * sb - sa * cb * cg)
567+
return r * xp.exp(1j * x)

tests/benchmarks/test_points.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from glass._types import UnifiedGenerator
1818
from tests.fixtures.helper_classes import (
19-
Compare,
2019
DataTransformer,
2120
GeneratorConsumer,
2221
)
@@ -103,31 +102,22 @@ def function_to_benchmark() -> list[Any]:
103102

104103

105104
@pytest.mark.parametrize(
106-
("r_to_alpha", "expected_lon", "expected_lat"),
105+
("r_to_alpha"),
107106
[
108107
# Complex
109-
(lambda r: r + 0j, 0.0, 5.0),
110-
(lambda r: -r + 0j, 0.0, -5.0),
111-
(lambda r: 1j * r, -5.0, 0.0),
112-
(lambda r: -1j * r, 5.0, 0.0),
108+
(lambda r: r + 0j),
113109
# Real
114-
(lambda r: [r, 0], 0.0, 5.0),
115-
(lambda r: [-r, 0], 0.0, -5.0),
116-
(lambda r: [0, r], -5.0, 0.0),
117-
(lambda r: [0, -r], 5.0, 0.0),
110+
(lambda r: [r, 0]),
118111
],
119112
)
120113
@pytest.mark.skipif(
121114
not hasattr(glass, "displace"),
122115
reason="test requires glass.displace",
123116
)
124-
def test_displace( # noqa: PLR0913
117+
def test_displace(
125118
benchmark: BenchmarkFixture,
126-
compare: type[Compare],
127119
xpb: ModuleType,
128120
r_to_alpha: Callable[[float], complex | list[float]],
129-
expected_lon: float,
130-
expected_lat: float,
131121
) -> None:
132122
"""Benchmark for glass.displace with complex values."""
133123
scale_length = 100_000
@@ -146,8 +136,8 @@ def test_displace( # noqa: PLR0913
146136
lat0,
147137
alpha,
148138
)
149-
compare.assert_allclose(lon, expected_lon, atol=1e-15)
150-
compare.assert_allclose(lat, expected_lat, atol=1e-15)
139+
assert lon.shape == (scale_length,)
140+
assert lat.shape == (scale_length,)
151141

152142

153143
@pytest.mark.stable

tests/core/test_points.py

Lines changed: 132 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import glass
1111
import glass.healpix as hp
12+
from glass._array_api_utils import xp_additions as uxpx
1213

1314
if TYPE_CHECKING:
1415
from types import ModuleType
@@ -356,11 +357,11 @@ def test_displace_arg_complex(compare: type[Compare], xp: ModuleType) -> None:
356357

357358
# east
358359
lon, lat = glass.displace(lon0, lat0, xp.asarray(1j * r))
359-
compare.assert_allclose([lon, lat], [-d, 0.0], atol=1e-15)
360+
compare.assert_allclose([lon, lat], [d, 0.0], atol=1e-15)
360361

361362
# west
362363
lon, lat = glass.displace(lon0, lat0, xp.asarray(-1j * r))
363-
compare.assert_allclose([lon, lat], [d, 0.0], atol=1e-15)
364+
compare.assert_allclose([lon, lat], [-d, 0.0], atol=1e-15)
364365

365366

366367
def test_displace_arg_real(compare: type[Compare], xp: ModuleType) -> None:
@@ -382,11 +383,11 @@ def test_displace_arg_real(compare: type[Compare], xp: ModuleType) -> None:
382383

383384
# east
384385
lon, lat = glass.displace(lon0, lat0, xp.asarray([0, r]))
385-
compare.assert_allclose([lon, lat], [-d, 0.0], atol=1e-15)
386+
compare.assert_allclose([lon, lat], [d, 0.0], atol=1e-15)
386387

387388
# west
388389
lon, lat = glass.displace(lon0, lat0, xp.asarray([0, -r]))
389-
compare.assert_allclose([lon, lat], [d, 0.0], atol=1e-15)
390+
compare.assert_allclose([lon, lat], [-d, 0.0], atol=1e-15)
390391

391392

392393
def test_displace_abs(
@@ -434,14 +435,14 @@ def test_displacement(
434435
data = [
435436
# equator
436437
(zero, zero, zero, five, deg5 * north),
437-
(zero, zero, -five, zero, deg5 * east),
438+
(zero, zero, five, zero, deg5 * east),
438439
(zero, zero, zero, -five, deg5 * south),
439-
(zero, zero, five, zero, deg5 * west),
440+
(zero, zero, -five, zero, deg5 * west),
440441
# pole
441442
(zero, ninety, ninety * 2, ninety - five, deg5 * north),
442-
(zero, ninety, -ninety, ninety - five, deg5 * east),
443+
(zero, ninety, ninety, ninety - five, deg5 * east),
443444
(zero, ninety, zero, ninety - five, deg5 * south),
444-
(zero, ninety, ninety, ninety - five, deg5 * west),
445+
(zero, ninety, -ninety, ninety - five, deg5 * west),
445446
]
446447

447448
# test each displacement individually
@@ -457,3 +458,126 @@ def test_displacement(
457458
urng.uniform(-90.0, 90.0, size=5),
458459
)
459460
assert alpha.shape == (20, 5)
461+
462+
463+
def test_displacement_zerodist(
464+
compare: type[Compare],
465+
urng: UnifiedGenerator,
466+
xp: ModuleType,
467+
) -> None:
468+
"""Check that zero displacement is computed correctly."""
469+
lon = urng.uniform(-180.0, 180.0, size=100)
470+
lat = urng.uniform(-90.0, 90.0, size=100)
471+
472+
compare.assert_allclose(
473+
glass.displacement(lon, lat, lon, lat),
474+
xp.zeros(100),
475+
)
476+
477+
478+
def test_displacement_consistent(
479+
compare: type[Compare],
480+
urng: UnifiedGenerator,
481+
xp: ModuleType,
482+
) -> None:
483+
"""Check displacement is consistent with displace."""
484+
n = 1_000
485+
486+
# magnitude and angle of displacement we want to achieve
487+
r = xp.acos(urng.uniform(-1.0, 1.0, size=n))
488+
x = urng.uniform(-math.pi, math.pi, size=n)
489+
490+
# displace at random positions on the sphere
491+
from_lon = urng.uniform(-180.0, 180.0, size=n)
492+
from_lat = xp.asin(urng.uniform(-1.0, 1.0, size=n)) / math.pi * 180.0
493+
494+
# compute the intended displacement
495+
alpha_in = r * xp.exp(1j * x)
496+
497+
# displace random points
498+
to_lon, to_lat = glass.displace(from_lon, from_lat, alpha_in)
499+
500+
# measure displacement
501+
alpha_out = glass.displacement(from_lon, from_lat, to_lon, to_lat)
502+
503+
compare.assert_allclose(alpha_out, alpha_in, atol=0.0, rtol=1e-10)
504+
505+
506+
def test_displacement_random(
507+
compare: type[Compare],
508+
urng: UnifiedGenerator,
509+
xp: ModuleType,
510+
) -> None:
511+
"""Check displacement for random points."""
512+
n = 1_000
513+
514+
# magnitude and angle of displacement we want to achieve
515+
r = xp.acos(urng.uniform(-1.0, 1.0, size=n))
516+
x = urng.uniform(-math.pi, math.pi, size=n)
517+
518+
# displacement at random positions on the sphere
519+
theta = xp.acos(urng.uniform(-1.0, 1.0, size=n))
520+
phi = urng.uniform(-math.pi, math.pi, size=n)
521+
522+
# rotation matrix that moves (0, 0, 1) to theta and phi
523+
zero = xp.zeros(n)
524+
one = xp.ones(n)
525+
rot_y = xp.stack(
526+
[
527+
xp.cos(theta), zero, xp.sin(theta),
528+
zero, one, zero,
529+
-xp.sin(theta), zero, xp.cos(theta),
530+
],
531+
axis=1,
532+
) # fmt: skip
533+
rot_z = xp.stack(
534+
[
535+
xp.cos(phi), -xp.sin(phi), zero,
536+
xp.sin(phi), xp.cos(phi), zero,
537+
zero, zero, one,
538+
],
539+
axis=1,
540+
) # fmt: skip
541+
rot = xp.reshape(rot_z, (n, 3, 3)) @ xp.reshape(rot_y, (n, 3, 3))
542+
543+
# meta-check that rotation works by rotating (0, 0, 1) to theta and phi
544+
u = xp.stack(
545+
[
546+
xp.sin(theta) * xp.cos(phi),
547+
xp.sin(theta) * xp.sin(phi),
548+
xp.cos(theta),
549+
],
550+
axis=1,
551+
)
552+
compare.assert_allclose(rot @ xp.asarray([0.0, 0.0, 1.0]), u)
553+
554+
# meta-check that recovering theta and phi from vector works
555+
compare.assert_allclose(xp.atan2(xp.hypot(u[:, 0], u[:, 1]), u[:, 2]), theta)
556+
compare.assert_allclose(xp.atan2(u[:, 1], u[:, 0]), phi)
557+
558+
# build the displaced points near (0, 0, 1) and rotate near theta and phi
559+
v = xp.stack(
560+
[
561+
xp.sin(r) * xp.cos(math.pi - x),
562+
xp.sin(r) * xp.sin(math.pi - x),
563+
xp.cos(r),
564+
],
565+
axis=1,
566+
)
567+
v = rot @ xp.reshape(v, (n, 3, 1))
568+
v = xp.reshape(v, (n, 3))
569+
570+
# compute displaced theta and phi
571+
theta_d = xp.atan2(xp.hypot(v[:, 0], v[:, 1]), v[:, 2])
572+
phi_d = xp.atan2(v[:, 1], v[:, 0])
573+
574+
# compute longitude and latitude
575+
from_lon = uxpx.degrees(phi)
576+
from_lat = 90.0 - uxpx.degrees(theta)
577+
to_lon = uxpx.degrees(phi_d)
578+
to_lat = 90.0 - uxpx.degrees(theta_d)
579+
580+
# compute displacement and compare to input
581+
alpha_in = r * xp.exp(1j * x)
582+
alpha_out = glass.displacement(from_lon, from_lat, to_lon, to_lat)
583+
compare.assert_allclose(alpha_out, alpha_in, atol=0.0, rtol=1e-10)

0 commit comments

Comments
 (0)