Skip to content

Commit b7bb25b

Browse files
committed
Initial blocksparse linear outline
1 parent 9e686cb commit b7bb25b

1 file changed

Lines changed: 81 additions & 0 deletions

File tree

columnformers/models/layers.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
from timm.layers import trunc_normal_
66
from torch import nn
77

8+
try:
9+
from triton.ops.blocksparse import matmul as blocksparse_matmul # noqa
10+
11+
triton_available = True
12+
except ImportError:
13+
triton_available = False
14+
815
Layer = Callable[..., nn.Module]
916

1017

@@ -236,3 +243,77 @@ def init_weights(module: nn.Module):
236243
nn.init.zeros_(module.bias)
237244
elif hasattr(module, "init_weights"):
238245
module.init_weights()
246+
247+
248+
class BlockSparseLinear(nn.Module):
249+
"""
250+
A linear layer with block sparse connectivity.
251+
252+
Args:
253+
connectivity: a binary tensor of shape (out_features, in_features) representing
254+
the connectivity between input and output units.
255+
bias: use bias
256+
blocksize: sparse block size, e.g. 16, 32. Must divide each dimension of
257+
connectivity
258+
259+
TODO:
260+
[ ] initialize weight and bias. weight should be masked by connectivity at init.
261+
think about what the appropriate init std should be.
262+
[ ] create a dsd sparse matmul kernel following xformers.BlockSparseAttention:
263+
https://github.com/facebookresearch/xformers/blob/fad50d49834ab18dd137acc727bd4d567ff17842/xformers/components/attention/blocksparse.py#L96
264+
[ ] implement forward that should mask weight by connectivity and then call the
265+
blocksparse matmul kernel
266+
"""
267+
268+
def __init__(
269+
self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 16
270+
):
271+
assert triton_available, "blocksparse linear requires triton"
272+
super().__init__()
273+
self.in_features = connectivity.shape[1]
274+
self.out_features = connectivity.shape[0]
275+
self.blocksize = blocksize
276+
277+
# convert to torch blocksparse representation if not already
278+
connectivity = connectivity.to_sparse_bsr(blocksize)
279+
280+
# block sparse layout as expected by triton
281+
# shape (1, out_features // block, in_features // block)
282+
# must be dtype int64
283+
layout = torch.sparse_csr_tensor(
284+
connectivity.crow_indices(),
285+
connectivity.col_indices(),
286+
torch.ones_like(connectivity.col_indices()),
287+
)
288+
layout = layout.to_dense().unsqueeze(0)
289+
290+
# only keep raw values, don't need indices since we have layout
291+
# shape (nnz_blocks, block, block)
292+
connectivity = (connectivity.values() > 0).float()
293+
294+
self.register_buffer("connectivity", connectivity)
295+
self.register_buffer("layout", layout)
296+
297+
# TODO: initialize weight and bias
298+
299+
def reset_parameters(self):
300+
raise NotImplementedError
301+
302+
def forward(self, x: torch.Tensor) -> torch.Tensor:
303+
raise NotImplementedError
304+
305+
def extra_repr(self) -> str:
306+
return (
307+
f"{self.in_features}, {self.out_features}, "
308+
f"bias={self.bias is not None}, blocksize={self.blocksize}"
309+
)
310+
311+
312+
class BlockSparseLocallyConnected(nn.Module):
313+
"""
314+
A locally connected layer implemented using block sparse linear.
315+
316+
TODO: main step is just computing the connectivity based on conv params. shape
317+
should be something like: (out_height * out_width * out_channels, in_height *
318+
in_width * in_channels). Then we just use BlockSparseLinear.
319+
"""

0 commit comments

Comments
 (0)