Skip to content

Commit db6e834

Browse files
committed
feat(event): add typed channel event system with static type checking
1 parent 17f1e16 commit db6e834

File tree

12 files changed

+295
-44
lines changed

12 files changed

+295
-44
lines changed

chanx/generic/websocket.py

Lines changed: 154 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
offering a robust framework for building real-time applications with Django
66
Channels and Django REST Framework. The AsyncJsonWebsocketConsumer serves as the
77
foundation for WebSocket connections with integrated authentication, permissions,
8-
structured message handling, and group messaging capabilities.
8+
structured message handling, group messaging capabilities, and typed channel events.
99
1010
Key features:
1111
- DRF-style authentication and permission checking
1212
- Structured message handling with Pydantic validation
1313
- Automatic group management for pub/sub messaging
14+
- Typed channel event system with discriminated unions
1415
- Comprehensive error handling and reporting
1516
- Configurable logging and message completion signals
1617
- Support for object-level permissions and retrieval
@@ -25,13 +26,14 @@
2526
import sys
2627
import uuid
2728
from abc import ABC, abstractmethod
28-
from collections.abc import Iterable, Sequence
29+
from collections.abc import Awaitable, Callable, Iterable, Sequence
2930
from types import ModuleType
3031
from typing import (
3132
Annotated,
3233
Any,
3334
Generic,
3435
Literal,
36+
TypedDict,
3537
cast,
3638
get_args,
3739
get_origin,
@@ -40,6 +42,7 @@
4042
from channels.generic.websocket import (
4143
AsyncJsonWebsocketConsumer as BaseAsyncJsonWebsocketConsumer,
4244
)
45+
from channels.layers import get_channel_layer
4346
from django.contrib.auth.models import AnonymousUser, User
4447
from django.db.models import Model
4548
from rest_framework.authentication import BaseAuthentication
@@ -50,13 +53,15 @@
5053
)
5154

5255
import structlog
56+
from asgiref.sync import async_to_sync
5357
from asgiref.typing import WebSocketConnectEvent, WebSocketDisconnectEvent
5458
from pydantic import Field, TypeAdapter, ValidationError
5559
from typing_extensions import TypeVar, get_original_bases
5660

5761
from chanx.constants import MISSING_PYHUMPS_ERROR
5862
from chanx.generic.authenticator import ChanxWebsocketAuthenticator, QuerysetLike
5963
from chanx.messages.base import (
64+
BaseChannelEvent,
6065
BaseGroupMessage,
6166
BaseMessage,
6267
)
@@ -83,10 +88,15 @@
8388
"OG", bound=BaseGroupMessage | None, default=None
8489
) # Outgoing group messages
8590
M = TypeVar("M", bound=Model | None, default=None) # Object model
91+
Event = TypeVar("Event", bound=BaseChannelEvent | None, default=None) # Channel Events
92+
93+
94+
class EventPayload(TypedDict):
95+
event_data: BaseChannelEvent
8696

8797

