Skip to content

Commit b168a6b

Browse files
Merge pull request #68 from james-trayford/global_mask_bg_parts
Global masking
2 parents 19b143a + 508c5a3 commit b168a6b

File tree

5 files changed

+48
-53
lines changed

5 files changed

+48
-53
lines changed

velociraptor-plot

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ if __name__ == "__main__":
149149
registration_file_path=args.registration,
150150
)
151151
print_if_debug(f"Linking catalogue and AutoPlotter instance.")
152-
auto_plotter.link_catalogue(catalogue=catalogue)
152+
auto_plotter.link_catalogue(catalogue=catalogue, global_mask_tag=None)
153153

154154
print_if_debug(
155155
f"Creating figures with extension .{args.file_type} in {args.output}."

velociraptor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,5 @@ def load(
7676

7777
if registration_file_path is not None:
7878
catalogue.register_derived_quantities(registration_file_path)
79-
79+
8080
return catalogue

velociraptor/autoplotter/objects.py

Lines changed: 41 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from unyt import unyt_quantity, unyt_array, matplotlib_support
1313
from unyt.exceptions import UnitConversionError
14-
from numpy import log10, linspace, logspace, array, logical_and
14+
from numpy import log10, linspace, logspace, array, logical_and, ones
1515
from matplotlib.pyplot import Axes, Figure, close
1616
from yaml import safe_load
1717
from typing import Union, List, Dict, Tuple
@@ -95,6 +95,8 @@ class VelociraptorPlot(object):
9595
observational_data_filenames: List[str]
9696
observational_data_bracket_width: float
9797
observational_data_directory: str
98+
# global mask
99+
global_mask: Union[None, array]
98100

99101
def __init__(
100102
self,
@@ -743,7 +745,7 @@ def _add_lines_to_axes(self, ax: Axes, x: unyt_array, y: unyt_array) -> None:
743745
return
744746

745747
def get_quantity_from_catalogue_with_mask(
746-
self, quantity: str, catalogue: VelociraptorCatalogue
748+
self, quantity: str, catalogue: VelociraptorCatalogue,
747749
) -> unyt_array:
748750
"""
749751
Get a quantity from the catalogue using the mask.
@@ -753,62 +755,48 @@ def get_quantity_from_catalogue_with_mask(
753755
# We give each dataset a custom name, that gets ruined when masking
754756
# in versions of unyt less than 2.6.0
755757
name = x.name
756-
758+
757759
if self.structure_mask is not None:
758-
x = x[self.structure_mask]
760+
# if structure_mask already set, mask and return
761+
x_mask = logical_and(self.global_mask, self.structure_mask)
762+
x = x[x_mask]
759763
x.name = name
760-
elif self.selection_mask is not None:
764+
return x
765+
766+
# allow all entries by default
767+
self.structure_mask = ones(x.shape).astype(bool)
768+
769+
if self.selection_mask is not None:
761770
# Create mask
762771
self.structure_mask = reduce(
763772
getattr, self.selection_mask.split("."), catalogue
764773
).astype(bool)
765-
766-
if self.select_structure_type is not None:
767-
if self.select_structure_type == self.exclude_structure_type:
768-
raise AutoPlotterError(
769-
f"Cannot simultaneously select and exclude structure"
770-
" type {self.select_structure_type}"
771-
)
772-
self.structure_mask = logical_and(
773-
self.structure_mask,
774-
catalogue.structure_type.structuretype
775-
== self.select_structure_type,
776-
)
777-
778-
elif self.exclude_structure_type is not None:
779-
self.structure_mask = logical_and(
780-
self.structure_mask,
781-
catalogue.structure_type.structuretype
782-
!= self.exclude_structure_type,
783-
)
784-
785-
x = x[self.structure_mask]
786-
x.name = name
787-
elif self.select_structure_type is not None:
774+
if self.select_structure_type is not None:
788775
if self.select_structure_type == self.exclude_structure_type:
789776
raise AutoPlotterError(
790777
f"Cannot simultaneously select and exclude structure"
791778
" type {self.select_structure_type}"
792779
)
793-
794-
# Need to create mask
795-
self.structure_mask = (
796-
catalogue.structure_type.structuretype == self.select_structure_type
780+
self.structure_mask = logical_and(
781+
self.structure_mask,
782+
catalogue.structure_type.structuretype
783+
== self.select_structure_type,
797784
)
798-
799-
x = x[self.structure_mask]
800-
x.name = name
801-
elif self.exclude_structure_type is not None:
802-
# Need to create mask
803-
self.structure_mask = (
804-
catalogue.structure_type.structuretype != self.exclude_structure_type
785+
if self.exclude_structure_type is not None:
786+
self.structure_mask = logical_and(
787+
self.structure_mask,
788+
catalogue.structure_type.structuretype
789+
!= self.exclude_structure_type,
805790
)
791+
792+
# combine global and structure masks
793+
x_mask = logical_and(self.global_mask, self.structure_mask)
806794

807-
x = x[self.structure_mask]
808-
x.name = name
809-
795+
# apply to the unyt array of values
796+
x = x[x_mask]
797+
x.name = name
810798
return x
811-
799+
812800
def _make_plot_scatter(
813801
self, catalogue: VelociraptorCatalogue
814802
) -> Tuple[Figure, Axes]:
@@ -974,7 +962,7 @@ def _make_plot_cumulative_histogram(
974962
return fig, ax
975963

976964
def make_plot(
977-
self, catalogue: VelociraptorCatalogue, directory: str, file_extension: str
965+
self, catalogue: VelociraptorCatalogue, directory: str, file_extension: str,
978966
):
979967
"""
980968
Federates out data parsing to individual functions based on the
@@ -1058,7 +1046,9 @@ class AutoPlotter(object):
10581046
observational_data_directory: str
10591047
# Whether or not the plots were created successfully.
10601048
created_successfully: List[bool]
1061-
1049+
# global mask
1050+
global_mask: Union[None, array]
1051+
10621052
def __init__(
10631053
self,
10641054
filename: Union[str, List[str]],
@@ -1123,14 +1113,18 @@ def parse_yaml(self):
11231113

11241114
return
11251115

1126-
def link_catalogue(self, catalogue: VelociraptorCatalogue):
1116+
def link_catalogue(self, catalogue: VelociraptorCatalogue, global_mask_tag: Union[None, str]):
11271117
"""
11281118
Links a catalogue with this object so that the plots
11291119
can actually be created.
11301120
"""
11311121

11321122
self.catalogue = catalogue
11331123

1124+
if global_mask_tag is not None:
1125+
self.global_mask = reduce(getattr, global_mask_tag.split("."), catalogue)
1126+
else:
1127+
self.global_mask = True
11341128
return
11351129

11361130
def create_plots(
@@ -1150,6 +1144,7 @@ def create_plots(
11501144

11511145
for plot in self.plots:
11521146
try:
1147+
plot.global_mask = self.global_mask
11531148
plot.make_plot(
11541149
catalogue=self.catalogue,
11551150
directory=directory,

velociraptor/catalogue/catalogue.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,4 +412,3 @@ def register_derived_quantities(self, registration_file_path: str) -> None:
412412
self.derived_quantities = DerivedQuantities(registration_file_path, self)
413413

414414
return
415-

velociraptor/catalogue/registration.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def registration_fail_all(
3737
+ name: A fancy (possibly LaTeX'd) name for the field.
3838
+ snake_case: A correct snake_case name for the field.
3939
"""
40-
40+
41+
4142
if field_path == "ThisFieldPathWouldNeverExist":
4243
return (
4344
unit_system.length,
@@ -307,6 +308,8 @@ def registration_masses(
307308
full_name = "$M_{\\rm FOF}$"
308309
elif field_path == "Mass_tot":
309310
full_name = r"$M$"
311+
elif field_path == "Mass_interloper":
312+
full_name = "$M_{\\rm BG}$"
310313

311314
# General regex matching case.
312315

@@ -321,7 +324,6 @@ def registration_masses(
321324
)
322325
regex = cached_regex(match_string)
323326
match = regex.match(field_path)
324-
325327
if match and not full_name:
326328
mass = match.group(1)
327329
radius = match.group(2)
@@ -1441,8 +1443,7 @@ def registration_spherical_overdensities(
14411443
return unit, full_name, snake_case
14421444
else:
14431445
raise RegistrationDoesNotMatchError
1444-
1445-
1446+
14461447
# TODO
14471448
# lambda_B
14481449
# q

0 commit comments

Comments
 (0)