Skip to content

Commit b3e31f0

Browse files
s6junchengcopybara-github
authored andcommitted
Improve transcript strand arrow plotting.
Arrows are less dominant compared with before. When plotting with a large interval and many transcripts, shrink the size of arrows dynamically based on the actual CDS hight per transcript. Also improved the alignment of arrows with intron line. PiperOrigin-RevId: 878942570 Change-Id: I06cb38d5d291ae9c757017c1514be28375237b3c
1 parent c9311d7 commit b3e31f0

2 files changed

Lines changed: 554 additions & 30 deletions

File tree

src/alphagenome/visualization/plot_transcripts.py

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
from collections.abc import Sequence
1818
import dataclasses
1919
import enum
20-
import math
2120
from typing import Any
2221

2322
from alphagenome.data import genome
2423
from alphagenome.data import transcript as transcript_utils
2524
import intervaltree
2625
import matplotlib as mpl
26+
import matplotlib.figure
27+
import matplotlib.path
2728
import matplotlib.pyplot as plt
2829

2930

@@ -156,6 +157,7 @@ def plot_transcripts(
156157
label=None
157158
if (label in labels_already_drawn and plot_labels_once)
158159
else label,
160+
num_transcripts=len(transcripts),
159161
**kwargs,
160162
)
161163

@@ -178,6 +180,7 @@ def draw_transcript(
178180
shift: int = 0,
179181
label: str | None = None,
180182
label_color: str = '#7f7f7f',
183+
num_transcripts: int = 1,
181184
**kwargs,
182185
) -> None:
183186
"""Draw an individual transcript as rectangular components on an axis.
@@ -202,6 +205,8 @@ def draw_transcript(
202205
shift: X-axis shift.
203206
label: Optional label to draw next to the transcript.
204207
label_color: Label color.
208+
num_transcripts: Total number of transcripts being drawn, used for dynamic
209+
arrow sizing.
205210
**kwargs: Additional keyword arguments passed to matplotlib plotting
206211
functions.
207212
"""
@@ -228,35 +233,6 @@ def draw_exons_and_introns(exons, color, exon_height):
228233
# 2. Draw all introns.
229234
for intron in transcript_utils.Transcript(exons).introns:
230235
ax.plot([intron.start, intron.end], [y, y], color=color, linewidth=0.5)
231-
# Draw max_num_arrows arrows, at least one arrow per intron.
232-
intron_to_interval_fraction = intron.width / interval.width
233-
# If this fraction is too small we do not draw arrow
234-
if intron_to_interval_fraction < 0.01:
235-
continue
236-
max_num_arrows = 10
237-
num_arrows = min(
238-
math.ceil(intron_to_interval_fraction * max_num_arrows),
239-
max_num_arrows,
240-
)
241-
space = intron.width // num_arrows
242-
243-
for i in range(num_arrows):
244-
arrow_start = intron.start - space // 2 + i * space
245-
# Arrows are at most 10bp from the exons.
246-
if (
247-
arrow_start + space - 10 > interval.end
248-
or arrow_start < interval.start + 10
249-
):
250-
continue
251-
style = '<-' if transcript.is_negative_strand else '->'
252-
ax.annotate(
253-
'',
254-
xy=(arrow_start + space, y),
255-
xytext=(arrow_start + space - 0.001, y),
256-
arrowprops=dict(
257-
arrowstyle=f'{style},head_width=0.3', linewidth=0.5, color=color
258-
),
259-
)
260236

261237
# First draw all exons and introns with UTR height.
262238
draw_exons_and_introns(
@@ -298,6 +274,116 @@ def draw_exons_and_introns(exons, color, exon_height):
298274
transcript.utr3, color=utr3_color, exon_height=utr_height
299275
)
300276

