Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weak lensing: updated baseline shear estimator + manuscript figures #1079

Merged
merged 22 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
eb702f3
Encoder eval notebook: more metrics, make it easier to set use_mode=F…
timwhite0 Nov 1, 2024
57bd0e6
Add _dc2 ellipticities to catalog pickle file
timwhite0 Dec 4, 2024
bf164c1
Filter to relevant tracts in catalog generation scripts
timwhite0 Dec 4, 2024
70c4f56
Update catalog generation notebooks to match scripts
timwhite0 Dec 4, 2024
ff604d2
Update text in catalog generation notebooks
timwhite0 Dec 4, 2024
93806e2
Use _dc2 ellipticities, update config, rerun dc2 ellipticity notebook
timwhite0 Dec 4, 2024
766fee9
Update DC2 images path in lensing config
timwhite0 Dec 18, 2024
7863da0
Manuscript: preliminary figure 1
timwhite0 Dec 18, 2024
35a077b
Manuscript: preliminary figure 3
timwhite0 Dec 19, 2024
78b089f
Manuscript: minor tweaks to figures 1 and 3
timwhite0 Dec 19, 2024
81322ae
Manuscript: minor update to figure 1
timwhite0 Dec 19, 2024
b3bc90e
Manuscript: preiminary figure 2
timwhite0 Dec 19, 2024
919255e
Manuscript: preliminary figure 4
timwhite0 Dec 19, 2024
01f9e3d
Manuscript: minor tweaks to figs 1, 3, 4
timwhite0 Dec 19, 2024
9170e67
Update figure numbers
timwhite0 Jan 3, 2025
a2a45cf
Manuscript: increase source brightness in figure 1
timwhite0 Jan 4, 2025
48ca510
Use LSST-measured second moments (ixx, ixy, iyy) in average ellip bas…
timwhite0 Jan 5, 2025
0e50ca3
Swap Ixx and Iyy to align coordinate system for baseline estimator
timwhite0 Jan 6, 2025
45b2bb1
Update avg ellip baseline tuning: use only training set, use R-squared
timwhite0 Jan 6, 2025
fe8b616
Update figure 4 after using LSST ellips in baseline estimator
timwhite0 Jan 6, 2025
817d334
Update plots in ellipticity tuning notebook
timwhite0 Jan 6, 2025
46fc831
Rerun encoder evaluation notebook
timwhite0 Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions case_studies/weak_lensing/generate_dc2_lensing_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
from GCRCatalogs import GCRQuery
from GCRCatalogs.helpers.tract_catalogs import tract_filter

GCRCatalogs.set_root_dir("/data/scratch/dc2_nfs/")

Expand Down Expand Up @@ -43,7 +44,8 @@
"mag_i",
"mag_z",
"mag_y",
]
],
native_filters=[tract_filter([3634, 3635, 3636, 3827, 3828, 3829, 3830, 4025, 4026, 4027])],
)
truth_df = pd.DataFrame(truth_df)

Expand Down Expand Up @@ -100,14 +102,16 @@
"psf_fwhm_i",
"psf_fwhm_z",
"psf_fwhm_y",
]
],
filters=ra_dec_filters,
native_filters=[tract_filter([3634, 3635, 3636, 3827, 3828, 3829, 3830, 4025, 4026, 4027])],
)
object_truth_df = pd.DataFrame(object_truth_df)


print("Loading CosmoDC2...\n") # noqa: WPS421

config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2"}
config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2_v1.1.4"}

cosmo_cat = GCRCatalogs.load_catalog("desc_cosmodc2", config_overwrite)

Expand All @@ -118,6 +122,8 @@
"dec",
"ellipticity_1_true",
"ellipticity_2_true",
"ellipticity_1_true_dc2",
"ellipticity_2_true_dc2",
"shear_1",
"shear_2",
"convergence",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
from GCRCatalogs import GCRQuery
from GCRCatalogs.helpers.tract_catalogs import tract_filter

GCRCatalogs.set_root_dir("/data/scratch/dc2_nfs/")

Expand Down Expand Up @@ -72,6 +73,7 @@
"psf_fwhm_z",
"psf_fwhm_y",
],
native_filters=[tract_filter([3634, 3635, 3636, 3827, 3828, 3829, 3830, 4025, 4026, 4027])],
)
object_truth_df = pd.DataFrame(object_truth_df)

Expand All @@ -96,7 +98,7 @@

