Skip to content

Commit 17f1e16

Browse files
committed
feat(generic): add generic type parameters and simplify message architecture
This change replaces manual schema declarations with generic type parameters on AsyncJsonWebsocketConsumer, allowing developers to use AsyncJsonWebsocketConsumer[MyMessageType] instead of defining INCOMING_MESSAGE_SCHEMA. It eliminates complex container classes in favor of direct union types, uses Pydantic's TypeAdapter for validation, and provides proper typing for message handlers while reducing code complexity and improving type safety. BREAKING CHANGE:
1 parent f59664e commit 17f1e16

File tree

20 files changed

+277
-627
lines changed

20 files changed

+277
-627
lines changed

chanx/generic/websocket.py

Lines changed: 99 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,19 @@
2222
"""
2323

2424
import asyncio
25+
import sys
2526
import uuid
2627
from abc import ABC, abstractmethod
2728
from collections.abc import Iterable, Sequence
2829
from types import ModuleType
2930
from typing import (
31+
Annotated,
3032
Any,
3133
Generic,
3234
Literal,
3335
cast,
36+
get_args,
37+
get_origin,
3438
)
3539

3640
from channels.generic.websocket import (
@@ -47,15 +51,14 @@
4751

4852
import structlog
4953
from asgiref.typing import WebSocketConnectEvent, WebSocketDisconnectEvent
50-
from pydantic import ValidationError
51-
from typing_extensions import TypeVar
54+
from pydantic import Field, TypeAdapter, ValidationError
55+
from typing_extensions import TypeVar, get_original_bases
5256

5357
from chanx.constants import MISSING_PYHUMPS_ERROR
5458
from chanx.generic.authenticator import ChanxWebsocketAuthenticator, QuerysetLike
5559
from chanx.messages.base import (
56-
BaseIncomingMessage,
60+
BaseGroupMessage,
5761
BaseMessage,
58-
BaseOutgoingGroupMessage,
5962
)
6063
from chanx.messages.outgoing import (
6164
AuthenticationMessage,
@@ -75,10 +78,16 @@
7578
humps = cast(ModuleType, None) # pragma: no cover
7679

7780

78-
_M = TypeVar("_M", bound=Model, default=Model)
81+
IC = TypeVar("IC", bound=BaseMessage) # Incoming messages
82+
OG = TypeVar(
83+
"OG", bound=BaseGroupMessage | None, default=None
84+
) # Outgoing group messages
85+
M = TypeVar("M", bound=Model | None, default=None) # Object model
7986

8087

81-
class AsyncJsonWebsocketConsumer(Generic[_M], BaseAsyncJsonWebsocketConsumer, ABC):
88+
class AsyncJsonWebsocketConsumer(
89+
Generic[IC, OG, M], BaseAsyncJsonWebsocketConsumer, ABC
90+
):
8291
"""
8392
Base class for asynchronous JSON WebSocket consumers with authentication and permissions.
8493
@@ -102,9 +111,6 @@ class AsyncJsonWebsocketConsumer(Generic[_M], BaseAsyncJsonWebsocketConsumer, AB
102111
log_sent_message: Whether to log sent messages
103112
log_ignored_actions: Message actions that should not be logged
104113
send_authentication_message: Whether to send auth status after connection
105-
INCOMING_MESSAGE_SCHEMA: Pydantic model class for validating incoming messages
106-
OUTGOING_GROUP_MESSAGE_SCHEMA: Pydantic model class for validating group broadcast messages,
107-
required when using send_group_message with kind="message"
108114
"""
109115

110116
# Authentication attributes
@@ -128,8 +134,46 @@ class AsyncJsonWebsocketConsumer(Generic[_M], BaseAsyncJsonWebsocketConsumer, AB
128134
send_authentication_message: bool | None = None
129135

130136
# Message schemas
131-
INCOMING_MESSAGE_SCHEMA: type[BaseIncomingMessage]
132-
OUTGOING_GROUP_MESSAGE_SCHEMA: type[BaseOutgoingGroupMessage]
137+
_INCOMING_MESSAGE_SCHEMA: IC
138+
_OUTGOING_GROUP_MESSAGE_SCHEMA: OG
139+
140+
# Object instance
141+
obj: M
142+
143+
def __init_subclass__(cls, *args: Any, **kwargs: Any):
144+
super().__init_subclass__(*args, **kwargs)
145+
146+
# Extract the actual type from Generic parameters
147+
orig_bases = get_original_bases(cls)
148+
for base in orig_bases:
149+
if base is AsyncJsonWebsocketConsumer:
150+
raise ValueError(
151+
f"Class {cls.__name__!r} must specify at least the incoming message type as a generic parameter. "
152+
f"Hint: class {cls.__name__}(AsyncJsonWebsocketConsumer[YourMessageType])"
153+
)
154+
if get_origin(base) is AsyncJsonWebsocketConsumer:
155+
# Workaround for TypeVar defaults handling differences across Python versions:
156+
# - In Python 3.10, get_args() only returns non-default types
157+
# - In Python 3.11+, get_args() returns all type arguments including defaults
158+
# We create a fixed-size array and populate it with available type arguments
159+
if sys.version_info < (3, 11): # pragma: no cover
160+
# Generic part of AsyncJsonWebsocketConsumer
161+
generic_types = get_original_bases(AsyncJsonWebsocketConsumer)[0]
162+
type_var_vals: list[Any] = [None] * len(get_args(generic_types))
163+
for i, var in enumerate(get_args(base)):
164+
if i < len(type_var_vals) and var is not None:
165+
type_var_vals[i] = var
166+
167+
incoming_message_schema, outgoing_group_message_schema, _model = (
168+
type_var_vals
169+
)
170+
else:
171+
incoming_message_schema, outgoing_group_message_schema, _model = (
172+
get_args(base)
173+
)
174+
cls._INCOMING_MESSAGE_SCHEMA = incoming_message_schema
175+
cls._OUTGOING_GROUP_MESSAGE_SCHEMA = outgoing_group_message_schema
176+
break
133177

134178
def __init__(self, *args: Any, **kwargs: Any) -> None:
135179
"""
@@ -143,7 +187,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
143187
ValueError: If INCOMING_MESSAGE_SCHEMA is not set
144188
"""
145189
super().__init__(*args, **kwargs)
146-
# Initialize configuration from settings if not set
190+
191+
# Load configuration and validate
192+
self._load_configuration_from_settings()
193+
self._setup_message_adapters()
194+
195+
# Create authenticator
196+
self.authenticator = self._create_authenticator()
197+
198+
# Initialize instance attributes
199+
self._initialize_instance_attributes()
200+
201+
# Validate optional dependencies
202+
self._validate_optional_dependencies()
203+
204+
def _load_configuration_from_settings(self) -> None:
205+
"""Load configuration values from settings if not already set."""
147206
if self.send_completion is None:
148207
self.send_completion = chanx_settings.SEND_COMPLETION
149208

@@ -159,27 +218,40 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
159218
if self.log_ignored_actions is None:
160219
self.log_ignored_actions = chanx_settings.LOG_IGNORED_ACTIONS
161220

162-
self.ignore_actions: set[str] = (
163-
set(self.log_ignored_actions) if self.log_ignored_actions else set()
164-
)
165-
166221
if self.send_authentication_message is None:
167222
self.send_authentication_message = (
168223
chanx_settings.SEND_AUTHENTICATION_MESSAGE
169224
)
170225

171-
if not hasattr(self, "INCOMING_MESSAGE_SCHEMA"):
172-
raise ValueError("INCOMING_MESSAGE_SCHEMA attribute is required.")
226+
# Process ignored actions
227+
self.ignore_actions: set[str] = (
228+
set(self.log_ignored_actions) if self.log_ignored_actions else set()
229+
)
173230

174-
# Create authenticator
175-
self.authenticator = self._create_authenticator()
231+
def _setup_message_adapters(self) -> None:
232+
"""Set up Pydantic type adapters for message validation."""
233+
self.incoming_message_adapter: TypeAdapter[IC] = TypeAdapter(
234+
Annotated[
235+
self._INCOMING_MESSAGE_SCHEMA,
236+
Field(discriminator=chanx_settings.MESSAGE_ACTION_KEY),
237+
]
238+
)
176239

177-
# Initialize instance attributes
240+
self.outgoing_group_message_adapter: TypeAdapter[OG] = TypeAdapter(
241+
Annotated[
242+
self._OUTGOING_GROUP_MESSAGE_SCHEMA,
243+
Field(discriminator=chanx_settings.MESSAGE_ACTION_KEY),
244+
]
245+
)
246+
247+
def _initialize_instance_attributes(self) -> None:
248+
"""Initialize instance attributes to their default values."""
178249
self.user: User | AnonymousUser | None = None
179-
self.obj: _M | None = None
180250
self.group_name: str | None = None
181251
self.connecting: bool = False
182252

253+
def _validate_optional_dependencies(self) -> None:
254+
"""Validate that optional dependencies are available when needed."""
183255
if chanx_settings.CAMELIZE:
184256
if not humps:
185257
raise RuntimeError(MISSING_PYHUMPS_ERROR)
@@ -332,7 +404,7 @@ async def receive_json(self, content: dict[str, Any], **kwargs: Any) -> None:
332404
structlog.contextvars.reset_contextvars(**token)
333405

334406
@abstractmethod
335-
async def receive_message(self, message: BaseMessage, **kwargs: Any) -> None:
407+
async def receive_message(self, message: IC, **kwargs: Any) -> None:
336408
"""
337409
Process a validated received message.
338410
@@ -501,9 +573,8 @@ async def send_group_member(self, event: GroupMemberEvent) -> None:
501573
)
502574

503575
if kind == "message":
504-
message = self.OUTGOING_GROUP_MESSAGE_SCHEMA.model_validate(
505-
{"group_message": content}
506-
).group_message
576+
message = self.outgoing_group_message_adapter.validate_python(content)
577+
assert message is not None
507578
await self.send_message(message)
508579
else:
509580
await self.send_json(content)
@@ -527,9 +598,8 @@ async def _handle_receive_json_and_signal_complete(
527598
**kwargs: Additional keyword arguments
528599
"""
529600
try:
530-
message = self.INCOMING_MESSAGE_SCHEMA.model_validate(
531-
{"message": content}
532-
).message
601+
602+
message = self.incoming_message_adapter.validate_python(content)
533603

534604
await self.receive_message(message, **kwargs)
535605
except ValidationError as e:

chanx/messages/base.py

Lines changed: 1 addition & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,13 @@
2121
"""
2222

2323
import abc
24-
from types import UnionType
2524
from typing import (
2625
Any,
27-
ClassVar,
2826
Literal,
29-
TypeVar,
30-
Union,
31-
get_args,
3227
get_origin,
3328
)
3429

35-
from pydantic import BaseModel, ConfigDict, Field
30+
from pydantic import BaseModel, ConfigDict
3631
from typing_extensions import Unpack
3732

3833

@@ -98,128 +93,3 @@ class BaseGroupMessage(BaseMessage, abc.ABC):
9893

9994
is_mine: bool = False
10095
is_current: bool = False
101-
102-
103-
# TypeVar for the message base class type
104-
T = TypeVar("T", bound=BaseMessage)
105-
106-
107-
class MessageContainerMixin(BaseModel, abc.ABC):
108-
"""
109-
Mixin for message container classes that wrap a message type.
110-
111-
This mixin provides common validation logic for classes that contain
112-
a field with a union of message types using a discriminator.
113-
114-
Attributes:
115-
_message_field_name: Name of the field containing the message
116-
_message_base_class: Base class that all message types must inherit from
117-
"""
118-
119-
_message_field_name: ClassVar[str]
120-
_message_base_class: ClassVar[type[BaseMessage]]
121-
122-
def __init_subclass__(cls, **kwargs: Unpack[ConfigDict]):
123-
"""
124-
Validates that subclasses properly define a message field that uses
125-
a union of specified base message types for type discrimination.
126-
127-
Args:
128-
**kwargs: Configuration options for Pydantic model
129-
130-
Raises:
131-
TypeError: If required message field is missing or not of correct type
132-
"""
133-
super().__init_subclass__(**kwargs)
134-
135-
if abc.ABC in cls.__bases__:
136-
return
137-
138-
field_name = cls._message_field_name
139-
base_class = cls._message_base_class
140-
141-
try:
142-
message_field = cls.__annotations__[field_name]
143-
except (KeyError, AttributeError) as e:
144-
raise TypeError(
145-
f"Class {cls.__name__!r} must define a '{field_name}' field"
146-
) from e
147-
148-
# Check if it's a Union type
149-
origin = get_origin(message_field)
150-
if origin is Union or origin is UnionType:
151-
# Get all union members
152-
args = get_args(message_field)
153-
154-
# Validate all union members are correct base class subclasses
155-
for arg in args:
156-
if not issubclass(arg, base_class):
157-
raise TypeError(
158-
f"All union members in '{field_name}' field of {cls.__name__!r} must be "
159-
f"subclasses of {base_class.__name__}, got {arg}"
160-
)
161-
# Or a direct subclass of the base class
162-
elif not (
163-
message_field is base_class
164-
or (
165-
isinstance(message_field, type)
166-
and issubclass(message_field, base_class)
167-
)
168-
):
169-
raise TypeError(
170-
f"The '{field_name}' field of {cls.__name__!r} must be {base_class.__name__} "
171-
f"or a union of {base_class.__name__} subclasses, got {message_field}"
172-
)
173-
174-
# Check if discriminator is already explicitly set
175-
has_discriminator = False
176-
if (
177-
hasattr(cls, field_name)
178-
and getattr(getattr(cls, field_name, None), "discriminator", None)
179-
is not None
180-
):
181-
has_discriminator = True
182-
183-
# Add discriminator automatically if not explicitly set
184-
if not has_discriminator:
185-
# Check if there's a settings module with MESSAGE_ACTION_KEY
186-
from chanx.settings import chanx_settings
187-
188-
# Update the field with discriminator
189-
cls.model_fields[field_name] = Field(
190-
discriminator=chanx_settings.MESSAGE_ACTION_KEY
191-
)
192-
193-
194-
class BaseIncomingMessage(MessageContainerMixin, abc.ABC):
195-
"""
196-
Base WebSocket incoming message wrapper.
197-
198-
This class serves as a container for incoming WebSocket messages,
199-
allowing for a discriminated union pattern where the 'message' field
200-
can contain any message type derived from BaseMessage.
201-
202-
Attributes:
203-
message: The wrapped message object, using action as discriminator field
204-
"""
205-
206-
_message_field_name: ClassVar[str] = "message"
207-
_message_base_class: ClassVar[type[BaseMessage]] = BaseMessage
208-
209-
message: Any
210-
211-
212-
class BaseOutgoingGroupMessage(MessageContainerMixin, abc.ABC):
213-
"""
214-
Base WebSocket outgoing group message wrapper.
215-
216-
Similar to BaseIncomingMessage, but for group messages being sent out.
217-
218-
Attributes:
219-
group_message: The wrapped group message
220-
"""
221-
222-
_message_field_name: ClassVar[str] = "group_message"
223-
_message_base_class: ClassVar[type[BaseMessage]] = BaseGroupMessage
224-
225-
group_message: Any

0 commit comments

Comments
 (0)