@@ -2130,6 +2130,7 @@ def plot_trajectory(
21302130 embedding_basis : str = "X_umap" ,
21312131 cell_color : str = "branch_selection" ,
21322132 smoothness : float = 1.0 ,
2133+ pseudotime_interval : Optional [Union [Tuple [float , float ], List [float ], np .ndarray ]] = None ,
21332134 n_arrows : int = 5 ,
21342135 arrowprops : Optional [dict ] = dict (),
21352136 scanpy_kwargs : Optional [dict ] = dict (),
@@ -2158,6 +2159,8 @@ def plot_trajectory(
21582159 If None, no coloring is applied. Defaults to 'branch_selection'.
21592160 smoothness : float, optional
21602161 Smoothness of fitted trajectory. Higher value means smoother. Defaults to 1.
2162+ pseudotime_interval : tuple, list, np.ndarray, optional
2163+ Interval for pseudotime values. If None, it is automatically determined.
21612164 n_arrows : int, optional
21622165 Number of arrows to plot. Defaults to 5.
21632166 arrowprops : dict, optional
@@ -2200,7 +2203,12 @@ def plot_trajectory(
22002203 mask = fate_mask [branch ].astype (bool )
22012204
22022205 pseudotime = pt [mask ]
2203- pseudotime_grid = np .linspace (np .min (pseudotime ), np .max (pseudotime ), 200 )
2206+ if pseudotime_interval is None :
2207+ pseudotime_interval = (np .min (pseudotime ), np .max (pseudotime ))
2208+ else :
2209+ if len (pseudotime_interval ) != 2 :
2210+ raise ValueError ("pseudotime_interval must be a tuple of two values." )
2211+ pseudotime_grid = np .linspace (pseudotime_interval [0 ], pseudotime_interval [1 ], 200 )
22042212 ls = smoothness * np .sqrt (np .sum ((np .max (umap , axis = 0 ) - np .min (umap , axis = 0 )) ** 2 )) / 20
22052213 umap_est = mellon .FunctionEstimator (ls = ls , sigma = ls , n_landmarks = 50 )
22062214 umap_trajectory = umap_est .fit_predict (pseudotime , umap [mask , :], pseudotime_grid )
0 commit comments