Skip to content

Commit

Permalink
Batched to_tile_catalog (#1057)
Browse files Browse the repository at this point in the history
batched to_tile_catalog

Co-authored-by: Yicun Duan <[email protected]>
  • Loading branch information
YicunDuanUMich and Yicun Duan authored Aug 10, 2024
1 parent 017a85c commit 295902b
Show file tree
Hide file tree
Showing 2 changed files with 981 additions and 124 deletions.
245 changes: 133 additions & 112 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,9 @@ def to_tile_catalog(

# TODO: a FullCatalog only needs to "know" its height and width to convert itself to a
# TileCatalog. So those parameters should be passed on conversion, not initialization.
tile_coords = torch.div(self["plocs"], tile_slen, rounding_mode="trunc").to(torch.int)
source_tile_coords = torch.div(self["plocs"], tile_slen, rounding_mode="trunc").to(
torch.int
) # (b, bm, 2)
n_tiles_h = math.ceil(self.height / tile_slen)
n_tiles_w = math.ceil(self.width / tile_slen)

Expand All @@ -587,124 +589,143 @@ def to_tile_catalog(
tile_params[k] = torch.zeros(size, dtype=v.dtype, device=self.device)
tile_params["locs"] = torch.zeros((*tile_cat_shape, 2), device=self.device)

for batch_index in range(self.batch_size):
n_sources = int(self["n_sources"][batch_index].item())
plocs = self["plocs"][batch_index, :n_sources] # (n_sources, 2)
filter_sources = n_sources
source_tile_coords = tile_coords[batch_index, :n_sources]

plocs_start_point = torch.tensor([0, 0], dtype=plocs.dtype, device=plocs.device).view(
1, -1
)
plocs_end_point = torch.tensor(
[self.height, self.width], dtype=plocs.dtype, device=plocs.device
).view(1, -1)
plocs_mask = ((plocs > plocs_start_point) & (plocs < plocs_end_point)).all(dim=-1)
if filter_oob:
filter_sources = plocs_mask.sum()
source_tile_coords = source_tile_coords[plocs_mask]
else:
assert plocs_mask.all(), "find sources that are outside boundary"

if filter_sources == 0:
continue

source_to_tile_indices = (
source_tile_coords[:, 0] * n_tiles_w + source_tile_coords[:, 1]
).to(
torch.int64
) # (n_sources, )
num_sources_on_per_tile = torch.zeros(
n_tiles_h * n_tiles_w,
dtype=inter_int_type,
device=self.device,
) # (h x w, )
num_sources_on_per_tile.scatter_add_(
dim=0,
index=source_to_tile_indices,
src=torch.ones_like(source_to_tile_indices, dtype=inter_int_type),
)
num_shared_tiles_per_source = torch.gather(
num_sources_on_per_tile, dim=0, index=source_to_tile_indices
) # (n_sources, )

# note that this doesn't test overflow
max_shared_tiles = num_shared_tiles_per_source.max().item()
if max_shared_tiles > max_sources_per_tile:
if not ignore_extra_sources:
raise ValueError( # noqa: WPS220
"# of sources per tile exceeds `max_sources_per_tile`."
)
if source_tile_coords.shape[1] == 0:
tile_params["n_sources"] = tile_n_sources
return TileCatalog(tile_params)

for tile_k, tile_v in tile_params.items():
tile_params[tile_k] = self._pad_along_max_sources(
tile_v, target_m=max_shared_tiles
)
# from full cat tensor to tiled cat tensor
batch_size = self["n_sources"].shape[0]
plocs = self["plocs"] # (b, bm, 2)
is_on_mask = self.is_on_mask # (b, bm)

if max_shared_tiles > 1:
source_cum = torch.zeros(
source_to_tile_indices.shape[0], dtype=inter_int_type, device=self.device
plocs_start_point = torch.tensor([0, 0], dtype=plocs.dtype, device=plocs.device)
plocs_start_point = plocs_start_point.view(1, 1, -1) # (1, 1, 2)
plocs_end_point = torch.tensor(
[self.height, self.width], dtype=plocs.dtype, device=plocs.device
)
plocs_end_point = plocs_end_point.view(1, 1, -1) # (1, 1, 2)
plocs_mask = ((plocs >= plocs_start_point) & (plocs <= plocs_end_point)).all(dim=-1)
plocs_mask &= is_on_mask # (b, bm)
if filter_oob and plocs_mask.sum() == 0:
tile_params["n_sources"] = tile_n_sources
return TileCatalog(tile_params)
if not filter_oob:
assert torch.masked_select(
plocs_mask, mask=is_on_mask
).all(), "find sources that are outside boundary"

source_to_tile_indices = (
source_tile_coords[:, :, 0] * n_tiles_w + source_tile_coords[:, :, 1]
) # (b, bm)
source_to_tile_indices = source_to_tile_indices.to(dtype=torch.int64)
source_to_tile_indices = torch.where(
plocs_mask,
source_to_tile_indices,
n_tiles_h * n_tiles_w,
)
num_sources_on_per_tile = torch.zeros(
batch_size,
n_tiles_h * n_tiles_w + 1,
dtype=inter_int_type,
device=self.device,
) # (b, h * w + 1)
num_sources_on_per_tile.scatter_add_(
dim=1,
index=source_to_tile_indices,
src=torch.ones_like(source_to_tile_indices, dtype=inter_int_type),
)
num_sources_on_per_tile[:, -1] = 0
num_shared_tiles_per_source = torch.gather(
num_sources_on_per_tile, dim=1, index=source_to_tile_indices
) # (b, bm)
assert (torch.masked_select(num_shared_tiles_per_source, mask=~plocs_mask) == 0).all()

# note that this doesn't test overflow
max_shared_tiles = num_shared_tiles_per_source.max().item()
if max_shared_tiles > max_sources_per_tile:
if not ignore_extra_sources:
raise ValueError( # noqa: WPS220
"# of sources per tile exceeds `max_sources_per_tile`."
)
for max_s in range(2, max_shared_tiles + 1):
max_s_mask = num_shared_tiles_per_source == max_s
max_s_sum = max_s_mask.sum().item()
assert max_s_sum % max_s == 0
masked_s_to_t_indices = torch.masked_select(
source_to_tile_indices, mask=max_s_mask
)
pos_tensor = torch.arange(
0, max_s, dtype=inter_int_type, device=self.device
).repeat(max_s_sum // max_s)
pos_tensor = torch.scatter(
torch.zeros_like(pos_tensor),
dim=0,
index=torch.argsort(masked_s_to_t_indices, dim=0, stable=stable),
src=pos_tensor,
)
source_cum.masked_scatter_(mask=max_s_mask, source=pos_tensor)
source_to_tile_indices += (
source_cum.to(dtype=source_to_tile_indices.dtype) * n_tiles_h * n_tiles_w
) # (n_sources, )

# get n_sources for each tile
tile_n_sources[batch_index] = rearrange(
num_sources_on_per_tile, "(nth ntw) -> nth ntw", nth=n_tiles_h, ntw=n_tiles_w
).to(dtype=tile_n_sources.dtype)

for tile_k, tile_v in tile_params.items():
if tile_k == "plocs":
raise KeyError("plocs should not be in tile_params")
if tile_k == "n_sources":
raise KeyError("n_sources should not be in tile_params")
if tile_k == "locs":
k = "plocs"
else:
k = tile_k
full_cat_v = self[k][batch_index, :n_sources] # (n_sources, k)
if filter_oob:
full_cat_v = full_cat_v[plocs_mask]

m = tile_v.shape[-2]
transposed_v = rearrange(tile_v[batch_index], "nth ntw m k -> (m nth ntw) k")
repeated_source_to_tile_indices = repeat(
source_to_tile_indices, "n_sources -> n_sources k", k=transposed_v.shape[-1]
)
transposed_v.scatter_(
tile_params[tile_k] = self._pad_along_max_sources(tile_v, target_m=max_shared_tiles)

if max_shared_tiles > 1:
source_cum = torch.zeros_like(source_to_tile_indices, dtype=inter_int_type) # (b, bm)
s_to_t_indices_offset = torch.cumsum(source_to_tile_indices.amax(dim=-1) + 1, dim=0)
s_to_t_indices_offset = s_to_t_indices_offset.unsqueeze(-1) # (b, 1)
s_to_t_indices_w_offset = source_to_tile_indices + s_to_t_indices_offset # (b, bm)
for max_s in range(2, max_shared_tiles + 1):
max_s_mask = num_shared_tiles_per_source == max_s # (b, bm)
max_s_sum = max_s_mask.sum().item()
assert max_s_sum % max_s == 0
masked_s_to_t_indices = torch.masked_select(
s_to_t_indices_w_offset, mask=max_s_mask
) # an 1d tensor
pos_tensor = torch.arange(
0, max_s, dtype=inter_int_type, device=self.device
).repeat(max_s_sum // max_s)
pos_tensor = torch.scatter(
torch.zeros_like(pos_tensor),
dim=0,
index=repeated_source_to_tile_indices,
src=full_cat_v.to(dtype=transposed_v.dtype),
)
target_v = rearrange(
transposed_v, "(m nth ntw) k -> nth ntw m k", m=m, nth=n_tiles_h, ntw=n_tiles_w
index=torch.argsort(masked_s_to_t_indices, dim=0, stable=stable),
src=pos_tensor,
)
if ignore_extra_sources:
target_v = target_v.clone()
tile_v[batch_index] = target_v

# modify tile location
tile_params["locs"][batch_index] = (
tile_params["locs"][batch_index] % tile_slen
) / tile_slen
source_cum.masked_scatter_(mask=max_s_mask, source=pos_tensor)
assert (torch.masked_select(source_cum, mask=~plocs_mask) == 0).all()
source_cum = source_cum.to(dtype=source_to_tile_indices.dtype)
source_to_tile_indices += source_cum * n_tiles_h * n_tiles_w

# get n_sources for each tile
tile_n_sources = rearrange(
num_sources_on_per_tile[:, :-1],
"b (nth ntw) -> b nth ntw",
nth=n_tiles_h,
ntw=n_tiles_w,
).to(dtype=tile_n_sources.dtype)

for tile_k, tile_v in tile_params.items():
if tile_k == "plocs":
raise KeyError("plocs should not be in tile_params")
if tile_k == "n_sources":
raise KeyError("n_sources should not be in tile_params")
if tile_k == "locs":
k = "plocs"
else:
k = tile_k
full_cat_v = self[k] # (b, bm, k)
if filter_oob:
full_cat_v = torch.where(plocs_mask.unsqueeze(-1), full_cat_v, 0)

m = tile_v.shape[-2]
transposed_v = rearrange(tile_v, "b nth ntw m k -> b (m nth ntw) k")
pad = torch.zeros_like(transposed_v)[:, 0:1, :] # (b 1 k)
transposed_v = torch.cat((transposed_v, pad), dim=1) # (b (m nth ntw + 1) k)
s_to_t_indices_w_offset = torch.where(
plocs_mask,
source_to_tile_indices,
n_tiles_h * n_tiles_w * m,
)
repeated_source_to_tile_indices = repeat(
s_to_t_indices_w_offset, "b bm -> b bm k", k=transposed_v.shape[-1]
)
transposed_v.scatter_(
dim=1,
index=repeated_source_to_tile_indices,
src=full_cat_v.to(dtype=transposed_v.dtype),
)
target_v = rearrange(
transposed_v[:, :-1, :],
"b (m nth ntw) k -> b nth ntw m k",
m=m,
nth=n_tiles_h,
ntw=n_tiles_w,
)
tile_params[tile_k] = target_v

# modify tile location
tile_params["locs"] = (tile_params["locs"] % tile_slen) / tile_slen

if ignore_extra_sources:
for tile_k, tile_v in tile_params.items():
Expand Down
Loading

0 comments on commit 295902b

Please sign in to comment.