Skip to content

Commit 73276aa

Browse files
authored
Weak lensing: updated baseline shear estimator + manuscript figures (#1079)
* Encoder eval notebook: more metrics, make it easier to set use_mode=False * Add _dc2 ellipticities to catalog pickle file * Filter to relevant tracts in catalog generation scripts * Update catalog generation notebooks to match scripts * Update text in catalog generation notebooks * Use _dc2 ellipticities, update config, rerun dc2 ellipticity notebook * Update DC2 images path in lensing config * Manuscript: preliminary figure 1 * Manuscript: preliminary figure 3 * Manuscript: minor tweaks to figures 1 and 3 * Manuscript: minor update to figure 1 * Manuscript: preiminary figure 2 * Manuscript: preliminary figure 4 * Manuscript: minor tweaks to figs 1, 3, 4 * Update figure numbers * Manuscript: increase source brightness in figure 1 * Use LSST-measured second moments (ixx, ixy, iyy) in average ellip baseline * Swap Ixx and Iyy to align coordinate system for baseline estimator * Update avg ellip baseline tuning: use only training set, use R-squared * Update figure 4 after using LSST ellips in baseline estimator * Update plots in ellipticity tuning notebook * Rerun encoder evaluation notebook
1 parent bb0c5cb commit 73276aa

13 files changed

+2417
-1364
lines changed

case_studies/weak_lensing/generate_dc2_lensing_catalog.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pandas as pd
99
from GCRCatalogs import GCRQuery
10+
from GCRCatalogs.helpers.tract_catalogs import tract_filter
1011

1112
GCRCatalogs.set_root_dir("/data/scratch/dc2_nfs/")
1213

@@ -43,7 +44,8 @@
4344
"mag_i",
4445
"mag_z",
4546
"mag_y",
46-
]
47+
],
48+
native_filters=[tract_filter([3634, 3635, 3636, 3827, 3828, 3829, 3830, 4025, 4026, 4027])],
4749
)
4850
truth_df = pd.DataFrame(truth_df)
4951

@@ -100,14 +102,16 @@
100102
"psf_fwhm_i",
101103
"psf_fwhm_z",
102104
"psf_fwhm_y",
103-
]
105+
],
106+
filters=ra_dec_filters,
107+
native_filters=[tract_filter([3634, 3635, 3636, 3827, 3828, 3829, 3830, 4025, 4026, 4027])],
104108
)
105109
object_truth_df = pd.DataFrame(object_truth_df)
106110

107111

108112
print("Loading CosmoDC2...\n") # noqa: WPS421
109113

110-
config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2"}
114+
config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2_v1.1.4"}
111115

112116
cosmo_cat = GCRCatalogs.load_catalog("desc_cosmodc2", config_overwrite)
113117

@@ -118,6 +122,8 @@
118122
"dec",
119123
"ellipticity_1_true",
120124
"ellipticity_2_true",
125+
"ellipticity_1_true_dc2",
126+
"ellipticity_2_true_dc2",
121127
"shear_1",
122128
"shear_2",
123129
"convergence",

case_studies/weak_lensing/generate_dc2_lensing_catalog_objectmatch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pandas as pd
99
from GCRCatalogs import GCRQuery
10+
from GCRCatalogs.helpers.tract_catalogs import tract_filter
1011

1112
GCRCatalogs.set_root_dir("/data/scratch/dc2_nfs/")
1213

@@ -72,6 +73,7 @@
7273
"psf_fwhm_z",
7374
"psf_fwhm_y",
7475
],
76+
native_filters=[tract_filter([3634, 3635, 3636, 3827, 3828, 3829, 3830, 4025, 4026, 4027])],
7577
)
7678
object_truth_df = pd.DataFrame(object_truth_df)
7779

@@ -96,7 +98,7 @@
9698

9799
print("Loading CosmoDC2...\n") # noqa: WPS421
98100

99-
config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2"}
101+
config_overwrite = {"catalog_root_dir": "/data/scratch/dc2_nfs/cosmoDC2_v1.1.4"}
100102
cosmo_cat = GCRCatalogs.load_catalog("desc_cosmodc2", config_overwrite)
101103

