-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Description
I get the following error
NotImplementedError: mean is not implemented: Categorical
if I execute the following code
import tensorflow_probability as tfp
prob_dist = tfp.distributions.Categorical(probs=[1.0])
print(prob_dist.mean())
I get a similar error, if I attempt to call stddev(). I am using TFP 0.8.0 with TF 2.
The parameters mean and standard deviation are not usually defined for a categorical distribution and this error is due to the fact that tfp.distributions.Categorical inherits from tfp.distributions.Distribution, which defines the mean and stddev methods. PyCharm even suggests the auto-completion of these two methods for the tfp.distributions.Categorical, but this is at least misleading. Maybe tfp.distributions.Categorical should not inherit fromtfp.distributions.Distribution, which should not have those methods in the first place, because not all distributions have a well-defined mean or standard deviation.