Skip to content

Commit 959e65b

Browse files
committed
WIP
1 parent 3c86b58 commit 959e65b

File tree

25 files changed

+377
-288
lines changed

25 files changed

+377
-288
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
### Feat
8989

9090
- **python310**: add support for python version 310
91-
- **websocket**: add kind to send_group_message to handle both pydantic message and json case
91+
- **websocket**: add kind to broadcast_message to handle both pydantic message and json case
9292

9393
### Fix
9494

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ You can use these parameters in different combinations:
113113
# With group messaging
114114
class GroupConsumer(AsyncJsonWebsocketConsumer[ChatMessage]):
115115
async def receive_message(self, message: ChatMessage, **kwargs: Any) -> None:
116-
# Send typed group messages using send_group_message
116+
# Send typed group messages using broadcast_message
117117
group_msg = MemberMessage(payload={"content": "Hello group!"})
118-
await self.send_group_message(group_msg)
118+
await self.broadcast_message(group_msg)
119119
120120
# Complete example with all generic parameters
121121
class ChatConsumer(AsyncJsonWebsocketConsumer[ChatMessage, ChatEvent, Room]):

chanx/asyncapi/generator.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,23 +135,22 @@ def build_operations(self) -> None:
135135
for route in self.routes:
136136
consumer = route.consumer
137137

138-
for action, handler_info in consumer._MESSAGE_HANDLER_INFO_MAP.items():
138+
for handler_info in consumer._MESSAGE_HANDLER_INFO_MAP.values():
139139
self._build_single_operation(
140-
action, consumer, route, handler_info, is_event=False
140+
handler_info, consumer, route, is_event=False
141141
)
142142

143143
# Build operations from event handlers (send operations)
144144
for action, handler_info in consumer._EVENT_HANDLER_INFO_MAP.items():
145145
self._build_single_operation(
146-
action, consumer, route, handler_info, is_event=True
146+
handler_info, consumer, route, is_event=True
147147
)
148148

149149
def _build_single_operation(
150150
self,
151-
action: str,
151+
handler_info: AsyncAPIHandlerInfo,
152152
consumer: AsyncJsonWebsocketConsumer,
153153
route: RouteInfo,
154-
handler_info: AsyncAPIHandlerInfo,
155154
is_event=False,
156155
) -> None:
157156
"""
@@ -164,9 +163,9 @@ def _build_single_operation(
164163
Returns:
165164
AsyncAPI operation definition.
166165
"""
167-
action_name = action
168-
if action in self._operation_names:
169-
action_name = "_".join((consumer.snake_name, action))
166+
action_name = handler_info["action"]
167+
if action_name in self._operation_names:
168+
action_name = "_".join((consumer.snake_name, action_name))
170169

