Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ mast

- Improved robustness of PanSTARRS column metadata parsing. This prevents metadata-related query errors. [#3485]

- The ``select_cols`` parameter in ``MastMissions`` query functions now accepts an iterable of column names, a comma-delimited
string of column names, or the special values 'all' or '*' to return all available columns. [#3492]

jplspec
^^^^^^^

Expand Down
98 changes: 79 additions & 19 deletions astroquery/mast/missions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import difflib
import warnings
from collections.abc import Iterable
from json import JSONDecodeError
from pathlib import Path
from urllib.parse import quote
Expand All @@ -21,7 +22,7 @@
from astroquery import log
from astroquery.utils import commons, async_to_sync
from astroquery.utils.class_or_instance import class_or_instance
from astroquery.exceptions import InvalidQueryError, MaxResultsWarning, NoResultsWarning
from astroquery.exceptions import InputWarning, InvalidQueryError, MaxResultsWarning, NoResultsWarning

from astroquery.mast import utils
from astroquery.mast.core import MastQueryWithLogin
Expand All @@ -43,7 +44,7 @@ class MastMissionsClass(MastQueryWithLogin):
_list_products = 'post_list_products'

# Workaround so that observation_id is returned in ULLYSES queries that do not specify columns
_default_ullyses_cols = ['target_name_ulysses', 'target_classification', 'targ_ra', 'targ_dec', 'host_galaxy_name',
_default_ullyses_cols = ['target_name_ullyses', 'target_classification', 'targ_ra', 'targ_dec', 'host_galaxy_name',
'spectral_type', 'bmv0_mag', 'u_mag', 'b_mag', 'v_mag', 'gaia_g_mean_mag', 'star_mass',
'instrument', 'grating', 'filter', 'observation_id']

Expand Down Expand Up @@ -197,6 +198,71 @@ def _build_params_from_criteria(self, params, **criteria):
value = [value]
params[prop] = value

def _parse_select_cols(self, select_cols):
"""
Parse the select_cols parameter to ensure it is in the correct format.

Parameters
----------
select_cols : iterable or str or None
The select_cols parameter to parse.

Returns
-------
list
A list of column names to select.

Raises
------
InvalidQueryError
If select_cols is not an iterable of strings, a comma-separated string, 'all', or '*'.
If any individual column name is not a string.
"""
if select_cols is None:
if self.mission == 'ullyses':
select_cols = self._default_ullyses_cols
return select_cols

# Handle special string cases first
all_columns = self.get_column_list()['name'].value.tolist()
if isinstance(select_cols, str):
if (select_cols.lower() == 'all' or select_cols == '*'):
return all_columns
# Comma-separated string
select_cols = select_cols.split(',')

# Handle an iterable
elif isinstance(select_cols, Iterable):
# Convert to list so we can iterate multiple times safely
select_cols = list(select_cols)

else:
raise InvalidQueryError(
"`select_cols` must be an iterable of column names, a comma-separated string, "
"'all', or '*'."
)

# Validate the column names
valid_select_cols = []
for col in select_cols:
if not isinstance(col, str):
raise InvalidQueryError(
"`select_cols` must contain only strings (column names)."
)
col = col.strip()
if col not in all_columns:
closest_match = difflib.get_close_matches(col, all_columns, n=1)
suggestion = f' Did you mean "{closest_match[0]}"?' if closest_match else ''
warnings.warn(f"Column '{col}' not found.{suggestion}", InputWarning)
else:
valid_select_cols.append(col)

# Dataset ID column should always be returned
dataset_col = self.dataset_kwds.get(self.mission, None)
if dataset_col and dataset_col not in valid_select_cols:
valid_select_cols.append(dataset_col)
return valid_select_cols

@class_or_instance
def query_region_async(self, coordinates, *, radius=3*u.arcmin, limit=5000, offset=0,
select_cols=None, **criteria):
Expand All @@ -217,9 +283,11 @@ def query_region_async(self, coordinates, *, radius=3*u.arcmin, limit=5000, offs
Default is 5000. The maximum number of dataset IDs in the results.
offset : int
Default is 0. The number of records you wish to skip before selecting records.
select_cols: list, optional
select_cols: iterable or str or None, optional
Default is None. Names of columns that will be included in the result table.
If None, a default set of columns will be returned.
Can either be an iterable of column names, a comma-separated string of column names,
or 'all'/'*' to return all available columns.
**criteria
Other mission-specific criteria arguments.
All valid filters can be found using `~astroquery.mast.missions.MastMissionsClass.get_column_list`
Expand Down Expand Up @@ -255,19 +323,13 @@ def query_region_async(self, coordinates, *, radius=3*u.arcmin, limit=5000, offs
f"Query radius too large. Must be ≤{self._max_query_radius}, got {radius}."
)

# Dataset ID column should always be returned
if select_cols:
select_cols.append(self.dataset_kwds.get(self.mission, None))
elif self.mission == 'ullyses':
select_cols = self._default_ullyses_cols

# Basic params
params = {'target': [f"{coordinates.ra.deg} {coordinates.dec.deg}"],
'radius': radius.arcsec,
'radius_units': 'arcseconds',
'limit': limit,
'offset': offset,
'select_cols': select_cols}
'select_cols': self._parse_select_cols(select_cols)}

self._build_params_from_criteria(params, **criteria)

Expand Down Expand Up @@ -295,9 +357,11 @@ def query_criteria_async(self, *, coordinates=None, objectname=None, radius=3*u.
Default is 5000. The maximum number of dataset IDs in the results.
offset : int
Default is 0. The number of records you wish to skip before selecting records.
select_cols: list, optional
select_cols: iterable or str or None, optional
Default is None. Names of columns that will be included in the result table.
If None, a default set of columns will be returned.
Can either be an iterable of column names, a comma-separated string of column names,
or 'all'/'*' to return all available columns.
resolver : str, optional
Default is None. The resolver to use when resolving a named target into coordinates. Valid options are
"SIMBAD" and "NED". If not specified, the default resolver order will be used. Please see the
Expand Down Expand Up @@ -344,14 +408,8 @@ def query_criteria_async(self, *, coordinates=None, objectname=None, radius=3*u.
f"Query radius too large. Must be ≤{self._max_query_radius}, got {radius}."
)

# Dataset ID column should always be returned
if select_cols:
select_cols.append(self.dataset_kwds.get(self.mission, None))
elif self.mission == 'ullyses':
select_cols = self._default_ullyses_cols

# build query
params = {"limit": self.limit, "offset": offset, 'select_cols': select_cols}
params = {"limit": self.limit, "offset": offset, 'select_cols': self._parse_select_cols(select_cols)}
if coordinates:
params["target"] = [f"{coordinates.ra.deg} {coordinates.dec.deg}"]
params["radius"] = radius.arcsec
Expand Down Expand Up @@ -382,9 +440,11 @@ def query_object_async(self, objectname, *, radius=3*u.arcmin, limit=5000, offse
Default is 5000. The maximum number of dataset IDs in the results.
offset : int
Default is 0. The number of records you wish to skip before selecting records.
select_cols: list, optional
select_cols: iterable or str or None, optional
Default is None. Names of columns that will be included in the result table.
If None, a default set of columns will be returned.
Can either be an iterable of column names, a comma-separated string of column names,
or 'all'/'*' to return all available columns.
resolver : str, optional
Default is None. The resolver to use when resolving a named target into coordinates. Valid options are
"SIMBAD" and "NED". If not specified, the default resolver order will be used. Please see the
Expand Down
38 changes: 34 additions & 4 deletions astroquery/mast/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..query import BaseQuery
from ..utils import async_to_sync
from ..utils.class_or_instance import class_or_instance
from ..exceptions import InvalidQueryError, TimeoutError, NoResultsWarning
from ..exceptions import BlankResponseWarning, InvalidQueryError, TimeoutError, NoResultsWarning

from . import conf

Expand Down Expand Up @@ -97,13 +97,43 @@ def _json_to_table(json_obj, data_key='data'):
# no consistent way to make the mask because np.equal fails on ''
# and array == value fails with None
if col_type == 'str':
col_mask = (col_data == ignore_value)
ignore_mask = (col_data == ignore_value)
else:
col_mask = np.equal(col_data, ignore_value)
ignore_mask = np.equal(col_data, ignore_value)

# add the column if it does not exist already
if col_name not in data_table.colnames:
data_table.add_column(MaskedColumn(col_data.astype(col_type), name=col_name, mask=col_mask))
try:
# Try to coerce entire column at once
coerced = col_data.astype(col_type)
data_table.add_column(
MaskedColumn(coerced, name=col_name, mask=ignore_mask)
)
except (ValueError, TypeError):
# Fallback to coercing values one by one
out = np.empty(len(col_data), dtype=col_type)
fail_mask = np.zeros(len(col_data), dtype=bool)
for i, val in enumerate(col_data):
if val == ignore_value:
# Ignored values are already masked by ignore_mask
continue

try:
out[i] = col_type(val)
except (ValueError, TypeError, OverflowError):
# Could not coerce value, mask it
fail_mask[i] = True

# Combined mask of ignored values + failed coercions
combined_mask = ignore_mask | fail_mask
if np.any(fail_mask):
warnings.warn(
f"Column '{col_name}': {np.sum(fail_mask)} values could not be coerced to {col_type} "
"and were masked.", BlankResponseWarning
)
data_table.add_column(
MaskedColumn(out, name=col_name, mask=combined_mask)
)

return data_table

Expand Down
76 changes: 74 additions & 2 deletions astroquery/mast/tests/test_mast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from astropy.utils.exceptions import AstropyDeprecationWarning
from astroquery.mast.services import _json_to_table
from astroquery.utils.mocks import MockResponse
from astroquery.exceptions import (InvalidQueryError, InputWarning, MaxResultsWarning, NoResultsWarning,
RemoteServiceError, ResolverError)
from astroquery.exceptions import (BlankResponseWarning, InvalidQueryError, InputWarning, MaxResultsWarning,
NoResultsWarning, RemoteServiceError, ResolverError)

from astroquery import mast

Expand Down Expand Up @@ -302,6 +302,56 @@ def test_missions_query_criteria(patch_post):
)


def test_missions_parse_select_cols(patch_post):
# Default columns
cols = mast.MastMissions._parse_select_cols(None) # Default columns for HST
assert cols is None

# All columns
all_cols = mast.MastMissions._parse_select_cols('all')
assert all_cols == mast.MastMissions.get_column_list()['name'].value.tolist()

# Comma-separated string
string_cols = mast.MastMissions._parse_select_cols('sci_pep_id, sci_instrume')
for col in ['sci_pep_id', 'sci_instrume', 'sci_data_set_name']:
assert col in string_cols

# List of columns
list_cols = mast.MastMissions._parse_select_cols(['sci_pep_id', 'sci_instrume'])
for col in ['sci_pep_id', 'sci_instrume', 'sci_data_set_name']:
assert col in list_cols

# Tuple of columns
tuple_cols = mast.MastMissions._parse_select_cols(('sci_pep_id', 'sci_instrume'))
for col in ['sci_pep_id', 'sci_instrume', 'sci_data_set_name']:
assert col in tuple_cols

# Generator of columns
gen_cols = mast.MastMissions._parse_select_cols(col for col in ['sci_pep_id', 'sci_instrume'])
for col in ['sci_pep_id', 'sci_instrume', 'sci_data_set_name']:
assert col in gen_cols

# Error if invalid type
with pytest.raises(InvalidQueryError, match="`select_cols` must be an iterable of column names"):
mast.MastMissions._parse_select_cols(123)

# Error if an individual column is not a string
with pytest.raises(InvalidQueryError, match="`select_cols` must contain only strings"):
mast.MastMissions._parse_select_cols(['sci_pep_id', 123])

# Warning for invalid column names
with pytest.warns(InputWarning, match="Column 'invalid_column' not found."):
valid_cols = mast.MastMissions._parse_select_cols(['sci_pep_id', 'invalid_column'])
assert 'sci_pep_id' in valid_cols
assert 'invalid_column' not in valid_cols

# Workaround for Ullyses mission default columns
ullyses_mission = mast.MastMissions(mission='ullyses')
ullyses_cols = ullyses_mission._parse_select_cols(None)
for col in mast.MastMissions._default_ullyses_cols:
assert col in ullyses_cols


def test_missions_get_product_list_async(patch_post):
# String input
result = mast.MastMissions.get_product_list_async('Z14Z0104T')
Expand Down Expand Up @@ -1485,3 +1535,25 @@ def test_parse_input_location(patch_post):
with pytest.warns(InputWarning, match="Resolver is only used when resolving object names"):
loc = mast.utils.parse_input_location(coordinates=coord, resolver="SIMBAD")
assert isinstance(loc, SkyCoord)


def test_json_to_table_fallback_type_coercion(patch_post):
json_obj = {'info': [{'name': 'test_int', 'type': 'int'}],
'data': [['1'], ['2'], ['not_an_int'], ['3'], [-999]]}

with pytest.warns(BlankResponseWarning):
table = _json_to_table(json_obj)

# Column exists
assert 'test_int' in table.colnames
col = table['test_int']
assert col.dtype == np.int64

# Good values survived
assert col[0] == 1
assert col[1] == 2
assert col[3] == 3

# Bad + ignored values are masked
assert col.mask[2] # 'not_an_int'
assert col.mask[4] # ignore_value
5 changes: 3 additions & 2 deletions astroquery/mast/tests/test_mast_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ def test_missions_query_criteria(self):
result = MastMissions.query_criteria(objectname='NGC6121',
radius=0.1,
sci_start_time='<2012',
sci_actual_duration='0..200'
)
sci_actual_duration='0..200',
select_cols='*')
assert len(result) == 3
assert result.colnames == MastMissions.get_column_list()['name'].value.tolist()
assert (result['ang_sep'].data.data.astype('float') < 0.1).all()
assert (result['sci_start_time'] < '2012').all()
assert ((result['sci_actual_duration'] >= 0) & (result['sci_actual_duration'] <= 200)).all()
Expand Down
4 changes: 3 additions & 1 deletion docs/mast/mast_missions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ Keyword arguments can also be used to refine results further. The following para
- ``sort_desc``: A boolean or list of booleans (one for each field specified in ``sort_by``),
describing if each field should be sorted in descending order (``True``) or ascending order (``False``).

- ``select_cols``: A list of columns to be returned in the response.
- ``select_cols``: Columns to include in the result table. If not specified, a default set of columns
is returned. This parameter may be given as an iterable of column names, a comma-separated string, or the special
values ``'all'`` or ``'*'`` to return all available columns.


Mission Positional Queries
Expand Down
Loading