Skip to content

NotImplementedError: mean is not implemented: Categorical #685

@nbro

Description

@nbro

I get the following error

NotImplementedError: mean is not implemented: Categorical

if I execute the following code

import tensorflow_probability as tfp

prob_dist = tfp.distributions.Categorical(probs=[1.0])
print(prob_dist.mean())

I get a similar error, if I attempt to call stddev(). I am using TFP 0.8.0 with TF 2.

The parameters mean and standard deviation are not usually defined for a categorical distribution and this error is due to the fact that tfp.distributions.Categorical inherits from tfp.distributions.Distribution, which defines the mean and stddev methods. PyCharm even suggests the auto-completion of these two methods for the tfp.distributions.Categorical, but this is at least misleading. Maybe tfp.distributions.Categorical should not inherit fromtfp.distributions.Distribution, which should not have those methods in the first place, because not all distributions have a well-defined mean or standard deviation.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions