Description
Thanks for building such an amazing package and maintaining it.
When attempting to use torchinfo.summary()
on models that include a torch.distributions.Categorical layer in the forward pass, torchinfo
fails to complete the summary and throws an error. This issue appears to stem from Categorical
not having certain tensor-like properties that torchinfo
expects, making it difficult for torchinfo
to process.
This issue impacts users attempting to use Categorical
within the forward pass of probabilistic models and limits torchinfo
's effectiveness for such model architectures.
Reproducible example:
import torch
import torch.nn as nn
import torchinfo
from torch.distributions import Categorical
# Define a simple model with Categorical
class ProbabilisticModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(ProbabilisticModel, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.final = nn.LogSoftmax(dim=-1)
def forward(self, x):
logits = self.final(self.linear(x))
return Categorical(logits=logits)
# Initialize the model and input
model = ProbabilisticModel(input_dim=10, output_dim=5)
input_data = torch.randn(1, 10)
# Attempt to run torchinfo summary
try:
torchinfo.summary(model, input_size=(1, 10))
except Exception as e:
print(f"Encountered an error: {e}")
If you drop Categorical
and just return logits
, it will be fine.
Addressing this issue will provide the model info.
I do not have a solution on top of my mind, however, I think, it might be possible to check some internal method in case of having problem with extracting shape. For example, self._output_shape
if runs into an error, and the user can define it in those customized models.
Activity