10
10
from bliss .encoder .metrics import CatalogMatcher
11
11
12
12
13
+ def convert_nmgy_to_njymag (nmgy ):
14
+ """Convert from flux (nano-maggie) to mag (nano-jansky), which is the format used by DC2.
15
+
16
+ For the difference between mag (Pogson magnitude) and njymag (AB magnitude), please view
17
+ the "Flux units: maggies and nanomaggies" part of
18
+ https://www.sdss3.org/dr8/algorithms/magnitudes.php#nmgy
19
+ When we change the standard source to AB sources, we need to do the conversion
20
+ described in "2.10 AB magnitudes" at
21
+ https://pstn-001.lsst.io/fluxunits.pdf
22
+
23
+ Args:
24
+ nmgy: the fluxes in nanomaggies
25
+
26
+ Returns:
27
+ Tensor indicating fluxes in AB magnitude
28
+ """
29
+
30
+ return 22.5 - 2.5 * torch .log10 (nmgy / 3631 )
31
+
32
+
13
33
class MetricBin (Metric ):
14
34
def __init__ (
15
35
self ,
@@ -67,6 +87,7 @@ def __init__(
67
87
68
88
def update (self , true_cat , est_cat , matching ):
69
89
cutoffs = torch .tensor (self .bin_cutoffs , device = self .device )
90
+ on_fluxes = convert_nmgy_to_njymag (true_cat .on_fluxes )
70
91
for i in range (true_cat .batch_size ):
71
92
tcat_matches , ecat_matches = matching [i ]
72
93
@@ -80,7 +101,7 @@ def update(self, true_cat, est_cat, matching):
80
101
true_red = true_cat ["redshifts" ][i , tcat_matches , :].to (self .device )
81
102
est_red = est_cat ["redshifts" ][i , ecat_matches , :].to (self .device )
82
103
83
- true_mag = true_cat . on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
104
+ true_mag = on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
84
105
bin_indices = torch .bucketize (true_mag , cutoffs )
85
106
86
107
red_err = (true_red - est_red ).abs () ** 2
@@ -238,6 +259,7 @@ def __init__(
238
259
239
260
def update (self , true_cat , est_cat , matching ):
240
261
cutoffs = torch .tensor (self .bin_cutoffs , device = self .device )
262
+ on_fluxes = convert_nmgy_to_njymag (true_cat .on_fluxes )
241
263
for i in range (true_cat .batch_size ):
242
264
tcat_matches , ecat_matches = matching [i ]
243
265
@@ -251,7 +273,7 @@ def update(self, true_cat, est_cat, matching):
251
273
true_red = true_cat ["redshifts" ][i , tcat_matches , :].to (self .device )
252
274
est_red = est_cat ["redshifts" ][i , ecat_matches , :].to (self .device )
253
275
254
- true_mag = true_cat . on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
276
+ true_mag = on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
255
277
bin_indices = torch .bucketize (true_mag , cutoffs )
256
278
257
279
metric_outlier = torch .abs (true_red - est_red ) / (1 + true_red )
@@ -322,6 +344,7 @@ def __init__(
322
344
323
345
def update (self , true_cat , est_cat , matching ):
324
346
cutoffs = torch .tensor (self .bin_cutoffs , device = self .device )
347
+ on_fluxes = convert_nmgy_to_njymag (true_cat .on_fluxes )
325
348
for i in range (true_cat .batch_size ):
326
349
tcat_matches , ecat_matches = matching [i ]
327
350
@@ -335,7 +358,7 @@ def update(self, true_cat, est_cat, matching):
335
358
true_red = true_cat ["redshifts" ][i , tcat_matches , :].to (self .device )
336
359
est_red = est_cat ["redshifts" ][i , ecat_matches , :].to (self .device )
337
360
338
- true_mag = true_cat . on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
361
+ true_mag = on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
339
362
bin_indices = torch .bucketize (true_mag , cutoffs )
340
363
341
364
metric_outlier_cata = torch .abs (true_red - est_red )
@@ -399,6 +422,7 @@ def __init__(self, **kwargs):
399
422
400
423
def update (self , true_cat , est_cat , matching ):
401
424
cutoffs = torch .tensor (self .bin_cutoffs , device = self .device )
425
+ on_fluxes = convert_nmgy_to_njymag (true_cat .on_fluxes )
402
426
for i in range (true_cat .batch_size ):
403
427
tcat_matches , ecat_matches = matching [i ]
404
428
@@ -412,7 +436,7 @@ def update(self, true_cat, est_cat, matching):
412
436
true_red = true_cat ["redshifts" ][i , tcat_matches , :].to (self .device )
413
437
est_red = est_cat ["redshifts" ][i , ecat_matches , :].to (self .device )
414
438
415
- true_mag = true_cat . on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
439
+ true_mag = on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
416
440
bin_indices = torch .bucketize (true_mag , cutoffs )
417
441
418
442
metrics = (true_red - est_red ) / (1 + true_red )
@@ -492,6 +516,7 @@ def __init__(self, **kwargs):
492
516
493
517
def update (self , true_cat , est_cat , matching ):
494
518
cutoffs = torch .tensor (self .bin_cutoffs , device = self .device )
519
+ on_fluxes = convert_nmgy_to_njymag (true_cat .on_fluxes )
495
520
for i in range (true_cat .batch_size ):
496
521
tcat_matches , ecat_matches = matching [i ]
497
522
@@ -505,7 +530,7 @@ def update(self, true_cat, est_cat, matching):
505
530
true_red = true_cat ["redshifts" ][i , tcat_matches , :].to (self .device )
506
531
est_red = est_cat ["redshifts" ][i , ecat_matches , :].to (self .device )
507
532
508
- true_mag = true_cat . on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
533
+ true_mag = on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
509
534
bin_indices = torch .bucketize (true_mag , cutoffs )
510
535
511
536
metrics = (true_red - est_red ) / (1 + true_red )
@@ -577,6 +602,7 @@ def __init__(self, **kwargs):
577
602
578
603
def update (self , true_cat , est_cat , matching ):
579
604
cutoffs = torch .tensor (self .bin_cutoffs , device = self .device )
605
+ on_fluxes = convert_nmgy_to_njymag (true_cat .on_fluxes )
580
606
for i in range (true_cat .batch_size ):
581
607
tcat_matches , ecat_matches = matching [i ]
582
608
@@ -590,7 +616,7 @@ def update(self, true_cat, est_cat, matching):
590
616
true_red = true_cat ["redshifts" ][i , tcat_matches , :].to (self .device )
591
617
est_red = est_cat ["redshifts" ][i , ecat_matches , :].to (self .device )
592
618
593
- true_mag = true_cat . on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
619
+ true_mag = on_fluxes [i ][..., self .mag_band ][tcat_matches ].to (self .device )
594
620
bin_indices = torch .bucketize (true_mag , cutoffs )
595
621
596
622
metrics = torch .abs (true_red - est_red ) / (1 + true_red )
0 commit comments