Skip to content

paddle.distribution.Categorical采样概率错误 #54733

Open
@loulouzny

Description

@loulouzny

bug描述 Describe the Bug

paddle.distribution.Categorical里面,采样用的是softmax之后的概率,但算probs的时候用又没有用softmax,直接用的归一化的值?这里是不是搞错了。
paddle源码:

class Categorical(distribution.Distribution):
    def __init__(self, logits, name=None):
        ...
        dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True)
        self._prob = self.logits / dist_sum    # 这里相当于torch里面Categorical的probs

    def _logits_to_probs(self, logits, is_binary=False):
        return (
            paddle.nn.functional.sigmoid(logits)
            if is_binary
            else paddle.nn.functional.softmax(logits, axis=-1)
        )
    
   def sample(self, shape):
        ...
        sample_index = multinomial(self._logits_to_probs(logits), num_samples, True)  # 但这里又相当于torch里面Categorical的logits,套了softmax

   def probs(self, value):
        ...
        return paddle.gather(self._prob,
                                 value.reshape([-1], name=name),
                                 name=name).reshape(value.shape, name=name)

其他补充信息 Additional Supplementary Information

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions