Skip to content

Commit 57208ab

Browse files
committed
test(test): improve websocket docstring and testing coverage
Add more test for websocket, almost cover it. Add is_alive to check the testing websocket still alive or not. Add more docstring for the websocket, too
1 parent 5eb9a4b commit 57208ab

File tree

5 files changed

+224
-79
lines changed

5 files changed

+224
-79
lines changed

chanx/generic/websocket.py

Lines changed: 80 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
)
1010
from django.contrib.auth.base_user import AbstractBaseUser
1111
from django.contrib.auth.models import AnonymousUser
12-
from django.http import HttpRequest, HttpResponseBase
12+
from django.http import HttpRequest
1313
from rest_framework import status
1414
from rest_framework.authentication import BaseAuthentication
1515
from rest_framework.permissions import (
1616
BasePermission,
1717
OperandHolder,
1818
SingleOperandHolder,
1919
)
20+
from rest_framework.response import Response
2021
from rest_framework.settings import api_settings
2122
from rest_framework.views import APIView
2223

@@ -39,10 +40,22 @@
3940

4041
class AsyncJsonWebsocketConsumer(BaseAsyncJsonWebsocketConsumer, ABC): # type: ignore
4142
"""
42-
Base class for asynchronous JSON WebSocket consumers with authentication and permission handling.
43-
44-
This class extends Django Channels' AsyncJsonWebsocketConsumer to provide authentication,
45-
permission checking, structured message handling, and logging.
43+
Base class for asynchronous JSON WebSocket consumers with authentication and permissions.
44+
45+
Provides DRF-style authentication/permissions, structured message handling with
46+
Pydantic validation, logging, and error handling. Subclasses must implement
47+
`receive_message` and set `INCOMING_MESSAGE_SCHEMA`.
48+
49+
Attributes:
50+
permission_classes: DRF permission classes for connection authorization
51+
authentication_classes: DRF authentication classes for connection verification
52+
send_completion: Whether to send completion message after processing
53+
send_message_immediately: Whether to yield control after sending messages
54+
log_received_message: Whether to log received messages
55+
log_sent_message: Whether to log sent messages
56+
log_ignored_actions: Message actions that should not be logged
57+
send_authentication_message: Whether to send auth status after connection
58+
INCOMING_MESSAGE_SCHEMA: Pydantic model class for message validation
4659
"""
4760

4861
permission_classes: (
@@ -55,15 +68,20 @@ class AsyncJsonWebsocketConsumer(BaseAsyncJsonWebsocketConsumer, ABC): # type:
5568
log_received_message: bool | None = None
5669
log_sent_message: bool | None = None
5770
log_ignored_actions: Iterable[str] | None = None
58-
INCOMING_MESSAGE_SCHEMA: type[BaseIncomingMessage] | None = None
71+
send_authentication_message: bool | None = None
72+
73+
INCOMING_MESSAGE_SCHEMA: type[BaseIncomingMessage]
5974

6075
def __init__(self, *args: Any, **kwargs: Any) -> None:
6176
"""
62-
Initialize the WebSocket consumer with authentication and permission setup.
77+
Initialize with authentication and permission setup.
6378
6479
Args:
65-
*args: Variable length argument list.
66-
**kwargs: Arbitrary keyword arguments.
80+
*args: Variable length argument list
81+
**kwargs: Arbitrary keyword arguments
82+
83+
Raises:
84+
ValueError: If INCOMING_MESSAGE_SCHEMA is not set
6785
"""
6886
super().__init__(*args, **kwargs)
6987

@@ -94,8 +112,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
94112
if self.log_ignored_actions is None:
95113
self.log_ignored_actions = chanx_settings.LOG_IGNORED_ACTIONS
96114

97-
if self.INCOMING_MESSAGE_SCHEMA is None:
98-
self.INCOMING_MESSAGE_SCHEMA = chanx_settings.INCOMING_MESSAGE_SCHEMA
115+
self.ignore_actions: set[str] = (
116+
set(self.log_ignored_actions) if self.log_ignored_actions else set()
117+
)
118+
119+
if self.send_authentication_message is None:
120+
self.send_authentication_message = (
121+
chanx_settings.SEND_AUTHENTICATION_MESSAGE
122+
)
123+
124+
if not hasattr(self, "INCOMING_MESSAGE_SCHEMA"):
125+
raise ValueError("INCOMING_MESSAGE_SCHEMA attribute is required.")
99126

100127
self._v = APIView()
101128
self._v.authentication_classes = self.authentication_classes
@@ -108,11 +135,10 @@ async def websocket_connect(self, message: dict[str, Any]) -> None:
108135
"""
109136
Handle WebSocket connection request.
110137
111-
This method is called when a client attempts to establish a WebSocket connection.
112-
It accepts the connection and authenticates the user.
138+
Accepts the connection and authenticates the user.
113139
114140
Args:
115-
message: The connection message.
141+
message: The connection message from Channels
116142
"""
117143
await self.accept()
118144
await self._authenticate()
@@ -121,11 +147,10 @@ async def websocket_disconnect(self, message: dict[str, Any]) -> None:
121147
"""
122148
Handle WebSocket disconnection.
123149
124-
This method is called when a client disconnects from the WebSocket.
125-
It cleans up context variables and logs the disconnection.
150+
Cleans up context variables and logs the disconnection.
126151
127152
Args:
128-
message: The disconnection message.
153+
message: The disconnection message from Channels
129154
"""
130155
await logger.ainfo("Disconnecting websocket")
131156
structlog.contextvars.clear_contextvars()
@@ -135,14 +160,13 @@ async def websocket_disconnect(self, message: dict[str, Any]) -> None:
135160

136161
async def receive_json(self, content: dict[str, Any], **kwargs: Any) -> None:
137162
"""
138-
Receive and process JSON data from the WebSocket.
163+
Receive and process JSON data from WebSocket.
139164
140-
This method is called when a client sends a JSON message. It logs the received
141-
message and creates a task to handle it asynchronously.
165+
Logs messages, assigns ID, and creates task for async processing.
142166
143167
Args:
144-
content: The JSON content received from the client.
145-
**kwargs: Additional keyword arguments.
168+
content: The JSON content received from the client
169+
**kwargs: Additional keyword arguments
146170
"""
147171
message_action = content.get(chanx_settings.MESSAGE_ACTION_KEY)
148172

@@ -151,7 +175,7 @@ async def receive_json(self, content: dict[str, Any], **kwargs: Any) -> None:
151175
message_id=message_id, received_action=message_action
152176
)
153177

