Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/roman_datamodels/datamodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@
from ._datamodels import * # noqa: F403

# rename rdm_open to open to match the current roman_datamodels API
from ._utils import FilenameMismatchWarning # noqa: F401
from ._utils import (
FilenameMismatchWarning, # noqa: F401
create_synchronized_table, # noqa: F401
)
from ._utils import rdm_open as open # noqa: F401
5 changes: 4 additions & 1 deletion src/roman_datamodels/datamodels/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ def save(self, path, dir_path=None, *args, all_array_compression="lz4", all_arra
output_path, *args, all_array_compression=all_array_compression, all_array_storage=all_array_storage, **kwargs
)
elif ext == ".parquet" and hasattr(self, "to_parquet"):
self.to_parquet(output_path)
to_parquet_kwargs = {}
if "ivoa_compliant" in kwargs:
to_parquet_kwargs["ivoa_compliant"] = kwargs["ivoa_compliant"]
self.to_parquet(output_path, **to_parquet_kwargs)
else:
raise ValueError(f"unknown filetype {ext}")

Expand Down
53 changes: 37 additions & 16 deletions src/roman_datamodels/datamodels/_datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
from collections import abc
from typing import TYPE_CHECKING

import astropy.table.meta
import numpy as np
from astropy import time as _time
from astropy.modeling import models

from ._core import DataModel
from ._utils import node_update, temporary_update_filedate, temporary_update_filename
from ._utils import (
create_synchronized_table,
node_update,
temporary_update_filedate,
temporary_update_filename,
)

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -84,16 +88,28 @@ class _ParquetMixin:

__slots__ = ()

def to_parquet(self, filepath):
def to_parquet(self, filepath, ivoa_compliant: bool = False):
"""
Save catalog in parquet format.
Save the catalog to a Parquet file preserving metadata.

Defers import of parquet to minimize import overhead for all other models.
Parameters
----------
filepath : str or Path
Path to the output Parquet file.
ivoa_compliant : bool, optional
If True, ensures units and metadata are formatted according to IVOA standards.

Notes
-----
- Validates the catalog before writing, as Parquet does not provide schema validation.
- Metadata is flattened and merged with table-level metadata for compatibility.
- Imports Parquet dependencies only when needed to minimize overhead.
- Optionally, column units and types can be synchronized for IVOA compliance.
"""
from roman_datamodels._stnode import DNode

# parquet does not provide validation so validate first with asdf
self.validate()
self.validate() # type: ignore[attr-defined]

global DTYPE_MAP
import pyarrow as pa
Expand All @@ -117,13 +133,14 @@ def to_parquet(self, filepath):
}
)

with temporary_update_filename(self, pathlib.Path(filepath).name), temporary_update_filedate(self, _time.Time.now()):
with temporary_update_filename(self, pathlib.Path(filepath).name), temporary_update_filedate(self, _time.Time.now()): # type: ignore[arg-type]
# Construct flat metadata dict
flat_meta = self.to_flat_dict()
flat_meta = self.to_flat_dict() # type: ignore[attr-defined]

# select only meta items
flat_meta = {k: str(v) for (k, v) in flat_meta.items() if k.startswith("roman.meta")}
# Extract table metadata
source_cat = self.source_catalog
source_cat = self.source_catalog # type: ignore[attr-defined]
scmeta = source_cat.meta
# Wrap it as a DNode so it can be flattened
dn_scmeta = DNode(scmeta)
Expand All @@ -136,14 +153,18 @@ def to_parquet(self, filepath):
keys = list(source_cat.columns.keys())
arrs = [np.array(source_cat[key]) for key in keys]
units = [str(source_cat[key].unit) for key in keys]
descriptions = [getattr(source_cat[key], "description", "") for key in keys]
dtypes = [DTYPE_MAP[np.array(source_cat[key]).dtype.name] for key in keys]
fields = [
pa.field(key, type=dtype, metadata={"unit": unit}) for (key, dtype, unit) in zip(keys, dtypes, units, strict=False)
]
extra_astropy_metadata = astropy.table.meta.get_yaml_from_table(source_cat)
flat_meta["table_meta_yaml"] = "\n".join(extra_astropy_metadata)
schema = pa.schema(fields, metadata=flat_meta)
table = pa.Table.from_arrays(arrs, schema=schema)
table = create_synchronized_table(
arrs,
keys,
units,
dtypes=dtypes,
global_meta=flat_meta,
ivoa_compliant=ivoa_compliant,
descriptions=descriptions,
table_meta=scmeta,
)
pq.write_table(table, filepath, compression=None)


Expand Down
137 changes: 136 additions & 1 deletion src/roman_datamodels/datamodels/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@
from roman_datamodels._stnode import DNode, LNode


