Skip to content

Commit 0c49dab

Browse files
authored
fix: generate StrEnum types for enums (#134)
1 parent 4d658a5 commit 0c49dab

File tree

4 files changed

+16
-5
lines changed

4 files changed

+16
-5
lines changed

.github/workflows/update-a2a-types.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ jobs:
5858
--use-default-kwarg \
5959
--use-one-literal-as-default \
6060
--class-name A2A \
61-
--use-standard-collections
61+
--use-standard-collections \
62+
--use-subclass-enum
6263
echo "Codegen finished."
6364
6465
- name: Create Pull Request with Updates

development.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ uv run datamodel-codegen \
1717
--use-default-kwarg \
1818
--use-one-literal-as-default \
1919
--class-name A2A \
20-
--use-standard-collections
20+
--use-standard-collections \
21+
--use-subclass-enum
2122
```

src/a2a/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class A2A(RootModel[Any]):
1313
root: Any
1414

1515

16-
class In(Enum):
16+
class In(str, Enum):
1717
"""
1818
The location of the API key. Valid values are "query", "header", or "cookie".
1919
"""
@@ -484,7 +484,7 @@ class JSONRPCSuccessResponse(BaseModel):
484484
"""
485485

486486

487-
class Role(Enum):
487+
class Role(str, Enum):
488488
"""
489489
Message sender's role
490490
"""
@@ -731,7 +731,7 @@ class TaskResubscriptionRequest(BaseModel):
731731
"""
732732

733733

734-
class TaskState(Enum):
734+
class TaskState(str, Enum):
735735
"""
736736
Represents the possible states of a Task.
737737
"""

tests/test_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,3 +1487,12 @@ def test_a2a_error_validation_and_serialization() -> None:
14871487
invalid_data: dict[str, Any] = {'code': -99999, 'message': 'Unknown error'}
14881488
with pytest.raises(ValidationError):
14891489
A2AError.model_validate(invalid_data)
1490+
1491+
1492+
def test_subclass_enums() -> None:
1493+
"""validate subtype enum types"""
1494+
assert "cookie" == In.cookie
1495+
1496+
assert "user" == Role.user
1497+
1498+
assert "working" == TaskState.working

0 commit comments

Comments
 (0)