|
7 | 7 | Any, |
8 | 8 | Callable, |
9 | 9 | Dict, |
10 | | - Iterable, |
11 | 10 | Literal, |
12 | 11 | Optional, |
13 | 12 | Union, |
|
17 | 16 | ) |
18 | 17 |
|
19 | 18 | from sbi.inference.posteriors.vi_posterior import VIPosterior |
20 | | -from sbi.sbi_types import PyroTransformedDistribution, TorchTransform |
| 19 | +from sbi.sbi_types import TorchTransform, VariationalDistribution |
21 | 20 | from sbi.utils.typechecks import ( |
22 | 21 | is_nonnegative_int, |
23 | 22 | is_positive_float, |
@@ -334,61 +333,66 @@ def validate(self): |
334 | 333 | @dataclass(frozen=True) |
335 | 334 | class VIPosteriorParameters(PosteriorParameters): |
336 | 335 | """ |
337 | | - Parameters for initializing VIPosterior. |
| 336 | + Parameters for VIPosterior, supporting both single-x and amortized VI. |
338 | 337 |
|
339 | 338 | Fields: |
340 | | - q: Variational distribution, either string, `TransformedDistribution`, or a |
341 | | - `VIPosterior` object. This specifies a parametric class of distribution |
342 | | - over which the best possible posterior approximation is searched. For |
343 | | - string input, we currently support [nsf, scf, maf, mcf, gaussian, |
344 | | - gaussian_diag]. You can also specify your own variational family by |
345 | | - passing a pyro `TransformedDistribution`. |
346 | | - Additionally, we allow a `Callable`, which allows you the pass a |
347 | | - `builder` function, which if called returns a distribution. This may be |
348 | | - useful for setting the hyperparameters e.g. `num_transfroms` within the |
349 | | - `get_flow_builder` method specifying the number of transformations |
350 | | - within a normalizing flow. If q is already a `VIPosterior`, then the |
351 | | - arguments will be copied from it (relevant for multi-round training). |
352 | | - vi_method: This specifies the variational methods which are used to fit q to |
353 | | - the posterior. We currently support [rKL, fKL, IW, alpha]. Note that |
354 | | - some of the divergences are `mode seeking` i.e. they underestimate |
355 | | - variance and collapse on multimodal targets (`rKL`, `alpha` for alpha > |
356 | | - 1) and some are `mass covering` i.e. they overestimate variance but |
357 | | - typically cover all modes (`fKL`, `IW`, `alpha` for alpha < 1). |
358 | | - parameters: List of parameters of the variational posterior. This is only |
359 | | - required for user-defined q i.e. if q does not have a `parameters` |
360 | | - attribute. |
361 | | - modules: List of modules of the variational posterior. This is only |
362 | | - required for user-defined q i.e. if q does not have a `modules` |
363 | | - attribute. |
| 339 | + q: Variational distribution. Either a string specifying the flow type |
| 340 | + [nsf, maf, naf, unaf, nice, sospf, gaussian, gaussian_diag], a |
| 341 | + `TransformedDistribution`, a `VIPosterior` object, or a `Callable` |
| 342 | + builder function. For amortized VI, use string flow types only. |
| 343 | + If q is already a `VIPosterior`, arguments are copied from it |
| 344 | + (relevant for multi-round training). |
| 345 | + vi_method: Variational method for fitting q to the posterior. Options: |
| 346 | + [rKL, fKL, IW, alpha]. Some are "mode seeking" (rKL, alpha > 1) and |
| 347 | + some are "mass covering" (fKL, IW, alpha < 1). Currently only used |
| 348 | + for single-x VI; amortized VI uses ELBO (rKL). |
| 349 | + num_transforms: Number of transforms in the normalizing flow. |
| 350 | + hidden_features: Hidden layer size in the flow networks. |
| 351 | + z_score_theta: Method for z-scoring θ (the parameters being modeled). |
| 352 | + One of "none", "independent", "structured". Use "structured" for |
| 353 | + parameters with correlations. |
| 354 | + z_score_x: Method for z-scoring x (the conditioning variable, amortized |
| 355 | + VI only). One of "none", "independent", "structured". Use |
| 356 | + "structured" for structured data like images. |
| 357 | +
|
| 358 | + Note: |
| 359 | + For custom distributions that lack `parameters()` and `modules()` methods, |
| 360 | + pass these via `VIPosterior.set_q(q, parameters=..., modules=...)` instead. |
364 | 361 | """ |
365 | 362 |
|
366 | 363 | q: Union[ |
367 | | - Literal["nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"], |
368 | | - PyroTransformedDistribution, |
| 364 | + Literal[ |
| 365 | + "nsf", "maf", "naf", "unaf", "nice", "sospf", "gaussian", "gaussian_diag" |
| 366 | + ], |
| 367 | + VariationalDistribution, |
369 | 368 | "VIPosterior", |
370 | 369 | Callable, |
371 | 370 | ] = "maf" |
372 | 371 | vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL" |
373 | | - parameters: Optional[Iterable] = None |
374 | | - modules: Optional[Iterable] = None |
| 372 | + num_transforms: int = 5 |
| 373 | + hidden_features: int = 50 |
| 374 | + z_score_theta: Literal["none", "independent", "structured"] = "independent" |
| 375 | + z_score_x: Literal["none", "independent", "structured"] = "independent" |
375 | 376 |
|
376 | 377 | def validate(self): |
377 | 378 | """Validate VIPosteriorParameters fields.""" |
378 | | - |
379 | | - valid_q = {"nsf", "scf", "maf", "mcf", "gaussian", "gaussian_diag"} |
| 379 | + valid_q = { |
| 380 | + "nsf", "maf", "naf", "unaf", "nice", "sospf", "gaussian", "gaussian_diag" |
| 381 | + } |
380 | 382 |
|
381 | 383 | if isinstance(self.q, str) and self.q not in valid_q: |
382 | 384 | raise ValueError(f"If `q` is a string, it must be one of {valid_q}") |
383 | 385 | elif not isinstance( |
384 | | - self.q, (PyroTransformedDistribution, VIPosterior, Callable, str) |
| 386 | + self.q, (VariationalDistribution, VIPosterior, Callable, str) |
385 | 387 | ): |
386 | 388 | raise TypeError( |
387 | | - "q must be either of typr PyroTransformedDistribution," |
388 | | - " VIPosterioror or Callable" |
| 389 | + "q must be either of type VariationalDistribution," |
| 390 | + " VIPosterior or Callable" |
389 | 391 | ) |
390 | 392 |
|
391 | | - if self.parameters is not None and not isinstance(self.parameters, Iterable): |
392 | | - raise TypeError("parameters must be iterable or None.") |
393 | | - if self.modules is not None and not isinstance(self.modules, Iterable): |
394 | | - raise TypeError("modules must be iterable or None.") |
| 393 | + if self.num_transforms < 1: |
| 394 | + raise ValueError(f"num_transforms must be >= 1, got {self.num_transforms}") |
| 395 | + if self.hidden_features < 1: |
| 396 | + raise ValueError( |
| 397 | + f"hidden_features must be >= 1, got {self.hidden_features}" |
| 398 | + ) |
0 commit comments