Skip to content

Commit

Permalink
Weak lensing: updated baseline shear estimator + manuscript figures (#…
Browse files Browse the repository at this point in the history
…1079)

* Encoder eval notebook: more metrics, make it easier to set use_mode=False

* Add _dc2 ellipticities to catalog pickle file

* Filter to relevant tracts in catalog generation scripts

* Update catalog generation notebooks to match scripts

* Update text in catalog generation notebooks

* Use _dc2 ellipticities, update config, rerun dc2 ellipticity notebook

* Update DC2 images path in lensing config

* Manuscript: preliminary figure 1

* Manuscript: preliminary figure 3

* Manuscript: minor tweaks to figures 1 and 3

* Manuscript: minor update to figure 1

* Manuscript: preiminary figure 2

* Manuscript: preliminary figure 4

* Manuscript: minor tweaks to figs 1, 3, 4

* Update figure numbers

* Manuscript: increase source brightness in figure 1

* Use LSST-measured second moments (ixx, ixy, iyy) in average ellip baseline

* Swap Ixx and Iyy to align coordinate system for baseline estimator

* Update avg ellip baseline tuning: use only training set, use R-squared

* Update figure 4 after using LSST ellips in baseline estimator

* Update plots in ellipticity tuning notebook

* Rerun encoder evaluation notebook
  • Loading branch information
timwhite0 authored Jan 7, 2025
1 parent bb0c5cb commit 73276aa
Show file tree
Hide file tree
Showing 13 changed files with 2,417 additions and 1,364 deletions.
12 changes: 9 additions & 3 deletions case_studies/weak_lensing/generate_dc2_lensing_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
from GCRCatalogs import GCRQuery
from GCRCatalogs.helpers.tract_catalogs import tract_filter

GCRCatalogs.set_root_dir("/data/scratch/dc2_nfs/")

Expand Down Expand Up @@ -43,7 +44,8 @@
"mag_i",
"mag_z",
"mag_y",
]
],
native_filters=[tract_filter([3634, 3635, 3636, 3827, 3828, 3829, 3830, 4025, 4026, 4027])],
)
truth_df = pd.DataFrame(truth_df)

Expand Down Expand Up @@ -100,14 +102,16 @@
"psf_fwhm_i",
"psf_fwhm_z",
"psf_fwhm_y",
]
],
filters=ra_dec_filters,
native_filters=[tract_filter([3634, 3635, 3636, 3827, 3828, 3829, 3830, 4025, 4026, 4027])],
)
object_truth_df = pd.DataFrame(object_truth_df)


print("Loading CosmoDC2...\n") # noqa: WPS421

config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2"}
config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2_v1.1.4"}

cosmo_cat = GCRCatalogs.load_catalog("desc_cosmodc2", config_overwrite)

Expand All @@ -118,6 +122,8 @@
"dec",
"ellipticity_1_true",
"ellipticity_2_true",
"ellipticity_1_true_dc2",
"ellipticity_2_true_dc2",
"shear_1",
"shear_2",
"convergence",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
from GCRCatalogs import GCRQuery
from GCRCatalogs.helpers.tract_catalogs import tract_filter

GCRCatalogs.set_root_dir("/data/scratch/dc2_nfs/")

Expand Down Expand Up @@ -72,6 +73,7 @@
"psf_fwhm_z",
"psf_fwhm_y",
],
native_filters=[tract_filter([3634, 3635, 3636, 3827, 3828, 3829, 3830, 4025, 4026, 4027])],
)
object_truth_df = pd.DataFrame(object_truth_df)

Expand All @@ -96,7 +98,7 @@

print("Loading CosmoDC2...\n") # noqa: WPS421

config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2"}
config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2_v1.1.4"}
cosmo_cat = GCRCatalogs.load_catalog("desc_cosmodc2", config_overwrite)

