Skip to content

Commit d23f6e0

Browse files
committed
fix: xticklabels centered and refactor
1 parent 2edaab6 commit d23f6e0

File tree

1 file changed

+76
-33
lines changed

1 file changed

+76
-33
lines changed

grassp/plotting/qc.py

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,73 @@ def bait_volcano_plots(
221221
return axs
222222

223223

224+
def _prepare_marker_profile_data(
225+
adata: AnnData,
226+
marker_column: str,
227+
plot_nan: bool,
228+
replicate_column: str | None,
229+
):
230+
"""Internal function to prepare data for marker profile plotting.
231+
232+
Parameters
233+
----------
234+
adata
235+
AnnData object with proteins in `.var` and samples/fractions in `.obs`.
236+
marker_column
237+
Column name in ``adata.obs`` containing marker annotations.
238+
plot_nan
239+
If ``True``, NaN entries in the marker column are included.
240+
replicate_column
241+
Column name in ``adata.var`` indicating replicate groups.
242+
243+
Returns
244+
-------
245+
marker_series
246+
Series containing marker annotations from adata.obs
247+
categories
248+
List of marker categories to plot
249+
palette
250+
Dictionary mapping categories to colors
251+
replicate_boundaries
252+
List of x-positions for replicate boundary lines
253+
"""
254+
# Validation
255+
if marker_column not in adata.obs.columns:
256+
raise ValueError(f"Column '{marker_column}' not found in adata.obs")
257+
258+
if replicate_column is not None and replicate_column not in adata.var.columns:
259+
raise ValueError(f"Column '{replicate_column}' not found in adata.var")
260+
261+
# Get marker categories
262+
if not isinstance(adata.obs[marker_column].dtype, pd.CategoricalDtype):
263+
adata.obs[marker_column] = adata.obs[marker_column].astype('category')
264+
265+
marker_series = adata.obs[marker_column]
266+
if plot_nan:
267+
categories = marker_series.cat.categories
268+
categories = categories[pd.notna(categories)] if not plot_nan else categories
269+
else:
270+
categories = marker_series.dropna().unique()
271+
272+
categories = sorted([cat for cat in categories if pd.notna(cat)])
273+
if plot_nan and marker_series.isna().any():
274+
categories.append(np.nan)
275+
276+
# Get colors for each marker category
277+
palette = scanpy.plotting._tools.scatterplots._get_palette(adata, marker_column)
278+
279+
# Find replicate boundaries if replicate_column is provided
280+
replicate_boundaries = []
281+
if replicate_column is not None:
282+
replicate_series = adata.var[replicate_column]
283+
# Find indices where replicate changes (on the last protein of each replicate)
284+
for i in range(1, len(replicate_series)):
285+
if replicate_series.iloc[i - 1] != replicate_series.iloc[i]:
286+
replicate_boundaries.append(i - 1)
287+
288+
return marker_series, categories, palette, replicate_boundaries
289+
290+
224291
def marker_profiles_split(
225292
adata: AnnData,
226293
marker_column: str,
@@ -273,39 +340,15 @@ def marker_profiles_split(
273340
Returns
274341
-------
275342
Returns the array of Axes if ``show`` is ``False``, otherwise ``None``.
276-
"""
277-
if marker_column not in adata.obs.columns:
278-
raise ValueError(f"Column '{marker_column}' not found in adata.obs")
279-
280-
if replicate_column is not None and replicate_column not in adata.var.columns:
281-
raise ValueError(f"Column '{replicate_column}' not found in adata.var")
282-
283-
# Get marker categories
284-
if not isinstance(adata.obs[marker_column].dtype, pd.CategoricalDtype):
285-
adata.obs[marker_column] = adata.obs[marker_column].astype('category')
286-
287-
marker_series = adata.obs[marker_column]
288-
if plot_nan:
289-
categories = marker_series.cat.categories
290-
categories = categories[pd.notna(categories)] if not plot_nan else categories
291-
else:
292-
categories = marker_series.dropna().unique()
293-
294-
categories = sorted([cat for cat in categories if pd.notna(cat)])
295-
if plot_nan and marker_series.isna().any():
296-
categories.append(np.nan)
297343
298-
# Get colors for each marker category
299-
palette = scanpy.plotting._tools.scatterplots._get_palette(adata, marker_column)
300-
301-
# Find replicate boundaries if replicate_column is provided
302-
replicate_boundaries = []
303-
if replicate_column is not None:
304-
replicate_series = adata.var[replicate_column]
305-
# Find indices where replicate changes (last index of each replicate)
306-
for i in range(len(replicate_series) - 1):
307-
if replicate_series.iloc[i] != replicate_series.iloc[i + 1]:
308-
replicate_boundaries.append(i + 0.5)
344+
See Also
345+
--------
346+
marker_profiles : Plot mean profiles with error bands in a single plot.
347+
"""
348+
# Prepare data
349+
marker_series, categories, palette, replicate_boundaries = _prepare_marker_profile_data(
350+
adata, marker_column, plot_nan, replicate_column
351+
)
309352

310353
# Create figure
311354
n_categories = len(categories)
@@ -354,7 +397,7 @@ def marker_profiles_split(
354397
# Set x-tick labels if requested
355398
if xticklabels:
356399
ax.set_xticks(range(n_proteins))
357-
ax.set_xticklabels(adata.var_names, rotation=90, ha="right")
400+
ax.set_xticklabels(adata.var_names, rotation=90, ha="center")
358401

359402
# Add legend to the last subplot with data
360403
if plot_mean and idx == len(categories) - 1:

0 commit comments

Comments
 (0)