Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
name: Tests

on:
push:
branches: [ main, dev ]
pull_request:
branches: [ main ]
schedule:
- cron: '0 0 * * 0' # Run weekly on Sundays at midnight UTC
workflow_dispatch: # Allow manual triggering

jobs:
quick-tests:
name: Quick Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Cache pip packages
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[test]"

- name: Run quick tests
run: |
pytest tests/ -m "quick" -v --tb=short

- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
with:
name: quick-test-results-py${{ matrix.python-version }}
path: |
.pytest_cache
test-results.xml
retention-days: 7

full-tests:
name: Full Test Suite
runs-on: ubuntu-latest
if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main'

steps:
- uses: actions/checkout@v4

- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Cache pip packages
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-

- name: Cache test data
uses: actions/cache@v3
with:
path: ~/.cache/keypoint_moseq_tests
key: ${{ runner.os }}-test-data-v1
restore-keys: |
${{ runner.os }}-test-data-

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[test]"

- name: Run full test suite (exclude slow)
run: |
pytest tests/ \
--cov=keypoint_moseq \
--cov-report=xml \
--cov-report=term \
-m "not slow" \
-v \
--tb=short \
--junitxml=test-results.xml

- name: Check coverage threshold
run: |
coverage report --fail-under=40

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
fail_ci_if_error: false
verbose: true

- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
with:
name: full-test-results
path: |
coverage.xml
test-results.xml
htmlcov/
retention-days: 30

- name: Comment PR with coverage
if: github.event_name == 'pull_request'
uses: py-cov-action/python-coverage-comment-action@v3
with:
GITHUB_TOKEN: ${{ github.token }}
MINIMUM_GREEN: 50
MINIMUM_ORANGE: 40

slow-tests:
name: Slow Tests (Weekly)
runs-on: ubuntu-latest
if: github.event_name == 'workflow_dispatch' || github.event_name == 'schedule'

steps:
- uses: actions/checkout@v4

- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[test]"

- name: Run all tests including slow
run: |
pytest tests/ -v --tb=short --timeout=7200

- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
with:
name: slow-test-results
path: test-results.xml
retention-days: 30
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
**/.DS_Store
testing
update_pypi.sh
docs/source/dlc*
docs/source/demo*
tests/dlc*
temp*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Keypoint MoSeq
# Keypoint MoSeq

![Tests](https://github.com/dattalab/keypoint-moseq/actions/workflows/test.yml/badge.svg)
[![codecov](https://codecov.io/gh/dattalab/keypoint-moseq/branch/main/graph/badge.svg)](https://codecov.io/gh/dattalab/keypoint-moseq)
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)

![logo](docs/source/_static/logo.jpg)

Expand Down
22 changes: 21 additions & 1 deletion keypoint_moseq/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,29 @@
from glob import glob
import panel as pn
from jax_moseq.utils import get_durations, get_frequencies
from packaging import version

pn.extension("plotly", "tabulator")
na = np.newaxis

# seaborn version compatibility: evaluate once at import time
_SEABORN_VERSION = version.parse(sns.__version__)
_USE_NATIVE_SCALE = _SEABORN_VERSION >= version.parse("0.14")


def _get_pointplot_errorbar_kwargs():
"""Get the appropriate errorbar kwargs for seaborn pointplot based on version.

seaborn 0.14.0 changed the errorbar API from errorbar=("ci", 68)
to using native_scale parameter and errorbar="se".
"""
if _USE_NATIVE_SCALE:
# seaborn >= 0.14
return {"errorbar": "se", "native_scale": True}
else:
# seaborn < 0.14
return {"errorbar": ("ci", 68)}


def get_syllable_names(project_dir, model_name, syllable_ixs):
"""Get syllable names from syll_info.csv file. Labels consist of the
Expand Down Expand Up @@ -1151,16 +1170,17 @@ def plot_syll_stats_with_sem(

# plot each group's stat data separately, computes groupwise SEM, and orders data based on the stat/ordering parameters
hue = "group" if groups is not None else None
errorbar_kwargs = _get_pointplot_errorbar_kwargs()
ax = sns.pointplot(
data=stats_df,
x="syllable",
y=stat,
hue=hue,
order=ordering,
errorbar=("ci", 68),
ax=ax,
hue_order=groups,
palette=colors,
**errorbar_kwargs,
)

# where some data has already been plotted to ax
Expand Down
2 changes: 1 addition & 1 deletion keypoint_moseq/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def save_keypoints(
bodyparts = [f"bodypart{i}" for i in range(num_keypoints)]

# create column names
suffixes = ["x", "y", "z"][:num_keypoints]
suffixes = ["x", "y", "z"][:num_dims]
if confidences is not None:
suffixes += ["conf"]
columns = [f"{bp}_{suffix}" for bp in bodyparts for suffix in suffixes]
Expand Down
37 changes: 35 additions & 2 deletions keypoint_moseq/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
import numpy as np
import plotly
import matplotlib
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from vidio.read import OpenCVReader
Expand All @@ -22,6 +23,7 @@

from plotly.subplots import make_subplots
import plotly.io as pio
from packaging import version

pio.renderers.default = "iframe"

Expand All @@ -31,6 +33,10 @@
# suppress warnings from imageio
logging.getLogger().setLevel(logging.ERROR)

# matplotlib version compatibility: evaluate once at import time
_MATPLOTLIB_VERSION = version.parse(matplotlib.__version__)
_USE_BUFFER_RGBA = _MATPLOTLIB_VERSION >= version.parse("3.10")


def crop_image(image, centroid, crop_size):
"""Crop an image around a centroid.
Expand Down Expand Up @@ -1428,12 +1434,39 @@ def get_limits(
return lims.astype(int)


def _get_canvas_buffer_method(canvas):
"""Get the appropriate canvas buffer method based on matplotlib version.

matplotlib 3.10 removed tostring_rgb() in favor of buffer_rgba().
This function returns the correct method to call.
"""
if _USE_BUFFER_RGBA:
return canvas.buffer_rgba
else:
return canvas.tostring_rgb


def _reshape_canvas_buffer(raster_flat, height, width):
"""Reshape and convert canvas buffer to RGB format.

For matplotlib >= 3.10, drops the alpha channel from RGBA.
For matplotlib < 3.10, returns RGB directly.
"""
if _USE_BUFFER_RGBA:
# matplotlib >= 3.10: RGBA buffer, drop alpha channel
return raster_flat.reshape((height, width, 4))[:, :, :3]
else:
# matplotlib < 3.10: RGB buffer
return raster_flat.reshape((height, width, 3))


def rasterize_figure(fig):
canvas = fig.canvas
canvas.draw()
width, height = canvas.get_width_height()
raster_flat = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
raster = raster_flat.reshape((height, width, 3))
buffer_method = _get_canvas_buffer_method(canvas)
raster_flat = np.frombuffer(buffer_method(), dtype="uint8")
raster = _reshape_canvas_buffer(raster_flat, height, width)
return raster


Expand Down
Loading