|
5 | 5 | from timm.layers import trunc_normal_ |
6 | 6 | from torch import nn |
7 | 7 |
|
| 8 | +try: |
| 9 | + from triton.ops.blocksparse import matmul as blocksparse_matmul # noqa |
| 10 | + |
| 11 | + triton_available = True |
| 12 | +except ImportError: |
| 13 | + triton_available = False |
| 14 | + |
8 | 15 | Layer = Callable[..., nn.Module] |
9 | 16 |
|
10 | 17 |
|
@@ -236,3 +243,77 @@ def init_weights(module: nn.Module): |
236 | 243 | nn.init.zeros_(module.bias) |
237 | 244 | elif hasattr(module, "init_weights"): |
238 | 245 | module.init_weights() |
| 246 | + |
| 247 | + |
| 248 | +class BlockSparseLinear(nn.Module): |
| 249 | + """ |
| 250 | + A linear layer with block sparse connectivity. |
| 251 | +
|
| 252 | + Args: |
| 253 | + connectivity: a binary tensor of shape (out_features, in_features) representing |
| 254 | + the connectivity between input and output units. |
| 255 | + bias: use bias |
| 256 | + blocksize: sparse block size, e.g. 16, 32. Must divide each dimension of |
| 257 | + connectivity |
| 258 | +
|
| 259 | + TODO: |
| 260 | + [ ] initialize weight and bias. weight should be masked by connectivity at init. |
| 261 | + think about what the appropriate init std should be. |
| 262 | + [ ] create a dsd sparse matmul kernel following xformers.BlockSparseAttention: |
| 263 | + https://github.com/facebookresearch/xformers/blob/fad50d49834ab18dd137acc727bd4d567ff17842/xformers/components/attention/blocksparse.py#L96 |
| 264 | + [ ] implement forward that should mask weight by connectivity and then call the |
| 265 | + blocksparse matmul kernel |
| 266 | + """ |
| 267 | + |
| 268 | + def __init__( |
| 269 | + self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 16 |
| 270 | + ): |
| 271 | + assert triton_available, "blocksparse linear requires triton" |
| 272 | + super().__init__() |
| 273 | + self.in_features = connectivity.shape[1] |
| 274 | + self.out_features = connectivity.shape[0] |
| 275 | + self.blocksize = blocksize |
| 276 | + |
| 277 | + # convert to torch blocksparse representation if not already |
| 278 | + connectivity = connectivity.to_sparse_bsr(blocksize) |
| 279 | + |
| 280 | + # block sparse layout as expected by triton |
| 281 | + # shape (1, out_features // block, in_features // block) |
| 282 | + # must be dtype int64 |
| 283 | + layout = torch.sparse_csr_tensor( |
| 284 | + connectivity.crow_indices(), |
| 285 | + connectivity.col_indices(), |
| 286 | + torch.ones_like(connectivity.col_indices()), |
| 287 | + ) |
| 288 | + layout = layout.to_dense().unsqueeze(0) |
| 289 | + |
| 290 | + # only keep raw values, don't need indices since we have layout |
| 291 | + # shape (nnz_blocks, block, block) |
| 292 | + connectivity = (connectivity.values() > 0).float() |
| 293 | + |
| 294 | + self.register_buffer("connectivity", connectivity) |
| 295 | + self.register_buffer("layout", layout) |
| 296 | + |
| 297 | + # TODO: initialize weight and bias |
| 298 | + |
| 299 | + def reset_parameters(self): |
| 300 | + raise NotImplementedError |
| 301 | + |
| 302 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 303 | + raise NotImplementedError |
| 304 | + |
| 305 | + def extra_repr(self) -> str: |
| 306 | + return ( |
| 307 | + f"{self.in_features}, {self.out_features}, " |
| 308 | + f"bias={self.bias is not None}, blocksize={self.blocksize}" |
| 309 | + ) |
| 310 | + |
| 311 | + |
| 312 | +class BlockSparseLocallyConnected(nn.Module): |
| 313 | + """ |
| 314 | + A locally connected layer implemented using block sparse linear. |
| 315 | +
|
| 316 | + TODO: main step is just computing the connectivity based on conv params. shape |
| 317 | + should be something like: (out_height * out_width * out_channels, in_height * |
| 318 | + in_width * in_channels). Then we just use BlockSparseLinear. |
| 319 | + """ |
0 commit comments