diff --git a/.github/workflows/run-unit-tests.yml b/.github/workflows/run-unit-tests.yml index c179b3d..10a3324 100644 --- a/.github/workflows/run-unit-tests.yml +++ b/.github/workflows/run-unit-tests.yml @@ -1,4 +1,4 @@ -name: Run Unit Tests +name: Run Unit Tests with Coverage on: push: @@ -8,6 +8,11 @@ on: branches: - main +permissions: + contents: read + checks: write + pull-requests: write + jobs: run-unit-tests: runs-on: ubuntu-latest @@ -31,5 +36,7 @@ jobs: python -m pip install --upgrade pip pip install ".[dev]" - - name: Run unit tests - run: pytest --cov=. + - name: Run unit tests with coverage + run: | + coverage run -m pytest + coverage report --fail-under=70 diff --git a/pyproject.toml b/pyproject.toml index 042408b..2f5d365 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ test = [ "pytest-astropy >= 0.11.0", "deepdiff", "stpreview>=0.5.1", + "pytest-cov", ] dev = ["roman_photoz[docs,test]", "tox > 4", "pre-commit > 3"] sdp = ["stpreview>=0.5.1"] @@ -128,3 +129,22 @@ archs = ["x86_64", "arm64"] [tool.cibuildwheel.linux] archs = ["auto", "aarch64"] + +[tool.coverage.run] +source = ["roman_photoz"] +omit = [ + "*/tests/*", + "*/docs/*", +] + +[tool.coverage.report] +show_missing = true +skip_covered = true +precision = 2 +fail_under = 70 + +[tool.coverage.html] +directory = "htmlcov" + +[tool.coverage.xml] +output = "coverage.xml" diff --git a/roman_photoz/create_roman_filters.py b/roman_photoz/create_roman_filters.py index b7fa99f..0e7ab5d 100644 --- a/roman_photoz/create_roman_filters.py +++ b/roman_photoz/create_roman_filters.py @@ -92,7 +92,7 @@ def create_files(data: pd.DataFrame, filepath: str = "") -> None: for col in data.columns[1:]: output_data = data[[wave, col]] # convert wavelength from um to A - output_data[wave] = output_data[wave] * 1e4 + output_data.loc[:, wave] = output_data[wave] * 1e4 filename = "roman" + "_".join(col.split(" ")).strip() + ".pb" first_line = f"# {col} (Roman filter info obtained from {BASE_URL.format(DEFAULT_FILE_DATE)})" fq_path = path / filename diff --git a/roman_photoz/tests/test_create_roman_filters.py b/roman_photoz/tests/test_create_roman_filters.py new file mode 100644 index 0000000..1d09a15 --- /dev/null +++ b/roman_photoz/tests/test_create_roman_filters.py @@ -0,0 +1,265 @@ +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from roman_photoz.create_roman_filters import ( + BASE_URL, + create_files, + create_path, + create_roman_phot_par_file, + download_file, + read_effarea_file, + run, + run_filter_command, +) + + +@pytest.fixture +def mock_response(): + """Create a mock requests response object.""" + mock = MagicMock() + mock.raise_for_status = MagicMock() + mock.content = b"Test content" + return mock + + +@pytest.fixture +def sample_dataframe(): + """Create a sample DataFrame for testing filter creation.""" + # Create a basic dataframe with wavelength and filter data + data = { + "wavelength": [0.1, 0.2, 0.3], + "F087": [0.8, 0.9, 0.7], + "F106": [0.7, 0.8, 0.6], + } + return pd.DataFrame(data) + + +def test_download_file(mock_response): + """Test the download_file function correctly downloads and saves a file.""" + with ( + patch("builtins.open", MagicMock()) as mock_open, + patch("requests.get", return_value=mock_response) as mock_get, + ): + url = "http://test.url/file.xlsx" + dest = "test_file.xlsx" + download_file(url, dest) + + # Verify the function called the correct URL with timeout + mock_get.assert_called_once_with(url, timeout=30) + # Verify response was checked for errors + mock_response.raise_for_status.assert_called_once() + # Verify file was opened and written with correct content + mock_open.assert_called_once_with(dest, "wb") + mock_open.return_value.__enter__.return_value.write.assert_called_once_with( + mock_response.content + ) + + +@pytest.mark.parametrize( + "file_exists, test_file, expected_kwargs, expected_download", + [ + # Case 1: File exists, no download needed + ( + True, + "test_file.xlsx", + {"header": 1}, + None, + ), + # Case 2: File doesn't exist, download needed + ( + False, + "test_Roman_effarea_20220101.xlsx", + {}, + { + "url": BASE_URL.format("20220101"), + "dest": Path("test_Roman_effarea_20220101.xlsx").resolve().as_posix(), + }, + ), + ], + ids=["file_exists", "file_download_needed"], +) +def test_read_effarea_file(file_exists, test_file, expected_kwargs, expected_download): + """Test reading an efficiency area file, with or without downloading it first.""" + mock_df = pd.DataFrame({"wavelength": [1, 2, 3]}) + + with patch("pathlib.Path.exists", return_value=file_exists) as mock_exists: + # Set up additional mocks based on whether download is expected + if file_exists: + # Simple case: file exists, just mock pandas.read_excel + with patch("pandas.read_excel", return_value=mock_df) as mock_read_excel: + result = read_effarea_file(test_file, **expected_kwargs) + + # Verify the function returns the expected DataFrame + assert result is mock_df + # Verify read_excel was called with expected parameters + file_path = Path(test_file).resolve() + mock_read_excel.assert_called_once_with(file_path, **expected_kwargs) + else: + # Complex case: file doesn't exist, mock download_file and pandas.read_excel + with ( + patch( + "roman_photoz.create_roman_filters.download_file" + ) as mock_download, + patch("pandas.read_excel", return_value=mock_df) as mock_read_excel, + ): + # Create a resolved path for testing + test_path = Path(test_file).resolve() + result = read_effarea_file(test_path.as_posix()) + + # Verify download was called with correct URL and destination + mock_download.assert_called_once_with( + expected_download["url"], expected_download["dest"] + ) + # Verify read_excel was called and the function returns the expected DataFrame + mock_read_excel.assert_called_once() + assert result is mock_df + + +def test_create_files(sample_dataframe): + """Test creating filter files from a DataFrame.""" + test_path = Path("test_path") + + with ( + patch( + "roman_photoz.create_roman_filters.create_path", return_value=test_path + ) as mock_create_path, + patch("builtins.open", MagicMock()) as mock_open, + patch( + "roman_photoz.create_roman_filters.create_roman_phot_par_file" + ) as mock_create_par, + ): + create_files(sample_dataframe, "test_path") + + # Verify create_path was called with the correct path + mock_create_path.assert_called_once_with("test_path") + + # Verify files were created for each filter column + assert ( + mock_open.call_count == 2 + ) # One for each filter column (excluding wavelength) + + # Verify create_roman_phot_par_file was called with correct filter list + mock_create_par.assert_called_once_with( + ["romanF087.pb", "romanF106.pb"], test_path + ) + + +def test_create_roman_phot_par_file(): + """Test creating the roman_phot.par file.""" + filter_list = ["filter1.pb", "filter2.pb"] + filter_rep = Path("test_filter_path") + + with patch("builtins.open", MagicMock()) as mock_open: + create_roman_phot_par_file(filter_list, filter_rep) + + # Verify file was opened at the correct path + mock_open.assert_called_once_with(filter_rep / "roman_phot.par", "w") + + # Verify file content contains the filter list and repository path + file_content = mock_open.return_value.__enter__.return_value.write.call_args[0][ + 0 + ] + assert "FILTER_LIST filter1.pb,filter2.pb" in file_content + assert f"FILTER_REP {filter_rep.as_posix()}" in file_content + assert "FILTER_CALIB 0,0" in file_content + + +@pytest.mark.parametrize( + "input_path, env_vars, expected_path", + [ + # Case 1: Default path (no arguments) + ( + None, + {}, + Path("/current/dir"), + ), + # Case 2: LEPHAREDIR environment variable is set + ( + None, + {"LEPHAREDIR": "/lephare/dir"}, + Path("/lephare/dir/filt/roman"), + ), + # Case 3: Custom path provided + ( + "/custom/path", + {}, + Path("/custom/path"), + ), + ], + ids=["default_path", "lepharedir_path", "custom_path"], +) +def test_create_path(input_path, env_vars, expected_path): + """Test create_path function with different scenarios using parametrization.""" + with ( + patch.dict(os.environ, env_vars, clear=True), + patch("pathlib.Path.resolve", return_value=expected_path), + patch("pathlib.Path.mkdir") as mock_mkdir, + ): + # Call the function with or without input path + path = create_path(input_path) if input_path is not None else create_path() + + # Verify the path matches the expected path + assert path == expected_path + + # Verify directory was created + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + +def test_run_filter_command(): + """Test running the filter command.""" + with ( + patch( + "roman_photoz.create_roman_filters.create_path", + return_value=Path("/test/path"), + ) as mock_create_path, + patch("roman_photoz.create_roman_filters.Filter") as mock_filter, + ): + # Configure the mock Filter instance + mock_filter_instance = MagicMock() + mock_filter.return_value = mock_filter_instance + + run_filter_command("/test/config/path") + + # Verify create_path was called with the config path + mock_create_path.assert_called_once_with(filepath="/test/config/path") + + # Verify Filter was instantiated with the correct config file path + mock_filter.assert_called_once_with(config_file="/test/path/roman_phot.par") + + # Verify run method was called on the Filter instance + mock_filter_instance.run.assert_called_once() + + +def test_run(): + """Test the main run function.""" + with ( + patch("roman_photoz.create_roman_filters.read_effarea_file") as mock_read, + patch("roman_photoz.create_roman_filters.get_auxiliary_data") as mock_get_data, + patch("roman_photoz.create_roman_filters.create_files") as mock_create_files, + patch( + "roman_photoz.create_roman_filters.run_filter_command" + ) as mock_run_filter, + ): + # Configure mock read_effarea_file to return a dataframe + mock_df = pd.DataFrame({"wavelength": [1, 2, 3]}) + mock_read.return_value = mock_df + + run("test_input.xlsx", "/test/output/path") + + # Verify read_effarea_file was called with the correct parameters + mock_read.assert_called_once_with(filename="test_input.xlsx", header=1) + + # Verify get_auxiliary_data was called + mock_get_data.assert_called_once() + + # Verify create_files was called with the dataframe and output path + mock_create_files.assert_called_once_with( + data=mock_df, filepath="/test/output/path" + ) + + # Verify run_filter_command was called with the output path + mock_run_filter.assert_called_once_with(config_file_path="/test/output/path") diff --git a/roman_photoz/tests/test_default_config_file.py b/roman_photoz/tests/test_default_config_file.py new file mode 100644 index 0000000..234a6f9 --- /dev/null +++ b/roman_photoz/tests/test_default_config_file.py @@ -0,0 +1,114 @@ +import os +from pathlib import Path +from unittest.mock import patch + +from roman_photoz.default_config_file import LEPHAREDIR, default_roman_config + + +def test_lepharedir_environment_variable(): + """Test that LEPHAREDIR is correctly set from environment or fallback.""" + # We need to ensure the module is reloaded after environment changes + import sys + + # Remove the module if it's already imported to ensure a fresh import + if "roman_photoz.default_config_file" in sys.modules: + del sys.modules["roman_photoz.default_config_file"] + + # Test when environment variable is set + test_lepharedir = "/test/lephare/dir" + with patch.dict(os.environ, {"LEPHAREDIR": test_lepharedir}, clear=True): + import roman_photoz.default_config_file + + assert roman_photoz.default_config_file.LEPHAREDIR == test_lepharedir + + # Clean up for next test + del sys.modules["roman_photoz.default_config_file"] + + # Test fallback to lp.LEPHAREDIR + with ( + patch.dict(os.environ, {}, clear=True), + patch("lephare.LEPHAREDIR", "/fallback/lephare/dir"), + ): + import roman_photoz.default_config_file + + assert roman_photoz.default_config_file.LEPHAREDIR == "/fallback/lephare/dir" + + +def test_cwd_is_current_directory(): + """Test that CWD is the current working directory.""" + with patch("os.getcwd", return_value="/mock/current/dir"): + from importlib import reload + + import roman_photoz.default_config_file + + reload(roman_photoz.default_config_file) + assert roman_photoz.default_config_file.CWD == "/mock/current/dir" + + +def test_default_config_contains_required_keys(): + """Test that default_roman_config contains all required configuration keys.""" + required_keys = [ + "FILTER_LIST", + "FILTER_REP", + "FILTER_FILE", + "CAT_IN", + "CAT_OUT", + "PARA_OUT", + "GAL_LIB", + "GAL_LIB_IN", + "GAL_LIB_OUT", + "ZPHOTLIB", + "Z_INTERP", + "Z_METHOD", + "Z_RANGE", + "Z_STEP", + ] + + for key in required_keys: + assert ( + key in default_roman_config + ), f"Required key '{key}' missing from default_roman_config" + + +def test_para_out_path_is_valid(): + """Test that PARA_OUT path is a valid path relative to the module.""" + para_out_path = default_roman_config["PARA_OUT"] + assert isinstance(para_out_path, str) + + # Verify it points to an expected location under the package + path_obj = Path(para_out_path) + assert "roman_photoz" in path_obj.parts + assert "data" in path_obj.parts + assert path_obj.name == "default_roman_output.para" + + +def test_filter_list_format(): + """Test that FILTER_LIST is properly formatted.""" + filter_list = default_roman_config["FILTER_LIST"] + + # Should be a comma-separated list of filter paths + filters = filter_list.split(",") + assert len(filters) == 8, "Expected 8 filters in FILTER_LIST" + + # Each filter should follow the expected pattern + for filter_path in filters: + assert filter_path.startswith( + "roman/roman_F" + ), f"Filter {filter_path} doesn't follow naming convention" + assert filter_path.endswith( + ".pb" + ), f"Filter {filter_path} doesn't have .pb extension" + + +def test_filter_rep_uses_lepharedir(): + """Test that FILTER_REP uses the LEPHAREDIR value.""" + filter_rep = default_roman_config["FILTER_REP"] + assert LEPHAREDIR in filter_rep + assert filter_rep == f"{LEPHAREDIR}/filt" + + +def test_all_variable_contains_default_roman_config(): + """Test that __all__ contains 'default_roman_config'.""" + from roman_photoz.default_config_file import __all__ + + assert "default_roman_config" in __all__ diff --git a/roman_photoz/tests/test_roman_catalog_handler.py b/roman_photoz/tests/test_roman_catalog_handler.py new file mode 100644 index 0000000..441f3aa --- /dev/null +++ b/roman_photoz/tests/test_roman_catalog_handler.py @@ -0,0 +1,174 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from roman_photoz.roman_catalog_handler import RomanCatalogHandler + + +@pytest.fixture +def mock_catalog_data(roman_catalog_handler): + """Create mock catalog data for testing""" + # Get the actual filter names used by the handler + filter_names = roman_catalog_handler.filter_names + + # Create fields dynamically based on actual filter names + field_list = [("id", "i4")] + for name in filter_names: + field_name = f"{name.replace('roman_', '')}_flux_psf" + field_err_name = f"{name.replace('roman_', '')}_flux_psf_err" + field_list.append((field_name, "f8")) + field_list.append((field_err_name, "f8")) + + field_list.extend([("context", "i4"), ("zspec", "f8"), ("string_data", "S20")]) + + dt = np.dtype(field_list) + + # Create sample data + data = np.zeros(3, dtype=dt) + data["id"] = [1, 2, 3] + data["context"] = [1, 1, 2] + data["zspec"] = [0.5, 1.0, 1.5] + data["string_data"] = [b"source1", b"source2", b"source3"] + + # Add test values for each filter field + for i, name in enumerate(filter_names): + field_name = f"{name.replace('roman_', '')}_flux_psf" + field_err_name = f"{name.replace('roman_', '')}_flux_psf_err" + data[field_name] = [100.0 + i * 10, 150.0 + i * 10, 200.0 + i * 10] + data[field_err_name] = [5.0 + i * 0.5, 7.5 + i * 0.5, 10.0 + i * 0.5] + + return data + + +@pytest.fixture +def roman_catalog_handler(): + """Create a basic RomanCatalogHandler instance for testing""" + return RomanCatalogHandler("test_catalog.asdf") + + +class TestRomanCatalogHandler: + """Test class for the RomanCatalogHandler""" + + def test_init_default(self): + """Test initialization with default parameters""" + handler = RomanCatalogHandler() + assert handler.cat_name == "" + assert handler.cat_array is None + assert handler.catalog is None + assert handler.cat_temp_filename == "cat_temp_file.csv" + assert isinstance(handler.filter_names, list) + assert len(handler.filter_names) > 0 + + def test_init_with_catalog(self): + """Test initialization with a catalog name""" + catalog_name = "test_catalog.asdf" + handler = RomanCatalogHandler(catalog_name) + assert handler.cat_name == catalog_name + + def test_init_filter_list_none(self): + """Test initialization when filter list is None""" + with patch( + "roman_photoz.roman_catalog_handler.default_roman_config" + ) as mock_config: + mock_config.get.return_value = None + with pytest.raises( + ValueError, match="Filter list not found in default config file" + ): + RomanCatalogHandler() + + @patch("roman_photoz.roman_catalog_handler.rdm.open") + def test_read_catalog(self, mock_open, mock_catalog_data): + """Test reading a catalog file""" + # Setup mock + mock_dm = MagicMock() + mock_dm.source_catalog.as_array.return_value = mock_catalog_data + mock_open.return_value = mock_dm + + # Initialize handler and read catalog + handler = RomanCatalogHandler("test_catalog.asdf") + handler.read_catalog() + + # Check that the catalog was read correctly + mock_open.assert_called_once_with("test_catalog.asdf") + assert handler.cat_array is not None + assert len(handler.cat_array) == 3 + assert handler.cat_array["id"][0] == 1 + + def test_format_catalog(self, roman_catalog_handler, mock_catalog_data): + """Test formatting a catalog""" + # Setup + roman_catalog_handler.cat_array = mock_catalog_data + + # Execute + roman_catalog_handler.format_catalog() + + # Check that the catalog was formatted correctly + assert roman_catalog_handler.catalog is not None + assert "id" in roman_catalog_handler.catalog.dtype.names + + # Check that filter fields were added correctly + for filter_name in roman_catalog_handler.filter_names: + flux_field = f"flux_psf_{filter_name}" + flux_err_field = f"flux_psf_err_{filter_name}" + assert flux_field in roman_catalog_handler.catalog.dtype.names + assert flux_err_field in roman_catalog_handler.catalog.dtype.names + + # Check that additional required fields were added + assert "context" in roman_catalog_handler.catalog.dtype.names + assert "zspec" in roman_catalog_handler.catalog.dtype.names + assert "string_data" in roman_catalog_handler.catalog.dtype.names + + # Check that data was copied correctly for a sample field + assert roman_catalog_handler.catalog["id"][0] == mock_catalog_data["id"][0] + assert ( + roman_catalog_handler.catalog["zspec"][1] == mock_catalog_data["zspec"][1] + ) + + @patch("roman_photoz.roman_catalog_handler.RomanCatalogHandler.read_catalog") + @patch("roman_photoz.roman_catalog_handler.RomanCatalogHandler.format_catalog") + def test_process( + self, + mock_format_catalog, + mock_read_catalog, + roman_catalog_handler, + mock_catalog_data, + ): + """Test processing a catalog""" + # Setup + roman_catalog_handler.catalog = mock_catalog_data + + # Execute + result = roman_catalog_handler.process() + + # Check that methods were called + mock_read_catalog.assert_called_once() + mock_format_catalog.assert_called_once() + + # Check that process returns the catalog + assert result is roman_catalog_handler.catalog + + @patch("roman_photoz.roman_catalog_handler.rdm.open") + def test_end_to_end_process(self, mock_open, mock_catalog_data): + """Test the entire process flow from read to format""" + # Setup mock + mock_dm = MagicMock() + mock_dm.source_catalog.as_array.return_value = mock_catalog_data + mock_open.return_value = mock_dm + + # Initialize handler + handler = RomanCatalogHandler("test_catalog.asdf") + + # Process the catalog + result = handler.process() + + # Check that the catalog was processed correctly + assert result is not None + assert "id" in result.dtype.names + assert "context" in result.dtype.names + assert "zspec" in result.dtype.names + assert "string_data" in result.dtype.names + + +if __name__ == "__main__": + pytest.main(["-v", "test_roman_catalog_handler.py"]) diff --git a/tox.ini b/tox.ini index 68cf212..b2792ed 100644 --- a/tox.ini +++ b/tox.ini @@ -1,4 +1,5 @@ [tox] +isolated_build = true env_list = check-{style,dependencies,build} test{,-alldeps,-devdeps}{,-pyargs,-warnings,-regtests,-cov,-webbpsf}{-nolegacypath}