Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 69 additions & 19 deletions columnformers/models/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import warnings
from typing import Callable, List, Literal, Optional, Tuple

Expand Down Expand Up @@ -375,6 +376,8 @@ def __init__(
out_channels: int,
kernel_size: int,
height: int,
padding: int,
stride: int,
depthwise: bool = False,
bias: bool = True,
blocksize: int = 32,
Expand All @@ -385,6 +388,8 @@ def __init__(
self.in_channels = in_channels
self.out_channels = out_channels
self.height = height
self.padding = padding
self.stride = stride
self.kernel_size = kernel_size
self.depthwise = depthwise
self.blocksize = blocksize
Expand All @@ -396,32 +401,66 @@ def __init__(
out_channels,
kernel_size,
height,
stride,
depthwise=depthwise,
channels_last=self.channels_last,
)
self.bsl = BlockSparseLinear(
connectivity=connectivity, bias=bias, blocksize=blocksize
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# TODO: figure out how to handle padding for flattened inputs
needs_reshape = self.in_shape != "nd"

if needs_reshape:
in_pattern = "n (h w) c" if self.in_shape == "nlc" else "n c h w"
out_pattern = "n (h w c)" if self.channels_last else "n (c h w)"
input = rearrange(
input, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height
x = rearrange(
x, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height
)
padding = (
(
0,
0,
self.padding,
self.padding,
self.padding,
self.padding,
0,
0,
)
if out_pattern == "n (h w c)"
else (
self.padding,
self.padding,
self.padding,
self.padding,
0,
0,
0,
0,
)
)

output = self.bsl(input)
x = F.pad(x, padding, "constant", 0)

output = self.bsl(x)

out_height = (
math.floor(
(self.height - self.kernel_size + 2 * self.padding) / self.stride
)
+ 1
) # number of kernels that fit in one dimension of input

if needs_reshape:
output = rearrange(
output,
f"{out_pattern} -> {in_pattern}",
c=self.out_channels,
h=self.height,
w=self.height,
h=out_height,
w=out_height,
)
return output

Expand All @@ -431,39 +470,43 @@ def _sparse_local_connectivity(
out_channels: int,
kernel_size: int,
height: int,
stride: int,
depthwise: bool = False,
channels_last: bool = False,
dtype: _dtype = None,
device: _device = None,
) -> torch.Tensor:
"""
Construct sparse local connectivity matrix, shape
(out_channels * height * height, in_channels * height * height). The returned
connectivity will have sparse COO layout.
(out_channels * h_out * h_out, in_channels * height * height), where
h_out = ⌊(height + 2 * padding - 1) / stride + 1⌋.
The returned connectivity will have sparse COO layout.

The connectivity pattern is equivalent to
`nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding="same")`
`nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)`

If channels_last is True, the shape of connectivity in einops notation is
"(h w cout) (h w cin)". Otherwise, it is "(cout h w) (cin h w)". The latter should
be more efficient when depthwise is True (connectivity will be more block sparse).
"""
assert kernel_size % 2 == 1, "kernel_size must be odd"
assert (
not depthwise or out_channels == in_channels
), "in channels must match out channels for depthwise"
N = height * height
N = height**2

# ij indices of input grid
# (h^2, 2)
col_indices = torch.cartesian_prod(torch.arange(height), torch.arange(height))
col_indices = torch.cartesian_prod(
torch.arange(0, height, stride), torch.arange(0, height, stride)
)

num_rows = col_indices.shape[0]

# conv kernel index offsets. note that the kernel width is required to be odd.
# (k^2, 2)
kernel_half_width = (kernel_size - 1) // 2
kernel_indices = torch.cartesian_prod(
torch.arange(-kernel_half_width, kernel_half_width + 1),
torch.arange(-kernel_half_width, kernel_half_width + 1),
torch.arange(0, kernel_size),
torch.arange(0, kernel_size),
)

# input edge indices for each output unit. these will be the column indices for the
Expand All @@ -473,12 +516,19 @@ def _sparse_local_connectivity(

# input edge row indices
# (h^2, k^2)
row_indices = torch.arange(N).unsqueeze(1).repeat(1, kernel_size**2)
row_indices = torch.zeros(num_rows, kernel_size**2)

# exclude edges falling outside grid
mask = ((col_indices >= 0) & (col_indices < height)).all(axis=-1)
mask = mask.all(axis=1).unsqueeze(1).repeat(1, kernel_size**2)
col_indices = col_indices[mask]
row_indices = row_indices[mask]
row_indices = row_indices[mask].reshape(-1, kernel_size**2)

out_height = row_indices.shape[0]

row_indices = (
torch.arange(out_height).unsqueeze(1).repeat(1, kernel_size**2).flatten()
)

# rasterize column indices
col_indices = height * col_indices[..., 0] + col_indices[..., 1]
Expand All @@ -497,7 +547,7 @@ def _sparse_local_connectivity(
row_indices = out_channels * row_indices.unsqueeze(1) + channel_indices[:, 0]
col_indices = in_channels * col_indices.unsqueeze(1) + channel_indices[:, 1]
else:
row_indices = N * channel_indices[:, 0].unsqueeze(1) + row_indices
row_indices = out_height * channel_indices[:, 0].unsqueeze(1) + row_indices
col_indices = N * channel_indices[:, 1].unsqueeze(1) + col_indices

row_indices = row_indices.flatten()
Expand All @@ -508,7 +558,7 @@ def _sparse_local_connectivity(
torch.sparse_coo_tensor(
torch.stack([row_indices, col_indices]),
torch.ones(len(row_indices), dtype=dtype),
size=(out_channels * N, in_channels * N),
size=(out_channels * out_height, in_channels * N),
)
.coalesce()
.to(device)
Expand Down