154-
if self.log_received_message and message_action not in self.log_ignored_actions:
178+
if self.log_received_message and message_action not in self.ignore_actions:
155179
await logger.ainfo("Received websocket json")
156180

157181
create_task(self._handle_receive_json_and_signal_complete(content, **kwargs))
@@ -160,29 +184,24 @@ async def receive_json(self, content: dict[str, Any], **kwargs: Any) -> None:
160184
@abstractmethod
161185
async def receive_message(self, message: BaseMessage, **kwargs: Any) -> None:
162186
"""
163-
Process a received message.
187+
Process a validated received message.
164188
165-
This abstract method must be implemented by subclasses to handle
166-
received messages after they've been deserialized.
189+
Must be implemented by subclasses to handle messages after validation.
167190
168191
Args:
169-
message: The deserialized Message object.
170-
**kwargs: Additional keyword arguments.
171-
172-
Raises:
173-
NotImplementedError: If the subclass does not implement this method.
192+
message: The validated message object
193+
**kwargs: Additional keyword arguments
174194
"""
175-
raise NotImplementedError
176195

177196
async def send_json(self, content: dict[str, Any], close: bool = False) -> None:
178197
"""
179198
Send JSON data to the WebSocket client.
180199
181-
This method sends a JSON message to the client and optionally logs it.
200+
Sends data and optionally logs it.
182201
183202
Args:
184-
content: The JSON content to send.
185-
close: Whether to close the connection after sending.
203+
content: The JSON content to send
204+
close: Whether to close the connection after sending
186205
"""
187206
await super().send_json(content, close)
188207

@@ -191,17 +210,17 @@ async def send_json(self, content: dict[str, Any], close: bool = False) -> None:
191210

192211
message_action = content.get(chanx_settings.MESSAGE_ACTION_KEY)
193212

194-
if self.log_sent_message:
213+
if self.log_sent_message and message_action not in self.ignore_actions:
195214
await logger.ainfo("Sent websocket json", sent_action=message_action)
196215

197216
async def send_message(self, message: BaseMessage) -> None:
198217
"""
199218
Send a Message object to the WebSocket client.
200219
201-
This method serializes a Message object and sends it to the client.
220+
Serializes the message and sends it as JSON.
202221
203222
Args:
204-
message: The Message object to send.
223+
message: The Message object to send
205224
"""
206225
await self.send_json(message.model_dump())
207226

@@ -211,19 +230,19 @@ async def _authenticate(self) -> None:
211230
"""
212231
Authenticate the WebSocket connection.
213232
214-
This method authenticates the WebSocket connection using the configured
215-
authentication classes and sets the user attribute.
233+
Uses DRF authentication classes and sends status if configured.
234+
Closes connection on authentication failure.
216235
"""
217236
res, req = await self._perform_dispatch()
218237

219-
self.user = req.user if hasattr(req, "user") else None
238+
self.user = req.user
220239

221240
await logger.ainfo("Finished authenticating ws request")
222241

