8
8
9
9
from __future__ import annotations
10
10
11
- from dataclasses import dataclass , fields , MISSING
12
- from itertools import chain , count , repeat
11
+ import warnings
12
+ from itertools import count , repeat
13
13
from typing import Any , Dict , Hashable , Iterable , Optional , TypeVar , Union
14
14
15
15
from botorch .utils .containers import BotorchContainer , DenseContainer , SliceContainer
16
16
from torch import long , ones , Tensor
17
- from typing_extensions import get_type_hints
18
17
19
18
T = TypeVar ("T" )
20
19
ContainerLike = Union [BotorchContainer , Tensor ]
21
20
MaybeIterable = Union [T , Iterable [T ]]
22
21
23
22
24
- @dataclass
25
- class BotorchDataset :
26
- # TODO: Once v3.10 becomes standard, expose `validate_init` as a kw_only InitVar
27
- def __post_init__ (self , validate_init : bool = True ) -> None :
28
- if validate_init :
29
- self ._validate ()
23
+ class SupervisedDataset :
24
+ r"""Base class for datasets consisting of labelled pairs `(X, Y)`
25
+ and an optional `Yvar` that stipulates observations variances so
26
+ that `Y[i] ~ N(f(X[i]), Yvar[i])`.
30
27
31
- def _validate (self ) -> None :
32
- pass
33
-
34
-
35
- class SupervisedDatasetMeta (type ):
36
- def __call__ (cls , * args : Any , ** kwargs : Any ):
37
- r"""Converts Tensor-valued fields to DenseContainer under the assumption
38
- that said fields house collections of feature vectors."""
39
- hints = get_type_hints (cls )
40
- fields_iter = (item for item in fields (cls ) if item .init is not None )
41
- f_dict = {}
42
- for value , field in chain (
43
- zip (args , fields_iter ),
44
- ((kwargs .pop (field .name , MISSING ), field ) for field in fields_iter ),
45
- ):
46
- if value is MISSING :
47
- if field .default is not MISSING :
48
- value = field .default
49
- elif field .default_factory is not MISSING :
50
- value = field .default_factory ()
51
- else :
52
- raise RuntimeError (f"Missing required field `{ field .name } `." )
53
-
54
- if issubclass (hints [field .name ], BotorchContainer ):
55
- if isinstance (value , Tensor ):
56
- value = DenseContainer (value , event_shape = value .shape [- 1 :])
57
- elif not isinstance (value , BotorchContainer ):
58
- raise TypeError (
59
- "Expected <BotorchContainer | Tensor> for field "
60
- f"`{ field .name } ` but was { type (value )} ."
61
- )
62
- f_dict [field .name ] = value
63
-
64
- return super ().__call__ (** f_dict , ** kwargs )
65
-
66
-
67
- @dataclass
68
- class SupervisedDataset (BotorchDataset , metaclass = SupervisedDatasetMeta ):
69
- r"""Base class for datasets consisting of labelled pairs `(x, y)`.
70
-
71
- This class object's `__call__` method converts Tensors `src` to
28
+ This class object's `__init__` method converts Tensors `src` to
72
29
DenseContainers under the assumption that `event_shape=src.shape[-1:]`.
73
30
74
31
Example:
@@ -87,6 +44,29 @@ class SupervisedDataset(BotorchDataset, metaclass=SupervisedDatasetMeta):
87
44
88
45
X : BotorchContainer
89
46
Y : BotorchContainer
47
+ Yvar : Optional [BotorchContainer ]
48
+
49
+ def __init__ (
50
+ self ,
51
+ X : ContainerLike ,
52
+ Y : ContainerLike ,
53
+ Yvar : Optional [ContainerLike ] = None ,
54
+ validate_init : bool = True ,
55
+ ) -> None :
56
+ r"""Constructs a `SupervisedDataset`.
57
+
58
+ Args:
59
+ X: A `Tensor` or `BotorchContainer` representing the input features.
60
+ Y: A `Tensor` or `BotorchContainer` representing the outcomes.
61
+ Yvar: An optional `Tensor` or `BotorchContainer` representing
62
+ the observation noise.
63
+ validate_init: If `True`, validates the input shapes.
64
+ """
65
+ self .X = _containerize (X )
66
+ self .Y = _containerize (Y )
67
+ self .Yvar = None if Yvar is None else _containerize (Yvar )
68
+ if validate_init :
69
+ self ._validate ()
90
70
91
71
def _validate (self ) -> None :
92
72
shape_X = self .X .shape
@@ -95,12 +75,15 @@ def _validate(self) -> None:
95
75
shape_Y = shape_Y [: len (shape_Y ) - len (self .Y .event_shape )]
96
76
if shape_X != shape_Y :
97
77
raise ValueError ("Batch dimensions of `X` and `Y` are incompatible." )
78
+ if self .Yvar is not None and self .Yvar .shape != self .Y .shape :
79
+ raise ValueError ("Shapes of `Y` and `Yvar` are incompatible." )
98
80
99
81
@classmethod
100
82
def dict_from_iter (
101
83
cls ,
102
84
X : MaybeIterable [ContainerLike ],
103
85
Y : MaybeIterable [ContainerLike ],
86
+ Yvar : Optional [MaybeIterable [ContainerLike ]] = None ,
104
87
* ,
105
88
keys : Optional [Iterable [Hashable ]] = None ,
106
89
) -> Dict [Hashable , SupervisedDataset ]:
@@ -111,40 +94,46 @@ def dict_from_iter(
111
94
X = (X ,) if single_Y else repeat (X )
112
95
if single_Y :
113
96
Y = (Y ,) if single_X else repeat (Y )
114
- return {key : cls (x , y ) for key , x , y in zip (keys or count (), X , Y )}
97
+ Yvar = repeat (Yvar ) if isinstance (Yvar , (Tensor , BotorchContainer )) else Yvar
98
+
99
+ # Pass in Yvar only if it is not None.
100
+ iterables = (X , Y ) if Yvar is None else (X , Y , Yvar )
101
+ return {
102
+ elements [0 ]: cls (* elements [1 :])
103
+ for elements in zip (keys or count (), * iterables )
104
+ }
105
+
106
+ def __eq__ (self , other : Any ) -> bool :
107
+ return (
108
+ type (other ) is type (self )
109
+ and self .X == other .X
110
+ and self .Y == other .Y
111
+ and self .Yvar == other .Yvar
112
+ )
115
113
116
114
117
- @dataclass
118
115
class FixedNoiseDataset (SupervisedDataset ):
119
116
r"""A SupervisedDataset with an additional field `Yvar` that stipulates
120
- observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`."""
117
+ observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`.
121
118
122
- X : BotorchContainer
123
- Y : BotorchContainer
124
- Yvar : BotorchContainer
125
-
126
- @classmethod
127
- def dict_from_iter (
128
- cls ,
129
- X : MaybeIterable [ContainerLike ],
130
- Y : MaybeIterable [ContainerLike ],
131
- Yvar : Optional [MaybeIterable [ContainerLike ]] = None ,
132
- * ,
133
- keys : Optional [Iterable [Hashable ]] = None ,
134
- ) -> Dict [Hashable , SupervisedDataset ]:
135
- r"""Returns a dictionary of `FixedNoiseDataset` from iterables."""
136
- single_X = isinstance (X , (Tensor , BotorchContainer ))
137
- single_Y = isinstance (Y , (Tensor , BotorchContainer ))
138
- if single_X :
139
- X = (X ,) if single_Y else repeat (X )
140
- if single_Y :
141
- Y = (Y ,) if single_X else repeat (Y )
119
+ NOTE: This is deprecated. Use `SupervisedDataset` instead.
120
+ """
142
121
143
- Yvar = repeat (Yvar ) if isinstance (Yvar , (Tensor , BotorchContainer )) else Yvar
144
- return {key : cls (x , y , c ) for key , x , y , c in zip (keys or count (), X , Y , Yvar )}
122
+ def __init__ (
123
+ self ,
124
+ X : ContainerLike ,
125
+ Y : ContainerLike ,
126
+ Yvar : ContainerLike ,
127
+ validate_init : bool = True ,
128
+ ) -> None :
129
+ r"""Initialize a `FixedNoiseDataset` -- deprecated!"""
130
+ warnings .warn (
131
+ "`FixedNoiseDataset` is deprecated. Use `SupervisedDataset` instead." ,
132
+ DeprecationWarning ,
133
+ )
134
+ super ().__init__ (X = X , Y = Y , Yvar = Yvar , validate_init = validate_init )
145
135
146
136
147
- @dataclass
148
137
class RankingDataset (SupervisedDataset ):
149
138
r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations
150
139
`x ∈ Z^{m}` of elements from a ground set `Z = (z_1, ...)` and ranking vectors
@@ -173,6 +162,18 @@ class RankingDataset(SupervisedDataset):
173
162
X : SliceContainer
174
163
Y : BotorchContainer
175
164
165
+ def __init__ (
166
+ self , X : SliceContainer , Y : ContainerLike , validate_init : bool = True
167
+ ) -> None :
168
+ r"""Construct a `RankingDataset`.
169
+
170
+ Args:
171
+ X: A `SliceContainer` representing the input features being ranked.
172
+ Y: A `Tensor` or `BotorchContainer` representing the rankings.
173
+ validate_init: If `True`, validates the input shapes.
174
+ """
175
+ super ().__init__ (X = X , Y = Y , Yvar = None , validate_init = validate_init )
176
+
176
177
def _validate (self ) -> None :
177
178
super ()._validate ()
178
179
@@ -201,3 +202,13 @@ def _validate(self) -> None:
201
202
202
203
# Same as: torch.where(y_diff == 0, y_incr + 1, 1)
203
204
y_incr = y_incr - y_diff + 1
205
+
206
+
207
+ def _containerize (value : ContainerLike ) -> BotorchContainer :
208
+ r"""Converts Tensor-valued arguments to DenseContainer under the assumption
209
+ that said arguments house collections of feature vectors.
210
+ """
211
+ if isinstance (value , Tensor ):
212
+ return DenseContainer (value , event_shape = value .shape [- 1 :])
213
+ else :
214
+ return value
0 commit comments