Skip to content

Commit 8f30e0f

Browse files
committed
trying to incorporate bias in distributed attention
1 parent 0c134ef commit 8f30e0f

File tree

3 files changed

+4
-63
lines changed

3 files changed

+4
-63
lines changed

makani/models/stepper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ def _forward_eval(self, inp, update_state=True, replace_state=True):
153153

154154
return y
155155

156-
def forward(self, inp, replace_state=True):
156+
def forward(self, inp, update_state=True, replace_state=True):
157157
# decide which routine to call
158158
if self.training:
159-
y = self._forward_train(inp, update_state=True, replace_state=replace_state)
159+
y = self._forward_train(inp, update_state=update_state, replace_state=replace_state)
160160
else:
161-
y = self._forward_eval(inp, update_state=True, replace_state=replace_state)
161+
y = self._forward_eval(inp, update_state=update_state, replace_state=replace_state)
162162

163163
return y

makani/mpu/layers.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
# use some distributed routines from torch harmonics
3434
from torch_harmonics.distributed import distributed_transpose_azimuth as distributed_transpose_w
3535
from torch_harmonics.distributed import distributed_transpose_polar as distributed_transpose_h
36-
from torch_harmonics.distributed import DistributedAttentionS2 as THDistributedAttentionS2
37-
3836

3937
class _DistMatmulHelper(torch.autograd.Function):
4038
@staticmethod
@@ -509,60 +507,3 @@ def forward(self, x):
509507

510508
return x
511509

512-
513-
class DistributedAttentionS2(nn.Module):
514-
def __init__(
515-
self,
516-
in_channels: int,
517-
num_heads: int,
518-
in_shape: Tuple[int],
519-
out_shape: Tuple[int],
520-
grid_in: Optional[str] = "equiangular",
521-
grid_out: Optional[str] = "equiangular",
522-
scale: Optional[Union[torch.Tensor, float]] = None,
523-
bias: Optional[bool] = True,
524-
k_channels: Optional[int] = None,
525-
out_channels: Optional[int] = None,
526-
drop_rate: Optional[float]=0.0,
527-
):
528-
super().__init__()
529-
530-
assert in_channels % num_heads == 0, "in_channels should be divisible by num_heads"
531-
assert out_channels % num_heads == 0, "out_channels should be divisible by num_heads"
532-
533-
self.attn = THDistributedAttentionS2(
534-
in_channels=in_channels,
535-
num_heads=num_heads,
536-
in_shape=in_shape,
537-
out_shape=out_shape,
538-
grid_in=grid_in,
539-
grid_out=grid_out,
540-
scale=scale,
541-
bias=bias,
542-
k_channels=k_channels,
543-
out_channels=out_channels,
544-
drop_rate=drop_rate,
545-
)
546-
547-
# set up weight sharing
548-
if comm.get_size("spatial") > 1:
549-
self.attn.q_weights.is_shared_mp = ["spatial"]
550-
self.attn.q_weights.sharded_dims_mp = [None, None, None, None]
551-
self.attn.k_weights.is_shared_mp = ["spatial"]
552-
self.attn.k_weights.sharded_dims_mp = [None, None, None, None]
553-
self.attn.v_weights.is_shared_mp = ["spatial"]
554-
self.attn.v_weights.sharded_dims_mp = [None, None, None, None]
555-
self.attn.proj_weights.is_shared_mp = ["spatial"]
556-
self.attn.proj_weights.sharded_dims_mp = [None, None, None, None]
557-
if bias:
558-
self.attn.q_bias.is_shared_mp = ["spatial"]
559-
self.attn.q_bias.sharded_dims_mp = [None]
560-
self.attn.k_bias.is_shared_mp = ["spatial"]
561-
self.attn.k_bias.sharded_dims_mp = [None]
562-
self.attn.v_bias.is_shared_mp = ["spatial"]
563-
self.attn.v_bias.sharded_dims_mp = [None]
564-
self.attn.proj_bias.is_shared_mp = ["spatial"]
565-
self.attn.proj_bias.sharded_dims_mp = [None]
566-
567-
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None) -> torch.Tensor:
568-
return self.attn(query, key, value)

tests/test_constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_hydrostatic_balance_constraint_wrapper_era5(self):
111111
# check the hb loss
112112
hb_loss_tens = hbloss(data_map, None)
113113

114-
# average over batch and sum over channels
114+
# average over batch and sum over channels
115115
hb_loss_val = torch.mean(torch.sum(hb_loss_tens, dim=1)).item()
116116

117117
self.assertTrue(hb_loss_val <= 1e-6)

0 commit comments

Comments
 (0)