|
33 | 33 | # use some distributed routines from torch harmonics |
34 | 34 | from torch_harmonics.distributed import distributed_transpose_azimuth as distributed_transpose_w |
35 | 35 | from torch_harmonics.distributed import distributed_transpose_polar as distributed_transpose_h |
36 | | -from torch_harmonics.distributed import DistributedAttentionS2 as THDistributedAttentionS2 |
37 | | - |
38 | 36 |
|
39 | 37 | class _DistMatmulHelper(torch.autograd.Function): |
40 | 38 | @staticmethod |
@@ -509,60 +507,3 @@ def forward(self, x): |
509 | 507 |
|
510 | 508 | return x |
511 | 509 |
|
512 | | - |
513 | | -class DistributedAttentionS2(nn.Module): |
514 | | - def __init__( |
515 | | - self, |
516 | | - in_channels: int, |
517 | | - num_heads: int, |
518 | | - in_shape: Tuple[int], |
519 | | - out_shape: Tuple[int], |
520 | | - grid_in: Optional[str] = "equiangular", |
521 | | - grid_out: Optional[str] = "equiangular", |
522 | | - scale: Optional[Union[torch.Tensor, float]] = None, |
523 | | - bias: Optional[bool] = True, |
524 | | - k_channels: Optional[int] = None, |
525 | | - out_channels: Optional[int] = None, |
526 | | - drop_rate: Optional[float]=0.0, |
527 | | - ): |
528 | | - super().__init__() |
529 | | - |
530 | | - assert in_channels % num_heads == 0, "in_channels should be divisible by num_heads" |
531 | | - assert out_channels % num_heads == 0, "out_channels should be divisible by num_heads" |
532 | | - |
533 | | - self.attn = THDistributedAttentionS2( |
534 | | - in_channels=in_channels, |
535 | | - num_heads=num_heads, |
536 | | - in_shape=in_shape, |
537 | | - out_shape=out_shape, |
538 | | - grid_in=grid_in, |
539 | | - grid_out=grid_out, |
540 | | - scale=scale, |
541 | | - bias=bias, |
542 | | - k_channels=k_channels, |
543 | | - out_channels=out_channels, |
544 | | - drop_rate=drop_rate, |
545 | | - ) |
546 | | - |
547 | | - # set up weight sharing |
548 | | - if comm.get_size("spatial") > 1: |
549 | | - self.attn.q_weights.is_shared_mp = ["spatial"] |
550 | | - self.attn.q_weights.sharded_dims_mp = [None, None, None, None] |
551 | | - self.attn.k_weights.is_shared_mp = ["spatial"] |
552 | | - self.attn.k_weights.sharded_dims_mp = [None, None, None, None] |
553 | | - self.attn.v_weights.is_shared_mp = ["spatial"] |
554 | | - self.attn.v_weights.sharded_dims_mp = [None, None, None, None] |
555 | | - self.attn.proj_weights.is_shared_mp = ["spatial"] |
556 | | - self.attn.proj_weights.sharded_dims_mp = [None, None, None, None] |
557 | | - if bias: |
558 | | - self.attn.q_bias.is_shared_mp = ["spatial"] |
559 | | - self.attn.q_bias.sharded_dims_mp = [None] |
560 | | - self.attn.k_bias.is_shared_mp = ["spatial"] |
561 | | - self.attn.k_bias.sharded_dims_mp = [None] |
562 | | - self.attn.v_bias.is_shared_mp = ["spatial"] |
563 | | - self.attn.v_bias.sharded_dims_mp = [None] |
564 | | - self.attn.proj_bias.is_shared_mp = ["spatial"] |
565 | | - self.attn.proj_bias.sharded_dims_mp = [None] |
566 | | - |
567 | | - def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor: |
568 | | - return self.attn(query, key, value) |
0 commit comments