11
11
12
12
from unyt import unyt_quantity , unyt_array , matplotlib_support
13
13
from unyt .exceptions import UnitConversionError
14
- from numpy import log10 , linspace , logspace , array , logical_and
14
+ from numpy import log10 , linspace , logspace , array , logical_and , ones
15
15
from matplotlib .pyplot import Axes , Figure , close
16
16
from yaml import safe_load
17
17
from typing import Union , List , Dict , Tuple
@@ -95,6 +95,8 @@ class VelociraptorPlot(object):
95
95
observational_data_filenames : List [str ]
96
96
observational_data_bracket_width : float
97
97
observational_data_directory : str
98
+ # global mask
99
+ global_mask : Union [None , array ]
98
100
99
101
def __init__ (
100
102
self ,
@@ -743,7 +745,7 @@ def _add_lines_to_axes(self, ax: Axes, x: unyt_array, y: unyt_array) -> None:
743
745
return
744
746
745
747
def get_quantity_from_catalogue_with_mask (
746
- self , quantity : str , catalogue : VelociraptorCatalogue
748
+ self , quantity : str , catalogue : VelociraptorCatalogue ,
747
749
) -> unyt_array :
748
750
"""
749
751
Get a quantity from the catalogue using the mask.
@@ -753,62 +755,48 @@ def get_quantity_from_catalogue_with_mask(
753
755
# We give each dataset a custom name, that gets ruined when masking
754
756
# in versions of unyt less than 2.6.0
755
757
name = x .name
756
-
758
+
757
759
if self .structure_mask is not None :
758
- x = x [self .structure_mask ]
760
+ # if structure_mask already set, mask and return
761
+ x_mask = logical_and (self .global_mask , self .structure_mask )
762
+ x = x [x_mask ]
759
763
x .name = name
760
- elif self .selection_mask is not None :
764
+ return x
765
+
766
+ # allow all entries by default
767
+ self .structure_mask = ones (x .shape ).astype (bool )
768
+
769
+ if self .selection_mask is not None :
761
770
# Create mask
762
771
self .structure_mask = reduce (
763
772
getattr , self .selection_mask .split ("." ), catalogue
764
773
).astype (bool )
765
-
766
- if self .select_structure_type is not None :
767
- if self .select_structure_type == self .exclude_structure_type :
768
- raise AutoPlotterError (
769
- f"Cannot simultaneously select and exclude structure"
770
- " type {self.select_structure_type}"
771
- )
772
- self .structure_mask = logical_and (
773
- self .structure_mask ,
774
- catalogue .structure_type .structuretype
775
- == self .select_structure_type ,
776
- )
777
-
778
- elif self .exclude_structure_type is not None :
779
- self .structure_mask = logical_and (
780
- self .structure_mask ,
781
- catalogue .structure_type .structuretype
782
- != self .exclude_structure_type ,
783
- )
784
-
785
- x = x [self .structure_mask ]
786
- x .name = name
787
- elif self .select_structure_type is not None :
774
+ if self .select_structure_type is not None :
788
775
if self .select_structure_type == self .exclude_structure_type :
789
776
raise AutoPlotterError (
790
777
f"Cannot simultaneously select and exclude structure"
791
778
" type {self.select_structure_type}"
792
779
)
793
-
794
- # Need to create mask
795
- self . structure_mask = (
796
- catalogue . structure_type . structuretype == self .select_structure_type
780
+ self . structure_mask = logical_and (
781
+ self . structure_mask ,
782
+ catalogue . structure_type . structuretype
783
+ == self .select_structure_type ,
797
784
)
798
-
799
- x = x [self .structure_mask ]
800
- x .name = name
801
- elif self .exclude_structure_type is not None :
802
- # Need to create mask
803
- self .structure_mask = (
804
- catalogue .structure_type .structuretype != self .exclude_structure_type
785
+ if self .exclude_structure_type is not None :
786
+ self .structure_mask = logical_and (
787
+ self .structure_mask ,
788
+ catalogue .structure_type .structuretype
789
+ != self .exclude_structure_type ,
805
790
)
791
+
792
+ # combine global and structure masks
793
+ x_mask = logical_and (self .global_mask , self .structure_mask )
806
794
807
- x = x [ self . structure_mask ]
808
- x . name = name
809
-
795
+ # apply to the unyt array of values
796
+ x = x [ x_mask ]
797
+ x . name = name
810
798
return x
811
-
799
+
812
800
def _make_plot_scatter (
813
801
self , catalogue : VelociraptorCatalogue
814
802
) -> Tuple [Figure , Axes ]:
@@ -974,7 +962,7 @@ def _make_plot_cumulative_histogram(
974
962
return fig , ax
975
963
976
964
def make_plot (
977
- self , catalogue : VelociraptorCatalogue , directory : str , file_extension : str
965
+ self , catalogue : VelociraptorCatalogue , directory : str , file_extension : str ,
978
966
):
979
967
"""
980
968
Federates out data parsing to individual functions based on the
@@ -1058,7 +1046,9 @@ class AutoPlotter(object):
1058
1046
observational_data_directory : str
1059
1047
# Whether or not the plots were created successfully.
1060
1048
created_successfully : List [bool ]
1061
-
1049
+ # global mask
1050
+ global_mask : Union [None , array ]
1051
+
1062
1052
def __init__ (
1063
1053
self ,
1064
1054
filename : Union [str , List [str ]],
@@ -1123,14 +1113,18 @@ def parse_yaml(self):
1123
1113
1124
1114
return
1125
1115
1126
- def link_catalogue (self , catalogue : VelociraptorCatalogue ):
1116
+ def link_catalogue (self , catalogue : VelociraptorCatalogue , global_mask_tag : Union [ None , str ] ):
1127
1117
"""
1128
1118
Links a catalogue with this object so that the plots
1129
1119
can actually be created.
1130
1120
"""
1131
1121
1132
1122
self .catalogue = catalogue
1133
1123
1124
+ if global_mask_tag is not None :
1125
+ self .global_mask = reduce (getattr , global_mask_tag .split ("." ), catalogue )
1126
+ else :
1127
+ self .global_mask = True
1134
1128
return
1135
1129
1136
1130
def create_plots (
@@ -1150,6 +1144,7 @@ def create_plots(
1150
1144
1151
1145
for plot in self .plots :
1152
1146
try :
1147
+ plot .global_mask = self .global_mask
1153
1148
plot .make_plot (
1154
1149
catalogue = self .catalogue ,
1155
1150
directory = directory ,
0 commit comments