-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update config parameters for run_1. * Fix shape sampling. * Create bounding boxes for DES. * Remove file datum generation file. * Exploring DES outputs. * Clean up notebook. * Update inference notebook. * Modify prior to use SVA1 sources. * Modify prior to correctly calculate radii in pixels. * Change default image size to 2560. * Modify mean sources to match DES. * Modify default config. * Modify data gen to write catalogs on the fly. * Script for basic cluster counting. * Script for mapping SVA1 regions. Flake8 throws import error. * Modify prior to use each source redshift instead of single cluster redshift.
- Loading branch information
Showing
8 changed files
with
803 additions
and
211 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# flake8: noqa | ||
import pickle | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from astropy.table import Table | ||
|
||
SVA_PATH = "/data/scratch/des/sva1_gold_r1.0_catalog.fits" | ||
BOUNDING_BOX_PATH = "/data/scratch/des/bounding_coordinates.pickle" | ||
|
||
|
||
def main(): | ||
sva_catalog = Table.read(SVA_PATH).to_pandas() | ||
bounding_boxes = pd.read_pickle(BOUNDING_BOX_PATH) | ||
des_sva_intersection = [] | ||
output_filename = "/data/scratch/des/sva_map.pickle" | ||
|
||
for k, v in bounding_boxes.items(): | ||
ra_min, ra_max, dec_min, dec_max = v["RA_min"], v["RA_max"], v["DEC_min"], v["DEC_max"] | ||
ra_intersection = np.logical_and((ra_min < sva_catalog["RA"]), (sva_catalog["RA"] < ra_max)) | ||
dec_intersection = np.logical_and( | ||
(dec_min < sva_catalog["DEC"]), (sva_catalog["DEC"] < dec_max) | ||
) | ||
full_intersection = np.logical_and(ra_intersection, dec_intersection) | ||
if full_intersection.any(): | ||
des_sva_intersection.append(k) | ||
|
||
with open(output_filename, "wb") as handle: | ||
pickle.dump(des_sva_intersection, handle, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
35 changes: 35 additions & 0 deletions
35
case_studies/galaxy_clustering/data_generation/coordinates.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import os | ||
import pickle | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
from astropy.io import fits | ||
|
||
|
||
def main(): | ||
des_dir = Path( | ||
"/nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles/" | ||
) | ||
des_subdirs = [d for d in os.listdir(des_dir) if d.startswith("DES")] | ||
bounding_coordinates = {} | ||
output_filename = "/data/scratch/des/bounding_coordinates.pickle" | ||
for des_subdir in des_subdirs: | ||
catalog_path = des_dir / Path(des_subdir) / Path(f"{des_subdir}_dr2_main.fits") | ||
catalog_data = fits.getdata(catalog_path) | ||
source_df = pd.DataFrame(catalog_data) | ||
ra_min, ra_max = source_df["RA"].min(), source_df["RA"].max() | ||
dec_min, dec_max = source_df["DEC"].min(), source_df["DEC"].max() | ||
bounding_box = { | ||
"RA_min": ra_min, | ||
"RA_max": ra_max, | ||
"DEC_min": dec_min, | ||
"DEC_max": dec_max, | ||
} | ||
bounding_coordinates[des_subdir] = bounding_box | ||
|
||
with open(output_filename, "wb") as handle: | ||
pickle.dump(bounding_coordinates, handle, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
110 changes: 0 additions & 110 deletions
110
case_studies/galaxy_clustering/data_generation/file_datum_generation.py
This file was deleted.
Oops, something went wrong.
59 changes: 59 additions & 0 deletions
59
case_studies/galaxy_clustering/data_generation/inference_stats.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
import pickle | ||
|
||
import torch | ||
|
||
DES_DIR = ( | ||
"/nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles" | ||
) | ||
DES_BANDS = ("g", "r", "i", "z") | ||
DES_SUBDIRS = (d for d in os.listdir(DES_DIR) if d.startswith("DES")) | ||
OUTPUT_DIR = "/data/scratch/des/dr2_detection_output/run_1" | ||
|
||
|
||
def convert_to_global_idx(tile_idx, gpu_idx): | ||
num_gpus = 2 | ||
tiles_per_img = 64 | ||
batch_size = 2 | ||
dir_idx = int(num_gpus * (tile_idx // (tiles_per_img / batch_size)) + gpu_idx) | ||
subimage_idx = [(batch_size * tile_idx + i) % tiles_per_img for i in range(batch_size)] | ||
return dir_idx, subimage_idx | ||
|
||
|
||
def convert_to_tile_idx(dir_idx): | ||
num_gpus = 2 | ||
tiles_per_img = 64 | ||
batch_size = 2 | ||
gpu_idx = dir_idx % num_gpus | ||
tile_starting_idx = (tiles_per_img / batch_size) * (dir_idx // num_gpus) | ||
return int(tile_starting_idx), int(gpu_idx) | ||
|
||
|
||
def count_num_clusters(dir_idx): | ||
memberships = torch.empty((0, 10, 10)) | ||
tile_starting_idx, gpu_idx = convert_to_tile_idx(dir_idx) | ||
for tile in range(tile_starting_idx, tile_starting_idx + 32): | ||
file = torch.load( | ||
f"{OUTPUT_DIR}/rank_{gpu_idx}_batchIdx_{tile}_dataloaderIdx_0.pt", | ||
map_location=torch.device("cpu"), | ||
) | ||
memberships = torch.cat((memberships, file["mode_cat"]["membership"].squeeze()), dim=0) | ||
memberships = torch.repeat_interleave(memberships, repeats=128, dim=1) | ||
memberships = torch.repeat_interleave(memberships, repeats=128, dim=2) | ||
return torch.any(memberships.view(memberships.shape[0], -1), dim=1).sum() | ||
|
||
|
||
def main(): | ||
num_clusters = {} | ||
output_filename = "/data/scratch/des/num_clusters.pickle" | ||
for dir_idx, des_dir in enumerate(DES_SUBDIRS): | ||
if dir_idx > 10167: | ||
break | ||
num_clusters[des_dir] = count_num_clusters(dir_idx) | ||
|
||
with open(output_filename, "wb") as handle: | ||
pickle.dump(num_clusters, handle, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.