import numpy as np
from arviz_stats.base import array_stats
rng = np.random.default_rng()
samples = rng.normal(size=(4, 1, 2))
axis = {"chain_axis": 0, "draw_axis": 1}
print(array_stats.rhat_nested(samples, (0, 0, 1, 1), **axis))
This outputs:
INFO:arviz_stats.base.stats_utils:Shape validation failed: input_shape: (4, 1), minimum_shape: (chains=2, draws=4)
INFO:arviz_stats.base.stats_utils:Shape validation failed: input_shape: (4, 1), minimum_shape: (chains=2, draws=4)
[nan nan]
According to the original paper on nested rhat (https://arxiv.org/abs/2110.13017), there should not be such a restriction on the minimum number of draws. Any reason for this restriction in arviz_stats?
Thanks!
This outputs:
According to the original paper on nested rhat (https://arxiv.org/abs/2110.13017), there should not be such a restriction on the minimum number of draws. Any reason for this restriction in arviz_stats?
Thanks!