1717from collections .abc import Sequence
1818import dataclasses
1919import enum
20- import math
2120from typing import Any
2221
2322from alphagenome .data import genome
2423from alphagenome .data import transcript as transcript_utils
2524import intervaltree
2625import matplotlib as mpl
26+ import matplotlib .figure
27+ import matplotlib .path
2728import matplotlib .pyplot as plt
2829
2930
@@ -156,6 +157,7 @@ def plot_transcripts(
156157 label = None
157158 if (label in labels_already_drawn and plot_labels_once )
158159 else label ,
160+ num_transcripts = len (transcripts ),
159161 ** kwargs ,
160162 )
161163
@@ -178,6 +180,7 @@ def draw_transcript(
178180 shift : int = 0 ,
179181 label : str | None = None ,
180182 label_color : str = '#7f7f7f' ,
183+ num_transcripts : int = 1 ,
181184 ** kwargs ,
182185) -> None :
183186 """Draw an individual transcript as rectangular components on an axis.
@@ -202,6 +205,8 @@ def draw_transcript(
202205 shift: X-axis shift.
203206 label: Optional label to draw next to the transcript.
204207 label_color: Label color.
208+ num_transcripts: Total number of transcripts being drawn, used for dynamic
209+ arrow sizing.
205210 **kwargs: Additional keyword arguments passed to matplotlib plotting
206211 functions.
207212 """
@@ -228,35 +233,6 @@ def draw_exons_and_introns(exons, color, exon_height):
228233 # 2. Draw all introns.
229234 for intron in transcript_utils .Transcript (exons ).introns :
230235 ax .plot ([intron .start , intron .end ], [y , y ], color = color , linewidth = 0.5 )
231- # Draw max_num_arrows arrows, at least one arrow per intron.
232- intron_to_interval_fraction = intron .width / interval .width
233- # If this fraction is too small we do not draw arrow
234- if intron_to_interval_fraction < 0.01 :
235- continue
236- max_num_arrows = 10
237- num_arrows = min (
238- math .ceil (intron_to_interval_fraction * max_num_arrows ),
239- max_num_arrows ,
240- )
241- space = intron .width // num_arrows
242-
243- for i in range (num_arrows ):
244- arrow_start = intron .start - space // 2 + i * space
245- # Arrows are at most 10bp from the exons.
246- if (
247- arrow_start + space - 10 > interval .end
248- or arrow_start < interval .start + 10
249- ):
250- continue
251- style = '<-' if transcript .is_negative_strand else '->'
252- ax .annotate (
253- '' ,
254- xy = (arrow_start + space , y ),
255- xytext = (arrow_start + space - 0.001 , y ),
256- arrowprops = dict (
257- arrowstyle = f'{ style } ,head_width=0.3' , linewidth = 0.5 , color = color
258- ),
259- )
260236
261237 # First draw all exons and introns with UTR height.
262238 draw_exons_and_introns (
@@ -298,6 +274,116 @@ def draw_exons_and_introns(exons, color, exon_height):
298274 transcript .utr3 , color = utr3_color , exon_height = utr_height
299275 )
300276
277+ # Draw strand arrows across the full transcript span.
278+ draw_strand_arrows (
279+ ax = ax ,
280+ transcript = transcript ,
281+ interval = interval ,
282+ y = y ,
283+ color = cds_color ,
284+ cds_height = cds_height ,
285+ num_transcripts = num_transcripts ,
286+ )
287+
288+
289+ def draw_strand_arrows (
290+ ax : plt .Axes ,
291+ transcript : transcript_utils .Transcript ,
292+ interval : genome .Interval ,
293+ y : float ,
294+ color : str ,
295+ * ,
296+ cds_height : float = 0.22 ,
297+ num_transcripts : int = 1 ,
298+ max_arrows_per_intron : int = 5 ,
299+ ) -> None :
300+ """Draw strand direction arrows on intron lines.
301+
302+ Arrow count per intron is computed dynamically based on the intron's width
303+ relative to the visible interval. Marker size is derived from the UTR height
304+ so arrows are always visually smaller than UTR exons.
305+
306+ Args:
307+ ax: Matplotlib axis.
308+ transcript: The transcript being drawn.
309+ interval: The visible genomic interval.
310+ y: Vertical position of the transcript.
311+ color: Arrow color.
312+ cds_height: CDS height in data coordinates, used to scale arrows.
313+ num_transcripts: Total number of transcripts being drawn.
314+ max_arrows_per_intron: Maximum number of arrows per intron.
315+ """
316+ introns = transcript_utils .Transcript (transcript .exons ).introns
317+ if not introns :
318+ return
319+
320+ fig = ax .get_figure ()
321+ if fig is not None :
322+ _ , fig_height_inches = (
323+ fig .get_size_inches () # pytype: disable=attribute-error
324+ )
325+ ax_height_inches = ax .get_position ().height * fig_height_inches
326+ y_range = num_transcripts + 2
327+ if y_range > 0 :
328+ pts_per_data = (ax_height_inches * 72 ) / y_range
329+ markersize = min (4.0 , cds_height * pts_per_data * 2 )
330+ else :
331+ markersize = 4.0
332+ else :
333+ markersize = 4.0
334+
335+ # Custom chevron path: two line segments forming > or < shape.
336+ if transcript .is_negative_strand :
337+ chevron = matplotlib .path .Path (
338+ [(0.5 , 0.5 ), (- 0.5 , 0.0 ), (0.5 , - 0.5 )],
339+ [
340+ matplotlib .path .Path .MOVETO ,
341+ matplotlib .path .Path .LINETO ,
342+ matplotlib .path .Path .LINETO ,
343+ ],
344+ )
345+ else :
346+ chevron = matplotlib .path .Path (
347+ [(- 0.5 , 0.5 ), (0.5 , 0.0 ), (- 0.5 , - 0.5 )],
348+ [
349+ matplotlib .path .Path .MOVETO ,
350+ matplotlib .path .Path .LINETO ,
351+ matplotlib .path .Path .LINETO ,
352+ ],
353+ )
354+
355+ arrow_positions = []
356+ for intron in introns :
357+ intron_to_interval_fraction = intron .width / interval .width
358+ # Skip arrows for introns that are too small.
359+ if intron_to_interval_fraction < 0.01 :
360+ continue
361+ # Use sqrt scaling so large introns don't get overwhelmed with arrows.
362+ num_arrows = min (
363+ max (1 , round (intron_to_interval_fraction ** 0.5 * max_arrows_per_intron )),
364+ max_arrows_per_intron ,
365+ )
366+ space = intron .width / (num_arrows + 1 )
367+ for i in range (1 , num_arrows + 1 ):
368+ arrow_pos = intron .start + i * space
369+ # Skip arrows too close to interval edges.
370+ if arrow_pos < interval .start + 10 or arrow_pos > interval .end - 10 :
371+ continue
372+ arrow_positions .append (arrow_pos )
373+
374+ if arrow_positions :
375+ ax .plot (
376+ arrow_positions ,
377+ [y ] * len (arrow_positions ),
378+ marker = chevron ,
379+ markersize = markersize ,
380+ color = color ,
381+ fillstyle = 'none' ,
382+ markeredgewidth = 0.8 ,
383+ linestyle = 'none' ,
384+ clip_on = True ,
385+ )
386+
301387
302388def draw_interval (
303389 ax : plt .Axes ,
0 commit comments