Skip to content

Commit a603511

Browse files
committed
updated product registry tests
1 parent f78f7b7 commit a603511

1 file changed

Lines changed: 80 additions & 12 deletions

File tree

backend/core/test/test_product_registry.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from unittest import mock
1010

1111
import pandas as pd
12-
from core.models import Product, ProductFile, ProductType, Release
12+
from core.models import Product, ProductContent, ProductFile, ProductType, Release
1313
from core.product_handle import ProductHandle
1414
from core.product_steps import RegistryProduct
1515
from core.table_data_collector import MainTableDataCollector
@@ -134,7 +134,7 @@ def create_multi_file_archive(self, extension="zip"):
134134
archive.add(file_b, arcname=file_b.name)
135135
return archive_path
136136

137-
def create_fake_hats_archive(self, properties_filename="hats.properties"):
137+
def create_fake_hats_archive(self, properties_filename="hats.properties", data=None):
138138
temp_dir = Path(tempfile.mkdtemp(prefix="pz_hats_archive_"))
139139
hats_root = temp_dir / "mock_hats"
140140
dataset_dir = hats_root / "dataset" / "Norder=0" / "Npix=0"
@@ -144,7 +144,14 @@ def create_fake_hats_archive(self, properties_filename="hats.properties"):
144144
"catalog_name=mock_hats\ndataproduct_type=object\n",
145145
encoding="utf-8",
146146
)
147-
(dataset_dir / "part-0.parquet").write_text("", encoding="utf-8")
147+
if data is None:
148+
data = pd.DataFrame(
149+
[
150+
{"ra": 0.0, "dec": 0.0, "z": 0.0},
151+
{"ra": 1.0, "dec": -1.0, "z": 0.1},
152+
]
153+
)
154+
data.to_parquet(dataset_dir / "part-0.parquet", index=False)
148155

149156
archive_path = temp_dir / "mock_hats.zip"
150157
with zipfile.ZipFile(archive_path, "w") as archive:
@@ -176,8 +183,15 @@ def create_fake_hats_collection_archive(self):
176183
"catalog_name=mock_margin\ndataproduct_type=margin\n",
177184
encoding="utf-8",
178185
)
179-
(catalog_dataset_dir / "part-0.parquet").write_text("", encoding="utf-8")
180-
(margin_dataset_dir / "part-0.parquet").write_text("", encoding="utf-8")
186+
pd.DataFrame(
187+
[
188+
{"ra": 0.0, "dec": 0.0, "z": 0.0},
189+
{"ra": 1.0, "dec": -1.0, "z": 0.1},
190+
]
191+
).to_parquet(catalog_dataset_dir / "part-0.parquet", index=False)
192+
pd.DataFrame([{"margin_id": 1}]).to_parquet(
193+
margin_dataset_dir / "part-0.parquet", index=False
194+
)
181195

182196
archive_path = temp_dir / "mock_collection.zip"
183197
with zipfile.ZipFile(archive_path, "w") as archive:
@@ -187,19 +201,43 @@ def create_fake_hats_collection_archive(self):
187201
archive.write(file_path, arcname=arcname)
188202
return archive_path
189203

190-
def create_fake_lsdb_module(self, n_rows=12, columns_as_index=False):
204+
def create_fake_lsdb_module(
205+
self,
206+
n_rows=12,
207+
columns_as_index=False,
208+
columns=None,
209+
all_columns=None,
210+
preview_columns=None,
211+
):
212+
catalog_columns = columns or ["ra", "dec", "z"]
213+
catalog_all_columns = (
214+
all_columns if all_columns is not None else catalog_columns
215+
)
216+
catalog_preview_columns = preview_columns or catalog_columns
217+
191218
class FakeCatalog:
192-
columns = pd.Index(["ra", "dec", "z"]) if columns_as_index else ["ra", "dec", "z"]
193-
all_columns = ["ra", "dec", "z"]
219+
columns = (
220+
pd.Index(catalog_columns) if columns_as_index else catalog_columns
221+
)
222+
all_columns = catalog_all_columns
194223

195224
def __len__(self):
196225
return n_rows
197226

198227
def head(self, n=5):
199-
rows = [
200-
{"ra": float(i), "dec": float(i) * -1.0, "z": float(i) * 0.1}
201-
for i in range(n_rows)
202-
]
228+
rows = []
229+
for i in range(n_rows):
230+
row = {}
231+
for column in catalog_preview_columns:
232+
if column == "ra":
233+
row[column] = float(i)
234+
elif column == "dec":
235+
row[column] = float(i) * -1.0
236+
elif column == "z":
237+
row[column] = float(i) * 0.1
238+
else:
239+
row[column] = f"{column}_{i}"
240+
rows.append(row)
203241
return pd.DataFrame(rows).head(n)
204242

205243
fake_module = types.ModuleType("lsdb")
@@ -428,6 +466,36 @@ def test_registry_handles_lsdb_columns_as_pandas_index(self):
428466

429467
self.assertEqual(response.status_code, 200)
430468

469+
def test_registry_uses_lsdb_columns_for_hats_content_columns(self):
470+
product = self.create_product(product_type_name="validation_results")
471+
expected_columns = ["object_id", "ra", "dec", "z", "quality_flag"]
472+
archive_path = self.create_fake_hats_archive()
473+
self.upload_main_file_from_path(product, archive_path)
474+
url = reverse("products-registry", kwargs={"pk": product.pk})
475+
fake_lsdb = self.create_fake_lsdb_module(
476+
n_rows=1,
477+
columns=expected_columns,
478+
all_columns=expected_columns,
479+
preview_columns=expected_columns,
480+
)
481+
482+
with mock.patch.dict(sys.modules, {"lsdb": fake_lsdb}):
483+
response = self.client.post(url)
484+
485+
self.assertEqual(response.status_code, 200)
486+
registered_columns = list(
487+
ProductContent.objects.filter(product=product)
488+
.order_by("order")
489+
.values_list("column_name", flat=True)
490+
)
491+
self.assertEqual(registered_columns, expected_columns)
492+
493+
preview_path = Path(
494+
settings.MEDIA_ROOT, product.path, RegistryProduct.TABLE_PREVIEW_FILENAME
495+
)
496+
payload = json.loads(preview_path.read_text(encoding="utf-8"))
497+
self.assertEqual(payload["columns"], expected_columns)
498+
431499
def test_registry_without_columns(self):
432500

433501
# Cria um novo produto.

0 commit comments

Comments
 (0)