Skip to content

Commit f5c735a

Browse files
authored
Merge pull request #46 from SWIFTSIM/arrows_to_missing_data_points
Show arrows where the data is outside the plot's domain
2 parents c84c6b9 + e8e9e29 commit f5c735a

File tree

3 files changed

+235
-10
lines changed

3 files changed

+235
-10
lines changed

velociraptor/autoplotter/compare.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ def load_yaml_line_data(
4444
----------
4545
paths: Union[str, List[str]]
4646
Paths to yaml data files to load.
47-
47+
4848
names: Union[str, List[str]]
4949
Names of the simulations that correspond to the yaml data files.
5050
Will be placed in the legends of the plots.
51-
51+
5252
Returns
5353
-------
54-
54+
5555
data: Dict[str, Dict]
5656
Dictionary of line data read directly from the files.
5757
"""
@@ -139,7 +139,7 @@ def recreate_single_figure(
139139
"""
140140
Recreates a single figure using the data in ``line_data`` and the metadata in
141141
``plot``.
142-
142+
143143
Parameters
144144
----------
145145
plot: VelociraptorPlot
@@ -230,6 +230,20 @@ def recreate_single_figure(
230230

231231
ax.scatter(additional_x, additional_y, c=color_name)
232232

233+
# Enter only if the plot has a valid Y-axis range and there are any
234+
# additional data points.
235+
if plot.y_lim is not None and len(additional_x) > 0:
236+
237+
# Draw arrows for each data point beyond X- or/and Y- axis range
238+
line.highlight_data_outside_domain(
239+
ax,
240+
additional_x.value,
241+
additional_y.value,
242+
color_name,
243+
(plot.x_lim[0].value, plot.x_lim[1].value),
244+
(plot.y_lim[0].value, plot.y_lim[1].value),
245+
)
246+
233247
# Add observational data second to allow for colour precedence
234248
# to go to runs
235249
observational_data_scale_factor_bracket = [

velociraptor/autoplotter/lines.py

+211-4
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
"""
44

55
from unyt import unyt_quantity, unyt_array
6-
from numpy import logspace, linspace, log10, logical_and
6+
from numpy import logspace, linspace, log10, logical_and, isnan, sqrt, logical_or
77
from typing import Dict, Union, Tuple, List
88
from matplotlib.pyplot import Axes
9+
from matplotlib.transforms import blended_transform_factory
910

1011
import velociraptor.tools.lines as lines
1112
from velociraptor.tools.mass_functions import (
@@ -187,7 +188,7 @@ def create_line(
187188
188189
x: unyt_array
189190
Horizontal axis data
190-
191+
191192
y: unyt_array
192193
Vertical axis data
193194
@@ -290,8 +291,192 @@ def create_line(
290291

291292
return self.output
292293

294+
def highlight_data_outside_domain(
295+
self,
296+
ax: Axes,
297+
x: unyt_array,
298+
y: unyt_array,
299+
color: str,
300+
x_lim: List,
301+
y_lim: List,
302+
) -> None:
303+
304+
"""
305+
Add arrows to the plot for each data point residing outside the plot's domain.
306+
The arrows indicate where the missing points are. For a given missing data point
307+
with its Y(X) coordinate outside the Y(X)-axis range, the corresponding arrow
308+
will have the same X(Y) coordinate and point to the direction where the missing
309+
point is. If a data point happens to lie outside both the X-axis range and
310+
Y-axis range, then a diagonal arrow is drawn.
311+
312+
Parameters
313+
----------
314+
315+
ax: Axes
316+
An object of axes where to draw the arrows
317+
318+
x: unyt_array
319+
Horizontal axis data
320+
321+
y: unyt_array
322+
Vertical axis data
323+
324+
color: str
325+
Color of the arrows that this function will draw. The color should be the
326+
same as the color of the (missing) data points.
327+
328+
x_lim: List
329+
A 2-length list containing the lower and upper limits of the X-axis range.
330+
331+
y_lim: List
332+
A 2-length list containing the lower and upper limits of the Y-axis range.
333+
"""
334+
335+
# Additional check to ensure all provided data points are good
336+
if not isnan(x).any() and not isnan(y).any():
337+
338+
# Arrow parameters
339+
arrow_length = 0.07
340+
distance_from_edge = 0.01
341+
arrow_style = "->"
342+
343+
# Split data into three categories (along X axis)
344+
below_x_range = x < x_lim[0]
345+
above_x_range = x > x_lim[1]
346+
within_x_range = logical_and(x >= x_lim[0], x <= x_lim[1])
347+
348+
# Split data into three categories (along Y axis)
349+
below_y_range = y < y_lim[0]
350+
above_y_range = y > y_lim[1]
351+
within_y_range = logical_and(y >= y_lim[0], y <= y_lim[1])
352+
353+
# First, find all data points that are outside the Y-axis range and within
354+
# X-axis range
355+
below_y_within_x = logical_and(below_y_range, within_x_range)
356+
above_y_within_x = logical_and(above_y_range, within_x_range)
357+
358+
# X coordinates of the data points whose Y coordinates are outside the
359+
# Y-axis range
360+
x_down_list = x[below_y_within_x]
361+
x_up_list = x[above_y_within_x]
362+
363+
# Use figure's data coordinates along the X axis and relative coordinates
364+
# along the Y axis.
365+
tform_x = blended_transform_factory(ax.transData, ax.transAxes)
366+
367+
# Draw arrows pointing downwards
368+
for x_down in x_down_list:
369+
# We are using 'ax.annotate' instead of 'ax.arrow' because we want the
370+
# arrow's head and tail to have the same size regardless of what the
371+
# axes aspect ratio is or whether the plot is in logarithmic or linear
372+
# scale.
373+
ax.annotate(
374+
"",
375+
xytext=(x_down, arrow_length + distance_from_edge),
376+
textcoords=tform_x,
377+
xy=(x_down, distance_from_edge),
378+
xycoords=tform_x,
379+
arrowprops=dict(color=color, arrowstyle=arrow_style),
380+
)
381+
382+
# Draw arrows pointing upwards
383+
for x_up in x_up_list:
384+
ax.annotate(
385+
"",
386+
xytext=(x_up, 1.0 - arrow_length - distance_from_edge),
387+
textcoords=tform_x,
388+
xy=(x_up, 1.0 - distance_from_edge),
389+
xycoords=tform_x,
390+
arrowprops=dict(color=color, arrowstyle=arrow_style),
391+
)
392+
393+
# Next, find all data points that are outside the X-axis range and
394+
# within Y-axis range
395+
below_x_within_y = logical_and(below_x_range, within_y_range)
396+
above_x_within_y = logical_and(above_x_range, within_y_range)
397+
398+
# Y coordinates of the data points whose X coordinates are outside the
399+
# X-axis range
400+
y_left_list = y[below_x_within_y]
401+
y_right_list = y[above_x_within_y]
402+
403+
# Use figure's data coordinates along the Y axis and relative coordinates
404+
# along the X axis.
405+
tform_y = blended_transform_factory(ax.transAxes, ax.transData)
406+
407+
# Draw arrows pointing leftwards
408+
for y_left in y_left_list:
409+
ax.annotate(
410+
"",
411+
xytext=(arrow_length + distance_from_edge, y_left),
412+
textcoords=tform_y,
413+
xy=(distance_from_edge, y_left),
414+
xycoords=tform_y,
415+
arrowprops=dict(color=color, arrowstyle=arrow_style),
416+
)
417+
418+
# Draw arrows pointing rightwards
419+
for y_right in y_right_list:
420+
ax.annotate(
421+
"",
422+
xytext=(1.0 - arrow_length - distance_from_edge, y_right),
423+
textcoords=tform_y,
424+
xy=(1.0 - distance_from_edge, y_right),
425+
xycoords=tform_y,
426+
arrowprops=dict(color=color, arrowstyle=arrow_style),
427+
)
428+
429+
# Finally, handle the points that are both outside the X and Y axis range
430+
outside_plot = logical_and(
431+
logical_or(below_y_range, above_y_range),
432+
logical_or(below_x_range, above_x_range),
433+
)
434+
x_outside_list, y_outside_list = x[outside_plot], y[outside_plot]
435+
436+
for x_outside, y_outside in zip(x_outside_list, y_outside_list):
437+
438+
# Unlike vertical and horizontal arrows, diagonal arrows extend both
439+
# in X and Y directions. We account for it by dividing the length of
440+
# diagonal arrow along each dimension by \sqrt(2).
441+
arrow_proj_length = arrow_length / sqrt(2.0)
442+
443+
# Find the correct position of the arrow on the plot
444+
if x_lim[0] > x_outside:
445+
arrow_start_x = arrow_proj_length + distance_from_edge
446+
arrow_end_x = distance_from_edge
447+
else:
448+
arrow_start_x = 1.0 - arrow_proj_length - distance_from_edge
449+
arrow_end_x = 1.0 - distance_from_edge
450+
451+
if y_lim[0] > y_outside:
452+
arrow_start_y = arrow_proj_length + distance_from_edge
453+
arrow_end_y = distance_from_edge
454+
else:
455+
arrow_start_y = 1.0 - arrow_proj_length - distance_from_edge
456+
arrow_end_y = 1.0 - distance_from_edge
457+
458+
# Use figure's relative coordinates along the X and Y axis.
459+
tform = blended_transform_factory(ax.transAxes, ax.transAxes)
460+
461+
ax.annotate(
462+
"",
463+
xytext=(arrow_start_x, arrow_start_y),
464+
textcoords=tform,
465+
xy=(arrow_end_x, arrow_end_y),
466+
xycoords=tform,
467+
arrowprops=dict(color=color, arrowstyle=arrow_style),
468+
)
469+
470+
return
471+
293472
def plot_line(
294-
self, ax: Axes, x: unyt_array, y: unyt_array, label: Union[str, None] = None
473+
self,
474+
ax: Axes,
475+
x: unyt_array,
476+
y: unyt_array,
477+
label: Union[str, None] = None,
478+
x_lim: Union[List, None] = None,
479+
y_lim: Union[List, None] = None,
295480
):
296481
"""
297482
Plot a line using these parameters on some axes, x against y.
@@ -304,14 +489,20 @@ def plot_line(
304489
305490
x: unyt_array
306491
Horizontal axis data
307-
492+
308493
y: unyt_array
309494
Vertical axis data
310495
311496
label: str
312497
Label associated with this data that will be included in the
313498
legend.
314499
500+
x_lim: Union[List, None]
501+
A 2-length list containing the lower and upper limits of the X-axis range.
502+
503+
y_lim: Union[List, None]
504+
A 2-length list containing the lower and upper limits of the Y-axis range.
505+
315506
Notes
316507
-----
317508
@@ -364,6 +555,22 @@ def plot_line(
364555

365556
try:
366557
ax.scatter(additional_x.value, additional_y.value, color=line.get_color())
558+
559+
# Enter only if the plot has a valid X-axis and Y-axis ranges and there are
560+
# any additional data points.
561+
if x_lim is not None and y_lim is not None and len(additional_x) > 0:
562+
563+
# Add arrows to the plot for each data point beyond X- or/and Y- axis
564+
# range
565+
self.highlight_data_outside_domain(
566+
ax,
567+
additional_x.value,
568+
additional_y.value,
569+
line.get_color(),
570+
(x_lim[0].value, x_lim[1].value),
571+
(y_lim[0].value, y_lim[1].value),
572+
)
573+
367574
# In case the line object is undefined
368575
except NameError:
369576
ax.scatter(additional_x.value, additional_y.value)

velociraptor/autoplotter/objects.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -663,9 +663,13 @@ def _add_lines_to_axes(self, ax: Axes, x: unyt_array, y: unyt_array) -> None:
663663
"""
664664

665665
if self.median_line is not None:
666-
self.median_line.plot_line(ax=ax, x=x, y=y, label="Median")
666+
self.median_line.plot_line(
667+
ax=ax, x=x, y=y, label="Median", x_lim=self.x_lim, y_lim=self.y_lim
668+
)
667669
if self.mean_line is not None:
668-
self.mean_line.plot_line(ax=ax, x=x, y=y, label="Mean")
670+
self.mean_line.plot_line(
671+
ax=ax, x=x, y=y, label="Mean", x_lim=self.x_lim, y_lim=self.y_lim
672+
)
669673

670674
return
671675

0 commit comments

Comments
 (0)