@@ -188,7 +188,7 @@ class BinomialProbs(Distribution):
188188 def __init__ (
189189 self ,
190190 probs : ArrayLike ,
191- total_count : int = 1 ,
191+ total_count : ArrayLike = 1 ,
192192 * ,
193193 validate_args : Optional [bool ] = None ,
194194 ):
@@ -266,7 +266,7 @@ class BinomialLogits(Distribution):
266266 def __init__ (
267267 self ,
268268 logits : ArrayLike ,
269- total_count : int = 1 ,
269+ total_count : ArrayLike = 1 ,
270270 * ,
271271 validate_args : Optional [bool ] = None ,
272272 ):
@@ -317,7 +317,7 @@ def support(self) -> constraints.Constraint:
317317
318318
319319def Binomial (
320- total_count : int = 1 ,
320+ total_count : ArrayLike = 1 ,
321321 probs : Optional [ArrayLike ] = None ,
322322 logits : Optional [ArrayLike ] = None ,
323323 * ,
@@ -575,7 +575,7 @@ class MultinomialProbs(Distribution):
575575 def __init__ (
576576 self ,
577577 probs : Array ,
578- total_count : int = 1 ,
578+ total_count : ArrayLike = 1 ,
579579 * ,
580580 total_count_max : Optional [int ] = None ,
581581 validate_args : Optional [bool ] = None ,
@@ -629,7 +629,7 @@ def support(self) -> constraints.Constraint:
629629
630630 @staticmethod
631631 def infer_shapes (
632- probs : Array , total_count : int
632+ probs : Array , total_count : ArrayLike
633633 ) -> tuple [tuple [int , ...], tuple [int , ...]]:
634634 batch_shape = lax .broadcast_shapes (probs [:- 1 ], total_count )
635635 event_shape = probs [- 1 :]
@@ -647,7 +647,7 @@ class MultinomialLogits(Distribution):
647647 def __init__ (
648648 self ,
649649 logits : Array ,
650- total_count : int = 1 ,
650+ total_count : ArrayLike = 1 ,
651651 * ,
652652 total_count_max : Optional [int ] = None ,
653653 validate_args : Optional [bool ] = None ,
@@ -707,7 +707,7 @@ def support(self) -> constraints.Constraint:
707707
708708 @staticmethod
709709 def infer_shapes (
710- logits : Array , total_count : int
710+ logits : Array , total_count : ArrayLike
711711 ) -> tuple [tuple [int , ...], tuple [int , ...]]:
712712 batch_shape = lax .broadcast_shapes (logits [:- 1 ], total_count )
713713 event_shape = logits [- 1 :]
0 commit comments