@@ -460,10 +460,10 @@ def __init__(
460460
461461 # output transform
462462 if self .big_skip :
463- self .residual_transform = nn .Conv2d (self .out_chans , self .out_chans , 1 , bias = False )
463+ self .residual_transform = nn .Conv2d (self .inp_chans , self .out_chans , 1 , bias = False )
464464 self .residual_transform .weight .is_shared_mp = ["spatial" ]
465465 self .residual_transform .weight .sharded_dims_mp = [None , None , None , None ]
466- scale = math .sqrt (0.5 / self .out_chans )
466+ scale = math .sqrt (0.5 / self .inp_chans )
467467 nn .init .normal_ (self .residual_transform .weight , mean = 0.0 , std = scale )
468468
469469 # learned position embedding
@@ -591,15 +591,15 @@ def forward(self, x):
591591 if self .out_shape != self .inp_shape :
592592 xtype = x .dtype
593593 # only take the predicted channels as residual
594- residual = x [..., : self . out_chans , :, :] .to (torch .float32 )
594+ residual = x .to (torch .float32 )
595595 with amp .autocast (enabled = False ):
596596 residual = self .trans_down (residual )
597597 residual = residual .contiguous ()
598598 residual = self .itrans_up (residual )
599599 residual = residual .to (dtype = xtype )
600600 else :
601601 # only take the predicted channels
602- residual = x [..., : self . out_chans , :, :]. contiguous ()
602+ residual = x
603603
604604 if comm .get_size ("fin" ) > 1 :
605605 x = scatter_to_parallel_region (x , 1 , "fin" )
0 commit comments