Skip to content

Commit 1a7a60c

Browse files
committed
dbg
1 parent 9f8e318 commit 1a7a60c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

makani/models/common/layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ def forward(self, x):
5151
return drop_path(x, self.drop_prob, self.training)
5252

5353

54-
class SeededDropout(nn.Module):
54+
class SeededDropout2d(nn.Module):
5555
def __init__(self, drop_prob=0.0, seed=333):
56-
super(SeededDropout, self).__init__()
56+
super(SeededDropout2d, self).__init__()
5757
self.drop_prob = drop_prob
5858
self.seed = seed
59-
self.drop = nn.Dropout(p=self.drop_prob)
59+
self.drop = nn.Dropout2d(p=self.drop_prob)
6060

6161
# set RNG states
6262
self.rng_cpu = torch.Generator(device=torch.device("cpu"))

0 commit comments

Comments
 (0)