Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weak lensing DC2 updates #1067

Merged
merged 17 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,7 @@ def load_image_and_catalog(self, image_index):
},
}

def generate_cached_data(self, image_index):
result_dict = self.load_image_and_catalog(image_index)

image = result_dict["inputs"]["image"]
tile_dict = result_dict["tile_dict"]
wcs_header_str = result_dict["other_info"]["wcs_header_str"]
psf_params = result_dict["inputs"]["psf_params"]

def split_image_and_tile_cat(self, image, tile_cat, tile_cat_keys_to_split, psf_params):
# split image
split_lim = self.image_lim[0] // self.n_image_split
image_splits = split_tensor(image, split_lim, 1, 2)
Expand All @@ -237,6 +230,31 @@ def generate_cached_data(self, image_index):

# split tile cat
tile_cat_splits = {}
for param_name in tile_cat_keys_to_split:
tile_cat_splits[param_name] = split_tensor(
tile_cat[param_name], split_lim // self.tile_slen, 0, 1
)

return {
"tile_catalog": unpack_dict(tile_cat_splits),
"images": image_splits,
"image_height_index": (
torch.arange(0, len(image_splits)) // split_image_num_on_width
).tolist(),
"image_width_index": (
torch.arange(0, len(image_splits)) % split_image_num_on_width
).tolist(),
"psf_params": [psf_params for _ in range(self.n_image_split**2)],
}

def generate_cached_data(self, image_index):
result_dict = self.load_image_and_catalog(image_index)

image = result_dict["inputs"]["image"]
tile_dict = result_dict["tile_dict"]
wcs_header_str = result_dict["other_info"]["wcs_header_str"]
psf_params = result_dict["inputs"]["psf_params"]

param_list = [
"locs",
"n_sources",
Expand All @@ -252,24 +270,11 @@ def generate_cached_data(self, image_index):
"two_sources_mask",
"more_than_two_sources_mask",
]
for param_name in param_list:
tile_cat_splits[param_name] = split_tensor(
tile_dict[param_name], split_lim // self.tile_slen, 0, 1
)

data_splits = {
"tile_catalog": unpack_dict(tile_cat_splits),
"images": image_splits,
"image_height_index": (
torch.arange(0, len(image_splits)) // split_image_num_on_width
).tolist(),
"image_width_index": (
torch.arange(0, len(image_splits)) % split_image_num_on_width
).tolist(),
"psf_params": [psf_params for _ in range(self.n_image_split**2)],
}
splits = self.split_image_and_tile_cat(image, tile_dict, param_list, psf_params)

data_splits = split_list(
unpack_dict(data_splits),
unpack_dict(splits),
sub_list_len=self.data_in_one_cached_file,
)

Expand Down
133 changes: 69 additions & 64 deletions case_studies/weak_lensing/generate_dc2_lensing_catalog.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=R0801
import os
import pickle as pkl

Expand All @@ -17,32 +18,61 @@
raise FileExistsError(f"{file_path} already exists.")


print("Loading truth...\n") # noqa: WPS421

truth_cat = GCRCatalogs.load_catalog("desc_dc2_run2.2i_dr6_truth")

truth_df = truth_cat.get_quantities(
quantities=[
"cosmodc2_id",
"id",
"match_objectId",
"truth_type",
"ra",
"dec",
"redshift",
"flux_u",
"flux_g",
"flux_r",
"flux_i",
"flux_z",
"flux_y",
"mag_u",
"mag_g",
"mag_r",
"mag_i",
"mag_z",
"mag_y",
]
)
truth_df = pd.DataFrame(truth_df)

truth_df = truth_df[truth_df["truth_type"] == 1]

truth_df = truth_df[truth_df["flux_r"] >= 200]

max_ra = np.nanmax(truth_df["ra"])
min_ra = np.nanmin(truth_df["ra"])
max_dec = np.nanmax(truth_df["dec"])
min_dec = np.nanmin(truth_df["dec"])
ra_dec_filters = [f"ra >= {min_ra}", f"ra <= {max_ra}", f"dec >= {min_dec}", f"dec <= {max_dec}"]

vertices = hp.ang2vec(
np.array([min_ra, max_ra, max_ra, min_ra]),
np.array([min_dec, min_dec, max_dec, max_dec]),
lonlat=True,
)
ipix = hp.query_polygon(32, vertices, inclusive=True)
healpix_filter = GCRQuery((lambda h: np.isin(h, ipix, assume_unique=True), "healpix_pixel"))


