Skip to content

Commit 41b18f2

Browse files
committed
save_all_data() and related tests
1 parent 2893582 commit 41b18f2

3 files changed

Lines changed: 74 additions & 9 deletions

File tree

oda_api/api.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from . import __version__
3636
from . import colors as C
3737
from . import custom_formatters
38-
from .data_products import (ApiCatalog, BinaryData, BinaryProduct,
38+
from .data_products import (ApiCatalog, BinaryData, BinaryProduct, DataProduct,
3939
GWContoursDataProduct, NumpyDataProduct,
4040
ODAAstropyTable, PictureProduct, TextLikeProduct)
4141

@@ -1196,11 +1196,17 @@ def __repr__(self):
11961196

11971197
class DataCollection(object):
11981198

1199-
def __init__(self, data_list, add_meta_to_name=None, instrument=None, product=None, request_job_id=None):
1199+
def __init__(self,
1200+
data_list: list[DataProduct],
1201+
add_meta_to_name: list[str] | None = None,
1202+
instrument: str | None = None,
1203+
product: str | None =None,
1204+
request_job_id: str | None=None
1205+
):
12001206
if add_meta_to_name is None:
12011207
add_meta_to_name = ['src_name', 'product']
1202-
self._p_list = []
1203-
self._n_list = []
1208+
self._p_list: list[DataProduct] = []
1209+
self._n_list: list[str] = []
12041210
self.request_job_id = request_job_id
12051211
for ID, data in enumerate(data_list):
12061212

@@ -1260,15 +1266,21 @@ def _build_prod_name(self, prod, name, add_meta_to_name):
12601266
name += '_' + s.strip()
12611267
return name, oda_api.misc_helpers.clean_var_name(name)
12621268

1263-
def save_all_data(self, prenpend_name=None):
1269+
def save_all_data(self, prenpend_name = None, overwrite=True):
1270+
# NOTE: prepend_name also determines file path
12641271
for pname, prod in zip(self._n_list, self._p_list, strict=False):
1272+
if not isinstance(prod, DataProduct):
1273+
logger.warning(f"Writing on disk is not implemented for product {pname} of type {pname.__class__.__name__}, skipping.")
1274+
continue
12651275
if prenpend_name is not None:
12661276
file_name = prenpend_name + '_' + pname
12671277
else:
12681278
file_name = pname
12691279

1270-
file_name = file_name + '.fits'
1271-
prod.write_fits_file(file_name)
1280+
fn_extension = prod.suggest_fn_extension()
1281+
file_name = f"{file_name}.{fn_extension}"
1282+
1283+
prod.write_file(file_name, overwrite=overwrite)
12721284

12731285
def save(self, file_name):
12741286
pickle.dump(self, open(file_name, 'wb'),
@@ -1363,7 +1375,7 @@ def from_response_json(cls, res_json, instrument, product):
13631375
d = cls(data, instrument=instrument, product=product, request_job_id=request_job_id)
13641376
for p in d._p_list:
13651377
if hasattr(p, 'meta_data') is False and hasattr(p, 'meta') is True:
1366-
p.meta_data = p.meta
1378+
p.meta_data = p.meta # type:ignore
13671379

13681380
return d
13691381

oda_api/data_products.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def _chekc_enc_data(data):
4949

5050

5151
class DataProduct(ABC):
52+
name: str | None
53+
meta_data: dict
54+
5255
@abstractmethod
5356
def encode(self, *args, **kwargs) -> str | dict[str, typing.Any]: ...
5457

@@ -192,7 +195,7 @@ def __init__(self, bin_data: bytes, name: str | None = None):
192195
def suggest_fn_extension(self) -> str:
193196
suggested_extension = puremagic.from_string(self.bin_data)
194197
if suggested_extension:
195-
return suggested_extension
198+
return suggested_extension.strip('.')
196199
return 'bin'
197200

198201
def encode(self):

tests/test_data_products.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import datetime
2+
import logging
23

34
import pytest
45
import json
@@ -8,6 +9,7 @@
89
import time
910
import os
1011
import typing
12+
from oda_api.api import DataCollection
1113
from oda_api.json import CustomJSONEncoder
1214
import filecmp
1315

@@ -317,3 +319,51 @@ def test_lightcurve_data_product_write_file_roundtrip(tmp_path):
317319
loaded = NumpyDataProduct.from_file(str(out_fn))
318320
assert np.array_equal(loaded.data_unit[1].data['FLUX'], np.array(values))
319321
assert np.array_equal(loaded.data_unit[1].data['ERROR'], np.array(errors))
322+
323+
324+
def test_save_all_data_mixed_collection(tmp_path):
325+
table = Table({'a': [1, 2], 'b': [3, 4]})
326+
astable = ODAAstropyTable(table, name='table')
327+
bin_prod = BinaryProduct.from_file('tests/test_data/lc.fits', name='binprd')
328+
329+
plot_fn = tmp_path / 'plot.png'
330+
plt.plot([1, 2], [1, 0])
331+
plt.savefig(str(plot_fn))
332+
plt.close()
333+
picture = PictureProduct.from_file(str(plot_fn), name='pic')
334+
text = TextLikeProduct('hello', name='text')
335+
lc = LightCurveDataProduct.from_arrays(['2022-02-20T13:45:34', '2022-02-20T14:45:34'], fluxes=[2, 3], errors=[0.1, 0.2])
336+
dc = DataCollection([astable, bin_prod, picture, text, lc])
337+
338+
out_dir = tmp_path / 'saved'
339+
os.makedirs(out_dir, exist_ok=True)
340+
dc.save_all_data(prenpend_name=str(out_dir / 'mixed'))
341+
342+
expected_files = [
343+
out_dir / 'mixed_table_0.fits',
344+
out_dir / 'mixed_binprd_1.fits',
345+
out_dir / 'mixed_pic_2.png',
346+
out_dir / 'mixed_text_3.txt',
347+
out_dir / 'mixed_prod_4.fits',
348+
]
349+
350+
for expected in expected_files:
351+
assert expected.exists()
352+
353+
354+
def test_save_all_data_empty_collection(tmp_path):
355+
dc = DataCollection([])
356+
dc.save_all_data(prenpend_name=str(tmp_path / 'empty'))
357+
assert not any(tmp_path.iterdir())
358+
359+
360+
def test_save_all_data_skips_unsupported_product(caplog, tmp_path):
361+
caplog.set_level(logging.WARNING)
362+
text = TextLikeProduct('hello', name='text')
363+
unsupported = object()
364+
dc = DataCollection([text, unsupported])
365+
366+
dc.save_all_data(prenpend_name=str(tmp_path / 'skipped'))
367+
368+
assert 'Writing on disk is not implemented for' in caplog.text
369+
assert (tmp_path / 'skipped_text_0.txt').exists()

0 commit comments

Comments
 (0)