@@ -733,61 +733,100 @@ def from_file(
733733 def write_file (self , file_path , overwrite = True ):
734734 self .write_fits_file (file_path , overwrite = overwrite )
735735
736- class ApiCatalog (object ):
737-
738-
739- def __init__ (self ,cat_dict ,name = None ):
736+ class ApiCatalog (DataProduct ):
737+ def __init__ (self , catalog_data : dict | Table , name : str | None = None ):
740738 self .name = name
741739 _skip_list = ['meta_ID' ]
742- meta = {}
743740
744- lon_name = None
745- if 'cat_lon_name' in cat_dict .keys ():
746- lon_name = cat_dict ['cat_lon_name' ]
741+ if isinstance (catalog_data , dict ):
742+ meta = {}
747743
748- lat_name = None
749- if 'cat_lat_name ' in cat_dict .keys ():
750- lat_name = cat_dict [ 'cat_lat_name ' ]
744+ lon_name = None
745+ if 'cat_lon_name ' in catalog_data .keys ():
746+ lon_name = catalog_data [ 'cat_lon_name ' ]
751747
752- frame = None
753- if 'cat_frame ' in cat_dict .keys ():
754- frame = cat_dict [ 'cat_frame ' ]
748+ lat_name = None
749+ if 'cat_lat_name ' in catalog_data .keys ():
750+ lat_name = catalog_data [ 'cat_lat_name ' ]
755751
756- coord_units = None
757- if 'cat_coord_units ' in cat_dict .keys ():
758- coord_units = cat_dict [ 'cat_coord_units ' ]
752+ frame = None
753+ if 'cat_frame ' in catalog_data .keys ():
754+ frame = catalog_data [ 'cat_frame ' ]
759755
760- if 'cat_meta' in cat_dict .keys ():
761- cat_meta_entry = cat_dict ['cat_meta' ]
762- meta .update (cat_meta_entry )
763-
764- meta ['FRAME' ] = frame
765- meta ['COORD_UNIT' ] = coord_units
766- meta ['LON_NAME' ] = lon_name
767- meta ['LAT_NAME' ] = lat_name
756+ coord_units = None
757+ if 'cat_coord_units' in catalog_data .keys ():
758+ coord_units = catalog_data ['cat_coord_units' ]
759+
760+ if 'cat_meta' in catalog_data .keys ():
761+ cat_meta_entry = catalog_data ['cat_meta' ]
762+ meta .update (cat_meta_entry )
763+
764+ meta ['FRAME' ] = frame
765+ meta ['COORD_UNIT' ] = coord_units
766+ meta ['LON_NAME' ] = lon_name
767+ meta ['LAT_NAME' ] = lat_name
768+
769+ self .table = Table (catalog_data ['cat_column_list' ], names = catalog_data ['cat_column_names' ],meta = meta )
768770
769- self .table = Table (cat_dict ['cat_column_list' ], names = cat_dict ['cat_column_names' ],meta = meta )
771+ if coord_units is not None :
772+ self .table [lon_name ]= Angle (self .table [lon_name ],unit = coord_units )
773+ self .table [lat_name ]= Angle (self .table [lat_name ],unit = coord_units )
770774
771- if coord_units is not None :
772- self .table [lon_name ]= Angle (self .table [lon_name ],unit = coord_units )
773- self .table [lat_name ]= Angle (self .table [lat_name ],unit = coord_units )
775+ self .lat_name = lat_name
776+ self .lon_name = lon_name
777+ else :
778+ self .table = catalog_data
779+ meta = getattr (self .table , 'meta' , {})
780+ self .lat_name = meta .get ('LAT_NAME' )
781+ self .lon_name = meta .get ('LON_NAME' )
774782
775- self .lat_name = lat_name
776- self .lon_name = lon_name
777783
778- def get_api_dictionary (self ):
784+ def get_api_dictionary (self , dump_string = True ):
779785 column_lists = []
780786 for colname in self .table .colnames :
781787 column_lists .append ([x if str (x ) != 'nan' else None for x in self .table [colname ]])
782788
783-
784- return json .dumps (dict (cat_frame = self .table .meta ['FRAME' ], # pyright: ignore[reportOptionalSubscript]
789+ cat_dict = dict (cat_frame = self .table .meta ['FRAME' ], # pyright: ignore[reportOptionalSubscript]
785790 cat_coord_units = self .table .meta ['COORD_UNIT' ], # pyright: ignore[reportOptionalSubscript]
786791 cat_column_list = column_lists ,
787792 cat_column_names = self .table .colnames ,
788793 cat_column_descr = self .table .dtype .descr ,
789794 cat_lat_name = self .lat_name ,
790- cat_lon_name = self .lon_name ))
795+ cat_lon_name = self .lon_name )
796+ if dump_string :
797+ return json .dumps (cat_dict )
798+ else :
799+ return cat_dict
800+
801+ def encode (self ):
802+ return self .get_api_dictionary (dump_string = False )
803+
804+ @classmethod
805+ def decode (cls , encoded_obj : str | dict [str , typing .Any ], name : str | None = None ) -> "ApiCatalog" :
806+ if isinstance (encoded_obj , str ):
807+ obj = json .loads (encoded_obj )
808+ else :
809+ obj = encoded_obj
810+
811+ return cls (obj , name = name )
812+
813+ def suggest_fn_extension (self ) -> str :
814+ return 'ecsv'
815+
816+ def write_file (self , file_path , overwrite = True ):
817+ # determine format from the file extension
818+ self .table .write (file_path , overwrite = overwrite )
819+
820+ @classmethod
821+ def from_file (cls , file_path : str , name : str | None = None , delimiter : str | None = None , format : str | None = None ):
822+ allowed_formats = ['ascii' ,'ascii.ecsv' ,'fits' ]
823+ if format in allowed_formats :
824+ kw = {'delimiter' : delimiter } if format != 'fits' and delimiter is not None else {}
825+ table = Table .read (file_path , format = format , ** kw )
826+ else :
827+ raise RuntimeError (f'Catalog file format not understood, allowed: { allowed_formats } ' )
828+
829+ return cls (table , name = name )
791830
792831class GWEventContours :
793832 def __init__ (self , event_contour_dict , name = '' ) -> None :
0 commit comments