223242
# We need to check status_code attribute which exists on both HttpResponse and Response
224243
status_code = getattr(res, "status_code", 500)
225-
data = getattr(res, "data", {})
226-
if chanx_settings.SEND_AUTHENTICATION_MESSAGE:
244+
data = getattr(res, "data", {}) if status_code != status.HTTP_200_OK else "OK"
245+
if self.send_authentication_message:
227246
await self.send_message(
228247
AuthenticationMessage(
229248
payload=AuthenticationPayload(status_code=status_code, data=data)
@@ -233,26 +252,25 @@ async def _authenticate(self) -> None:
233252
await self.close()
234253

235254
@sync_to_async
236-
def _perform_dispatch(self) -> tuple[HttpResponseBase, HttpRequest]:
255+
def _perform_dispatch(self) -> tuple[Response, HttpRequest]:
237256
"""
238257
Perform authentication dispatch synchronously.
239258
240-
This method creates a request from the WebSocket scope, binds logging context,
241-
and dispatches the request through the authentication pipeline.
259+
Creates request from WebSocket scope and runs it through
260+
the DRF authentication pipeline.
242261
243262
Returns:
244-
A tuple containing the response and request objects.
263+
Tuple of (response, request) objects
245264
"""
246265
raw_request = request_from_scope(self.scope)
247266
self._bind_structlog_request_context(raw_request)
248267

249268
logger.info("Start to authenticate ws request")
250269

251-
res = self._v.dispatch(raw_request)
270+
res = cast(Response, self._v.dispatch(raw_request))
252271

253272
# Assuming res has a render method (it does if it's a DRF Response)
254-
if hasattr(res, "render"):
255-
res.render()
273+
res.render()
256274

257275
# For DRF Response objects, renderer_context would be available
258276
if hasattr(res, "renderer_context"):
@@ -267,21 +285,16 @@ def _bind_structlog_request_context(self, raw_request: HttpRequest) -> None:
267285
"""
268286
Bind structured logging context variables from request.
269287
270-
This method extracts and binds request metadata to the structured logging context.
288+
Extracts request ID, path and IP for consistent logging.
271289
272290
Args:
273-
raw_request: The HTTP request object.
291+
raw_request: The HTTP request object
274292
"""
275293
request_id = get_request_header(
276294
raw_request, "x-request-id", "HTTP_X_REQUEST_ID"
277295
) or str(uuid.uuid4())
278-
correlation_id = get_request_header(
279-
raw_request, "x-correlation-id", "HTTP_X_CORRELATION_ID"
280-
)
281-
structlog.contextvars.bind_contextvars(request_id=request_id)
282296

283-
if correlation_id:
284-
structlog.contextvars.bind_contextvars(correlation_id=correlation_id)
297+
structlog.contextvars.bind_contextvars(request_id=request_id)
285298

286299
structlog.contextvars.bind_contextvars(path=raw_request.path)
287300
structlog.contextvars.bind_contextvars(ip=self.scope.get("client", [None])[0])
@@ -292,22 +305,21 @@ async def _handle_receive_json_and_signal_complete(
292305
self, content: dict[str, Any], **kwargs: Any
293306
) -> None:
294307
"""
295-
Handle received JSON content and signal completion.
308+
Handle received JSON and signal completion.
296309
297-
This method deserializes the JSON content into a Message object, calls the
298-
receive_message method, and optionally sends a completion signal.
310+
Validates JSON against schema, processes it, handles exceptions,
311+
and optionally sends completion message.
299312
300313
Args:
301-
content: The JSON content to handle.
302-
**kwargs: Additional keyword arguments.
314+
content: The JSON content to handle
315+
**kwargs: Additional keyword arguments
303316
"""
304317
try:
305-
if self.INCOMING_MESSAGE_SCHEMA is not None:
306-
message = self.INCOMING_MESSAGE_SCHEMA.model_validate(
307-
{"message": content}
308-
).message
318+
message = self.INCOMING_MESSAGE_SCHEMA.model_validate(
319+
{"message": content}
320+
).message
309321

310-
await self.receive_message(message, **kwargs)
322+
await self.receive_message(message, **kwargs)
311323
except ValidationError as e:
312324
await self.send_message(
313325
ErrorMessage(

chanx/settings.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from rest_framework.authentication import BaseAuthentication
99
from rest_framework.settings import APISettings
1010

11-
from chanx.messages.base import BaseIncomingMessage
12-
1311

1412
@dataclass
1513
class MySetting:
@@ -23,10 +21,6 @@ class MySetting:
2321
LOG_SENT_MESSAGE: bool = True
2422
LOG_IGNORED_ACTIONS: Iterable[str] = dataclasses.field(default_factory=list)
2523

26-
INCOMING_MESSAGE_SCHEMA: type[BaseIncomingMessage] = (
27-
"chanx.messages.incoming.IncomingMessage" # type: ignore
28-
)
29-
3024
# Add this field to satisfy the type checker
3125
# It will be used by APISettings but isn't part of the real dataclass structure
3226
user_settings: dict[str, Any] = dataclasses.field(default_factory=dict)

0 commit comments

Comments
 (0)