cosmo_df = cosmo_cat.get_quantities(
Expand All @@ -106,6 +108,8 @@
"dec",
"ellipticity_1_true",
"ellipticity_2_true",
"ellipticity_1_true_dc2",
"ellipticity_2_true_dc2",
"shear_1",
"shear_2",
"convergence",
Expand Down
8 changes: 4 additions & 4 deletions case_studies/weak_lensing/lensing_config_dc2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,25 +71,25 @@ encoder:
surveys:
dc2:
_target_: case_studies.weak_lensing.lensing_dc2.LensingDC2DataModule
dc2_image_dir: ${paths.dc2}/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
dc2_image_dir: /data/scratch/dc2_nfs/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
dc2_cat_path: ${paths.dc2}/dc2_lensing_catalog.pkl
image_slen: 4096
n_image_split: 2 # split into n_image_split**2 subimages
tile_slen: 256
splits: 0:80/80:90/90:100
avg_ellip_kernel_size: 15 # needs to be odd
avg_ellip_kernel_sigma: 4
avg_ellip_kernel_sigma: 5
batch_size: 1
num_workers: 1
cached_data_path: ${paths.dc2}/dc2_lensing_splits_radec
cached_data_path: ${paths.dc2}/dc2_lensing_splits
train_transforms:
- _target_: case_studies.weak_lensing.lensing_data_augmentation.LensingRotateFlipTransform

train:
trainer:
logger:
name: weak_lensing_experiments_dc2
version: october18
version: december4
max_epochs: 250
devices: 1
use_distributed_sampler: false
Expand Down
30 changes: 23 additions & 7 deletions case_studies/weak_lensing/lensing_dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from typing import List

import numpy as np
import pandas as pd
import torch

Expand Down Expand Up @@ -102,10 +103,10 @@ def to_tile_catalog(self, full_catalog, height, width):
v_count = torch.zeros(self.batch_size, num_tiles, v.shape[-1], dtype=v.dtype)
add_pos = torch.ones_like(v)

if k == "psf":
psf_nanmask = ~torch.isnan(v).any(dim=-1, keepdim=True).expand(-1, -1, v.shape[-1])
v = torch.where(psf_nanmask, v, torch.tensor(0.0))
add_pos = psf_nanmask.float().to(v.dtype)
if k in {"psf", "ellip1_lsst", "ellip2_lsst"}:
nanmask = ~torch.isnan(v).any(dim=-1, keepdim=True).expand(-1, -1, v.shape[-1])
v = torch.where(nanmask, v, torch.tensor(0.0))
add_pos = nanmask.float().to(v.dtype)
source_to_tile_indices = source_to_tile_indices_default.expand(-1, -1, v.shape[-1])
else:
source_to_tile_indices = source_to_tile_indices_default
Expand Down Expand Up @@ -167,6 +168,9 @@ def generate_cached_data(self, image_index):
ellip1_lensed = tile_dict["ellip1_lensed_sum"] / tile_dict["ellip1_lensed_count"]
ellip2_lensed = tile_dict["ellip2_lensed_sum"] / tile_dict["ellip2_lensed_count"]
ellip_lensed = torch.stack((ellip1_lensed.squeeze(-1), ellip2_lensed.squeeze(-1)), dim=-1)
ellip1_lsst = tile_dict["ellip1_lsst_sum"] / tile_dict["ellip1_lsst_count"]
ellip2_lsst = tile_dict["ellip2_lsst_sum"] / tile_dict["ellip2_lsst_count"]
ellip_lsst = torch.stack((ellip1_lsst.squeeze(-1), ellip2_lsst.squeeze(-1)), dim=-1)
redshift = tile_dict["redshift_sum"] / tile_dict["redshift_count"]
ra = tile_dict["ra_sum"] / tile_dict["ra_count"]
dec = tile_dict["dec_sum"] / tile_dict["dec_count"]
Expand All @@ -175,7 +179,8 @@ def generate_cached_data(self, image_index):
tile_dict["shear_2"] = shear2
tile_dict["convergence"] = convergence
tile_dict["ellip_lensed"] = ellip_lensed
tile_dict["ellip_lensed_wavg"] = compute_weighted_avg_ellip(
tile_dict["ellip_lsst"] = ellip_lsst
tile_dict["ellip_lsst_wavg"] = compute_weighted_avg_ellip(
tile_dict, self.avg_ellip_kernel_size, self.avg_ellip_kernel_sigma
)
tile_dict["redshift"] = redshift
Expand Down Expand Up @@ -211,15 +216,22 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
convergence = torch.from_numpy(catalog["convergence"].values)
reduced_shear = complex_shear / (1.0 - convergence)

ellip1_intrinsic = torch.from_numpy(catalog["ellipticity_1_true"].values)
ellip2_intrinsic = torch.from_numpy(catalog["ellipticity_2_true"].values)
ellip1_intrinsic = torch.from_numpy(catalog["ellipticity_1_true_dc2"].values)
ellip2_intrinsic = torch.from_numpy(catalog["ellipticity_2_true_dc2"].values)
complex_ellip_intrinsic = ellip1_intrinsic + ellip2_intrinsic * 1j
complex_ellip_lensed = (complex_ellip_intrinsic + reduced_shear) / (
1.0 + reduced_shear.conj() * complex_ellip_intrinsic
)
ellip1_lensed = torch.view_as_real(complex_ellip_lensed)[..., 0]
ellip2_lensed = torch.view_as_real(complex_ellip_lensed)[..., 1]

ixx = torch.from_numpy(catalog["Iyy_pixel"].values) # align coordinate system
iyy = torch.from_numpy(catalog["Ixx_pixel"].values) # align coordinate system
ixy = torch.from_numpy(catalog["Ixy_pixel"].values)
ellip_lsst = (ixx - iyy + 2j * ixy) / (ixx + iyy + 2 * np.sqrt(ixx * iyy - (ixy**2)))
ellip1_lsst = torch.view_as_real(ellip_lsst)[..., 0]
ellip2_lsst = torch.view_as_real(ellip_lsst)[..., 1]

redshift = torch.from_numpy(catalog["redshift"].values)

_, psf_params = cls.get_bands_flux_and_psf(kwargs["bands"], catalog, median=False)
Expand All @@ -240,6 +252,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
convergence = convergence[plocs_mask]
ellip1_lensed = ellip1_lensed[plocs_mask]
ellip2_lensed = ellip2_lensed[plocs_mask]
ellip1_lsst = ellip1_lsst[plocs_mask]
ellip2_lsst = ellip2_lsst[plocs_mask]

redshift = redshift[plocs_mask]

Expand All @@ -257,6 +271,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
"convergence": convergence.reshape(1, nobj, 1),
"ellip1_lensed": ellip1_lensed.reshape(1, nobj, 1),
"ellip2_lensed": ellip2_lensed.reshape(1, nobj, 1),
"ellip1_lsst": ellip1_lsst.reshape(1, nobj, 1),
"ellip2_lsst": ellip2_lsst.reshape(1, nobj, 1),
"redshift": redshift.reshape(1, nobj, 1),
"psf": psf_params.reshape(1, nobj, -1),
}
Expand Down
8 changes: 2 additions & 6 deletions case_studies/weak_lensing/lensing_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@ def update(self, true_cat, est_cat, matching) -> None:
pred_shear2 = est_cat["shear_2"].flatten(1, 2)
zero_baseline_pred_shear1 = torch.zeros_like(true_shear1)
zero_baseline_pred_shear2 = torch.zeros_like(true_shear2)
ellip_baseline_pred_shear1 = (
true_cat["ellip_lensed_wavg"][..., 0].unsqueeze(-1).flatten(1, 2)
)
ellip_baseline_pred_shear2 = (
true_cat["ellip_lensed_wavg"][..., 1].unsqueeze(-1).flatten(1, 2)
)
ellip_baseline_pred_shear1 = true_cat["ellip_lsst_wavg"][..., 0].unsqueeze(-1).flatten(1, 2)
ellip_baseline_pred_shear2 = true_cat["ellip_lsst_wavg"][..., 1].unsqueeze(-1).flatten(1, 2)

if "convergence" not in est_cat:
true_convergence = torch.zeros_like(true_shear1).flatten(1, 2)
Expand Down
669 changes: 477 additions & 192 deletions case_studies/weak_lensing/notebooks/dc2/ellipticity.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 73276aa

Please sign in to comment.