Skip to content

Commit 7dc9037

Browse files
authored
feat: support integer, boolean, and string button types (#253)
1 parent c2b9d7b commit 7dc9037

File tree

3 files changed

+118
-36
lines changed

3 files changed

+118
-36
lines changed

aidial_sdk/chat_completion/form.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
_T = TypeVar("_T")
2222

2323

24+
_SUPPORTED_BUTTON_TYPES = ["number", "integer", "boolean", "string"]
25+
26+
2427
@dataclass
2528
class Button(Generic[_T]):
2629
const: _T
@@ -118,17 +121,11 @@ def _handle_buttons_extension(schema: Dict[str, Any]) -> None:
118121
prop["dial:widget"] = "buttons"
119122
prop["oneOf"] = button_schemas
120123

121-
# NOTE: The meta schema of the DIAL forms only supports
122-
# 'number' type, so we convert 'integer' to 'number'.
123-
# Could be removed once this restriction is lifted.
124-
if prop["type"] == "integer":
125-
prop["type"] = "number"
126-
127-
# NOTE: The meta schema of the DIAL forms only supports 'number' type.
128-
# Could be removed once this restriction is lifted.
129-
if prop["type"] != "number":
124+
if prop["type"] not in _SUPPORTED_BUTTON_TYPES:
125+
ts = ", ".join(f"{ty!r}" for ty in _SUPPORTED_BUTTON_TYPES)
130126
raise ValueError(
131-
f"Button value must be a number. However, field {schema['title']}.{prop_name} has type {prop['type']!r}."
127+
f"Button value must be a one of the following types: {ts}. "
128+
f"However, field {schema['title']}.{prop_name} has type {prop['type']!r}."
132129
)
133130

134131

tests/examples/test_tic_tac_toe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_ttt_configuration():
1919
"title": "Player",
2020
"description": "Select tic-tac-toe player",
2121
"enum": [1, 2],
22-
"type": "number",
22+
"type": "integer",
2323
"dial:widget": "buttons",
2424
"oneOf": [
2525
{

tests/test_form.py

Lines changed: 110 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_configuration_optional_button_schema():
7676
"properties": {
7777
"int_button_field": {
7878
"title": "Int Button Field",
79-
"type": "number",
79+
"type": "integer",
8080
"dial:widget": "buttons",
8181
"oneOf": [
8282
{
@@ -134,7 +134,7 @@ def test_configuration_one_button_schema():
134134
"int_button_field": {
135135
"title": "Integer Button field",
136136
"description": "Pick a button",
137-
"type": "number",
137+
"type": "integer",
138138
"dial:widget": "buttons",
139139
"oneOf": [
140140
{
@@ -186,6 +186,40 @@ def test_configuration_parsing_success():
186186
)
187187

188188

189+
def test_configuration_string_type_parsing_success():
190+
class _Conf(BaseModel, metaclass=FormMetaclass):
191+
button_field: str = Field(
192+
buttons=[
193+
Button(const="10", title="Title1"),
194+
Button(const="20", title="Title2"),
195+
],
196+
)
197+
198+
assert _Conf.parse_obj({"button_field": "10"}) == _Conf(button_field="10")
199+
200+
201+
def test_configuration_string_type_parsing_fail():
202+
class _Conf(BaseModel, metaclass=FormMetaclass):
203+
button_field: str = Field(
204+
buttons=[
205+
Button(const="10", title="Title1"),
206+
Button(const="20", title="Title2"),
207+
],
208+
)
209+
210+
with pytest.raises(ValidationError) as e:
211+
_Conf.parse_obj({"button_field": "30"})
212+
213+
assert e.value.errors() == [
214+
{
215+
"ctx": {"given": "30", "permitted": ("10", "20")},
216+
"loc": ("button_field",),
217+
"msg": "unexpected value; permitted: '10', '20'",
218+
"type": "value_error.const",
219+
}
220+
]
221+
222+
189223
def test_configuration_parsing_one_button_fail():
190224
conf = {
191225
"int_field": 10,
@@ -332,7 +366,7 @@ class Conf(BaseModel):
332366
},
333367
],
334368
"title": "Buttons Field",
335-
"type": "number",
369+
"type": "integer",
336370
},
337371
"int_field": {
338372
"title": "Int Field",
@@ -397,7 +431,7 @@ class Conf(BaseModel):
397431
},
398432
],
399433
"title": "Buttons Field",
400-
"type": "number",
434+
"type": "integer",
401435
},
402436
"int_field": {
403437
"title": "Int Field",
@@ -488,10 +522,12 @@ class Conf(BaseModel):
488522
assert parsed_conf.buttons_field == 10
489523

490524

491-
def test_dynamic_configuration_two_buttons():
525+
def test_dynamic_configuration_four_buttons():
492526
class Conf(BaseModel):
493527
int_button_field: int
494528
float_button_field: float
529+
bool_button_field: bool
530+
str_button_field: str
495531

496532
conf = form(
497533
chat_message_input_disabled=True,
@@ -509,6 +545,20 @@ class Conf(BaseModel):
509545
Button(const=40.1, title="Title4"),
510546
],
511547
),
548+
bool_button_field=Field(
549+
description="Switch",
550+
buttons=[
551+
Button(const=True, title="On"),
552+
Button(const=False, title="Off"),
553+
],
554+
),
555+
str_button_field=Field(
556+
description="Flavour",
557+
buttons=[
558+
Button(const="vanilla", title="Vanilla flavour"),
559+
Button(const="chocolate", title="Chocolate flavour"),
560+
],
561+
),
512562
)(Conf)
513563

514564
actual_schema = conf.schema()
@@ -541,7 +591,7 @@ class Conf(BaseModel):
541591
},
542592
],
543593
"title": "Int Button Field",
544-
"type": "number",
594+
"type": "integer",
545595
},
546596
"float_button_field": {
547597
"dial:widget": "buttons",
@@ -569,35 +619,70 @@ class Conf(BaseModel):
569619
"title": "Float Button Field",
570620
"type": "number",
571621
},
622+
"bool_button_field": {
623+
"dial:widget": "buttons",
624+
"description": "Switch",
625+
"oneOf": [
626+
{
627+
"const": True,
628+
"dial:widgetOptions": {
629+
"confirmationMessage": None,
630+
"populateText": None,
631+
"submit": False,
632+
},
633+
"title": "On",
634+
},
635+
{
636+
"const": False,
637+
"dial:widgetOptions": {
638+
"confirmationMessage": None,
639+
"populateText": None,
640+
"submit": False,
641+
},
642+
"title": "Off",
643+
},
644+
],
645+
"title": "Bool Button Field",
646+
"type": "boolean",
647+
},
648+
"str_button_field": {
649+
"dial:widget": "buttons",
650+
"description": "Flavour",
651+
"oneOf": [
652+
{
653+
"const": "vanilla",
654+
"dial:widgetOptions": {
655+
"confirmationMessage": None,
656+
"populateText": None,
657+
"submit": False,
658+
},
659+
"title": "Vanilla flavour",
660+
},
661+
{
662+
"const": "chocolate",
663+
"dial:widgetOptions": {
664+
"confirmationMessage": None,
665+
"populateText": None,
666+
"submit": False,
667+
},
668+
"title": "Chocolate flavour",
669+
},
670+
],
671+
"title": "Str Button Field",
672+
"type": "string",
673+
},
572674
},
573675
"required": [
574676
"int_button_field",
575677
"float_button_field",
678+
"bool_button_field",
679+
"str_button_field",
576680
],
577681
"title": "_Conf",
578682
"type": "object",
579683
}
580684

581685

582-
def test_configuration_invalid_button_type():
583-
with pytest.raises(ValueError) as e:
584-
585-
class _Conf(BaseModel, metaclass=FormMetaclass):
586-
button_field: str = Field(
587-
buttons=[
588-
Button(const="10", title="Title1"),
589-
Button(const="20", title="Title2"),
590-
],
591-
)
592-
593-
_Conf.schema()
594-
595-
assert (
596-
str(e.value)
597-
== "Button value must be a number. However, field _Conf.button_field has type 'string'."
598-
)
599-
600-
601686
def test_configuration_missing_buttons():
602687
with pytest.raises(ValueError) as e:
603688

0 commit comments

Comments
 (0)