Skip to content

Commit 96d2383

Browse files
committed
unify ApiCatalog with DataProduct
1 parent 8423cf2 commit 96d2383

2 files changed

Lines changed: 132 additions & 37 deletions

File tree

oda_api/data_products.py

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -733,61 +733,100 @@ def from_file(
733733
def write_file(self, file_path, overwrite=True):
734734
self.write_fits_file(file_path, overwrite=overwrite)
735735

736-
class ApiCatalog(object):
737-
738-
739-
def __init__(self,cat_dict,name=None):
736+
class ApiCatalog(DataProduct):
737+
def __init__(self, catalog_data: dict | Table, name: str | None = None):
740738
self.name=name
741739
_skip_list=['meta_ID']
742-
meta = {}
743740

744-
lon_name = None
745-
if 'cat_lon_name' in cat_dict.keys():
746-
lon_name = cat_dict['cat_lon_name']
741+
if isinstance(catalog_data, dict):
742+
meta = {}
747743

748-
lat_name = None
749-
if 'cat_lat_name' in cat_dict.keys():
750-
lat_name = cat_dict['cat_lat_name']
744+
lon_name = None
745+
if 'cat_lon_name' in catalog_data.keys():
746+
lon_name = catalog_data['cat_lon_name']
751747

752-
frame = None
753-
if 'cat_frame' in cat_dict.keys():
754-
frame = cat_dict['cat_frame']
748+
lat_name = None
749+
if 'cat_lat_name' in catalog_data.keys():
750+
lat_name = catalog_data['cat_lat_name']
755751

756-
coord_units = None
757-
if 'cat_coord_units' in cat_dict.keys():
758-
coord_units = cat_dict['cat_coord_units']
752+
frame = None
753+
if 'cat_frame' in catalog_data.keys():
754+
frame = catalog_data['cat_frame']
759755

760-
if 'cat_meta' in cat_dict.keys():
761-
cat_meta_entry = cat_dict['cat_meta']
762-
meta.update(cat_meta_entry)
763-
764-
meta['FRAME'] = frame
765-
meta['COORD_UNIT'] = coord_units
766-
meta['LON_NAME'] = lon_name
767-
meta['LAT_NAME'] = lat_name
756+
coord_units = None
757+
if 'cat_coord_units' in catalog_data.keys():
758+
coord_units = catalog_data['cat_coord_units']
759+
760+
if 'cat_meta' in catalog_data.keys():
761+
cat_meta_entry = catalog_data['cat_meta']
762+
meta.update(cat_meta_entry)
763+
764+
meta['FRAME'] = frame
765+
meta['COORD_UNIT'] = coord_units
766+
meta['LON_NAME'] = lon_name
767+
meta['LAT_NAME'] = lat_name
768+
769+
self.table =Table(catalog_data['cat_column_list'], names=catalog_data['cat_column_names'],meta=meta)
768770

769-
self.table =Table(cat_dict['cat_column_list'], names=cat_dict['cat_column_names'],meta=meta)
771+
if coord_units is not None:
772+
self.table[lon_name]=Angle(self.table[lon_name],unit=coord_units)
773+
self.table[lat_name]=Angle(self.table[lat_name],unit=coord_units)
770774

771-
if coord_units is not None:
772-
self.table[lon_name]=Angle(self.table[lon_name],unit=coord_units)
773-
self.table[lat_name]=Angle(self.table[lat_name],unit=coord_units)
775+
self.lat_name=lat_name
776+
self.lon_name=lon_name
777+
else:
778+
self.table = catalog_data
779+
meta = getattr(self.table, 'meta', {})
780+
self.lat_name = meta.get('LAT_NAME')
781+
self.lon_name = meta.get('LON_NAME')
774782

775-
self.lat_name=lat_name
776-
self.lon_name=lon_name
777783

778-
def get_api_dictionary(self ):
784+
def get_api_dictionary(self, dump_string=True):
779785
column_lists = []
780786
for colname in self.table.colnames:
781787
column_lists.append([x if str(x) != 'nan' else None for x in self.table[colname]])
782788

783-
784-
return json.dumps(dict(cat_frame=self.table.meta['FRAME'], # pyright: ignore[reportOptionalSubscript]
789+
cat_dict = dict(cat_frame=self.table.meta['FRAME'], # pyright: ignore[reportOptionalSubscript]
785790
cat_coord_units=self.table.meta['COORD_UNIT'], # pyright: ignore[reportOptionalSubscript]
786791
cat_column_list=column_lists,
787792
cat_column_names=self.table.colnames,
788793
cat_column_descr=self.table.dtype.descr,
789794
cat_lat_name=self.lat_name,
790-
cat_lon_name=self.lon_name))
795+
cat_lon_name=self.lon_name)
796+
if dump_string:
797+
return json.dumps(cat_dict)
798+
else:
799+
return cat_dict
800+
801+
def encode(self):
802+
return self.get_api_dictionary(dump_string=False)
803+
804+
@classmethod
805+
def decode(cls, encoded_obj : str | dict[str, typing.Any], name: str | None = None) -> "ApiCatalog":
806+
if isinstance(encoded_obj, str):
807+
obj = json.loads(encoded_obj)
808+
else:
809+
obj = encoded_obj
810+
811+
return cls(obj, name = name)
812+
813+
def suggest_fn_extension(self) -> str:
814+
return 'ecsv'
815+
816+
def write_file(self, file_path, overwrite=True):
817+
# determine format from the file extension
818+
self.table.write(file_path, overwrite=overwrite)
819+
820+
@classmethod
821+
def from_file(cls, file_path: str, name: str | None = None, delimiter: str | None = None, format: str | None = None):
822+
allowed_formats=['ascii','ascii.ecsv','fits']
823+
if format in allowed_formats:
824+
kw = {'delimiter': delimiter} if format != 'fits' and delimiter is not None else {}
825+
table = Table.read(file_path, format=format, **kw)
826+
else:
827+
raise RuntimeError(f'Catalog file format not understood, allowed: {allowed_formats}')
828+
829+
return cls(table, name=name)
791830

792831
class GWEventContours:
793832
def __init__(self, event_contour_dict, name='') -> None:

tests/test_data_products.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
ODAAstropyTable,
1919
PictureProduct,
2020
BinaryProduct,
21-
TextLikeProduct)
21+
TextLikeProduct,
22+
ApiCatalog)
2223
from astropy import time as atime
2324
from astropy import units as u
2425
from astropy.table import Table
2526
from matplotlib import pyplot as plt
2627

