|
5 | 5 |
|
6 | 6 | import astropy.units as u |
7 | 7 | import numpy as np |
8 | | -from astropy.coordinates import (Latitude, Longitude, SkyCoord, |
9 | | - SphericalRepresentation, |
| 8 | +from astropy.coordinates import (BaseCoordinateFrame, Latitude, Longitude, |
| 9 | + SkyCoord, SphericalRepresentation, |
10 | 10 | UnitSphericalRepresentation, |
11 | | - cartesian_to_spherical) |
| 11 | + cartesian_to_spherical, frame_transform_graph) |
12 | 12 |
|
13 | 13 | __all__ = [] |
14 | 14 |
|
15 | 15 |
|
| 16 | +def get_astropy_frame_class(frame): |
| 17 | + """ |
| 18 | + Get a frame class from the input `frame`, which could be a frame name |
| 19 | + string, or frame class. |
| 20 | +
|
| 21 | + Direct port of `_get_frame_class()` from |
| 22 | + https://github.com/astropy/astropy/blob/v7.2.0/astropy/coordinates/sky_coordinate_parsers.py |
| 23 | + to avoid importing private methods. |
| 24 | +
|
| 25 | + Parameters |
| 26 | + ---------- |
| 27 | + frame : str or `~astropy.coordinates.BaseCoordinateFrame` instance |
| 28 | + The frame as a string or astropy frame class. |
| 29 | +
|
| 30 | + Returns |
| 31 | + ------- |
| 32 | + frame_cls : A (sub)class of `~astropy.coordinates.BaseCoordinateFrame` |
| 33 | + The frame as an astropy frame class. |
| 34 | + """ |
| 35 | + if isinstance(frame, str): |
| 36 | + frame_names = frame_transform_graph.get_names() |
| 37 | + if frame not in frame_names: |
| 38 | + raise ValueError( |
| 39 | + f'Coordinate frame name "{frame}" is not a known ' |
| 40 | + f"coordinate frame ({sorted(frame_names)})", |
| 41 | + ) |
| 42 | + frame_cls = frame_transform_graph.lookup_name(frame) |
| 43 | + |
| 44 | + elif isinstance(frame, type) and issubclass(frame, BaseCoordinateFrame): |
| 45 | + frame_cls = frame |
| 46 | + |
| 47 | + else: |
| 48 | + raise ValueError( |
| 49 | + 'Coordinate frame must be a frame name or frame class, not a' |
| 50 | + f" '{frame.__class__.__name__}'", |
| 51 | + ) |
| 52 | + |
| 53 | + return frame_cls |
| 54 | + |
| 55 | + |
16 | 56 | def cross_product_skycoord2skycoord(c1, c2): |
17 | 57 | """ |
18 | 58 | Compute cross product of two sky coordinates (from a spherical |
|
0 commit comments