@@ -91,6 +91,8 @@ def seeded_model(data):
9191 -874.89813
9292"""
9393
94+ from __future__ import annotations
95+
9496from collections import OrderedDict
9597from types import TracebackType
9698from typing import Callable , Optional , Union
@@ -103,12 +105,12 @@ def seeded_model(data):
103105from jax .typing import ArrayLike
104106
105107import numpyro
108+ from numpyro ._typing import Message , TraceT
106109from numpyro .distributions .distribution import COERCIONS
107110from numpyro .primitives import (
108111 _PYRO_STACK ,
109112 CondIndepStackFrame ,
110113 DistributionLike ,
111- Message ,
112114 Messenger ,
113115 apply_stack ,
114116 plate ,
@@ -163,9 +165,9 @@ class trace(Messenger):
163165 'value': Array(-0.20584235, dtype=float32)})])
164166 """
165167
166- def __enter__ (self ) -> OrderedDict [ str , Message ] : # type: ignore [override]
168+ def __enter__ (self ) -> TraceT : # type: ignore [override]
167169 super (trace , self ).__enter__ ()
168- self .trace : OrderedDict [ str , Message ] = OrderedDict ()
170+ self .trace : TraceT = OrderedDict ()
169171 return self .trace
170172
171173 def postprocess_message (self , msg : Message ) -> None :
@@ -180,7 +182,7 @@ def postprocess_message(self, msg: Message) -> None:
180182 )
181183 self .trace [msg ["name" ]] = msg .copy ()
182184
183- def get_trace (self , * args , ** kwargs ) -> OrderedDict [ str , Message ] :
185+ def get_trace (self , * args , ** kwargs ) -> TraceT :
184186 """
185187 Run the wrapped callable and return the recorded trace.
186188
@@ -225,7 +227,7 @@ class replay(Messenger):
225227 def __init__ (
226228 self ,
227229 fn : Optional [Callable ] = None ,
228- trace : Optional [OrderedDict [ str , Message ] ] = None ,
230+ trace : Optional [TraceT ] = None ,
229231 ) -> None :
230232 assert trace is not None
231233 self .trace = trace
@@ -357,7 +359,7 @@ def process_message(self, msg: Message) -> None:
357359 if isinstance (msg ["fn" ], Funsor ) or isinstance (msg ["value" ], (str , Funsor )):
358360 msg ["stop" ] = True
359361
360- def __enter__ (self ) -> OrderedDict [ str , Message ] : # type: ignore [override]
362+ def __enter__ (self ) -> TraceT : # type: ignore [override]
361363 self .preserved_plates = frozenset (
362364 h .name for h in _PYRO_STACK if isinstance (h , plate )
363365 )
@@ -451,7 +453,7 @@ def __init__(
451453 raise ValueError ("Only one of `data` or `condition_fn` should be provided." )
452454 super (condition , self ).__init__ (fn )
453455
454- def process_message (self , msg ) :
456+ def process_message (self , msg : Message ) -> None :
455457 if (msg ["type" ] != "sample" ) or msg .get ("_control_flow_done" , False ):
456458 if msg ["type" ] == "control_flow" :
457459 if self .data is not None :
@@ -465,6 +467,7 @@ def process_message(self, msg):
465467 if self .data is not None :
466468 value = self .data .get (msg ["name" ])
467469 else :
470+ assert self .condition_fn is not None
468471 value = self .condition_fn (msg )
469472
470473 if value is not None :
@@ -804,9 +807,9 @@ class seed(Messenger):
804807
805808 def __init__ (
806809 self ,
807- fn : Optional [ Callable ] = None ,
808- rng_seed : Optional [ Array ] = None ,
809- hide_types : Optional [ list [str ]] = None ,
810+ fn : Callable | None = None ,
811+ rng_seed : Array | int | None = None ,
812+ hide_types : list [str ] | None = None ,
810813 ) -> None :
811814 if rng_seed is not None :
812815 if not is_prng_key (rng_seed ) and (
0 commit comments