Skip to content

networks with ModuleList  #327

Open
Open
@caiuspetronius

Description

@caiuspetronius

Describe the bug
The model structure and the total number of parameters are shown incorrectly for a network that includes several ModuleLists which themselves comprise ModuleLists. I don't know if it is a bug or a missing feature.

To Reproduce
Steps to reproduce the behavior:
You can try running summary on the EncoderVNet network defined below

Expected behavior
Network structure and the total number of parameters shown correctly. The total # parameters calculated as num_params = sum( p.numel() for p in net.parameters() if p.requires_grad ) was 10088490.
Screenshot 2024-10-26 at 11 18 09 PM

Screenshots
If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

  • OS: [e.g. iOS] CentOS Linux
  • Browser [e.g. chrome, safari] N/A
  • Version [e.g. 22] 7

Additional context
The model is a VNet encoder, so it has multiple stages with a residual block at each stage. The residual blocks have multiple steps inside. Both steps and residual blocks are implemented via ModuleLists as follows:

class ResBlock( nn.Module ) :
def init( self, stage, channels, kernel, padstyle, paddings, activation, deep_res = False, nsteps = None, **kwargs ) :
super( ResBlock, self ).init()
if nsteps is None :
self.nsteps = 3 if stage > 3 else stage
else :
self.nsteps = self.nsteps
self.activation = activation
self.deep_res = deep_res
self.convs = nn.ModuleList( [ nn.Conv2d( channels, channels, kernel_size = kernel, padding_mode = padstyle, padding = paddings ) for _ in range( self.nsteps ) ] )
self.norms = nn.ModuleList( [ nn.BatchNorm2d( channels ) for _ in range( self.nsteps ) ] )
self.norm_out = nn.BatchNorm2d( channels )

def forward( self, x ) :
    inp = [ x.clone() ]
    for i in range( self.nsteps ) :
        x = self.convs[ i ]( x )
        if self.deep_res :
            inp.append( x.clone() )
            for j in range( i + 1 ) :  # add output from all previous steps
                x = x + inp[ j ]
        x = self.activation( self.norms[ i ]( x ) )
    return self.activation( self.norm_out( x + inp[ 0 ] ) )  # residual connection over the whole block

class EncoderVNet( nn.Module ) :
def init( self, channels, kernel, padstyle, activation, dropout, deep_res = False, nstages = 5, nsteps = None, **kwargs ) :
super( EncoderVNet, self ).init()
paddings = ( kernel[ 0 ] // 2, kernel[ 1 ] // 2 )
self.channels = channels # channels starting from the number of input image channels and then for each stage
self.kernel = kernel
self.padstyle = padstyle
self.activation = activation
self.deep_res = deep_res
self.nstages = nstages
self.nsteps = nsteps
self.conv_inp = nn.Conv2d( channels[ 0 ], channels[ 1 ], kernel_size = kernel, padding_mode = padstyle, padding = paddings )
self.norm_inp = nn.BatchNorm2d( channels[ 1 ] )
self.drop = nn.Dropout( dropout )
self.res_blocks = nn.ModuleList( [ ResBlock( s + 1, channels[ s + 1 ], kernel, padstyle, paddings, activation, deep_res, nsteps ) for s in range( nstages ) ] )
self.convs_down = nn.ModuleList( [ nn.Conv2d( channels[ s + 1 ], channels[ s + 2 ], kernel_size = 2, stride = 2, padding = 'valid' ) for s in range( nstages - 1 ) ] )
self.norms = nn.ModuleList( [ nn.BatchNorm2d( channels[ s + 2 ] ) for s in range( nstages - 1 ) ] )
if channels[ -1 ] is not None : # bottleneck layer (e.g., for autoencoder)
self.conv_out = nn.Conv2d( channels[ -2 ], channels[ -1 ], kernel_size = 1, padding = 'valid' )

def forward( self, x ) :
    x = self.activation( self.norm_inp( self.conv_inp( x ) ) )  # this matches the number of image channels to the first stage residual sum
    for s in range( self.nstages ) :
        x = self.drop( x )
        x = self.res_blocks[ s ]( x )
        if s < self.nstages - 1 :
            x = self.activation( self.norms[ s ]( self.convs_down[ s ]( x ) ) )
    if self.channels[ -1 ] is not None :  # make a certain number of channels at the output
        x = self.conv_out( x )
    return x

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions