Skip to content

rhat_nested returns nan when the number of draws is one (or less than four) #354

@pipme

Description

@pipme
import numpy as np
from arviz_stats.base import array_stats

rng = np.random.default_rng()
samples = rng.normal(size=(4, 1, 2))
axis = {"chain_axis": 0, "draw_axis": 1}
print(array_stats.rhat_nested(samples, (0, 0, 1, 1), **axis))

This outputs:

INFO:arviz_stats.base.stats_utils:Shape validation failed: input_shape: (4, 1), minimum_shape: (chains=2, draws=4)
INFO:arviz_stats.base.stats_utils:Shape validation failed: input_shape: (4, 1), minimum_shape: (chains=2, draws=4)

[nan nan]

According to the original paper on nested rhat (https://arxiv.org/abs/2110.13017), there should not be such a restriction on the minimum number of draws. Any reason for this restriction in arviz_stats?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions