|
5 | 5 | from timm.layers import trunc_normal_ |
6 | 6 | from torch import nn |
7 | 7 |
|
| 8 | +try: |
| 9 | + from triton.ops.blocksparse import matmul as blocksparse_matmul # noqa |
| 10 | + |
| 11 | + triton_available = True |
| 12 | +except ImportError: |
| 13 | + triton_available = False |
| 14 | + |
8 | 15 | Layer = Callable[..., nn.Module] |
9 | 16 |
|
10 | 17 |
|
@@ -280,3 +287,77 @@ def init_weights(module: nn.Module): |
280 | 287 | nn.init.zeros_(module.bias) |
281 | 288 | elif hasattr(module, "init_weights"): |
282 | 289 | module.init_weights() |
| 290 | + |
| 291 | + |
| 292 | +class BlockSparseLinear(nn.Module): |
| 293 | + """ |
| 294 | + A linear layer with block sparse connectivity. |
| 295 | +
|
| 296 | + Args: |
| 297 | + connectivity: a binary tensor of shape (out_features, in_features) representing |
| 298 | + the connectivity between input and output units. |
| 299 | + bias: use bias |
| 300 | + blocksize: sparse block size, e.g. 16, 32. Must divide each dimension of |
| 301 | + connectivity |
| 302 | +
|
| 303 | + TODO: |
| 304 | + [ ] initialize weight and bias. weight should be masked by connectivity at init. |
| 305 | + think about what the appropriate init std should be. |
| 306 | + [ ] create a dsd sparse matmul kernel following xformers.BlockSparseAttention: |
| 307 | + https://github.com/facebookresearch/xformers/blob/fad50d49834ab18dd137acc727bd4d567ff17842/xformers/components/attention/blocksparse.py#L96 |
| 308 | + [ ] implement forward that should mask weight by connectivity and then call the |
| 309 | + blocksparse matmul kernel |
| 310 | + """ |
| 311 | + |
| 312 | + def __init__( |
| 313 | + self, connectivity: torch.Tensor, bias: bool = True, blocksize: int = 16 |
| 314 | + ): |
| 315 | + assert triton_available, "blocksparse linear requires triton" |
| 316 | + super().__init__() |
| 317 | + self.in_features = connectivity.shape[1] |
| 318 | + self.out_features = connectivity.shape[0] |
| 319 | + self.blocksize = blocksize |
| 320 | + |
| 321 | + # convert to torch blocksparse representation if not already |
| 322 | + connectivity = connectivity.to_sparse_bsr(blocksize) |
| 323 | + |
| 324 | + # block sparse layout as expected by triton |
| 325 | + # shape (1, out_features // block, in_features // block) |
| 326 | + # must be dtype int64 |
| 327 | + layout = torch.sparse_csr_tensor( |
| 328 | + connectivity.crow_indices(), |
| 329 | + connectivity.col_indices(), |
| 330 | + torch.ones_like(connectivity.col_indices()), |
| 331 | + ) |
| 332 | + layout = layout.to_dense().unsqueeze(0) |
| 333 | + |
| 334 | + # only keep raw values, don't need indices since we have layout |
| 335 | + # shape (nnz_blocks, block, block) |
| 336 | + connectivity = (connectivity.values() > 0).float() |
| 337 | + |
| 338 | + self.register_buffer("connectivity", connectivity) |
| 339 | + self.register_buffer("layout", layout) |
| 340 | + |
| 341 | + # TODO: initialize weight and bias |
| 342 | + |
| 343 | + def reset_parameters(self): |
| 344 | + raise NotImplementedError |
| 345 | + |
| 346 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 347 | + raise NotImplementedError |
| 348 | + |
| 349 | + def extra_repr(self) -> str: |
| 350 | + return ( |
| 351 | + f"{self.in_features}, {self.out_features}, " |
| 352 | + f"bias={self.bias is not None}, blocksize={self.blocksize}" |
| 353 | + ) |
| 354 | + |
| 355 | + |
| 356 | +class BlockSparseLocallyConnected(nn.Module): |
| 357 | + """ |
| 358 | + A locally connected layer implemented using block sparse linear. |
| 359 | +
|
| 360 | + TODO: main step is just computing the connectivity based on conv params. shape |
| 361 | + should be something like: (out_height * out_width * out_channels, in_height * |
| 362 | + in_width * in_channels). Then we just use BlockSparseLinear. |
| 363 | + """ |
0 commit comments