Skip to content

Commit

Permalink
Routine Weekly Merge (#1055)
Browse files Browse the repository at this point in the history
* Update config parameters for run_1.

* Fix shape sampling.

* Create bounding boxes for DES.

* Remove file datum generation file.

* Exploring DES outputs.

* Clean up notebook.

* Update inference notebook.

* Modify prior to use SVA1 sources.

* Modify prior to correctly calculate radii in pixels.

* Change default image size to 2560.

* Modify mean sources to match DES.

* Modify default config.

* Modify data gen to write catalogs on the fly.

* Script for basic cluster counting.

* Script for mapping SVA1 regions. Flake8 throws import error.

* Modify prior to use each source redshift instead of single cluster redshift.
  • Loading branch information
kapnadak authored Aug 9, 2024
1 parent da6f2b5 commit cfb8f1c
Show file tree
Hide file tree
Showing 8 changed files with 803 additions and 211 deletions.
14 changes: 7 additions & 7 deletions case_studies/galaxy_clustering/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ defaults:
- override hydra/job_logging: stdout

data_gen:
data_dir: /nfs/turbo/lsa-regier/scratch/kapnadak/new_data
image_size: 1280
tile_size: 128
nfiles: 5000
data_dir: /data/scratch/kapnadak/data-08-08/
image_size: 2560
tile_size: 256
nfiles: 2000
n_catalogs_per_file: 500
bands: ["g", "r", "i", "z"]
min_flux_for_loss: 0
Expand Down Expand Up @@ -37,7 +37,7 @@ cached_simulator:
batch_size: 2
splits: 0:60/60:90/90:100 # train/val/test splits as percent ranges
num_workers: 8
cached_data_path: /nfs/turbo/lsa-regier/scratch/kapnadak/new_data/file_data
cached_data_path: /data/scratch/kapnadak/new_data/file_data
train_transforms: []

train:
Expand Down Expand Up @@ -81,7 +81,7 @@ encoder:

predict:
cached_dataset:
_target_: case_studies.galaxy_clustering.cached_dataset.CachedDESModule
_target_: case_studies.galaxy_clustering.inference.cached_dataset.CachedDESModule
cached_data_path: /nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles
tiles_per_img: 64
batch_size: 2
Expand All @@ -100,6 +100,6 @@ predict:
output_dir: "/data/scratch/des/dr2_detection_output/run_1"
write_interval: "batch"
encoder: ${encoder}
weight_save_path: /nfs/turbo/lsa-regier/scratch/gapatron/best_encoder.ckpt
weight_save_path: /data/scratch/des/best_encoder.ckpt
device: "cuda:0"
output_save_path: "/data/scratch/des/dr2_detection_output/run_0"
33 changes: 33 additions & 0 deletions case_studies/galaxy_clustering/data_generation/SVA_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# flake8: noqa
import pickle

import numpy as np
import pandas as pd
from astropy.table import Table

SVA_PATH = "/data/scratch/des/sva1_gold_r1.0_catalog.fits"
BOUNDING_BOX_PATH = "/data/scratch/des/bounding_coordinates.pickle"


def main():
sva_catalog = Table.read(SVA_PATH).to_pandas()
bounding_boxes = pd.read_pickle(BOUNDING_BOX_PATH)
des_sva_intersection = []
output_filename = "/data/scratch/des/sva_map.pickle"

for k, v in bounding_boxes.items():
ra_min, ra_max, dec_min, dec_max = v["RA_min"], v["RA_max"], v["DEC_min"], v["DEC_max"]
ra_intersection = np.logical_and((ra_min < sva_catalog["RA"]), (sva_catalog["RA"] < ra_max))
dec_intersection = np.logical_and(
(dec_min < sva_catalog["DEC"]), (sva_catalog["DEC"] < dec_max)
)
full_intersection = np.logical_and(ra_intersection, dec_intersection)
if full_intersection.any():
des_sva_intersection.append(k)

with open(output_filename, "wb") as handle:
pickle.dump(des_sva_intersection, handle, protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions case_studies/galaxy_clustering/data_generation/coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
import pickle
from pathlib import Path

import pandas as pd
from astropy.io import fits


def main():
des_dir = Path(
"/nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles/"
)
des_subdirs = [d for d in os.listdir(des_dir) if d.startswith("DES")]
bounding_coordinates = {}
output_filename = "/data/scratch/des/bounding_coordinates.pickle"
for des_subdir in des_subdirs:
catalog_path = des_dir / Path(des_subdir) / Path(f"{des_subdir}_dr2_main.fits")
catalog_data = fits.getdata(catalog_path)
source_df = pd.DataFrame(catalog_data)
ra_min, ra_max = source_df["RA"].min(), source_df["RA"].max()
dec_min, dec_max = source_df["DEC"].min(), source_df["DEC"].max()
bounding_box = {
"RA_min": ra_min,
"RA_max": ra_max,
"DEC_min": dec_min,
"DEC_max": dec_max,
}
bounding_coordinates[des_subdir] = bounding_box

with open(output_filename, "wb") as handle:
pickle.dump(bounding_coordinates, handle, protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == "__main__":
main()
10 changes: 4 additions & 6 deletions case_studies/galaxy_clustering/data_generation/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,14 @@ def catalog_gen(cfg):
cluster_prior = ClusterPrior(image_size=image_size)
background_prior = BackgroundPrior(image_size=image_size)

combined_catalogs = []
for _ in range(nfiles):
for i in range(nfiles):
background_catalog = background_prior.sample_background()
if np.random.uniform() < 0.5:
cluster_catalog = cluster_prior.sample_cluster()
combined_catalogs.append(pd.concat([cluster_catalog, background_catalog]))
catalog = pd.concat([cluster_catalog, background_catalog])
else:
combined_catalogs.append(background_catalog)

for i, catalog in enumerate(combined_catalogs):
catalog = background_catalog
print(f"Writing catalog {i} ...")
file_name = f"{catalogs_path}/{file_prefix}_{i:03}.dat"
catalog_table = Table.from_pandas(catalog)
astro_ascii.write(catalog_table, file_name, format="no_header", overwrite=True)
Expand Down

This file was deleted.

59 changes: 59 additions & 0 deletions case_studies/galaxy_clustering/data_generation/inference_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import pickle

import torch

DES_DIR = (
"/nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles"
)
DES_BANDS = ("g", "r", "i", "z")
DES_SUBDIRS = (d for d in os.listdir(DES_DIR) if d.startswith("DES"))
OUTPUT_DIR = "/data/scratch/des/dr2_detection_output/run_1"


def convert_to_global_idx(tile_idx, gpu_idx):
num_gpus = 2
tiles_per_img = 64
batch_size = 2
dir_idx = int(num_gpus * (tile_idx // (tiles_per_img / batch_size)) + gpu_idx)
subimage_idx = [(batch_size * tile_idx + i) % tiles_per_img for i in range(batch_size)]
return dir_idx, subimage_idx


def convert_to_tile_idx(dir_idx):
num_gpus = 2
tiles_per_img = 64
batch_size = 2
gpu_idx = dir_idx % num_gpus
tile_starting_idx = (tiles_per_img / batch_size) * (dir_idx // num_gpus)
return int(tile_starting_idx), int(gpu_idx)


def count_num_clusters(dir_idx):
memberships = torch.empty((0, 10, 10))
tile_starting_idx, gpu_idx = convert_to_tile_idx(dir_idx)
for tile in range(tile_starting_idx, tile_starting_idx + 32):
file = torch.load(
f"{OUTPUT_DIR}/rank_{gpu_idx}_batchIdx_{tile}_dataloaderIdx_0.pt",
map_location=torch.device("cpu"),
)
memberships = torch.cat((memberships, file["mode_cat"]["membership"].squeeze()), dim=0)
memberships = torch.repeat_interleave(memberships, repeats=128, dim=1)
memberships = torch.repeat_interleave(memberships, repeats=128, dim=2)
return torch.any(memberships.view(memberships.shape[0], -1), dim=1).sum()


def main():
num_clusters = {}
output_filename = "/data/scratch/des/num_clusters.pickle"
for dir_idx, des_dir in enumerate(DES_SUBDIRS):
if dir_idx > 10167:
break
num_clusters[des_dir] = count_num_clusters(dir_idx)

with open(output_filename, "wb") as handle:
pickle.dump(num_clusters, handle, protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == "__main__":
main()
Loading

0 comments on commit cfb8f1c

Please sign in to comment.