Skip to content

Commit 0805469

Browse files
abelabajanfb
andauthored
refactor: dataclasses for LossArgs and TrainConfig (#1668)
* refactor: add training dataclasses and update get_start_index method argument * refactor: update train methods to use TrainConfig dataclass * chore: move num_atoms default value to dataclass * chore: add docstrings to _get_losses method and add error message * Update sbi/inference/trainers/_contracts.py Co-authored-by: Jan <janfb@users.noreply.github.com> * chore: set global variables in NeuralInference * chore: remove verbose comments * chore: rename ctx to context * docs: update warning message for nre_a loss_kwargs argument * nre: update error message displayed when not passing a value to loss_kwargs * refactor: add validations for training dataclasses * chore: add license header * Update sbi/inference/trainers/_contracts.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/base.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/base.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/base.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/base.py Co-authored-by: Jan <janfb@users.noreply.github.com> * chore: update formatting * chore: update training dataclass imports * docs(bnre): add comment for LossArgsBNRE instantiation * refactor: add _get_losses overload methods and update type hints * test: update training dataclass tests * docs: update LossArgsVF times field docstring * Update sbi/inference/trainers/nre/nre_base.py Co-authored-by: Jan <janfb@users.noreply.github.com> * chore: update formatting --------- Co-authored-by: Jan <janfb@users.noreply.github.com>
1 parent 68fb53b commit 0805469

File tree

13 files changed

+839
-194
lines changed

13 files changed

+839
-194
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2+
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3+
4+
from __future__ import annotations
5+
6+
from dataclasses import dataclass, field
7+
from typing import Callable, Optional, TypeVar, Union
8+
9+
from torch import Tensor
10+
from torch.distributions import Distribution
11+
12+
from sbi.inference.posteriors.base_posterior import NeuralPosterior
13+
from sbi.utils.typechecks import (
14+
validate_bool,
15+
validate_float_range,
16+
validate_optional,
17+
validate_positive_float,
18+
validate_positive_int,
19+
)
20+
21+
22+
@dataclass(frozen=True)
23+
class StartIndexContext:
24+
"""Inputs for computing the start index of training.
25+
26+
Consolidates parameters that previously varied across subclasses, enabling a
27+
single base signature: `_get_start_index(context: StartIndexContext) -> int`.
28+
29+
Fields are optional where method families differ; subclasses read only what
30+
they need.
31+
"""
32+
33+
# Common across methods (e.g., NLE/NRE);
34+
discard_prior_samples: bool
35+
36+
# SNPE-specific knobs
37+
force_first_round_loss: Optional[bool] = None
38+
39+
# Generic training state:
40+
resume_training: Optional[bool] = None
41+
42+
def __post_init__(self):
43+
validate_bool(self.discard_prior_samples, "discard_prior_samples")
44+
validate_optional(self.force_first_round_loss, "force_first_round_loss", bool)
45+
validate_optional(self.resume_training, "resume_training", bool)
46+
47+
48+
@dataclass
49+
class TrainConfig:
50+
"""Configuration for the core training path.
51+
52+
This captures loop-level hyperparameters and toggles that are independent of
53+
any specific estimator family. Subclass `train(...kwargs)` wrappers translate
54+
user kwargs into this config and delegate to the base core.
55+
"""
56+
57+
# Data & optimization
58+
training_batch_size: int
59+
learning_rate: float
60+
61+
# Loop controls
62+
validation_fraction: float
63+
stop_after_epochs: int
64+
max_num_epochs: int
65+
66+
# Lifecycle
67+
resume_training: bool
68+
retrain_from_scratch: bool
69+
70+
# UX
71+
show_train_summary: bool
72+
73+
# Regularization / safety
74+
clip_max_norm: Optional[float] = None
75+
76+
def __post_init__(self):
77+
validate_positive_int(self.training_batch_size, "training_batch_size")
78+
validate_positive_float(self.learning_rate, "learning_rate")
79+
validate_float_range(
80+
self.validation_fraction,
81+
"validation_fraction",
82+
min_val=0,
83+
max_val=1,
84+
range_inclusive=False,
85+
)
86+
validate_positive_int(self.stop_after_epochs, "stop_after_epochs")
87+
validate_positive_int(self.max_num_epochs, "max_num_epochs")
88+
validate_bool(self.resume_training, "resume_training")
89+
validate_bool(self.retrain_from_scratch, "retrain_from_scratch")
90+
validate_bool(self.show_train_summary, "show_train_summary")
91+
if self.clip_max_norm is not None:
92+
validate_positive_float(self.clip_max_norm, "clip_max_norm")
93+
94+
95+
@dataclass(frozen=True)
96+
class LossArgsNRE:
97+
"""
98+
Typed args for ratio-estimation losses (NRE family).
99+
100+
Fields:
101+
num_atoms: Number of atoms to use for classification.
102+
"""
103+
104+
num_atoms: int = 10
105+
106+
def __post_init__(self):
107+
validate_positive_int(self.num_atoms, "num_atoms")
108+
109+
110+
@dataclass(frozen=True)
111+
class LossArgsNRE_A(LossArgsNRE):
112+
"""
113+
Typed args for NRE_A.
114+
115+
Fields:
116+
num_atoms: Number of atoms to use for classification,
117+
AALR is defined for `num_atoms=2`.
118+
"""
119+
120+
num_atoms: int = field(init=False, default=2)
121+
122+
def __post_init__(self):
123+
if self.num_atoms != 2:
124+
raise ValueError("In AARL / NRE-A, num_atoms must always be 2")
125+
126+
127+
@dataclass(frozen=True, kw_only=True)
128+
class LossArgsBNRE(LossArgsNRE_A):
129+
r"""
130+
Typed args for balanced neural ratio estimation losses (BNRE).
131+
132+
Fields:
133+
regularization_strength: The multiplicative coefficient applied to the
134+
balancing regularizer ($\lambda$).
135+
"""
136+
137+
regularization_strength: float
138+
139+
def __post_init__(self):
140+
validate_positive_float(self.regularization_strength, "regularization_strength")
141+
142+
143+
@dataclass(frozen=True, kw_only=True)
144+
class LossArgsNRE_C(LossArgsNRE):
145+
r"""
146+
Typed args for NRE_C losses.
147+
148+
Fields:
149+
gamma: Determines the relative weight of the sum of all $K$ dependently
150+
drawn classes against the marginally drawn one. Specifically,
151+
$p(y=k) :=p_K$, $p(y=0) := p_0$, $p_0 = 1 - K p_K$, and finally
152+
$\gamma := K p_K / p_0$.
153+
"""
154+
155+
gamma: float
156+
157+
def __post_init__(self):
158+
validate_positive_float(self.gamma, "gamma")
159+
160+
161+
@dataclass(frozen=True)
162+
class LossArgsNPE:
163+
"""
164+
Typed args for posterior-estimation losses (NPE family).
165+
166+
Fields:
167+
proposal may be a torch.distributions.Distribution or a NeuralPosterior
168+
calibration_kernel: A function to calibrate the loss with respect
169+
to the simulations `x` (optional). See Lueckmann, Gonçalves et al.,
170+
NeurIPS 2017. If `None`, no calibration is used.
171+
force_first_round_loss: If `True`, train with maximum likelihood,
172+
i.e., potentially ignoring the correction for using a proposal
173+
distribution different from the prior.
174+
"""
175+
176+
proposal: Optional[Union["Distribution", "NeuralPosterior"]] = None
177+
calibration_kernel: Optional[Callable[..., "Tensor"]] = None
178+
force_first_round_loss: bool = False
179+
180+
def __post_init__(self):
181+
validate_optional(self.proposal, "proposal", Distribution, NeuralPosterior)
182+
validate_optional(self.calibration_kernel, "calibration_kernel", Callable)
183+
validate_bool(self.force_first_round_loss, "force_first_round_loss")
184+
185+
186+
@dataclass(frozen=True)
187+
class LossArgsVF:
188+
"""
189+
Typed args for vector-field estimation losses (VF family).
190+
191+
Fields:
192+
proposal: a torch.distributions.Distribution or a NeuralPosterior.
193+
calibration_kernel: A function to calibrate the loss with respect
194+
to the simulations `x` (optional). See Lueckmann, Gonçalves et al.,
195+
NeurIPS 2017. If `None`, no calibration is used.
196+
times: Time steps to compute the loss at.
197+
force_first_round_loss: If `True`, train with maximum likelihood,
198+
i.e., potentially ignoring the correction for using a proposal
199+
distribution different from the prior.
200+
"""
201+
202+
proposal: Optional[Union["Distribution", "NeuralPosterior"]] = None
203+
calibration_kernel: Optional[Callable[..., "Tensor"]] = None
204+
times: Optional["Tensor"] = None
205+
force_first_round_loss: bool = False
206+
207+
def __post_init__(self):
208+
validate_optional(self.proposal, "proposal", Distribution, NeuralPosterior)
209+
validate_optional(self.calibration_kernel, "calibration_kernel", Callable)
210+
validate_optional(self.times, "times", Tensor)
211+
validate_bool(self.force_first_round_loss, "force_first_round_loss")
212+
213+
214+
LossArgs = Union[LossArgsNRE, LossArgsNPE, LossArgsVF]
215+
LossArgsT = TypeVar("LossArgsT", LossArgsNRE, LossArgsNPE, LossArgsVF)
216+
217+
218+
__all__ = [
219+
"StartIndexContext",
220+
"TrainConfig",
221+
"LossArgsNRE",
222+
"LossArgsNPE",
223+
"LossArgsVF",
224+
"LossArgs",
225+
"LossArgsT",
226+
]

0 commit comments

Comments
 (0)