277+
# Draw strand arrows across the full transcript span.
278+
draw_strand_arrows(
279+
ax=ax,
280+
transcript=transcript,
281+
interval=interval,
282+
y=y,
283+
color=cds_color,
284+
cds_height=cds_height,
285+
num_transcripts=num_transcripts,
286+
)
287+
288+
289+
def draw_strand_arrows(
290+
ax: plt.Axes,
291+
transcript: transcript_utils.Transcript,
292+
interval: genome.Interval,
293+
y: float,
294+
color: str,
295+
*,
296+
cds_height: float = 0.22,
297+
num_transcripts: int = 1,
298+
max_arrows_per_intron: int = 5,
299+
) -> None:
300+
"""Draw strand direction arrows on intron lines.
301+
302+
Arrow count per intron is computed dynamically based on the intron's width
303+
relative to the visible interval. Marker size is derived from the UTR height
304+
so arrows are always visually smaller than UTR exons.
305+
306+
Args:
307+
ax: Matplotlib axis.
308+
transcript: The transcript being drawn.
309+
interval: The visible genomic interval.
310+
y: Vertical position of the transcript.
311+
color: Arrow color.
312+
cds_height: CDS height in data coordinates, used to scale arrows.
313+
num_transcripts: Total number of transcripts being drawn.
314+
max_arrows_per_intron: Maximum number of arrows per intron.
315+
"""
316+
introns = transcript_utils.Transcript(transcript.exons).introns
317+
if not introns:
318+
return
319+
320+
fig = ax.get_figure()
321+
if fig is not None:
322+
_, fig_height_inches = (
323+
fig.get_size_inches() # pytype: disable=attribute-error
324+
)
325+
ax_height_inches = ax.get_position().height * fig_height_inches
326+
y_range = num_transcripts + 2
327+
if y_range > 0:
328+
pts_per_data = (ax_height_inches * 72) / y_range
329+
markersize = min(4.0, cds_height * pts_per_data * 2)
330+
else:
331+
markersize = 4.0
332+
else:
333+
markersize = 4.0
334+
335+
# Custom chevron path: two line segments forming > or < shape.
336+
if transcript.is_negative_strand:
337+
chevron = matplotlib.path.Path(
338+
[(0.5, 0.5), (-0.5, 0.0), (0.5, -0.5)],
339+
[
340+
matplotlib.path.Path.MOVETO,
341+
matplotlib.path.Path.LINETO,
342+
matplotlib.path.Path.LINETO,
343+
],
344+
)
345+
else:
346+
chevron = matplotlib.path.Path(
347+
[(-0.5, 0.5), (0.5, 0.0), (-0.5, -0.5)],
348+
[
349+
matplotlib.path.Path.MOVETO,
350+
matplotlib.path.Path.LINETO,
351+
matplotlib.path.Path.LINETO,
352+
],
353+
)
354+
355+
arrow_positions = []
356+
for intron in introns:
357+
intron_to_interval_fraction = intron.width / interval.width
358+
# Skip arrows for introns that are too small.
359+
if intron_to_interval_fraction < 0.01:
360+
continue
361+
# Use sqrt scaling so large introns don't get overwhelmed with arrows.
362+
num_arrows = min(
363+
max(1, round(intron_to_interval_fraction**0.5 * max_arrows_per_intron)),
364+
max_arrows_per_intron,
365+
)
366+
space = intron.width / (num_arrows + 1)
367+
for i in range(1, num_arrows + 1):
368+
arrow_pos = intron.start + i * space
369+
# Skip arrows too close to interval edges.
370+
if arrow_pos < interval.start + 10 or arrow_pos > interval.end - 10:
371+
continue
372+
arrow_positions.append(arrow_pos)
373+
374+
if arrow_positions:
375+
ax.plot(
376+
arrow_positions,
377+
[y] * len(arrow_positions),
378+
marker=chevron,
379+
markersize=markersize,
380+
color=color,
381+
fillstyle='none',
382+
markeredgewidth=0.8,
383+
linestyle='none',
384+
clip_on=True,
385+
)
386+
301387

302388
def draw_interval(
303389
ax: plt.Axes,

0 commit comments

Comments
 (0)