1
1
from __future__ import annotations
2
2
3
3
from abc import ABC
4
- from typing import Any , Literal , get_args , get_type_hints
4
+ from typing import Annotated , Any , Generic , Literal , get_args , get_origin
5
5
6
- from pydantic import BaseModel , ConfigDict
6
+ from pydantic import BaseModel , ConfigDict , GetPydanticSchema
7
+ from pydantic ._internal ._generics import get_origin as get_model_origin # type: ignore[import]
8
+ from pydantic_core import core_schema
7
9
from typing_extensions import ( # noqa: UP035
8
10
LiteralString ,
9
- TypeIs ,
11
+ TypeAlias ,
12
+ TypeGuard ,
13
+ TypeVar ,
10
14
override ,
11
15
)
12
16
@@ -33,12 +37,58 @@ def model_dump_json(self, **kwargs: Any) -> str:
33
37
return super ().model_dump_json (** kwargs )
34
38
35
39
36
- def is_literal_str_type (value : object | None ) -> TypeIs [LiteralString ]:
37
- """Check if a type is a Literal type with string values."""
40
+ # We do this to get the typing module's _LiteralGenericAlias type, which is not formally exported.
41
+ _LiteralStrGenericAlias : TypeAlias = type (Literal ["whatever" ]) # type: ignore[valid-type] # noqa: UP040
42
+ """A generic alias for a Literal type used for internal mechanisms of this module.
43
+
44
+ This is opposed to LiteralStrGenericAlias which is used for typing.
45
+ """
46
+
47
+
48
+ # Set this variable here to call the function just once.
49
+ _pydantic_str_schema = core_schema .str_schema ()
50
+
51
+ GetPydanticStrSchema = GetPydanticSchema (lambda _ts , handler : handler (_pydantic_str_schema ))
52
+ """A function that returns a Pydantic schema for a string type."""
53
+
54
+ PydanticLiteralStrGenericAlias : TypeAlias = Annotated [ # type: ignore[valid-type] # noqa: UP040
55
+ _LiteralStrGenericAlias ,
56
+ GetPydanticStrSchema ,
57
+ ]
58
+ """A Pydantic-compatible generic alias for a Literal type.
59
+
60
+ Pydantic will treat a field of this type as a string schema, while static type checkers
61
+ will still treat it as a _LiteralGenericAlias type.
62
+
63
+ Even if a subclass of EventBase uses a Literal with multiple string values,
64
+ an event message will only ever have one of those values in the event field,
65
+ and so we don't need to handle this with a more complex Pydantic schema.
66
+ """
67
+
68
+
69
+ # This type alias is used to handle static type checking accurately while still conveying that
70
+ # a value is expected to be a Literal with string type args.
71
+ LiteralStrGenericAlias : TypeAlias = Annotated [ # noqa: UP040
72
+ LiteralString ,
73
+ GetPydanticStrSchema ,
74
+ ]
75
+ """Type alias for a generic literal string type that is compatible with Pydantic."""
76
+
77
+
78
+ # covariant=True is used to allow subclasses of EventBase to be used in place of the base class.
79
+ LiteralEventName_co = TypeVar ("LiteralEventName_co" , bound = PydanticLiteralStrGenericAlias , default = PydanticLiteralStrGenericAlias , covariant = True )
80
+ """Type variable for a Literal type with string args."""
81
+
82
+
83
+ def is_literal_str_generic_alias_type (value : object | None ) -> TypeGuard [LiteralStrGenericAlias ]:
84
+ """Check if a type is a concrete Literal type with string args."""
38
85
if value is None :
39
86
return False
40
87
41
- event_field_base_type = getattr (value , "__origin__" , None )
88
+ if isinstance (value , TypeVar ):
89
+ return False
90
+
91
+ event_field_base_type = get_origin (value )
42
92
43
93
if event_field_base_type is not Literal :
44
94
return False
@@ -48,12 +98,10 @@ def is_literal_str_type(value: object | None) -> TypeIs[LiteralString]:
48
98
49
99
## EventBase implementation model of the Stream Deck Plugin SDK events.
50
100
51
- class EventBase (ConfiguredBaseModel , ABC ):
101
+ class EventBase (ConfiguredBaseModel , ABC , Generic [ LiteralEventName_co ] ):
52
102
"""Base class for event models that represent Stream Deck Plugin SDK events."""
53
- # Configure to use the docstrings of the fields as the field descriptions.
54
- model_config = ConfigDict (use_attribute_docstrings = True , serialize_by_alias = True )
55
103
56
- event : str
104
+ event : LiteralEventName_co
57
105
"""Name of the event used to identify what occurred.
58
106
59
107
Subclass models must define this field as a Literal type with the event name string that the model represents.
@@ -63,25 +111,30 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
63
111
"""Validate that the event field is a Literal[str] type."""
64
112
super ().__init_subclass__ (** kwargs )
65
113
66
- model_event_type = cls .get_event_type_annotations ()
114
+ # This is a GenericAlias (likely used in the subclass definition, i.e. `class ConcreteEvent(EventBase[Literal["event_name"]]):`) which is technically a subclass.
115
+ # We can safely ignore this case, as we only want to validate the concrete subclass itself (`ConscreteEvent`).
116
+ if get_model_origin (cls ) is None :
117
+ return
118
+
119
+ model_event_type = cls .__event_type__ ()
67
120
68
- if not is_literal_str_type (model_event_type ):
121
+ if not is_literal_str_generic_alias_type (model_event_type ):
69
122
msg = f"The event field annotation must be a Literal[str] type. Given type: { model_event_type } "
70
123
raise TypeError (msg )
71
124
72
125
@classmethod
73
- def get_event_type_annotations (cls ) -> type [object ]:
126
+ def __event_type__ (cls ) -> type [object ]:
74
127
"""Get the type annotations of the subclass model's event field."""
75
- return get_type_hints ( cls ) ["event" ]
128
+ return cls . model_fields ["event" ]. annotation # type: ignore[index ]
76
129
77
130
@classmethod
78
- def get_model_event_name (cls ) -> tuple [str , ...]:
131
+ def get_model_event_names (cls ) -> tuple [str , ...]:
79
132
"""Get the value of the subclass model's event field Literal annotation."""
80
- model_event_type = cls .get_event_type_annotations ()
133
+ model_event_type = cls .__event_type__ ()
81
134
82
135
# Ensure that the event field annotation is a Literal type.
83
- if not is_literal_str_type (model_event_type ):
84
- msg = "The ` event` field annotation of an Event model must be a Literal[str] type."
136
+ if not is_literal_str_generic_alias_type (model_event_type ):
137
+ msg = f "The event field annotation of an Event model must be a Literal[str] type. Given type: { model_event_type } "
85
138
raise TypeError (msg )
86
139
87
140
return get_args (model_event_type )
0 commit comments