Skip to content

Commit 3506538

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Merge SupervisedDataset & FixedNoiseDataset (#1945)
Summary: Pull Request resolved: #1945 This diff deprecates `FixedNoiseDataset` and merges it into `SupervisedDataset` with Yvar becoming an optional field. This also simplifies the class hierarchy a bit, removing `SupervisedDatasetMeta` in favor of an `__init__` method. I plan to follow up on this by adding optional metric names to datasets and introducing a MultiTaskDataset, which will simplify some of the planned work in Ax MBM. Reviewed By: esantorella Differential Revision: D47729430 fbshipit-source-id: 551cd78a02755505573b10ea1f075aa21f838ab7
1 parent 633d9c0 commit 3506538

11 files changed

+121
-187
lines changed

botorch/acquisition/input_constructors.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
100100
from botorch.utils.constraints import get_outcome_constraint_transforms
101101
from botorch.utils.containers import BotorchContainer
102-
from botorch.utils.datasets import BotorchDataset, SupervisedDataset
102+
from botorch.utils.datasets import SupervisedDataset
103103
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
104104
FastNondominatedPartitioning,
105105
NondominatedPartitioning,
@@ -114,7 +114,7 @@
114114

115115

116116
def _field_is_shared(
117-
datasets: Union[Iterable[BotorchDataset], Dict[Hashable, BotorchDataset]],
117+
datasets: Union[Iterable[SupervisedDataset], Dict[Hashable, SupervisedDataset]],
118118
fieldname: Hashable,
119119
) -> bool:
120120
r"""Determines whether or not a given field is shared by all datasets."""
@@ -136,7 +136,7 @@ def _field_is_shared(
136136

137137

138138
def _get_dataset_field(
139-
dataset: MaybeDict[BotorchDataset],
139+
dataset: MaybeDict[SupervisedDataset],
140140
fieldname: str,
141141
transform: Optional[Callable[[BotorchContainer], Any]] = None,
142142
join_rule: Optional[Callable[[Sequence[Any]], Any]] = None,

botorch/models/gp_regression_mixed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def construct_inputs(
185185
likelihood: Optional[Likelihood] = None,
186186
**kwargs: Any,
187187
) -> Dict[str, Any]:
188-
r"""Construct `Model` keyword arguments from a dict of `BotorchDataset`.
188+
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
189189
190190
Args:
191191
training_data: A `SupervisedDataset` containing the training data.

botorch/models/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from botorch.posteriors import Posterior, PosteriorList
3939
from botorch.sampling.base import MCSampler
4040
from botorch.sampling.list_sampler import ListSampler
41-
from botorch.utils.datasets import BotorchDataset
41+
from botorch.utils.datasets import SupervisedDataset
4242
from botorch.utils.transforms import is_fully_bayesian
4343
from torch import Tensor
4444
from torch.nn import Module, ModuleDict, ModuleList
@@ -169,10 +169,10 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
169169
@classmethod
170170
def construct_inputs(
171171
cls,
172-
training_data: Union[BotorchDataset, Dict[Hashable, BotorchDataset]],
172+
training_data: Union[SupervisedDataset, Dict[Hashable, SupervisedDataset]],
173173
**kwargs: Any,
174174
) -> Dict[str, Any]:
175-
r"""Construct `Model` keyword arguments from a dict of `BotorchDataset`."""
175+
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`."""
176176
from botorch.models.utils.parse_training_data import parse_training_data
177177

178178
return parse_training_data(cls, training_data, **kwargs)

botorch/models/utils/parse_training_data.py

+9-22
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@
1616
from botorch.models.model import Model
1717
from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP
1818
from botorch.models.pairwise_gp import PairwiseGP
19-
from botorch.utils.datasets import (
20-
BotorchDataset,
21-
FixedNoiseDataset,
22-
RankingDataset,
23-
SupervisedDataset,
24-
)
19+
from botorch.utils.datasets import RankingDataset, SupervisedDataset
2520
from botorch.utils.dispatcher import Dispatcher
2621
from torch import cat, Tensor
2722
from torch.nn.functional import pad
@@ -37,13 +32,13 @@ def _encoder(arg: Any) -> Type:
3732

3833
def parse_training_data(
3934
consumer: Any,
40-
training_data: Union[BotorchDataset, Dict[Hashable, BotorchDataset]],
35+
training_data: Union[SupervisedDataset, Dict[Hashable, SupervisedDataset]],
4136
**kwargs: Any,
4237
) -> Dict[Hashable, Tensor]:
4338
r"""Prepares a (collection of) datasets for consumption by a given object.
4439
4540
Args:
46-
training_datas: A BoTorchDataset or dictionary thereof.
41+
training_datas: A SupervisedDataset or dictionary thereof.
4742
consumer: The object that will consume the parsed data, or type thereof.
4843
4944
Returns:
@@ -56,18 +51,10 @@ def parse_training_data(
5651
def _parse_model_supervised(
5752
consumer: Model, dataset: SupervisedDataset, **ignore: Any
5853
) -> Dict[Hashable, Tensor]:
59-
return {"train_X": dataset.X(), "train_Y": dataset.Y()}
60-
61-
62-
@dispatcher.register(Model, FixedNoiseDataset)
63-
def _parse_model_fixedNoise(
64-
consumer: Model, dataset: FixedNoiseDataset, **ignore: Any
65-
) -> Dict[Hashable, Tensor]:
66-
return {
67-
"train_X": dataset.X(),
68-
"train_Y": dataset.Y(),
69-
"train_Yvar": dataset.Yvar(),
70-
}
54+
parsed_data = {"train_X": dataset.X(), "train_Y": dataset.Y()}
55+
if dataset.Yvar is not None:
56+
parsed_data["train_Yvar"] = dataset.Yvar()
57+
return parsed_data
7158

7259

7360
@dispatcher.register(PairwiseGP, RankingDataset)
@@ -88,7 +75,7 @@ def _parse_pairwiseGP_ranking(
8875
@dispatcher.register(Model, dict)
8976
def _parse_model_dict(
9077
consumer: Model,
91-
training_data: Dict[Hashable, BotorchDataset],
78+
training_data: Dict[Hashable, SupervisedDataset],
9279
**kwargs: Any,
9380
) -> Dict[Hashable, Tensor]:
9481
if len(training_data) != 1:
@@ -102,7 +89,7 @@ def _parse_model_dict(
10289
@dispatcher.register((MultiTaskGP, FixedNoiseMultiTaskGP), dict)
10390
def _parse_multitask_dict(
10491
consumer: Model,
105-
training_data: Dict[Hashable, BotorchDataset],
92+
training_data: Dict[Hashable, SupervisedDataset],
10693
*,
10794
task_feature: int = 0,
10895
task_feature_container: Hashable = "train_X",

botorch/utils/datasets.py

+87-76
Original file line numberDiff line numberDiff line change
@@ -8,67 +8,24 @@
88

99
from __future__ import annotations
1010

11-
from dataclasses import dataclass, fields, MISSING
12-
from itertools import chain, count, repeat
11+
import warnings
12+
from itertools import count, repeat
1313
from typing import Any, Dict, Hashable, Iterable, Optional, TypeVar, Union
1414

1515
from botorch.utils.containers import BotorchContainer, DenseContainer, SliceContainer
1616
from torch import long, ones, Tensor
17-
from typing_extensions import get_type_hints
1817

1918
T = TypeVar("T")
2019
ContainerLike = Union[BotorchContainer, Tensor]
2120
MaybeIterable = Union[T, Iterable[T]]
2221

2322

24-
@dataclass
25-
class BotorchDataset:
26-
# TODO: Once v3.10 becomes standard, expose `validate_init` as a kw_only InitVar
27-
def __post_init__(self, validate_init: bool = True) -> None:
28-
if validate_init:
29-
self._validate()
23+
class SupervisedDataset:
24+
r"""Base class for datasets consisting of labelled pairs `(X, Y)`
25+
and an optional `Yvar` that stipulates observations variances so
26+
that `Y[i] ~ N(f(X[i]), Yvar[i])`.
3027
31-
def _validate(self) -> None:
32-
pass
33-
34-
35-
class SupervisedDatasetMeta(type):
36-
def __call__(cls, *args: Any, **kwargs: Any):
37-
r"""Converts Tensor-valued fields to DenseContainer under the assumption
38-
that said fields house collections of feature vectors."""
39-
hints = get_type_hints(cls)
40-
fields_iter = (item for item in fields(cls) if item.init is not None)
41-
f_dict = {}
42-
for value, field in chain(
43-
zip(args, fields_iter),
44-
((kwargs.pop(field.name, MISSING), field) for field in fields_iter),
45-
):
46-
if value is MISSING:
47-
if field.default is not MISSING:
48-
value = field.default
49-
elif field.default_factory is not MISSING:
50-
value = field.default_factory()
51-
else:
52-
raise RuntimeError(f"Missing required field `{field.name}`.")
53-
54-
if issubclass(hints[field.name], BotorchContainer):
55-
if isinstance(value, Tensor):
56-
value = DenseContainer(value, event_shape=value.shape[-1:])
57-
elif not isinstance(value, BotorchContainer):
58-
raise TypeError(
59-
"Expected <BotorchContainer | Tensor> for field "
60-
f"`{field.name}` but was {type(value)}."
61-
)
62-
f_dict[field.name] = value
63-
64-
return super().__call__(**f_dict, **kwargs)
65-
66-
67-
@dataclass
68-
class SupervisedDataset(BotorchDataset, metaclass=SupervisedDatasetMeta):
69-
r"""Base class for datasets consisting of labelled pairs `(x, y)`.
70-
71-
This class object's `__call__` method converts Tensors `src` to
28+
This class object's `__init__` method converts Tensors `src` to
7229
DenseContainers under the assumption that `event_shape=src.shape[-1:]`.
7330
7431
Example:
@@ -87,6 +44,29 @@ class SupervisedDataset(BotorchDataset, metaclass=SupervisedDatasetMeta):
8744

8845
X: BotorchContainer
8946
Y: BotorchContainer
47+
Yvar: Optional[BotorchContainer]
48+
49+
def __init__(
50+
self,
51+
X: ContainerLike,
52+
Y: ContainerLike,
53+
Yvar: Optional[ContainerLike] = None,
54+
validate_init: bool = True,
55+
) -> None:
56+
r"""Constructs a `SupervisedDataset`.
57+
58+
Args:
59+
X: A `Tensor` or `BotorchContainer` representing the input features.
60+
Y: A `Tensor` or `BotorchContainer` representing the outcomes.
61+
Yvar: An optional `Tensor` or `BotorchContainer` representing
62+
the observation noise.
63+
validate_init: If `True`, validates the input shapes.
64+
"""
65+
self.X = _containerize(X)
66+
self.Y = _containerize(Y)
67+
self.Yvar = None if Yvar is None else _containerize(Yvar)
68+
if validate_init:
69+
self._validate()
9070

9171
def _validate(self) -> None:
9272
shape_X = self.X.shape
@@ -95,12 +75,15 @@ def _validate(self) -> None:
9575
shape_Y = shape_Y[: len(shape_Y) - len(self.Y.event_shape)]
9676
if shape_X != shape_Y:
9777
raise ValueError("Batch dimensions of `X` and `Y` are incompatible.")
78+
if self.Yvar is not None and self.Yvar.shape != self.Y.shape:
79+
raise ValueError("Shapes of `Y` and `Yvar` are incompatible.")
9880

9981
@classmethod
10082
def dict_from_iter(
10183
cls,
10284
X: MaybeIterable[ContainerLike],
10385
Y: MaybeIterable[ContainerLike],
86+
Yvar: Optional[MaybeIterable[ContainerLike]] = None,
10487
*,
10588
keys: Optional[Iterable[Hashable]] = None,
10689
) -> Dict[Hashable, SupervisedDataset]:
@@ -111,40 +94,46 @@ def dict_from_iter(
11194
X = (X,) if single_Y else repeat(X)
11295
if single_Y:
11396
Y = (Y,) if single_X else repeat(Y)
114-
return {key: cls(x, y) for key, x, y in zip(keys or count(), X, Y)}
97+
Yvar = repeat(Yvar) if isinstance(Yvar, (Tensor, BotorchContainer)) else Yvar
98+
99+
# Pass in Yvar only if it is not None.
100+
iterables = (X, Y) if Yvar is None else (X, Y, Yvar)
101+
return {
102+
elements[0]: cls(*elements[1:])
103+
for elements in zip(keys or count(), *iterables)
104+
}
105+
106+
def __eq__(self, other: Any) -> bool:
107+
return (
108+
type(other) is type(self)
109+
and self.X == other.X
110+
and self.Y == other.Y
111+
and self.Yvar == other.Yvar
112+
)
115113

116114

117-
@dataclass
118115
class FixedNoiseDataset(SupervisedDataset):
119116
r"""A SupervisedDataset with an additional field `Yvar` that stipulates
120-
observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`."""
117+
observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`.
121118
122-
X: BotorchContainer
123-
Y: BotorchContainer
124-
Yvar: BotorchContainer
125-
126-
@classmethod
127-
def dict_from_iter(
128-
cls,
129-
X: MaybeIterable[ContainerLike],
130-
Y: MaybeIterable[ContainerLike],
131-
Yvar: Optional[MaybeIterable[ContainerLike]] = None,
132-
*,
133-
keys: Optional[Iterable[Hashable]] = None,
134-
) -> Dict[Hashable, SupervisedDataset]:
135-
r"""Returns a dictionary of `FixedNoiseDataset` from iterables."""
136-
single_X = isinstance(X, (Tensor, BotorchContainer))
137-
single_Y = isinstance(Y, (Tensor, BotorchContainer))
138-
if single_X:
139-
X = (X,) if single_Y else repeat(X)
140-
if single_Y:
141-
Y = (Y,) if single_X else repeat(Y)
119+
NOTE: This is deprecated. Use `SupervisedDataset` instead.
120+
"""
142121

143-
Yvar = repeat(Yvar) if isinstance(Yvar, (Tensor, BotorchContainer)) else Yvar
144-
return {key: cls(x, y, c) for key, x, y, c in zip(keys or count(), X, Y, Yvar)}
122+
def __init__(
123+
self,
124+
X: ContainerLike,
125+
Y: ContainerLike,
126+
Yvar: ContainerLike,
127+
validate_init: bool = True,
128+
) -> None:
129+
r"""Initialize a `FixedNoiseDataset` -- deprecated!"""
130+
warnings.warn(
131+
"`FixedNoiseDataset` is deprecated. Use `SupervisedDataset` instead.",
132+
DeprecationWarning,
133+
)
134+
super().__init__(X=X, Y=Y, Yvar=Yvar, validate_init=validate_init)
145135

146136

147-
@dataclass
148137
class RankingDataset(SupervisedDataset):
149138
r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations
150139
`x ∈ Z^{m}` of elements from a ground set `Z = (z_1, ...)` and ranking vectors
@@ -173,6 +162,18 @@ class RankingDataset(SupervisedDataset):
173162
X: SliceContainer
174163
Y: BotorchContainer
175164

165+
def __init__(
166+
self, X: SliceContainer, Y: ContainerLike, validate_init: bool = True
167+
) -> None:
168+
r"""Construct a `RankingDataset`.
169+
170+
Args:
171+
X: A `SliceContainer` representing the input features being ranked.
172+
Y: A `Tensor` or `BotorchContainer` representing the rankings.
173+
validate_init: If `True`, validates the input shapes.
174+
"""
175+
super().__init__(X=X, Y=Y, Yvar=None, validate_init=validate_init)
176+
176177
def _validate(self) -> None:
177178
super()._validate()
178179

@@ -201,3 +202,13 @@ def _validate(self) -> None:
201202

202203
# Same as: torch.where(y_diff == 0, y_incr + 1, 1)
203204
y_incr = y_incr - y_diff + 1
205+
206+
207+
def _containerize(value: ContainerLike) -> BotorchContainer:
208+
r"""Converts Tensor-valued arguments to DenseContainer under the assumption
209+
that said arguments house collections of feature vectors.
210+
"""
211+
if isinstance(value, Tensor):
212+
return DenseContainer(value, event_shape=value.shape[-1:])
213+
else:
214+
return value

test/models/test_fully_bayesian.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from botorch.models.transforms import Normalize, Standardize
4848
from botorch.posteriors.fully_bayesian import batched_bisect, FullyBayesianPosterior
4949
from botorch.sampling.get_sampler import get_sampler
50-
from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset
50+
from botorch.utils.datasets import SupervisedDataset
5151
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
5252
NondominatedPartitioning,
5353
)
@@ -550,10 +550,7 @@ def test_construct_inputs(self):
550550
X, Y, Yvar, model = self._get_data_and_model(
551551
infer_noise=infer_noise, **tkwargs
552552
)
553-
if infer_noise:
554-
training_data = SupervisedDataset(X, Y)
555-
else:
556-
training_data = FixedNoiseDataset(X, Y, Yvar)
553+
training_data = SupervisedDataset(X, Y, Yvar)
557554

558555
data_dict = model.construct_inputs(training_data)
559556
self.assertTrue(X.equal(data_dict["train_X"]))

test/models/test_gp_regression.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from botorch.models.utils import add_output_dim
2121
from botorch.posteriors import GPyTorchPosterior
2222
from botorch.sampling import SobolQMCNormalSampler
23-
from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset
23+
from botorch.utils.datasets import SupervisedDataset
2424
from botorch.utils.sampling import manual_seed
2525
from botorch.utils.testing import _get_random_data, BotorchTestCase
2626
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
@@ -450,7 +450,7 @@ def test_construct_inputs(self):
450450
X = model_kwargs["train_X"]
451451
Y = model_kwargs["train_Y"]
452452
Yvar = model_kwargs["train_Yvar"]
453-
training_data = FixedNoiseDataset(X, Y, Yvar)
453+
training_data = SupervisedDataset(X, Y, Yvar)
454454
data_dict = model.construct_inputs(training_data)
455455
self.assertTrue(X.equal(data_dict["train_X"]))
456456
self.assertTrue(Y.equal(data_dict["train_Y"]))

0 commit comments

Comments
 (0)