Skip to content

Commit d0ec304

Browse files
committed
Initial blocksparse linear outline
1 parent d01a730 commit d0ec304

1 file changed

Lines changed: 81 additions & 0 deletions

File tree

columnformers/models/layers.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
from timm.layers import trunc_normal_
66
from torch import nn
77

8+
try:
9+
from triton.ops.blocksparse import matmul as blocksparse_matmul # noqa
10+
11+
triton_available = True
12+
except ImportError:
13+
triton_available = False
14+
815
Layer = Callable[..., nn.Module]
916

1017

@@ -280,3 +287,77 @@ def init_weights(module: nn.Module):
280287
nn.init.zeros_(module.bias)
281288
elif hasattr(module, "init_weights"):
282289
module.init_weights()
290+
291+
292+
class BlockSparseLinear(nn.Module):
293+
"""
294+
A linear layer with block sparse connectivity.
295+
296+
Args:
297+
connectivity: a binary tensor of shape (out_features, in_features) representing
298+
the connectivity between input and output units.
299+
bias: use bias
300+
blocksize: sparse block size, e.g. 16, 32. Must divide each dimension of
301+
connectivity
302+
303+
TODO:
304+
[ ] initialize weight and bias. weight should be masked by connectivity at init.
305+
think about what the appropriate init std should be.
306+
[ ] create a dsd sparse matmul kernel following xformers.BlockSparseAttention:
307+
https://github.com/facebookresearch/xformers/blob/fad50d49834ab18dd137acc727bd4d567ff17842/xformers/components/attention/blocksparse.py#L96
308+
[ ] implement forward that should mask weight by connectivity and then call the
309+
blocksparse matmul kernel
310+
"""
311+
312+
def __init__(
313+
self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 16
314+
):
315+
assert triton_available, "blocksparse linear requires triton"
316+
super().__init__()
317+
self.in_features = connectivity.shape[1]
318+
self.out_features = connectivity.shape[0]
319+
self.blocksize = blocksize
320+
321+
# convert to torch blocksparse representation if not already
322+
connectivity = connectivity.to_sparse_bsr(blocksize)
323+
324+
# block sparse layout as expected by triton
325+
# shape (1, out_features // block, in_features // block)
326+
# must be dtype int64
327+
layout = torch.sparse_csr_tensor(
328+
connectivity.crow_indices(),
329+
connectivity.col_indices(),
330+
torch.ones_like(connectivity.col_indices()),
331+
)
332+
layout = layout.to_dense().unsqueeze(0)
333+
334+
# only keep raw values, don't need indices since we have layout
335+
# shape (nnz_blocks, block, block)
336+
connectivity = (connectivity.values() > 0).float()
337+
338+
self.register_buffer("connectivity", connectivity)
339+
self.register_buffer("layout", layout)
340+
341+
# TODO: initialize weight and bias
342+
343+
def reset_parameters(self):
344+
raise NotImplementedError
345+
346+
def forward(self, x: torch.Tensor) -> torch.Tensor:
347+
raise NotImplementedError
348+
349+
def extra_repr(self) -> str:
350+
return (
351+
f"{self.in_features}, {self.out_features}, "
352+
f"bias={self.bias is not None}, blocksize={self.blocksize}"
353+
)
354+
355+
356+
class BlockSparseLocallyConnected(nn.Module):
357+
"""
358+
A locally connected layer implemented using block sparse linear.
359+
360+
TODO: main step is just computing the connectivity based on conv params. shape
361+
should be something like: (out_height * out_width * out_channels, in_height *
362+
in_width * in_channels). Then we just use BlockSparseLinear.
363+
"""

0 commit comments

Comments
 (0)