Skip to content

Issue with torchinfo.summary() Failing on Models with torch.distributions.Categorical #329

Open
@Naeemkh

Description

Thanks for building such an amazing package and maintaining it.

When attempting to use torchinfo.summary() on models that include a torch.distributions.Categorical layer in the forward pass, torchinfo fails to complete the summary and throws an error. This issue appears to stem from Categorical not having certain tensor-like properties that torchinfo expects, making it difficult for torchinfo to process.

This issue impacts users attempting to use Categorical within the forward pass of probabilistic models and limits torchinfo's effectiveness for such model architectures.

Reproducible example:

import torch
import torch.nn as nn
import torchinfo
from torch.distributions import Categorical

# Define a simple model with Categorical
class ProbabilisticModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProbabilisticModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.final = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        logits = self.final(self.linear(x))
        return Categorical(logits=logits)

# Initialize the model and input
model = ProbabilisticModel(input_dim=10, output_dim=5)
input_data = torch.randn(1, 10)

# Attempt to run torchinfo summary
try:
    torchinfo.summary(model, input_size=(1, 10))
except Exception as e:
    print(f"Encountered an error: {e}")

If you drop Categorical and just return logits, it will be fine.

Addressing this issue will provide the model info.

I do not have a solution on top of my mind, however, I think, it might be possible to check some internal method in case of having problem with extracting shape. For example, self._output_shape if runs into an error, and the user can define it in those customized models.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions