Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 233 additions & 1 deletion columnformers/models/layers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Callable, List
import warnings
from typing import Callable, List, Literal, Optional, Tuple

import torch
import torch.nn.functional as F
from einops import rearrange
from timm.layers import trunc_normal_
from torch import nn
from torch.types import _device, _dtype

Layer = Callable[..., nn.Module]

Expand Down Expand Up @@ -280,3 +283,232 @@ def init_weights(module: nn.Module):
nn.init.zeros_(module.bias)
elif hasattr(module, "init_weights"):
module.init_weights()


class BlockSparseLinear(nn.Module):
"""
A linear layer with block sparse connectivity.

Args:
connectivity: a binary tensor of shape (out_features, in_features) representing
the connectivity between input and output units.
bias: use bias
blocksize: sparse block size, e.g. 16, 32. Must divide each dimension of
connectivity
"""

connectivity: torch.Tensor

def __init__(
self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 32
):
super().__init__()
device_capability = _cuda_get_device_capability()
if device_capability is None or device_capability < (8, 0):
warnings.warn(
"BlockSparseLinear only supported for CUDA A100 or higher",
RuntimeWarning,
)

self.in_features = connectivity.shape[1]
self.out_features = connectivity.shape[0]
self.blocksize = blocksize

# convert to torch blocksparse representation if not already
connectivity = connectivity.to_sparse_bsr(blocksize).float()
self.register_buffer("connectivity", connectivity)

