Skip to content

Commit 55fcbbc

Browse files
committed
Fix total_count dtype in binomial-derived distributions
1 parent 606a2dd commit 55fcbbc

2 files changed

Lines changed: 12 additions & 12 deletions

File tree

numpyro/distributions/conjugate.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
self,
5757
concentration1: ArrayLike,
5858
concentration0: ArrayLike,
59-
total_count: int = 1,
59+
total_count: ArrayLike = 1,
6060
*,
6161
validate_args: Optional[bool] = None,
6262
):
@@ -263,7 +263,7 @@ class DirichletMultinomial(Distribution):
263263
def __init__(
264264
self,
265265
concentration: ArrayLike,
266-
total_count: int = 1,
266+
total_count: ArrayLike = 1,
267267
*,
268268
total_count_max: Optional[int] = None,
269269
validate_args: Optional[bool] = None,
@@ -431,7 +431,7 @@ def cdf(self, value: ArrayLike) -> ArrayLike:
431431

432432

433433
def NegativeBinomial(
434-
total_count: int,
434+
total_count: ArrayLike,
435435
probs: Optional[ArrayLike] = None,
436436
logits: Optional[ArrayLike] = None,
437437
*,
@@ -473,7 +473,7 @@ class NegativeBinomialProbs(GammaPoisson):
473473

474474
def __init__(
475475
self,
476-
total_count: int,
476+
total_count: ArrayLike,
477477
probs: ArrayLike,
478478
*,
479479
validate_args: Optional[bool] = None,
@@ -502,7 +502,7 @@ class NegativeBinomialLogits(GammaPoisson):
502502

503503
def __init__(
504504
self,
505-
total_count: int,
505+
total_count: ArrayLike,
506506
logits: ArrayLike,
507507
*,
508508
validate_args: Optional[bool] = None,

numpyro/distributions/discrete.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class BinomialProbs(Distribution):
188188
def __init__(
189189
self,
190190
probs: ArrayLike,
191-
total_count: int = 1,
191+
total_count: ArrayLike = 1,
192192
*,
193193
validate_args: Optional[bool] = None,
194194
):
@@ -266,7 +266,7 @@ class BinomialLogits(Distribution):
266266
def __init__(
267267
self,
268268
logits: ArrayLike,
269-
total_count: int = 1,
269+
total_count: ArrayLike = 1,
270270
*,
271271
validate_args: Optional[bool] = None,
272272
):
@@ -317,7 +317,7 @@ def support(self) -> constraints.Constraint:
317317

318318

319319
def Binomial(
320-
total_count: int = 1,
320+
total_count: ArrayLike = 1,
321321
probs: Optional[ArrayLike] = None,
322322
logits: Optional[ArrayLike] = None,
323323
*,
@@ -575,7 +575,7 @@ class MultinomialProbs(Distribution):
575575
def __init__(
576576
self,
577577
probs: Array,
578-
total_count: int = 1,
578+
total_count: ArrayLike = 1,
579579
*,
580580
total_count_max: Optional[int] = None,
581581
validate_args: Optional[bool] = None,
@@ -629,7 +629,7 @@ def support(self) -> constraints.Constraint:
629629

630630
@staticmethod
631631
def infer_shapes(
632-
probs: Array, total_count: int
632+
probs: Array, total_count: ArrayLike
633633
) -> tuple[tuple[int, ...], tuple[int, ...]]:
634634
batch_shape = lax.broadcast_shapes(probs[:-1], total_count)
635635
event_shape = probs[-1:]
@@ -647,7 +647,7 @@ class MultinomialLogits(Distribution):
647647
def __init__(
648648
self,
649649
logits: Array,
650-
total_count: int = 1,
650+
total_count: ArrayLike = 1,
651651
*,
652652
total_count_max: Optional[int] = None,
653653
validate_args: Optional[bool] = None,
@@ -707,7 +707,7 @@ def support(self) -> constraints.Constraint:
707707

708708
@staticmethod
709709
def infer_shapes(
710-
logits: Array, total_count: int
710+
logits: Array, total_count: ArrayLike
711711
) -> tuple[tuple[int, ...], tuple[int, ...]]:
712712
batch_shape = lax.broadcast_shapes(logits[:-1], total_count)
713713
event_shape = logits[-1:]

0 commit comments

Comments
 (0)