diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml new file mode 100644 index 0000000..d5c70eb --- /dev/null +++ b/.github/workflows/run-tests.yml @@ -0,0 +1,37 @@ +name: Run tests +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + workflow_dispatch: + + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Test with pytest + run: | + pip install pytest pytest-cov + pytest --cov=. --cov-report=xml tests/test_validate.py + pytest --cov=. --cov-report=xml tests/test_score.py diff --git a/README.md b/README.md index e4b6a05..0eb59fc 100644 --- a/README.md +++ b/README.md @@ -66,8 +66,9 @@ in Python. R support TBD. > for each task. > [!IMPORTANT] -> Modifying the `main()` function is highly discouraged. This function has -> specifically been written to interact with ORCA. +> The `main()` function is specifically designed for seamless interaction with +> ORCA. **Modifying this function is strongly discouraged** and could lead to +> compatibility issues. Unit tests are in place to ensure its proper integration. 3. Update `requirements.txt` with any additional libraries/packages used by the script. @@ -117,8 +118,7 @@ in Python. R support TBD. > for each task. > [!IMPORTANT] -> Modifying the `main()` function is highly discouraged. This function has -> specifically been written to interact with ORCA. +> Like `validate.py`, modifying `main()` is strongly discouraged. 3. Update `requirements.txt` with any additional libraries/packages used by the script. diff --git a/sample_data/groundtruth/truth_data.csv b/sample_data/groundtruth/truth_data.csv index c3702b7..0a09fef 100644 --- a/sample_data/groundtruth/truth_data.csv +++ b/sample_data/groundtruth/truth_data.csv @@ -1,6 +1,6 @@ id,disease -1,1 -2,1 -3,0 -4,1 -5,1 +A_01,1 +A_02,1 +A_03,0 +A_04,1 +A_05,1 diff --git a/sample_data/invalid_pred.csv b/sample_data/invalid_pred.csv index e706af0..dfa7f8e 100644 --- a/sample_data/invalid_pred.csv +++ b/sample_data/invalid_pred.csv @@ -1,7 +1,7 @@ id,probability,note -1,-0.011,invalid value (bad range) -2,NA,invalid value (null) -3,0.092, -3,0.913,invalid ID (duplicate) -5,0.543, -10,0.290,invalid ID (unkonwn) \ No newline at end of file +A_01,-0.011,invalid value (bad range) +A_02,NA,invalid value (null) +A_03,0.092, +A_03,0.913,invalid ID (duplicate) +A_05,0.543, +A_10,0.290,invalid ID (unkonwn) \ No newline at end of file diff --git a/sample_data/valid_pred.csv b/sample_data/valid_pred.csv index 289624b..881e648 100644 --- a/sample_data/valid_pred.csv +++ b/sample_data/valid_pred.csv @@ -1,6 +1,6 @@ id,probability -1,0.011 -2,0.765 -3,0.092 -4,0.913 -5,0.543 \ No newline at end of file +A_01,0.011 +A_02,0.765 +A_03,0.092 +A_04,0.913 +A_05,0.543 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3e0cb93 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,78 @@ +import json +import os +import sys +import tempfile + +import pandas as pd +import pytest + + +def pytest_configure(config): + """Allow test scripts to import scripts from parent folder.""" + src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + sys.path.insert(0, src_path) + + +@pytest.fixture(scope="module") +def temp_dir(): + """Creates a temporary directory for test files.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture(scope="module") +def groundtruth_dir(temp_dir): + """Creates a temporary groundtruth directory.""" + gt_dir = os.path.join(temp_dir, "groundtruth") + os.makedirs(gt_dir) + return gt_dir + + +@pytest.fixture(scope="module") +def gt_file(groundtruth_dir): + """Creates a dummy groundtruth file.""" + truth = pd.DataFrame( + { + "id": ["A_01", "A_02", "A_03"], + "disease": [1, 0, 1], + } + ) + gt_file = os.path.join(groundtruth_dir, "truth.csv") + truth.to_csv(gt_file, index=False) + return gt_file + + +@pytest.fixture(scope="module") +def pred_file(temp_dir): + """Creates a dummy prediction file.""" + pred = pd.DataFrame( + { + "id": ["A_01", "A_02", "A_03"], + "probability": [0.5, 0.5, 0.5], + } + ) + pred_file = os.path.join(temp_dir, "predictions.csv") + pred.to_csv(pred_file, index=False) + return pred_file + + +@pytest.fixture(scope="module") +def valid_predictions_json(): + """Creates a dummy valid results JSON.""" + return json.dumps( + { + "validation_status": "VALIDATED", + "validation_errors": "", + } + ) + + +@pytest.fixture(scope="module") +def invalid_predictions_json(): + """Creates a dummy invalid results JSON.""" + return json.dumps( + { + "validation_status": "INVALID", + "validation_errors": "Something went wrong.", + } + ) diff --git a/tests/test_score.py b/tests/test_score.py new file mode 100644 index 0000000..1c3f8ef --- /dev/null +++ b/tests/test_score.py @@ -0,0 +1,258 @@ +"""Test for score script. + +These tests are designed to test the general functionality for +interacting with ORCA, NOT the actual scoring process written. +""" + +import json +import os +from unittest.mock import patch + +import pytest +import typer +from typer.testing import CliRunner + +from score import main, score + +app = typer.Typer() +app.command()(main) + +runner = CliRunner() + + +# ----- Tests for score() function ----- +def test_score_valid_task_number(gt_file, pred_file): + """Test: score() returns a dict for valid task number.""" + task_number = 1 + res = score( + task_number=task_number, + gt_file=gt_file, + pred_file=pred_file, + ) + assert isinstance(res, dict) + + +def test_score_invalid_task_number(): + """Test: score() raises KeyError for invalid task number.""" + task_number = 99999 + with pytest.raises(KeyError): + score( + task_number=task_number, + gt_file="dummy_gt.csv", + pred_file="dummy_pred.csv", + ) + + +# ----- Tests for main() function ----- +@patch("score.extract_gt_file") +@patch("score.score") +def test_main_invalid_task_number( + mock_score, mock_extract_gt_file, gt_file, pred_file, temp_dir +): + """Test: final results should be INVALID for invalid task number.""" + task_number = 99999 + mock_extract_gt_file.return_value = gt_file + mock_score.side_effect = KeyError + + groundtruth_dir = os.path.dirname(gt_file) + output_file = os.path.join(temp_dir, "results.json") + with open(output_file, "w") as f: + pass + result = runner.invoke( + app, + [ + "-p", + pred_file, + "-g", + groundtruth_dir, + "-t", + task_number, + "-o", + output_file, + ], + ) + assert result.exit_code == 0 + assert result.stdout.strip() == "INVALID" + with open(output_file, "r") as f: + output_data = json.load(f) + assert output_data["score_status"] == "INVALID" + assert ( + output_data["score_errors"] + == f"Invalid challenge task number specified: `{task_number}`" + ) + mock_extract_gt_file.assert_called_once_with(groundtruth_dir) + mock_score.assert_called_once_with( + task_number=task_number, gt_file=gt_file, pred_file=pred_file + ) + + +@patch("score.extract_gt_file") +@patch("score.score") +def test_main_prior_validation_failed( + mock_score, + mock_extract_gt_file, + temp_dir, + invalid_predictions_json, + groundtruth_dir, +): + """Test: final results should be INVALID for invalid predictions file.""" + output_file = os.path.join(temp_dir, "results.json") + with open(output_file, "w") as f: + f.write(invalid_predictions_json) + result = runner.invoke( + app, + [ + "-p", + "dummy_pred.csv", + "-g", + groundtruth_dir, + "-t", + "1", + "-o", + output_file, + ], + ) + assert result.exit_code == 0 + assert result.stdout.strip() == "INVALID" + with open(output_file, "r") as f: + output_data = json.load(f) + assert output_data["score_status"] == "INVALID" + assert ( + output_data["score_errors"] + == "Submission could not be evaluated due to validation errors." + ) + mock_extract_gt_file.assert_not_called() + mock_score.assert_not_called() + + +@patch("score.extract_gt_file") +@patch("score.score") +def test_main_no_prior_validations( + mock_score, + mock_extract_gt_file, + gt_file, + pred_file, + temp_dir, +): + """Test: notice about no prior validation results should be given.""" + mock_extract_gt_file.return_value = gt_file + groundtruth_dir = os.path.dirname(gt_file) + output_file = os.path.join(temp_dir, "dummy_results.json") + result = runner.invoke( + app, + [ + "-p", + pred_file, + "-g", + groundtruth_dir, + "-t", + "1", + "-o", + output_file, + ], + ) + assert result.exit_code == 0 + assert result.stdout.strip() in {"SCORED", "INVALID"} + with open(output_file) as f: + output_data = json.load(f) + assert output_data["validation_status"] == "" + assert output_data["validation_errors"] == ( + "Validation results not found. Proceeding with scoring but it " + "may fail or results may be inaccurate." + ) + mock_extract_gt_file.assert_called_once_with(groundtruth_dir) + mock_score.assert_called_once_with( + task_number=1, gt_file=gt_file, pred_file=pred_file + ) + + +@patch("score.extract_gt_file") +@patch("score.score") +def test_main_valid_predictions_cannot_score( + mock_score, + mock_extract_gt_file, + valid_predictions_json, + gt_file, + pred_file, + temp_dir, +): + """ + Test: final results should be INVALID when predictions cannot be scored + (indicated by ValueError). + """ + mock_extract_gt_file.return_value = gt_file + mock_score.side_effect = ValueError + + groundtruth_dir = os.path.dirname(gt_file) + output_file = os.path.join(temp_dir, "results.json") + with open(output_file, "w") as f: + f.write(valid_predictions_json) + result = runner.invoke( + app, + [ + "-p", + pred_file, + "-g", + groundtruth_dir, + "-t", + "1", + "-o", + output_file, + ], + ) + assert result.exit_code == 0 + assert result.stdout.strip() == "INVALID" + with open(output_file) as f: + output_data = json.load(f) + assert output_data["score_status"] == "INVALID" + assert ( + output_data["score_errors"] + == "Error encountered during scoring; submission not evaluated." + ) + mock_extract_gt_file.assert_called_once_with(groundtruth_dir) + mock_score.assert_called_once_with( + task_number=1, gt_file=gt_file, pred_file=pred_file + ) + + +@patch("score.extract_gt_file") +@patch("score.score") +def test_main_valid_predictions_can_score( + mock_score, + mock_extract_gt_file, + valid_predictions_json, + gt_file, + pred_file, + temp_dir, +): + """ + Test: final results should be SCORED for successful scoring. + """ + mock_extract_gt_file.return_value = gt_file + groundtruth_dir = os.path.dirname(gt_file) + output_file = os.path.join(temp_dir, "results.json") + with open(output_file, "w") as f: + f.write(valid_predictions_json) + result = runner.invoke( + app, + [ + "-p", + pred_file, + "-g", + groundtruth_dir, + "-t", + "1", + "-o", + output_file, + ], + ) + assert result.exit_code == 0 + assert result.stdout.strip() == "SCORED" + with open(output_file) as f: + output_data = json.load(f) + assert output_data["score_status"] == "SCORED" + assert output_data["score_errors"] == "" + mock_extract_gt_file.assert_called_once_with(groundtruth_dir) + mock_score.assert_called_once_with( + task_number=1, gt_file=gt_file, pred_file=pred_file + ) diff --git a/tests/test_validate.py b/tests/test_validate.py new file mode 100644 index 0000000..7768b1a --- /dev/null +++ b/tests/test_validate.py @@ -0,0 +1,166 @@ +"""Test for validation script. + +These tests are designed to test the general functionality for +interacting with ORCA, NOT the actual validation process written. +""" + +import json +import os +from unittest.mock import patch + +import typer +from typer.testing import CliRunner + +from validate import main, validate + +app = typer.Typer() +app.command()(main) + +runner = CliRunner() + + +# ----- Tests for validate() function ----- +def test_validate_valid_task_number(gt_file, pred_file): + """ + Test: validate() returns a list, filter, or tuple for valid + task number. + """ + task_number = 1 + errors = validate( + task_number=task_number, + gt_file=gt_file, + pred_file=pred_file, + ) + assert isinstance(errors, (list, filter, tuple)) + + +def test_validate_invalid_task_number(): + """Test: validate() notifies about invalid task number.""" + task_number = 99999 + errors = validate( + task_number=task_number, + gt_file="dummy_gt.csv", + pred_file="dummy_pred.csv", + ) + assert f"Invalid challenge task number specified: `{task_number}`" in errors + + +# ----- Tests for main() function ----- +@patch("validate.extract_gt_file") +@patch("validate.validate") +def test_main_valid_submission_type( + mock_validate, mock_extract_gt_file, gt_file, pred_file, temp_dir +): + """Test: final results should be INVALID or VALIDATED.""" + mock_extract_gt_file.return_value = gt_file + mock_validate.return_value = [] + groundtruth_dir = os.path.dirname(gt_file) + output_file = os.path.join(temp_dir, "results.json") + result = runner.invoke( + app, + [ + "-p", + pred_file, + "-g", + groundtruth_dir, + "-t", + "1", + "-o", + output_file, + ], + ) + assert result.exit_code == 0 + + # Ensure that STDOUT is one of `VALIDATED` or `INVALID` + assert result.stdout.strip() in {"VALIDATED", "INVALID"} + with open(output_file) as f: + output_data = json.load(f) + + # If status is `VALIDATED`, there should be no error messages. + if result.stdout.strip() == "VALIDATED": + assert output_data["validation_status"] == "VALIDATED" + assert output_data["validation_errors"] == "" + + # Otherwise, there should be error message(s) for `INVALID` predictions. + else: + assert output_data["validation_status"] == "INVALID" + assert not output_data["validation_errors"] + + mock_extract_gt_file.assert_called_once_with(groundtruth_dir) + mock_validate.assert_called_once_with( + task_number=1, gt_file=gt_file, pred_file=pred_file + ) + + +@patch("validate.extract_gt_file") +@patch("validate.validate") +def test_main_invalid_submission_type( + mock_validate, mock_extract_gt_file, groundtruth_dir, temp_dir +): + """Test: final results should be INVALID for incorrect submission type.""" + invalid_file = os.path.join(temp_dir, "INVALID_predictions.txt") + with open(invalid_file, "w") as f: + f.write("foo") + output_file = os.path.join(temp_dir, "results.json") + result = runner.invoke( + app, + [ + "-p", + invalid_file, + "-g", + groundtruth_dir, + "-t", + "1", + "-o", + output_file, + ], + ) + assert result.exit_code == 0 + assert result.stdout.strip() == "INVALID" + with open(output_file, "r") as f: + output_data = json.load(f) + assert output_data["validation_status"] == "INVALID" + mock_extract_gt_file.assert_not_called() + mock_validate.assert_not_called() + + +@patch("validate.extract_gt_file") +@patch("validate.validate") +def test_main_long_error_message( + mock_validate, mock_extract_gt_file, gt_file, pred_file, temp_dir +): + """Test: validation errors should never exceed 500 characters.""" + mock_extract_gt_file.return_value = gt_file + + # Create a dummy string longer than 500 characters. + long_error_message = "foo" * 500 + mock_validate.return_value = [long_error_message] + + groundtruth_dir = os.path.dirname(gt_file) + output_file = os.path.join(temp_dir, "results.json") + + result = runner.invoke( + app, + [ + "-p", + pred_file, + "-g", + groundtruth_dir, + "-t", + "1", + "-o", + output_file, + ], + ) + assert result.exit_code == 0 + assert result.stdout.strip() == "INVALID" + with open(output_file, "r") as f: + output_data = json.load(f) + assert output_data["validation_status"] == "INVALID" + assert len(output_data["validation_errors"]) < 500 + assert output_data["validation_errors"].endswith("...") + + mock_extract_gt_file.assert_called_once_with(groundtruth_dir) + mock_validate.assert_called_once_with( + task_number=1, gt_file=gt_file, pred_file=pred_file + ) diff --git a/validate.py b/validate.py index 96c881a..4b24457 100644 --- a/validate.py +++ b/validate.py @@ -40,7 +40,7 @@ } -def validate_task1(gt_file: str, pred_file: str) -> list[str]: +def validate_task1(gt_file: str, pred_file: str) -> list[str] | filter: """Sample validation function. Checks include: @@ -101,7 +101,7 @@ def validate_task1(gt_file: str, pred_file: str) -> list[str]: # return [] -def validate(task_number: int, gt_file: str, pred_file: str) -> list[str]: +def validate(task_number: int, gt_file: str, pred_file: str) -> list[str] | filter: """ Routes validation to the appropriate task-specific function. """