|
6 | 6 | # pyre-strict |
7 | 7 |
|
8 | 8 | import torch |
9 | | -from ax.core.search_space import SearchSpaceDigest |
10 | 9 | from ax.generators.torch.botorch_modular.input_constructors.outcome_transform import ( |
11 | 10 | outcome_transform_argparse, |
12 | 11 | ) |
|
18 | 17 | ) |
19 | 18 | from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset |
20 | 19 | from pyre_extensions import assert_is_instance |
21 | | -from torch import Tensor |
22 | 20 |
|
23 | 21 |
|
24 | 22 | class DummyOutcomeTransform(OutcomeTransform): |
@@ -72,52 +70,39 @@ def test_argparse_stratified_standardize(self) -> None: |
72 | 70 | X = self.dataset.X |
73 | 71 | X[:5, 3] = 0 |
74 | 72 | X[5:, 3] = 1 |
75 | | - ssd = SearchSpaceDigest( |
76 | | - feature_names=self.dataset.feature_names, |
77 | | - bounds=[(0.0, 1.0)] * 3 + [(0.0, 2.0)], |
78 | | - task_features=[3], |
79 | | - target_values={3: 1}, |
80 | | - ) |
81 | 73 | mt_dataset = MultiTaskDataset.from_joint_dataset( |
82 | 74 | dataset=self.dataset, |
83 | 75 | task_feature_index=3, |
84 | 76 | target_task_value=1, |
85 | 77 | ) |
86 | 78 | outcome_transform_kwargs_a = outcome_transform_argparse( |
87 | | - StratifiedStandardize, |
88 | | - dataset=mt_dataset, |
89 | | - search_space_digest=ssd, |
| 79 | + StratifiedStandardize, dataset=mt_dataset |
90 | 80 | ) |
91 | | - options_b = {"stratification_idx": 2, "default_task_value": 4} |
| 81 | + options_b = { |
| 82 | + "stratification_idx": 2, |
| 83 | + "task_values": torch.tensor([0, 3]), |
| 84 | + } |
92 | 85 | outcome_transform_kwargs_b = outcome_transform_argparse( |
93 | 86 | StratifiedStandardize, |
94 | 87 | dataset=mt_dataset, |
95 | 88 | outcome_transform_options=options_b, |
96 | | - search_space_digest=ssd, |
97 | 89 | ) |
98 | 90 | expected_options_a = { |
99 | 91 | "stratification_idx": 3, |
100 | | - "observed_task_values": torch.tensor([0, 1], dtype=torch.long), |
101 | | - "all_task_values": torch.tensor([0, 1, 2], dtype=torch.long), |
102 | | - "default_task_value": 1, |
103 | | - } |
104 | | - expected_options_b = { |
105 | | - "stratification_idx": 2, |
106 | | - "observed_task_values": torch.tensor([0, 1], dtype=torch.long), |
107 | | - "all_task_values": torch.tensor([0, 1, 2], dtype=torch.long), |
108 | | - "default_task_value": 4, |
| 92 | + "task_values": torch.tensor([0, 1]), |
109 | 93 | } |
110 | 94 | for expected_options, actual_options in zip( |
111 | | - (expected_options_a, expected_options_b), |
| 95 | + (expected_options_a, options_b), |
112 | 96 | (outcome_transform_kwargs_a, outcome_transform_kwargs_b), |
113 | 97 | ): |
114 | | - self.assertEqual(len(actual_options), 4) |
115 | | - for k in ("stratification_idx", "stratification_idx"): |
116 | | - self.assertEqual(actual_options[k], expected_options[k]) |
117 | | - for k in ("observed_task_values", "all_task_values"): |
118 | | - self.assertTrue( |
119 | | - torch.equal( |
120 | | - actual_options[k], |
121 | | - assert_is_instance(expected_options[k], Tensor), |
122 | | - ) |
| 98 | + self.assertEqual(len(actual_options), 2) |
| 99 | + self.assertEqual( |
| 100 | + actual_options["stratification_idx"], |
| 101 | + expected_options["stratification_idx"], |
| 102 | + ) |
| 103 | + self.assertTrue( |
| 104 | + torch.equal( |
| 105 | + actual_options["task_values"], |
| 106 | + assert_is_instance(expected_options["task_values"], torch.Tensor), |
123 | 107 | ) |
| 108 | + ) |
0 commit comments