Skip to content

Commit 84fa231

Browse files
Jason Chowfacebook-github-bot
authored andcommitted
Temporary attribute context manager for factories
Summary: Utility context manager that allows us to temporarily modify the class attributes of a factory to make a module for a subclass. Now used in the pairwise factory. Differential Revision: D74192292
1 parent 0697272 commit 84fa231

2 files changed

Lines changed: 27 additions & 20 deletions

File tree

aepsych/factory/pairwise.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
default_mean_covar_factory,
1818
DefaultMeanCovarFactory,
1919
)
20+
from aepsych.factory.utils import temporary_attributes
2021
from aepsych.kernels.pairwisekernel import PairwiseKernel
2122

2223

@@ -84,30 +85,21 @@ def _make_covar_module(self) -> gpytorch.kernels.Kernel:
8485
)
8586

8687
# Temporarily modify attributes to make base pair covariance module
87-
original_dim = self.dim
88-
self.dim = len(active_dims) // 2
89-
self.stimuli_per_trial = 1
90-
91-
base_cov = super()._make_covar_module() # TODO: This is really awkward
92-
93-
self.dim = original_dim
94-
self.stimuli_per_trial = 1
88+
with temporary_attributes(self, dim=len(active_dims) // 2, stimuli_per_trial=1):
89+
base_cov = super()._make_covar_module()
9590

9691
if len(self.shared_dims) == 0:
9792
return PairwiseKernel(base_cov)
9893

9994
else: # Some paired dims
100-
# Need to make an extra shared dim covariance module
101-
self.dim = len(self.shared_dims)
102-
orig_active_dims = self.active_dims
103-
self.active_dims = self.shared_dims
104-
self.stimuli_per_trial = 1
105-
106-
shared_cov = super()._make_covar_module()
107-
108-
self.dim = original_dim
109-
self.stimuli_per_trial = 1
110-
self.active_dims = orig_active_dims
95+
# Again temporary attributes
96+
with temporary_attributes(
97+
self,
98+
dim=len(self.shared_dims),
99+
active_dims=self.shared_dims,
100+
stimuli_per_trial=1,
101+
):
102+
shared_cov = super()._make_covar_module()
111103

112104
return PairwiseKernel(base_cov, active_dims=active_dims) * shared_cov
113105

aepsych/factory/utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,24 @@
44

55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
7-
7+
from contextlib import contextmanager
88

99
__default_invgamma_concentration = 4.6
1010
__default_invgamma_rate = 1.0
1111
DEFAULT_INVGAMMA_CONC = 4.6
1212
DEFAULT_INVGAMMA_RATE = 1.0
13+
14+
15+
@contextmanager
16+
def temporary_attributes(obj, **kwargs):
17+
"""Temporarily sets attributes on an object, and restores them when the context exits."""
18+
19+
try:
20+
old_attrs = {}
21+
for attr, val in kwargs.items():
22+
old_attrs[attr] = getattr(obj, attr)
23+
setattr(obj, attr, val)
24+
yield obj
25+
finally:
26+
for attr, val in old_attrs.items():
27+
setattr(obj, attr, val)

0 commit comments

Comments
 (0)