Skip to content

Commit 8074f5b

Browse files
detect ip version earlier
1 parent b45f418 commit 8074f5b

File tree

3 files changed

+50
-33
lines changed

3 files changed

+50
-33
lines changed

src/pyartnet/base/base_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ async def __aenter__(self) -> Self:
194194
if self._socket is not None:
195195
return self
196196

197-
ip_v6 = await self._network.is_ip_v6()
198-
self._socket = self._network.create_socket(ip_v6=ip_v6)
197+
await self._network.resolve_hostname()
198+
self._socket = self._network.create_socket()
199199

200200
self._refresh_task.start()
201201
return self

src/pyartnet/base/network.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -70,21 +70,30 @@ def validate_ip_address(host: str) -> IPv4Address | IPv6Address:
7070

7171

7272
class NetworkTargetBase:
73+
def __init__(self, *, ip_v6: bool | None = None) -> None:
74+
self._ip_v6: bool | None = ip_v6
7375

74-
def create_socket(self, *, ip_v6: bool) -> socket.socket:
76+
def create_socket(self) -> socket.socket:
7577
# create nonblocking UDP socket
76-
sock: Final = socket.socket(AF_INET6 if ip_v6 else AF_INET, SOCK_DGRAM)
78+
sock: Final = socket.socket(AF_INET6 if self.ip_v6 else AF_INET, SOCK_DGRAM)
7779
sock.setblocking(False)
7880

7981
return sock
8082

81-
async def is_ip_v6(self) -> bool:
83+
@property
84+
def ip_v6(self) -> bool:
85+
if self._ip_v6 is None:
86+
msg = 'Host not yet resolved!'
87+
raise RuntimeError(msg)
88+
return self._ip_v6
89+
90+
async def resolve_hostname(self) -> None:
8291
raise NotImplementedError()
8392

8493

8594
class UnicastNetworkTarget(NetworkTargetBase):
86-
def __init__(self, dst: tuple[str, int], src: tuple[str, int] | None = None) -> None:
87-
super().__init__()
95+
def __init__(self, dst: tuple[str, int], src: tuple[str, int] | None = None, *, ip_v6: bool | None = None) -> None:
96+
super().__init__(ip_v6=ip_v6)
8897
self.dst: Final = dst
8998
self.src: Final = src
9099

@@ -94,8 +103,8 @@ def __repr__(self) -> str:
94103
return f'{self.__class__.__name__:s}(dst={ip:s}:{port:d}, source={src:s})'
95104

96105
@override
97-
def create_socket(self, *, ip_v6: bool) -> socket.socket:
98-
sock: Final = super().create_socket(ip_v6=ip_v6)
106+
def create_socket(self) -> socket.socket:
107+
sock: Final = super().create_socket()
99108

100109
# option to set source port/ip
101110
if (src := self.src) is not None:
@@ -116,45 +125,51 @@ def create(cls, host: str, port: int, source_ip: str | None = None, source_port:
116125
validate_port(source_port, allow_0=True)
117126
source = (source_ip, source_port)
118127

119-
return cls(dst=(host, port), src=source)
120-
121-
@override
122-
async def is_ip_v6(self) -> bool:
128+
# if host is an IP address, determine IP version now
129+
ip_v6: bool | None = None
123130
try:
124-
dst_ip = validate_ip_address(self.dst[0])
131+
dst_ip = validate_ip_address(host)
125132
except AddressValueError:
126133
pass
127134
else:
128-
if self.dst is not None:
135+
if source_ip is not None:
129136
# destination and source IP version must match
130137
try:
131-
dst_ip.__class__(self.dst[0])
138+
dst_ip.__class__(source_ip)
132139
except AddressValueError:
133-
msg = f'Source IP "{self.dst[0]}" is not a valid IPv{dst_ip.version}!'
140+
msg = f'Source IP "{source_ip}" is not a valid IPv{dst_ip.version}!'
134141
raise ValueError(msg) from None
135142

136-
return dst_ip.version == 6
143+
ip_v6 = dst_ip.version == 6
144+
145+
return cls(dst=(host, port), src=source, ip_v6=ip_v6)
146+
147+
@override
148+
async def resolve_hostname(self) -> None:
149+
if self._ip_v6 is not None:
150+
return None
137151

138152
# source ip can be used to set the mode for resolution
139153
mode: RESOLVE_TO_IP_TYPE = 'auto'
140154
if self.src is not None:
141155
mode = 'v6' if validate_ip_address(self.src[0]).version == 6 else 'v4'
142156

143157
info = await resolve_hostname(self.dst[0], self.dst[1], mode=mode)
144-
return info[0].version == 6
158+
self._ip_v6 = info[0].version == 6
159+
return None
145160

146161

147162
class MulticastNetworkTarget(NetworkTargetBase):
148-
def __init__(self, src: tuple[str, int]) -> None:
149-
super().__init__()
163+
def __init__(self, src: tuple[str, int], *, ip_v6: bool | None = None) -> None:
164+
super().__init__(ip_v6=ip_v6)
150165
self.src: Final = src
151166

152167
def __repr__(self) -> str:
153168
return f'{self.__class__.__name__:s}(source={self.src[0]:s})'
154169

155170
@override
156-
def create_socket(self, *, ip_v6: bool) -> socket.socket:
157-
sock: Final = super().create_socket(ip_v6=ip_v6)
171+
def create_socket(self) -> socket.socket:
172+
sock: Final = super().create_socket()
158173

159174
# set source port/ip
160175
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -163,18 +178,18 @@ def create_socket(self, *, ip_v6: bool) -> socket.socket:
163178
# setup socket for multicast
164179
sock.setsockopt(
165180
socket.IPPROTO_IP,
166-
socket.IPV6_MULTICAST_IF if ip_v6 else socket.IP_MULTICAST_IF,
167-
socket.inet_pton(AF_INET6 if ip_v6 else AF_INET, self.src[0])
181+
socket.IPV6_MULTICAST_IF if self.ip_v6 else socket.IP_MULTICAST_IF,
182+
socket.inet_pton(AF_INET6 if self.ip_v6 else AF_INET, self.src[0])
168183
)
169184

170185
return sock
171186

172187
@classmethod
173188
def create(cls, source_ip: str, source_port: int = 0) -> Self:
174-
validate_ip_address(source_ip)
189+
ip_obj = validate_ip_address(source_ip)
175190
validate_port(source_port, allow_0=True)
176-
return cls(src=(source_ip, source_port))
191+
return cls(src=(source_ip, source_port), ip_v6=ip_obj.version == 6)
177192

178193
@override
179-
async def is_ip_v6(self) -> bool:
180-
return validate_ip_address(self.src[0]).version == 6
194+
async def resolve_hostname(self) -> None:
195+
return None

tests/helper.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ def __exit__(self, exc_type: type[BaseException] | None,
5555

5656
class UnicastNetworkTestingTarget(UnicastNetworkTarget):
5757
@override
58-
async def is_ip_v6(self) -> bool:
59-
return False
58+
async def resolve_hostname(self) -> None:
59+
self._ip_v6 = False
60+
return None
6061

6162

6263
class MulticastTestingNetworkTarget(MulticastNetworkTarget):
6364
@override
64-
async def is_ip_v6(self) -> bool:
65-
return False
65+
async def resolve_hostname(self) -> None:
66+
self._ip_v6 = False
67+
return None

0 commit comments

Comments
 (0)