n_blocks = (self.out_features // blocksize) * (self.in_features // blocksize)
nnz_blocks = connectivity.values().size(0)
self.sparsity = 1 - (nnz_blocks / n_blocks)

# Nb, we are using pytorch native block-sparse tensors following this blog:
# https://pytorch.org/blog/speeding-up-vits/
# We were previously using triton blocksparse matmul, but it seems not very
# stable. See here:
# https://github.com/triton-lang/triton/pull/4156
self.weight = nn.Parameter(
torch.sparse_bsr_tensor(
crow_indices=connectivity.crow_indices(),
col_indices=connectivity.col_indices(),
values=torch.empty_like(connectivity.values()),
size=connectivity.size(),
)
)

if bias:
self.bias = nn.Parameter(torch.empty(self.out_features))
else:
self.register_parameter("bias", None)

self.reset_parameters()

def reset_parameters(self):
trunc_normal_(self.weight.values(), std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# apply sparse connectivity mask
self.weight.values().data.mul_(self.connectivity.values())
x = F.linear(x, self.weight, self.bias)
return x

def extra_repr(self) -> str:
return (
f"{self.in_features}, {self.out_features}, "
f"bias={self.bias is not None}, blocksize={self.blocksize}, "
f"sparsity={self.sparsity:.2f}"
)


class BlockSparseLocallyConnected(nn.Module):
"""
A locally connected layer implemented using block sparse linear.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
height: int,
depthwise: bool = False,
bias: bool = True,
blocksize: int = 32,
in_shape: Literal["nlc", "nchw"] = "nchw",
):
super().__init__()
assert isinstance(kernel_size, int), "only square kernels supported"
self.in_channels = in_channels
self.out_channels = out_channels
self.height = height
self.kernel_size = kernel_size
self.depthwise = depthwise
self.blocksize = blocksize
self.in_shape = in_shape
self.channels_last = not depthwise

connectivity = _sparse_local_connectivity(
in_channels,
out_channels,
kernel_size,
height,
depthwise=depthwise,
channels_last=self.channels_last,
)
self.bsl = BlockSparseLinear(
connectivity=connectivity, bias=bias, blocksize=blocksize
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
in_pattern = "n (h w) c" if self.in_shape == "nlc" else "n c h w"
out_pattern = "n (h w c)" if self.channels_last else "n (c h w)"
output = rearrange(
input, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height
)
output = self.bsl(output)
output = rearrange(
Comment thread
clane9 marked this conversation as resolved.
Outdated
output,
f"{out_pattern} -> {in_pattern}",
c=self.out_channels,
h=self.height,
w=self.height,
)
return output


def _sparse_local_connectivity(
in_channels: int,
out_channels: int,
kernel_size: int,
height: int,
depthwise: bool = False,
channels_last: bool = False,
dtype: _dtype = None,
device: _device = None,
) -> torch.Tensor:
"""
Construct sparse local connectivity matrix, shape
(out_channels * height * height, in_channels * height * height). The returned
connectivity will have sparse COO layout.

The connectivity pattern is equivalent to
`nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding="same")`

If channels_last is True, the shape of connectivity in einops notation is
"(h w cout) (h w cin)". Otherwise, it is "(cout h w) (cin h w)". The latter should
be more efficient when depthwise is True (connectivity will be more block sparse).
"""
assert kernel_size % 2 == 1, "kernel_size must be odd"
assert (
not depthwise or out_channels == in_channels
), "in channels must match out channels for depthwise"
N = height * height

# ij indices of input grid
# (h^2, 2)
col_indices = torch.cartesian_prod(torch.arange(height), torch.arange(height))

# conv kernel index offsets. note that the kernel width is required to be odd.
# (k^2, 2)
kernel_half_width = (kernel_size - 1) // 2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit restrictive, can we adapt this to also include even kernel size by doing something like this?

if kernel_size % 2 == 0:
    kernel_half_width = kernel_size // 2
    kernel_indices = torch.cartesian_prod(
        torch.arange(-kernel_half_width, kernel_half_width),
        torch.arange(-kernel_half_width, kernel_half_width),
    )
else:
    kernel_half_width = (kernel_size - 1) // 2
    kernel_indices = torch.cartesian_prod(
        torch.arange(-kernel_half_width, kernel_half_width + 1),
        torch.arange(-kernel_half_width, kernel_half_width + 1),
    )

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think you're right, something like this would probably be better. Although it feels like it should be possible to make the code shorter.

More generally, it would probably be best to have exactly the same interface and behavior as Conv2d. What I have now takes a few shortcuts.

kernel_indices = torch.cartesian_prod(
torch.arange(-kernel_half_width, kernel_half_width + 1),
torch.arange(-kernel_half_width, kernel_half_width + 1),
)

# input edge indices for each output unit. these will be the column indices for the
# sparse COO connectivity.
# (h^2, k^2, 2)
col_indices = col_indices.unsqueeze(1) + kernel_indices.unsqueeze(0)

# input edge row indices
# (h^2, k^2)
row_indices = torch.arange(N).unsqueeze(1).repeat(1, kernel_size**2)

# exclude edges falling outside grid
mask = ((col_indices >= 0) & (col_indices < height)).all(axis=-1)
col_indices = col_indices[mask]
row_indices = row_indices[mask]

# rasterize column indices
col_indices = height * col_indices[..., 0] + col_indices[..., 1]

# add channel blocks with full or depthwise (diagonal) connectivity
if depthwise:
channel_indices = torch.arange(out_channels).unsqueeze(1).repeat(1, 2)
else:
channel_indices = torch.cartesian_prod(
torch.arange(out_channels), torch.arange(in_channels)
)

# we can insert the channels axis either at the front or the back
# front is better for depthwise=True, back is better for depthwise=False
if channels_last:
row_indices = out_channels * row_indices.unsqueeze(1) + channel_indices[:, 0]
col_indices = in_channels * col_indices.unsqueeze(1) + channel_indices[:, 1]
else:
row_indices = N * channel_indices[:, 0].unsqueeze(1) + row_indices
col_indices = N * channel_indices[:, 1].unsqueeze(1) + col_indices

row_indices = row_indices.flatten()
col_indices = col_indices.flatten()

# construct sparse connectivity tensor
connectivity = (
torch.sparse_coo_tensor(
torch.stack([row_indices, col_indices]),
torch.ones(len(row_indices), dtype=dtype),
size=(out_channels * N, in_channels * N),
)
.coalesce()
.to(device)
)
return connectivity


def _cuda_get_device_capability() -> Optional[Tuple[int, int]]:
if not torch.cuda.is_available():
return None
return torch.cuda.get_device_capability()
35 changes: 35 additions & 0 deletions tests/test_models/test_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
import torch
from torch import nn

import columnformers.models.layers as L


def test_block_sparse_locally_connected():
loc = L.BlockSparseLocallyConnected(
in_channels=8,
out_channels=16,
kernel_size=3,
height=16,
depthwise=False,
)
logging.info("%s", loc)

conv = nn.Conv2d(
in_channels=8,
out_channels=16,
kernel_size=3,
stride=1,
padding="same",
)

nn.init.ones_(loc.bsl.weight.values())
nn.init.ones_(conv.weight)
nn.init.zeros_(loc.bsl.bias)
nn.init.zeros_(conv.bias)

# TODO: finish testing on cuda
input = torch.randn(2, 8, 16, 16)
output_loc = loc(input)
output_conv = conv(input)
assert torch.allclose(output_loc, output_conv)