Skip to content

Commit af06e81

Browse files
authored
Add UdpSocket::try_recv (#255)
Also, add ::connect capabilities to UdpSocket.
1 parent 61891a3 commit af06e81

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

src/host.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ pub(crate) struct Udp {
212212

213213
struct UdpBind {
214214
bind_addr: SocketAddr,
215+
target_addr: Option<SocketAddr>,
215216
broadcast: bool,
216217
multicast_loop: bool,
217218
queue: mpsc::Sender<(Datagram, SocketAddr)>,
@@ -259,6 +260,7 @@ impl Udp {
259260
let (tx, rx) = mpsc::channel(self.capacity);
260261
let bind = UdpBind {
261262
bind_addr: addr,
263+
target_addr: None,
262264
broadcast: DEFAULT_BROADCAST,
263265
multicast_loop: DEFAULT_MULTICAST_LOOP,
264266
queue: tx,
@@ -275,9 +277,23 @@ impl Udp {
275277

276278
Ok(UdpSocket::new(addr, rx))
277279
}
280+
pub(crate) fn connect(&mut self, src: SocketAddr, dst: SocketAddr) {
281+
let Some(bind) = self.binds.get_mut(&src.port()) else {
282+
panic!("Connect failed (no matching bind) for {src}");
283+
};
284+
285+
bind.target_addr = Some(dst);
286+
}
278287

279288
fn receive_from_network(&mut self, src: SocketAddr, dst: SocketAddr, datagram: Datagram) {
280289
if let Some(bind) = self.binds.get_mut(&dst.port()) {
290+
if let Some(target) = bind.target_addr {
291+
if !matches(target, src) {
292+
tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Connect Addr not matching)");
293+
return;
294+
}
295+
}
296+
281297
if !matches(bind.bind_addr, dst) {
282298
tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Addr not bound)");
283299
return;

src/net/udp.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl Rx {
129129
/// will continue to return immediately until the readiness event is
130130
/// consumed by an attempt to read that fails with `WouldBlock` or
131131
/// `Poll::Pending`.
132-
pub async fn readable(&mut self) -> Result<()> {
132+
async fn readable(&mut self) -> Result<()> {
133133
if self.buffer.is_some() {
134134
return Ok(());
135135
}
@@ -156,6 +156,14 @@ impl UdpSocket {
156156
}),
157157
}
158158
}
159+
pub async fn connect<A: ToSocketAddrs>(&self, addr: A) {
160+
World::current(|world| {
161+
let addr = addr.to_socket_addr(&world.dns);
162+
let host = world.current_host_mut();
163+
164+
host.udp.connect(self.local_addr, addr);
165+
})
166+
}
159167

160168
/// This function will create a new UDP socket and attempt to bind it to
161169
/// the `addr` provided.
@@ -292,6 +300,19 @@ impl UdpSocket {
292300

293301
Ok((limit, origin))
294302
}
303+
/// Tries to receive a single datagram message on the socket from the remote
304+
/// address to which it is connected. On success, returns the number of
305+
/// bytes read.
306+
///
307+
/// This method must be called with valid byte array `buf` of sufficient size
308+
/// to hold the message bytes. If a message is too long to fit in the
309+
/// supplied buffer, excess bytes may be discarded.
310+
///
311+
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
312+
/// returned. This function is usually paired with `readable()`.
313+
pub fn try_recv(&self, buf: &mut [u8]) -> io::Result<usize> {
314+
self.try_recv_from(buf).map(|(size, _)| size)
315+
}
295316

296317
/// Waits for the socket to become readable.
297318
///

tests/udp.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,3 +1007,101 @@ fn socket_to_nonexistent_node() -> Result {
10071007
});
10081008
sim.run()
10091009
}
1010+
1011+
#[test]
1012+
fn try_recv_not_connected() -> Result {
1013+
let mut sim = Builder::new().build();
1014+
1015+
sim.client("server1", async move {
1016+
let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 1234)).await?;
1017+
1018+
sock.readable().await?;
1019+
1020+
let mut buf = [0; 3];
1021+
1022+
// When not connected try_recv takes any incoming Datagram.
1023+
let x = sock.try_recv(&mut buf)?;
1024+
1025+
assert_eq!(x, 3);
1026+
assert_eq!(buf, [1, 2, 3]);
1027+
1028+
Ok(())
1029+
});
1030+
1031+
sim.client("server2", async move {
1032+
let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 5678)).await?;
1033+
1034+
sock.send_to(&[1, 2, 3], ("server1", 1234)).await?;
1035+
1036+
Ok(())
1037+
});
1038+
1039+
sim.run()
1040+
}
1041+
1042+
#[test]
1043+
fn try_recv_connected() -> Result {
1044+
let mut sim = Builder::new().build();
1045+
1046+
sim.client("server1", async move {
1047+
let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 1234)).await?;
1048+
1049+
sock.connect(("server2", 5678)).await;
1050+
1051+
sock.readable().await?;
1052+
1053+
let mut buf = [0; 3];
1054+
1055+
let x = sock.try_recv(&mut buf)?;
1056+
1057+
assert_eq!(x, 3);
1058+
assert_eq!(buf, [1, 2, 3]);
1059+
1060+
Ok(())
1061+
});
1062+
1063+
sim.client("server2", async move {
1064+
let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 5678)).await?;
1065+
1066+
sock.send_to(&[1, 2, 3], ("server1", 1234)).await?;
1067+
1068+
Ok(())
1069+
});
1070+
1071+
sim.run()
1072+
}
1073+
1074+
#[test]
1075+
fn try_recv_connected_wrong_host() -> Result {
1076+
let mut sim = Builder::new().build();
1077+
1078+
sim.client("server1", async move {
1079+
let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 1234)).await?;
1080+
1081+
// Connect to server3, which doesn't send anything.
1082+
sock.connect(("server3", 6781)).await;
1083+
1084+
// Datagram from server2 is filtered so sock never becomes readable.
1085+
tokio::time::timeout(Duration::from_secs(5), sock.readable())
1086+
.await
1087+
.unwrap_err();
1088+
1089+
Ok(())
1090+
});
1091+
1092+
sim.client("server2", async move {
1093+
let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 5678)).await?;
1094+
1095+
sock.send_to(&[1, 2, 9], ("server1", 1234)).await?;
1096+
1097+
Ok(())
1098+
});
1099+
1100+
sim.client("server3", async move {
1101+
_ = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 6789)).await?;
1102+
1103+
Ok(())
1104+
});
1105+
1106+
sim.run()
1107+
}

0 commit comments

Comments
 (0)