171170
channel_name = self._route_channel_mapping[route.path]
172171
operation: dict[str, Any] = {

chanx/core/adapter.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22

33
IS_USING_DJANGO = os.environ.get("DJANGO_SETTINGS_MODULE", False)
44

5-
if IS_USING_DJANGO:
6-
from channels.generic.websocket import AsyncJsonWebsocketConsumer
7-
from channels.layers import get_channel_layer
8-
from channels.testing import WebsocketCommunicator
9-
10-
from asgiref.typing import WebSocketConnectEvent, WebSocketDisconnectEvent
11-
12-
else:
13-
from fast_channels.consumer import AsyncJsonWebsocketConsumer
14-
from fast_channels.layers import get_channel_layer
15-
from fast_channels.testing import WebsocketCommunicator
16-
from fast_channels.type_defs import WebSocketConnectEvent, WebSocketDisconnectEvent
17-
5+
# if IS_USING_DJANGO:
6+
# from channels.generic.websocket import AsyncJsonWebsocketConsumer
7+
# from channels.layers import get_channel_layer
8+
# from channels.testing import WebsocketCommunicator
9+
#
10+
# from asgiref.typing import WebSocketConnectEvent, WebSocketDisconnectEvent
11+
#
12+
# else:
13+
from fast_channels.consumer import AsyncJsonWebsocketConsumer
14+
from fast_channels.layers import get_channel_layer
15+
from fast_channels.testing import WebsocketCommunicator
16+
from fast_channels.type_defs import WebSocketConnectEvent, WebSocketDisconnectEvent
1817

1918
__all__ = [
2019
"get_channel_layer",

chanx/core/decorators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
_P = ParamSpec("_P")
2121
_R = TypeVar("_R")
22+
_T = TypeVar("_T")
2223

2324

2425
@overload
@@ -118,6 +119,7 @@ def decorator(func: Callable) -> Callable: # noqa C901
118119
consumer_name = func.__qualname__.split(".")[0]
119120
handler_info: AsyncAPIHandlerInfo = {
120121
"action": final_action,
122+
"message_action": final_input_type.model_fields["action"].default,
121123
"input_type": final_input_type,
122124
"output_type": final_output_type,
123125
"method_name": func.__name__,
@@ -291,7 +293,7 @@ def channel(
291293
name: str | None = None,
292294
description: str | None = None,
293295
tags: list[str] | None = None,
294-
) -> Callable[[type], type]:
296+
) -> Callable[[_T], _T]:
295297
"""
296298
Decorator for WebSocket consumer classes to add AsyncAPI channel metadata.
297299
@@ -311,7 +313,7 @@ class AssistantConsumer(AsyncJsonWebsocketConsumer):
311313
pass
312314
"""
313315

314-
def decorator(cls: type) -> type:
316+
def decorator(cls: _T) -> _T:
315317
# Store channel information on the class
316318
channel_info: ChannelInfo = {
317319
"name": name,

chanx/core/websocket.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,17 @@ def _process_handlers(cls) -> None:
106106

107107
# Process WebSocket handlers
108108
if hasattr(attr, "_ws_handler_info"):
109-
handler_info = attr._ws_handler_info
110-
cls._MESSAGE_HANDLER_INFO_MAP[handler_info["action"]] = handler_info
109+
handler_info: AsyncAPIHandlerInfo = attr._ws_handler_info
110+
cls._MESSAGE_HANDLER_INFO_MAP[handler_info["message_action"]] = (
111+
handler_info
112+
)
111113

112114
# Process event handlers
113115
if hasattr(attr, "_event_handler_info"):
114-
handler_info = attr._event_handler_info
115-
cls._EVENT_HANDLER_INFO_MAP[handler_info["action"]] = handler_info
116+
handler_info: AsyncAPIHandlerInfo = attr._event_handler_info
117+
cls._EVENT_HANDLER_INFO_MAP[handler_info["message_action"]] = (
118+
handler_info
119+
)
116120

117121
# Extract types and build unions/adapters
118122
cls._build_adapters()
@@ -264,14 +268,23 @@ async def websocket_connect(self, message: WebSocketConnectEvent) -> None:
264268
await self.accept()
265269

266270
# Authenticate the connection
267-
auth_result = None
268271
if self.authenticator:
269272
auth_result = await self.authenticator.authenticate(self.scope)
270273

271274
if not auth_result:
272275
await self.close()
273276
return
274277

278+
try:
279+
for group in self.groups:
280+
channel_layer = get_channel_layer(self.channel_layer_alias)
281+
if channel_layer:
282+
await channel_layer.group_add(group, self.channel_name)
283+
except AttributeError:
284+
raise ValueError(
285+
"BACKEND is unconfigured or doesn't support groups"
286+
) from None
287+
275288
await self.post_authentication()
276289

277290
async def websocket_disconnect(self, message: WebSocketDisconnectEvent) -> None:
@@ -367,10 +380,10 @@ async def receive_message(self, message: BaseMessage) -> None:
367380
message: The validated message object (BaseMessage instance)
368381
"""
369382
# Extract the action from the discriminator field
370-
action = getattr(message, self.discriminator_field)
383+
message_action = getattr(message, self.discriminator_field)
371384

372-
# Find the handler for this action
373-
handler_info = self.__class__._MESSAGE_HANDLER_INFO_MAP.get(action)
385+
# Find the handler for this message_action
386+
handler_info = self.__class__._MESSAGE_HANDLER_INFO_MAP.get(message_action)
374387

375388
# Get the handler method by name
376389
method_name = handler_info["method_name"]
@@ -385,7 +398,7 @@ async def receive_message(self, message: BaseMessage) -> None:
385398
await self.handle_result(result)
386399

387400
except Exception as e:
388-
await self.handle_error(e, action, message)
401+
await self.handle_error(e, message_action, message)
389402

390403
async def handle_result(self, result: BaseMessage) -> None:
391404
"""
@@ -458,12 +471,12 @@ async def handle_error(
458471
await logger.aexception(f"Handler error for action '{action}': {error}")
459472

460473
# Group operations methods
461-
async def send_group_message(
474+
async def broadcast_message(
462475
self,
463476
message: BaseMessage | dict[str, Any],
464477
groups: list[str] | None = None,
465478
*,
466-
exclude_current: bool = True,
479+
exclude_current: bool = False,
467480
) -> None:
468481
"""
469482
Send a BaseMessage object to one or more channel groups.
@@ -479,24 +492,27 @@ async def send_group_message(
479492
exclude_current: Whether to exclude the sending consumer from receiving
480493
the broadcast (prevents echo effects)
481494
"""
495+
channel_layer = get_channel_layer(self.channel_layer_alias)
496+
assert channel_layer
497+
482498
if groups is None:
483499
groups = self.groups or []
484500

485501
if isinstance(message, BaseMessage):
486502
message = message.model_dump(mode="json")
487503

488504
for group in groups:
489-
await self.channel_layer.group_send(
505+
await channel_layer.group_send(
490506
group,
491507
{
492-
"type": "send_group_member",
508+
"type": "handle_group_message",
493509
"message": message,
494510
"exclude_current": exclude_current,
495511
"from_channel": self.channel_name,
496512
},
497513
)
498514

499-
async def send_group_member(self, event: BaseMessage) -> None:
515+
async def handle_group_message(self, event: BaseMessage) -> None:
500516
"""
501517
Handle incoming group message and relay to client.
502518
@@ -534,7 +550,7 @@ async def send_group_member(self, event: BaseMessage) -> None:
534550

535551
# Channel event system methods
536552
@classmethod
537-
async def asend_event(
553+
async def send_event(
538554
cls,
539555
channel_name: str,
540556
event: ReceiveEvent,
@@ -562,13 +578,13 @@ async def asend_event(
562578
)
563579

564580
@classmethod
565-
def send_event(
581+
def send_event_sync(
566582
cls,
567583
channel_name: str,
568584
event: ReceiveEvent,
569585
) -> None:
570586
"""
571-
Synchronous version of asend_event for use in Django tasks/views.
587+
Synchronous version of send_event for use in Django tasks/views.
572588
573589
This method provides the same functionality as asend_channel_event but
574590
can be called from synchronous code like Django tasks, views, or signals.
@@ -577,13 +593,13 @@ def send_event(
577593
channel_name: Channel name to send to
578594
event: The typed event to send (BaseMessage subclass)
579595
"""
580-
async_to_sync(cls.asend_event)(channel_name, event)
596+
async_to_sync(cls.send_event)(channel_name, event)
581597

582598
@classmethod
583-
async def abroadcast_event(
599+
async def broadcast_event(
584600
cls,
585-
group_name: str,
586601
event: ReceiveEvent,
602+
groups: Iterable[str] | None = None,
587603
) -> None:
588604
"""
589605
Broadcast a typed channel event to a channel group.
@@ -593,37 +609,41 @@ async def abroadcast_event(
593609
places where you don't have a consumer instance.
594610
595611
Args:
596-
group_name: Group name to broadcast the event to
597612
event: The typed event to broadcast (BaseMessage subclass)
613+
groups: Groups to broadcast the event to
598614
"""
599615
channel_layer = get_channel_layer(cls.channel_layer_alias)
600616
assert channel_layer is not None
601617

602-
await channel_layer.group_send(
603-
group_name,
604-
{
605-
"type": "handle_channel_event",
606-
"event_data": event.model_dump(mode="json"),
607-
},
608-
)
618+
if groups is None:
619+
groups = cls.groups or []
620+
621+
for group in groups:
622+
await channel_layer.group_send(
623+
group,
624+
{
625+
"type": "handle_channel_event",
626+
"event_data": event.model_dump(mode="json"),
627+
},
628+
)
609629

610630
@classmethod
611-
def broadcast_event(
631+
def broadcast_event_sync(
612632
cls,
613-
group_name: str,
614633
event: ReceiveEvent,
634+
groups: Iterable[str] | None = None,
615635
) -> None:
616636
"""
617-
Synchronous version of abroadcast_event for use in Django tasks/views.
637+
Synchronous version of broadcast_event for use in sync tasks/views.
618638
619-
This method provides the same functionality as abroadcast_event but
639+
This method provides the same functionality as broadcast_event but
620640
can be called from synchronous code like Django tasks, views, or signals.
621641
622642
Args:
623-
group_name: Group name to broadcast to
624643
event: The typed event to broadcast (BaseMessage subclass)
644+
groups: Groups to broadcast to
625645
"""
626-
async_to_sync(cls.abroadcast_event)(group_name, event)
646+
async_to_sync(cls.broadcast_event)(event, groups)
627647

628648
async def handle_channel_event(self, event_payload: EventPayload) -> None:
629649
"""

chanx/type_defs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class HandlerInfo(TypedDict):
1515
"""Information stored on handler functions by the @ws_handler decorator."""
1616

1717
action: str
18+
message_action: str
1819
input_type: type[BaseMessage] | None
1920
output_type: type[BaseMessage] | None
2021
method_name: str

docs/examples/chat.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ Now we'll create our chat consumer with proper pattern matching:
142142
user = self.user
143143
144144
# Send joined notification to the group
145-
await self.send_group_message(
145+
await self.broadcast_message(
146146
UserJoinedPayload(
147147
payload={
148148
"username": user.username,
@@ -221,7 +221,7 @@ Now we'll create our chat consumer with proper pattern matching:
221221
chat_message = ChatMessagePayload(payload=text)
222222
223223
# Broadcast to the group
224-
await self.send_group_message(chat_message)
224+
await self.broadcast_message(chat_message)
225225
226226
async def save_message_to_db(self, user: User, room: ChatRoom, text: str) -> None:
227227
"""Save chat message to database."""
@@ -247,7 +247,7 @@ Now we'll create our chat consumer with proper pattern matching:
247247
room = self.obj
248248
249249
# Send user left notification
250-
await self.send_group_message(
250+
await self.broadcast_message(
251251
UserLeftPayload(
252252
payload={
253253
"username": user.username,

docs/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Chanx fills these gaps with a cohesive framework that provides:
6868
return [f"room_{self.room_id}"]
6969
7070
async def receive_message(self, message: ChatMessage, **kwargs: Any) -> None:
71-
await self.send_group_message(message)
71+
await self.broadcast_message(message)
7272
7373
4. **Enhanced URL Routing**: Django-style routing utilities for WebSocket endpoints with type hints support
7474

docs/quick-start.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ Now let's enhance our consumer to support group messaging. First, we need to add
267267
username = getattr(self.user, 'username', 'Anonymous')
268268
269269
# Send to the whole group
270-
await self.send_group_message(
270+
await self.broadcast_message(
271271
ChatGroupMessage(
272272
payload=MessagePayload(content=f"{username}: {payload.content}")
273273
)

0 commit comments

Comments
 (0)