diff --git a/columnformers/models/layers.py b/columnformers/models/layers.py index 035084c..d0202a3 100644 --- a/columnformers/models/layers.py +++ b/columnformers/models/layers.py @@ -1,3 +1,4 @@ +import math import warnings from typing import Callable, List, Literal, Optional, Tuple @@ -375,6 +376,8 @@ def __init__( out_channels: int, kernel_size: int, height: int, + padding: int, + stride: int, depthwise: bool = False, bias: bool = True, blocksize: int = 32, @@ -385,6 +388,8 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels self.height = height + self.padding = padding + self.stride = stride self.kernel_size = kernel_size self.depthwise = depthwise self.blocksize = blocksize @@ -396,6 +401,7 @@ def __init__( out_channels, kernel_size, height, + stride, depthwise=depthwise, channels_last=self.channels_last, ) @@ -403,25 +409,58 @@ def __init__( connectivity=connectivity, bias=bias, blocksize=blocksize ) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: + # TODO: figure out how to handle padding for flattened inputs needs_reshape = self.in_shape != "nd" if needs_reshape: in_pattern = "n (h w) c" if self.in_shape == "nlc" else "n c h w" out_pattern = "n (h w c)" if self.channels_last else "n (c h w)" - input = rearrange( - input, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height + x = rearrange( + x, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height + ) + padding = ( + ( + 0, + 0, + self.padding, + self.padding, + self.padding, + self.padding, + 0, + 0, + ) + if out_pattern == "n (h w c)" + else ( + self.padding, + self.padding, + self.padding, + self.padding, + 0, + 0, + 0, + 0, + ) ) - output = self.bsl(input) + x = F.pad(x, padding, "constant", 0) + + output = self.bsl(x) + + out_height = ( + math.floor( + (self.height - self.kernel_size + 2 * self.padding) / self.stride + ) + + 1 + ) # number of kernels that fit in one dimension of input if needs_reshape: output = rearrange( output, f"{out_pattern} -> {in_pattern}", c=self.out_channels, - h=self.height, - w=self.height, + h=out_height, + w=out_height, ) return output @@ -431,6 +470,7 @@ def _sparse_local_connectivity( out_channels: int, kernel_size: int, height: int, + stride: int, depthwise: bool = False, channels_last: bool = False, dtype: _dtype = None, @@ -438,32 +478,35 @@ def _sparse_local_connectivity( ) -> torch.Tensor: """ Construct sparse local connectivity matrix, shape - (out_channels * height * height, in_channels * height * height). The returned - connectivity will have sparse COO layout. + (out_channels * h_out * h_out, in_channels * height * height), where + h_out = ⌊(height + 2 * padding - 1) / stride + 1⌋. + The returned connectivity will have sparse COO layout. The connectivity pattern is equivalent to - `nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding="same")` + `nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)` If channels_last is True, the shape of connectivity in einops notation is "(h w cout) (h w cin)". Otherwise, it is "(cout h w) (cin h w)". The latter should be more efficient when depthwise is True (connectivity will be more block sparse). """ - assert kernel_size % 2 == 1, "kernel_size must be odd" assert ( not depthwise or out_channels == in_channels ), "in channels must match out channels for depthwise" - N = height * height + N = height**2 # ij indices of input grid # (h^2, 2) - col_indices = torch.cartesian_prod(torch.arange(height), torch.arange(height)) + col_indices = torch.cartesian_prod( + torch.arange(0, height, stride), torch.arange(0, height, stride) + ) + + num_rows = col_indices.shape[0] # conv kernel index offsets. note that the kernel width is required to be odd. # (k^2, 2) - kernel_half_width = (kernel_size - 1) // 2 kernel_indices = torch.cartesian_prod( - torch.arange(-kernel_half_width, kernel_half_width + 1), - torch.arange(-kernel_half_width, kernel_half_width + 1), + torch.arange(0, kernel_size), + torch.arange(0, kernel_size), ) # input edge indices for each output unit. these will be the column indices for the @@ -473,12 +516,19 @@ def _sparse_local_connectivity( # input edge row indices # (h^2, k^2) - row_indices = torch.arange(N).unsqueeze(1).repeat(1, kernel_size**2) + row_indices = torch.zeros(num_rows, kernel_size**2) # exclude edges falling outside grid mask = ((col_indices >= 0) & (col_indices < height)).all(axis=-1) + mask = mask.all(axis=1).unsqueeze(1).repeat(1, kernel_size**2) col_indices = col_indices[mask] - row_indices = row_indices[mask] + row_indices = row_indices[mask].reshape(-1, kernel_size**2) + + out_height = row_indices.shape[0] + + row_indices = ( + torch.arange(out_height).unsqueeze(1).repeat(1, kernel_size**2).flatten() + ) # rasterize column indices col_indices = height * col_indices[..., 0] + col_indices[..., 1] @@ -497,7 +547,7 @@ def _sparse_local_connectivity( row_indices = out_channels * row_indices.unsqueeze(1) + channel_indices[:, 0] col_indices = in_channels * col_indices.unsqueeze(1) + channel_indices[:, 1] else: - row_indices = N * channel_indices[:, 0].unsqueeze(1) + row_indices + row_indices = out_height * channel_indices[:, 0].unsqueeze(1) + row_indices col_indices = N * channel_indices[:, 1].unsqueeze(1) + col_indices row_indices = row_indices.flatten() @@ -508,7 +558,7 @@ def _sparse_local_connectivity( torch.sparse_coo_tensor( torch.stack([row_indices, col_indices]), torch.ones(len(row_indices), dtype=dtype), - size=(out_channels * N, in_channels * N), + size=(out_channels * out_height, in_channels * N), ) .coalesce() .to(device)