print("Loading CosmoDC2...\n") # noqa: WPS421

config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2"}
config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2_v1.1.4"}
cosmo_cat = GCRCatalogs.load_catalog("desc_cosmodc2", config_overwrite)

cosmo_df = cosmo_cat.get_quantities(
Expand All @@ -106,6 +108,8 @@
"dec",
"ellipticity_1_true",
"ellipticity_2_true",
"ellipticity_1_true_dc2",
"ellipticity_2_true_dc2",
"shear_1",
"shear_2",
"convergence",
Expand Down
8 changes: 4 additions & 4 deletions case_studies/weak_lensing/lensing_config_dc2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,25 +71,25 @@ encoder:
surveys:
dc2:
_target_: case_studies.weak_lensing.lensing_dc2.LensingDC2DataModule
dc2_image_dir: ${paths.dc2}/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
dc2_image_dir: /data/scratch/dc2_nfs/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
dc2_cat_path: ${paths.dc2}/dc2_lensing_catalog.pkl
image_slen: 4096
n_image_split: 2 # split into n_image_split**2 subimages
tile_slen: 256
splits: 0:80/80:90/90:100
avg_ellip_kernel_size: 15 # needs to be odd
avg_ellip_kernel_sigma: 4
avg_ellip_kernel_sigma: 5
batch_size: 1
num_workers: 1
cached_data_path: ${paths.dc2}/dc2_lensing_splits_radec
cached_data_path: ${paths.dc2}/dc2_lensing_splits
train_transforms:
- _target_: case_studies.weak_lensing.lensing_data_augmentation.LensingRotateFlipTransform

train:
trainer:
logger:
name: weak_lensing_experiments_dc2
version: october18
version: december4
max_epochs: 250
devices: 1
use_distributed_sampler: false
Expand Down
30 changes: 23 additions & 7 deletions case_studies/weak_lensing/lensing_dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from typing import List

import numpy as np
import pandas as pd
import torch

Expand Down Expand Up @@ -102,10 +103,10 @@ def to_tile_catalog(self, full_catalog, height, width):
v_count = torch.zeros(self.batch_size, num_tiles, v.shape[-1], dtype=v.dtype)
add_pos = torch.ones_like(v)

if k == "psf":
psf_nanmask = ~torch.isnan(v).any(dim=-1, keepdim=True).expand(-1, -1, v.shape[-1])
v = torch.where(psf_nanmask, v, torch.tensor(0.0))
add_pos = psf_nanmask.float().to(v.dtype)
if k in {"psf", "ellip1_lsst", "ellip2_lsst"}:
nanmask = ~torch.isnan(v).any(dim=-1, keepdim=True).expand(-1, -1, v.shape[-1])
v = torch.where(nanmask, v, torch.tensor(0.0))
add_pos = nanmask.float().to(v.dtype)
source_to_tile_indices = source_to_tile_indices_default.expand(-1, -1, v.shape[-1])
else:
source_to_tile_indices = source_to_tile_indices_default
Expand Down Expand Up @@ -167,6 +168,9 @@ def generate_cached_data(self, image_index):
ellip1_lensed = tile_dict["ellip1_lensed_sum"] / tile_dict["ellip1_lensed_count"]
ellip2_lensed = tile_dict["ellip2_lensed_sum"] / tile_dict["ellip2_lensed_count"]
ellip_lensed = torch.stack((ellip1_lensed.squeeze(-1), ellip2_lensed.squeeze(-1)), dim=-1)
ellip1_lsst = tile_dict["ellip1_lsst_sum"] / tile_dict["ellip1_lsst_count"]
ellip2_lsst = tile_dict["ellip2_lsst_sum"] / tile_dict["ellip2_lsst_count"]
ellip_lsst = torch.stack((ellip1_lsst.squeeze(-1), ellip2_lsst.squeeze(-1)), dim=-1)
redshift = tile_dict["redshift_sum"] / tile_dict["redshift_count"]
ra = tile_dict["ra_sum"] / tile_dict["ra_count"]
dec = tile_dict["dec_sum"] / tile_dict["dec_count"]
Expand All @@ -175,7 +179,8 @@ def generate_cached_data(self, image_index):
tile_dict["shear_2"] = shear2
tile_dict["convergence"] = convergence
tile_dict["ellip_lensed"] = ellip_lensed
tile_dict["ellip_lensed_wavg"] = compute_weighted_avg_ellip(
tile_dict["ellip_lsst"] = ellip_lsst
tile_dict["ellip_lsst_wavg"] = compute_weighted_avg_ellip(
tile_dict, self.avg_ellip_kernel_size, self.avg_ellip_kernel_sigma
)
tile_dict["redshift"] = redshift
Expand Down Expand Up @@ -211,15 +216,22 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
convergence = torch.from_numpy(catalog["convergence"].values)
reduced_shear = complex_shear / (1.0 - convergence)

ellip1_intrinsic = torch.from_numpy(catalog["ellipticity_1_true"].values)
ellip2_intrinsic = torch.from_numpy(catalog["ellipticity_2_true"].values)
ellip1_intrinsic = torch.from_numpy(catalog["ellipticity_1_true_dc2"].values)
ellip2_intrinsic = torch.from_numpy(catalog["ellipticity_2_true_dc2"].values)
complex_ellip_intrinsic = ellip1_intrinsic + ellip2_intrinsic * 1j
complex_ellip_lensed = (complex_ellip_intrinsic + reduced_shear) / (
1.0 + reduced_shear.conj() * complex_ellip_intrinsic
)
ellip1_lensed = torch.view_as_real(complex_ellip_lensed)[..., 0]
ellip2_lensed = torch.view_as_real(complex_ellip_lensed)[..., 1]

ixx = torch.from_numpy(catalog["Iyy_pixel"].values) # align coordinate system
iyy = torch.from_numpy(catalog["Ixx_pixel"].values) # align coordinate system
ixy = torch.from_numpy(catalog["Ixy_pixel"].values)
ellip_lsst = (ixx - iyy + 2j * ixy) / (ixx + iyy + 2 * np.sqrt(ixx * iyy - (ixy**2)))
ellip1_lsst = torch.view_as_real(ellip_lsst)[..., 0]
ellip2_lsst = torch.view_as_real(ellip_lsst)[..., 1]

redshift = torch.from_numpy(catalog["redshift"].values)

_, psf_params = cls.get_bands_flux_and_psf(kwargs["bands"], catalog, median=False)
Expand All @@ -240,6 +252,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
convergence = convergence[plocs_mask]
ellip1_lensed = ellip1_lensed[plocs_mask]
ellip2_lensed = ellip2_lensed[plocs_mask]
ellip1_lsst = ellip1_lsst[plocs_mask]
ellip2_lsst = ellip2_lsst[plocs_mask]

redshift = redshift[plocs_mask]

Expand All @@ -257,6 +271,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
"convergence": convergence.reshape(1, nobj, 1),
"ellip1_lensed": ellip1_lensed.reshape(1, nobj, 1),
"ellip2_lensed": ellip2_lensed.reshape(1, nobj, 1),
"ellip1_lsst": ellip1_lsst.reshape(1, nobj, 1),
"ellip2_lsst": ellip2_lsst.reshape(1, nobj, 1),
"redshift": redshift.reshape(1, nobj, 1),
"psf": psf_params.reshape(1, nobj, -1),
}
Expand Down
8 changes: 2 additions & 6 deletions case_studies/weak_lensing/lensing_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@ def update(self, true_cat, est_cat, matching) -> None:
pred_shear2 = est_cat["shear_2"].flatten(1, 2)
zero_baseline_pred_shear1 = torch.zeros_like(true_shear1)
zero_baseline_pred_shear2 = torch.zeros_like(true_shear2)
ellip_baseline_pred_shear1 = (
true_cat["ellip_lensed_wavg"][..., 0].unsqueeze(-1).flatten(1, 2)
)
ellip_baseline_pred_shear2 = (
true_cat["ellip_lensed_wavg"][..., 1].unsqueeze(-1).flatten(1, 2)
)
ellip_baseline_pred_shear1 = true_cat["ellip_lsst_wavg"][..., 0].unsqueeze(-1).flatten(1, 2)
ellip_baseline_pred_shear2 = true_cat["ellip_lsst_wavg"][..., 1].unsqueeze(-1).flatten(1, 2)

if "convergence" not in est_cat:
true_convergence = torch.zeros_like(true_shear1).flatten(1, 2)
Expand Down
669 changes: 477 additions & 192 deletions case_studies/weak_lensing/notebooks/dc2/ellipticity.ipynb

Large diffs are not rendered by default.

Loading