Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 80 additions & 12 deletions backend/core/test/test_product_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from unittest import mock

import pandas as pd
from core.models import Product, ProductFile, ProductType, Release
from core.models import Product, ProductContent, ProductFile, ProductType, Release
from core.product_handle import ProductHandle
from core.product_steps import RegistryProduct
from core.table_data_collector import MainTableDataCollector
Expand Down Expand Up @@ -134,7 +134,7 @@ def create_multi_file_archive(self, extension="zip"):
archive.add(file_b, arcname=file_b.name)
return archive_path

def create_fake_hats_archive(self, properties_filename="hats.properties"):
def create_fake_hats_archive(self, properties_filename="hats.properties", data=None):
temp_dir = Path(tempfile.mkdtemp(prefix="pz_hats_archive_"))
hats_root = temp_dir / "mock_hats"
dataset_dir = hats_root / "dataset" / "Norder=0" / "Npix=0"
Expand All @@ -144,7 +144,14 @@ def create_fake_hats_archive(self, properties_filename="hats.properties"):
"catalog_name=mock_hats\ndataproduct_type=object\n",
encoding="utf-8",
)
(dataset_dir / "part-0.parquet").write_text("", encoding="utf-8")
if data is None:
data = pd.DataFrame(
[
{"ra": 0.0, "dec": 0.0, "z": 0.0},
{"ra": 1.0, "dec": -1.0, "z": 0.1},
]
)
data.to_parquet(dataset_dir / "part-0.parquet", index=False)

archive_path = temp_dir / "mock_hats.zip"
with zipfile.ZipFile(archive_path, "w") as archive:
Expand Down Expand Up @@ -176,8 +183,15 @@ def create_fake_hats_collection_archive(self):
"catalog_name=mock_margin\ndataproduct_type=margin\n",
encoding="utf-8",
)
(catalog_dataset_dir / "part-0.parquet").write_text("", encoding="utf-8")
(margin_dataset_dir / "part-0.parquet").write_text("", encoding="utf-8")
pd.DataFrame(
[
{"ra": 0.0, "dec": 0.0, "z": 0.0},
{"ra": 1.0, "dec": -1.0, "z": 0.1},
]
).to_parquet(catalog_dataset_dir / "part-0.parquet", index=False)
pd.DataFrame([{"margin_id": 1}]).to_parquet(
margin_dataset_dir / "part-0.parquet", index=False
)

archive_path = temp_dir / "mock_collection.zip"
with zipfile.ZipFile(archive_path, "w") as archive:
Expand All @@ -187,19 +201,43 @@ def create_fake_hats_collection_archive(self):
archive.write(file_path, arcname=arcname)
return archive_path

def create_fake_lsdb_module(self, n_rows=12, columns_as_index=False):
def create_fake_lsdb_module(
self,
n_rows=12,
columns_as_index=False,
columns=None,
all_columns=None,
preview_columns=None,
):
catalog_columns = columns or ["ra", "dec", "z"]
catalog_all_columns = (
all_columns if all_columns is not None else catalog_columns
)
catalog_preview_columns = preview_columns or catalog_columns

class FakeCatalog:
columns = pd.Index(["ra", "dec", "z"]) if columns_as_index else ["ra", "dec", "z"]
all_columns = ["ra", "dec", "z"]
columns = (
pd.Index(catalog_columns) if columns_as_index else catalog_columns
)
all_columns = catalog_all_columns

def __len__(self):
return n_rows

def head(self, n=5):
rows = [
{"ra": float(i), "dec": float(i) * -1.0, "z": float(i) * 0.1}
for i in range(n_rows)
]
rows = []
for i in range(n_rows):
row = {}
for column in catalog_preview_columns:
if column == "ra":
row[column] = float(i)
elif column == "dec":
row[column] = float(i) * -1.0
elif column == "z":
row[column] = float(i) * 0.1
else:
row[column] = f"{column}_{i}"
rows.append(row)
return pd.DataFrame(rows).head(n)

fake_module = types.ModuleType("lsdb")
Expand Down Expand Up @@ -428,6 +466,36 @@ def test_registry_handles_lsdb_columns_as_pandas_index(self):

self.assertEqual(response.status_code, 200)

def test_registry_uses_lsdb_columns_for_hats_content_columns(self):
product = self.create_product(product_type_name="validation_results")
expected_columns = ["object_id", "ra", "dec", "z", "quality_flag"]
archive_path = self.create_fake_hats_archive()
self.upload_main_file_from_path(product, archive_path)
url = reverse("products-registry", kwargs={"pk": product.pk})
fake_lsdb = self.create_fake_lsdb_module(
n_rows=1,
columns=expected_columns,
all_columns=expected_columns,
preview_columns=expected_columns,
)

with mock.patch.dict(sys.modules, {"lsdb": fake_lsdb}):
response = self.client.post(url)

self.assertEqual(response.status_code, 200)
registered_columns = list(
ProductContent.objects.filter(product=product)
.order_by("order")
.values_list("column_name", flat=True)
)
self.assertEqual(registered_columns, expected_columns)

preview_path = Path(
settings.MEDIA_ROOT, product.path, RegistryProduct.TABLE_PREVIEW_FILENAME
)
payload = json.loads(preview_path.read_text(encoding="utf-8"))
self.assertEqual(payload["columns"], expected_columns)

def test_registry_without_columns(self):

# Cria um novo produto.
Expand Down
Loading