__all__ = ["FilenameMismatchWarning", "node_update", "rdm_open", "temporary_update_filedate", "temporary_update_filename"]
__all__ = [
"FilenameMismatchWarning",
"create_synchronized_table",
"node_update",
"parse_units_to_ivoa",
"rdm_open",
"temporary_update_filedate",
"temporary_update_filename",
]


class FilenameMismatchWarning(UserWarning):
Expand Down Expand Up @@ -302,3 +310,130 @@ def rdm_open(init, memmap=False, **kwargs):
if not isinstance(init, asdf.AsdfFile):
asdf_file.close()
raise TypeError(f"Unknown datamodel type: {model_type}, please use asdf.open for non-roman_datamodels files")


def parse_units_to_ivoa(unit_strings: list[str]) -> list[str]:
"""
Convert a list of unit strings to their IVOA-compliant representations.

Parameters
----------
unit_strings : list of str
List of unit strings to be converted. Can include None, empty strings, or 'unitless'.

Returns
-------
list of str
List of IVOA-compliant unit strings. Dimensionless or unrecognized units are mapped to "1".
"""
from astropy import units as u

ivoa_list: list[str] = []
for s in unit_strings:
# Standardize dimensionless/null inputs to IVOA "1"
if s is None or str(s).lower() in ("none", "", "unitless"):
ivoa_list.append("1")
continue
try:
unit_obj = u.Unit(s)
if isinstance(unit_obj, u.function.core.FunctionUnitBase):
unit_str = unit_obj.to_string("generic")
if unit_str.startswith("mag"):
ivoa_list.append("mag")
else:
ivoa_list.append(unit_str)
else:
ivoa_list.append(unit_obj.to_string(format="vounit", deprecations="convert"))
except Exception as e:
warnings.warn(
f"Could not parse unit '{s}' to IVOA format: {e}. Using dimensionless unit '1'.", UserWarning, stacklevel=2
)
ivoa_list.append("1")

return ivoa_list


def create_synchronized_table(
arrs: list,
names: list[str],
units: list[str],
dtypes: list | None,
global_meta: dict | None,
ivoa_compliant: bool = False,
descriptions: list[str] | None = None,
table_meta: Mapping | None = None,
):
"""
Create a PyArrow table with synchronized field metadata and Astropy YAML metadata.

Parameters
----------
arrs : list
List of arrays or PyArrow columns.
names : list of str
Column names.
units : list of str
Unit strings for each column.
dtypes : list, optional
PyArrow data types.
global_meta : dict, optional
Existing global metadata to preserve.
ivoa_compliant : bool, optional
If True, convert units to IVOA-compliant strings using parse_units_to_ivoa.
Defaults to False (uses units as-is).
descriptions : list of str, optional
Description strings for each column.
table_meta : Mapping, optional
Table-level metadata to be embedded in the Astropy YAML sidecar and
restored when reading the Parquet file back into an Astropy Table.

Returns
-------
pyarrow.Table
A PyArrow Table with synchronized field-level unit metadata and Astropy YAML metadata embedded in the schema.
"""
import astropy.table.meta
import pyarrow as pa
from astropy.table import Table

# Determine final units to use
if ivoa_compliant:
# This uses your unified gatekeeper logic
final_units = parse_units_to_ivoa(units)
else:
# Default: Use exactly what was passed in, but ensure strings for .encode()
# We still handle None -> "" or "1" here to prevent encode errors
final_units = [str(u) if u is not None else "" for u in units]

# Build Fields with Field-Level Metadata
fields = []
for i, (name, unit) in enumerate(zip(names, final_units, strict=False)):
col_type = dtypes[i] if dtypes else arrs[i].type
# Only add metadata if the unit isn't an empty string
meta = {b"unit": unit.encode()} if unit else {}
fields.append(pa.field(name, type=col_type, metadata=meta))

# Build Temp Astropy Table for YAML Synchronization
temp_table = Table()
for i, name in enumerate(names):
# Convert to numpy for Astropy compatibility
temp_table[name] = arrs[i] if isinstance(arrs[i], np.ndarray) else arrs[i].to_numpy()
# Apply the final unit
temp_table[name].unit = final_units[i] if final_units[i] else None # type: ignore[attr-defined]
# Apply description if provided
if descriptions and descriptions[i]:
temp_table[name].description = descriptions[i] # type: ignore[attr-defined]

# Attach table-level metadata (e.g., aperture_radii, ee_fractions) so it is
# serialized into the YAML sidecar and restored on read.
if table_meta:
temp_table.meta.update(table_meta)

# Update Global Metadata (The "Astropy Sidecar")
updated_meta = dict(global_meta) if global_meta else {}
new_yaml = astropy.table.meta.get_yaml_from_table(temp_table)
updated_meta[b"table_meta_yaml"] = "\n".join(new_yaml).encode()

# Build and return
schema = pa.schema(fields, metadata=updated_meta)
return pa.Table.from_arrays(arrs, schema=schema)
Loading
Loading