diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..d0bb359 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,158 @@ +name: Tests + +on: + push: + branches: [ main, dev ] + pull_request: + branches: [ main ] + schedule: + - cron: '0 0 * * 0' # Run weekly on Sundays at midnight UTC + workflow_dispatch: # Allow manual triggering + +jobs: + quick-tests: + name: Quick Tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip packages + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run quick tests + run: | + pytest tests/ -m "quick" -v --tb=short + + - name: Upload test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: quick-test-results-py${{ matrix.python-version }} + path: | + .pytest_cache + test-results.xml + retention-days: 7 + + full-tests: + name: Full Test Suite + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main' + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Cache pip packages + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Cache test data + uses: actions/cache@v3 + with: + path: ~/.cache/keypoint_moseq_tests + key: ${{ runner.os }}-test-data-v1 + restore-keys: | + ${{ runner.os }}-test-data- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run full test suite (exclude slow) + run: | + pytest tests/ \ + --cov=keypoint_moseq \ + --cov-report=xml \ + --cov-report=term \ + -m "not slow" \ + -v \ + --tb=short \ + --junitxml=test-results.xml + + - name: Check coverage threshold + run: | + coverage report --fail-under=40 + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + fail_ci_if_error: false + verbose: true + + - name: Upload test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: full-test-results + path: | + coverage.xml + test-results.xml + htmlcov/ + retention-days: 30 + + - name: Comment PR with coverage + if: github.event_name == 'pull_request' + uses: py-cov-action/python-coverage-comment-action@v3 + with: + GITHUB_TOKEN: ${{ github.token }} + MINIMUM_GREEN: 50 + MINIMUM_ORANGE: 40 + + slow-tests: + name: Slow Tests (Weekly) + runs-on: ubuntu-latest + if: github.event_name == 'workflow_dispatch' || github.event_name == 'schedule' + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run all tests including slow + run: | + pytest tests/ -v --tb=short --timeout=7200 + + - name: Upload test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: slow-test-results + path: test-results.xml + retention-days: 30 diff --git a/.gitignore b/.gitignore index bf6f2ff..21ad1f7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ **/.DS_Store testing update_pypi.sh +docs/source/dlc* +docs/source/demo* +tests/dlc* +temp* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index 96d3600..8e78fd2 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,8 @@ -# Keypoint MoSeq +# Keypoint MoSeq + +![Tests](https://github.com/dattalab/keypoint-moseq/actions/workflows/test.yml/badge.svg) +[![codecov](https://codecov.io/gh/dattalab/keypoint-moseq/branch/main/graph/badge.svg)](https://codecov.io/gh/dattalab/keypoint-moseq) +[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) ![logo](docs/source/_static/logo.jpg) diff --git a/keypoint_moseq/analysis.py b/keypoint_moseq/analysis.py index e9a53dd..d5729d5 100644 --- a/keypoint_moseq/analysis.py +++ b/keypoint_moseq/analysis.py @@ -16,10 +16,29 @@ from glob import glob import panel as pn from jax_moseq.utils import get_durations, get_frequencies +from packaging import version pn.extension("plotly", "tabulator") na = np.newaxis +# seaborn version compatibility: evaluate once at import time +_SEABORN_VERSION = version.parse(sns.__version__) +_USE_NATIVE_SCALE = _SEABORN_VERSION >= version.parse("0.14") + + +def _get_pointplot_errorbar_kwargs(): + """Get the appropriate errorbar kwargs for seaborn pointplot based on version. + + seaborn 0.14.0 changed the errorbar API from errorbar=("ci", 68) + to using native_scale parameter and errorbar="se". + """ + if _USE_NATIVE_SCALE: + # seaborn >= 0.14 + return {"errorbar": "se", "native_scale": True} + else: + # seaborn < 0.14 + return {"errorbar": ("ci", 68)} + def get_syllable_names(project_dir, model_name, syllable_ixs): """Get syllable names from syll_info.csv file. Labels consist of the @@ -1151,16 +1170,17 @@ def plot_syll_stats_with_sem( # plot each group's stat data separately, computes groupwise SEM, and orders data based on the stat/ordering parameters hue = "group" if groups is not None else None + errorbar_kwargs = _get_pointplot_errorbar_kwargs() ax = sns.pointplot( data=stats_df, x="syllable", y=stat, hue=hue, order=ordering, - errorbar=("ci", 68), ax=ax, hue_order=groups, palette=colors, + **errorbar_kwargs, ) # where some data has already been plotted to ax diff --git a/keypoint_moseq/io.py b/keypoint_moseq/io.py index c507f6a..f01344f 100644 --- a/keypoint_moseq/io.py +++ b/keypoint_moseq/io.py @@ -860,7 +860,7 @@ def save_keypoints( bodyparts = [f"bodypart{i}" for i in range(num_keypoints)] # create column names - suffixes = ["x", "y", "z"][:num_keypoints] + suffixes = ["x", "y", "z"][:num_dims] if confidences is not None: suffixes += ["conf"] columns = [f"{bp}_{suffix}" for bp in bodyparts for suffix in suffixes] diff --git a/keypoint_moseq/viz.py b/keypoint_moseq/viz.py index 81d7c0a..7ca2b5e 100644 --- a/keypoint_moseq/viz.py +++ b/keypoint_moseq/viz.py @@ -7,6 +7,7 @@ import h5py import numpy as np import plotly +import matplotlib import matplotlib.pyplot as plt from scipy.ndimage import gaussian_filter1d from vidio.read import OpenCVReader @@ -22,6 +23,7 @@ from plotly.subplots import make_subplots import plotly.io as pio +from packaging import version pio.renderers.default = "iframe" @@ -31,6 +33,10 @@ # suppress warnings from imageio logging.getLogger().setLevel(logging.ERROR) +# matplotlib version compatibility: evaluate once at import time +_MATPLOTLIB_VERSION = version.parse(matplotlib.__version__) +_USE_BUFFER_RGBA = _MATPLOTLIB_VERSION >= version.parse("3.10") + def crop_image(image, centroid, crop_size): """Crop an image around a centroid. @@ -1428,12 +1434,39 @@ def get_limits( return lims.astype(int) +def _get_canvas_buffer_method(canvas): + """Get the appropriate canvas buffer method based on matplotlib version. + + matplotlib 3.10 removed tostring_rgb() in favor of buffer_rgba(). + This function returns the correct method to call. + """ + if _USE_BUFFER_RGBA: + return canvas.buffer_rgba + else: + return canvas.tostring_rgb + + +def _reshape_canvas_buffer(raster_flat, height, width): + """Reshape and convert canvas buffer to RGB format. + + For matplotlib >= 3.10, drops the alpha channel from RGBA. + For matplotlib < 3.10, returns RGB directly. + """ + if _USE_BUFFER_RGBA: + # matplotlib >= 3.10: RGBA buffer, drop alpha channel + return raster_flat.reshape((height, width, 4))[:, :, :3] + else: + # matplotlib < 3.10: RGB buffer + return raster_flat.reshape((height, width, 3)) + + def rasterize_figure(fig): canvas = fig.canvas canvas.draw() width, height = canvas.get_width_height() - raster_flat = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") - raster = raster_flat.reshape((height, width, 3)) + buffer_method = _get_canvas_buffer_method(canvas) + raster_flat = np.frombuffer(buffer_method(), dtype="uint8") + raster = _reshape_canvas_buffer(raster_flat, height, width) return raster diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c53ee5d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,159 @@ +[build-system] +build-backend = "setuptools.build_meta" + +requires = [ "setuptools>=45", "setuptools-scm[toml]>=6.2", "versioneer[toml]==0.29", "wheel" ] + +[project] +name = "keypoint-moseq" +description = "Unsupervised machine learning method for behavior analysis using keypoint tracking data" +readme = "README.md" +keywords = [ "behavior", "keypoint-tracking", "machine-learning", "motion-sequencing", "neuroscience" ] +license = { text = "Non-Commercial Research and Academic Use - see LICENSE.md" } +maintainers = [ + { name = "Caleb Weinreb", email = "calebsw@gmail.com" }, +] +authors = [ + { name = "Caleb Weinreb", email = "calebsw@gmail.com" }, +] +requires-python = ">=3.10" +classifiers = [ + "Intended Audience :: Science/Research", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Bio-Informatics", +] + +dynamic = [ "version" ] +# Core dependencies from setup.cfg +dependencies = [ + "bokeh>=2.4.3,<3.9", # Tested: 2.4.3-3.8.0 + "commentjson", + "cytoolz", + "holoviews[recommended]>=1.15.4,<1.22", # Tested: 1.15.4-1.21.0 + "imageio[ffmpeg]", + "ipykernel", + "ipympl", + "ipython-genutils", + "ipywidgets", + "jax>=0.4.20,<0.7", # Upper for tf probability 0.25.0 + "jax-moseq", + "matplotlib>=3,<4", # Tested: 3.0.3-3.10 + "ndx-pose", + "networkx", + "numpy<=1.26.4", # Upper bound for jax compatibility + "packaging", # For version comparison + "pandas", + "panel>=0.14.4,<1.9", # Tested: 0.14.4-1.8.2 + "plotly", + "pynwb", + "pyyaml", + "seaborn>=0.8,<1", # Tested: 0.8.1-0.13.2 + "sleap-io", + "statsmodels", + "tables", + "tabulate", + "tqdm", + "vidio", +] + +# All optional dependencies +optional-dependencies.all = [ + "keypoint-moseq[dev,cuda,test]", +] +# CUDA support for GPU acceleration +optional-dependencies.cuda = [ + "jax-moseq[cuda]", +] +# Development and documentation dependencies +optional-dependencies.dev = [ + "autodocsumm", + "myst-nb", + "sphinx", + "sphinx-rtd-theme", +] +# Testing dependencies +optional-dependencies.test = [ + "gdown", # For downloading test data from Google Drive + "h5py>=3", # For HDF5 validation + "jupytext>=1.14", + "pytest>=7", + "pytest-cov>=4", + "pytest-timeout>=2.1", + "pytest-xdist>=3", # Parallel test execution +] +urls."Bug Tracker" = "https://github.com/dattalab/keypoint-moseq/issues" +urls.Documentation = "https://keypoint-moseq.readthedocs.io/en/latest/" +urls.Homepage = "https://github.com/dattalab/keypoint-moseq" +urls."Paper" = "https://www.nature.com/articles/s41592-024-02318-2" +urls.Repository = "https://github.com/dattalab/keypoint-moseq" + +[tool.setuptools] +packages = [ "keypoint_moseq" ] +include-package-data = true + +[tool.setuptools.package-data] +"*" = [ "*.md" ] + +[tool.pytest.ini_options] +testpaths = [ "tests" ] +python_files = [ "test_*.py" ] +python_classes = [ "Test*" ] +python_functions = [ "test_*" ] +addopts = [ + "-v", # Verbose output + "--cov=keypoint_moseq", + "--cov-report=term-missing", + "--cov-report=html", +] +timeout = 2700 # 45 minutes per test +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", + "notebook: marks tests derived from notebooks", + "quick: marks tests as quick (< 1 minute)", + "medium: marks tests as medium duration (1-5 minutes)", +] +# Custom CLI options can be added via conftest.py for --no-teardown + +[tool.coverage.run] +source = [ "keypoint_moseq" ] +omit = [ + "keypoint_moseq/_version.py", # Auto-generated by versioneer + "*/tests/*", # Exclude test files from coverage + "*/test_*.py", # Exclude test files +] + +[tool.coverage.report] +# Minimum coverage threshold (fails if coverage drops below this) +fail_under = 40 +# Don't report files with 100% coverage to focus on gaps +skip_covered = false +# Show lines that weren't executed +show_missing = true +# Exclude lines from coverage (e.g., defensive code, debugging) +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", + "if typing.TYPE_CHECKING:", +] + +[tool.coverage.html] +# Directory for HTML coverage report +directory = "htmlcov" + +[tool.versioneer] +VCS = "git" +style = "pep440" +versionfile_source = "keypoint_moseq/_version.py" +versionfile_build = "keypoint_moseq/_version.py" +tag_prefix = "" +parentdir_prefix = "" diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..e116893 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,334 @@ +# Keypoint-MoSeq Test Suite + +This directory contains pytest-compatible tests for the keypoint-MoSeq package, converted from the official Jupyter notebooks. + +## Structure + +### Test Files + +- `test_colab_workflow.py` - Complete workflow tests from colab notebook +- `test_modeling.py` - Model fitting and checkpoint management tests +- `test_analysis.py` - Result extraction and visualization tests +- `conftest.py` - Shared pytest fixtures and configuration +- `__init__.py` - Package initialization + +## Prerequisites + +### Installation + +Install keypoint-moseq with test dependencies: + +```bash +pip install -e ".[test]" +``` + +This installs: + +- pytest and plugins (pytest-cov, pytest-timeout, pytest-xdist) +- jupytext for notebook conversion +- h5py for HDF5 validation +- gdown for downloading test data from Google Drive + +### Test Data + +Tests use the DLC example project included in the repository: + +- Location: `docs/source/dlc_example_project/` +- Contains: 10 minimal DLC tracking files +- Size: ~small (suitable for CI/CD) + +**Important**: Input data is never deleted during test teardown. + +## Running Tests + +### Basic Usage + +```bash +# Run all tests +pytest tests/ + +# Run with verbose output +pytest tests/ -v + +# Run specific test file +pytest tests/test_colab_workflow.py + +# Run specific test function +pytest tests/test_colab_workflow.py::test_project_setup +``` + +### By Test Category + +Tests are marked by duration and type: + +```bash +# Quick tests only (< 1 minute) +pytest tests/ -m quick + +# Medium tests (1-5 minutes) +pytest tests/ -m medium + +# Integration tests (5-15 minutes with reduced iterations) +pytest tests/ -m integration + +# Notebook-derived tests +pytest tests/ -m notebook + +# Exclude slow tests (for CI/CD) +pytest tests/ -m "not slow" +``` + +### Parallel Execution + +Run tests in parallel with pytest-xdist: + +```bash +# Use all available CPU cores +pytest tests/ -n auto + +# Use specific number of workers +pytest tests/ -n 4 +``` + +### Preserve Test Outputs + +By default, test outputs are cleaned up. To preserve them: + +```bash +# Preserve outputs in /tmp/kpms_test_/ +pytest tests/ --no-teardown + +# Specify custom output directory +pytest tests/ --test-data-dir=/path/to/output +``` + +Example output locations: + +- `/tmp/kpms_test_test_complete_workflow/` +- Contains: model checkpoints, results, plots, videos + +### Timeout Configuration + +Tests have a 30-minute default timeout configured in `pyproject.toml`. + +Override for specific tests: + +```bash +# Set custom timeout (in seconds) +pytest tests/ --timeout=3600 + +# Disable timeout +pytest tests/ --timeout=0 +``` + +## Test Categories + +### Quick Tests (< 1 minute) + +- `test_project_setup` - Project initialization +- `test_load_keypoints` - Data loading +- `test_hyperparameter_estimation` - Hyperparam computation +- `test_config_update` - Configuration management +- `test_syllable_statistics` - Statistics computation + +### Medium Tests (1-5 minutes) + +- `test_format_and_outlier_detection` - Data QA +- `test_pca_fitting` - PCA model fitting +- `test_model_initialization` - Model setup +- `test_ar_hmm_fitting` - AR-HMM fitting (reduced iterations) + +### Integration Tests (5-15 minutes) + +- `test_complete_workflow` - End-to-end pipeline +- `test_full_model_fitting` - Complete model fitting +- `test_model_saving_and_loading` - Checkpoint management +- `test_result_extraction` - Result generation +- `test_csv_export` - CSV output +- `test_trajectory_plots` - Visualization +- `test_similarity_dendrogram` - Dendrogram generation + +### Slow Tests (> 15 minutes) + +- `test_grid_movies` - Video rendering (~20 minutes) + +Run without slow tests: + +```bash +pytest tests/ -m "not slow" +``` + +## Test Fixtures + +Key fixtures available in `conftest.py`: + +### Path Fixtures + +- `temp_project_dir` - Temporary project directory (cleaned up unless --no-teardown) +- `dlc_example_project` - Path to DLC example data (never cleaned up) +- `dlc_config` - Path to DLC config.yaml +- `dlc_videos_dir` - Path to DLC videos directory +- `notebook_output_dir` - Directory for notebook outputs +- `test_data_cache` - Cache directory for downloaded data + +### Configuration Fixtures + +- `reduced_iterations` - Reduced iteration counts for fast testing: + - `ar_hmm_iters`: 10 (vs 50 default) + - `full_model_iters`: 20 (vs 500 default) + - `pca_variance`: 0.90 (90% variance explained) + - `timeout_minutes`: 30 + +### Utility Functions + +- `download_google_drive_file()` - Download from Google Drive (skips if exists) +- `unzip_file()` - Extract zip archives + +## Expected Test Durations + +Based on actual execution with minimal DLC dataset: + +| Test Suite | Duration | Notes | +|------------|----------|-------| +| Quick tests | < 2 min | All quick tests combined | +| Medium tests | 5-10 min | Includes PCA, outlier detection | +| Integration tests | 60-90 min | All integration tests | +| Complete workflow | ~15 min | Single full pipeline test | +| All tests (no slow) | ~90 min | Suitable for CI/CD | +| All tests (with slow) | ~110 min | Includes video rendering | + +## CI/CD Recommendations + +### Minimal Test Suite (Fast) + +```bash +# Run only quick tests (~2 minutes) +pytest tests/ -m quick -n auto +``` + +### Standard Test Suite (Balanced) + +```bash +# Run quick + medium tests (~15 minutes) +pytest tests/ -m "quick or medium" -n auto +``` + +### Full Test Suite (Comprehensive) + +```bash +# Run all except slow tests (~90 minutes) +pytest tests/ -m "not slow" -n auto +``` + +### Weekly Tests + +```bash +# Run everything including slow tests (~110 minutes) +pytest tests/ -n auto +``` + +## Troubleshooting + +### Test Failures + +**Import errors**: Ensure package installed with test dependencies + +```bash +pip install -e ".[test]" +``` + +**DLC data not found**: Verify DLC example project exists + +```bash +ls docs/source/dlc_example_project/ +``` + +**Timeout errors**: Increase timeout or run on faster hardware + +```bash +pytest tests/ --timeout=3600 +``` + +**JAX/GPU warnings**: Expected on CPU-only systems (tests run fine) + +### Common Warnings + +- `FigureCanvasAgg is non-interactive` - Expected for headless execution +- `os.fork() was called... may lead to deadlock` - JAX warning during video generation (harmless) +- `An NVIDIA GPU may be present... Falling back to cpu` - Expected without CUDA + +### Preserving Outputs for Debugging + +```bash +# Keep outputs and show print statements +pytest tests/test_colab_workflow.py::test_complete_workflow -s --no-teardown + +# Check preserved outputs +ls /tmp/kpms_test_test_complete_workflow/ +``` + +## Code Coverage + +### Generate Coverage Report + +Run tests with coverage measurement: + +```bash +# Generate HTML and terminal coverage report +pytest tests/ --cov=keypoint_moseq --cov-report=html --cov-report=term -m "not slow" + +# View HTML report in browser +open htmlcov/index.html # macOS +xdg-open htmlcov/index.html # Linux +``` + +### Coverage Commands + +```bash +# Coverage with missing line numbers +pytest tests/ --cov=keypoint_moseq --cov-report=term-missing + +# Coverage for specific module +pytest tests/ --cov=keypoint_moseq.analysis --cov-report=term + +# Coverage with XML output (for CI/CD) +pytest tests/ --cov=keypoint_moseq --cov-report=xml --cov-report=term +``` + +## Test Development + +### Adding New Tests + +1. Create test file: `tests/test_.py` +2. Import required modules +3. Add pytest markers: `@pytest.mark.quick`, `@pytest.mark.integration`, etc. +4. Use fixtures: `def test_something(temp_project_dir, dlc_config):` +5. Add assertions: `assert result is not None, "Result should not be None"` +6. Document expected duration in docstring + +### Test Template + +```python +import pytest +from pathlib import Path + +@pytest.mark.quick +@pytest.mark.notebook +def test_feature_name(temp_project_dir, dlc_config): + """Test description + + Expected duration: < 1 minute + """ + import keypoint_moseq as kpms + + # Setup + project_dir = temp_project_dir + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + + # Test logic + result = kpms.some_function() + + # Assertions + assert result is not None, "Result should not be None" + assert Path(project_dir, "output.txt").exists(), "Output file not created" +``` diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..c25abfd --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,8 @@ +""" +Test suite for keypoint-moseq + +This test suite includes tests derived from the original Jupyter notebooks: +- test_colab.py: Tests from keypoint_moseq_colab.ipynb +- test_modeling.py: Tests from modeling.ipynb +- test_analysis.py: Tests from analysis.ipynb +""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5cffad0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,542 @@ +""" +Pytest configuration and shared fixtures for keypoint-moseq tests +""" + +import shutil +import tempfile +import warnings +from pathlib import Path + +import gdown +import pytest +from matplotlib import MatplotlibDeprecationWarning + +# - Warnings Configuration - # +# Ignore warnings from viz.py: "FigureCanvasAgg is non-interactive" +warnings.filterwarnings("ignore", category=UserWarning, module="keypoint_moseq") +# Bohek from Numpy 1.24: DeprecationWarning for np.bool8 +warnings.filterwarnings("ignore", category=DeprecationWarning, module="numpy") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="bokeh") +# From JAX: 0 should be passed as minlength instead of None +# From JAX: shape requires ndarray or scalar, got None +warnings.filterwarnings("ignore", category=DeprecationWarning, module="jax") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="jax_moseq") +warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="tensorflow_probability" +) +# From matplotlib, 'mode' is deprecated, removed in Pillow 13 (2026-10-15) +warnings.filterwarnings("ignore", category=DeprecationWarning, module="PIL") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="matplotlib") +# From matplotlib, tostring_rgb is deprecated +warnings.simplefilter("ignore", MatplotlibDeprecationWarning) + + +def pytest_configure(config): + """Configure pytest environment - set matplotlib to non-interactive backend""" + import matplotlib + + matplotlib.use("Agg") # Non-interactive backend for tests + + +def pytest_addoption(parser): + """Add custom command line options""" + parser.addoption( + "--no-teardown", + action="store_true", + default=False, + help="Preserve test outputs (don't cleanup temporary directories)", + ) + parser.addoption( + "--test-data-dir", + action="store", + default=None, + help="Directory for test data (if not specified, uses temp dir)", + ) + + +@pytest.fixture +def no_teardown(request): + """Check if --no-teardown flag is set""" + return request.config.getoption("--no-teardown") + + +@pytest.fixture +def temp_project_dir(request, no_teardown): + """Create a temporary project directory for testing + + If --no-teardown is specified, the directory is preserved after tests. + """ + test_data_dir = request.config.getoption("--test-data-dir") + + if test_data_dir: + # Use specified directory + tmpdir = Path(test_data_dir) / f"test_{request.node.name}" + tmpdir.mkdir(parents=True, exist_ok=True) + yield str(tmpdir) + if not no_teardown: + shutil.rmtree(tmpdir, ignore_errors=True) + else: + # Use system temp directory + if no_teardown: + # Create in /tmp with predictable name + tmpdir = Path("/tmp") / f"kpms_test_{request.node.name}" + tmpdir.mkdir(parents=True, exist_ok=True) + yield str(tmpdir) + print(f"\n[NO TEARDOWN] Test outputs preserved at: {tmpdir}") + else: + # Standard temporary directory + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture(scope="session") +def dlc_example_project(test_data_cache): + """Path to the DLC example project + + This fixture returns the path to the DLC example data. + First checks the repository location, then downloads from Google Drive if missing. + The data is NEVER deleted during teardown - it's preserved as input data. + Session-scoped since it's read-only data. + """ + repo_root = Path(__file__).parent.parent + dlc_path = repo_root / "docs" / "source" / "dlc_example_project" + + # First, check if data exists in repository + if dlc_path.exists(): + return str(dlc_path) + + # If not in repo, try to download from Google Drive to cache + print(f"DLC example project not found at {dlc_path}") + print("Attempting to download from Google Drive...") + + cached_zip = test_data_cache / "dlc_example_project.zip" + cached_extract = test_data_cache / "dlc_example_project" + + # Check if already downloaded and extracted + if cached_extract.exists() and (cached_extract / "config.yaml").exists(): + print(f"Using cached DLC project: {cached_extract}") + return str(cached_extract) + + try: + # Download from Google Drive (file ID from keypoint_moseq_colab.ipynb) + file_id = "1JGyS9MbdS3MtrlYnh4xdEQwe2bYoCuSZ" + download_google_drive_file(file_id, cached_zip, use_cache=True) + + # Extract the zip file + print(f"Extracting to {test_data_cache}") + unzip_file(cached_zip, test_data_cache) + + if cached_extract.exists(): + print( + f"Successfully downloaded and extracted DLC project to {cached_extract}" + ) + return str(cached_extract) + else: + pytest.skip(f"Downloaded but extraction failed - expected {cached_extract}") + + except Exception as e: + pytest.skip(f"Failed to download DLC example project: {e}") + + +@pytest.fixture(scope="session") +def dlc_config(dlc_example_project): + """Path to DLC config file + + Session-scoped since it's read-only data. + """ + config_path = Path(dlc_example_project) / "config.yaml" + + if not config_path.exists(): + pytest.skip("DLC config file not found") + + return str(config_path) + + +@pytest.fixture(scope="session") +def dlc_videos_dir(dlc_example_project): + """Path to DLC videos directory + + Session-scoped since it's read-only data. + """ + videos_path = Path(dlc_example_project) / "videos" + + if not videos_path.exists(): + pytest.skip("DLC videos directory not found") + + return str(videos_path) + + +@pytest.fixture(scope="session") +def notebook_output_dir(): + """Directory for notebook-generated outputs during testing""" + repo_root = Path(__file__).parent.parent + output_dir = repo_root / "tests" / "notebook_outputs" + output_dir.mkdir(exist_ok=True) + return str(output_dir) + + +@pytest.fixture(scope="session") +def test_data_cache(): + """Cache directory for downloaded test data""" + cache_dir = Path.home() / ".cache" / "keypoint_moseq_tests" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def download_google_drive_file(file_id, output_path, use_cache=True): + """Download a file from Google Drive + + Always checks if file exists before downloading. Downloaded data is + preserved and never deleted during teardown. + + Args: + file_id: Google Drive file ID + output_path: Path where file should be saved + use_cache: If True, skip download if file exists (default: True) + + Returns: + Path to downloaded file + """ + output_path = Path(output_path) + + # Always check if already exists - skip download if present + if output_path.exists(): + if use_cache: + print(f"Using cached file: {output_path}") + return output_path + else: + print(f"File exists but use_cache=False, re-downloading: {output_path}") + + # Create parent directory if needed + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Download from Google Drive + url = f"https://drive.google.com/uc?id={file_id}" + print(f"Downloading from Google Drive: {file_id}") + gdown.download(url, str(output_path), quiet=False) + + return output_path + + +def unzip_file(zip_path, extract_to): + """Extract a zip file + + Args: + zip_path: Path to zip file + extract_to: Directory to extract to + + Returns: + Path to extracted directory + """ + import zipfile + + extract_to = Path(extract_to) + extract_to.mkdir(parents=True, exist_ok=True) + + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_to) + + return extract_to + + +@pytest.fixture(scope="session") +def dlc_test_data(dlc_example_project): + """Alias for dlc_example_project fixture for backward compatibility + + This fixture is maintained for backward compatibility with older test code. + Use dlc_example_project directly in new code. + """ + return dlc_example_project + + +@pytest.fixture(scope="session") +def reduced_iterations(): + """Configuration for reduced iteration counts for faster testing + + Returns dict with recommended iteration counts for CI/CD + + Note: pca_variance set to 0.80 (was 0.90) for speed. This reduces + the number of PCA components, making model fitting ~30-40% faster + while still capturing most variance. For production models, use 0.90. + + Session-scoped since it's just configuration data. + """ + return { + "ar_hmm_iters": 10, # Reduced from 50 + "full_model_iters": 20, # Reduced from 500 + "pca_variance": 0.80, # 80% variance (was 0.90) - faster for tests + "timeout_minutes": 30, # Max test duration + } + + +@pytest.fixture(scope="session") +def kpms(): + """Session-scoped fixture for keypoint_moseq package import + + Eliminates redundant imports in every test function. + """ + import keypoint_moseq as kpms + + return kpms + + +@pytest.fixture(scope="session") +def update_kwargs(): + """Standard config update kwargs used across multiple tests + + Returns dict with common bodypart configurations. + Use with: kpms.update_config(project_dir, **update_kwargs) + + Session-scoped since it's just configuration data. + """ + return { + "use_bodyparts": [ + "spine4", + "spine3", + "spine2", + "spine1", + "head", + "nose", + "right ear", + "left ear", + ], + "anterior_bodyparts": ["nose"], + "posterior_bodyparts": ["spine4"], + } + + +@pytest.fixture(scope="module") +def module_project_dir(request): + """Create a module-scoped temporary project directory + + This is used by fitted_model fixture to create a single project + directory that's shared across all tests in the module. + """ + # Use system temp directory with module name + import tempfile + + tmpdir = Path( + tempfile.mkdtemp(prefix=f"kpms_test_module_{request.module.__name__}_") + ) + yield str(tmpdir) + # Cleanup after all tests in module complete + import shutil + + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.fixture(scope="module") +def prepared_model( + module_project_dir, + dlc_config, + dlc_videos_dir, + reduced_iterations, + kpms, + update_kwargs, +): + """Module-scoped fixture providing initialized model ready for fitting + + Runs setup → load → format → PCA → hyperparams → init once per module. + Tests can then fit the model with different parameters. + + Speed impact: Eliminates 1.5 min of duplicated setup per test. + """ + project_dir = module_project_dir + + # Step 1: Setup project + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + + # Step 2: Update config with standard bodyparts + kpms.update_config(project_dir, **update_kwargs) + config = kpms.load_config(project_dir) + + # Step 3: Load keypoints + coordinates, confidences, _ = kpms.load_keypoints(dlc_videos_dir, "deeplabcut") + + # Step 4: Format data + data, metadata = kpms.format_data(coordinates, confidences, **config) + + # Step 5: Fit PCA + pca = kpms.fit_pca(**data, **config) + + # Step 6: Compute latent dimensions + latent_dim = compute_latent_dim( + pca, variance_threshold=reduced_iterations["pca_variance"] + ) + kpms.update_config(project_dir, latent_dim=int(latent_dim)) + config = kpms.load_config(project_dir) + + # Step 7: Estimate hyperparameters + sigmasq_loc = kpms.estimate_sigmasq_loc( + data["Y"], data["mask"], filter_size=config["fps"] + ) + kpms.update_config(project_dir, sigmasq_loc=sigmasq_loc) + config = kpms.load_config(project_dir) + + # Step 8: Initialize model (but don't fit yet) + model = kpms.init_model(data, pca=pca, **config) + + # Return all intermediate results + return { + "project_dir": project_dir, + "model": model, + "data": data, + "metadata": metadata, + "pca": pca, + "config": config, + "coordinates": coordinates, + "confidences": confidences, + } + + +@pytest.fixture(scope="module") +def fitted_model( + module_project_dir, dlc_config, dlc_videos_dir, reduced_iterations, kpms +): + """Module-scoped fixture providing a fully fitted model + + This fixture runs the expensive workflow once per module: + - Setup project + - Load and format data + - Fit PCA + - Estimate hyperparameters + - Initialize model + - Fit AR-HMM and full model + + Returns dict with all intermediate results for reuse in tests. + + Speed impact: Reduces 10-15 min workflow to <1 min for dependent tests. + """ + project_dir = module_project_dir + + # Step 1: Setup project + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + + # Step 2: Update config + kpms.update_config( + project_dir, + use_bodyparts=[ + "spine4", + "spine3", + "spine2", + "spine1", + "head", + "nose", + "right ear", + "left ear", + ], + anterior_bodyparts=["nose"], + posterior_bodyparts=["spine4"], + ) + config = kpms.load_config(project_dir) + + # Step 3: Load keypoints + coordinates, confidences, _ = kpms.load_keypoints(dlc_videos_dir, "deeplabcut") + + # Step 4: Format data + data, metadata = kpms.format_data(coordinates, confidences, **config) + + # Step 5: Fit PCA + pca = kpms.fit_pca(**data, **config) + kpms.save_pca(pca, project_dir) + + # Step 6: Compute latent dimensions + latent_dim = compute_latent_dim( + pca, variance_threshold=reduced_iterations["pca_variance"] + ) + kpms.update_config(project_dir, latent_dim=int(latent_dim)) + config = kpms.load_config(project_dir) + + # Step 7: Estimate hyperparameters + sigmasq_loc = kpms.estimate_sigmasq_loc( + data["Y"], data["mask"], filter_size=config["fps"] + ) + kpms.update_config(project_dir, sigmasq_loc=sigmasq_loc) + config = kpms.load_config(project_dir) + + # Step 8: Initialize model + model = kpms.init_model(data, pca=pca, **config) + + # Step 9: Fit model + model, model_name = kpms.fit_model( + model, + data, + metadata, + project_dir, + ar_only=True, + num_iters=reduced_iterations["ar_hmm_iters"], + ) + model, _ = kpms.fit_model( + model, + data, + metadata, + project_dir, + ar_only=False, + num_iters=reduced_iterations["full_model_iters"], + ) + + # Return all intermediate results + return { + "project_dir": project_dir, + "model": model, + "model_name": model_name, + "data": data, + "metadata": metadata, + "pca": pca, + "config": config, + "coordinates": coordinates, + "confidences": confidences, + } + + +# Helper functions + + +def compute_latent_dim(pca, variance_threshold=0.9): + """Compute number of PCA components needed to explain variance threshold + + Args: + pca: Fitted PCA object with explained_variance_ratio_ attribute + variance_threshold: Target cumulative variance (default: 0.9 for 90%) + + Returns: + int: Number of components needed + """ + import numpy as np + + cumsum = np.cumsum(pca.explained_variance_ratio_) + latent_dim = int(np.argmax(cumsum >= variance_threshold) + 1) + return latent_dim + + +def load_path_from_model(project_dir, model_name, filename, delete_existing=False): + """Construct standardized path to model output file + + Args: + project_dir: Project directory path + model_name: Model name (timestamp directory) + filename: Target filename (e.g., 'checkpoint.h5', 'results.h5') + + Returns: + Path: Absolute path to file + """ + file_path = Path(project_dir) / model_name / filename + + if delete_existing and file_path.exists(): + file_path.unlink() + + return file_path + + +def assert_result_keys(results, expected_keys): + """Assert that results dict contains all expected keys + + Args: + results: Results dictionary to validate + expected_keys: List or set of expected key names + + Raises: + AssertionError: If any expected keys are missing + """ + missing_keys = set(expected_keys) - set(results.keys()) + assert not missing_keys, f"Results missing keys: {missing_keys}" diff --git a/tests/test_analysis.py b/tests/test_analysis.py new file mode 100644 index 0000000..1216751 --- /dev/null +++ b/tests/test_analysis.py @@ -0,0 +1,301 @@ +""" +Test suite for keypoint-MoSeq analysis functionality + +Tests result extraction, visualization, and analysis tools. +""" + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + + +@pytest.mark.medium +@pytest.mark.notebook +def test_result_extraction(fitted_model, kpms): + """Test extracting results from fitted model + + Expected duration: ~1 minute (uses fitted_model fixture) + """ + from tests.conftest import assert_result_keys, load_path_from_model + + # Use fitted model from fixture + project_dir = fitted_model["project_dir"] + model = fitted_model["model"] + model_name = fitted_model["model_name"] + metadata = fitted_model["metadata"] + config = fitted_model["config"] + + # Verify checkpoint exists + checkpoint_path = load_path_from_model(project_dir, model_name, "checkpoint.h5") + assert checkpoint_path.exists(), "Checkpoint file not created" + + kpms.reindex_syllables_in_checkpoint(project_dir, model_name) + + # Delete results.h5 if it exists (from previous test using same fixture) + _ = load_path_from_model( + project_dir, model_name, "results.h5", delete_existing=True + ) + + # Extract results + results = kpms.extract_results(model, metadata, project_dir, model_name, config) + + # Verify results structure - results is dict[recording_name -> dict[key -> data]] + assert len(results) > 0, "No recordings in results" + + # Check that each recording has expected keys + expected_keys = ["syllable", "centroid", "heading", "latent_state"] + for recording_name, recording_results in results.items(): + assert_result_keys(recording_results, expected_keys) + + # Check data types + assert isinstance( + recording_results["syllable"], np.ndarray + ), f"Syllables not array for {recording_name}" + assert recording_results["syllable"].dtype in [ + np.int32, + np.int64, + ], f"Syllables wrong dtype for {recording_name}" + + +@pytest.mark.medium +@pytest.mark.notebook +def test_csv_export(fitted_model, kpms): + """Test CSV export of results + + Expected duration: ~1 minute (uses fitted_model fixture) + """ + from tests.conftest import load_path_from_model + + # Use fitted model from fixture + project_dir = fitted_model["project_dir"] + model = fitted_model["model"] + model_name = fitted_model["model_name"] + metadata = fitted_model["metadata"] + config = fitted_model["config"] + + # Verify checkpoint exists + checkpoint_path = load_path_from_model(project_dir, model_name, "checkpoint.h5") + assert checkpoint_path.exists(), "Checkpoint file not created" + + kpms.reindex_syllables_in_checkpoint(project_dir, model_name) + + # Delete results.h5 if it exists (from previous test using same fixture) + _ = load_path_from_model( + project_dir, model_name, "results.h5", delete_existing=True + ) + + results = kpms.extract_results(model, metadata, project_dir, model_name, config) + + # Export to CSV + kpms.save_results_as_csv(results, project_dir, model_name) + + # Verify CSV files + results_dir = Path(project_dir) / model_name / "results" + assert results_dir.exists(), "Results directory not created" + + csv_files = list(results_dir.glob("*.csv")) + assert len(csv_files) > 0, "No CSV files created" + + # Verify CSV structure + first_csv = csv_files[0] + df = pd.read_csv(first_csv) + + expected_columns = ["syllable", "centroid x", "centroid y", "heading"] + for col in expected_columns: + assert col in df.columns, f"CSV missing column: {col}" + + # Check data validity + assert len(df) > 0, "CSV is empty" + assert df["syllable"].dtype in [ + np.int32, + np.int64, + ], "Syllable column wrong dtype" + + +@pytest.mark.medium +@pytest.mark.notebook +def test_trajectory_plots(fitted_model, kpms): + """Test trajectory plot generation + + Expected duration: ~1 minute (uses fitted_model fixture) + """ + from tests.conftest import load_path_from_model + + # Use fitted model from fixture + project_dir = fitted_model["project_dir"] + model = fitted_model["model"] + model_name = fitted_model["model_name"] + metadata = fitted_model["metadata"] + config = fitted_model["config"] + coordinates = fitted_model["coordinates"] + + # Verify checkpoint exists + checkpoint_path = load_path_from_model(project_dir, model_name, "checkpoint.h5") + assert checkpoint_path.exists(), "Checkpoint file not created" + + kpms.reindex_syllables_in_checkpoint(project_dir, model_name) + + # Delete results.h5 if it exists (from previous test using same fixture) + _ = load_path_from_model( + project_dir, model_name, "results.h5", delete_existing=True + ) + + results = kpms.extract_results(model, metadata, project_dir, model_name, config) + + # Generate trajectory plots + kpms.generate_trajectory_plots( + coordinates, + results, + project_dir=project_dir, + model_name=model_name, + fps=config["fps"], + ) + + # Verify outputs + trajectory_dir = Path(project_dir) / model_name / "trajectory_plots" + assert trajectory_dir.exists(), "Trajectory plots directory not created" + + pdf_files = list(trajectory_dir.glob("*.pdf")) + assert len(pdf_files) > 0, "No trajectory PDFs created" + + # Note: generate_trajectory_plots filters by min_frequency and min_duration, + # so not all syllables will have plots. Just verify we got some plots created. + assert ( + len(pdf_files) >= 5 + ), f"Expected at least 5 trajectory plots, got {len(pdf_files)}" + + +@pytest.mark.slow +@pytest.mark.notebook +def test_grid_movies(fitted_model, kpms): + """Test grid movie generation + + Expected duration: ~2 minutes (uses fitted_model fixture + video rendering) + """ + from tests.conftest import load_path_from_model + + # Use fitted model from fixture + project_dir = fitted_model["project_dir"] + model = fitted_model["model"] + model_name = fitted_model["model_name"] + metadata = fitted_model["metadata"] + config = fitted_model["config"] + coordinates = fitted_model["coordinates"] + + # Verify checkpoint exists + checkpoint_path = load_path_from_model(project_dir, model_name, "checkpoint.h5") + assert checkpoint_path.exists(), "Checkpoint file not created" + + kpms.reindex_syllables_in_checkpoint(project_dir, model_name) + + # Delete results.h5 if it exists (from previous test using same fixture) + _ = load_path_from_model( + project_dir, model_name, "results.h5", delete_existing=True + ) + + results = kpms.extract_results(model, metadata, project_dir, model_name, config) + + # Generate grid movies (keypoints only, no video frames) + kpms.generate_grid_movies( + results, + project_dir=project_dir, + model_name=model_name, + coordinates=coordinates, + fps=30, + keypoints_only=True, + ) + + # Verify outputs + grid_movies_dir = Path(project_dir) / model_name / "grid_movies" + assert grid_movies_dir.exists(), "Grid movies directory not created" + + mp4_files = list(grid_movies_dir.glob("*.mp4")) + assert len(mp4_files) > 0, "No grid movies created" + + # Verify file sizes (should not be empty) + for mp4 in mp4_files: + assert mp4.stat().st_size > 1000, f"Grid movie too small: {mp4}" + + +@pytest.mark.medium +@pytest.mark.notebook +def test_similarity_dendrogram(fitted_model, kpms): + """Test similarity dendrogram generation + + Expected duration: ~1 minute (uses fitted_model fixture) + """ + from tests.conftest import load_path_from_model + + # Use fitted model from fixture + project_dir = fitted_model["project_dir"] + model = fitted_model["model"] + model_name = fitted_model["model_name"] + metadata = fitted_model["metadata"] + config = fitted_model["config"] + coordinates = fitted_model["coordinates"] + + # Verify checkpoint exists + checkpoint_path = load_path_from_model(project_dir, model_name, "checkpoint.h5") + assert checkpoint_path.exists(), "Checkpoint file not created" + + kpms.reindex_syllables_in_checkpoint(project_dir, model_name) + + # Delete results.h5 if it exists (from previous test using same fixture) + _ = load_path_from_model( + project_dir, model_name, "results.h5", delete_existing=True + ) + + # Extract results for dendrogram + results = kpms.extract_results(model, metadata, project_dir, model_name, config) + + # Generate dendrogram + kpms.plot_similarity_dendrogram( + coordinates, + results, + project_dir=project_dir, + model_name=model_name, + fps=config["fps"], + ) + + # Verify output + dendrogram_pdf = load_path_from_model( + project_dir, model_name, "similarity_dendrogram.pdf" + ) + assert dendrogram_pdf.exists(), "Dendrogram PDF not created" + + dendrogram_png = load_path_from_model( + project_dir, model_name, "similarity_dendrogram.png" + ) + assert dendrogram_png.exists(), "Dendrogram PNG not created" + + # Verify file sizes + assert dendrogram_pdf.stat().st_size > 1000, "Dendrogram PDF too small" + assert dendrogram_png.stat().st_size > 1000, "Dendrogram PNG too small" + + +@pytest.mark.quick +@pytest.mark.notebook +def test_syllable_statistics(): + """Test syllable statistics computation + + Expected duration: < 1 second + """ + # Mock syllable data + syllables = { + "rec1": np.array([0, 0, 1, 1, 1, 2, 2, 0, 0]), + "rec2": np.array([1, 1, 0, 0, 2, 2, 2, 1, 1, 1]), + } + + # Count syllable occurrences + all_syllables = np.concatenate([s for s in syllables.values()]) + unique, counts = np.unique(all_syllables, return_counts=True) + + # Verify + assert len(unique) == 3, "Should have 3 unique syllables" + assert sum(counts) == 19, "Total syllable count should be 19" + + # Check frequencies + frequencies = counts / sum(counts) + assert np.isclose(sum(frequencies), 1.0), "Frequencies should sum to 1" diff --git a/tests/test_analysis_unit.py b/tests/test_analysis_unit.py new file mode 100644 index 0000000..693aea4 --- /dev/null +++ b/tests/test_analysis_unit.py @@ -0,0 +1,593 @@ +""" +Unit tests for keypoint_moseq.analysis module + +Tests core analysis functions without requiring full model fitting. +Focuses on statistical analysis, transition matrices, and data processing functions. +""" + +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + + +@pytest.fixture +def mock_syllable_data(): + """Mock syllable data for testing""" + return { + "rec1": np.array([0, 0, 1, 1, 1, 2, 2, 0, 0, 3, 3, 3]), + "rec2": np.array([1, 1, 0, 0, 2, 2, 2, 1, 1, 1, 3, 3]), + "rec3": np.array([0, 0, 0, 1, 1, 2, 3, 3, 3, 3, 0, 0]), + } + + +@pytest.fixture +def mock_results_dict(): + """Mock results dictionary matching keypoint-MoSeq output format""" + return { + "rec1": { + "syllable": np.array([0, 0, 1, 1, 1, 2, 2, 0]), + "centroid": np.array( + [ + [0.0, 0.0], + [0.1, 0.1], + [0.2, 0.2], + [0.3, 0.3], + [0.4, 0.4], + [0.5, 0.5], + [0.6, 0.6], + [0.7, 0.7], + ] + ), + "heading": np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]), + }, + "rec2": { + "syllable": np.array([1, 1, 0, 0, 2, 2, 2, 1]), + "centroid": np.array( + [ + [1.0, 1.0], + [1.1, 1.1], + [1.2, 1.2], + [1.3, 1.3], + [1.4, 1.4], + [1.5, 1.5], + [1.6, 1.6], + [1.7, 1.7], + ] + ), + "heading": np.array([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7]), + }, + } + + +@pytest.fixture +def temp_project(): + """Create temporary project directory""" + tmpdir = tempfile.mkdtemp(prefix="kpms_analysis_test_") + yield tmpdir + shutil.rmtree(tmpdir, ignore_errors=True) + + +# Test transition matrix functions + + +@pytest.mark.quick +def test_get_transitions(): + """Test syllable transition detection""" + from keypoint_moseq.analysis import get_transitions + + # Simple case: clear transitions + labels = np.array([0, 0, 1, 1, 2, 2, 2, 3]) + transitions, locs = get_transitions(labels) + + assert len(transitions) == 3, "Should detect 3 transitions" + assert len(locs) == 3, "Should have 3 transition locations" + assert np.array_equal(transitions, [1, 2, 3]), "Transitions should be [1, 2, 3]" + assert np.array_equal(locs, [2, 4, 7]), "Locations should be [2, 4, 7]" + + +@pytest.mark.quick +def test_get_transitions_no_transitions(): + """Test get_transitions with no transitions""" + from keypoint_moseq.analysis import get_transitions + + # All same syllable + labels = np.array([5, 5, 5, 5, 5]) + transitions, locs = get_transitions(labels) + + assert len(transitions) == 0, "Should detect no transitions" + assert len(locs) == 0, "Should have no transition locations" + + +@pytest.mark.quick +def test_n_gram_transition_matrix(): + """Test n-gram transition matrix computation""" + from keypoint_moseq.analysis import n_gram_transition_matrix + + # Simple bigram case + labels = [0, 1, 2, 1, 0] + trans_mat = n_gram_transition_matrix(labels, n=2, max_label=5) + + assert trans_mat.shape == (5, 5), "Transition matrix should be 5x5" + assert trans_mat[0, 1] == 1.0, "0->1 transition should occur once" + assert trans_mat[1, 2] == 1.0, "1->2 transition should occur once" + assert trans_mat[2, 1] == 1.0, "2->1 transition should occur once" + assert trans_mat[1, 0] == 1.0, "1->0 transition should occur once" + + +@pytest.mark.quick +def test_normalize_transition_matrix(): + """Test transition matrix normalization""" + from keypoint_moseq.analysis import normalize_transition_matrix + + # Create simple 3x3 matrix + matrix = np.array( + [ + [1.0, 2.0, 1.0], + [3.0, 0.0, 1.0], + [0.0, 2.0, 2.0], + ] + ) + + # Test bigram normalization + norm_bigram = normalize_transition_matrix(matrix.copy(), "bigram") + assert np.isclose(norm_bigram.sum(), 1.0), "Bigram normalization should sum to 1" + + # Test row normalization + norm_rows = normalize_transition_matrix(matrix.copy(), "rows") + assert np.allclose(norm_rows.sum(axis=1), 1.0), "Row sums should be 1" + + # Test column normalization + norm_cols = normalize_transition_matrix(matrix.copy(), "columns") + assert np.allclose(norm_cols.sum(axis=0), 1.0), "Column sums should be 1" + + # Test None normalization (no change) + norm_none = normalize_transition_matrix(matrix.copy(), None) + assert np.array_equal( + norm_none, matrix + ), "None normalization should not change matrix" + + +@pytest.mark.quick +def test_get_transition_matrix_single_recording(mock_syllable_data): + """Test transition matrix for single recording""" + from keypoint_moseq.analysis import get_transition_matrix + + syllables = mock_syllable_data["rec1"] + trans_mats = get_transition_matrix(syllables, max_syllable=10, normalize="bigram") + + assert len(trans_mats) == 1, "Should return 1 transition matrix" + assert trans_mats[0].shape == (10, 10), "Matrix should be 10x10" + assert np.isclose(trans_mats[0].sum(), 1.0), "Normalized matrix should sum to 1" + + +@pytest.mark.quick +def test_get_transition_matrix_combined(mock_syllable_data): + """Test combined transition matrix across recordings""" + from keypoint_moseq.analysis import get_transition_matrix + + syllables = list(mock_syllable_data.values()) + trans_mat = get_transition_matrix( + syllables, max_syllable=10, normalize="bigram", combine=True + ) + + assert isinstance(trans_mat, np.ndarray), "Combined matrix should be ndarray" + assert trans_mat.shape == (10, 10), "Matrix should be 10x10" + assert np.isclose(trans_mat.sum(), 1.0), "Normalized matrix should sum to 1" + + +# Test syllable name functions + + +@pytest.mark.quick +def test_get_syllable_names_no_file(temp_project): + """Test get_syllable_names when syll_info.csv doesn't exist""" + from keypoint_moseq.analysis import get_syllable_names + + model_name = "test_model" + Path(temp_project, model_name).mkdir(parents=True) + syllable_ixs = [0, 1, 2] + + names = get_syllable_names(temp_project, model_name, syllable_ixs) + + assert len(names) == 3, "Should return 3 names" + assert names == ["0", "1", "2"], "Should return index strings when no file" + + +@pytest.mark.quick +def test_get_syllable_names_with_labels(temp_project): + """Test get_syllable_names with custom labels""" + from keypoint_moseq.analysis import get_syllable_names + + model_name = "test_model" + model_dir = Path(temp_project, model_name) + model_dir.mkdir(parents=True) + + # Create syll_info.csv with custom labels + syll_info = pd.DataFrame( + { + "syllable": [0, 1, 2], + "label": ["walk", "run", ""], + "short_description": ["walking behavior", "running", ""], + } + ) + syll_info.to_csv(model_dir / "syll_info.csv", index=False) + + syllable_ixs = [0, 1, 2] + names = get_syllable_names(temp_project, model_name, syllable_ixs) + + assert len(names) == 3, "Should return 3 names" + assert names[0] == "0 (walk)", "Syllable 0 should have custom label" + assert names[1] == "1 (run)", "Syllable 1 should have custom label" + assert names[2] == "2", "Syllable 2 should only have index (empty label)" + + +# Test index generation + + +@pytest.mark.quick +def test_generate_index_new_file(temp_project, mock_results_dict): + """Test index generation when file doesn't exist""" + from unittest.mock import patch + + from keypoint_moseq.analysis import generate_index + + model_name = "test_model" + model_dir = Path(temp_project, model_name) + model_dir.mkdir(parents=True) + index_filepath = Path(temp_project, "index.csv") + + # Mock load_results to return our mock data + with patch("keypoint_moseq.analysis.load_results", return_value=mock_results_dict): + generate_index(temp_project, model_name, str(index_filepath)) + + assert index_filepath.exists(), "Index file should be created" + + # Verify contents + index_df = pd.read_csv(index_filepath) + assert len(index_df) == 2, "Should have 2 recordings" + assert "name" in index_df.columns, "Should have 'name' column" + assert "group" in index_df.columns, "Should have 'group' column" + assert set(index_df["name"]) == { + "rec1", + "rec2", + }, "Should have correct recording names" + assert all(index_df["group"] == "default"), "All groups should be 'default'" + + +@pytest.mark.quick +def test_generate_index_append_missing(temp_project, mock_results_dict): + """Test index generation appends missing recordings""" + from unittest.mock import patch + + from keypoint_moseq.analysis import generate_index + + model_name = "test_model" + model_dir = Path(temp_project, model_name) + model_dir.mkdir(parents=True) + index_filepath = Path(temp_project, "index.csv") + + # Create existing index with only rec1 + existing_index = pd.DataFrame({"name": ["rec1"], "group": ["experimental"]}) + existing_index.to_csv(index_filepath, index=False) + + # Mock load_results to return data with rec1 and rec2 + with patch("keypoint_moseq.analysis.load_results", return_value=mock_results_dict): + generate_index(temp_project, model_name, str(index_filepath)) + + # Verify rec2 was added + index_df = pd.read_csv(index_filepath) + assert len(index_df) == 2, "Should have 2 recordings now" + assert "rec2" in index_df["name"].values, "rec2 should be added" + assert ( + index_df[index_df["name"] == "rec1"]["group"].values[0] == "experimental" + ), "rec1 group should be preserved" + assert ( + index_df[index_df["name"] == "rec2"]["group"].values[0] == "default" + ), "rec2 should have default group" + + +# Test syllable sorting functions + + +@pytest.mark.quick +def test_sort_syllables_by_stat_frequency(): + """Test sorting syllables by frequency""" + from keypoint_moseq.analysis import sort_syllables_by_stat + + # Create mock stats dataframe + stats_df = pd.DataFrame( + { + "syllable": [2, 0, 1, 3], + "frequency": [0.3, 0.1, 0.4, 0.2], + "duration": [1.0, 2.0, 1.5, 1.2], + } + ) + + ordering, relabel_mapping = sort_syllables_by_stat(stats_df, stat="frequency") + + # For frequency, should sort by syllable index + assert ordering == [ + 0, + 1, + 2, + 3, + ], "Frequency sorting should use syllable index order" + assert relabel_mapping == { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + }, "Mapping should be identity for sorted indices" + + +@pytest.mark.quick +def test_sort_syllables_by_stat_duration(): + """Test sorting syllables by duration""" + from keypoint_moseq.analysis import sort_syllables_by_stat + + # Create mock stats dataframe + stats_df = pd.DataFrame( + { + "syllable": [0, 1, 2, 3], + "frequency": [0.1, 0.4, 0.3, 0.2], + "duration": [2.0, 1.5, 1.0, 1.2], + "group": ["A", "A", "A", "A"], + } + ) + + ordering, relabel_mapping = sort_syllables_by_stat(stats_df, stat="duration") + + # Should sort by duration descending + assert ordering[0] == 0, "Syllable 0 has highest duration (2.0)" + assert ordering[-1] == 2, "Syllable 2 has lowest duration (1.0)" + + +@pytest.mark.quick +def test_sort_syllables_by_stat_difference(): + """Test sorting syllables by difference between groups""" + from keypoint_moseq.analysis import sort_syllables_by_stat_difference + + # Create mock stats dataframe with two groups + stats_df = pd.DataFrame( + { + "syllable": [0, 0, 1, 1, 2, 2], + "group": [ + "control", + "experimental", + "control", + "experimental", + "control", + "experimental", + ], + "frequency": [0.2, 0.5, 0.3, 0.1, 0.5, 0.4], + "name": ["rec1", "rec2", "rec1", "rec2", "rec1", "rec2"], + } + ) + + ordering = sort_syllables_by_stat_difference( + stats_df, "control", "experimental", stat="frequency" + ) + + # Syllable 0: exp(0.5) - ctrl(0.2) = +0.3 (highest increase) + # Syllable 2: exp(0.4) - ctrl(0.5) = -0.1 (small decrease) + # Syllable 1: exp(0.1) - ctrl(0.3) = -0.2 (largest decrease) + assert ordering[0] == 0, "Syllable 0 should be first (largest increase)" + assert ordering[-1] == 1, "Syllable 1 should be last (largest decrease)" + + +# Test Kruskal-Wallis helper functions + + +@pytest.mark.quick +def test_get_tie_correction(): + """Test tie correction computation for Kruskal-Wallis""" + from keypoint_moseq.analysis import get_tie_correction + + # Case 1: No ties + x = pd.Series([1, 2, 3, 4, 5]) + N_m = 5 + correction = get_tie_correction(x, N_m) + assert correction == 0.0, "No ties should give 0 correction" + + # Case 2: Some ties + x = pd.Series([1, 1, 2, 2, 2]) + N_m = 5 + correction = get_tie_correction(x, N_m) + assert correction > 0.0, "Ties should give positive correction" + + +# Test moseq dataframe computation + + +@pytest.mark.quick +def test_compute_moseq_df_basic_structure(temp_project, mock_results_dict): + """Test compute_moseq_df creates proper dataframe structure""" + from unittest.mock import patch + + from keypoint_moseq.analysis import compute_moseq_df + + model_name = "test_model" + model_dir = Path(temp_project, model_name) + model_dir.mkdir(parents=True) + + # Create index file to avoid UnboundLocalError in compute_moseq_df + index_df = pd.DataFrame({"name": ["rec1", "rec2"], "group": ["default", "default"]}) + index_df.to_csv(Path(temp_project, "index.csv"), index=False) + + with patch("keypoint_moseq.analysis.load_results", return_value=mock_results_dict): + moseq_df = compute_moseq_df( + temp_project, model_name, fps=30, smooth_heading=False + ) + + # Verify structure + assert isinstance(moseq_df, pd.DataFrame), "Should return DataFrame" + assert len(moseq_df) == 16, "Should have 16 rows (8 frames × 2 recordings)" + + # Check required columns + required_cols = [ + "name", + "centroid_x", + "centroid_y", + "heading", + "angular_velocity", + "velocity_px_s", + "syllable", + "frame_index", + "group", + "onset", + ] + for col in required_cols: + assert col in moseq_df.columns, f"Missing column: {col}" + + # Check data types + assert moseq_df["syllable"].dtype in [ + np.int32, + np.int64, + ], "Syllables should be integers" + assert moseq_df["onset"].dtype == bool, "Onset should be boolean" + + +@pytest.mark.quick +def test_compute_moseq_df_onset_detection(temp_project, mock_results_dict): + """Test syllable onset detection in compute_moseq_df""" + from unittest.mock import patch + + from keypoint_moseq.analysis import compute_moseq_df + + model_name = "test_model" + model_dir = Path(temp_project, model_name) + model_dir.mkdir(parents=True) + + # Create index file to avoid UnboundLocalError in compute_moseq_df + index_df = pd.DataFrame({"name": ["rec1", "rec2"], "group": ["default", "default"]}) + index_df.to_csv(Path(temp_project, "index.csv"), index=False) + + with patch("keypoint_moseq.analysis.load_results", return_value=mock_results_dict): + moseq_df = compute_moseq_df( + temp_project, model_name, fps=30, smooth_heading=False + ) + + # Check onset detection + # First frame of each recording should have onset=True + rec1_data = moseq_df[moseq_df["name"] == "rec1"] + assert rec1_data.iloc[0]["onset"], "First frame should have onset" + + # Frames where syllable changes should have onset=True + syllables = rec1_data["syllable"].values + onsets = rec1_data["onset"].values + + # Check transitions + for i in range(1, len(syllables)): + if syllables[i] != syllables[i - 1]: + assert onsets[i], f"Frame {i} should have onset (transition)" + + +# Test validation function + + +@pytest.mark.quick +def test_validate_and_order_syll_stats_params(): + """Test parameter validation and ordering""" + from keypoint_moseq.analysis import _validate_and_order_syll_stats_params + + # Create mock dataframe + complete_df = pd.DataFrame( + { + "syllable": [0, 1, 2, 0, 1, 2], + "group": ["control", "control", "control", "exp", "exp", "exp"], + "frequency": [0.3, 0.2, 0.5, 0.4, 0.3, 0.3], + "duration": [1.0, 2.0, 1.5, 1.2, 1.8, 1.3], + } + ) + + # Test basic validation + ordering, groups, colors, figsize = _validate_and_order_syll_stats_params( + complete_df, stat="frequency", order="stat" + ) + + assert len(ordering) > 0, "Should return ordering" + assert len(groups) > 0, "Should return groups" + assert len(colors) == len(groups), "Should have color for each group" + assert figsize == (10, 5), "Should return figsize" + + +@pytest.mark.quick +def test_validate_and_order_invalid_stat(): + """Test validation with invalid statistic""" + from keypoint_moseq.analysis import _validate_and_order_syll_stats_params + + complete_df = pd.DataFrame( + { + "syllable": [0, 1, 2], + "group": ["A", "A", "A"], + "frequency": [0.3, 0.2, 0.5], + } + ) + + with pytest.raises(ValueError, match="Invalid stat entered"): + _validate_and_order_syll_stats_params( + complete_df, stat="nonexistent_column", order="stat" + ) + + +@pytest.mark.quick +def test_validate_and_order_diff_without_groups(): + """Test validation for diff ordering without proper groups""" + from keypoint_moseq.analysis import _validate_and_order_syll_stats_params + + complete_df = pd.DataFrame( + { + "syllable": [0, 1, 2], + "group": ["A", "A", "A"], + "frequency": [0.3, 0.2, 0.5], + } + ) + + with pytest.raises(ValueError, match="Attempting to sort by"): + _validate_and_order_syll_stats_params( + complete_df, + stat="frequency", + order="diff", + ctrl_group="B", # Group B doesn't exist + exp_group="C", + ) + + +# Test summary statistics computation + + +@pytest.mark.quick +def test_compute_stats_df_basic(temp_project, mock_results_dict): + """Test basic stats dataframe computation""" + from unittest.mock import patch + + from keypoint_moseq.analysis import compute_moseq_df, compute_stats_df + + model_name = "test_model" + model_dir = Path(temp_project, model_name) + model_dir.mkdir(parents=True) + + # Create index file + index_df = pd.DataFrame( + {"name": ["rec1", "rec2"], "group": ["control", "experimental"]} + ) + index_df.to_csv(Path(temp_project, "index.csv"), index=False) + + with patch("keypoint_moseq.analysis.load_results", return_value=mock_results_dict): + moseq_df = compute_moseq_df( + temp_project, model_name, fps=30, smooth_heading=False + ) + stats_df = compute_stats_df( + temp_project, model_name, moseq_df, min_frequency=0.0, fps=30 + ) + + # Verify structure + assert isinstance(stats_df, pd.DataFrame), "Should return DataFrame" + assert "syllable" in stats_df.columns, "Should have syllable column" + assert "frequency" in stats_df.columns, "Should have frequency column" + assert "duration" in stats_df.columns, "Should have duration column" + assert "group" in stats_df.columns, "Should have group column" diff --git a/tests/test_colab_workflow.py b/tests/test_colab_workflow.py new file mode 100644 index 0000000..c8544ae --- /dev/null +++ b/tests/test_colab_workflow.py @@ -0,0 +1,350 @@ +""" +Test suite for the keypoint-MoSeq colab workflow + +This test suite validates the complete workflow from the colab notebook, +adapted for pytest with appropriate fixtures and assertions. +""" + +from pathlib import Path + +import h5py +import numpy as np +import pytest + + +@pytest.mark.integration +@pytest.mark.notebook +def test_complete_workflow( + temp_project_dir, dlc_config, dlc_videos_dir, reduced_iterations, kpms +): + """Test the complete keypoint-MoSeq workflow end-to-end + + This test runs the full pipeline with reduced iterations suitable for CI/CD. + Expected duration: ~15 minutes + """ + from tests.conftest import compute_latent_dim, load_path_from_model + + project_dir = temp_project_dir + + # Step 1: Setup project + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + assert Path(project_dir, "config.yml").exists(), "Config file not created" + + # Step 2: Update config + kpms.update_config( + project_dir, + use_bodyparts=[ + "spine4", + "spine3", + "spine2", + "spine1", + "head", + "nose", + "right ear", + "left ear", + ], + anterior_bodyparts=["head", "nose", "right ear", "left ear"], + posterior_bodyparts=["spine4", "spine3", "spine2", "spine1"], + seg_length=5, + ) + config = kpms.load_config(project_dir) + + # Step 3: Load keypoints + coordinates, confidences, bodyparts = kpms.load_keypoints( + dlc_videos_dir, "deeplabcut" + ) + assert len(coordinates) > 0, "No keypoints loaded" + assert len(bodyparts) == 9, f"Expected 9 bodyparts, got {len(bodyparts)}" + + # Step 4: Outlier removal (before formatting) + kpms.update_config(project_dir, outlier_scale_factor=6.0) + coordinates, confidences = kpms.outlier_removal( + coordinates, confidences, project_dir, overwrite=True, **config + ) + qa_dir = Path(project_dir) / "QA" / "plots" + assert qa_dir.exists(), "QA plots directory not created" + + # Step 5: Format data after outlier removal + data, metadata = kpms.format_data(coordinates, confidences, **config) + assert "Y" in data, "Formatted data missing Y" + assert "conf" in data, "Formatted data missing conf" + + # Step 6: Skip calibration (not needed for minimal dataset) + # Manual calibration widget would go here in interactive mode + + # Step 7: Fit PCA + pca = kpms.fit_pca(**data, **config) + kpms.save_pca(pca, project_dir) + pca_path = Path(project_dir) / "pca.p" + assert pca_path.exists(), "PCA model not saved" + + # Step 8: Update latent dimensions + latent_dim = compute_latent_dim(pca, variance_threshold=0.9) + assert latent_dim >= 3, f"Expected at least 3 PCs, got {latent_dim}" + kpms.update_config(project_dir, latent_dim=int(latent_dim)) + config = kpms.load_config(project_dir) + + # Step 9: Estimate hyperparameters + sigmasq_loc = kpms.estimate_sigmasq_loc( + data["Y"], data["mask"], filter_size=config["fps"] + ) + kpms.update_config(project_dir, sigmasq_loc=sigmasq_loc) + config = kpms.load_config(project_dir) + + # Step 10: Initialize model + model = kpms.init_model(data, pca=pca, **config) + assert model is not None, "Model initialization failed" + + # Step 11: Fit AR-HMM with reduced iterations + model, model_name = kpms.fit_model( + model, + data, + metadata, + project_dir, + ar_only=True, + num_iters=reduced_iterations["ar_hmm_iters"], + ) + + # Step 12: Fit full model with reduced iterations + model, _ = kpms.fit_model( + model, + data, + metadata, + project_dir, + ar_only=False, + num_iters=reduced_iterations["full_model_iters"], + ) + + # Step 13: Verify checkpoint was saved by fit_model + checkpoint_path = load_path_from_model(project_dir, model_name, "checkpoint.h5") + assert checkpoint_path.exists(), "Checkpoint file not created" + + # Step 14: Reindex syllables + kpms.reindex_syllables_in_checkpoint(project_dir, model_name) + + # Step 15: Extract results + results = kpms.extract_results(model, metadata, project_dir, model_name, config) + example_model = results[metadata[0][0]] + assert "syllable" in example_model, "Results missing syllable labels" + + results_h5_path = Path(project_dir) / model_name / "results.h5" + assert results_h5_path.exists(), "Results HDF5 not created" + + # Validate results structure + with h5py.File(results_h5_path, "r") as f: + recording_keys = list(f.keys()) + assert len(recording_keys) > 0, "No recordings in results" + + first_recording = f[recording_keys[0]] + # Verify required datasets are present + required_datasets = {"syllable", "centroid", "heading", "latent_state"} + actual_datasets = set(first_recording.keys()) + missing = required_datasets - actual_datasets + assert not missing, f"Results missing datasets: {missing}" + + # Step 16: Save as CSV + results_dir = load_path_from_model(project_dir, model_name, "results") + csv_files_before = list(results_dir.glob("*.csv")) if results_dir.exists() else [] + + kpms.save_results_as_csv(results, project_dir, model_name) + assert results_dir.exists(), "Results CSV directory not created" + + csv_files_after = list(results_dir.glob("*.csv")) + assert len(csv_files_after) > len(csv_files_before), "No new CSV files created" + + # Step 17: Generate visualizations + # Add video_dir to config for visualization functions + config["video_dir"] = dlc_videos_dir + + # Generate trajectory plots + kpms.generate_trajectory_plots( + coordinates=coordinates, + results=results, + project_dir=project_dir, + model_name=model_name, + **config, + ) + trajectory_dir = load_path_from_model(project_dir, model_name, "trajectory_plots") + assert trajectory_dir.exists(), "Trajectory plots directory not created" + + num_syllables = len(set(example_model["syllable"])) + assert num_syllables > 0, "No syllables identified" + + # Check for trajectory plots + pdf_plots = [f for f in trajectory_dir.glob("*.pdf")] + assert len(pdf_plots) > 0, "No trajectory PDFs created" + + # Generate grid movies + kpms.generate_grid_movies( + coordinates=coordinates, + results=results, + project_dir=project_dir, + model_name=model_name, + frame_path=None, + **config, + ) + grid_movies_dir = load_path_from_model(project_dir, model_name, "grid_movies") + assert grid_movies_dir.exists(), "Grid movies directory not created" + + mp4_files = [f for f in grid_movies_dir.glob("*.mp4")] + assert len(mp4_files) > 0, "No grid movies created" + + # Generate similarity dendrogram + kpms.plot_similarity_dendrogram( + coordinates=coordinates, + results=results, + project_dir=project_dir, + model_name=model_name, + **config, + ) + dendrogram_pdf = load_path_from_model( + project_dir, model_name, "similarity_dendrogram.pdf" + ) + assert dendrogram_pdf.exists(), "Similarity dendrogram not created" + + +@pytest.mark.quick +@pytest.mark.notebook +def test_project_setup(temp_project_dir, dlc_config, kpms): + """Test project setup and configuration + + Expected duration: < 1 second + """ + project_dir = temp_project_dir + + # Test setup + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + + # Verify files created + config_path = Path(project_dir, "config.yml") + assert config_path.exists(), "Config file not created" + + # Update config with valid bodyparts before loading + # (setup_project creates placeholders that need to be updated) + kpms.update_config( + project_dir, + use_bodyparts=[ + "spine4", + "spine3", + "spine2", + "spine1", + "head", + "nose", + "right ear", + "left ear", + ], + anterior_bodyparts=["head", "nose", "right ear", "left ear"], + posterior_bodyparts=["spine4", "spine3", "spine2", "spine1"], + ) + + # Test config loading after update + config = kpms.load_config(project_dir) + expected_keys = {"bodyparts", "fps", "use_bodyparts"} + assert expected_keys.issubset( + config.keys() + ), f"Config missing keys: {expected_keys - config.keys()}" + assert len(config["use_bodyparts"]) == 8, "Wrong number of use_bodyparts" + + +@pytest.mark.quick +@pytest.mark.notebook +def test_load_keypoints(temp_project_dir, dlc_config, dlc_videos_dir, kpms): + """Test keypoint loading from DLC data + + Expected duration: < 1 second + """ + project_dir = temp_project_dir + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + + # Load keypoints from DLC videos directory (not project_dir) + coordinates, confidences, bodyparts = kpms.load_keypoints( + dlc_videos_dir, "deeplabcut" + ) + + # Verify data structure + assert len(coordinates) > 0, "No coordinates loaded" + assert len(confidences) > 0, "No confidences loaded" + assert len(bodyparts) == 9, f"Expected 9 bodyparts, got {len(bodyparts)}" + + # Check data types + first_recording = next(iter(coordinates.keys())) + assert isinstance( + coordinates[first_recording], np.ndarray + ), "Coordinates not numpy array" + assert coordinates[first_recording].ndim == 3, "Coordinates wrong shape" + + +@pytest.mark.medium +@pytest.mark.notebook +def test_format_and_outlier_detection( + temp_project_dir, dlc_config, dlc_videos_dir, kpms, update_kwargs +): + """Test data formatting and outlier detection + + Expected duration: ~1 minute + """ + project_dir = temp_project_dir + + # Setup + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + + # Update config using fixture + kpms.update_config(project_dir, **update_kwargs) + config = kpms.load_config(project_dir) + + # Load keypoints + coordinates, confidences, bodyparts = kpms.load_keypoints( + dlc_videos_dir, "deeplabcut" + ) + + # Format data + data, metadata = kpms.format_data(coordinates, confidences, **config) + assert "Y" in data, "Formatted data missing Y" + + # Test outlier removal (matches notebook API) + kpms.update_config(project_dir, outlier_scale_factor=6.0) + coordinates_clean, confidences_clean = kpms.outlier_removal( + coordinates, confidences, project_dir, overwrite=True, **config + ) + + # Verify outputs + assert len(coordinates_clean) > 0, "No coordinates after outlier removal" + assert len(confidences_clean) > 0, "No confidences after outlier removal" + + qa_dir = Path(project_dir) / "QA" / "plots" + assert qa_dir.exists(), "QA directory not created" + + +@pytest.mark.medium +@pytest.mark.notebook +def test_pca_fitting(temp_project_dir, dlc_config, dlc_videos_dir, kpms, update_kwargs): + """Test PCA model fitting + + Expected duration: ~5 seconds + """ + from tests.conftest import compute_latent_dim + + project_dir = temp_project_dir + + # Setup and load data + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + + # Update config using fixture + kpms.update_config(project_dir, **update_kwargs) + config = kpms.load_config(project_dir) + + coordinates, confidences, _ = kpms.load_keypoints(dlc_videos_dir, "deeplabcut") + data, metadata = kpms.format_data(coordinates, confidences, **config) + + # Fit PCA + pca = kpms.fit_pca(**data, **config) + kpms.save_pca(pca, project_dir) + + # Verify PCA + pca_path = Path(project_dir) / "pca.p" + assert pca_path.exists(), "PCA model not saved" + + # Test variance explained using helper + latent_dim = compute_latent_dim(pca, variance_threshold=0.9) + assert latent_dim >= 3, f"Expected at least 3 PCs, got {latent_dim}" + assert latent_dim <= 10, f"Too many PCs required: {latent_dim}" diff --git a/tests/test_fitting.py b/tests/test_fitting.py new file mode 100644 index 0000000..bb79fcd --- /dev/null +++ b/tests/test_fitting.py @@ -0,0 +1,774 @@ +""" +Unit tests for keypoint_moseq.fitting module + +Target coverage: 36% → 80% (from 241/671 statements to ~536/671) +Functions tested: +- _wrapped_resample() +- _set_parallel_flag() +- init_model() +- fit_model() +- apply_model() +- estimate_syllable_marginals() +- update_hypparams() +- expected_marginal_likelihoods() +""" + +import os +import tempfile +import warnings +from unittest.mock import Mock, patch + +import h5py +import jax.numpy as jnp +import numpy as np +import pytest + +from keypoint_moseq.fitting import ( + StopResampling, + _set_parallel_flag, + _wrapped_resample, + apply_model, + estimate_syllable_marginals, + expected_marginal_likelihoods, + fit_model, + init_model, + update_hypparams, +) + +# Suppress JAX/matplotlib warnings for clean test output +warnings.filterwarnings("ignore", category=UserWarning, message=".*os.fork.*") +warnings.filterwarnings("ignore", category=UserWarning, message=".*FigureCanvasAgg.*") +warnings.filterwarnings("ignore", category=UserWarning, message=".*NVIDIA GPU.*") + + +@pytest.mark.quick +class TestWrappedResample: + """Test _wrapped_resample function.""" + + def test_successful_resample(self): + """Test successful resampling without NaNs or interrupts.""" + # Mock resample function that returns updated model + mock_resample = Mock(return_value={"states": {"x": jnp.array([1.0])}}) + data = {"Y": jnp.array([1.0])} + model = {"states": {"x": jnp.array([0.5])}} + + with patch("keypoint_moseq.fitting.check_for_nans") as mock_check: + mock_check.return_value = (False, {}, []) + result = _wrapped_resample(mock_resample, data, model) + + assert "states" in result + mock_resample.assert_called_once_with(data, **model) + + def test_keyboard_interrupt(self): + """Test KeyboardInterrupt handling.""" + mock_resample = Mock(side_effect=KeyboardInterrupt) + data = {"Y": jnp.array([1.0])} + model = {"states": {"x": jnp.array([0.5])}} + + with pytest.raises(StopResampling): + _wrapped_resample(mock_resample, data, model) + + def test_nan_detection(self): + """Test NaN detection during resampling.""" + mock_resample = Mock(return_value={"states": {"x": jnp.array([np.nan])}}) + data = {"Y": jnp.array([1.0])} + model = {"states": {"x": jnp.array([0.5])}} + + with patch("keypoint_moseq.fitting.check_for_nans") as mock_check: + mock_check.return_value = ( + True, + {"x": np.nan}, + ["NaN found in states.x"], + ) + with pytest.warns(UserWarning, match="Early termination.*NaNs"): + with pytest.raises(StopResampling): + _wrapped_resample(mock_resample, data, model) + + def test_with_progress_bar(self): + """Test with progress bar parameter.""" + mock_resample = Mock(return_value={"states": {"x": jnp.array([1.0])}}) + mock_pbar = Mock() + data = {"Y": jnp.array([1.0])} + model = {"states": {"x": jnp.array([0.5])}} + + with patch("keypoint_moseq.fitting.check_for_nans") as mock_check: + mock_check.return_value = (False, {}, []) + result = _wrapped_resample(mock_resample, data, model, pbar=mock_pbar) + + assert "states" in result + + def test_nan_with_progress_bar_closes(self): + """Test that progress bar is closed when NaN detected.""" + mock_resample = Mock(return_value={"states": {"x": jnp.array([np.nan])}}) + mock_pbar = Mock() + data = {"Y": jnp.array([1.0])} + model = {"states": {"x": jnp.array([0.5])}} + + with patch("keypoint_moseq.fitting.check_for_nans") as mock_check: + mock_check.return_value = (True, {}, ["NaN detected"]) + with pytest.warns(UserWarning): + with pytest.raises(StopResampling): + _wrapped_resample(mock_resample, data, model, pbar=mock_pbar) + + mock_pbar.close.assert_called_once() + + +@pytest.mark.quick +class TestSetParallelFlag: + """Test _set_parallel_flag function.""" + + def test_force_true(self): + """Test force=True always returns True.""" + result = _set_parallel_flag("force") + assert result is True + + def test_none_with_gpu(self): + """Test None with GPU backend.""" + with patch("jax.default_backend", return_value="gpu"): + result = _set_parallel_flag(None) + assert result is True + + def test_none_with_cpu(self): + """Test None with CPU backend.""" + with patch("jax.default_backend", return_value="cpu"): + result = _set_parallel_flag(None) + assert result is False + + def test_explicit_true_with_cpu_warns(self): + """Test explicit True with CPU backend raises warning.""" + with patch("jax.default_backend", return_value="cpu"): + with pytest.warns(UserWarning, match="CPU-bound"): + result = _set_parallel_flag(True) + assert result is True + + def test_explicit_false(self): + """Test explicit False returns False.""" + result = _set_parallel_flag(False) + assert result is False + + +@pytest.mark.quick +class TestInitModel: + """Test init_model function.""" + + def test_standard_model(self): + """Test standard keypoint-SLDS model initialization.""" + data = {"Y": jnp.ones((10, 5, 2))} + + with patch("keypoint_moseq.fitting.keypoint_slds.init_model") as mock_init: + mock_init.return_value = {"model": "standard"} + result = init_model(data, location_aware=False) + + assert result == {"model": "standard"} + mock_init.assert_called_once() + + def test_location_aware_model(self): + """Test location-aware model initialization.""" + data = {"Y": jnp.ones((10, 5, 2))} + trans_hypparams = {"num_states": 50} + + with patch("keypoint_moseq.fitting.allo_keypoint_slds.init_model") as mock_init: + mock_init.return_value = {"model": "allow"} + result = init_model( + data, + location_aware=True, + trans_hypparams=trans_hypparams, + ) + + assert result == {"model": "allow"} + mock_init.assert_called_once() + + def test_location_aware_allo_hypparams(self): + """Test that location-aware model sets allo_hypparams.""" + data = {"Y": jnp.ones((10, 5, 2))} + trans_hypparams = {"num_states": 30} + + with patch("keypoint_moseq.fitting.allo_keypoint_slds.init_model") as mock_init: + mock_init.return_value = {"model": "allow"} + _ = init_model( + data, + location_aware=True, + trans_hypparams=trans_hypparams, + ) + + # Check that allo_hypparams were passed + call_kwargs = mock_init.call_args[1] + assert "allo_hypparams" in call_kwargs + assert call_kwargs["allo_hypparams"]["num_states"] == 30 + + +@pytest.mark.quick +class TestUpdateHypparams: + """Test update_hypparams function.""" + + def test_update_scalar_hyperparam(self): + """Test updating a scalar hyperparameter.""" + model = { + "hypparams": { + "trans_hypparams": {"kappa": 1e3}, + "ar_hypparams": {"nlags": 3}, + } + } + + result = update_hypparams(model, kappa=1e4) + + assert result["hypparams"]["trans_hypparams"]["kappa"] == 1e4 + + def test_update_multiple_hypparams(self): + """Test updating multiple hyperparameters.""" + model = { + "hypparams": { + "trans_hypparams": {"kappa": 1e3, "gamma": 1e2}, + "ar_hypparams": {"nlags": 3}, + } + } + + result = update_hypparams(model, kappa=5e3, gamma=5e2) + + assert result["hypparams"]["trans_hypparams"]["kappa"] == 5e3 + assert result["hypparams"]["trans_hypparams"]["gamma"] == 5e2 + + def test_type_conversion_warning(self): + """Test warning when type conversion is needed.""" + model = { + "hypparams": { + "trans_hypparams": {"kappa": 1000.0}, # float + } + } + + with pytest.warns(UserWarning, match="will be cast"): + result = update_hypparams(model, kappa=2000) # int + + assert result["hypparams"]["trans_hypparams"]["kappa"] == 2000.0 + + def test_non_scalar_hyperparam_not_updated(self): + """Test that non-scalar hyperparameters are not updated.""" + model = { + "hypparams": { + "trans_hypparams": { + "kappa": 1e3, + "matrix_param": np.array([[1, 2], [3, 4]]), + }, + } + } + + # Should print message but not raise error + result = update_hypparams(model, matrix_param=np.array([[5, 6], [7, 8]])) + + # Original matrix should be unchanged + np.testing.assert_array_equal( + result["hypparams"]["trans_hypparams"]["matrix_param"], + np.array([[1, 2], [3, 4]]), + ) + + def test_unknown_hyperparam_warns(self): + """Test warning for unknown hyperparameter.""" + model = { + "hypparams": { + "trans_hypparams": {"kappa": 1e3}, + } + } + + with pytest.warns(UserWarning, match="not found"): + _ = update_hypparams(model, unknown_param=999) + + def test_missing_hypparams_raises(self): + """Test error when model has no hypparams.""" + model = {"states": {}, "params": {}} + + with pytest.raises(AssertionError, match="does not contain any hyperparams"): + update_hypparams(model, kappa=1e4) + + +@pytest.mark.quick +class TestFitModelParameters: + """Test fit_model parameter validation and setup.""" + + def test_explicit_model_name_used(self): + """Test that explicit model name is used instead of auto-generation.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + # This simpler test just checks the directory is created with the right name + test_name = "my_custom_model" + + with patch("keypoint_moseq.fitting._wrapped_resample") as mock_resample: + mock_resample.return_value = model + with patch("keypoint_moseq.fitting.save_hdf5"): + with patch( + "keypoint_moseq.fitting.device_put_as_scalar" + ) as mock_device: + mock_device.return_value = model + _, returned_name = fit_model( + model, + data, + metadata, + project_dir=tmpdir, + model_name=test_name, + num_iters=0, # No iterations, just check setup + ) + + assert returned_name == test_name + assert os.path.exists(os.path.join(tmpdir, test_name)) + + def test_save_every_n_iters_none_no_save(self): + """Test save_every_n_iters=None disables saving.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with patch("keypoint_moseq.fitting._wrapped_resample") as mock_resample: + mock_resample.return_value = model + with patch("keypoint_moseq.fitting.save_hdf5") as mock_save: + with patch( + "keypoint_moseq.fitting.device_put_as_scalar" + ) as mock_device: + mock_device.return_value = model + result, _ = fit_model( + model, + data, + metadata, + project_dir=tmpdir, + save_every_n_iters=None, + num_iters=2, + ) + + # save_hdf5 should not be called + mock_save.assert_not_called() + + def test_progress_plots_require_saving(self): + """Test warning when progress plots requested but saving disabled.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with pytest.warns(UserWarning, match="Progress plots"): + with patch("keypoint_moseq.fitting._wrapped_resample") as mock_resample: + mock_resample.return_value = model + with patch( + "keypoint_moseq.fitting.device_put_as_scalar" + ) as mock_device: + mock_device.return_value = model + fit_model( + model, + data, + metadata, + project_dir=tmpdir, + save_every_n_iters=0, + generate_progress_plots=True, + num_iters=1, + ) + + def test_ar_only_mode(self): + """Test AR-only fitting mode.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with patch("keypoint_moseq.fitting._wrapped_resample") as mock_resample: + mock_resample.return_value = model + with patch( + "keypoint_moseq.fitting.device_put_as_scalar" + ) as mock_device: + mock_device.return_value = model + fit_model( + model, + data, + metadata, + project_dir=tmpdir, + save_every_n_iters=None, + ar_only=True, + num_iters=1, + ) + + # Check ar_only was passed + call_kwargs = mock_resample.call_args[1] + assert call_kwargs["ar_only"] is True + + def test_location_aware_uses_allo_resample(self): + """Test location_aware mode uses allow resample function.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with patch( + "keypoint_moseq.fitting.allo_keypoint_slds.resample_model" + ) as mock_allo: + mock_allo.return_value = model + with patch( + "keypoint_moseq.fitting.device_put_as_scalar" + ) as mock_device: + mock_device.return_value = model + fit_model( + model, + data, + metadata, + project_dir=tmpdir, + save_every_n_iters=None, + location_aware=True, + num_iters=1, + ) + + assert mock_allo.called + + # Helper methods + def _create_mock_model(self): + """Create a minimal mock model.""" + return { + "states": {"x": jnp.ones((10, 5)), "z": jnp.zeros(10, dtype=int)}, + "params": {"Ab": jnp.eye(5)}, + "hypparams": {"trans_hypparams": {"num_states": 10}}, + "seed": 0, + } + + def _create_mock_data(self): + """Create minimal mock data.""" + return { + "Y": jnp.ones((10, 5, 2)), + "mask": jnp.ones((10, 5), dtype=bool), + } + + +@pytest.mark.quick +class TestApplyModelBasics: + """Test apply_model basic functionality.""" + + def test_save_results_requires_params(self): + """Test that save_results=True requires project_dir and model_name.""" + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with pytest.raises(AssertionError, match="requires either"): + apply_model( + model, + data, + metadata, + save_results=True, + # Missing project_dir and model_name + ) + + def test_results_path_override(self): + """Test that results_path overrides project_dir/model_name.""" + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with tempfile.TemporaryDirectory() as tmpdir: + custom_path = os.path.join(tmpdir, "custom_results.h5") + + with patch("keypoint_moseq.fitting._wrapped_resample") as mock_resample: + mock_resample.return_value = model + with patch("keypoint_moseq.fitting.init_model") as mock_init: + mock_init.return_value = model + with patch( + "keypoint_moseq.fitting.extract_results" + ) as mock_extract: + mock_extract.return_value = {"rec1": {}} + with patch("jax.device_put") as mock_device: + mock_device.return_value = data + apply_model( + model, + data, + metadata, + save_results=True, + results_path=custom_path, + num_iters=1, + ) + + # Check extract_results was called - check positional args + call_args = mock_extract.call_args[0] + # extract_results(model, metadata, project_dir, model_name, save_results, results_path) + # Custom path should be the last positional arg + assert call_args[-1] == custom_path + + def test_return_model_option(self): + """Test return_model=True returns both results and model.""" + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with patch("keypoint_moseq.fitting._wrapped_resample") as mock_resample: + mock_resample.return_value = model + with patch("keypoint_moseq.fitting.init_model") as mock_init: + mock_init.return_value = model + with patch("keypoint_moseq.fitting.extract_results") as mock_extract: + mock_extract.return_value = {"rec1": {}} + with patch("jax.device_put") as mock_device: + mock_device.return_value = data + results, returned_model = apply_model( + model, + data, + metadata, + save_results=False, + return_model=True, + num_iters=1, + ) + + assert "rec1" in results + assert returned_model == model + + def test_location_aware_apply(self): + """Test location_aware mode in apply_model.""" + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with patch( + "keypoint_moseq.fitting.allo_keypoint_slds.resample_model" + ) as mock_allo: + mock_allo.return_value = model + with patch("keypoint_moseq.fitting.init_model") as mock_init: + mock_init.return_value = model + with patch("keypoint_moseq.fitting.extract_results") as mock_extract: + mock_extract.return_value = {"rec1": {}} + with patch("jax.device_put") as mock_device: + mock_device.return_value = data + apply_model( + model, + data, + metadata, + save_results=False, + location_aware=True, + num_iters=1, + ) + + assert mock_allo.called + + # Helper methods + def _create_mock_model(self): + """Create a minimal mock model.""" + return { + "states": {"x": jnp.ones((10, 5)), "z": jnp.zeros(10, dtype=int)}, + "params": {"Ab": jnp.eye(5)}, + "hypparams": {"trans_hypparams": {"num_states": 10}}, + "seed": 0, + } + + def _create_mock_data(self): + """Create minimal mock data.""" + return { + "Y": jnp.ones((10, 5, 2)), + "mask": jnp.ones((10, 5), dtype=bool), + } + + +@pytest.mark.quick +class TestEstimateSyllableMarginals: + """Test estimate_syllable_marginals function.""" + + def test_basic_marginal_estimation(self): + """Test basic marginal estimation.""" + model = self._create_mock_model() + data = self._create_mock_data() + # Bounds must match data shape + metadata = (["rec1"], np.array([[0, 100]])) + + with patch("keypoint_moseq.fitting._wrapped_resample") as mock_resample: + mock_resample.return_value = model + with patch("keypoint_moseq.fitting.init_model") as mock_init: + mock_init.return_value = model + with patch( + "keypoint_moseq.fitting.stateseq_marginals" + ) as mock_marginals: + # Return marginals for 10 states + mock_marginals.return_value = jnp.ones((100, 10)) + with patch("keypoint_moseq.fitting.get_nlags") as mock_nlags: + mock_nlags.return_value = 3 + with patch("keypoint_moseq.fitting.unbatch") as mock_unbatch: + # Return unbatched result directly + mock_unbatch.return_value = {"rec1": np.ones((97, 10))} + with patch("jax.device_put") as mock_device: + mock_device.return_value = data + result = estimate_syllable_marginals( + model, + data, + metadata, + burn_in_iters=2, + num_samples=2, + steps_per_sample=1, + ) + + assert "rec1" in result + assert result["rec1"].shape[1] == 10 # num_syllables + + def test_return_samples_option(self): + """Test return_samples=True returns both marginals and samples.""" + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with patch("keypoint_moseq.fitting._wrapped_resample") as mock_resample: + mock_resample.return_value = model + with patch("keypoint_moseq.fitting.init_model") as mock_init: + mock_init.return_value = model + with patch( + "keypoint_moseq.fitting.stateseq_marginals" + ) as mock_marginals: + mock_marginals.return_value = jnp.ones((100, 10)) + with patch("keypoint_moseq.fitting.get_nlags") as mock_nlags: + mock_nlags.return_value = 3 + with patch("keypoint_moseq.fitting.unbatch") as mock_unbatch: + # Return unbatched result directly (call count = 2, one for marginals, one for samples) + mock_unbatch.side_effect = [ + {"rec1": np.ones((97, 10))}, # marginals + {"rec1": np.ones((97, 2))}, # samples + ] + with patch("numpy.moveaxis") as mock_moveaxis: + # Mock moveaxis to avoid shape issues + mock_moveaxis.return_value = np.ones((100, 100, 2)) + with patch("jax.device_put") as mock_device: + mock_device.return_value = data + marginals, samples = estimate_syllable_marginals( + model, + data, + metadata, + burn_in_iters=1, + num_samples=2, + steps_per_sample=1, + return_samples=True, + ) + + assert "rec1" in marginals + assert "rec1" in samples + + def test_location_aware_marginals(self): + """Test location_aware mode in marginal estimation.""" + model = self._create_mock_model() + data = self._create_mock_data() + metadata = (["rec1"], np.array([[0, 100]])) + + with patch( + "keypoint_moseq.fitting.allo_keypoint_slds.resample_model" + ) as mock_allo: + mock_allo.return_value = model + with patch("keypoint_moseq.fitting.init_model") as mock_init: + mock_init.return_value = model + with patch( + "keypoint_moseq.fitting.stateseq_marginals" + ) as mock_marginals: + mock_marginals.return_value = jnp.ones((100, 10)) + with patch("keypoint_moseq.fitting.get_nlags") as mock_nlags: + mock_nlags.return_value = 3 + with patch("keypoint_moseq.fitting.unbatch") as mock_unbatch: + # Return unbatched result directly + mock_unbatch.return_value = {"rec1": np.ones((97, 10))} + with patch("jax.device_put") as mock_device: + mock_device.return_value = data + estimate_syllable_marginals( + model, + data, + metadata, + location_aware=True, + burn_in_iters=1, + num_samples=1, + ) + + assert mock_allo.called + + # Helper methods + def _create_mock_model(self): + """Create a minimal mock model.""" + return { + "states": {"x": jnp.ones((100, 5)), "z": jnp.zeros(100, dtype=int)}, + "params": {"Ab": jnp.eye(5)}, + "hypparams": {"trans_hypparams": {"num_states": 10}}, + "seed": 0, + } + + def _create_mock_data(self): + """Create minimal mock data.""" + return { + "Y": jnp.ones((100, 5, 2)), + "mask": jnp.ones((100, 5), dtype=bool), + } + + +@pytest.mark.quick +class TestExpectedMarginalLikelihoods: + """Test expected_marginal_likelihoods function.""" + + def test_with_checkpoint_paths(self): + """Test with explicit checkpoint paths.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create two mock checkpoints + checkpoint_paths = [] + for i in range(2): + path = os.path.join(tmpdir, f"checkpoint_{i}.h5") + checkpoint_paths.append(path) + self._create_mock_checkpoint(path) + + with patch("keypoint_moseq.fitting.load_checkpoint") as mock_load: + mock_load.return_value = self._create_mock_checkpoint_data() + with patch( + "keypoint_moseq.fitting.marginal_log_likelihood" + ) as mock_mll: + mock_mll.return_value = jnp.array(-100.0) + scores, std_errors = expected_marginal_likelihoods( + checkpoint_paths=checkpoint_paths + ) + + assert len(scores) == 2 + assert len(std_errors) == 2 + + def test_with_project_dir_and_names(self): + """Test with project_dir and model_names.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create model directories + model_names = ["model_1", "model_2"] + for name in model_names: + model_dir = os.path.join(tmpdir, name) + os.makedirs(model_dir) + checkpoint = os.path.join(model_dir, "checkpoint.h5") + self._create_mock_checkpoint(checkpoint) + + with patch("keypoint_moseq.fitting.load_checkpoint") as mock_load: + mock_load.return_value = self._create_mock_checkpoint_data() + with patch( + "keypoint_moseq.fitting.marginal_log_likelihood" + ) as mock_mll: + mock_mll.return_value = jnp.array(-100.0) + scores, std_errors = expected_marginal_likelihoods( + project_dir=tmpdir, + model_names=model_names, + ) + + assert len(scores) == 2 + assert len(std_errors) == 2 + + def test_requires_params(self): + """Test that function requires either checkpoint_paths or project_dir+model_names.""" + with pytest.raises(AssertionError, match="Must provide either"): + expected_marginal_likelihoods() + + # Helper methods + def _create_mock_checkpoint(self, path): + """Create a minimal HDF5 checkpoint file.""" + with h5py.File(path, "w") as f: + f.create_dataset("model/states/x", data=np.ones((10, 5))) + f.create_dataset("model/params/Ab", data=np.eye(5)) + + def _create_mock_checkpoint_data(self): + """Create mock data returned by load_checkpoint.""" + model = { + "states": {"x": jnp.ones((10, 5))}, + "params": { + "Ab": jnp.eye(5), + "Q": jnp.eye(5) * 0.1, + "pi": jnp.ones(10) / 10, + }, + } + data = {"mask": jnp.ones((10, 5), dtype=bool)} + metadata = (["rec1"], np.array([[0, 10]])) + iteration = 100 + return model, data, metadata, iteration + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_io_unit.py b/tests/test_io_unit.py new file mode 100644 index 0000000..dcce624 --- /dev/null +++ b/tests/test_io_unit.py @@ -0,0 +1,880 @@ +""" +Unit tests for keypoint_moseq.io module + +Target coverage: 55% → 75% (from 249/455 statements to ~341/455) +Priority functions tested: +- Configuration management (_build_yaml, generate_config, check_config_validity, update_config) +- Path utilities (_get_path, _name_from_path) +- HDF5 operations (save_hdf5, load_hdf5) +- PCA persistence (save_pca, load_pca) +- Result extraction (extract_results) +""" + +import os +import warnings +from pathlib import Path +from unittest.mock import patch + +import h5py +import numpy as np +import pytest +import yaml + +from keypoint_moseq.io import ( + _build_yaml, + _get_path, + _name_from_path, + check_config_validity, + extract_results, + generate_config, + load_checkpoint, + load_config, + load_hdf5, + load_pca, + load_results, + reindex_syllables_in_checkpoint, + save_hdf5, + save_keypoints, + save_pca, + save_results_as_csv, + setup_project, + update_config, +) + +# Suppress common warnings +warnings.filterwarnings("ignore", category=UserWarning, message=".*os.fork.*") +warnings.filterwarnings("ignore", category=UserWarning, message=".*FigureCanvasAgg.*") + + +@pytest.mark.quick +class TestBuildYaml: + """Test _build_yaml helper function.""" + + def test_basic_structure(self): + """Test basic YAML structure generation.""" + sections = [ + ("TEST SECTION", {"key1": "value1", "key2": 123}), + ] + comments = {} + + result = _build_yaml(sections, comments) + + assert "TEST SECTION" in result + assert "key1" in result + assert "value1" in result + assert "key2" in result + + def test_with_comments(self): + """Test YAML generation with comments.""" + sections = [ + ("TEST", {"setting": "value"}), + ] + comments = {"setting": "This is a test comment"} + + result = _build_yaml(sections, comments) + + assert "# This is a test comment" in result + assert "setting: value" in result + + def test_multiple_sections(self): + """Test multiple sections.""" + sections = [ + ("SECTION1", {"a": 1}), + ("SECTION2", {"b": 2}), + ] + comments = {} + + result = _build_yaml(sections, comments) + + assert "SECTION1" in result + assert "SECTION2" in result + assert "a: 1" in result + assert "b: 2" in result + + +@pytest.mark.quick +class TestGetPath: + """Test _get_path utility function.""" + + def test_with_explicit_path(self): + """Test when path is explicitly provided.""" + result = _get_path( + project_dir="/proj", + model_name="model", + path="/explicit/path.h5", + filename="default.h5", + ) + assert result == "/explicit/path.h5" + + def test_without_path_constructs_from_parts(self): + """Test path construction from project_dir and model_name.""" + result = _get_path( + project_dir="/project", + model_name="my_model", + path=None, + filename="results.h5", + ) + assert result == "/project/my_model/results.h5" + + def test_missing_params_raises_error(self): + """Test error when required params missing.""" + with pytest.raises(AssertionError, match="required"): + _get_path( + project_dir=None, + model_name="model", + path=None, + filename="file.h5", + ) + + +@pytest.mark.quick +class TestNameFromPath: + """Test _name_from_path utility function.""" + + def test_basename_only(self): + """Test extracting just the basename.""" + result = _name_from_path( + "/path/to/file.csv", + path_in_name=False, + path_sep="-", + remove_extension=True, + ) + assert result == "file" + + def test_full_path_with_separator(self): + """Test full path with custom separator.""" + result = _name_from_path( + "/path/to/file.csv", + path_in_name=True, + path_sep="-", + remove_extension=True, + ) + assert result == "-path-to-file" + + def test_keep_extension(self): + """Test keeping file extension.""" + result = _name_from_path( + "/path/to/file.csv", + path_in_name=False, + path_sep="-", + remove_extension=False, + ) + assert result == "file.csv" + + +@pytest.mark.quick +class TestGenerateConfig: + """Test generate_config function.""" + + def test_creates_config_file(self, tmp_path): + """Test that config file is created.""" + project_dir = str(tmp_path) + generate_config(project_dir) + + config_path = tmp_path / "config.yml" + assert config_path.exists() + + def test_config_has_required_sections(self, tmp_path): + """Test that config has all required sections.""" + project_dir = str(tmp_path) + generate_config(project_dir) + + with open(tmp_path / "config.yml") as f: + content = f.read() + + assert "ANATOMY" in content + assert "FITTING" in content + assert "HYPER PARAMS" in content + assert "OTHER" in content + + def test_custom_values_override_defaults(self, tmp_path): + """Test that custom values override defaults.""" + project_dir = str(tmp_path) + generate_config(project_dir, fps=60, verbose=True) + + config = yaml.safe_load(open(tmp_path / "config.yml")) + assert config["fps"] == 60 + assert config["verbose"] is True + + def test_bodyparts_in_config(self, tmp_path): + """Test that bodyparts are configured.""" + project_dir = str(tmp_path) + generate_config(project_dir, bodyparts=["nose", "tail"]) + + config = yaml.safe_load(open(tmp_path / "config.yml")) + assert config["bodyparts"] == ["nose", "tail"] + + +@pytest.mark.quick +class TestCheckConfigValidity: + """Test check_config_validity function.""" + + def test_valid_config_returns_true(self): + """Test that valid config returns True.""" + config = { + "bodyparts": ["bp1", "bp2", "bp3"], + "use_bodyparts": ["bp1", "bp2"], + "skeleton": [["bp1", "bp2"]], + "anterior_bodyparts": ["bp1"], + "posterior_bodyparts": ["bp2"], + } + assert check_config_validity(config) is True + + def test_invalid_use_bodyparts(self, capsys): + """Test detection of invalid use_bodyparts.""" + config = { + "bodyparts": ["bp1", "bp2"], + "use_bodyparts": ["bp1", "bp3"], # bp3 not in bodyparts + "skeleton": [], + "anterior_bodyparts": ["bp1"], + "posterior_bodyparts": ["bp1"], + } + result = check_config_validity(config) + assert result is False + captured = capsys.readouterr() + assert "bp3" in captured.out + + def test_invalid_skeleton_bodypart(self, capsys): + """Test detection of invalid skeleton bodypart.""" + config = { + "bodyparts": ["bp1", "bp2"], + "use_bodyparts": ["bp1", "bp2"], + "skeleton": [["bp1", "bp3"]], # bp3 not in bodyparts + "anterior_bodyparts": ["bp1"], + "posterior_bodyparts": ["bp2"], + } + result = check_config_validity(config) + assert result is False + captured = capsys.readouterr() + assert "bp3" in captured.out + + def test_anterior_not_in_use(self, capsys): + """Test detection of anterior bodypart not in use_bodyparts.""" + config = { + "bodyparts": ["bp1", "bp2", "bp3"], + "use_bodyparts": ["bp1", "bp2"], + "skeleton": [], + "anterior_bodyparts": ["bp3"], # bp3 not in use_bodyparts + "posterior_bodyparts": ["bp2"], + } + result = check_config_validity(config) + assert result is False + + +@pytest.mark.quick +class TestLoadConfig: + """Test load_config function.""" + + def test_loads_valid_config(self, tmp_path): + """Test loading a valid config file.""" + project_dir = str(tmp_path) + generate_config(project_dir) + + config = load_config(project_dir, check_if_valid=False) + + assert "bodyparts" in config + assert "fps" in config + assert "trans_hypparams" in config + + def test_builds_indexes(self, tmp_path): + """Test that anterior/posterior indexes are built.""" + project_dir = str(tmp_path) + generate_config( + project_dir, + bodyparts=["bp1", "bp2", "bp3"], + use_bodyparts=["bp1", "bp2", "bp3"], + anterior_bodyparts=["bp1"], + posterior_bodyparts=["bp3"], + ) + + config = load_config(project_dir, build_indexes=True) + + assert "anterior_idxs" in config + assert "posterior_idxs" in config + assert config["anterior_idxs"][0] == 0 # bp1 is at index 0 + assert config["posterior_idxs"][0] == 2 # bp3 is at index 2 + + def test_skip_validity_check(self, tmp_path): + """Test loading without validity check.""" + project_dir = str(tmp_path) + # Create invalid config + generate_config( + project_dir, + bodyparts=["bp1"], + use_bodyparts=["bp2"], # Invalid: bp2 not in bodyparts + ) + + # Should not raise error with check_if_valid=False + config = load_config(project_dir, check_if_valid=False, build_indexes=False) + assert config is not None + + +@pytest.mark.quick +class TestUpdateConfig: + """Test update_config function.""" + + def test_updates_top_level_key(self, tmp_path): + """Test updating a top-level config key.""" + project_dir = str(tmp_path) + generate_config(project_dir, fps=30) + + update_config(project_dir, fps=60) + + config = load_config(project_dir, check_if_valid=False) + assert config["fps"] == 60 + + def test_updates_hyperparam(self, tmp_path): + """Test updating a hyperparameter.""" + project_dir = str(tmp_path) + generate_config(project_dir) + + update_config(project_dir, kappa=1e5) + + config = load_config(project_dir, check_if_valid=False) + assert config["trans_hypparams"]["kappa"] == 1e5 + + def test_updates_multiple_keys(self, tmp_path): + """Test updating multiple keys at once.""" + project_dir = str(tmp_path) + generate_config(project_dir) + + update_config(project_dir, fps=45, verbose=True, kappa=1e4) + + config = load_config(project_dir, check_if_valid=False) + assert config["fps"] == 45 + assert config["verbose"] is True + assert config["trans_hypparams"]["kappa"] == 1e4 + + +@pytest.mark.quick +class TestPCAPersistence: + """Test PCA save/load functions.""" + + def test_save_and_load_pca(self, tmp_path): + """Test saving and loading PCA model.""" + from sklearn.decomposition import PCA + + project_dir = str(tmp_path) + + # Create real PCA object with fitted data + X = np.random.randn(100, 20) + pca = PCA(n_components=10) + pca.fit(X) + + # Save PCA + save_pca(pca, project_dir) + + # Load PCA + loaded_pca = load_pca(project_dir) + + # Verify loaded + assert loaded_pca is not None + np.testing.assert_array_almost_equal(loaded_pca.components_, pca.components_) + + def test_save_with_custom_path(self, tmp_path): + """Test saving PCA with custom path.""" + from sklearn.decomposition import PCA + + # Create real PCA object + X = np.random.randn(50, 10) + pca = PCA(n_components=5) + pca.fit(X) + + custom_path = str(tmp_path / "custom_pca.p") + save_pca(pca, str(tmp_path), pca_path=custom_path) + + assert Path(custom_path).exists() + + def test_load_nonexistent_raises_error(self, tmp_path): + """Test loading nonexistent PCA raises error.""" + with pytest.raises(AssertionError, match="No PCA model found"): + load_pca(str(tmp_path)) + + +@pytest.mark.quick +class TestHDF5Operations: + """Test HDF5 save/load functions.""" + + def test_save_and_load_simple_dict(self, tmp_path): + """Test saving and loading simple dictionary.""" + filepath = str(tmp_path / "test.h5") + data = { + "array": np.array([1, 2, 3]), + "scalar": 42, + "string": "test", + } + + save_hdf5(filepath, data) + loaded = load_hdf5(filepath) + + np.testing.assert_array_equal(loaded["array"], data["array"]) + assert loaded["scalar"] == data["scalar"] + assert loaded["string"] == data["string"] + + def test_save_nested_dict(self, tmp_path): + """Test saving nested dictionary structure.""" + filepath = str(tmp_path / "nested.h5") + data = { + "level1": { + "level2": { + "array": np.array([1, 2, 3]), + "value": 123, + } + } + } + + save_hdf5(filepath, data) + loaded = load_hdf5(filepath) + + assert "level1" in loaded + assert "level2" in loaded["level1"] + np.testing.assert_array_equal( + loaded["level1"]["level2"]["array"], + data["level1"]["level2"]["array"], + ) + + def test_save_with_datapath(self, tmp_path): + """Test saving to specific path within HDF5.""" + filepath = str(tmp_path / "datapath.h5") + data = {"value": 42} + + save_hdf5(filepath, data, datapath="custom/path") + + with h5py.File(filepath, "r") as f: + assert "custom" in f + assert "path" in f["custom"] + + def test_exist_ok_false_prevents_overwrite(self, tmp_path): + """Test that exist_ok=False prevents overwriting.""" + filepath = str(tmp_path / "exists.h5") + + save_hdf5(filepath, {"data": 1}) + + with pytest.raises(AssertionError, match="already exists"): + save_hdf5(filepath, {"data": 2}, exist_ok=False) + + def test_exist_ok_true_allows_append(self, tmp_path): + """Test that exist_ok=True allows appending.""" + filepath = str(tmp_path / "append.h5") + + save_hdf5(filepath, {"data1": 1}) + save_hdf5(filepath, {"data2": 2}, exist_ok=True) + + loaded = load_hdf5(filepath) + assert "data1" in loaded + assert "data2" in loaded + + +@pytest.mark.quick +class TestExtractResults: + """Test extract_results function.""" + + def test_extract_results_structure(self, tmp_path): + """Test that extract_results creates correct structure.""" + # Mock model with states + model = { + "states": { + "x": np.random.randn(10, 5), + "z": np.zeros((10, 2), dtype=int), + "v": np.random.randn(10, 2), + "h": np.random.randn(10), + } + } + metadata = (["recording1"], np.array([[0, 10]])) + + with patch("jax.device_get", side_effect=lambda x: x): + with patch("keypoint_moseq.io.unbatch") as mock_unbatch: + # Mock unbatch to return simple dict + mock_unbatch.return_value = {"recording1": np.random.randn(10, 5)} + + results = extract_results( + model, + metadata, + save_results=False, + ) + + assert "recording1" in results + assert "syllable" in results["recording1"] + assert "latent_state" in results["recording1"] + assert "centroid" in results["recording1"] + assert "heading" in results["recording1"] + + def test_save_results_to_file(self, tmp_path): + """Test saving results to file.""" + model = { + "states": { + "x": np.random.randn(10, 5), + "z": np.zeros((10, 2), dtype=int), + "v": np.random.randn(10, 2), + "h": np.random.randn(10), + } + } + metadata = (["rec1"], np.array([[0, 10]])) + project_dir = str(tmp_path) + model_name = "test_model" + + # Create model directory + os.makedirs(os.path.join(project_dir, model_name)) + + with patch("jax.device_get", side_effect=lambda x: x): + with patch("keypoint_moseq.io.unbatch") as mock_unbatch: + mock_unbatch.return_value = {"rec1": np.random.randn(10, 5)} + + extract_results( + model, + metadata, + project_dir=project_dir, + model_name=model_name, + save_results=True, + ) + + results_path = Path(project_dir) / model_name / "results.h5" + assert results_path.exists() + + +@pytest.mark.quick +class TestLoadResults: + """Test load_results function.""" + + def test_load_results_from_default_path(self, tmp_path): + """Test loading results from default path.""" + project_dir = str(tmp_path) + model_name = "test_model" + results_path = tmp_path / model_name / "results.h5" + results_path.parent.mkdir(parents=True) + + # Create mock results file + test_data = {"rec1": {"syllable": np.array([0, 1, 2])}} + save_hdf5(str(results_path), test_data) + + loaded = load_results(project_dir=project_dir, model_name=model_name) + + assert "rec1" in loaded + np.testing.assert_array_equal( + loaded["rec1"]["syllable"], test_data["rec1"]["syllable"] + ) + + +@pytest.mark.quick +class TestSaveResultsAsCsv: + """Test save_results_as_csv function.""" + + def test_creates_csv_files(self, tmp_path): + """Test that CSV files are created.""" + results = { + "recording1": { + "syllable": np.array([0, 1, 2, 1, 0]), + "centroid": np.array( + [[1.0, 2.0], [1.1, 2.1], [1.2, 2.2], [1.3, 2.3], [1.4, 2.4]] + ), + "heading": np.array([0.1, 0.2, 0.3, 0.4, 0.5]), + } + } + + save_dir = str(tmp_path / "csv_results") + save_results_as_csv(results, save_dir=save_dir) + + csv_path = Path(save_dir) / "recording1.csv" + assert csv_path.exists() + + def test_csv_contains_correct_columns(self, tmp_path): + """Test that CSV has correct column structure.""" + import pandas as pd + + results = { + "rec1": { + "syllable": np.array([0, 1, 2]), + "centroid": np.array([[1.0, 2.0], [1.1, 2.1], [1.2, 2.2]]), + "heading": np.array([0.1, 0.2, 0.3]), + "latent_state": np.random.randn(3, 5), + } + } + + save_dir = str(tmp_path / "csv_test") + save_results_as_csv(results, save_dir=save_dir) + + df = pd.read_csv(Path(save_dir) / "rec1.csv") + + assert "syllable" in df.columns + assert "centroid x" in df.columns + assert "centroid y" in df.columns + assert "heading" in df.columns + assert "latent_state 0" in df.columns + + def test_path_separator_replacement(self, tmp_path): + """Test that path separators are replaced.""" + results = { + "path/to/recording": { + "syllable": np.array([0, 1, 2]), + } + } + + save_dir = str(tmp_path / "csv_pathsep") + save_results_as_csv(results, save_dir=save_dir, path_sep="_") + + csv_path = Path(save_dir) / "path_to_recording.csv" + assert csv_path.exists() + + +@pytest.mark.quick +class TestSaveKeypoints: + """Test save_keypoints function.""" + + def test_saves_coordinates_only(self, tmp_path): + """Test saving coordinates without confidences.""" + import pandas as pd + + coordinates = { + "rec1": np.random.randn(10, 3, 2), # 10 frames, 3 keypoints, 2D + } + bodyparts = ["bp1", "bp2", "bp3"] + + save_dir = str(tmp_path / "keypoints") + save_keypoints(save_dir, coordinates, bodyparts=bodyparts) + + csv_path = Path(save_dir) / "rec1.csv" + assert csv_path.exists() + + df = pd.read_csv(csv_path) + assert "bp1_x" in df.columns + assert "bp1_y" in df.columns + assert "bp2_x" in df.columns + + def test_saves_with_confidences(self, tmp_path): + """Test saving coordinates with confidences.""" + import pandas as pd + + coordinates = { + "rec1": np.random.randn(10, 2, 2), + } + confidences = { + "rec1": np.random.rand(10, 2), + } + bodyparts = ["bp1", "bp2"] + + save_dir = str(tmp_path / "keypoints_conf") + save_keypoints( + save_dir, coordinates, confidences=confidences, bodyparts=bodyparts + ) + + df = pd.read_csv(Path(save_dir) / "rec1.csv") + assert "bp1_conf" in df.columns + assert "bp2_conf" in df.columns + + def test_3d_coordinates(self, tmp_path): + """Test saving 3D coordinates.""" + import pandas as pd + + coordinates = { + "rec1": np.random.randn(5, 2, 3), # 5 frames, 2 keypoints, 3D + } + bodyparts = ["bp1", "bp2"] + + save_dir = str(tmp_path / "keypoints_3d") + save_keypoints(save_dir, coordinates, bodyparts=bodyparts) + + df = pd.read_csv(Path(save_dir) / "rec1.csv") + assert "bp1_x" in df.columns + assert "bp1_y" in df.columns + assert "bp1_z" in df.columns + + +@pytest.mark.quick +class TestSetupProject: + """Test setup_project function.""" + + def test_creates_project_directory(self, tmp_path): + """Test that project directory is created.""" + project_dir = str(tmp_path / "new_project") + setup_project(project_dir) + + assert Path(project_dir).exists() + assert (Path(project_dir) / "config.yml").exists() + + def test_existing_directory_no_overwrite(self, tmp_path, capsys): + """Test that existing directory is not overwritten without flag.""" + project_dir = str(tmp_path / "existing") + setup_project(project_dir) + + # Try to setup again without overwrite + setup_project(project_dir, overwrite=False) + + captured = capsys.readouterr().out + # Prev failed bc capured was "already\nexists" + assert "already" in captured and "exists" in captured + + def test_existing_directory_with_overwrite(self, tmp_path): + """Test that existing directory can be overwritten with flag.""" + project_dir = str(tmp_path / "existing") + setup_project(project_dir, fps=30) + + # Setup again with overwrite and different fps + setup_project(project_dir, fps=60, overwrite=True) + + config = load_config(project_dir, check_if_valid=False) + assert config["fps"] == 60 + + def test_with_custom_options(self, tmp_path): + """Test setup with custom configuration options.""" + project_dir = str(tmp_path / "custom") + setup_project( + project_dir, + fps=45, + bodyparts=["nose", "tail", "back"], + verbose=True, + ) + + config = load_config(project_dir, check_if_valid=False) + assert config["fps"] == 45 + assert config["bodyparts"] == ["nose", "tail", "back"] + assert config["verbose"] is True + + +@pytest.mark.quick +class TestCheckpointOperations: + """Test checkpoint loading and reindexing.""" + + def test_load_checkpoint_with_explicit_path(self, tmp_path): + """Test loading checkpoint from explicit path.""" + checkpoint_path = str(tmp_path / "checkpoint.h5") + + # Create mock checkpoint + model_data = { + "params": {"pi": np.eye(3)}, + "states": {"z": np.array([0, 1, 2])}, + } + data = {"Y": np.random.randn(100, 10)} + metadata = {"keys": ["rec1"], "bounds": np.array([[0, 100]])} + + save_hdf5( + checkpoint_path, + {"model_snapshots": {"50": model_data}}, + exist_ok=True, + ) + save_hdf5(checkpoint_path, {"data": data}, exist_ok=True) + save_hdf5(checkpoint_path, {"metadata": metadata}, exist_ok=True) + + model, loaded_data, loaded_metadata, iteration = load_checkpoint( + path=checkpoint_path + ) + + assert iteration == 50 + assert "params" in model + assert "states" in model + + def test_load_checkpoint_from_project_dir(self, tmp_path): + """Test loading checkpoint using project_dir and model_name.""" + project_dir = str(tmp_path) + model_name = "test_model" + checkpoint_path = tmp_path / model_name / "checkpoint.h5" + checkpoint_path.parent.mkdir(parents=True) + + # Create minimal checkpoint + model_data = { + "params": {"pi": np.eye(2)}, + "states": {"z": np.array([0, 1])}, + } + + save_hdf5( + str(checkpoint_path), + {"model_snapshots": {"100": model_data}}, + exist_ok=True, + ) + save_hdf5( + str(checkpoint_path), + {"data": {"Y": np.random.randn(10, 5)}}, + exist_ok=True, + ) + save_hdf5( + str(checkpoint_path), + {"metadata": {"keys": ["rec1"], "bounds": np.array([[0, 10]])}}, + exist_ok=True, + ) + + model, data, metadata, iteration = load_checkpoint( + project_dir=project_dir, model_name=model_name + ) + + assert iteration == 100 + + def test_load_checkpoint_specific_iteration(self, tmp_path): + """Test loading specific iteration from checkpoint.""" + checkpoint_path = str(tmp_path / "multi_snapshot.h5") + + # Create checkpoint with multiple snapshots + for it in [10, 20, 30]: + model_data = { + "params": {"value": it}, + "states": {"z": np.array([it])}, + } + save_hdf5( + checkpoint_path, + {f"model_snapshots/{it}": model_data}, + exist_ok=True, + ) + + save_hdf5( + checkpoint_path, + {"data": {"Y": np.random.randn(10, 5)}}, + exist_ok=True, + ) + save_hdf5( + checkpoint_path, + {"metadata": {"keys": ["rec1"], "bounds": np.array([[0, 10]])}}, + exist_ok=True, + ) + + # Load iteration 20 + model, _, _, iteration = load_checkpoint(path=checkpoint_path, iteration=20) + + assert iteration == 20 + assert model["params"]["value"] == 20 + + def test_reindex_syllables_modifies_checkpoint(self, tmp_path): + """Test that reindex_syllables modifies checkpoint in place.""" + checkpoint_path = str(tmp_path / "reindex.h5") + num_states = 5 + + # Create checkpoint with model snapshot + model_data = { + "params": { + "betas": np.arange(num_states), + "pi": np.eye(num_states), + "Ab": np.arange(num_states), + "Q": np.arange(num_states), + }, + "states": { + "z": np.array([0, 1, 2, 3, 4, 0, 1]), + }, + } + + save_hdf5( + checkpoint_path, + {"model_snapshots": {"50": model_data}}, + exist_ok=True, + ) + save_hdf5( + checkpoint_path, + {"data": {"mask": np.ones(7, dtype=bool)}}, + exist_ok=True, + ) + + # Reindex with custom index (reverse order) + custom_index = np.array([4, 3, 2, 1, 0]) + returned_index = reindex_syllables_in_checkpoint( + path=checkpoint_path, index=custom_index + ) + + np.testing.assert_array_equal(returned_index, custom_index) + + # Load and verify reindexing happened + reindexed_model = load_hdf5(checkpoint_path, "model_snapshots/50") + + # betas should be reordered + np.testing.assert_array_equal( + reindexed_model["params"]["betas"], np.array([4, 3, 2, 1, 0]) + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_modeling.py b/tests/test_modeling.py new file mode 100644 index 0000000..79f6662 --- /dev/null +++ b/tests/test_modeling.py @@ -0,0 +1,181 @@ +""" +Test suite for keypoint-MoSeq modeling functionality + +Tests model initialization, fitting, and checkpoint management. +""" + +from pathlib import Path + +import h5py +import numpy as np +import pytest + + +@pytest.mark.medium +@pytest.mark.notebook +def test_model_initialization(prepared_model): + """Test model initialization with hyperparameters + + Expected duration: <5 seconds (uses prepared_model fixture) + """ + # Get prepared model from fixture + model = prepared_model["model"] + + # Verify model was initialized + assert model is not None, "Model initialization returned None" + + # Verify model structure (model is a dict, not an object) + assert "states" in model, "Model missing states key" + assert "params" in model, "Model missing params key" + assert "hypparams" in model, "Model missing hypparams key" + + +@pytest.mark.integration +@pytest.mark.notebook +def test_model_fitting_sequence(prepared_model, reduced_iterations, kpms): + """Test sequential model fitting: AR-HMM → full model + + Expected duration: ~10 minutes (uses prepared_model fixture) + """ + # Get prepared model from fixture + model = prepared_model["model"] + data = prepared_model["data"] + metadata = prepared_model["metadata"] + project_dir = prepared_model["project_dir"] + + # Test AR-HMM fitting + model, model_name = kpms.fit_model( + model, + data, + metadata, + project_dir, + ar_only=True, + num_iters=reduced_iterations["ar_hmm_iters"], + ) + assert model is not None, "AR-HMM fitting failed" + assert model_name is not None, "Model name is None" + + # Test full model fitting + model_fitted, _ = kpms.fit_model( + model, + data, + metadata, + project_dir, + ar_only=False, + num_iters=reduced_iterations["full_model_iters"], + ) + assert model_fitted is not None, "Full model fitting failed" + + +@pytest.mark.medium +@pytest.mark.notebook +def test_model_saving_and_loading(prepared_model, kpms): + """Test model checkpoint saving and loading + + Expected duration: ~2 minutes (uses prepared_model fixture) + """ + # Get prepared model from fixture + model = prepared_model["model"] + data = prepared_model["data"] + metadata = prepared_model["metadata"] + project_dir = prepared_model["project_dir"] + + # Quick fit - fit_model automatically saves checkpoint + model, model_name = kpms.fit_model( + model, + data, + metadata, + project_dir, + ar_only=True, + num_iters=5, # Very short for speed + ) + + assert model_name is not None, "Model name is None" + + # Check checkpoint file was created by fit_model + checkpoint_path = Path(project_dir) / model_name / "checkpoint.h5" + assert checkpoint_path.exists(), "Checkpoint not saved" + + # Verify checkpoint structure + with h5py.File(checkpoint_path, "r") as f: + assert "model_snapshots" in f, "Checkpoint missing model_snapshots group" + assert "data" in f, "Checkpoint missing data group" + assert "metadata" in f, "Checkpoint missing metadata group" + + # Test reindexing + kpms.reindex_syllables_in_checkpoint(project_dir, model_name) + + # Checkpoint should still exist after reindexing + assert checkpoint_path.exists(), "Checkpoint removed after reindexing" + + +@pytest.mark.quick +@pytest.mark.notebook +def test_hyperparameter_estimation( + temp_project_dir, dlc_config, dlc_videos_dir, kpms, update_kwargs +): + """Test hyperparameter estimation (sigmasq_loc) + + Expected duration: < 5 seconds + """ + project_dir = temp_project_dir + + # Setup - use update_kwargs fixture for standard config + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + + # Use different anterior/posterior for this test (testing edge case) + kpms.update_config( + project_dir, + use_bodyparts=update_kwargs["use_bodyparts"], + anterior_bodyparts=["head", "nose", "right ear", "left ear"], + posterior_bodyparts=["spine4", "spine3", "spine2", "spine1"], + ) + + # Prepare data + coordinates, confidences, _ = kpms.load_keypoints(dlc_videos_dir, "deeplabcut") + config = kpms.load_config(project_dir) + data, metadata = kpms.format_data(coordinates, confidences, **config) + + # Fit PCA + _ = kpms.fit_pca(**data, **config) + + # Estimate sigmasq_loc hyperparameter (this is what keypoint_moseq provides) + sigmasq_loc = kpms.estimate_sigmasq_loc( + data["Y"], data["mask"], filter_size=config["fps"] + ) + + # Verify estimate is reasonable + assert isinstance( + sigmasq_loc, (int, float, np.number) + ), "sigmasq_loc should be numeric" + assert sigmasq_loc > 0, "sigmasq_loc should be positive" + assert sigmasq_loc < 100, "sigmasq_loc should be reasonable (< 100)" + + +@pytest.mark.quick +def test_config_update(temp_project_dir, dlc_config, kpms, update_kwargs): + """Test configuration update and persistence + + Expected duration: < 1 second + """ + project_dir = temp_project_dir + + # Setup + kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True) + + # Update config with required bodyparts first (using standard config) + kpms.update_config( + project_dir, + use_bodyparts=update_kwargs["use_bodyparts"], + anterior_bodyparts=["head", "nose", "right ear", "left ear"], + posterior_bodyparts=["spine4", "spine3", "spine2", "spine1"], + ) + + # Update config with a real parameter (latent_dim) + test_value = 4 + kpms.update_config(project_dir, latent_dim=test_value) + + # Verify update persisted + config = kpms.load_config(project_dir) + assert "latent_dim" in config["ar_hypparams"], "Config update not persisted" + assert config["ar_hypparams"]["latent_dim"] == test_value, "Config value mismatch" diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..ca13019 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,863 @@ +"""Unit tests for keypoint_moseq.util module. + +This module tests utility functions for data manipulation, validation, +and processing in the keypoint-moseq package. +""" + +import warnings +from unittest.mock import MagicMock, Mock, patch + +import numpy as np +import pytest + +from keypoint_moseq.util import ( + _find_optimal_segment_length, + _get_percent_padding, + apply_syllable_mapping, + check_nan_proportions, + check_video_paths, + downsample_timepoints, + estimate_sigmasq_loc, + filter_angle, + filtered_derivative, + find_matching_videos, + find_medoid_distance_outliers, + generate_syllable_mapping, + get_distance_to_medoid, + get_edges, + get_syllable_instances, + interpolate_along_axis, + interpolate_keypoints, + list_files_with_exts, + pad_along_axis, + permute_cyclic, + print_dims_to_explain_variance, + reindex_by_bodyparts, +) + + +class TestPadAlongAxis: + """Test pad_along_axis function.""" + + def test_pad_axis_0(self): + """Test padding along axis 0.""" + arr = np.ones((3, 4)) + result = pad_along_axis(arr, (1, 2), axis=0, value=0) + assert result.shape == (6, 4) + assert result[0, 0] == 0 # First row padded + assert result[-1, 0] == 0 # Last row padded + assert result[1, 0] == 1 # Original data + + def test_pad_axis_1(self): + """Test padding along axis 1.""" + arr = np.ones((3, 4)) + result = pad_along_axis(arr, (2, 1), axis=1, value=5) + assert result.shape == (3, 7) + assert result[0, 0] == 5 # First column padded + assert result[0, -1] == 5 # Last column padded + assert result[0, 2] == 1 # Original data + + def test_pad_custom_value(self): + """Test padding with custom value.""" + arr = np.zeros((2, 2)) + result = pad_along_axis(arr, (1, 1), axis=0, value=99) + assert result[0, 0] == 99 + assert result[-1, 0] == 99 + + +class TestFilterAngle: + """Test filter_angle function.""" + + def test_median_filter(self): + """Test median filtering of angles.""" + # Create angles with some noise + np.random.seed(42) + angles = np.linspace(0, 2 * np.pi, 100) + noisy_angles = angles + np.random.randn(100) * 0.5 + + result = filter_angle(noisy_angles, size=9, axis=0, method="median") + assert result.shape == noisy_angles.shape + # Filtered angles should be smoother (higher noise for more obvious effect) + assert np.std(np.diff(result)) < np.std(np.diff(noisy_angles)) + + def test_gaussian_filter(self): + """Test Gaussian filtering of angles.""" + angles = np.linspace(0, 2 * np.pi, 100) + result = filter_angle(angles, size=5, axis=0, method="gaussian") + assert result.shape == angles.shape + + def test_filter_2d_array(self): + """Test filtering 2D array of angles along axis.""" + angles = np.random.randn(50, 3) # Multiple angle sequences + result = filter_angle(angles, size=7, axis=0, method="median") + assert result.shape == angles.shape + + +class TestGetEdges: + """Test get_edges function.""" + + def test_edges_from_indices(self): + """Test edge list from index pairs.""" + use_bodyparts = ["nose", "left_ear", "right_ear"] + skeleton = [[0, 1], [0, 2]] + edges = get_edges(use_bodyparts, skeleton) + assert edges == [[0, 1], [0, 2]] + + def test_edges_from_names(self): + """Test edge list from bodypart names.""" + use_bodyparts = ["nose", "left_ear", "right_ear", "neck"] + skeleton = [ + ("nose", "left_ear"), + ("nose", "right_ear"), + ("nose", "neck"), + ] + edges = get_edges(use_bodyparts, skeleton) + assert len(edges) == 3 + assert [0, 1] in edges + assert [0, 2] in edges + assert [0, 3] in edges + + def test_edges_partial_skeleton(self): + """Test edges when some bodyparts not in use_bodyparts.""" + use_bodyparts = ["nose", "left_ear"] + skeleton = [ + ("nose", "left_ear"), + ("nose", "tail"), + ] # tail not in use_bodyparts + edges = get_edges(use_bodyparts, skeleton) + assert len(edges) == 1 + assert [0, 1] in edges + + def test_empty_skeleton(self): + """Test with empty skeleton.""" + edges = get_edges(["nose"], []) + assert edges == [] + + +class TestReindexByBodyparts: + """Test reindex_by_bodyparts function.""" + + def test_reindex_array(self): + """Test reindexing a single array.""" + data = np.arange(12).reshape(3, 4) # 3 frames, 4 bodyparts + bodyparts = ["a", "b", "c", "d"] + use_bodyparts = ["d", "b", "a"] + + result = reindex_by_bodyparts(data, bodyparts, use_bodyparts, axis=1) + assert result.shape == (3, 3) + assert np.array_equal(result[:, 0], data[:, 3]) # d + assert np.array_equal(result[:, 1], data[:, 1]) # b + assert np.array_equal(result[:, 2], data[:, 0]) # a + + def test_reindex_dict(self): + """Test reindexing a dictionary of arrays.""" + data = { + "rec1": np.arange(8).reshape(2, 4), + "rec2": np.arange(8, 16).reshape(2, 4), + } + bodyparts = ["a", "b", "c", "d"] + use_bodyparts = ["c", "a"] + + result = reindex_by_bodyparts(data, bodyparts, use_bodyparts, axis=1) + assert isinstance(result, dict) + assert result["rec1"].shape == (2, 2) + assert np.array_equal(result["rec1"][:, 0], data["rec1"][:, 2]) # c + assert np.array_equal(result["rec1"][:, 1], data["rec1"][:, 0]) # a + + +class TestInterpolateAlongAxis: + """Test interpolate_along_axis function.""" + + def test_linear_interpolation(self): + """Test linear interpolation along axis.""" + xp = np.array([0, 2, 4]) + fp = np.array([[0, 0], [10, 10], [20, 20]]) + x = np.array([0, 1, 2, 3, 4]) + + result = interpolate_along_axis(x, xp, fp, axis=0) + assert result.shape == (5, 2) + assert np.allclose(result[1], [5, 5]) # Midpoint between 0 and 10 + assert np.allclose(result[3], [15, 15]) # Midpoint between 10 and 20 + + def test_extrapolation(self): + """Test that interpolation extrapolates beyond data range.""" + xp = np.array([1, 2]) + fp = np.array([10, 20]) + x = np.array([0, 1, 2, 3]) + + result = interpolate_along_axis(x, xp, fp, axis=0) + assert result[0] == 10 # Extrapolates to first value + assert result[-1] == 20 # Extrapolates to last value + + def test_empty_datapoints_raises(self): + """Test that empty datapoints raises assertion.""" + xp = np.array([]) + fp = np.array([]).reshape(0, 2) + x = np.array([0, 1, 2]) + + with pytest.raises( + AssertionError, match="cannot interpolate without datapoints" + ): + interpolate_along_axis(x, xp, fp, axis=0) + + +class TestInterpolateKeypoints: + """Test interpolate_keypoints function.""" + + def test_no_outliers(self): + """Test interpolation with no outliers.""" + coordinates = np.random.randn(10, 3, 2) # 10 frames, 3 keypoints, 2D + outliers = np.zeros((10, 3), dtype=bool) + + result = interpolate_keypoints(coordinates, outliers) + assert np.allclose(result, coordinates) + + def test_single_outlier(self): + """Test interpolation of single outlier frame.""" + coordinates = np.array( + [ + [[0, 0], [1, 1]], + [[5, 5], [6, 6]], # Outlier frame + [[2, 2], [3, 3]], + ] + ) + outliers = np.array( + [ + [False, False], + [True, True], + [False, False], + ] + ) + + result = interpolate_keypoints(coordinates, outliers) + # Frame 1 should be interpolated between frames 0 and 2 + assert np.allclose(result[1, 0], [1, 1]) + assert np.allclose(result[1, 1], [2, 2]) + + def test_all_outliers_for_keypoint(self): + """Test when all frames are outliers for a keypoint.""" + coordinates = np.random.randn(5, 2, 2) + outliers = np.zeros((5, 2), dtype=bool) + outliers[:, 1] = True # All frames outliers for keypoint 1 + + result = interpolate_keypoints(coordinates, outliers) + # Keypoint 1 should be all zeros (no valid data to interpolate) + assert np.allclose(result[:, 1], 0) + + +class TestFilteredDerivative: + """Test filtered_derivative function.""" + + def test_constant_signal(self): + """Test derivative of constant signal is zero.""" + Y = np.ones((100, 3)) + dY = filtered_derivative(Y, ksize=5, axis=0) + assert dY.shape == Y.shape + assert np.allclose(dY, 0, atol=1e-10) + + def test_linear_signal(self): + """Test derivative of linear signal is constant.""" + Y = np.arange(100).reshape(-1, 1).astype(float) + dY = filtered_derivative(Y, ksize=3, axis=0) + # The filtered derivative algorithm uses forward - backward convolution + # For linear signal, derivative should be constant (but value depends on kernel) + # Just check that variance is low (derivative is relatively constant) + assert np.std(dY[10:-10]) < 0.5 + + def test_axis_parameter(self): + """Test derivative along different axis.""" + Y = np.arange(20).reshape(4, 5).astype(float) + dY_axis0 = filtered_derivative(Y, ksize=1, axis=0) + dY_axis1 = filtered_derivative(Y, ksize=1, axis=1) + assert dY_axis0.shape == Y.shape + assert dY_axis1.shape == Y.shape + + +class TestPermuteCyclic: + """Test permute_cyclic function.""" + + def test_permutation_shape(self): + """Test permutation preserves shape.""" + arr = np.arange(20).reshape(10, 2) + result = permute_cyclic(arr, axis=0) + assert result.shape == arr.shape + + def test_permutation_with_mask(self): + """Test permutation with mask.""" + np.random.seed(42) + arr = np.arange(10) + mask = np.zeros(10, dtype=int) + mask[:5] = 1 # Only permute first 5 elements + + result = permute_cyclic(arr, mask=mask, axis=0) + # Last 5 elements should be zeros (not permuted, kept as masked) + assert np.all(result[5:] == 0) + + def test_permutation_preserves_values(self): + """Test permutation preserves values (just reorders).""" + arr = np.array([1, 2, 3, 4, 5]) + result = permute_cyclic(arr, axis=0) + assert set(result) == set(arr) + + +class TestDownsampleTimepoints: + """Test downsample_timepoints function.""" + + def test_downsample_array(self): + """Test downsampling an array.""" + data = np.arange(100).reshape(100, 1) + downsampled, indexes = downsample_timepoints(data, downsample_rate=2) + + assert downsampled.shape == (50, 1) + assert np.array_equal(indexes, np.arange(50) * 2) + assert np.array_equal(downsampled[:, 0], data[::2, 0]) + + def test_downsample_dict(self): + """Test downsampling a dictionary.""" + data = { + "rec1": np.arange(10).reshape(10, 1), + "rec2": np.arange(20).reshape(20, 1), + } + downsampled, indexes = downsample_timepoints(data, downsample_rate=3) + + assert isinstance(downsampled, dict) + assert downsampled["rec1"].shape == (4, 1) + assert downsampled["rec2"].shape == (7, 1) + assert indexes["rec1"][0] == 0 + assert indexes["rec1"][1] == 3 + + +class TestGetPercentPadding: + """Test _get_percent_padding function.""" + + def test_no_padding_needed(self): + """Test when sequences are exact multiples of segment length.""" + sequence_lengths = np.array([10, 20, 30]) + seg_length = 10 + percent = _get_percent_padding(sequence_lengths, seg_length) + assert percent == 0.0 + + def test_padding_needed(self): + """Test when padding is needed.""" + sequence_lengths = np.array([8, 15, 4]) + seg_length = 10 + # 8 needs 2, 15 needs 5, 4 needs 6 = 13 total padding + # Total length = 27, so 13/27 * 100 = 48.15% + percent = _get_percent_padding(sequence_lengths, seg_length) + assert np.isclose(percent, 48.148, atol=0.01) + + def test_single_sequence(self): + """Test with single sequence.""" + sequence_lengths = np.array([23]) + seg_length = 10 + # 23 needs 7 padding to reach 30 + # 7/23 * 100 = 30.43% + percent = _get_percent_padding(sequence_lengths, seg_length) + assert np.isclose(percent, 30.43, atol=0.01) + + +class TestFindOptimalSegmentLength: + """Test _find_optimal_segment_length function.""" + + def test_optimal_length_exact_match(self): + """Test when sequence lengths are available options.""" + sequence_lengths = np.array([100, 200, 150]) + seg_length = _find_optimal_segment_length( + sequence_lengths, + max_seg_length=200, + max_percent_padding=50, + min_fragment_length=4, + ) + assert seg_length <= 200 + assert seg_length >= 5 # Must be > min_fragment_length + + def test_respects_min_fragment_length(self): + """Test that result respects min_fragment_length.""" + sequence_lengths = np.array([100, 103, 107]) + seg_length = _find_optimal_segment_length( + sequence_lengths, + max_seg_length=100, + max_percent_padding=50, + min_fragment_length=10, + ) + # All remainders should be >= 10 or == 0 + remainders = sequence_lengths % seg_length + assert np.all((remainders >= 10) | (remainders == 0)) + + def test_short_sequences_raise(self): + """Test that sequences shorter than min_fragment_length raise.""" + sequence_lengths = np.array([10, 3, 8]) # 3 is too short + with pytest.raises(AssertionError, match="at least"): + _find_optimal_segment_length(sequence_lengths, min_fragment_length=4) + + +class TestGetDistanceToMedoid: + """Test get_distance_to_medoid function.""" + + def test_2d_coordinates(self): + """Test distance calculation with 2D coordinates.""" + # Simple case: 3 keypoints arranged in line + coordinates = np.array( + [ + [[0, 0], [1, 0], [2, 0]], # Frame 1 + [[0, 1], [1, 1], [2, 1]], # Frame 2 + ] + ) + distances = get_distance_to_medoid(coordinates) + + assert distances.shape == (2, 3) + # Median is (1, 0) for frame 1, so distances should be [1, 0, 1] + assert np.allclose(distances[0], [1, 0, 1]) + + def test_3d_coordinates(self): + """Test distance calculation with 3D coordinates.""" + coordinates = np.array( + [ + [[0, 0, 0], [1, 1, 1], [2, 2, 2]], + ] + ) + distances = get_distance_to_medoid(coordinates) + assert distances.shape == (1, 3) + # Medoid is (1, 1, 1), distances are sqrt(3), 0, sqrt(3) + assert np.allclose(distances[0, 1], 0) + + +class TestFindMedoidDistanceOutliers: + """Test find_medoid_distance_outliers function.""" + + def test_no_outliers(self): + """Test with normally distributed keypoints (no outliers).""" + np.random.seed(42) + coordinates = np.random.randn(100, 5, 2) * 0.1 # Small variance + + result = find_medoid_distance_outliers(coordinates, outlier_scale_factor=6.0) + assert "mask" in result + assert "thresholds" in result + assert result["mask"].shape == (100, 5) + assert result["thresholds"].shape == (5,) + # With scale factor 6, few outliers expected + assert np.sum(result["mask"]) < 50 # Less than 50% + + def test_with_outliers(self): + """Test with injected outliers.""" + np.random.seed(42) + coordinates = np.random.randn(50, 3, 2) * 0.1 + # Add clear outliers + coordinates[10, 0] = [100, 100] # Far from others + coordinates[20, 1] = [-100, -100] + + result = find_medoid_distance_outliers(coordinates, outlier_scale_factor=3.0) + # Should detect at least the injected outliers + assert result["mask"][10, 0] + assert result["mask"][20, 1] + + def test_scale_factor_effect(self): + """Test that higher scale factor yields fewer outliers.""" + np.random.seed(42) + coordinates = np.random.randn(100, 4, 2) + + result_low = find_medoid_distance_outliers( + coordinates, outlier_scale_factor=2.0 + ) + result_high = find_medoid_distance_outliers( + coordinates, outlier_scale_factor=10.0 + ) + + # Higher scale factor should have fewer outliers + assert np.sum(result_high["mask"]) < np.sum(result_low["mask"]) + + +class TestGenerateSyllableMapping: + """Test generate_syllable_mapping function.""" + + def test_simple_grouping(self): + """Test basic syllable grouping.""" + results = { + "rec1": {"syllable": np.array([0, 0, 1, 1, 2, 2, 3, 3])}, + "rec2": {"syllable": np.array([0, 1, 2, 3])}, + } + syllable_grouping = [[0, 1], [2, 3]] + + mapping = generate_syllable_mapping(results, syllable_grouping) + + # All syllables should be mapped + assert 0 in mapping and 1 in mapping and 2 in mapping and 3 in mapping + # Grouped syllables should map to same index + assert mapping[0] == mapping[1] + assert mapping[2] == mapping[3] + + def test_frequency_based_ordering(self): + """Test that groups are ordered by frequency.""" + results = { + "rec1": {"syllable": np.array([0] * 100 + [1] * 10 + [2] * 50)}, + } + syllable_grouping = [[0, 2]] # Group high-frequency syllables + + mapping = generate_syllable_mapping(results, syllable_grouping) + # Group [0, 2] (150 occurrences) should get lower index than singleton [1] (10 occurrences) + assert mapping[0] < mapping[1] + assert mapping[2] < mapping[1] + + def test_no_grouping(self): + """Test with empty grouping (all syllables separate).""" + results = { + "rec1": {"syllable": np.array([0, 1, 2, 0, 1, 2])}, + } + syllable_grouping = [] + + mapping = generate_syllable_mapping(results, syllable_grouping) + # Should create identity-like mapping based on frequency + assert len(mapping) == 3 + assert set(mapping.values()) == {0, 1, 2} + + +class TestApplySyllableMapping: + """Test apply_syllable_mapping function.""" + + def test_simple_remapping(self): + """Test basic syllable remapping.""" + results = { + "rec1": { + "syllable": np.array([0, 1, 2, 3]), + "centroid": np.array([[0, 0], [1, 1], [2, 2], [3, 3]]), + } + } + mapping = {0: 5, 1: 6, 2: 7, 3: 8} + + remapped = apply_syllable_mapping(results, mapping) + + assert np.array_equal(remapped["rec1"]["syllable"], [5, 6, 7, 8]) + # Other fields should be copied unchanged + assert np.array_equal(remapped["rec1"]["centroid"], results["rec1"]["centroid"]) + + def test_collapsing_syllables(self): + """Test mapping multiple syllables to same index.""" + results = { + "rec1": { + "syllable": np.array([0, 1, 2, 1, 0]), + "heading": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), + } + } + mapping = {0: 0, 1: 0, 2: 1} # Collapse 0 and 1 to 0 + + remapped = apply_syllable_mapping(results, mapping) + + assert np.array_equal(remapped["rec1"]["syllable"], [0, 0, 1, 0, 0]) + assert np.array_equal(remapped["rec1"]["heading"], results["rec1"]["heading"]) + + def test_multiple_recordings(self): + """Test remapping across multiple recordings.""" + results = { + "rec1": {"syllable": np.array([0, 1])}, + "rec2": {"syllable": np.array([1, 2])}, + } + mapping = {0: 10, 1: 11, 2: 12} + + remapped = apply_syllable_mapping(results, mapping) + + assert np.array_equal(remapped["rec1"]["syllable"], [10, 11]) + assert np.array_equal(remapped["rec2"]["syllable"], [11, 12]) + + +class TestListFilesWithExts: + """Test list_files_with_exts function.""" + + def test_single_file_match(self, tmp_path): + """Test finding single file with extension.""" + test_file = tmp_path / "test.txt" + test_file.write_text("content") + + result = list_files_with_exts(str(tmp_path), [".txt"], recursive=False) + assert len(result) == 1 + assert test_file.name in result[0] + + def test_multiple_extensions(self, tmp_path): + """Test finding files with multiple extensions.""" + (tmp_path / "file1.txt").write_text("a") + (tmp_path / "file2.csv").write_text("b") + (tmp_path / "file3.json").write_text("c") + + result = list_files_with_exts(str(tmp_path), [".txt", ".csv"], recursive=False) + assert len(result) == 2 + + def test_recursive_search(self, tmp_path): + """Test recursive file search.""" + subdir = tmp_path / "subdir" + subdir.mkdir() + (tmp_path / "file1.txt").write_text("a") + (subdir / "file2.txt").write_text("b") + + result = list_files_with_exts(str(tmp_path), [".txt"], recursive=True) + assert len(result) == 2 + + def test_extension_normalization(self, tmp_path): + """Test that extensions are normalized (case, leading dot).""" + (tmp_path / "file.TXT").write_text("a") + + result = list_files_with_exts(str(tmp_path), ["txt"], recursive=False) + assert len(result) == 1 + + +class TestFindMatchingVideos: + """Test find_matching_videos function.""" + + def test_exact_match(self, tmp_path): + """Test exact video name matching.""" + (tmp_path / "video1.mp4").write_text("fake video") + (tmp_path / "video2.avi").write_text("fake video") + + keys = ["video1", "video2"] + result = find_matching_videos( + keys, str(tmp_path), as_dict=True, recursive=False + ) + + assert "video1" in result + assert "video2" in result + assert "video1.mp4" in result["video1"] + + def test_prefix_match(self, tmp_path): + """Test prefix matching (recording names have more text).""" + (tmp_path / "vid.mp4").write_text("fake") + + keys = ["vid_2024_session1"] + result = find_matching_videos( + keys, str(tmp_path), as_dict=False, recursive=False + ) + + assert len(result) == 1 + assert "vid.mp4" in result[0] + + def test_longest_match(self, tmp_path): + """Test that longest matching video name is used.""" + (tmp_path / "video.mp4").write_text("fake") + (tmp_path / "video_long.mp4").write_text("fake") + + keys = ["video_long_session"] + result = find_matching_videos( + keys, str(tmp_path), as_dict=False, recursive=False + ) + + # Should match "video_long" not "video" + assert "video_long.mp4" in result[0] + + def test_no_match_raises(self, tmp_path): + """Test that missing video raises assertion.""" + keys = ["nonexistent"] + + with pytest.raises(AssertionError, match="No matching videos"): + find_matching_videos(keys, str(tmp_path), as_dict=False, recursive=False) + + +class TestCheckVideoPaths: + """Test check_video_paths function.""" + + def test_valid_paths(self, tmp_path): + """Test with valid video paths.""" + video1 = tmp_path / "video1.mp4" + video1.write_bytes(b"fake video data") + + with patch("keypoint_moseq.util.OpenCVReader") as mock_reader: + mock_instance = MagicMock() + mock_instance.nframes = 100 + mock_reader.return_value = mock_instance + + video_paths = {"rec1": str(video1)} + keys = ["rec1"] + + # Should not raise + check_video_paths(video_paths, keys) + + def test_missing_key_raises(self): + """Test that missing key raises ValueError.""" + video_paths = {"rec1": "/path/to/video.mp4"} + keys = ["rec1", "rec2"] # rec2 missing + + with pytest.raises(ValueError, match="require a video path"): + check_video_paths(video_paths, keys) + + def test_nonexistent_video_raises(self): + """Test that nonexistent video raises ValueError.""" + video_paths = {"rec1": "/nonexistent/path/video.mp4"} + keys = ["rec1"] + + with pytest.raises(ValueError, match="do not exist"): + check_video_paths(video_paths, keys) + + def test_unreadable_video_raises(self, tmp_path): + """Test that unreadable video raises ValueError.""" + video1 = tmp_path / "corrupted.mp4" + video1.write_bytes(b"corrupted") + + with patch("keypoint_moseq.util.OpenCVReader") as mock_reader: + mock_reader.side_effect = Exception("Cannot read video") + + video_paths = {"rec1": str(video1)} + keys = ["rec1"] + + with pytest.raises(ValueError, match="not readable"): + check_video_paths(video_paths, keys) + + +class TestCheckNanProportions: + """Test check_nan_proportions function.""" + + def test_no_warnings_low_nans(self): + """Test no warnings when NaN proportion is low.""" + coordinates = { + "rec1": np.random.randn(100, 5, 2), + } + # Add few NaNs (< 50% threshold) + coordinates["rec1"][0:10, 0, :] = np.nan + + bodyparts = ["bp1", "bp2", "bp3", "bp4", "bp5"] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + check_nan_proportions(coordinates, bodyparts, warning_threshold=0.5) + assert len(w) == 0 + + def test_warning_high_nans(self): + """Test warning when NaN proportion exceeds threshold.""" + coordinates = { + "rec1": np.random.randn(100, 3, 2), + } + # Add many NaNs to bodypart 1 (> 50%) + coordinates["rec1"][:, 1, :] = np.nan + + bodyparts = ["bp1", "bp2", "bp3"] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + check_nan_proportions(coordinates, bodyparts, warning_threshold=0.5) + assert len(w) >= 1 + assert "bp2" in str(w[0].message) + + +class TestGetSyllableInstances: + """Test get_syllable_instances function.""" + + def test_basic_instances(self): + """Test extraction of syllable instances.""" + stateseqs = { + "rec1": np.array( + [0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0] * 10 + ), # Longer sequence + } + + instances = get_syllable_instances( + stateseqs, + min_duration=3, + pre=5, + post=70, + min_frequency=0, + min_instances=0, + ) + + # Should find instances of syllables (with enough boundary space) + assert len(instances) > 0 + # Each instance is a tuple (name, start, end) + for syllable, instance_list in instances.items(): + for name, start, end in instance_list: + assert name == "rec1" + assert start >= 5 # Respects pre + assert start < len(stateseqs["rec1"]) - 70 # Respects post + + def test_min_duration_filter(self): + """Test filtering by minimum duration.""" + stateseqs = { + "rec1": np.array( + [5, 5, 5, 5] + [0, 0, 1, 1, 1, 0] * 10 + [5, 5, 5, 5] * 20 + ), # Padding + repeats + } + + instances = get_syllable_instances( + stateseqs, + min_duration=3, # Only syllable 1 (duration 3) meets this + pre=3, + post=80, + ) + + # Syllable 1 (duration 3) should be included + assert 1 in instances + assert len(instances[1]) >= 1 + + def test_boundary_filtering(self): + """Test filtering instances near sequence boundaries.""" + stateseqs = { + "rec1": np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]), + } + + instances = get_syllable_instances( + stateseqs, + min_duration=3, + pre=3, # Exclude instances starting before frame 3 + post=3, # Exclude instances ending after frame len-3 + ) + + # Syllable 0 starts at frame 0 (excluded) + # Syllable 1 starts at frame 3, ends at 6 (included) + # Syllable 2 starts at frame 6 (excluded, too close to end) + assert 1 in instances + assert len(instances[1]) == 1 + + +class TestPrintDimsToExplainVariance: + """Test print_dims_to_explain_variance function.""" + + def test_sufficient_variance(self, capsys): + """Test printing when sufficient components exist.""" + mock_pca = Mock() + mock_pca.explained_variance_ratio_ = np.array([0.5, 0.3, 0.15, 0.05]) + + print_dims_to_explain_variance(mock_pca, 0.8) + captured = capsys.readouterr() + + # Should find that some components explain >=80% + # The function uses f">={f*100}% of variance explained by..." (typo "explained") + assert ">=80" in captured.out or "components" in captured.out + + def test_insufficient_variance(self, capsys): + """Test printing when components don't explain enough variance.""" + mock_pca = Mock() + mock_pca.explained_variance_ratio_ = np.array([0.3, 0.2, 0.15, 0.1]) + + print_dims_to_explain_variance(mock_pca, 0.9) + captured = capsys.readouterr() + + # Should indicate that all components together explain < 90% + assert "All components" in captured.out or "75%" in captured.out + + +class TestEstimateSigmasqLoc: + """Test estimate_sigmasq_loc function.""" + + def test_basic_estimation(self): + """Test basic sigmasq_loc estimation.""" + # Create simple trajectory with known movement + Y = np.zeros((2, 100, 5, 2)) # 2 batches, 100 frames, 5 keypoints, 2D + # Add linear motion + for i in range(100): + Y[:, i, :, 0] = i * 0.1 # Move in x direction + + mask = np.ones((2, 100)) + + result = estimate_sigmasq_loc(Y, mask, filter_size=5) + + # Should return a positive float + assert isinstance(result, float) + assert result > 0 + + def test_with_nans(self): + """Test estimation with masked frames.""" + Y = np.random.randn(3, 50, 4, 2) + mask = np.ones((3, 50)) + mask[:, 20:30] = 0 # Mask out middle frames + + result = estimate_sigmasq_loc(Y, mask, filter_size=10) + + assert isinstance(result, float) + assert not np.isnan(result) + + +# Mark all tests as quick tests +pytestmark = pytest.mark.quick