-
Notifications
You must be signed in to change notification settings - Fork 11
Add block sparse linear and locally connected layers #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
d0ec304
bc7e748
4489743
d320747
f38640d
6356947
f6581f4
b32e1fa
b5b9a02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,12 @@ | ||
| from typing import Callable, List | ||
| import warnings | ||
| from typing import Callable, List, Literal, Optional, Tuple | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from einops import rearrange | ||
| from timm.layers import trunc_normal_ | ||
| from torch import nn | ||
| from torch.types import _device, _dtype | ||
|
|
||
| Layer = Callable[..., nn.Module] | ||
|
|
||
|
|
@@ -280,3 +283,232 @@ def init_weights(module: nn.Module): | |
| nn.init.zeros_(module.bias) | ||
| elif hasattr(module, "init_weights"): | ||
| module.init_weights() | ||
|
|
||
|
|
||
| class BlockSparseLinear(nn.Module): | ||
| """ | ||
| A linear layer with block sparse connectivity. | ||
|
|
||
| Args: | ||
| connectivity: a binary tensor of shape (out_features, in_features) representing | ||
| the connectivity between input and output units. | ||
| bias: use bias | ||
| blocksize: sparse block size, e.g. 16, 32. Must divide each dimension of | ||
| connectivity | ||
| """ | ||
|
|
||
| connectivity: torch.Tensor | ||
|
|
||
| def __init__( | ||
| self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 32 | ||
| ): | ||
| super().__init__() | ||
| device_capability = _cuda_get_device_capability() | ||
| if device_capability is None or device_capability < (8, 0): | ||
| warnings.warn( | ||
| "BlockSparseLinear only supported for CUDA A100 or higher", | ||
| RuntimeWarning, | ||
| ) | ||
|
|
||
| self.in_features = connectivity.shape[1] | ||
| self.out_features = connectivity.shape[0] | ||
| self.blocksize = blocksize | ||
|
|
||
| # convert to torch blocksparse representation if not already | ||
| connectivity = connectivity.to_sparse_bsr(blocksize).float() | ||
| self.register_buffer("connectivity", connectivity) | ||
|
|
||
| n_blocks = (self.out_features // blocksize) * (self.in_features // blocksize) | ||
| nnz_blocks = connectivity.values().size(0) | ||
| self.sparsity = 1 - (nnz_blocks / n_blocks) | ||
|
|
||
| # Nb, we are using pytorch native block-sparse tensors following this blog: | ||
| # https://pytorch.org/blog/speeding-up-vits/ | ||
| # We were previously using triton blocksparse matmul, but it seems not very | ||
| # stable. See here: | ||
| # https://github.com/triton-lang/triton/pull/4156 | ||
| self.weight = nn.Parameter( | ||
| torch.sparse_bsr_tensor( | ||
| crow_indices=connectivity.crow_indices(), | ||
| col_indices=connectivity.col_indices(), | ||
| values=torch.empty_like(connectivity.values()), | ||
| size=connectivity.size(), | ||
| ) | ||
| ) | ||
|
|
||
| if bias: | ||
| self.bias = nn.Parameter(torch.empty(self.out_features)) | ||
| else: | ||
| self.register_parameter("bias", None) | ||
|
|
||
| self.reset_parameters() | ||
|
|
||
| def reset_parameters(self): | ||
| trunc_normal_(self.weight.values(), std=0.02) | ||
| if self.bias is not None: | ||
| nn.init.zeros_(self.bias) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| # apply sparse connectivity mask | ||
| self.weight.values().data.mul_(self.connectivity.values()) | ||
| x = F.linear(x, self.weight, self.bias) | ||
| return x | ||
|
|
||
| def extra_repr(self) -> str: | ||
| return ( | ||
| f"{self.in_features}, {self.out_features}, " | ||
| f"bias={self.bias is not None}, blocksize={self.blocksize}, " | ||
| f"sparsity={self.sparsity:.2f}" | ||
| ) | ||
|
|
||
|
|
||
| class BlockSparseLocallyConnected(nn.Module): | ||
| """ | ||
| A locally connected layer implemented using block sparse linear. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| in_channels: int, | ||
| out_channels: int, | ||
| kernel_size: int, | ||
| height: int, | ||
| depthwise: bool = False, | ||
| bias: bool = True, | ||
| blocksize: int = 32, | ||
| in_shape: Literal["nlc", "nchw"] = "nchw", | ||
| ): | ||
| super().__init__() | ||
| assert isinstance(kernel_size, int), "only square kernels supported" | ||
| self.in_channels = in_channels | ||
| self.out_channels = out_channels | ||
| self.height = height | ||
| self.kernel_size = kernel_size | ||
| self.depthwise = depthwise | ||
| self.blocksize = blocksize | ||
| self.in_shape = in_shape | ||
| self.channels_last = not depthwise | ||
|
|
||
| connectivity = _sparse_local_connectivity( | ||
| in_channels, | ||
| out_channels, | ||
| kernel_size, | ||
| height, | ||
| depthwise=depthwise, | ||
| channels_last=self.channels_last, | ||
| ) | ||
| self.bsl = BlockSparseLinear( | ||
| connectivity=connectivity, bias=bias, blocksize=blocksize | ||
| ) | ||
|
|
||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
| in_pattern = "n (h w) c" if self.in_shape == "nlc" else "n c h w" | ||
| out_pattern = "n (h w c)" if self.channels_last else "n (c h w)" | ||
| output = rearrange( | ||
| input, f"{in_pattern} -> {out_pattern}", h=self.height, w=self.height | ||
| ) | ||
| output = self.bsl(output) | ||
| output = rearrange( | ||
| output, | ||
| f"{out_pattern} -> {in_pattern}", | ||
| c=self.out_channels, | ||
| h=self.height, | ||
| w=self.height, | ||
| ) | ||
| return output | ||
|
|
||
|
|
||
| def _sparse_local_connectivity( | ||
| in_channels: int, | ||
| out_channels: int, | ||
| kernel_size: int, | ||
| height: int, | ||
| depthwise: bool = False, | ||
| channels_last: bool = False, | ||
| dtype: _dtype = None, | ||
| device: _device = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Construct sparse local connectivity matrix, shape | ||
| (out_channels * height * height, in_channels * height * height). The returned | ||
| connectivity will have sparse COO layout. | ||
|
|
||
| The connectivity pattern is equivalent to | ||
| `nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding="same")` | ||
|
|
||
| If channels_last is True, the shape of connectivity in einops notation is | ||
| "(h w cout) (h w cin)". Otherwise, it is "(cout h w) (cin h w)". The latter should | ||
| be more efficient when depthwise is True (connectivity will be more block sparse). | ||
| """ | ||
| assert kernel_size % 2 == 1, "kernel_size must be odd" | ||
| assert ( | ||
| not depthwise or out_channels == in_channels | ||
| ), "in channels must match out channels for depthwise" | ||
| N = height * height | ||
|
|
||
| # ij indices of input grid | ||
| # (h^2, 2) | ||
| col_indices = torch.cartesian_prod(torch.arange(height), torch.arange(height)) | ||
|
|
||
| # conv kernel index offsets. note that the kernel width is required to be odd. | ||
| # (k^2, 2) | ||
| kernel_half_width = (kernel_size - 1) // 2 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit restrictive, can we adapt this to also include even kernel size by doing something like this?
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think you're right, something like this would probably be better. Although it feels like it should be possible to make the code shorter. More generally, it would probably be best to have exactly the same interface and behavior as Conv2d. What I have now takes a few shortcuts. |
||
| kernel_indices = torch.cartesian_prod( | ||
| torch.arange(-kernel_half_width, kernel_half_width + 1), | ||
| torch.arange(-kernel_half_width, kernel_half_width + 1), | ||
| ) | ||
|
|
||
| # input edge indices for each output unit. these will be the column indices for the | ||
| # sparse COO connectivity. | ||
| # (h^2, k^2, 2) | ||
| col_indices = col_indices.unsqueeze(1) + kernel_indices.unsqueeze(0) | ||
|
|
||
| # input edge row indices | ||
| # (h^2, k^2) | ||
| row_indices = torch.arange(N).unsqueeze(1).repeat(1, kernel_size**2) | ||
|
|
||
| # exclude edges falling outside grid | ||
| mask = ((col_indices >= 0) & (col_indices < height)).all(axis=-1) | ||
| col_indices = col_indices[mask] | ||
| row_indices = row_indices[mask] | ||
|
|
||
| # rasterize column indices | ||
| col_indices = height * col_indices[..., 0] + col_indices[..., 1] | ||
|
|
||
| # add channel blocks with full or depthwise (diagonal) connectivity | ||
| if depthwise: | ||
| channel_indices = torch.arange(out_channels).unsqueeze(1).repeat(1, 2) | ||
| else: | ||
| channel_indices = torch.cartesian_prod( | ||
| torch.arange(out_channels), torch.arange(in_channels) | ||
| ) | ||
|
|
||
| # we can insert the channels axis either at the front or the back | ||
| # front is better for depthwise=True, back is better for depthwise=False | ||
| if channels_last: | ||
| row_indices = out_channels * row_indices.unsqueeze(1) + channel_indices[:, 0] | ||
| col_indices = in_channels * col_indices.unsqueeze(1) + channel_indices[:, 1] | ||
| else: | ||
| row_indices = N * channel_indices[:, 0].unsqueeze(1) + row_indices | ||
| col_indices = N * channel_indices[:, 1].unsqueeze(1) + col_indices | ||
|
|
||
| row_indices = row_indices.flatten() | ||
| col_indices = col_indices.flatten() | ||
|
|
||
| # construct sparse connectivity tensor | ||
| connectivity = ( | ||
| torch.sparse_coo_tensor( | ||
| torch.stack([row_indices, col_indices]), | ||
| torch.ones(len(row_indices), dtype=dtype), | ||
| size=(out_channels * N, in_channels * N), | ||
| ) | ||
| .coalesce() | ||
| .to(device) | ||
| ) | ||
| return connectivity | ||
|
|
||
|
|
||
| def _cuda_get_device_capability() -> Optional[Tuple[int, int]]: | ||
| if not torch.cuda.is_available(): | ||
| return None | ||
| return torch.cuda.get_device_capability() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| import logging | ||
| import torch | ||
| from torch import nn | ||
|
|
||
| import columnformers.models.layers as L | ||
|
|
||
|
|
||
| def test_block_sparse_locally_connected(): | ||
| loc = L.BlockSparseLocallyConnected( | ||
| in_channels=8, | ||
| out_channels=16, | ||
| kernel_size=3, | ||
| height=16, | ||
| depthwise=False, | ||
| ) | ||
| logging.info("%s", loc) | ||
|
|
||
| conv = nn.Conv2d( | ||
| in_channels=8, | ||
| out_channels=16, | ||
| kernel_size=3, | ||
| stride=1, | ||
| padding="same", | ||
| ) | ||
|
|
||
| nn.init.ones_(loc.bsl.weight.values()) | ||
| nn.init.ones_(conv.weight) | ||
| nn.init.zeros_(loc.bsl.bias) | ||
| nn.init.zeros_(conv.bias) | ||
|
|
||
| # TODO: finish testing on cuda | ||
| input = torch.randn(2, 8, 16, 16) | ||
| output_loc = loc(input) | ||
| output_conv = conv(input) | ||
| assert torch.allclose(output_loc, output_conv) |
Uh oh!
There was an error while loading. Please reload this page.