Skip to content

Commit 453bac8

Browse files
committed
pseudotime_interval in palantir.plot.plot_trajectory
1 parent b44130c commit 453bac8

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ Release Notes
8282
* Expanded and standardized documentation with NumPy-style docstrings throughout the codebase
8383
* Added comprehensive type hints to improve code quality and IDE support
8484
* Remove dependency from `_` methods in scanpy for plotting.
85+
* add `pseudotime_interval` argument to control path length in `palantir.plot.plot_trajectory`
8586

8687
#### Testing and Quality Improvements
8788
* Added comprehensive tests for optional pygam dependency

src/palantir/plot.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2130,6 +2130,7 @@ def plot_trajectory(
21302130
embedding_basis: str = "X_umap",
21312131
cell_color: str = "branch_selection",
21322132
smoothness: float = 1.0,
2133+
pseudotime_interval: Optional[Union[Tuple[float, float], List[float], np.ndarray]] = None,
21332134
n_arrows: int = 5,
21342135
arrowprops: Optional[dict] = dict(),
21352136
scanpy_kwargs: Optional[dict] = dict(),
@@ -2158,6 +2159,8 @@ def plot_trajectory(
21582159
If None, no coloring is applied. Defaults to 'branch_selection'.
21592160
smoothness : float, optional
21602161
Smoothness of fitted trajectory. Higher value means smoother. Defaults to 1.
2162+
pseudotime_interval : tuple, list, np.ndarray, optional
2163+
Interval for pseudotime values. If None, it is automatically determined.
21612164
n_arrows : int, optional
21622165
Number of arrows to plot. Defaults to 5.
21632166
arrowprops : dict, optional
@@ -2200,7 +2203,12 @@ def plot_trajectory(
22002203
mask = fate_mask[branch].astype(bool)
22012204

22022205
pseudotime = pt[mask]
2203-
pseudotime_grid = np.linspace(np.min(pseudotime), np.max(pseudotime), 200)
2206+
if pseudotime_interval is None:
2207+
pseudotime_interval = (np.min(pseudotime), np.max(pseudotime))
2208+
else:
2209+
if len(pseudotime_interval) != 2:
2210+
raise ValueError("pseudotime_interval must be a tuple of two values.")
2211+
pseudotime_grid = np.linspace(pseudotime_interval[0], pseudotime_interval[1], 200)
22042212
ls = smoothness * np.sqrt(np.sum((np.max(umap, axis=0) - np.min(umap, axis=0)) ** 2)) / 20
22052213
umap_est = mellon.FunctionEstimator(ls=ls, sigma=ls, n_landmarks=50)
22062214
umap_trajectory = umap_est.fit_predict(pseudotime, umap[mask, :], pseudotime_grid)

0 commit comments

Comments
 (0)