Skip to content

Commit 866c227

Browse files
committed
Add turmoil::net::lookup_host
... mirroring `tokio::net::lookup_host`. The tokio method returns an error if the lookup fails, whereas turmoil's existing `ToSocketAddrs` implementation panics. To ensure consistent behavior, this commit thus also changes `ToSocketAddrs::to_socket_addr` to return an error as well.
1 parent af06e81 commit 866c227

File tree

6 files changed

+63
-37
lines changed

6 files changed

+63
-37
lines changed

src/dns.rs

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use indexmap::IndexMap;
22
#[cfg(feature = "regex")]
33
use regex::Regex;
4+
use std::io;
45
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
56

67
use crate::ip::IpVersionAddrIter;
@@ -30,7 +31,7 @@ pub trait ToIpAddrs: sealed::Sealed {
3031
/// A simulated version of `tokio::net::ToSocketAddrs`.
3132
pub trait ToSocketAddrs: sealed::Sealed {
3233
#[doc(hidden)]
33-
fn to_socket_addr(&self, dns: &Dns) -> SocketAddr;
34+
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr>;
3435
}
3536

3637
impl Dns {
@@ -116,73 +117,76 @@ impl ToIpAddrs for Regex {
116117

117118
// Hostname and port
118119
impl ToSocketAddrs for (String, u16) {
119-
fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
120+
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
120121
(&self.0[..], self.1).to_socket_addr(dns)
121122
}
122123
}
123124

124125
impl ToSocketAddrs for (&str, u16) {
125-
fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
126+
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
126127
// When IP address is passed directly as a str.
127128
if let Ok(ip) = self.0.parse::<IpAddr>() {
128-
return (ip, self.1).into();
129+
return Ok((ip, self.1).into());
129130
}
130131

131132
match dns.names.get(self.0) {
132-
Some(ip) => (*ip, self.1).into(),
133-
None => panic!("no ip address found for a hostname: {}", self.0),
133+
Some(ip) => Ok((*ip, self.1).into()),
134+
None => Err(io::Error::other(format!(
135+
"no ip address found for a hostname: {}",
136+
self.0
137+
))),
134138
}
135139
}
136140
}
137141

138142
impl ToSocketAddrs for SocketAddr {
139-
fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
140-
*self
143+
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
144+
Ok(*self)
141145
}
142146
}
143147

144148
impl ToSocketAddrs for SocketAddrV4 {
145-
fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
146-
SocketAddr::V4(*self)
149+
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
150+
Ok(SocketAddr::V4(*self))
147151
}
148152
}
149153

150154
impl ToSocketAddrs for SocketAddrV6 {
151-
fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
152-
SocketAddr::V6(*self)
155+
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
156+
Ok(SocketAddr::V6(*self))
153157
}
154158
}
155159

156160
impl ToSocketAddrs for (IpAddr, u16) {
157-
fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
158-
(*self).into()
161+
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
162+
Ok((*self).into())
159163
}
160164
}
161165

162166
impl ToSocketAddrs for (Ipv4Addr, u16) {
163-
fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
164-
(*self).into()
167+
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
168+
Ok((*self).into())
165169
}
166170
}
167171

168172
impl ToSocketAddrs for (Ipv6Addr, u16) {
169-
fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
170-
(*self).into()
173+
fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
174+
Ok((*self).into())
171175
}
172176
}
173177

174178
impl<T: ToSocketAddrs + ?Sized> ToSocketAddrs for &T {
175-
fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
179+
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
176180
(**self).to_socket_addr(dns)
177181
}
178182
}
179183

180184
impl ToSocketAddrs for str {
181-
fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
185+
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
182186
let socketaddr: Result<SocketAddr, _> = self.parse();
183187

184188
if let Ok(s) = socketaddr {
185-
return s;
189+
return Ok(s);
186190
}
187191

188192
// Borrowed from std
@@ -191,7 +195,7 @@ impl ToSocketAddrs for str {
191195
($e:expr, $msg:expr) => {
192196
match $e {
193197
Some(r) => r,
194-
None => panic!("Unable to parse dns: {}", $msg),
198+
None => return Err(io::Error::new(io::ErrorKind::InvalidInput, $msg)),
195199
}
196200
};
197201
}
@@ -205,7 +209,7 @@ impl ToSocketAddrs for str {
205209
}
206210

207211
impl ToSocketAddrs for String {
208-
fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
212+
fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
209213
self.as_str().to_socket_addr(dns)
210214
}
211215
}
@@ -227,16 +231,22 @@ mod tests {
227231
let mut dns = Dns::new(IpVersionAddrIter::default());
228232
let generated_addr = dns.lookup("foo");
229233

230-
let hostname_port = "foo:5000".to_socket_addr(&dns);
234+
let hostname_port = "foo:5000";
231235
let ipv4_port = "127.0.0.1:5000";
232236
let ipv6_port = "[::1]:5000";
233237

234238
assert_eq!(
235-
hostname_port,
239+
hostname_port.to_socket_addr(&dns).unwrap(),
236240
format!("{generated_addr}:5000").parse().unwrap()
237241
);
238-
assert_eq!(ipv4_port.to_socket_addr(&dns), ipv4_port.parse().unwrap());
239-
assert_eq!(ipv6_port.to_socket_addr(&dns), ipv6_port.parse().unwrap());
242+
assert_eq!(
243+
ipv4_port.to_socket_addr(&dns).unwrap(),
244+
ipv4_port.parse().unwrap()
245+
);
246+
assert_eq!(
247+
ipv6_port.to_socket_addr(&dns).unwrap(),
248+
ipv6_port.parse().unwrap()
249+
);
240250
}
241251

