We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9f8e318 commit 1a7a60cCopy full SHA for 1a7a60c
makani/models/common/layers.py
@@ -51,12 +51,12 @@ def forward(self, x):
51
return drop_path(x, self.drop_prob, self.training)
52
53
54
-class SeededDropout(nn.Module):
+class SeededDropout2d(nn.Module):
55
def __init__(self, drop_prob=0.0, seed=333):
56
- super(SeededDropout, self).__init__()
+ super(SeededDropout2d, self).__init__()
57
self.drop_prob = drop_prob
58
self.seed = seed
59
- self.drop = nn.Dropout(p=self.drop_prob)
+ self.drop = nn.Dropout2d(p=self.drop_prob)
60
61
# set RNG states
62
self.rng_cpu = torch.Generator(device=torch.device("cpu"))
0 commit comments