2222"""
2323
2424import asyncio
25+ import sys
2526import uuid
2627from abc import ABC , abstractmethod
2728from collections .abc import Iterable , Sequence
2829from types import ModuleType
2930from typing import (
31+ Annotated ,
3032 Any ,
3133 Generic ,
3234 Literal ,
3335 cast ,
36+ get_args ,
37+ get_origin ,
3438)
3539
3640from channels .generic .websocket import (
4751
4852import structlog
4953from asgiref .typing import WebSocketConnectEvent , WebSocketDisconnectEvent
50- from pydantic import ValidationError
51- from typing_extensions import TypeVar
54+ from pydantic import Field , TypeAdapter , ValidationError
55+ from typing_extensions import TypeVar , get_original_bases
5256
5357from chanx .constants import MISSING_PYHUMPS_ERROR
5458from chanx .generic .authenticator import ChanxWebsocketAuthenticator , QuerysetLike
5559from chanx .messages .base import (
56- BaseIncomingMessage ,
60+ BaseGroupMessage ,
5761 BaseMessage ,
58- BaseOutgoingGroupMessage ,
5962)
6063from chanx .messages .outgoing import (
6164 AuthenticationMessage ,
7578 humps = cast (ModuleType , None ) # pragma: no cover
7679
7780
78- _M = TypeVar ("_M" , bound = Model , default = Model )
81+ IC = TypeVar ("IC" , bound = BaseMessage ) # Incoming messages
82+ OG = TypeVar (
83+ "OG" , bound = BaseGroupMessage | None , default = None
84+ ) # Outgoing group messages
85+ M = TypeVar ("M" , bound = Model | None , default = None ) # Object model
7986
8087
81- class AsyncJsonWebsocketConsumer (Generic [_M ], BaseAsyncJsonWebsocketConsumer , ABC ):
88+ class AsyncJsonWebsocketConsumer (
89+ Generic [IC , OG , M ], BaseAsyncJsonWebsocketConsumer , ABC
90+ ):
8291 """
8392 Base class for asynchronous JSON WebSocket consumers with authentication and permissions.
8493
@@ -102,9 +111,6 @@ class AsyncJsonWebsocketConsumer(Generic[_M], BaseAsyncJsonWebsocketConsumer, AB
102111 log_sent_message: Whether to log sent messages
103112 log_ignored_actions: Message actions that should not be logged
104113 send_authentication_message: Whether to send auth status after connection
105- INCOMING_MESSAGE_SCHEMA: Pydantic model class for validating incoming messages
106- OUTGOING_GROUP_MESSAGE_SCHEMA: Pydantic model class for validating group broadcast messages,
107- required when using send_group_message with kind="message"
108114 """
109115
110116 # Authentication attributes
@@ -128,8 +134,46 @@ class AsyncJsonWebsocketConsumer(Generic[_M], BaseAsyncJsonWebsocketConsumer, AB
128134 send_authentication_message : bool | None = None
129135
130136 # Message schemas
131- INCOMING_MESSAGE_SCHEMA : type [BaseIncomingMessage ]
132- OUTGOING_GROUP_MESSAGE_SCHEMA : type [BaseOutgoingGroupMessage ]
137+ _INCOMING_MESSAGE_SCHEMA : IC
138+ _OUTGOING_GROUP_MESSAGE_SCHEMA : OG
139+
140+ # Object instance
141+ obj : M
142+
143+ def __init_subclass__ (cls , * args : Any , ** kwargs : Any ):
144+ super ().__init_subclass__ (* args , ** kwargs )
145+
146+ # Extract the actual type from Generic parameters
147+ orig_bases = get_original_bases (cls )
148+ for base in orig_bases :
149+ if base is AsyncJsonWebsocketConsumer :
150+ raise ValueError (
151+ f"Class { cls .__name__ !r} must specify at least the incoming message type as a generic parameter. "
152+ f"Hint: class { cls .__name__ } (AsyncJsonWebsocketConsumer[YourMessageType])"
153+ )
154+ if get_origin (base ) is AsyncJsonWebsocketConsumer :
155+ # Workaround for TypeVar defaults handling differences across Python versions:
156+ # - In Python 3.10, get_args() only returns non-default types
157+ # - In Python 3.11+, get_args() returns all type arguments including defaults
158+ # We create a fixed-size array and populate it with available type arguments
159+ if sys .version_info < (3 , 11 ): # pragma: no cover
160+ # Generic part of AsyncJsonWebsocketConsumer
161+ generic_types = get_original_bases (AsyncJsonWebsocketConsumer )[0 ]
162+ type_var_vals : list [Any ] = [None ] * len (get_args (generic_types ))
163+ for i , var in enumerate (get_args (base )):
164+ if i < len (type_var_vals ) and var is not None :
165+ type_var_vals [i ] = var
166+
167+ incoming_message_schema , outgoing_group_message_schema , _model = (
168+ type_var_vals
169+ )
170+ else :
171+ incoming_message_schema , outgoing_group_message_schema , _model = (
172+ get_args (base )
173+ )
174+ cls ._INCOMING_MESSAGE_SCHEMA = incoming_message_schema
175+ cls ._OUTGOING_GROUP_MESSAGE_SCHEMA = outgoing_group_message_schema
176+ break
133177
134178 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
135179 """
@@ -143,7 +187,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
143187 ValueError: If INCOMING_MESSAGE_SCHEMA is not set
144188 """
145189 super ().__init__ (* args , ** kwargs )
146- # Initialize configuration from settings if not set
190+
191+ # Load configuration and validate
192+ self ._load_configuration_from_settings ()
193+ self ._setup_message_adapters ()
194+
195+ # Create authenticator
196+ self .authenticator = self ._create_authenticator ()
197+
198+ # Initialize instance attributes
199+ self ._initialize_instance_attributes ()
200+
201+ # Validate optional dependencies
202+ self ._validate_optional_dependencies ()
203+
204+ def _load_configuration_from_settings (self ) -> None :
205+ """Load configuration values from settings if not already set."""
147206 if self .send_completion is None :
148207 self .send_completion = chanx_settings .SEND_COMPLETION
149208
@@ -159,27 +218,40 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
159218 if self .log_ignored_actions is None :
160219 self .log_ignored_actions = chanx_settings .LOG_IGNORED_ACTIONS
161220
162- self .ignore_actions : set [str ] = (
163- set (self .log_ignored_actions ) if self .log_ignored_actions else set ()
164- )
165-
166221 if self .send_authentication_message is None :
167222 self .send_authentication_message = (
168223 chanx_settings .SEND_AUTHENTICATION_MESSAGE
169224 )
170225
171- if not hasattr (self , "INCOMING_MESSAGE_SCHEMA" ):
172- raise ValueError ("INCOMING_MESSAGE_SCHEMA attribute is required." )
226+ # Process ignored actions
227+ self .ignore_actions : set [str ] = (
228+ set (self .log_ignored_actions ) if self .log_ignored_actions else set ()
229+ )
173230
174- # Create authenticator
175- self .authenticator = self ._create_authenticator ()
231+ def _setup_message_adapters (self ) -> None :
232+ """Set up Pydantic type adapters for message validation."""
233+ self .incoming_message_adapter : TypeAdapter [IC ] = TypeAdapter (
234+ Annotated [
235+ self ._INCOMING_MESSAGE_SCHEMA ,
236+ Field (discriminator = chanx_settings .MESSAGE_ACTION_KEY ),
237+ ]
238+ )
176239
177- # Initialize instance attributes
240+ self .outgoing_group_message_adapter : TypeAdapter [OG ] = TypeAdapter (
241+ Annotated [
242+ self ._OUTGOING_GROUP_MESSAGE_SCHEMA ,
243+ Field (discriminator = chanx_settings .MESSAGE_ACTION_KEY ),
244+ ]
245+ )
246+
247+ def _initialize_instance_attributes (self ) -> None :
248+ """Initialize instance attributes to their default values."""
178249 self .user : User | AnonymousUser | None = None
179- self .obj : _M | None = None
180250 self .group_name : str | None = None
181251 self .connecting : bool = False
182252
253+ def _validate_optional_dependencies (self ) -> None :
254+ """Validate that optional dependencies are available when needed."""
183255 if chanx_settings .CAMELIZE :
184256 if not humps :
185257 raise RuntimeError (MISSING_PYHUMPS_ERROR )
@@ -332,7 +404,7 @@ async def receive_json(self, content: dict[str, Any], **kwargs: Any) -> None:
332404 structlog .contextvars .reset_contextvars (** token )
333405
334406 @abstractmethod
335- async def receive_message (self , message : BaseMessage , ** kwargs : Any ) -> None :
407+ async def receive_message (self , message : IC , ** kwargs : Any ) -> None :
336408 """
337409 Process a validated received message.
338410
@@ -501,9 +573,8 @@ async def send_group_member(self, event: GroupMemberEvent) -> None:
501573 )
502574
503575 if kind == "message" :
504- message = self .OUTGOING_GROUP_MESSAGE_SCHEMA .model_validate (
505- {"group_message" : content }
506- ).group_message
576+ message = self .outgoing_group_message_adapter .validate_python (content )
577+ assert message is not None
507578 await self .send_message (message )
508579 else :
509580 await self .send_json (content )
@@ -527,9 +598,8 @@ async def _handle_receive_json_and_signal_complete(
527598 **kwargs: Additional keyword arguments
528599 """
529600 try :
530- message = self .INCOMING_MESSAGE_SCHEMA .model_validate (
531- {"message" : content }
532- ).message
601+
602+ message = self .incoming_message_adapter .validate_python (content )
533603
534604 await self .receive_message (message , ** kwargs )
535605 except ValidationError as e :
0 commit comments