Skip to content

Commit cd9fac8

Browse files
Steve Mandalafacebook-github-bot
Steve Mandala
authored andcommitted
Create helper for defining arm
Summary: Currently, when users want to define status quo arms in Ax, they have to enummerate every parameter key and, assuming they want defaults, set it as None (e.g. https://fburl.com/wiki/zis37qp0). To simplify this, we introduce a helper utility that constructs a null parameter arm by default, but allows users to override specific params if necessary. Reviewed By: sdsingh Differential Revision: D20491298 fbshipit-source-id: 1aecbf6a5f15b27eb0967a78d35f02ffbb365cf0
1 parent 7d1e308 commit cd9fac8

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

ax/core/search_space.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SumConstraint,
1818
)
1919
from ax.core.types import TParameterization
20+
from ax.utils.common.typeutils import not_none
2021

2122

2223
class SearchSpace(Base):
@@ -238,10 +239,29 @@ def out_of_design_arm(self) -> Arm:
238239
Returns:
239240
New arm w/ null parameter values.
240241
"""
241-
parameters = {}
242-
for p_name in self.parameters.keys():
243-
parameters[p_name] = None
244-
return Arm(parameters)
242+
return self.construct_arm()
243+
244+
def construct_arm(
245+
self, parameters: Optional[TParameterization] = None, name: Optional[str] = None
246+
) -> Arm:
247+
"""Construct new arm using given parameters and name. Any
248+
missing parameters fallback to the experiment defaults,
249+
represented as None
250+
"""
251+
final_parameters: TParameterization = {k: None for k in self.parameters.keys()}
252+
if parameters is not None:
253+
# Validate the param values
254+
for p_name, p_value in parameters.items():
255+
if p_name not in self.parameters:
256+
raise ValueError(f"`{p_name}` does not exist in search space.")
257+
if p_value is not None and not self.parameters[p_name].validate(
258+
p_value
259+
):
260+
raise ValueError(
261+
f"`{p_value}` is not a valid value for parameter {p_name}."
262+
)
263+
final_parameters.update(not_none(parameters))
264+
return Arm(parameters=final_parameters, name=name)
245265

246266
def clone(self) -> "SearchSpace":
247267
return SearchSpace(

ax/core/tests/test_search_space.py

+25
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,28 @@ def testOutOfDesignArm(self):
302302
arm1_nones = [p is None for p in arm1.parameters.values()]
303303
self.assertTrue(all(arm1_nones))
304304
self.assertTrue(arm1 == arm2)
305+
306+
def testConstructArm(self):
307+
# Test constructing an arm of default values
308+
arm = self.ss1.construct_arm(name="test")
309+
self.assertEqual(arm.name, "test")
310+
for p_name in self.ss1.parameters.keys():
311+
self.assertTrue(p_name in arm.parameters)
312+
self.assertEqual(arm.parameters[p_name], None)
313+
314+
# Test constructing an arm with a custom value
315+
arm = self.ss1.construct_arm({"a": 1.0})
316+
for p_name in self.ss1.parameters.keys():
317+
self.assertTrue(p_name in arm.parameters)
318+
if p_name == "a":
319+
self.assertEqual(arm.parameters[p_name], 1.0)
320+
else:
321+
self.assertEqual(arm.parameters[p_name], None)
322+
323+
# Test constructing an arm with a bad param name
324+
with self.assertRaises(ValueError):
325+
self.ss1.construct_arm({"IDONTEXIST_a": 1.0})
326+
327+
# Test constructing an arm with a bad param name
328+
with self.assertRaises(ValueError):
329+
self.ss1.construct_arm({"a": "notafloat"})

0 commit comments

Comments
 (0)