3
3
"""
4
4
5
5
from unyt import unyt_quantity , unyt_array
6
- from numpy import logspace , linspace , log10 , logical_and
6
+ from numpy import logspace , linspace , log10 , logical_and , isnan , sqrt , logical_or
7
7
from typing import Dict , Union , Tuple , List
8
8
from matplotlib .pyplot import Axes
9
+ from matplotlib .transforms import blended_transform_factory
9
10
10
11
import velociraptor .tools .lines as lines
11
12
from velociraptor .tools .mass_functions import (
@@ -187,7 +188,7 @@ def create_line(
187
188
188
189
x: unyt_array
189
190
Horizontal axis data
190
-
191
+
191
192
y: unyt_array
192
193
Vertical axis data
193
194
@@ -290,8 +291,192 @@ def create_line(
290
291
291
292
return self .output
292
293
294
+ def highlight_data_outside_domain (
295
+ self ,
296
+ ax : Axes ,
297
+ x : unyt_array ,
298
+ y : unyt_array ,
299
+ color : str ,
300
+ x_lim : List ,
301
+ y_lim : List ,
302
+ ) -> None :
303
+
304
+ """
305
+ Add arrows to the plot for each data point residing outside the plot's domain.
306
+ The arrows indicate where the missing points are. For a given missing data point
307
+ with its Y(X) coordinate outside the Y(X)-axis range, the corresponding arrow
308
+ will have the same X(Y) coordinate and point to the direction where the missing
309
+ point is. If a data point happens to lie outside both the X-axis range and
310
+ Y-axis range, then a diagonal arrow is drawn.
311
+
312
+ Parameters
313
+ ----------
314
+
315
+ ax: Axes
316
+ An object of axes where to draw the arrows
317
+
318
+ x: unyt_array
319
+ Horizontal axis data
320
+
321
+ y: unyt_array
322
+ Vertical axis data
323
+
324
+ color: str
325
+ Color of the arrows that this function will draw. The color should be the
326
+ same as the color of the (missing) data points.
327
+
328
+ x_lim: List
329
+ A 2-length list containing the lower and upper limits of the X-axis range.
330
+
331
+ y_lim: List
332
+ A 2-length list containing the lower and upper limits of the Y-axis range.
333
+ """
334
+
335
+ # Additional check to ensure all provided data points are good
336
+ if not isnan (x ).any () and not isnan (y ).any ():
337
+
338
+ # Arrow parameters
339
+ arrow_length = 0.07
340
+ distance_from_edge = 0.01
341
+ arrow_style = "->"
342
+
343
+ # Split data into three categories (along X axis)
344
+ below_x_range = x < x_lim [0 ]
345
+ above_x_range = x > x_lim [1 ]
346
+ within_x_range = logical_and (x >= x_lim [0 ], x <= x_lim [1 ])
347
+
348
+ # Split data into three categories (along Y axis)
349
+ below_y_range = y < y_lim [0 ]
350
+ above_y_range = y > y_lim [1 ]
351
+ within_y_range = logical_and (y >= y_lim [0 ], y <= y_lim [1 ])
352
+
353
+ # First, find all data points that are outside the Y-axis range and within
354
+ # X-axis range
355
+ below_y_within_x = logical_and (below_y_range , within_x_range )
356
+ above_y_within_x = logical_and (above_y_range , within_x_range )
357
+
358
+ # X coordinates of the data points whose Y coordinates are outside the
359
+ # Y-axis range
360
+ x_down_list = x [below_y_within_x ]
361
+ x_up_list = x [above_y_within_x ]
362
+
363
+ # Use figure's data coordinates along the X axis and relative coordinates
364
+ # along the Y axis.
365
+ tform_x = blended_transform_factory (ax .transData , ax .transAxes )
366
+
367
+ # Draw arrows pointing downwards
368
+ for x_down in x_down_list :
369
+ # We are using 'ax.annotate' instead of 'ax.arrow' because we want the
370
+ # arrow's head and tail to have the same size regardless of what the
371
+ # axes aspect ratio is or whether the plot is in logarithmic or linear
372
+ # scale.
373
+ ax .annotate (
374
+ "" ,
375
+ xytext = (x_down , arrow_length + distance_from_edge ),
376
+ textcoords = tform_x ,
377
+ xy = (x_down , distance_from_edge ),
378
+ xycoords = tform_x ,
379
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
380
+ )
381
+
382
+ # Draw arrows pointing upwards
383
+ for x_up in x_up_list :
384
+ ax .annotate (
385
+ "" ,
386
+ xytext = (x_up , 1.0 - arrow_length - distance_from_edge ),
387
+ textcoords = tform_x ,
388
+ xy = (x_up , 1.0 - distance_from_edge ),
389
+ xycoords = tform_x ,
390
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
391
+ )
392
+
393
+ # Next, find all data points that are outside the X-axis range and
394
+ # within Y-axis range
395
+ below_x_within_y = logical_and (below_x_range , within_y_range )
396
+ above_x_within_y = logical_and (above_x_range , within_y_range )
397
+
398
+ # Y coordinates of the data points whose X coordinates are outside the
399
+ # X-axis range
400
+ y_left_list = y [below_x_within_y ]
401
+ y_right_list = y [above_x_within_y ]
402
+
403
+ # Use figure's data coordinates along the Y axis and relative coordinates
404
+ # along the X axis.
405
+ tform_y = blended_transform_factory (ax .transAxes , ax .transData )
406
+
407
+ # Draw arrows pointing leftwards
408
+ for y_left in y_left_list :
409
+ ax .annotate (
410
+ "" ,
411
+ xytext = (arrow_length + distance_from_edge , y_left ),
412
+ textcoords = tform_y ,
413
+ xy = (distance_from_edge , y_left ),
414
+ xycoords = tform_y ,
415
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
416
+ )
417
+
418
+ # Draw arrows pointing rightwards
419
+ for y_right in y_right_list :
420
+ ax .annotate (
421
+ "" ,
422
+ xytext = (1.0 - arrow_length - distance_from_edge , y_right ),
423
+ textcoords = tform_y ,
424
+ xy = (1.0 - distance_from_edge , y_right ),
425
+ xycoords = tform_y ,
426
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
427
+ )
428
+
429
+ # Finally, handle the points that are both outside the X and Y axis range
430
+ outside_plot = logical_and (
431
+ logical_or (below_y_range , above_y_range ),
432
+ logical_or (below_x_range , above_x_range ),
433
+ )
434
+ x_outside_list , y_outside_list = x [outside_plot ], y [outside_plot ]
435
+
436
+ for x_outside , y_outside in zip (x_outside_list , y_outside_list ):
437
+
438
+ # Unlike vertical and horizontal arrows, diagonal arrows extend both
439
+ # in X and Y directions. We account for it by dividing the length of
440
+ # diagonal arrow along each dimension by \sqrt(2).
441
+ arrow_proj_length = arrow_length / sqrt (2.0 )
442
+
443
+ # Find the correct position of the arrow on the plot
444
+ if x_lim [0 ] > x_outside :
445
+ arrow_start_x = arrow_proj_length + distance_from_edge
446
+ arrow_end_x = distance_from_edge
447
+ else :
448
+ arrow_start_x = 1.0 - arrow_proj_length - distance_from_edge
449
+ arrow_end_x = 1.0 - distance_from_edge
450
+
451
+ if y_lim [0 ] > y_outside :
452
+ arrow_start_y = arrow_proj_length + distance_from_edge
453
+ arrow_end_y = distance_from_edge
454
+ else :
455
+ arrow_start_y = 1.0 - arrow_proj_length - distance_from_edge
456
+ arrow_end_y = 1.0 - distance_from_edge
457
+
458
+ # Use figure's relative coordinates along the X and Y axis.
459
+ tform = blended_transform_factory (ax .transAxes , ax .transAxes )
460
+
461
+ ax .annotate (
462
+ "" ,
463
+ xytext = (arrow_start_x , arrow_start_y ),
464
+ textcoords = tform ,
465
+ xy = (arrow_end_x , arrow_end_y ),
466
+ xycoords = tform ,
467
+ arrowprops = dict (color = color , arrowstyle = arrow_style ),
468
+ )
469
+
470
+ return
471
+
293
472
def plot_line (
294
- self , ax : Axes , x : unyt_array , y : unyt_array , label : Union [str , None ] = None
473
+ self ,
474
+ ax : Axes ,
475
+ x : unyt_array ,
476
+ y : unyt_array ,
477
+ label : Union [str , None ] = None ,
478
+ x_lim : Union [List , None ] = None ,
479
+ y_lim : Union [List , None ] = None ,
295
480
):
296
481
"""
297
482
Plot a line using these parameters on some axes, x against y.
@@ -304,14 +489,20 @@ def plot_line(
304
489
305
490
x: unyt_array
306
491
Horizontal axis data
307
-
492
+
308
493
y: unyt_array
309
494
Vertical axis data
310
495
311
496
label: str
312
497
Label associated with this data that will be included in the
313
498
legend.
314
499
500
+ x_lim: Union[List, None]
501
+ A 2-length list containing the lower and upper limits of the X-axis range.
502
+
503
+ y_lim: Union[List, None]
504
+ A 2-length list containing the lower and upper limits of the Y-axis range.
505
+
315
506
Notes
316
507
-----
317
508
@@ -364,6 +555,22 @@ def plot_line(
364
555
365
556
try :
366
557
ax .scatter (additional_x .value , additional_y .value , color = line .get_color ())
558
+
559
+ # Enter only if the plot has a valid X-axis and Y-axis ranges and there are
560
+ # any additional data points.
561
+ if x_lim is not None and y_lim is not None and len (additional_x ) > 0 :
562
+
563
+ # Add arrows to the plot for each data point beyond X- or/and Y- axis
564
+ # range
565
+ self .highlight_data_outside_domain (
566
+ ax ,
567
+ additional_x .value ,
568
+ additional_y .value ,
569
+ line .get_color (),
570
+ (x_lim [0 ].value , x_lim [1 ].value ),
571
+ (y_lim [0 ].value , y_lim [1 ].value ),
572
+ )
573
+
367
574
# In case the line object is undefined
368
575
except NameError :
369
576
ax .scatter (additional_x .value , additional_y .value )
0 commit comments