Skip to content

Commit 0218658

Browse files
authored
Merge pull request #5 from NVIDIA/bbonev/0.1.0
Addressing logic in the big skip connection
2 parents 94d9e45 + 6cc4b10 commit 0218658

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

makani/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
__version__ = "0.1.0a1"
16+
__version__ = "0.1.0"
1717

1818
from .utils.trainer import Trainer
1919
from .utils.inferencer import Inferencer

makani/models/networks/sfnonet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -460,10 +460,10 @@ def __init__(
460460

461461
# output transform
462462
if self.big_skip:
463-
self.residual_transform = nn.Conv2d(self.out_chans, self.out_chans, 1, bias=False)
463+
self.residual_transform = nn.Conv2d(self.inp_chans, self.out_chans, 1, bias=False)
464464
self.residual_transform.weight.is_shared_mp = ["spatial"]
465465
self.residual_transform.weight.sharded_dims_mp = [None, None, None, None]
466-
scale = math.sqrt(0.5 / self.out_chans)
466+
scale = math.sqrt(0.5 / self.inp_chans)
467467
nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale)
468468

469469
# learned position embedding
@@ -591,15 +591,15 @@ def forward(self, x):
591591
if self.out_shape != self.inp_shape:
592592
xtype = x.dtype
593593
# only take the predicted channels as residual
594-
residual = x[..., : self.out_chans, :, :].to(torch.float32)
594+
residual = x.to(torch.float32)
595595
with amp.autocast(enabled=False):
596596
residual = self.trans_down(residual)
597597
residual = residual.contiguous()
598598
residual = self.itrans_up(residual)
599599
residual = residual.to(dtype=xtype)
600600
else:
601601
# only take the predicted channels
602-
residual = x[..., : self.out_chans, :, :].contiguous()
602+
residual = x
603603

604604
if comm.get_size("fin") > 1:
605605
x = scatter_to_parallel_region(x, 1, "fin")

0 commit comments

Comments
 (0)