print("Loading object-with-truth-match...\n") # noqa: WPS421

object_truth_cat = GCRCatalogs.load_catalog("desc_dc2_run2.2i_dr6_object_with_truth_match")

object_truth_df = object_truth_cat.get_quantities(
quantities=[
"cosmodc2_id_truth",
"id_truth",
"objectId",
"match_objectId",
"truth_type",
"ra_truth",
"dec_truth",
"redshift_truth",
"flux_u_truth",
"flux_g_truth",
"flux_r_truth",
"flux_i_truth",
"flux_z_truth",
"flux_y_truth",
"mag_u_truth",
"mag_g_truth",
"mag_r_truth",
"mag_i_truth",
"mag_z_truth",
"mag_y_truth",
"Ixx_pixel",
"Iyy_pixel",
"Ixy_pixel",
Expand Down Expand Up @@ -70,32 +100,15 @@
"psf_fwhm_i",
"psf_fwhm_z",
"psf_fwhm_y",
],
]
)
object_truth_df = pd.DataFrame(object_truth_df)

max_ra = np.nanmax(object_truth_df["ra_truth"])
min_ra = np.nanmin(object_truth_df["ra_truth"])
max_dec = np.nanmax(object_truth_df["dec_truth"])
min_dec = np.nanmin(object_truth_df["dec_truth"])
ra_dec_filters = [f"ra >= {min_ra}", f"ra <= {max_ra}", f"dec >= {min_dec}", f"dec <= {max_dec}"]

vertices = hp.ang2vec(
np.array([min_ra, max_ra, max_ra, min_ra]),
np.array([min_dec, min_dec, max_dec, max_dec]),
lonlat=True,
)
ipix = hp.query_polygon(32, vertices, inclusive=True)
healpix_filter = GCRQuery((lambda h: np.isin(h, ipix, assume_unique=True), "healpix_pixel"))

object_truth_df = object_truth_df[object_truth_df["truth_type"] == 1]

object_truth_df.drop_duplicates(subset=["cosmodc2_id_truth"], inplace=True)


print("Loading CosmoDC2...\n") # noqa: WPS421

config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2"}

cosmo_cat = GCRCatalogs.load_catalog("desc_cosmodc2", config_overwrite)

cosmo_df = cosmo_cat.get_quantities(
Expand All @@ -115,39 +128,31 @@
cosmo_df = pd.DataFrame(cosmo_df)


print("Merging...\n") # noqa: WPS421
print("Merging truth with object-with-truth-match...\n") # noqa: WPS421

merge_df = object_truth_df.merge(
cosmo_df, left_on="cosmodc2_id_truth", right_on="galaxy_id", how="left"
merge_df1 = truth_df.merge(
object_truth_df, left_on="cosmodc2_id", right_on="cosmodc2_id_truth", how="left"
)

merge_df = merge_df[~merge_df["galaxy_id"].isna()]

merge_df.drop(columns=["ra_truth", "dec_truth"], inplace=True)

merge_df.rename(
columns={
"redshift_truth": "redshift",
"flux_u_truth": "flux_u",
"flux_g_truth": "flux_g",
"flux_r_truth": "flux_r",
"flux_i_truth": "flux_i",
"flux_z_truth": "flux_z",
"flux_y_truth": "flux_y",
"mag_u_truth": "mag_u",
"mag_g_truth": "mag_g",
"mag_r_truth": "mag_r",
"mag_i_truth": "mag_i",
"mag_z_truth": "mag_z",
"mag_y_truth": "mag_y",
},
inplace=True,
)
merge_df1.drop_duplicates(subset=["cosmodc2_id"], inplace=True)

merge_df1.drop(columns=["cosmodc2_id_truth"], inplace=True)


print("Merging with CosmoDC2...\n") # noqa: WPS421

merge_df2 = merge_df1.merge(cosmo_df, left_on="cosmodc2_id", right_on="galaxy_id", how="left")

merge_df2 = merge_df2[~merge_df2["galaxy_id"].isna()]

merge_df2.drop(columns=["ra_y", "dec_y"], inplace=True)

merge_df2.rename(columns={"ra_x": "ra", "dec_x": "dec"}, inplace=True)


print("Saving...\n") # noqa: WPS421

with open(file_path, "wb") as f:
pkl.dump(merge_df, f)
pkl.dump(merge_df2, f)

print(f"Catalog has been saved at {file_path}") # noqa: WPS421
Loading
Loading