242252
#[test]
@@ -251,7 +261,7 @@ mod tests {
251261
let addr = dns.lookup("192.168.3.3");
252262
assert_eq!(addr, Ipv4Addr::new(192, 168, 3, 3));
253263

254-
let addr = "192.168.3.3:0".to_socket_addr(&dns);
264+
let addr = "192.168.3.3:0".to_socket_addr(&dns).unwrap();
255265
assert_eq!(addr.ip(), Ipv4Addr::new(192, 168, 3, 3));
256266
}
257267
}

src/net/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,28 @@
44
//! a high fidelity implementation.
55
66
use std::net::SocketAddr;
7+
use std::{io, iter};
8+
9+
use crate::world::World;
10+
use crate::ToSocketAddrs;
711

812
pub mod tcp;
913
pub use tcp::{listener::TcpListener, stream::TcpStream};
1014

1115
pub(crate) mod udp;
1216
pub use udp::UdpSocket;
1317

18+
/// Performs a DNS resolution.
19+
///
20+
/// Must be called from within a turmoil simulation context.
21+
pub async fn lookup_host<T>(host: T) -> io::Result<impl Iterator<Item = SocketAddr>>
22+
where
23+
T: ToSocketAddrs,
24+
{
25+
let addr = World::current(|world| host.to_socket_addr(&world.dns))?;
26+
Ok(iter::once(addr))
27+
}
28+
1429
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
1530
pub(crate) struct SocketPair {
1631
pub(crate) local: SocketAddr,

src/net/tcp/listener.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl TcpListener {
3636
/// Only `0.0.0.0`, `::`, or localhost are currently supported.
3737
pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<TcpListener> {
3838
World::current(|world| {
39-
let mut addr = addr.to_socket_addr(&world.dns);
39+
let mut addr = addr.to_socket_addr(&world.dns)?;
4040
let host = world.current_host_mut();
4141

4242
if !addr.ip().is_unspecified() && !addr.ip().is_loopback() {

src/net/tcp/stream.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl TcpStream {
6363
let (ack, syn_ack) = oneshot::channel();
6464

6565
let (pair, rx) = World::current(|world| {
66-
let dst = addr.to_socket_addr(&world.dns);
66+
let dst = addr.to_socket_addr(&world.dns)?;
6767

6868
let host = world.current_host_mut();
6969
let mut local_addr = SocketAddr::new(host.addr, host.assign_ephemeral_port());

src/net/udp.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,13 @@ impl UdpSocket {
156156
}),
157157
}
158158
}
159-
pub async fn connect<A: ToSocketAddrs>(&self, addr: A) {
159+
pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> Result<()> {
160160
World::current(|world| {
161-
let addr = addr.to_socket_addr(&world.dns);
161+
let addr = addr.to_socket_addr(&world.dns)?;
162162
let host = world.current_host_mut();
163163

164164
host.udp.connect(self.local_addr, addr);
165+
Ok(())
165166
})
166167
}
167168

@@ -175,7 +176,7 @@ impl UdpSocket {
175176
/// Only `0.0.0.0`, `::`, or localhost are currently supported.
176177
pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
177178
World::current(|world| {
178-
let mut addr = addr.to_socket_addr(&world.dns);
179+
let mut addr = addr.to_socket_addr(&world.dns)?;
179180
let host = world.current_host_mut();
180181

181182
verify_ipv4_bind_interface(addr.ip(), host.addr)?;
@@ -209,7 +210,7 @@ impl UdpSocket {
209210
/// completes first, then it is guaranteed that the message was not sent.
210211
pub async fn send_to<A: ToSocketAddrs>(&self, buf: &[u8], target: A) -> Result<usize> {
211212
World::current(|world| {
212-
let dst = target.to_socket_addr(&world.dns);
213+
let dst = target.to_socket_addr(&world.dns)?;
213214
self.send(world, dst, Datagram(Bytes::copy_from_slice(buf)))?;
214215
Ok(buf.len())
215216
})
@@ -231,7 +232,7 @@ impl UdpSocket {
231232
/// [`ErrorKind::WouldBlock`]: std::io::ErrorKind::WouldBlock
232233
pub fn try_send_to<A: ToSocketAddrs>(&self, buf: &[u8], target: A) -> Result<usize> {
233234
World::current(|world| {
234-
let dst = target.to_socket_addr(&world.dns);
235+
let dst = target.to_socket_addr(&world.dns)?;
235236
self.send(world, dst, Datagram(Bytes::copy_from_slice(buf)))?;
236237
Ok(buf.len())
237238
})

tests/udp.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ fn try_recv_connected() -> Result {
10461046
sim.client("server1", async move {
10471047
let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 1234)).await?;
10481048

1049-
sock.connect(("server2", 5678)).await;
1049+
sock.connect(("server2", 5678)).await?;
10501050

10511051
sock.readable().await?;
10521052

@@ -1079,7 +1079,7 @@ fn try_recv_connected_wrong_host() -> Result {
10791079
let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 1234)).await?;
10801080

10811081
// Connect to server3, which doesn't send anything.
1082-
sock.connect(("server3", 6781)).await;
1082+
sock.connect(("server3", 6781)).await?;
10831083

10841084
// Datagram from server2 is filtered so sock never becomes readable.
10851085
tokio::time::timeout(Duration::from_secs(5), sock.readable())

0 commit comments

Comments
 (0)