28+
from astropy.utils.diff import report_diff_values
29+
2730
import base64
2831
import pickle
2932

@@ -320,6 +323,47 @@ def test_lightcurve_data_product_write_file_roundtrip(tmp_path):
320323
assert np.array_equal(loaded.data_unit[1].data['FLUX'], np.array(values))
321324
assert np.array_equal(loaded.data_unit[1].data['ERROR'], np.array(errors))
322325

326+
def test_catalog_data_product_write_file_roundtrip(tmp_path):
327+
catalog_ecsv = """\
328+
# %ECSV 1.0
329+
# ---
330+
# datatype:
331+
# - {name: meta_ID, datatype: int64}
332+
# - {name: src_names, datatype: string}
333+
# - {name: significance, datatype: float32}
334+
# - {name: ra, datatype: float32}
335+
# - {name: dec, datatype: float32}
336+
# - {name: NEW_SOURCE, datatype: uint16}
337+
# - {name: ISGRI_FLAG, datatype: int64}
338+
# - {name: FLAG, datatype: int64}
339+
# - {name: ERR_RAD, datatype: float64}
340+
# meta: !!omap
341+
# - {FRAME: fk5}
342+
# - {LAT_NAME: dec}
343+
# - {COORD_UNIT: deg}
344+
# - {LON_NAME: ra}
345+
# schema: astropy-2.0
346+
meta_ID src_names significance ra dec NEW_SOURCE ISGRI_FLAG FLAG ERR_RAD
347+
0 "IGR J15311-3737" 0.0 0.0 0.0 0 1 0 0.0750000029802
348+
1 "IGR J15409-4057" 0.804126 235.259 -40.9678 0 1 0 0.00860999990255
349+
"""
350+
ecsv_file_in = tmp_path / 'catalog.ecsv'
351+
with open(ecsv_file_in, 'w') as fd:
352+
fd.write(catalog_ecsv)
353+
354+
ecsv_file_out = tmp_path / 'catalog.ecsv'
355+
fits_file_out = tmp_path / 'catalog.fits'
356+
357+
catalog = ApiCatalog.from_file(ecsv_file_in, name = 'example catalog', format = 'ascii.ecsv')
358+
catalog.write_file(ecsv_file_out)
359+
catalog.write_file(fits_file_out)
360+
361+
with open(ecsv_file_out, 'r') as fd:
362+
assert fd.read() == catalog_ecsv
363+
364+
from_fits = ApiCatalog.from_file(fits_file_out, name = 'example catalog 1', format = 'fits')
365+
assert report_diff_values(from_fits.table, from_fits.table)
366+
323367

324368
def test_save_all_data_mixed_collection(tmp_path):
325369
table = Table({'a': [1, 2], 'b': [3, 4]})
@@ -333,7 +377,18 @@ def test_save_all_data_mixed_collection(tmp_path):
333377
picture = PictureProduct.from_file(str(plot_fn), name='pic')
334378
text = TextLikeProduct('hello', name='text')
335379
lc = LightCurveDataProduct.from_arrays(['2022-02-20T13:45:34', '2022-02-20T14:45:34'], fluxes=[2, 3], errors=[0.1, 0.2])
336-
dc = DataCollection([astable, bin_prod, picture, text, lc])
380+
ascii_catalog = """\
381+
meta_ID src_names significance ra dec NEW_SOURCE ISGRI_FLAG FLAG ERR_RAD
382+
0 "IGR J15311-3737" 0.0 0.0 0.0 0 1 0 0.0750000029802
383+
1 "IGR J15409-4057" 0.804126 235.259 -40.9678 0 1 0 0.00860999990255
384+
"""
385+
catalog_fn_in = tmp_path / 'catalog.csv'
386+
with open(catalog_fn_in, 'w') as fd:
387+
fd.write(ascii_catalog)
388+
389+
catalog = ApiCatalog.from_file(catalog_fn_in, name='catalog', format='ascii', delimiter=' ')
390+
391+
dc = DataCollection([astable, bin_prod, picture, text, lc, catalog])
337392

338393
out_dir = tmp_path / 'saved'
339394
os.makedirs(out_dir, exist_ok=True)
@@ -345,6 +400,7 @@ def test_save_all_data_mixed_collection(tmp_path):
345400
out_dir / 'mixed_pic_2.png',
346401
out_dir / 'mixed_text_3.txt',
347402
out_dir / 'mixed_prod_4.fits',
403+
out_dir / 'mixed_catalog_5.ecsv',
348404
]
349405

350406
for expected in expected_files:

0 commit comments

Comments
 (0)