102104
cosmo_df = cosmo_cat.get_quantities(
@@ -106,6 +108,8 @@
106108
"dec",
107109
"ellipticity_1_true",
108110
"ellipticity_2_true",
111+
"ellipticity_1_true_dc2",
112+
"ellipticity_2_true_dc2",
109113
"shear_1",
110114
"shear_2",
111115
"convergence",

case_studies/weak_lensing/lensing_config_dc2.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,25 @@ encoder:
7171
surveys:
7272
dc2:
7373
_target_: case_studies.weak_lensing.lensing_dc2.LensingDC2DataModule
74-
dc2_image_dir: ${paths.dc2}/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
74+
dc2_image_dir: /data/scratch/dc2_nfs/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
7575
dc2_cat_path: ${paths.dc2}/dc2_lensing_catalog.pkl
7676
image_slen: 4096
7777
n_image_split: 2 # split into n_image_split**2 subimages
7878
tile_slen: 256
7979
splits: 0:80/80:90/90:100
8080
avg_ellip_kernel_size: 15 # needs to be odd
81-
avg_ellip_kernel_sigma: 4
81+
avg_ellip_kernel_sigma: 5
8282
batch_size: 1
8383
num_workers: 1
84-
cached_data_path: ${paths.dc2}/dc2_lensing_splits_radec
84+
cached_data_path: ${paths.dc2}/dc2_lensing_splits
8585
train_transforms:
8686
- _target_: case_studies.weak_lensing.lensing_data_augmentation.LensingRotateFlipTransform
8787

8888
train:
8989
trainer:
9090
logger:
9191
name: weak_lensing_experiments_dc2
92-
version: october18
92+
version: december4
9393
max_epochs: 250
9494
devices: 1
9595
use_distributed_sampler: false

case_studies/weak_lensing/lensing_dc2.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
from typing import List
55

6+
import numpy as np
67
import pandas as pd
78
import torch
89

@@ -102,10 +103,10 @@ def to_tile_catalog(self, full_catalog, height, width):
102103
v_count = torch.zeros(self.batch_size, num_tiles, v.shape[-1], dtype=v.dtype)
103104
add_pos = torch.ones_like(v)
104105

105-
if k == "psf":
106-
psf_nanmask = ~torch.isnan(v).any(dim=-1, keepdim=True).expand(-1, -1, v.shape[-1])
107-
v = torch.where(psf_nanmask, v, torch.tensor(0.0))
108-
add_pos = psf_nanmask.float().to(v.dtype)
106+
if k in {"psf", "ellip1_lsst", "ellip2_lsst"}:
107+
nanmask = ~torch.isnan(v).any(dim=-1, keepdim=True).expand(-1, -1, v.shape[-1])
108+
v = torch.where(nanmask, v, torch.tensor(0.0))
109+
add_pos = nanmask.float().to(v.dtype)
109110
source_to_tile_indices = source_to_tile_indices_default.expand(-1, -1, v.shape[-1])
110111
else:
111112
source_to_tile_indices = source_to_tile_indices_default
@@ -167,6 +168,9 @@ def generate_cached_data(self, image_index):
167168
ellip1_lensed = tile_dict["ellip1_lensed_sum"] / tile_dict["ellip1_lensed_count"]
168169
ellip2_lensed = tile_dict["ellip2_lensed_sum"] / tile_dict["ellip2_lensed_count"]
169170
ellip_lensed = torch.stack((ellip1_lensed.squeeze(-1), ellip2_lensed.squeeze(-1)), dim=-1)
171+
ellip1_lsst = tile_dict["ellip1_lsst_sum"] / tile_dict["ellip1_lsst_count"]
172+
ellip2_lsst = tile_dict["ellip2_lsst_sum"] / tile_dict["ellip2_lsst_count"]
173+
ellip_lsst = torch.stack((ellip1_lsst.squeeze(-1), ellip2_lsst.squeeze(-1)), dim=-1)
170174
redshift = tile_dict["redshift_sum"] / tile_dict["redshift_count"]
171175
ra = tile_dict["ra_sum"] / tile_dict["ra_count"]
172176
dec = tile_dict["dec_sum"] / tile_dict["dec_count"]
@@ -175,7 +179,8 @@ def generate_cached_data(self, image_index):
175179
tile_dict["shear_2"] = shear2
176180
tile_dict["convergence"] = convergence
177181
tile_dict["ellip_lensed"] = ellip_lensed
178-
tile_dict["ellip_lensed_wavg"] = compute_weighted_avg_ellip(
182+
tile_dict["ellip_lsst"] = ellip_lsst
183+
tile_dict["ellip_lsst_wavg"] = compute_weighted_avg_ellip(
179184
tile_dict, self.avg_ellip_kernel_size, self.avg_ellip_kernel_sigma
180185
)
181186
tile_dict["redshift"] = redshift
@@ -211,15 +216,22 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
211216
convergence = torch.from_numpy(catalog["convergence"].values)
212217
reduced_shear = complex_shear / (1.0 - convergence)
213218

214-
ellip1_intrinsic = torch.from_numpy(catalog["ellipticity_1_true"].values)
215-
ellip2_intrinsic = torch.from_numpy(catalog["ellipticity_2_true"].values)
219+
ellip1_intrinsic = torch.from_numpy(catalog["ellipticity_1_true_dc2"].values)
220+
ellip2_intrinsic = torch.from_numpy(catalog["ellipticity_2_true_dc2"].values)
216221
complex_ellip_intrinsic = ellip1_intrinsic + ellip2_intrinsic * 1j
217222
complex_ellip_lensed = (complex_ellip_intrinsic + reduced_shear) / (
218223
1.0 + reduced_shear.conj() * complex_ellip_intrinsic
219224
)
220225
ellip1_lensed = torch.view_as_real(complex_ellip_lensed)[..., 0]
221226
ellip2_lensed = torch.view_as_real(complex_ellip_lensed)[..., 1]
222227

228+
ixx = torch.from_numpy(catalog["Iyy_pixel"].values) # align coordinate system
229+
iyy = torch.from_numpy(catalog["Ixx_pixel"].values) # align coordinate system
230+
ixy = torch.from_numpy(catalog["Ixy_pixel"].values)
231+
ellip_lsst = (ixx - iyy + 2j * ixy) / (ixx + iyy + 2 * np.sqrt(ixx * iyy - (ixy**2)))
232+
ellip1_lsst = torch.view_as_real(ellip_lsst)[..., 0]
233+
ellip2_lsst = torch.view_as_real(ellip_lsst)[..., 1]
234+
223235
redshift = torch.from_numpy(catalog["redshift"].values)
224236

225237
_, psf_params = cls.get_bands_flux_and_psf(kwargs["bands"], catalog, median=False)
@@ -240,6 +252,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
240252
convergence = convergence[plocs_mask]
241253
ellip1_lensed = ellip1_lensed[plocs_mask]
242254
ellip2_lensed = ellip2_lensed[plocs_mask]
255+
ellip1_lsst = ellip1_lsst[plocs_mask]
256+
ellip2_lsst = ellip2_lsst[plocs_mask]
243257

244258
redshift = redshift[plocs_mask]
245259

@@ -257,6 +271,8 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
257271
"convergence": convergence.reshape(1, nobj, 1),
258272
"ellip1_lensed": ellip1_lensed.reshape(1, nobj, 1),
259273
"ellip2_lensed": ellip2_lensed.reshape(1, nobj, 1),
274+
"ellip1_lsst": ellip1_lsst.reshape(1, nobj, 1),
275+
"ellip2_lsst": ellip2_lsst.reshape(1, nobj, 1),
260276
"redshift": redshift.reshape(1, nobj, 1),
261277
"psf": psf_params.reshape(1, nobj, -1),
262278
}

case_studies/weak_lensing/lensing_metrics.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,8 @@ def update(self, true_cat, est_cat, matching) -> None:
2929
pred_shear2 = est_cat["shear_2"].flatten(1, 2)
3030
zero_baseline_pred_shear1 = torch.zeros_like(true_shear1)
3131
zero_baseline_pred_shear2 = torch.zeros_like(true_shear2)
32-
ellip_baseline_pred_shear1 = (
33-
true_cat["ellip_lensed_wavg"][..., 0].unsqueeze(-1).flatten(1, 2)
34-
)
35-
ellip_baseline_pred_shear2 = (
36-
true_cat["ellip_lensed_wavg"][..., 1].unsqueeze(-1).flatten(1, 2)
37-
)
32+
ellip_baseline_pred_shear1 = true_cat["ellip_lsst_wavg"][..., 0].unsqueeze(-1).flatten(1, 2)
33+
ellip_baseline_pred_shear2 = true_cat["ellip_lsst_wavg"][..., 1].unsqueeze(-1).flatten(1, 2)
3834

3935
if "convergence" not in est_cat:
4036
true_convergence = torch.zeros_like(true_shear1).flatten(1, 2)

case_studies/weak_lensing/notebooks/dc2/ellipticity.ipynb

Lines changed: 477 additions & 192 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)