8898
class AsyncJsonWebsocketConsumer(
89-
Generic[IC, OG, M], BaseAsyncJsonWebsocketConsumer, ABC
99+
Generic[IC, Event, OG, M], BaseAsyncJsonWebsocketConsumer, ABC
90100
):
91101
"""
92102
Base class for asynchronous JSON WebSocket consumers with authentication and permissions.
@@ -99,6 +109,9 @@ class AsyncJsonWebsocketConsumer(
99109
`OUTGOING_GROUP_MESSAGE_SCHEMA` to enable proper validation and handling
100110
of group message broadcasts.
101111
112+
For typed channel events, subclasses can define a union type of channel events
113+
and use the generic type parameter Event to enable type-safe channel event handling.
114+
102115
Attributes:
103116
authentication_classes: DRF authentication classes for connection verification
104117
permission_classes: DRF permission classes for connection authorization
@@ -135,6 +148,7 @@ class AsyncJsonWebsocketConsumer(
135148

136149
# Message schemas
137150
_INCOMING_MESSAGE_SCHEMA: IC
151+
_EVENT_SCHEMA: Event
138152
_OUTGOING_GROUP_MESSAGE_SCHEMA: OG
139153

140154
# Object instance
@@ -164,14 +178,21 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any):
164178
if i < len(type_var_vals) and var is not None:
165179
type_var_vals[i] = var
166180

167-
incoming_message_schema, outgoing_group_message_schema, _model = (
168-
type_var_vals
169-
)
181+
(
182+
incoming_message_schema,
183+
event_schema,
184+
outgoing_group_message_schema,
185+
_model,
186+
) = type_var_vals
170187
else:
171-
incoming_message_schema, outgoing_group_message_schema, _model = (
172-
get_args(base)
173-
)
188+
(
189+
incoming_message_schema,
190+
event_schema,
191+
outgoing_group_message_schema,
192+
_model,
193+
) = get_args(base)
174194
cls._INCOMING_MESSAGE_SCHEMA = incoming_message_schema
195+
cls._EVENT_SCHEMA = event_schema
175196
cls._OUTGOING_GROUP_MESSAGE_SCHEMA = outgoing_group_message_schema
176197
break
177198

@@ -237,6 +258,13 @@ def _setup_message_adapters(self) -> None:
237258
]
238259
)
239260

261+
self.event_adapter: TypeAdapter[Event] = TypeAdapter(
262+
Annotated[
263+
self._EVENT_SCHEMA,
264+
Field(discriminator="handler"),
265+
]
266+
)
267+
240268
self.outgoing_group_message_adapter: TypeAdapter[OG] = TypeAdapter(
241269
Annotated[
242270
self._OUTGOING_GROUP_MESSAGE_SCHEMA,
@@ -582,6 +610,123 @@ async def send_group_member(self, event: GroupMemberEvent) -> None:
582610
if self.send_completion:
583611
await self.send_message(GroupCompleteMessage())
584612

613+
# Channel event system methods
614+
@classmethod
615+
async def asend_channel_event(
616+
cls,
617+
group_name: str,
618+
event: Event,
619+
) -> None:
620+
"""
621+
Send a typed channel event to one or more channel groups.
622+
623+
This is a class method that provides a type-safe way to send events through
624+
the channel layer to consumers. It can be called from tasks, views, or other
625+
places where you don't have a consumer instance.
626+
627+
Args:
628+
event: The typed event to send (constrained by the consumer's Event type)
629+
group_name: Group name to send to (required)
630+
631+
Example:
632+
```python
633+
# From a Django task or view:
634+
await ChatDetailConsumer.send_channel_event(
635+
MemberAddedEvent(
636+
type="member_added",
637+
payload={"member_id": 123, "email": "user@example.com"}
638+
),
639+
groups=["chat_room_1"],
640+
from_user_pk=request.user.pk
641+
)
642+
```
643+
"""
644+
channel_layer = get_channel_layer()
645+
assert channel_layer is not None
646+
647+
assert event is not None
648+
await channel_layer.group_send(
649+
group_name,
650+
{
651+
"type": "handle_channel_event",
652+
"event_data": event.model_dump(),
653+
},
654+
)
655+
656+
@classmethod
657+
def send_channel_event(
658+
cls,
659+
group_name: str,
660+
event: Event,
661+
) -> None:
662+
"""
663+
Synchronous version of send_channel_event for use in Django tasks/views.
664+
665+
This method provides the same functionality as send_channel_event but
666+
can be called from synchronous code like Django tasks, views, or signals.
667+
668+
Args:
669+
group_name: Group name to send to (required)
670+
event: The typed event to send (constrained by the consumer's Event type)
671+
672+
Example:
673+
```python
674+
# From a Django task:
675+
ChatDetailConsumer.send_channel_event_sync(
676+
"chat_room_1",
677+
MemberAddedEvent(
678+
type="member_added",
679+
payload={"member_id": 123, "email": "user@example.com"}
680+
),
681+
)
682+
```
683+
"""
684+
async_to_sync(cls.asend_channel_event)(group_name, event)
685+
686+
async def handle_channel_event(self, event_payload: EventPayload) -> None:
687+
"""
688+
Internal dispatcher for channel events with completion signal.
689+
690+
This method is called by the channel layer when an event is sent to a group
691+
this consumer belongs to. It extracts the event data, checks exclusion rules,
692+
finds the appropriate handler method, and calls it with proper error handling.
693+
694+
Args:
695+
event_payload: The message from the channel layer containing event data
696+
and metadata about the sender
697+
"""
698+
event_data_dict = event_payload.get("event_data", {})
699+
event_data = self.event_adapter.validate_python(event_data_dict)
700+
701+
assert event_data is not None
702+
703+
try:
704+
handler_name = event_data.handler
705+
if not handler_name:
706+
await logger.awarning("Received channel event without handler field")
707+
return
708+
709+
# Find and call the handler method
710+
handler_method: Callable[[Event], Awaitable[None]] | None = getattr(
711+
self, handler_name, None
712+
)
713+
if not callable(handler_method):
714+
await logger.ainfo(
715+
f"Handler method '{handler_name}' is not available for sending event"
716+
)
717+
return
718+
719+
# Handler is async, await it
720+
await handler_method(event_data)
721+
722+
except Exception as e:
723+
await logger.aexception(f"Failed to process channel event: {str(e)}")
724+
# Don't re-raise to avoid breaking the channel layer
725+
finally:
726+
# Send completion signal if configured
727+
if self.send_completion:
728+
await self.send_message(CompleteMessage())
729+
585730
# Helper methods
586731

587732
async def _handle_receive_json_and_signal_complete(

chanx/messages/base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@
2121
"""
2222

2323
import abc
24-
from typing import (
25-
Any,
26-
Literal,
27-
get_origin,
28-
)
24+
from typing import Any, Literal, get_origin
2925

3026
from pydantic import BaseModel, ConfigDict
3127
from typing_extensions import Unpack
@@ -93,3 +89,8 @@ class BaseGroupMessage(BaseMessage, abc.ABC):
9389

9490
is_mine: bool = False
9591
is_current: bool = False
92+
93+
94+
class BaseChannelEvent(BaseModel, abc.ABC):
95+
handler: Any
96+
payload: Any

docs/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

sandbox/assistants/consumers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Any
22

3-
from rest_framework.permissions import IsAuthenticated
4-
53
from chanx.generic.websocket import AsyncJsonWebsocketConsumer
64
from chanx.messages.incoming import PingMessage
75
from chanx.messages.outgoing import PongMessage
@@ -17,7 +15,8 @@
1715
class AssistantConsumer(AsyncJsonWebsocketConsumer[AssistantIncomingMessage]):
1816
"""Websocket to chat with server, like chat with chatbot system"""
1917

20-
permission_classes = [IsAuthenticated]
18+
permission_classes = []
19+
authentication_classes = []
2120

2221
async def receive_message(
2322
self, message: AssistantIncomingMessage, **kwargs: Any

sandbox/assistants/tests/test_consumers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import patch
44
from uuid import uuid4
55

6+
from django.conf import settings
67
from rest_framework import status
78

89
import pytest
@@ -288,3 +289,26 @@ async def test_ignore_actions(self) -> None:
288289

289290
# Should not log received message for silent action
290291
assert "ping" not in str(mock_logger.call_args_list)
292+
293+
async def test_unauthorized_connection(self) -> None:
294+
# Test anonymous connection
295+
anonymous_communicator = self.create_communicator(
296+
headers=[
297+
(b"origin", settings.SERVER_URL.encode()),
298+
]
299+
)
300+
await anonymous_communicator.connect()
301+
await anonymous_communicator.assert_authenticated_status_ok()
302+
303+
# Test chat functionality
304+
message_content = "My message content"
305+
await anonymous_communicator.send_message(
306+
NewMessage(payload=MessagePayload(content=message_content))
307+
)
308+
309+
all_messages = await anonymous_communicator.receive_all_json()
310+
assert all_messages == [
311+
ReplyMessage(
312+
payload=MessagePayload(content=f"Reply: {message_content}")
313+
).model_dump()
314+
]

sandbox/chat/consumers/chat_detail.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818

1919
class ChatDetailConsumer(
20-
AsyncJsonWebsocketConsumer[ChatIncomingMessage, OutgoingGroupMessage, GroupChat]
20+
AsyncJsonWebsocketConsumer[
21+
ChatIncomingMessage, None, OutgoingGroupMessage, GroupChat
22+
]
2123
):
2224
permission_classes = [IsGroupChatMember]
2325
queryset = GroupChat.objects.get_queryset()

sandbox/config/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from django.urls import include, path
2222
from django.views.generic import RedirectView
2323

24+
from discussion.views import NotifyView
2425
from drf_spectacular.views import (
2526
SpectacularAPIView,
2627
SpectacularRedocView,
@@ -31,6 +32,7 @@
3132
path("accounts/", include("accounts.urls")),
3233
path("chat/", include("chat.urls")),
3334
path("playground/", include("chanx.playground.urls")),
35+
path("notify/", NotifyView.as_view(), name="notify_people"),
3436
path("schema/", SpectacularAPIView.as_view(), name="schema"),
3537
path(
3638
"schema/swg/",

0 commit comments

Comments
 (0)