Skip to content

Commit 7a71ca0

Browse files
committed
Implement stand-alone port _get_frame_class
1 parent 6d3df5f commit 7a71ca0

2 files changed

Lines changed: 46 additions & 6 deletions

File tree

regions/_utils/spherical_helpers.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,54 @@
55

66
import astropy.units as u
77
import numpy as np
8-
from astropy.coordinates import (Latitude, Longitude, SkyCoord,
9-
SphericalRepresentation,
8+
from astropy.coordinates import (BaseCoordinateFrame, Latitude, Longitude,
9+
SkyCoord, SphericalRepresentation,
1010
UnitSphericalRepresentation,
11-
cartesian_to_spherical)
11+
cartesian_to_spherical, frame_transform_graph)
1212

1313
__all__ = []
1414

1515

16+
def get_astropy_frame_class(frame):
17+
"""
18+
Get a frame class from the input `frame`, which could be a frame name
19+
string, or frame class.
20+
21+
Direct port of `_get_frame_class()` from
22+
https://github.com/astropy/astropy/blob/v7.2.0/astropy/coordinates/sky_coordinate_parsers.py
23+
to avoid importing private methods.
24+
25+
Parameters
26+
----------
27+
frame : str or `~astropy.coordinates.BaseCoordinateFrame` instance
28+
The frame as a string or astropy frame class.
29+
30+
Returns
31+
-------
32+
frame_cls : A (sub)class of `~astropy.coordinates.BaseCoordinateFrame`
33+
The frame as an astropy frame class.
34+
"""
35+
if isinstance(frame, str):
36+
frame_names = frame_transform_graph.get_names()
37+
if frame not in frame_names:
38+
raise ValueError(
39+
f'Coordinate frame name "{frame}" is not a known '
40+
f"coordinate frame ({sorted(frame_names)})",
41+
)
42+
frame_cls = frame_transform_graph.lookup_name(frame)
43+
44+
elif isinstance(frame, type) and issubclass(frame, BaseCoordinateFrame):
45+
frame_cls = frame
46+
47+
else:
48+
raise ValueError(
49+
'Coordinate frame must be a frame name or frame class, not a'
50+
f" '{frame.__class__.__name__}'",
51+
)
52+
53+
return frame_cls
54+
55+
1656
def cross_product_skycoord2skycoord(c1, c2):
1757
"""
1858
Compute cross product of two sky coordinates (from a spherical

regions/core/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import astropy.units as u
77
import numpy as np
88
from astropy.coordinates import BaseCoordinateFrame, SkyCoord
9-
from astropy.coordinates.sky_coordinate_parsers import _get_frame_class
109

11-
from regions._utils.spherical_helpers import bounding_lonlat_poles_processing
10+
from regions._utils.spherical_helpers import (bounding_lonlat_poles_processing,
11+
get_astropy_frame_class)
1212
from regions.core.metadata import RegionMeta, RegionVisual
1313
from regions.core.pixcoord import PixCoord
1414
from regions.core.registry import RegionsRegistry
@@ -791,7 +791,7 @@ def union(self, other):
791791
@staticmethod
792792
def _standardize_frame(frame):
793793
# Standardize frame format: get as an astropy coordinate frame class
794-
frame = _get_frame_class(frame)
794+
frame = get_astropy_frame_class(frame)
795795
return frame
796796

797797
def _validate_frame(self, frame):

0 commit comments

Comments
 (0)