Skip to content

Commit 9f269fa

Browse files
Balandatmeta-codesync[bot]
authored andcommitted
Expand Self usage for better fluent APIs (#3194)
Summary: Pull Request resolved: #3194 Replace explicit class return types with `Self` for methods that return `self`, enabling better type inference for subclasses. Files updated: - botorch/models/relevance_pursuit.py - RelevancePursuitMixin: to_sparse, to_dense, expand_support, contract_support, full_support, remove_support now return Self - botorch/models/robust_relevance_pursuit_model.py - RobustRelevancePursuitMixin.load_standard_model now returns Self - botorch/models/model.py - Model.eval, Model.train now return Self Reviewed By: hvarfner Differential Revision: D91641969 fbshipit-source-id: 6bac972ad533ebb0c23a86e9de36cb6cfb99eaf9
1 parent bad715c commit 9f269fa

3 files changed

Lines changed: 11 additions & 11 deletions

File tree

botorch/models/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,12 @@ def _revert_to_original_inputs(self) -> None:
243243
self.set_train_data(self._original_train_inputs, strict=False)
244244
self._has_transformed_inputs = False
245245

246-
def eval(self) -> Model:
246+
def eval(self) -> Self:
247247
r"""Puts the model in ``eval`` mode and sets the transformed inputs."""
248248
self._set_transformed_inputs()
249249
return super().eval()
250250

251-
def train(self, mode: bool = True) -> Model:
251+
def train(self, mode: bool = True) -> Self:
252252
r"""Put the model in ``train`` mode. Reverts to the original inputs if
253253
in ``train`` mode (``mode=True``) or sets transformed inputs if in
254254
``eval`` mode (``mode=False``).

botorch/models/relevance_pursuit.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from collections.abc import Callable, Sequence
2323
from copy import copy, deepcopy
2424
from functools import partial
25-
from typing import Any, cast
25+
from typing import Any, cast, Self
2626
from warnings import warn
2727

2828
import torch
@@ -117,7 +117,7 @@ def inactive_indices(self) -> Tensor:
117117
device = self.sparse_parameter.device
118118
return torch.arange(self.dim, device=device)[~self.is_active]
119119

120-
def to_sparse(self) -> RelevancePursuitMixin:
120+
def to_sparse(self) -> Self:
121121
"""Converts the sparse parameter to its sparse (< dim) representation.
122122
123123
Returns:
@@ -130,7 +130,7 @@ def to_sparse(self) -> RelevancePursuitMixin:
130130
self._is_sparse = True
131131
return self
132132

133-
def to_dense(self) -> RelevancePursuitMixin:
133+
def to_dense(self) -> Self:
134134
"""Converts the sparse parameter to its dense, length-``dim`` representation.
135135
136136
Returns:
@@ -157,7 +157,7 @@ def to_dense(self) -> RelevancePursuitMixin:
157157
self._is_sparse = False
158158
return self
159159

160-
def expand_support(self, indices: list[int]) -> RelevancePursuitMixin:
160+
def expand_support(self, indices: list[int]) -> Self:
161161
"""Expands the support by a number of indices.
162162
163163
Args:
@@ -186,7 +186,7 @@ def expand_support(self, indices: list[int]) -> RelevancePursuitMixin:
186186
)
187187
return self
188188

189-
def contract_support(self, indices: list[int]) -> RelevancePursuitMixin:
189+
def contract_support(self, indices: list[int]) -> Self:
190190
"""Contracts the support by a number of indices.
191191
192192
Args:
@@ -216,7 +216,7 @@ def contract_support(self, indices: list[int]) -> RelevancePursuitMixin:
216216
return self
217217

218218
# support initialization helpers
219-
def full_support(self) -> RelevancePursuitMixin:
219+
def full_support(self) -> Self:
220220
"""Initializes the RelevancePursuitMixin with a full, size-``dim`` support.
221221
222222
Returns:
@@ -226,7 +226,7 @@ def full_support(self) -> RelevancePursuitMixin:
226226
self.to_dense() # no reason to be sparse with full support
227227
return self
228228

229-
def remove_support(self) -> RelevancePursuitMixin:
229+
def remove_support(self) -> Self:
230230
"""Initializes the RelevancePursuitMixin with an empty, size-zero support.
231231
232232
Returns:

botorch/models/robust_relevance_pursuit_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from abc import ABC, abstractmethod
3636
from collections.abc import Callable, Mapping, Sequence
37-
from typing import Any
37+
from typing import Any, Self
3838

3939
import torch
4040
from botorch.exceptions.errors import UnsupportedError
@@ -140,7 +140,7 @@ def to_standard_model(self) -> Model:
140140
A standard model.
141141
"""
142142

143-
def load_standard_model(self, standard_model: Model) -> RobustRelevancePursuitMixin:
143+
def load_standard_model(self, standard_model: Model) -> Self:
144144
"""Loads the state dict of a model into the ``RobustRelevancePursuitMixin``.
145145
146146
Args:

0 commit comments

Comments
 (0)