Skip to content

Commit a881123

Browse files
committed
feat!: validate A2S remote address like AsyncA2S
1 parent c07eaa3 commit a881123

1 file changed

Lines changed: 29 additions & 19 deletions

File tree

src/little_a2s/client/sync.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,15 @@ class A2S:
6767
"""
6868

6969
buffer_size = 32768 # probably overkill?
70-
_protocols: dict[Address | None, A2SClientProtocol]
70+
_remote_addr: Address | None = None
71+
_protocols: dict[Address, A2SClientProtocol]
7172

7273
def __init__(self, sock: socket.socket) -> None:
7374
if sock.type != socket.SOCK_DGRAM:
7475
raise TypeError("Socket type must be SOCK_DGRAM")
7576

7677
self._sock = sock
78+
self._remote_addr = None # FIXME: can this be introspected from sock?
7779
self._protocols = {}
7880

7981
# Connection methods
@@ -129,7 +131,10 @@ def from_addr(
129131
sock = socket.socket(family, type, proto)
130132
sock.settimeout(timeout)
131133
sock.connect(addr)
132-
return cls(sock)
134+
135+
obj = cls(sock)
136+
obj._remote_addr = addr # Default address when user calls addr=None
137+
return obj
133138

134139
@classmethod
135140
def from_ipv4(cls, timeout: float | None = DEFAULT_TIMEOUT) -> Self:
@@ -190,6 +195,7 @@ def info(self, addr: Address | None = None) -> ClientEventInfo:
190195
:raises TimeoutError: The socket timed out.
191196
192197
"""
198+
addr = self._get_addr(addr)
193199
proto = self._get_protocol(addr)
194200
return self._send(ClientEventInfo, addr, proto.info)
195201

@@ -212,6 +218,7 @@ def players(self, addr: Address | None = None) -> ClientEventPlayers:
212218
:raises TimeoutError: The socket timed out.
213219
214220
"""
221+
addr = self._get_addr(addr)
215222
proto = self._get_protocol(addr)
216223
return self._send(ClientEventPlayers, addr, proto.players)
217224

@@ -234,10 +241,18 @@ def rules(self, addr: Address | None = None) -> ClientEventRules:
234241
:raises TimeoutError: The socket timed out.
235242
236243
"""
244+
addr = self._get_addr(addr)
237245
proto = self._get_protocol(addr)
238246
return self._send(ClientEventRules, addr, proto.rules)
239247

240-
def _get_protocol(self, addr: Address | None) -> A2SClientProtocol:
248+
def _get_addr(self, addr: Address | None) -> Address:
249+
if self._remote_addr and addr:
250+
raise TypeError("Socket has remote address, addr= is disallowed")
251+
elif not self._remote_addr and not addr:
252+
raise TypeError("Socket has no remote address, addr= is required")
253+
return addr or self._remote_addr # type: ignore
254+
255+
def _get_protocol(self, addr: Address) -> A2SClientProtocol:
241256
"""Get the A2S protocol for the given address, creating a new one
242257
if it doesn't already exist.
243258
@@ -264,7 +279,7 @@ def _create_protocol(self) -> A2SClientProtocol:
264279
def _send(
265280
self,
266281
t: Type[ClientEventT],
267-
addr: Address | None,
282+
addr: Address,
268283
payload: Callable[[], ClientPacket],
269284
) -> ClientEventT:
270285
"""Use the given request function to generate an outbound packet,
@@ -282,7 +297,7 @@ def _send(
282297
types = (t, ClientEventChallenge)
283298

284299
for _ in range(3):
285-
self._sendto(bytes(payload()), addr)
300+
self._sock.sendto(bytes(payload()), addr)
286301
events = list(filter_type(types, self._recv(addr)))
287302
if not events:
288303
# FIXME: not really a timeout, should be a custom exception
@@ -292,17 +307,11 @@ def _send(
292307

293308
raise ChallengeError("Server responded with too many challenges")
294309

295-
def _sendto(self, data: bytes, addr: Address | None) -> int:
296-
if addr is not None:
297-
return self._sock.sendto(data, addr)
298-
else:
299-
return self._sock.send(data)
300-
301-
def _recv(self, addr: Address | None) -> list[ClientEvent]:
310+
def _recv(self, addr: Address) -> list[ClientEvent]:
302311
"""Read one datagram from the socket and pass it to the protocol.
303312
304-
If address is not None, this may call :meth:`~socket.socket.recvfrom()`
305-
multiple times until a datagram from the given address is received.
313+
This may call :meth:`~socket.socket.recvfrom()` multiple times
314+
until a datagram from the given address is received.
306315
307316
:param addr: The address to wait for a datagram from.
308317
:raises PayloadError: The server sent a malformed packet.
@@ -311,15 +320,15 @@ def _recv(self, addr: Address | None) -> list[ClientEvent]:
311320
"""
312321
# NOTE: not thread-safe!
313322
data, recv_addr = self._sock.recvfrom(self.buffer_size)
314-
events = self._receive_datagram(data, addr and recv_addr)
323+
events = self._receive_datagram(data, recv_addr)
315324

316-
while not events or addr and addr != recv_addr:
325+
while not events or addr != recv_addr:
317326
data, recv_addr = self._sock.recvfrom(self.buffer_size)
318-
events = self._receive_datagram(data, addr and recv_addr)
327+
events = self._receive_datagram(data, recv_addr)
319328

320329
return events
321330

322-
def _receive_datagram(self, data: bytes, addr: Address | None) -> list[ClientEvent]:
331+
def _receive_datagram(self, data: bytes, addr: Address) -> list[ClientEvent]:
323332
"""Pass the datagram to the protocol and return any generated events.
324333
325334
:raises PayloadError: The server sent a malformed packet.
@@ -333,7 +342,7 @@ def _receive_datagram(self, data: bytes, addr: Address | None) -> list[ClientEve
333342

334343
proto.receive_datagram(data)
335344
for packet in proto.packets_to_send():
336-
self._sendto(bytes(packet), addr)
345+
self._sock.sendto(bytes(packet), addr)
337346

338347
return proto.events_received()
339348

@@ -342,6 +351,7 @@ class A2SGoldsource(A2S):
342351
"""A synchronous client for A2S Goldsource queries."""
343352

344353
def info(self, addr: Address | None = None) -> ClientEventGoldsourceInfo: # type: ignore
354+
addr = self._get_addr(addr)
345355
proto = self._get_protocol(addr)
346356
return self._send(ClientEventGoldsourceInfo, addr, proto.info)
347357

0 commit comments

Comments
 (0)