Skip to content

Commit dfa47bf

Browse files
Improvements following Josh's comments
1 parent c6cf77f commit dfa47bf

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

velociraptor/autoplotter/lines.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
from unyt import unyt_quantity, unyt_array
6-
from numpy import logspace, linspace, log10, logical_and, isnan, sqrt
6+
from numpy import logspace, linspace, log10, logical_and, isnan, sqrt, logical_or
77
from typing import Dict, Union, Tuple, List
88
from matplotlib.pyplot import Axes
99
from matplotlib.transforms import blended_transform_factory
@@ -338,21 +338,22 @@ def highlight_data_outside_domain(
338338
# Arrow parameters
339339
arrow_length = 0.07
340340
distance_from_edge = 0.01
341+
arrow_style = "->"
341342

342343
# Split data into three categories (along X axis)
343344
below_x_range = x < x_lim[0]
344345
above_x_range = x > x_lim[1]
345-
within_x_range = (x > x_lim[0]) * (x < x_lim[1])
346+
within_x_range = logical_and(x > x_lim[0], x < x_lim[1])
346347

347348
# Split data into three categories (along Y axis)
348349
below_y_range = y < y_lim[0]
349350
above_y_range = y > y_lim[1]
350-
within_y_range = (y > y_lim[0]) * (y < y_lim[1])
351+
within_y_range = logical_and(y > y_lim[0], y < y_lim[1])
351352

352353
# First, find all data points that are outside the Y-axis range and within
353354
# X-axis range
354-
below_y_within_x = below_y_range * within_x_range
355-
above_y_within_x = above_y_range * within_x_range
355+
below_y_within_x = logical_and(below_y_range, within_x_range)
356+
above_y_within_x = logical_and(above_y_range, within_x_range)
356357

357358
# X coordinates of the data points whose Y coordinates are outside the
358359
# Y-axis range
@@ -365,13 +366,17 @@ def highlight_data_outside_domain(
365366

366367
# Draw arrows pointing downwards
367368
for x_down in x_down_list:
369+
# We are using 'ax.annotate' instead of 'ax.arrow' because we want the
370+
# arrow's head and tail to have the same size regardless of what the
371+
# axes aspect ratio is or whether the plot is in logarithmic or linear
372+
# scale.
368373
ax.annotate(
369374
"",
370375
xytext=(x_down, arrow_length + distance_from_edge),
371376
textcoords=tform_x,
372377
xy=(x_down, distance_from_edge),
373378
xycoords=tform_x,
374-
arrowprops=dict(color=color, arrowstyle="->"),
379+
arrowprops=dict(color=color, arrowstyle=arrow_style),
375380
)
376381

377382
# Draw arrows pointing upwards
@@ -382,13 +387,13 @@ def highlight_data_outside_domain(
382387
textcoords=tform_x,
383388
xy=(x_up, 1.0 - distance_from_edge),
384389
xycoords=tform_x,
385-
arrowprops=dict(color=color, arrowstyle="->"),
390+
arrowprops=dict(color=color, arrowstyle=arrow_style),
386391
)
387392

388393
# Next, find all data points that are outside the X-axis range and
389394
# within Y-axis range
390-
below_x_within_y = below_x_range * within_y_range
391-
above_x_within_y = above_x_range * within_y_range
395+
below_x_within_y = logical_and(below_x_range, within_y_range)
396+
above_x_within_y = logical_and(above_x_range, within_y_range)
392397

393398
# Y coordinates of the data points whose X coordinates are outside the
394399
# X-axis range
@@ -407,7 +412,7 @@ def highlight_data_outside_domain(
407412
textcoords=tform_y,
408413
xy=(distance_from_edge, y_left),
409414
xycoords=tform_y,
410-
arrowprops=dict(color=color, arrowstyle="->"),
415+
arrowprops=dict(color=color, arrowstyle=arrow_style),
411416
)
412417

413418
# Draw arrows pointing rightwards
@@ -418,12 +423,13 @@ def highlight_data_outside_domain(
418423
textcoords=tform_y,
419424
xy=(1.0 - distance_from_edge, y_right),
420425
xycoords=tform_y,
421-
arrowprops=dict(color=color, arrowstyle="->"),
426+
arrowprops=dict(color=color, arrowstyle=arrow_style),
422427
)
423428

424429
# Finally, handle the points that are both outside the X and Y axis range
425-
outside_plot = (below_y_range + above_y_range) * (
426-
below_x_range + above_x_range
430+
outside_plot = logical_and(
431+
logical_or(below_y_range, above_y_range),
432+
logical_or(below_x_range, above_x_range),
427433
)
428434
x_outside_list, y_outside_list = x[outside_plot], y[outside_plot]
429435

@@ -458,7 +464,7 @@ def highlight_data_outside_domain(
458464
textcoords=tform,
459465
xy=(arrow_end_x, arrow_end_y),
460466
xycoords=tform,
461-
arrowprops=dict(color=color, arrowstyle="->"),
467+
arrowprops=dict(color=color, arrowstyle=arrow_style),
462468
)
463469

464470
return

0 commit comments

Comments
 (0)