@@ -57,41 +57,80 @@ def test_ThompsonSampler(self) -> None:
5757 self .assertEqual (len (gen_metadata ["arms_to_weights" ]), 4 )
5858 self .assertEqual (gen_metadata ["best_x" ], arms [0 ])
5959
60+ def test_ThompsonSamplerWeightConfigs (self ) -> None :
61+ for label , min_weight , uniform_weights , expected_arms , expected_weights in [
62+ (
63+ "min_weight=0.01" ,
64+ 0.01 ,
65+ False ,
66+ [[4 , 4 ], [3 , 3 ], [2 , 2 ]],
67+ [3 * i for i in [0.725 , 0.225 , 0.05 ]],
68+ ),
69+ (
70+ "uniform_weights" ,
71+ 0.0 ,
72+ True ,
73+ [[4 , 4 ], [3 , 3 ], [2 , 2 ]],
74+ [1.0 , 1.0 , 1.0 ],
75+ ),
76+ ]:
77+ with self .subTest (config = label ):
78+ np .random .seed (0 )
79+ generator = ThompsonSampler (
80+ min_weight = min_weight , uniform_weights = uniform_weights
81+ )
82+ generator .fit (
83+ Xs = self .Xs ,
84+ Ys = self .Ys ,
85+ Yvars = self .Yvars ,
86+ parameter_values = self .parameter_values ,
87+ outcome_names = self .outcome_names ,
88+ )
89+ arms , weights , _ = generator .gen (
90+ n = 3 ,
91+ parameter_values = self .parameter_values ,
92+ objective_weights = np .ones (1 ),
93+ )
94+ self .assertEqual (arms , expected_arms )
95+ for weight , expected_weight in zip (weights , expected_weights ):
96+ self .assertAlmostEqual (weight , expected_weight , 1 )
97+
6098 def test_ThompsonSamplerValidation (self ) -> None :
6199 generator = ThompsonSampler (min_weight = 0.01 )
62100
63- # all Xs are not the same
64- with self .assertRaises (ValueError ):
65- generator .fit (
66- Xs = [[[1 , 1 ], [2 , 2 ], [3 , 3 ], [4 , 4 ]], [[1 , 1 ], [2 , 2 ], [4 , 4 ]]],
67- Ys = self .Ys ,
68- Yvars = self .Yvars ,
69- parameter_values = self .parameter_values ,
70- outcome_names = self .outcome_names ,
71- )
101+ with self . subTest ( case = "mismatched_Xs" ):
102+ with self .assertRaises (ValueError ):
103+ generator .fit (
104+ Xs = [[[1 , 1 ], [2 , 2 ], [3 , 3 ], [4 , 4 ]], [[1 , 1 ], [2 , 2 ], [4 , 4 ]]],
105+ Ys = self .Ys ,
106+ Yvars = self .Yvars ,
107+ parameter_values = self .parameter_values ,
108+ outcome_names = self .outcome_names ,
109+ )
72110
73- # multiple observations per parameterization
74- with self .assertRaises (ValueError ):
111+ with self .subTest (case = "duplicate_parameterizations" ):
112+ with self .assertRaises (ValueError ):
113+ generator .fit (
114+ Xs = [[[1 , 1 ], [2 , 2 ], [2 , 2 ]]],
115+ Ys = self .Ys ,
116+ Yvars = self .Yvars ,
117+ parameter_values = self .parameter_values ,
118+ outcome_names = self .outcome_names ,
119+ )
120+
121+ with self .subTest (case = "similar_but_different_observations" ):
122+ # these are not the same observations, so should not error
75123 generator .fit (
76- Xs = [[[1 , 1 ], [2 , 2 ], [2 , 2 ]]],
124+ Xs = [[[1 , 1 ], [2.0 , 2 ], [2 , 2 ]]],
77125 Ys = self .Ys ,
78126 Yvars = self .Yvars ,
79127 parameter_values = self .parameter_values ,
80128 outcome_names = self .outcome_names ,
81129 )
82130
83- # these are not the same observations, so should not error
84- generator .fit (
85- Xs = [[[1 , 1 ], [2.0 , 2 ], [2 , 2 ]]],
86- Ys = self .Ys ,
87- Yvars = self .Yvars ,
88- parameter_values = self .parameter_values ,
89- outcome_names = self .outcome_names ,
90- )
91-
92- # requires objective weights
93- with self .assertRaises (ValueError ):
94- generator .gen (5 , self .parameter_values , objective_weights = None )
131+ with self .subTest (case = "missing_objective_weights" ):
132+ with self .assertRaises (ValueError ):
133+ generator .gen (5 , self .parameter_values , objective_weights = None )
95134
96135 def test_ThompsonSamplerTopKError (self ) -> None :
97136 generator = ThompsonSampler (topk = 5 )
@@ -156,45 +195,6 @@ def test_TopTwo_alters_weights_vs_TopOne(self) -> None:
156195 # 4) Monotonicity in the final TTTS distribution still holds
157196 self .assertTrue (full_w2 [3 ] > full_w2 [2 ] > full_w2 [1 ] > full_w2 [0 ])
158197
159- def test_ThompsonSamplerMinWeight (self ) -> None :
160- np .random .seed (0 )
161- generator = ThompsonSampler (min_weight = 0.01 )
162- generator .fit (
163- Xs = self .Xs ,
164- Ys = self .Ys ,
165- Yvars = self .Yvars ,
166- parameter_values = self .parameter_values ,
167- outcome_names = self .outcome_names ,
168- )
169- arms , weights , _ = generator .gen (
170- n = 3 ,
171- parameter_values = self .parameter_values ,
172- objective_weights = np .ones (1 ),
173- )
174- self .assertEqual (arms , [[4 , 4 ], [3 , 3 ], [2 , 2 ]])
175- for weight , expected_weight in zip (
176- weights , [3 * i for i in [0.725 , 0.225 , 0.05 ]]
177- ):
178- self .assertAlmostEqual (weight , expected_weight , 1 )
179-
180- def test_ThompsonSamplerUniformWeights (self ) -> None :
181- generator = ThompsonSampler (min_weight = 0.0 , uniform_weights = True )
182- generator .fit (
183- Xs = self .Xs ,
184- Ys = self .Ys ,
185- Yvars = self .Yvars ,
186- parameter_values = self .parameter_values ,
187- outcome_names = self .outcome_names ,
188- )
189- arms , weights , _ = generator .gen (
190- n = 3 ,
191- parameter_values = self .parameter_values ,
192- objective_weights = np .ones (1 ),
193- )
194- self .assertEqual (arms , [[4 , 4 ], [3 , 3 ], [2 , 2 ]])
195- for weight , expected_weight in zip (weights , [1.0 , 1.0 , 1.0 ]):
196- self .assertAlmostEqual (weight , expected_weight , 1 )
197-
198198 def test_ThompsonSamplerInfeasible (self ) -> None :
199199 generator = ThompsonSampler (min_weight = 0.9 )
200200 generator .fit (
@@ -302,9 +302,12 @@ def test_ThompsonSamplerNonPositiveN(self) -> None:
302302 outcome_names = self .outcome_names ,
303303 )
304304 for n in (- 1 , 0 ):
305- with self .assertRaisesRegex (ValueError , "ThompsonSampler requires n > 0" ):
306- generator .gen (
307- n = n ,
308- parameter_values = self .parameter_values ,
309- objective_weights = np .ones (1 ),
310- )
305+ with self .subTest (n = n ):
306+ with self .assertRaisesRegex (
307+ ValueError , "ThompsonSampler requires n > 0"
308+ ):
309+ generator .gen (
310+ n = n ,
311+ parameter_values = self .parameter_values ,
312+ objective_weights = np .ones (1 ),
313+ )
0 commit comments