|
7 | 7 | import matplotlib |
8 | 8 | import matplotlib.pyplot as plt |
9 | 9 | from matplotlib.markers import MarkerStyle |
| 10 | +from matplotlib.lines import Line2D |
10 | 11 |
|
11 | 12 | from palantir.plot import ( |
12 | 13 | density_2d, |
|
27 | 28 | plot_gene_trend_heatmaps, |
28 | 29 | plot_gene_trend_clusters, |
29 | 30 | gene_score_histogram, |
| 31 | + plot_trajectories, |
| 32 | + plot_trajectory, |
30 | 33 | ) |
31 | 34 | from palantir.presults import PResults |
32 | 35 |
|
@@ -676,3 +679,92 @@ def test_gene_score_histogram_errors(mock_anndata): |
676 | 679 | # Test with invalid quantile |
677 | 680 | with pytest.raises(ValueError): |
678 | 681 | gene_score_histogram(mock_anndata, "gene_score", quantile=1.5) |
| 682 | + |
| 683 | + |
| 684 | +def test_plot_trajectory(mock_anndata): |
| 685 | + # Test with basic parameters |
| 686 | + ax = plot_trajectory(mock_anndata, "a") |
| 687 | + assert isinstance(ax, plt.Axes) |
| 688 | + assert ax.get_title() == "Branch: a" |
| 689 | + plt.close() |
| 690 | + |
| 691 | + # Test with custom parameters |
| 692 | + fig, custom_ax = plt.subplots(figsize=(6, 6)) |
| 693 | + ax = plot_trajectory( |
| 694 | + mock_anndata, |
| 695 | + "a", |
| 696 | + ax=custom_ax, |
| 697 | + cell_color="palantir_pseudotime", |
| 698 | + smoothness=2.0, |
| 699 | + n_arrows=3, |
| 700 | + arrowprops={"color": "red"}, |
| 701 | + pseudotime_interval=(0.2, 0.8) |
| 702 | + ) |
| 703 | + assert ax is custom_ax |
| 704 | + plt.close() |
| 705 | + |
| 706 | + |
| 707 | +def test_plot_trajectory_errors(mock_anndata): |
| 708 | + # Test with invalid branch - note the error is actually a TypeError in the implementation |
| 709 | + with pytest.raises(TypeError): |
| 710 | + plot_trajectory(mock_anndata, "invalid_branch") |
| 711 | + |
| 712 | + # Test with invalid keys |
| 713 | + with pytest.raises(KeyError): |
| 714 | + plot_trajectory(mock_anndata, "a", pseudo_time_key="invalid_key") |
| 715 | + |
| 716 | + with pytest.raises(KeyError): |
| 717 | + plot_trajectory(mock_anndata, "a", masks_key="invalid_key") |
| 718 | + |
| 719 | + with pytest.raises(KeyError): |
| 720 | + plot_trajectory(mock_anndata, "a", embedding_basis="invalid_basis") |
| 721 | + |
| 722 | + # Test with invalid pseudotime_interval |
| 723 | + with pytest.raises(ValueError): |
| 724 | + plot_trajectory(mock_anndata, "a", pseudotime_interval=[0.1, 0.2, 0.3]) |
| 725 | + |
| 726 | + |
| 727 | +def test_plot_trajectories(mock_anndata): |
| 728 | + # Test with default parameters |
| 729 | + ax = plot_trajectories(mock_anndata) |
| 730 | + assert isinstance(ax, plt.Axes) |
| 731 | + plt.close() |
| 732 | + |
| 733 | + # Test with specific branches |
| 734 | + ax = plot_trajectories(mock_anndata, groups=["a", "b"]) |
| 735 | + assert isinstance(ax, plt.Axes) |
| 736 | + plt.close() |
| 737 | + |
| 738 | + # Test with custom parameters |
| 739 | + ax = plot_trajectories( |
| 740 | + mock_anndata, |
| 741 | + cell_color="palantir_pseudotime", |
| 742 | + smoothness=2.0, |
| 743 | + n_arrows=3, |
| 744 | + arrowprops={"color": "blue"}, |
| 745 | + outline_arrowprops={"color": "black", "lw": 3}, |
| 746 | + show_legend=True, |
| 747 | + legend_kwargs={"loc": "upper left", "frameon": True} |
| 748 | + ) |
| 749 | + assert isinstance(ax, plt.Axes) |
| 750 | + plt.close() |
| 751 | + |
| 752 | + |
| 753 | +def test_plot_trajectories_errors(mock_anndata): |
| 754 | + # Test with invalid branches |
| 755 | + with pytest.raises(ValueError): |
| 756 | + plot_trajectories(mock_anndata, groups=["invalid_branch"]) |
| 757 | + |
| 758 | + # Test with invalid keys |
| 759 | + with pytest.raises(KeyError): |
| 760 | + plot_trajectories(mock_anndata, pseudo_time_key="invalid_key") |
| 761 | + |
| 762 | + with pytest.raises(KeyError): |
| 763 | + plot_trajectories(mock_anndata, masks_key="invalid_key") |
| 764 | + |
| 765 | + with pytest.raises(KeyError): |
| 766 | + plot_trajectories(mock_anndata, embedding_basis="invalid_basis") |
| 767 | + |
| 768 | + # Test with invalid pseudotime_interval |
| 769 | + with pytest.raises(ValueError): |
| 770 | + plot_trajectories(mock_anndata, pseudotime_interval=[0.1, 0.2, 0.3]) |
0 commit comments