Skip to content

Commit 617d5a4

Browse files
committed
adapt types; update posterior parameters.
1 parent c53a83a commit 617d5a4

File tree

5 files changed

+165
-206
lines changed

5 files changed

+165
-206
lines changed

sbi/inference/posteriors/posterior_parameters.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Any,
88
Callable,
99
Dict,
10-
Iterable,
1110
Literal,
1211
Optional,
1312
Union,
@@ -17,7 +16,7 @@
1716
)
1817

1918
from sbi.inference.posteriors.vi_posterior import VIPosterior
20-
from sbi.sbi_types import PyroTransformedDistribution, TorchTransform
19+
from sbi.sbi_types import TorchTransform, VariationalDistribution
2120
from sbi.utils.typechecks import (
2221
is_nonnegative_int,
2322
is_positive_float,
@@ -334,61 +333,66 @@ def validate(self):
334333
@dataclass(frozen=True)
335334
class VIPosteriorParameters(PosteriorParameters):
336335
"""
337-
Parameters for initializing VIPosterior.
336+
Parameters for VIPosterior, supporting both single-x and amortized VI.
338337
339338
Fields:
340-
q: Variational distribution, either string, `TransformedDistribution`, or a
341-
`VIPosterior` object. This specifies a parametric class of distribution
342-
over which the best possible posterior approximation is searched. For
343-
string input, we currently support [nsf, scf, maf, mcf, gaussian,
344-
gaussian_diag]. You can also specify your own variational family by
345-
passing a pyro `TransformedDistribution`.
346-
Additionally, we allow a `Callable`, which allows you the pass a
347-
`builder` function, which if called returns a distribution. This may be
348-
useful for setting the hyperparameters e.g. `num_transfroms` within the
349-
`get_flow_builder` method specifying the number of transformations
350-
within a normalizing flow. If q is already a `VIPosterior`, then the
351-
arguments will be copied from it (relevant for multi-round training).
352-
vi_method: This specifies the variational methods which are used to fit q to
353-
the posterior. We currently support [rKL, fKL, IW, alpha]. Note that
354-
some of the divergences are `mode seeking` i.e. they underestimate
355-
variance and collapse on multimodal targets (`rKL`, `alpha` for alpha >
356-
1) and some are `mass covering` i.e. they overestimate variance but
357-
typically cover all modes (`fKL`, `IW`, `alpha` for alpha < 1).
358-
parameters: List of parameters of the variational posterior. This is only
359-
required for user-defined q i.e. if q does not have a `parameters`
360-
attribute.
361-
modules: List of modules of the variational posterior. This is only
362-
required for user-defined q i.e. if q does not have a `modules`
363-
attribute.
339+
q: Variational distribution. Either a string specifying the flow type
340+
[nsf, maf, naf, unaf, nice, sospf, gaussian, gaussian_diag], a
341+
`TransformedDistribution`, a `VIPosterior` object, or a `Callable`
342+
builder function. For amortized VI, use string flow types only.
343+
If q is already a `VIPosterior`, arguments are copied from it
344+
(relevant for multi-round training).
345+
vi_method: Variational method for fitting q to the posterior. Options:
346+
[rKL, fKL, IW, alpha]. Some are "mode seeking" (rKL, alpha > 1) and
347+
some are "mass covering" (fKL, IW, alpha < 1). Currently only used
348+
for single-x VI; amortized VI uses ELBO (rKL).
349+
num_transforms: Number of transforms in the normalizing flow.
350+
hidden_features: Hidden layer size in the flow networks.
351+
z_score_theta: Method for z-scoring θ (the parameters being modeled).
352+
One of "none", "independent", "structured". Use "structured" for
353+
parameters with correlations.
354+
z_score_x: Method for z-scoring x (the conditioning variable, amortized
355+
VI only). One of "none", "independent", "structured". Use
356+
"structured" for structured data like images.
357+
358+
Note:
359+
For custom distributions that lack `parameters()` and `modules()` methods,
360+
pass these via `VIPosterior.set_q(q, parameters=..., modules=...)` instead.
364361
"""
365362

366363
q: Union[
367-
Literal["nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"],
368-
PyroTransformedDistribution,
364+
Literal[
365+
"nsf", "maf", "naf", "unaf", "nice", "sospf", "gaussian", "gaussian_diag"
366+
],
367+
VariationalDistribution,
369368
"VIPosterior",
370369
Callable,
371370
] = "maf"
372371
vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL"
373-
parameters: Optional[Iterable] = None
374-
modules: Optional[Iterable] = None
372+
num_transforms: int = 5
373+
hidden_features: int = 50
374+
z_score_theta: Literal["none", "independent", "structured"] = "independent"
375+
z_score_x: Literal["none", "independent", "structured"] = "independent"
375376

376377
def validate(self):
377378
"""Validate VIPosteriorParameters fields."""
378-
379-
valid_q = {"nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"}
379+
valid_q = {
380+
"nsf", "maf", "naf", "unaf", "nice", "sospf", "gaussian", "gaussian_diag"
381+
}
380382

381383
if isinstance(self.q, str) and self.q not in valid_q:
382384
raise ValueError(f"If `q` is a string, it must be one of {valid_q}")
383385
elif not isinstance(
384-
self.q, (PyroTransformedDistribution, VIPosterior, Callable, str)
386+
self.q, (VariationalDistribution, VIPosterior, Callable, str)
385387
):
386388
raise TypeError(
387-
"q must be either of typr PyroTransformedDistribution,"
388-
" VIPosterioror or Callable"
389+
"q must be either of type VariationalDistribution,"
390+
" VIPosterior or Callable"
389391
)
390392

391-
if self.parameters is not None and not isinstance(self.parameters, Iterable):
392-
raise TypeError("parameters must be iterable or None.")
393-
if self.modules is not None and not isinstance(self.modules, Iterable):
394-
raise TypeError("modules must be iterable or None.")
393+
if self.num_transforms < 1:
394+
raise ValueError(f"num_transforms must be >= 1, got {self.num_transforms}")
395+
if self.hidden_features < 1:
396+
raise ValueError(
397+
f"hidden_features must be >= 1, got {self.hidden_features}"
398+
)

0 commit comments

Comments
 (0)