|
12 | 12 | from torch import Tensor
|
13 | 13 |
|
14 | 14 |
|
| 15 | +def convert_mag_to_nmgy(mag): |
| 16 | + return 10 ** ((22.5 - mag) / 2.5) |
| 17 | + |
| 18 | + |
| 19 | +def convert_nmgy_to_mag(nmgy): |
| 20 | + return 22.5 - 2.5 * torch.log10(nmgy) |
| 21 | + |
| 22 | + |
15 | 23 | class SourceType(IntEnum):
|
16 | 24 | STAR = 0
|
17 | 25 | GALAXY = 1
|
@@ -112,6 +120,20 @@ def galaxy_bools(self) -> Tensor:
|
112 | 120 | is_galaxy = self["source_type"] == SourceType.GALAXY
|
113 | 121 | return is_galaxy * self.is_on_mask.unsqueeze(-1)
|
114 | 122 |
|
| 123 | + @property |
| 124 | + def on_fluxes(self): |
| 125 | + """Gets fluxes of "on" sources based on whether the source is a star or galaxy. |
| 126 | +
|
| 127 | + Returns: |
| 128 | + Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1) |
| 129 | + """ |
| 130 | + fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"]) |
| 131 | + return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes)) |
| 132 | + |
| 133 | + @property |
| 134 | + def magnitudes(self): |
| 135 | + return convert_nmgy_to_mag(self.on_fluxes) |
| 136 | + |
115 | 137 | @property
|
116 | 138 | def device(self):
|
117 | 139 | return self.locs.device
|
@@ -242,18 +264,9 @@ def gather_param_at_tiles(self, param_name: str, indices: Tensor) -> Tensor:
|
242 | 264 | idx_to_gather = repeat(indices, "... -> ... k", k=param.size(-1))
|
243 | 265 | return torch.gather(param, dim=1, index=idx_to_gather)
|
244 | 266 |
|
245 |
| - def get_fluxes_of_on_sources(self): |
246 |
| - """Gets fluxes of "on" sources based on whether the source is a star or galaxy. |
247 |
| -
|
248 |
| - Returns: |
249 |
| - Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1) |
250 |
| - """ |
251 |
| - fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"]) |
252 |
| - return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes)) |
253 |
| - |
254 | 267 | def _sort_sources_by_flux(self, band=2):
|
255 | 268 | # sort by fluxes of "on" sources to get brightest source per tile
|
256 |
| - on_fluxes = self.get_fluxes_of_on_sources()[..., band] # shape n x nth x ntw x d |
| 269 | + on_fluxes = self.on_fluxes[..., band] # shape n x nth x ntw x d |
257 | 270 | top_indexes = on_fluxes.argsort(dim=3, descending=True)
|
258 | 271 |
|
259 | 272 | d = {"n_sources": self.n_sources}
|
@@ -310,7 +323,7 @@ def filter_tile_catalog_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):
|
310 | 323 | sorted_self = self._sort_sources_by_flux(band=band)
|
311 | 324 |
|
312 | 325 | # get fluxes of "on" sources to mask by
|
313 |
| - on_fluxes = sorted_self.get_fluxes_of_on_sources()[..., band] |
| 326 | + on_fluxes = sorted_self.on_fluxes[..., band] |
314 | 327 | flux_mask = (on_fluxes > min_flux) & (on_fluxes < max_flux)
|
315 | 328 |
|
316 | 329 | d = {}
|
@@ -443,25 +456,42 @@ def to(self, device):
|
443 | 456 | def device(self):
|
444 | 457 | return self.plocs.device
|
445 | 458 |
|
446 |
| - def get_is_on_mask(self) -> Tensor: |
| 459 | + @property |
| 460 | + def is_on_mask(self) -> Tensor: |
447 | 461 | arange = torch.arange(self.max_sources, device=self.device)
|
448 | 462 | return arange.view(1, -1) < self.n_sources.view(-1, 1)
|
449 | 463 |
|
450 | 464 | @property
|
451 | 465 | def star_bools(self) -> Tensor:
|
452 | 466 | is_star = self["source_type"] == SourceType.STAR
|
453 | 467 | assert is_star.size(1) == self.max_sources
|
454 |
| - is_on_mask = self.get_is_on_mask() |
455 | 468 | assert is_star.size(2) == 1
|
456 |
| - return is_star * is_on_mask.unsqueeze(2) |
| 469 | + return is_star * self.is_on_mask.unsqueeze(2) |
457 | 470 |
|
458 | 471 | @property
|
459 | 472 | def galaxy_bools(self) -> Tensor:
|
460 | 473 | is_galaxy = self["source_type"] == SourceType.GALAXY
|
461 | 474 | assert is_galaxy.size(1) == self.max_sources
|
462 |
| - is_on_mask = self.get_is_on_mask() |
463 | 475 | assert is_galaxy.size(2) == 1
|
464 |
| - return is_galaxy * is_on_mask.unsqueeze(2) |
| 476 | + return is_galaxy * self.is_on_mask.unsqueeze(2) |
| 477 | + |
| 478 | + @property |
| 479 | + def on_fluxes(self) -> Tensor: |
| 480 | + """Gets fluxes of "on" sources based on whether the source is a star or galaxy. |
| 481 | +
|
| 482 | + Returns: |
| 483 | + Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1) |
| 484 | + """ |
| 485 | + if "fluxes" in self: |
| 486 | + # ideally we'd always store fluxes rather than star_fluxes and galaxy_fluxes |
| 487 | + fluxes = self.get("fluxes") |
| 488 | + else: |
| 489 | + fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"]) |
| 490 | + return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes)) |
| 491 | + |
| 492 | + @property |
| 493 | + def magnitudes(self): |
| 494 | + return convert_nmgy_to_mag(self.on_fluxes) |
465 | 495 |
|
466 | 496 | def one_source(self, b: int, s: int):
|
467 | 497 | """Return a dict containing all parameter for one specified light source."""
|
@@ -613,12 +643,12 @@ def to_astropy_table(self, encoder_survey_bands: Tuple[str]) -> Table:
|
613 | 643 |
|
614 | 644 | # Convert dictionary of tensors to list of dictionaries
|
615 | 645 | on_vals = {}
|
616 |
| - is_on_mask = self.get_is_on_mask() |
| 646 | + is_on_mask = self.is_on_mask |
617 | 647 | for k, v in self.to_dict().items():
|
618 | 648 | if k == "n_sources":
|
619 | 649 | continue
|
620 | 650 | if k == "galaxy_params":
|
621 |
| - # reshape get_is_on_mask() to have same last dimension as galaxy_params |
| 651 | + # reshape is_on_mask to have same last dimension as galaxy_params |
622 | 652 | galaxy_params_mask = is_on_mask.unsqueeze(-1).expand_as(v)
|
623 | 653 | on_vals[k] = v[galaxy_params_mask].reshape(-1, v.shape[-1]).cpu()
|
624 | 654 | else:
|
|
0 commit comments