Skip to content

Commit cfb8f1c

Browse files
authored
Routine Weekly Merge (#1055)
* Update config parameters for run_1. * Fix shape sampling. * Create bounding boxes for DES. * Remove file datum generation file. * Exploring DES outputs. * Clean up notebook. * Update inference notebook. * Modify prior to use SVA1 sources. * Modify prior to correctly calculate radii in pixels. * Change default image size to 2560. * Modify mean sources to match DES. * Modify default config. * Modify data gen to write catalogs on the fly. * Script for basic cluster counting. * Script for mapping SVA1 regions. Flake8 throws import error. * Modify prior to use each source redshift instead of single cluster redshift.
1 parent da6f2b5 commit cfb8f1c

File tree

8 files changed

+803
-211
lines changed

8 files changed

+803
-211
lines changed

case_studies/galaxy_clustering/conf/config.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ defaults:
55
- override hydra/job_logging: stdout
66

77
data_gen:
8-
data_dir: /nfs/turbo/lsa-regier/scratch/kapnadak/new_data
9-
image_size: 1280
10-
tile_size: 128
11-
nfiles: 5000
8+
data_dir: /data/scratch/kapnadak/data-08-08/
9+
image_size: 2560
10+
tile_size: 256
11+
nfiles: 2000
1212
n_catalogs_per_file: 500
1313
bands: ["g", "r", "i", "z"]
1414
min_flux_for_loss: 0
@@ -37,7 +37,7 @@ cached_simulator:
3737
batch_size: 2
3838
splits: 0:60/60:90/90:100 # train/val/test splits as percent ranges
3939
num_workers: 8
40-
cached_data_path: /nfs/turbo/lsa-regier/scratch/kapnadak/new_data/file_data
40+
cached_data_path: /data/scratch/kapnadak/new_data/file_data
4141
train_transforms: []
4242

4343
train:
@@ -81,7 +81,7 @@ encoder:
8181

8282
predict:
8383
cached_dataset:
84-
_target_: case_studies.galaxy_clustering.cached_dataset.CachedDESModule
84+
_target_: case_studies.galaxy_clustering.inference.cached_dataset.CachedDESModule
8585
cached_data_path: /nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles
8686
tiles_per_img: 64
8787
batch_size: 2
@@ -100,6 +100,6 @@ predict:
100100
output_dir: "/data/scratch/des/dr2_detection_output/run_1"
101101
write_interval: "batch"
102102
encoder: ${encoder}
103-
weight_save_path: /nfs/turbo/lsa-regier/scratch/gapatron/best_encoder.ckpt
103+
weight_save_path: /data/scratch/des/best_encoder.ckpt
104104
device: "cuda:0"
105105
output_save_path: "/data/scratch/des/dr2_detection_output/run_0"
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# flake8: noqa
2+
import pickle
3+
4+
import numpy as np
5+
import pandas as pd
6+
from astropy.table import Table
7+
8+
SVA_PATH = "/data/scratch/des/sva1_gold_r1.0_catalog.fits"
9+
BOUNDING_BOX_PATH = "/data/scratch/des/bounding_coordinates.pickle"
10+
11+
12+
def main():
13+
sva_catalog = Table.read(SVA_PATH).to_pandas()
14+
bounding_boxes = pd.read_pickle(BOUNDING_BOX_PATH)
15+
des_sva_intersection = []
16+
output_filename = "/data/scratch/des/sva_map.pickle"
17+
18+
for k, v in bounding_boxes.items():
19+
ra_min, ra_max, dec_min, dec_max = v["RA_min"], v["RA_max"], v["DEC_min"], v["DEC_max"]
20+
ra_intersection = np.logical_and((ra_min < sva_catalog["RA"]), (sva_catalog["RA"] < ra_max))
21+
dec_intersection = np.logical_and(
22+
(dec_min < sva_catalog["DEC"]), (sva_catalog["DEC"] < dec_max)
23+
)
24+
full_intersection = np.logical_and(ra_intersection, dec_intersection)
25+
if full_intersection.any():
26+
des_sva_intersection.append(k)
27+
28+
with open(output_filename, "wb") as handle:
29+
pickle.dump(des_sva_intersection, handle, protocol=pickle.HIGHEST_PROTOCOL)
30+
31+
32+
if __name__ == "__main__":
33+
main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
import pickle
3+
from pathlib import Path
4+
5+
import pandas as pd
6+
from astropy.io import fits
7+
8+
9+
def main():
10+
des_dir = Path(
11+
"/nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles/"
12+
)
13+
des_subdirs = [d for d in os.listdir(des_dir) if d.startswith("DES")]
14+
bounding_coordinates = {}
15+
output_filename = "/data/scratch/des/bounding_coordinates.pickle"
16+
for des_subdir in des_subdirs:
17+
catalog_path = des_dir / Path(des_subdir) / Path(f"{des_subdir}_dr2_main.fits")
18+
catalog_data = fits.getdata(catalog_path)
19+
source_df = pd.DataFrame(catalog_data)
20+
ra_min, ra_max = source_df["RA"].min(), source_df["RA"].max()
21+
dec_min, dec_max = source_df["DEC"].min(), source_df["DEC"].max()
22+
bounding_box = {
23+
"RA_min": ra_min,
24+
"RA_max": ra_max,
25+
"DEC_min": dec_min,
26+
"DEC_max": dec_max,
27+
}
28+
bounding_coordinates[des_subdir] = bounding_box
29+
30+
with open(output_filename, "wb") as handle:
31+
pickle.dump(bounding_coordinates, handle, protocol=pickle.HIGHEST_PROTOCOL)
32+
33+
34+
if __name__ == "__main__":
35+
main()

case_studies/galaxy_clustering/data_generation/data_gen.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,14 @@ def catalog_gen(cfg):
4848
cluster_prior = ClusterPrior(image_size=image_size)
4949
background_prior = BackgroundPrior(image_size=image_size)
5050

51-
combined_catalogs = []
52-
for _ in range(nfiles):
51+
for i in range(nfiles):
5352
background_catalog = background_prior.sample_background()
5453
if np.random.uniform() < 0.5:
5554
cluster_catalog = cluster_prior.sample_cluster()
56-
combined_catalogs.append(pd.concat([cluster_catalog, background_catalog]))
55+
catalog = pd.concat([cluster_catalog, background_catalog])
5756
else:
58-
combined_catalogs.append(background_catalog)
59-
60-
for i, catalog in enumerate(combined_catalogs):
57+
catalog = background_catalog
58+
print(f"Writing catalog {i} ...")
6159
file_name = f"{catalogs_path}/{file_prefix}_{i:03}.dat"
6260
catalog_table = Table.from_pandas(catalog)
6361
astro_ascii.write(catalog_table, file_name, format="no_header", overwrite=True)

case_studies/galaxy_clustering/data_generation/file_datum_generation.py

Lines changed: 0 additions & 110 deletions
This file was deleted.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
import pickle
3+
4+
import torch
5+
6+
DES_DIR = (
7+
"/nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles"
8+
)
9+
DES_BANDS = ("g", "r", "i", "z")
10+
DES_SUBDIRS = (d for d in os.listdir(DES_DIR) if d.startswith("DES"))
11+
OUTPUT_DIR = "/data/scratch/des/dr2_detection_output/run_1"
12+
13+
14+
def convert_to_global_idx(tile_idx, gpu_idx):
15+
num_gpus = 2
16+
tiles_per_img = 64
17+
batch_size = 2
18+
dir_idx = int(num_gpus * (tile_idx // (tiles_per_img / batch_size)) + gpu_idx)
19+
subimage_idx = [(batch_size * tile_idx + i) % tiles_per_img for i in range(batch_size)]
20+
return dir_idx, subimage_idx
21+
22+
23+
def convert_to_tile_idx(dir_idx):
24+
num_gpus = 2
25+
tiles_per_img = 64
26+
batch_size = 2
27+
gpu_idx = dir_idx % num_gpus
28+
tile_starting_idx = (tiles_per_img / batch_size) * (dir_idx // num_gpus)
29+
return int(tile_starting_idx), int(gpu_idx)
30+
31+
32+
def count_num_clusters(dir_idx):
33+
memberships = torch.empty((0, 10, 10))
34+
tile_starting_idx, gpu_idx = convert_to_tile_idx(dir_idx)
35+
for tile in range(tile_starting_idx, tile_starting_idx + 32):
36+
file = torch.load(
37+
f"{OUTPUT_DIR}/rank_{gpu_idx}_batchIdx_{tile}_dataloaderIdx_0.pt",
38+
map_location=torch.device("cpu"),
39+
)
40+
memberships = torch.cat((memberships, file["mode_cat"]["membership"].squeeze()), dim=0)
41+
memberships = torch.repeat_interleave(memberships, repeats=128, dim=1)
42+
memberships = torch.repeat_interleave(memberships, repeats=128, dim=2)
43+
return torch.any(memberships.view(memberships.shape[0], -1), dim=1).sum()
44+
45+
46+
def main():
47+
num_clusters = {}
48+
output_filename = "/data/scratch/des/num_clusters.pickle"
49+
for dir_idx, des_dir in enumerate(DES_SUBDIRS):
50+
if dir_idx > 10167:
51+
break
52+
num_clusters[des_dir] = count_num_clusters(dir_idx)
53+
54+
with open(output_filename, "wb") as handle:
55+
pickle.dump(num_clusters, handle, protocol=pickle.HIGHEST_PROTOCOL)
56+
57+
58+
if __name__ == "__main__":
59+
main()

0 commit comments

Comments
 (0)