Skip to content

Commit f7f2c39

Browse files
authored
docs: use Literal for strict typing of z-scoring options in Public API (#1744)
* refactor: use Literal for strict z-score typing * fix the parser function * revert internal modules to str to avoid type purging issues * revert type hint to str to fix CI
1 parent 723c2d5 commit f7f2c39

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

sbi/neural_nets/factory.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,12 @@ class ZukoFlowType(Enum):
7676

7777
def classifier_nn(
7878
model: str,
79-
z_score_theta: Optional[str] = "independent",
80-
z_score_x: Optional[str] = "independent",
79+
z_score_theta: Optional[
80+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
81+
] = "independent",
82+
z_score_x: Optional[
83+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
84+
] = "independent",
8185
hidden_features: int = 50,
8286
embedding_net_theta: nn.Module = nn.Identity(),
8387
embedding_net_x: nn.Module = nn.Identity(),
@@ -151,8 +155,12 @@ def build_fn(batch_theta, batch_x):
151155

152156
def likelihood_nn(
153157
model: str,
154-
z_score_theta: Optional[str] = "independent",
155-
z_score_x: Optional[str] = "independent",
158+
z_score_theta: Optional[
159+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
160+
] = "independent",
161+
z_score_x: Optional[
162+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
163+
] = "independent",
156164
hidden_features: int = 50,
157165
num_transforms: int = 5,
158166
num_bins: int = 10,
@@ -226,8 +234,12 @@ def build_fn(batch_theta, batch_x):
226234

227235
def posterior_nn(
228236
model: str,
229-
z_score_theta: Optional[str] = "independent",
230-
z_score_x: Optional[str] = "independent",
237+
z_score_theta: Optional[
238+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
239+
] = "independent",
240+
z_score_x: Optional[
241+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
242+
] = "independent",
231243
hidden_features: int = 50,
232244
num_transforms: int = 5,
233245
num_bins: int = 10,
@@ -334,8 +346,12 @@ def posterior_score_nn(
334346
VectorFieldNet,
335347
] = "mlp",
336348
sde_type: str = "ve",
337-
z_score_theta: Optional[str] = "independent",
338-
z_score_x: Optional[str] = "independent",
349+
z_score_theta: Optional[
350+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
351+
] = "independent",
352+
z_score_x: Optional[
353+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
354+
] = "independent",
339355
hidden_features: int = 100,
340356
num_layers: int = 5,
341357
embedding_net: nn.Module = nn.Identity(),
@@ -436,8 +452,12 @@ def build_fn(batch_theta, batch_x):
436452
# TODO: remove this function on next release
437453
def flowmatching_nn(
438454
model: str,
439-
z_score_theta: Optional[str] = "independent",
440-
z_score_x: Optional[str] = "independent",
455+
z_score_theta: Optional[
456+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
457+
] = "independent",
458+
z_score_x: Optional[
459+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
460+
] = "independent",
441461
hidden_features: int = 64,
442462
num_layers: int = 5,
443463
num_blocks: int = 5,
@@ -510,8 +530,12 @@ def posterior_flow_nn(
510530
Literal["mlp", "ada_mlp", "transformer", "transformer_cross_attn"],
511531
VectorFieldNet,
512532
] = "mlp",
513-
z_score_theta: Optional[str] = None,
514-
z_score_x: Optional[str] = "independent",
533+
z_score_theta: Optional[
534+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
535+
] = None,
536+
z_score_x: Optional[
537+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
538+
] = "independent",
515539
hidden_features: int = 100,
516540
num_layers: int = 5,
517541
embedding_net: nn.Module = nn.Identity(),
@@ -592,7 +616,9 @@ def build_fn(batch_theta, batch_x):
592616

593617
def marginal_nn(
594618
model: ZukoFlowType,
595-
z_score_x: Optional[str] = "independent",
619+
z_score_x: Optional[
620+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
621+
] = "independent",
596622
hidden_features: int = 50,
597623
num_transforms: int = 5,
598624
num_bins: int = 10,

sbi/utils/restriction_estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from copy import deepcopy
55
from math import floor
6-
from typing import Any, Callable, Optional, Tuple, Union
6+
from typing import Any, Callable, Literal, Optional, Tuple, Union
77

88
import torch
99
import torch.nn.functional as F
@@ -74,7 +74,9 @@ def __init__(
7474
hidden_features: int = 100,
7575
num_blocks: int = 2,
7676
dropout_probability: float = 0.5,
77-
z_score: Optional[str] = "independent",
77+
z_score: Optional[
78+
Literal["independent", "structured", "transform_to_unconstrained", "none"]
79+
] = "independent",
7880
embedding_net: nn.Module = nn.Identity(),
7981
) -> None:
8082
r"""

0 commit comments

Comments
 (0)