Skip to content

Commit 4edbcfb

Browse files
committed
implement test for plot_trajectories
1 parent a22e628 commit 4edbcfb

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

tests/test_plot.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import matplotlib
88
import matplotlib.pyplot as plt
99
from matplotlib.markers import MarkerStyle
10+
from matplotlib.lines import Line2D
1011

1112
from palantir.plot import (
1213
density_2d,
@@ -27,6 +28,8 @@
2728
plot_gene_trend_heatmaps,
2829
plot_gene_trend_clusters,
2930
gene_score_histogram,
31+
plot_trajectories,
32+
plot_trajectory,
3033
)
3134
from palantir.presults import PResults
3235

@@ -676,3 +679,92 @@ def test_gene_score_histogram_errors(mock_anndata):
676679
# Test with invalid quantile
677680
with pytest.raises(ValueError):
678681
gene_score_histogram(mock_anndata, "gene_score", quantile=1.5)
682+
683+
684+
def test_plot_trajectory(mock_anndata):
685+
# Test with basic parameters
686+
ax = plot_trajectory(mock_anndata, "a")
687+
assert isinstance(ax, plt.Axes)
688+
assert ax.get_title() == "Branch: a"
689+
plt.close()
690+
691+
# Test with custom parameters
692+
fig, custom_ax = plt.subplots(figsize=(6, 6))
693+
ax = plot_trajectory(
694+
mock_anndata,
695+
"a",
696+
ax=custom_ax,
697+
cell_color="palantir_pseudotime",
698+
smoothness=2.0,
699+
n_arrows=3,
700+
arrowprops={"color": "red"},
701+
pseudotime_interval=(0.2, 0.8)
702+
)
703+
assert ax is custom_ax
704+
plt.close()
705+
706+
707+
def test_plot_trajectory_errors(mock_anndata):
708+
# Test with invalid branch - note the error is actually a TypeError in the implementation
709+
with pytest.raises(TypeError):
710+
plot_trajectory(mock_anndata, "invalid_branch")
711+
712+
# Test with invalid keys
713+
with pytest.raises(KeyError):
714+
plot_trajectory(mock_anndata, "a", pseudo_time_key="invalid_key")
715+
716+
with pytest.raises(KeyError):
717+
plot_trajectory(mock_anndata, "a", masks_key="invalid_key")
718+
719+
with pytest.raises(KeyError):
720+
plot_trajectory(mock_anndata, "a", embedding_basis="invalid_basis")
721+
722+
# Test with invalid pseudotime_interval
723+
with pytest.raises(ValueError):
724+
plot_trajectory(mock_anndata, "a", pseudotime_interval=[0.1, 0.2, 0.3])
725+
726+
727+
def test_plot_trajectories(mock_anndata):
728+
# Test with default parameters
729+
ax = plot_trajectories(mock_anndata)
730+
assert isinstance(ax, plt.Axes)
731+
plt.close()
732+
733+
# Test with specific branches
734+
ax = plot_trajectories(mock_anndata, groups=["a", "b"])
735+
assert isinstance(ax, plt.Axes)
736+
plt.close()
737+
738+
# Test with custom parameters
739+
ax = plot_trajectories(
740+
mock_anndata,
741+
cell_color="palantir_pseudotime",
742+
smoothness=2.0,
743+
n_arrows=3,
744+
arrowprops={"color": "blue"},
745+
outline_arrowprops={"color": "black", "lw": 3},
746+
show_legend=True,
747+
legend_kwargs={"loc": "upper left", "frameon": True}
748+
)
749+
assert isinstance(ax, plt.Axes)
750+
plt.close()
751+
752+
753+
def test_plot_trajectories_errors(mock_anndata):
754+
# Test with invalid branches
755+
with pytest.raises(ValueError):
756+
plot_trajectories(mock_anndata, groups=["invalid_branch"])
757+
758+
# Test with invalid keys
759+
with pytest.raises(KeyError):
760+
plot_trajectories(mock_anndata, pseudo_time_key="invalid_key")
761+
762+
with pytest.raises(KeyError):
763+
plot_trajectories(mock_anndata, masks_key="invalid_key")
764+
765+
with pytest.raises(KeyError):
766+
plot_trajectories(mock_anndata, embedding_basis="invalid_basis")
767+
768+
# Test with invalid pseudotime_interval
769+
with pytest.raises(ValueError):
770+
plot_trajectories(mock_anndata, pseudotime_interval=[0.1, 0.2, 0.3])

0 commit comments

Comments
 (0)