Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ jobs:
flags: unittests
name: codecov-umbrella
token: ${{ secrets.CODECOV_TOKEN }}
verbose: true
verbose: true
2 changes: 1 addition & 1 deletion docs/roman_catalog_process.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ Basic usage::

Notes
-----
This module follows the Roman Space Telescope data specifications v1.2.
This module follows the Roman Space Telescope data specifications v1.2.
6 changes: 3 additions & 3 deletions roman_photoz/create_simulated_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,11 @@ def add_error(
):
"""
Add a Gaussian error to each magnitude column in the catalog.

For each magnitude column, this method adds:

+ a Gaussian noise with a mean equal to the original value and a standard deviation of `mag_noise`

+ an error column (`<magnitude_column>_err`) with values sampled from a Gaussian distribution with a mean of 0 and a standard deviation of `mag_err`.

Parameters
Expand Down
119 changes: 91 additions & 28 deletions roman_photoz/roman_catalog_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,31 @@ class RomanCatalogProcess:
Informer stage for creating the library of SEDs.
estimated : RailStage
Estimator stage for finding the best fits from the library.
model_filename : str
Name of the pickle model file.
"""

def __init__(self, config_filename: Union[dict, str] = ""):
def __init__(
self,
config_filename: Union[dict, str] = "",
model_filename: str = "roman_model.pkl",
):
"""
Initialize the RomanCatalogProcess instance.

Parameters
----------
config_filename : Union[dict, str], optional
Path to the configuration file or a configuration dictionary.
model_filename : str, optional
Name of the pickle model file (default: "roman_model.pkl").
"""
self.data: dict = OrderedDict()
# set configuration file (roman will have its own)
self.set_config_file(config_filename)
# set model filename
self.model_filename = model_filename
self.informer_model_path = Path(LEPHAREWORK, self.model_filename).as_posix()
# set attributes used for determining the redshift
self.flux_cols: list = []
self.flux_err_cols: list = []
Expand Down Expand Up @@ -147,7 +158,7 @@ def format_data(self, cat_data: Table):
"""
# get information about Roman filters
bands = self.config["FILTER_LIST"].split(",")
print(len(bands))
print(f"Processing {len(bands)} bands")

# loop over the filters we want to keep to get
# the number of the filter, n, and the name, b
Expand All @@ -171,7 +182,7 @@ def create_informer_stage(self):
"""
# use the inform stage to create the library of SEDs with
# various redshifts, extinction parameters, and reddening values.
# -> Informer will produce as output a generic model,
# -> Informer will produce as output a generic "model",
# the details of which depends on the sub-class.
# |we use rail's interface here to create the informer stage
# |https://rail-hub.readthedocs.io/en/latest/api/rail.estimation.informer.html
Expand All @@ -188,7 +199,7 @@ def create_informer_stage(self):
self.inform_stage = LephareInformer.make_stage(
name="inform_roman",
nondetect_val=np.nan,
model=f"{Path(LEPHAREWORK, 'roman_model.pkl').as_posix()}",
model=self.informer_model_path,
hdf5_groupname="",
lephare_config=self.config,
star_config=None,
Expand All @@ -209,14 +220,18 @@ def create_estimator_stage(self):
# take the sythetic test data, and find the best
# fits from the library. This results in a PDF, zmode,
# and zmean for each input test data.
# -> Estimators use a generic model, apply the photo-z estimation
# and provide as output a QPEnsemble, with per-object p(z).
# -> Estimators use a generic "model", apply the photo-z estimation
# and provide as "output" a QPEnsemble, with per-object p(z).
# |we use rail's interface here to create the estimator stage
# |https://rail-hub.readthedocs.io/en/latest/api/rail.estimation.estimator.html
if self.informer_model_exists:
model = self.informer_model_path
else:
model = self.inform_stage.get_handle("model")
estimate_lephare = LephareEstimator.make_stage(
name="estimate_roman",
nondetect_val=np.nan,
model=self.inform_stage.get_handle("model"),
model=model,
hdf5_groupname="",
aliases=dict(input="test_data", output="lephare_estim"),
bands=self.flux_cols,
Expand Down Expand Up @@ -254,6 +269,7 @@ def save_results(
if self.estimated is not None:
ancil_data = self.estimated.data.ancil
else:
print("Error: No results to save.")
raise ValueError("No results to save.")

tree = {"roman_photoz_results": ancil_data}
Expand Down Expand Up @@ -289,26 +305,45 @@ def process(
cat_data = self.get_data(input_filename=input_filename, input_path=input_path)

self.format_data(cat_data)
self.create_informer_stage()
if not self.informer_model_exists:
print(
"Warning: The informer model file does not exist. Creating a new one..."
)
self.create_informer_stage()
self.create_estimator_stage()

if save_results:
self.save_results(output_filename=output_filename, output_path=output_path)

@property
def informer_model_exists(self):
"""
Check if the informer model file exists.

def main(argv=None):
"""
Main function to process Roman catalog data.
Returns
-------
bool
True if the model file exists, False otherwise.
"""
if os.path.exists(self.informer_model_path):
print(
f"The informer model file {self.informer_model_path} exists. Using it..."
)
return True
return False

Parameters
----------
argv : list, optional
List of command-line arguments.

def _get_parser():
"""
if argv is None:
# skip the first argument (script name)
argv = sys.argv[1:]
Create and return the argument parser for the roman_photoz command-line interface.

This function is used by both the main function and the Sphinx documentation.

Returns
-------
argparse.ArgumentParser
The configured argument parser
"""
parser = argparse.ArgumentParser(description="Process Roman catalog data.")
parser.add_argument(
"--config_filename",
Expand All @@ -317,6 +352,13 @@ def main(argv=None):
help="Path to the configuration file (default: use default Roman config).",
required=False,
)
parser.add_argument(
"--model_filename",
type=str,
default="roman_model.pkl",
help="Name of the pickle model file (default: roman_model.pkl).",
required=False,
)
parser.add_argument(
"--input_path",
type=str,
Expand Down Expand Up @@ -347,17 +389,38 @@ def main(argv=None):
default=True,
help="Save results? (default: True).",
)
return parser

args = parser.parse_args(argv)

rcp = RomanCatalogProcess(config_filename=args.config_filename)
def main(argv=None):
"""
Main function to process Roman catalog data.

rcp.process(
input_filename=args.input_filename,
input_path=args.input_path,
output_filename=args.output_filename,
output_path=args.output_path,
save_results=args.save_results,
)
Parameters
----------
argv : list, optional
List of command-line arguments.
"""
print("Starting Roman catalog processing...")
if argv is None:
# skip the first argument (script name)
argv = sys.argv[1:]

parser = _get_parser()
args = parser.parse_args(argv)

print("Done.")
try:
rcp = RomanCatalogProcess(
config_filename=args.config_filename, model_filename=args.model_filename
)
rcp.process(
input_filename=args.input_filename,
input_path=args.input_path,
output_filename=args.output_filename,
output_path=args.output_path,
save_results=args.save_results,
)
print("Processing completed successfully.")
except Exception as e:
print(f"An error occurred during processing: {str(e)}")
sys.exit(1)
82 changes: 60 additions & 22 deletions roman_photoz/tests/test_create_simulated_catalog.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from unittest.mock import MagicMock, patch

import numpy as np
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
from numpy.lib.recfunctions import merge_arrays

from roman_photoz.create_simulated_catalog import SimulatedCatalog
from roman_photoz.default_config_file import default_roman_config
from numpy.lib.recfunctions import merge_arrays

FILTER_LIST = (
default_roman_config.get("FILTER_LIST", "")
Expand All @@ -13,14 +14,18 @@
.split(",")
)


@pytest.fixture
def simulated_catalog():
return SimulatedCatalog()


def test_is_folder_not_empty(simulated_catalog):
with patch("pathlib.Path.exists", return_value=True), \
patch("pathlib.Path.is_dir", return_value=True), \
patch("pathlib.Path.glob", return_value=["file1", "file2"]):
with (
patch("pathlib.Path.exists", return_value=True),
patch("pathlib.Path.is_dir", return_value=True),
patch("pathlib.Path.glob", return_value=["file1", "file2"]),
):
assert simulated_catalog.is_folder_not_empty("dummy_path", "file") is True

with patch("pathlib.Path.exists", return_value=False):
Expand All @@ -30,12 +35,14 @@ def test_is_folder_not_empty(simulated_catalog):
with patch("pathlib.Path.glob", return_value=[]):
assert simulated_catalog.is_folder_not_empty("dummy_path", "file") is False


def test_add_ids(simulated_catalog):
catalog = np.array([(1.0, 2.0), (3.0, 4.0)], dtype=[("col1", "f8"), ("col2", "f8")])
updated_catalog = simulated_catalog.add_ids(catalog)
assert "id" in updated_catalog.dtype.names
assert np.array_equal(updated_catalog["id"], [1, 2])


@pytest.mark.parametrize(
"params",
[
Expand All @@ -52,49 +59,80 @@ def test_add_error(simulated_catalog, params):
# ensure that the new columns are added and values are within the expected range
assert "mag1_err" in updated_catalog.dtype.names
assert "mag2_err" in updated_catalog.dtype.names
assert (updated_catalog["mag1_err"][0] > 0) & (updated_catalog["mag1_err"][0] <= params["mag_err"])
assert (updated_catalog["mag2_err"][0] > 0) & (updated_catalog["mag2_err"][0] <= params["mag_err"])
assert (updated_catalog["mag1_err"][0] > 0) & (
updated_catalog["mag1_err"][0] <= params["mag_err"]
)
assert (updated_catalog["mag2_err"][0] > 0) & (
updated_catalog["mag2_err"][0] <= params["mag_err"]
)
# ensure that noise has been added to the original magnitudes
assert np.all(updated_catalog["mag1"] != catalog["mag1"])
assert np.all(updated_catalog["mag2"] != catalog["mag2"])


def test_pick_random_lines(simulated_catalog):
simulated_catalog.simulated_data = np.array(
[(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)],
dtype=[("col1", "f8"), ("col2", "f8")]
[(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)], dtype=[("col1", "f8"), ("col2", "f8")]
)
random_lines = simulated_catalog.pick_random_lines(2)
assert len(random_lines) == 2


def test_create_header(simulated_catalog):
# create a mock header as generated by lephare.prepare()
mock_file_content = "# model ext_law E(B-V) age N_filt magnitude_vector kcorr_vector"
mock_file_content = (
"# model ext_law E(B-V) age N_filt magnitude_vector kcorr_vector"
)
with patch("builtins.open", new_callable=MagicMock) as mock_open:
mock_open.return_value.__enter__.return_value.readline.return_value = mock_file_content
mock_open.return_value.__enter__.return_value.readline.return_value = (
mock_file_content
)
colnames = simulated_catalog.create_header("dummy_catalog")
# check that we have expanded the columns with _vector suffix into multiple columns (one for each filter)
assert all(f"magnitude_{filter}" in colnames for filter in FILTER_LIST)
assert all(f"kcorr_{filter}" in colnames for filter in FILTER_LIST)
# check that "#" and "age" have been removed
assert all(x not in colnames for x in ["#", "age"])


def test_update_roman_catalog_template(simulated_catalog):
# Create test data for fluxes and catalog
flux = np.array([(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], dtype=[(f"magnitude_{f}", "f8") for f in FILTER_LIST])
flux_err = np.array([(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], dtype=[(f"magnitude_{f}_err", "f8") for f in FILTER_LIST])
extra = np.array([(1, 2.0, 3.0, 3.0)], dtype=[("id", "i4"), ("context", "f8"), ("zspec", "f8"), ("z_true", "f8")])

flux = np.array(
[(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)],
dtype=[(f"magnitude_{f}", "f8") for f in FILTER_LIST],
)
flux_err = np.array(
[(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)],
dtype=[(f"magnitude_{f}_err", "f8") for f in FILTER_LIST],
)
extra = np.array(
[(1, 2.0, 3.0, 3.0)],
dtype=[("id", "i4"), ("context", "f8"), ("zspec", "f8"), ("z_true", "f8")],
)

# Merge fluxes into the catalog
catalog = merge_arrays([extra, flux, flux_err], flatten=True)
simulated_catalog.update_roman_catalog_template(catalog)
# check that the LePhare-required columns were added to the roman_catalog_template.source_catalog
assert all(x in simulated_catalog.roman_catalog_template.source_catalog.colnames for x in ["id", "context", "zspec", "string_data"])
assert all(
x in simulated_catalog.roman_catalog_template.source_catalog.colnames
for x in ["id", "context", "zspec", "string_data"]
)


def test_process(simulated_catalog):
with patch.object(simulated_catalog, "get_filters") as mock_get_filters, \
patch.object(simulated_catalog, "create_simulated_data") as mock_create_simulated_data, \
patch.object(simulated_catalog, "create_simulated_input_catalog") as mock_create_simulated_input_catalog:
simulated_catalog.process(output_path="dummy_path", output_filename="dummy_file")
with (
patch.object(simulated_catalog, "get_filters") as mock_get_filters,
patch.object(
simulated_catalog, "create_simulated_data"
) as mock_create_simulated_data,
patch.object(
simulated_catalog, "create_simulated_input_catalog"
) as mock_create_simulated_input_catalog,
):
simulated_catalog.process(
output_path="dummy_path", output_filename="dummy_file"
)
mock_get_filters.assert_called_once()
mock_create_simulated_data.assert_called_once()
mock_create_simulated_input_catalog.assert_called_once()
Loading
Loading