Skip to content

Commit c7f522b

Browse files
refactor(infer.elbo): add type hints to elbo module (#2028)
* initial exploration of typing for infer.elbo * continue adding types, cleanup code dead ends * add _typing module, add ANN rule, finish elbo module * refactor(infer.elbo): bind paramspecs to methods, add type aliases * refactor(infer.elbo): ParticleT -> LossT and update trace elbo annotations * fix: import TypeAlias from typing_extensions * exclude python 3.9 from lint/typecheck, MessageT -> Message
1 parent 17cf5a7 commit c7f522b

13 files changed

Lines changed: 291 additions & 694 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
pip install -r docs/requirements.txt
3737
pip freeze
3838
- name: Lint with mypy and ruff
39+
if: matrix.python-version != '3.9'
3940
run: |
4041
make lint
4142
- name: Build documentation

numpyro/_typing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from collections import OrderedDict
5+
from collections.abc import Callable
6+
from typing import Any
7+
8+
from typing_extensions import ParamSpec, TypeAlias
9+
10+
P = ParamSpec("P")
11+
ModelT: TypeAlias = Callable[P, Any]
12+
13+
Message: TypeAlias = dict[str, Any]
14+
TraceT: TypeAlias = OrderedDict[str, Message]

numpyro/contrib/stochastic_support/dcc.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import jax
99
from jax import random
1010
import jax.numpy as jnp
11-
from jax.typing import ArrayLike
1211

1312
import numpyro.distributions as dist
1413
from numpyro.handlers import condition, seed, trace
@@ -61,7 +60,7 @@ def __init__(self, model: Callable, num_slp_samples: int, max_slps: int) -> None
6160
self.max_slps: int = max_slps
6261

6362
def _find_slps(
64-
self, rng_key: ArrayLike, *args: Any, **kwargs: Any
63+
self, rng_key: jax.Array, *args: Any, **kwargs: Any
6564
) -> dict[str, OrderedDictType]:
6665
"""
6766
Discover the straight-line programs (SLPs) in the model by sampling from the prior.
@@ -109,7 +108,7 @@ def _get_branching_trace(self, tr: dict[str, Any]) -> OrderedDictType:
109108
@abstractmethod
110109
def _run_inference(
111110
self,
112-
rng_key: ArrayLike,
111+
rng_key: jax.Array,
113112
branching_trace: OrderedDictType,
114113
*args: Any,
115114
**kwargs: Any,
@@ -119,7 +118,7 @@ def _run_inference(
119118
@abstractmethod
120119
def _combine_inferences(
121120
self,
122-
rng_key: ArrayLike,
121+
rng_key: jax.Array,
123122
inferences: dict[str, Any],
124123
branching_traces: dict[str, OrderedDictType],
125124
*args: Any,
@@ -128,7 +127,7 @@ def _combine_inferences(
128127
raise NotImplementedError
129128

130129
def run(
131-
self, rng_key: ArrayLike, *args: Any, **kwargs: Any
130+
self, rng_key: jax.Array, *args: Any, **kwargs: Any
132131
) -> Union[DCCResult, SDVIResult]:
133132
"""
134133
Run inference on each SLP separately and combine the results.
@@ -209,7 +208,7 @@ def __init__(
209208

210209
def _run_inference(
211210
self,
212-
rng_key: ArrayLike,
211+
rng_key: jax.Array,
213212
branching_trace: OrderedDictType,
214213
*args: Any,
215214
**kwargs: Any,
@@ -226,7 +225,7 @@ def _run_inference(
226225

227226
def _combine_inferences( # type: ignore[override]
228227
self,
229-
rng_key: ArrayLike,
228+
rng_key: jax.Array,
230229
samples: dict[str, Any],
231230
branching_traces: dict[str, OrderedDictType],
232231
*args: Any,
@@ -244,7 +243,7 @@ def _combine_inferences( # type: ignore[override]
244243
"""
245244

246245
def log_weight(
247-
rng_key: ArrayLike,
246+
rng_key: jax.Array,
248247
i: int,
249248
slp_model: Callable,
250249
slp_samples: dict[str, Any],

numpyro/contrib/stochastic_support/sdvi.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import jax
77
import jax.numpy as jnp
8-
from jax.typing import ArrayLike
98

109
from numpyro.contrib.stochastic_support.dcc import (
1110
RunInferenceResult,
@@ -98,7 +97,7 @@ def __init__(
9897

9998
def _run_inference(
10099
self,
101-
rng_key: ArrayLike,
100+
rng_key: jax.Array,
102101
branching_trace: OrderedDictType,
103102
*args: Any,
104103
**kwargs: Any,
@@ -120,17 +119,17 @@ def _run_inference(
120119

121120
def _combine_inferences( # type: ignore[override]
122121
self,
123-
rng_key: ArrayLike,
122+
rng_key: jax.Array,
124123
guides: dict[str, tuple[AutoGuide, dict[str, Any]]],
125124
branching_traces: dict[str, OrderedDictType],
126125
*args: Any,
127126
**kwargs: Any,
128127
) -> SDVIResult:
129128
"""Weight each SLP proportional to its estimated ELBO."""
130-
elbos = {}
129+
elbos: dict[str, jax.Array] = {}
131130
for bt, (guide, param_map) in guides.items():
132131
slp_model = condition(self.model, branching_traces[bt])
133-
elbos[bt] = -Trace_ELBO(num_particles=self.combine_elbo_particles).loss(
132+
elbos[bt] = -Trace_ELBO(num_particles=self.combine_elbo_particles).loss( # type: ignore
134133
rng_key, param_map, slp_model, guide, *args, **kwargs
135134
)
136135

numpyro/handlers.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def seeded_model(data):
9191
-874.89813
9292
"""
9393

94+
from __future__ import annotations
95+
9496
from collections import OrderedDict
9597
from types import TracebackType
9698
from typing import Callable, Optional, Union
@@ -103,12 +105,12 @@ def seeded_model(data):
103105
from jax.typing import ArrayLike
104106

105107
import numpyro
108+
from numpyro._typing import Message, TraceT
106109
from numpyro.distributions.distribution import COERCIONS
107110
from numpyro.primitives import (
108111
_PYRO_STACK,
109112
CondIndepStackFrame,
110113
DistributionLike,
111-
Message,
112114
Messenger,
113115
apply_stack,
114116
plate,
@@ -163,9 +165,9 @@ class trace(Messenger):
163165
'value': Array(-0.20584235, dtype=float32)})])
164166
"""
165167

166-
def __enter__(self) -> OrderedDict[str, Message]: # type: ignore [override]
168+
def __enter__(self) -> TraceT: # type: ignore [override]
167169
super(trace, self).__enter__()
168-
self.trace: OrderedDict[str, Message] = OrderedDict()
170+
self.trace: TraceT = OrderedDict()
169171
return self.trace
170172

171173
def postprocess_message(self, msg: Message) -> None:
@@ -180,7 +182,7 @@ def postprocess_message(self, msg: Message) -> None:
180182
)
181183
self.trace[msg["name"]] = msg.copy()
182184

183-
def get_trace(self, *args, **kwargs) -> OrderedDict[str, Message]:
185+
def get_trace(self, *args, **kwargs) -> TraceT:
184186
"""
185187
Run the wrapped callable and return the recorded trace.
186188
@@ -225,7 +227,7 @@ class replay(Messenger):
225227
def __init__(
226228
self,
227229
fn: Optional[Callable] = None,
228-
trace: Optional[OrderedDict[str, Message]] = None,
230+
trace: Optional[TraceT] = None,
229231
) -> None:
230232
assert trace is not None
231233
self.trace = trace
@@ -357,7 +359,7 @@ def process_message(self, msg: Message) -> None:
357359
if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)):
358360
msg["stop"] = True
359361

360-
def __enter__(self) -> OrderedDict[str, Message]: # type: ignore [override]
362+
def __enter__(self) -> TraceT: # type: ignore [override]
361363
self.preserved_plates = frozenset(
362364
h.name for h in _PYRO_STACK if isinstance(h, plate)
363365
)
@@ -451,7 +453,7 @@ def __init__(
451453
raise ValueError("Only one of `data` or `condition_fn` should be provided.")
452454
super(condition, self).__init__(fn)
453455

454-
def process_message(self, msg):
456+
def process_message(self, msg: Message) -> None:
455457
if (msg["type"] != "sample") or msg.get("_control_flow_done", False):
456458
if msg["type"] == "control_flow":
457459
if self.data is not None:
@@ -465,6 +467,7 @@ def process_message(self, msg):
465467
if self.data is not None:
466468
value = self.data.get(msg["name"])
467469
else:
470+
assert self.condition_fn is not None
468471
value = self.condition_fn(msg)
469472

470473
if value is not None:
@@ -804,9 +807,9 @@ class seed(Messenger):
804807

805808
def __init__(
806809
self,
807-
fn: Optional[Callable] = None,
808-
rng_seed: Optional[Array] = None,
809-
hide_types: Optional[list[str]] = None,
810+
fn: Callable | None = None,
811+
rng_seed: Array | int | None = None,
812+
hide_types: list[str] | None = None,
810813
) -> None:
811814
if rng_seed is not None:
812815
if not is_prng_key(rng_seed) and (

0 commit comments

Comments
 (0)