3
3
"""
4
4
5
5
from unyt import unyt_quantity , unyt_array
6
- from numpy import logspace , linspace , log10 , logical_and , isnan , sqrt
6
+ from numpy import logspace , linspace , log10 , logical_and , isnan , sqrt , logical_or
7
7
from typing import Dict , Union , Tuple , List
8
8
from matplotlib .pyplot import Axes
9
9
from matplotlib .transforms import blended_transform_factory
@@ -338,21 +338,22 @@ def highlight_data_outside_domain(
338
338
# Arrow parameters
339
339
arrow_length = 0.07
340
340
distance_from_edge = 0.01
341
+ arrow_style = "->"
341
342
342
343
# Split data into three categories (along X axis)
343
344
below_x_range = x < x_lim [0 ]
344
345
above_x_range = x > x_lim [1 ]
345
- within_x_range = (x > x_lim [0 ]) * ( x < x_lim [1 ])
346
+ within_x_range = logical_and (x > x_lim [0 ], x < x_lim [1 ])
346
347
347
348
# Split data into three categories (along Y axis)
348
349
below_y_range = y < y_lim [0 ]
349
350
above_y_range = y > y_lim [1 ]
350
- within_y_range = (y > y_lim [0 ]) * ( y < y_lim [1 ])
351
+ within_y_range = logical_and (y > y_lim [0 ], y < y_lim [1 ])
351
352
352
353
# First, find all data points that are outside the Y-axis range and within
353
354
# X-axis range
354
- below_y_within_x = below_y_range * within_x_range
355
- above_y_within_x = above_y_range * within_x_range
355
+ below_y_within_x = logical_and ( below_y_range , within_x_range )
356
+ above_y_within_x = logical_and ( above_y_range , within_x_range )
356
357
357
358
# X coordinates of the data points whose Y coordinates are outside the
358
359
# Y-axis range
@@ -365,13 +366,17 @@ def highlight_data_outside_domain(
365
366
366
367
# Draw arrows pointing downwards
367
368
for x_down in x_down_list :
369
+ # We are using 'ax.annotate' instead of 'ax.arrow' because we want the
370
+ # arrow's head and tail to have the same size regardless of what the
371
+ # axes aspect ratio is or whether the plot is in logarithmic or linear
372
+ # scale.
368
373
ax .annotate (
369
374
"" ,
370
375
xytext = (x_down , arrow_length + distance_from_edge ),
371
376
textcoords = tform_x ,
372
377
xy = (x_down , distance_from_edge ),
373
378
xycoords = tform_x ,
374
- arrowprops = dict (color = color , arrowstyle = "->" ),
379
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
375
380
)
376
381
377
382
# Draw arrows pointing upwards
@@ -382,13 +387,13 @@ def highlight_data_outside_domain(
382
387
textcoords = tform_x ,
383
388
xy = (x_up , 1.0 - distance_from_edge ),
384
389
xycoords = tform_x ,
385
- arrowprops = dict (color = color , arrowstyle = "->" ),
390
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
386
391
)
387
392
388
393
# Next, find all data points that are outside the X-axis range and
389
394
# within Y-axis range
390
- below_x_within_y = below_x_range * within_y_range
391
- above_x_within_y = above_x_range * within_y_range
395
+ below_x_within_y = logical_and ( below_x_range , within_y_range )
396
+ above_x_within_y = logical_and ( above_x_range , within_y_range )
392
397
393
398
# Y coordinates of the data points whose X coordinates are outside the
394
399
# X-axis range
@@ -407,7 +412,7 @@ def highlight_data_outside_domain(
407
412
textcoords = tform_y ,
408
413
xy = (distance_from_edge , y_left ),
409
414
xycoords = tform_y ,
410
- arrowprops = dict (color = color , arrowstyle = "->" ),
415
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
411
416
)
412
417
413
418
# Draw arrows pointing rightwards
@@ -418,12 +423,13 @@ def highlight_data_outside_domain(
418
423
textcoords = tform_y ,
419
424
xy = (1.0 - distance_from_edge , y_right ),
420
425
xycoords = tform_y ,
421
- arrowprops = dict (color = color , arrowstyle = "->" ),
426
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
422
427
)
423
428
424
429
# Finally, handle the points that are both outside the X and Y axis range
425
- outside_plot = (below_y_range + above_y_range ) * (
426
- below_x_range + above_x_range
430
+ outside_plot = logical_and (
431
+ logical_or (below_y_range , above_y_range ),
432
+ logical_or (below_x_range , above_x_range ),
427
433
)
428
434
x_outside_list , y_outside_list = x [outside_plot ], y [outside_plot ]
429
435
@@ -458,7 +464,7 @@ def highlight_data_outside_domain(
458
464
textcoords = tform ,
459
465
xy = (arrow_end_x , arrow_end_y ),
460
466
xycoords = tform ,
461
- arrowprops = dict (color = color , arrowstyle = "->" ),
467
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
462
468
)
463
469
464
470
return
0 commit comments