Skip to content

Commit

Permalink
fixup! feat(wasip3): add wasi:[email protected]
Browse files Browse the repository at this point in the history
  • Loading branch information
rvolosatovs committed Feb 3, 2025
1 parent 5e0c8d6 commit 251c2a1
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 162 deletions.
216 changes: 55 additions & 161 deletions crates/wasi/src/p3/sockets/host/types/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@ use core::net::SocketAddr;

use anyhow::{ensure, Context as _};
use rustix::io::Errno;
use rustix::net::sockopt;
use wasmtime::component::{for_any, stream, FutureReader, Resource, StreamReader};
use wasmtime::component::{for_any, FutureReader, Resource, StreamReader};
use wasmtime::StoreContextMut;

use crate::p3::bindings::sockets::types::{
Duration, ErrorCode, HostTcpSocket, IpAddressFamily, IpSocketAddress, TcpSocket,
};
use crate::p3::sockets::tcp::{bind, connect, TcpState};
use crate::p3::sockets::{SocketAddrUse, WasiSocketsImpl, WasiSocketsView};
use crate::p3::{
bindings::sockets::types::{
Duration, ErrorCode, HostTcpSocket, IpAddressFamily, IpSocketAddress, TcpSocket,
},
sockets::SocketAddressFamily,
};

impl<T> HostTcpSocket for WasiSocketsImpl<&mut T>
where
Expand Down Expand Up @@ -260,21 +256,7 @@ where
.table()
.get(&socket)
.context("failed to get socket resource from table")?;
match &sock.tcp_state {
TcpState::Bound(socket) => match socket.local_addr() {
Ok(addr) => Ok(Ok(addr.into())),
Err(err) => Ok(Err(err.into())),
},
TcpState::Connected(stream) => match stream.local_addr() {
Ok(addr) => Ok(Ok(addr.into())),
Err(err) => Ok(Err(err.into())),
},
TcpState::Listening(listener) => match listener.local_addr() {
Ok(addr) => Ok(Ok(addr.into())),
Err(err) => Ok(Err(err.into())),
},
_ => Ok(Err(ErrorCode::InvalidState)),
}
Ok(sock.local_address())
}

fn remote_address(
Expand All @@ -285,67 +267,35 @@ where
.table()
.get(&socket)
.context("failed to get socket resource from table")?;
match &sock.tcp_state {
TcpState::Connected(stream) => match stream.peer_addr() {
Ok(addr) => Ok(Ok(addr.into())),
Err(err) => Ok(Err(err.into())),
},
_ => Ok(Err(ErrorCode::InvalidState)),
}
Ok(sock.remote_address())
}

fn is_listening(&mut self, socket: Resource<TcpSocket>) -> wasmtime::Result<bool> {
let sock = self
.table()
.get(&socket)
.context("failed to get socket from table")?;
Ok(matches!(sock.tcp_state, TcpState::Listening { .. }))
Ok(sock.is_listening())
}

fn address_family(&mut self, socket: Resource<TcpSocket>) -> wasmtime::Result<IpAddressFamily> {
let sock = self
.table()
.get(&socket)
.context("failed to get socket from table")?;
match sock.family {
SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4),
SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6),
}
Ok(sock.address_family())
}

fn set_listen_backlog_size(
&mut self,
socket: Resource<TcpSocket>,
value: u64,
) -> wasmtime::Result<Result<(), ErrorCode>> {
const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.

let sock = self
.table()
.get_mut(&socket)
.context("failed to get socket from table")?;
// Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
if value == 0 {
return Ok(Err(ErrorCode::InvalidArgument));
}
let value = value.try_into().unwrap_or(MAX_BACKLOG).min(MAX_BACKLOG);
match &sock.tcp_state {
TcpState::Default(..) | TcpState::Bound(..) => {
// Socket not listening yet. Stash value for first invocation to `listen`.
sock.listen_backlog_size = value;
Ok(Ok(()))
}
TcpState::Listening(listener) => {
// Try to update the backlog by calling `listen` again.
// Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() {
return Ok(Err(ErrorCode::NotSupported));
}
sock.listen_backlog_size = value;
Ok(Ok(()))
}
_ => Ok(Err(ErrorCode::InvalidState.into())),
}
Ok(sock.set_listen_backlog_size(value))
}

fn keep_alive_enabled(
Expand All @@ -356,20 +306,7 @@ where
.table()
.get(&socket)
.context("failed to get socket from table")?;
match &sock.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
Ok(sockopt::get_socket_keepalive(socket).map_err(Into::into))
}
TcpState::Connected(stream) => {
Ok(sockopt::get_socket_keepalive(stream).map_err(Into::into))
}
TcpState::Listening(listener) => {
Ok(sockopt::get_socket_keepalive(listener).map_err(Into::into))
}
TcpState::BindStarted | TcpState::Connecting | TcpState::Closed => {
Ok(Err(ErrorCode::InvalidState))
}
}
Ok(sock.keep_alive_enabled())
}

fn set_keep_alive_enabled(
Expand All @@ -381,20 +318,7 @@ where
.table()
.get(&socket)
.context("failed to get socket from table")?;
match &sock.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
Ok(sockopt::set_socket_keepalive(socket, value).map_err(Into::into))
}
TcpState::Connected(stream) => {
Ok(sockopt::set_socket_keepalive(stream, value).map_err(Into::into))
}
TcpState::Listening(listener) => {
Ok(sockopt::set_socket_keepalive(listener, value).map_err(Into::into))
}
TcpState::BindStarted | TcpState::Connecting | TcpState::Closed => {
Ok(Err(ErrorCode::InvalidState))
}
}
Ok(sock.set_keep_alive_enabled(value))
}

fn keep_alive_idle_time(
Expand All @@ -405,58 +329,19 @@ where
.table()
.get(&socket)
.context("failed to get socket from table")?;
match match &sock.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
sockopt::get_tcp_keepidle(socket)
}
TcpState::Connected(stream) => sockopt::get_tcp_keepidle(stream),
TcpState::Listening(listener) => sockopt::get_tcp_keepidle(listener),
TcpState::BindStarted | TcpState::Connecting | TcpState::Closed => {
return Ok(Err(ErrorCode::InvalidState))
}
} {
Ok(t) => Ok(Ok(t.as_nanos().try_into().unwrap_or(u64::MAX))),
Err(err) => Ok(Err(err.into())),
}
Ok(sock.keep_alive_idle_time())
}

fn set_keep_alive_idle_time(
&mut self,
socket: Resource<TcpSocket>,
value: Duration,
) -> wasmtime::Result<Result<(), ErrorCode>> {
// Ensure that the value passed to the actual syscall never gets rounded down to 0.
const MIN_SECS: core::time::Duration = core::time::Duration::from_secs(1);

// Cap it at Linux' maximum, which appears to have the lowest limit across our supported platforms.
const MAX_SECS: core::time::Duration = core::time::Duration::from_secs(i16::MAX as u64);

if value == 0 {
// WIT: "If the provided value is 0, an `invalid-argument` error is returned."
return Ok(Err(ErrorCode::InvalidArgument));
}
let value = core::time::Duration::from_nanos(value).clamp(MIN_SECS, MAX_SECS);
let sock = self
.table()
.get_mut(&socket)
.context("failed to get socket from table")?;
match match &sock.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
sockopt::set_tcp_keepidle(socket, value)
}
TcpState::Connected(stream) => sockopt::set_tcp_keepidle(stream, value),
TcpState::Listening(listener) => sockopt::set_tcp_keepidle(listener, value),
_ => return Ok(Err(ErrorCode::InvalidState)),
} {
Ok(()) => {
#[cfg(target_os = "macos")]
{
sock.keep_alive_idle_time = Some(value);
}
Ok(Ok(()))
}
Err(err) => Ok(Err(err.into())),
}
Ok(sock.set_keep_alive_idle_time(value))
}

fn keep_alive_interval(
Expand All @@ -467,27 +352,19 @@ where
.table()
.get(&socket)
.context("failed to get socket from table")?;
match match &sock.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
sockopt::get_tcp_keepintvl(socket)
}
TcpState::Connected(stream) => sockopt::get_tcp_keepintvl(stream),
TcpState::Listening(listener) => sockopt::get_tcp_keepintvl(listener),
TcpState::BindStarted | TcpState::Connecting | TcpState::Closed => {
return Ok(Err(ErrorCode::InvalidState))
}
} {
Ok(t) => Ok(Ok(t.as_nanos().try_into().unwrap_or(u64::MAX))),
Err(err) => Ok(Err(err.into())),
}
Ok(sock.keep_alive_interval())
}

fn set_keep_alive_interval(
&mut self,
socket: Resource<TcpSocket>,
value: Duration,
) -> wasmtime::Result<Result<(), ErrorCode>> {
todo!()
let sock = self
.table()
.get(&socket)
.context("failed to get socket from table")?;
Ok(sock.set_keep_alive_interval(value))
}

fn keep_alive_count(
Expand All @@ -498,71 +375,88 @@ where
.table()
.get(&socket)
.context("failed to get socket from table")?;
match &sock.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
Ok(sockopt::get_tcp_keepcnt(socket).map_err(Into::into))
}
TcpState::Connected(stream) => Ok(sockopt::get_tcp_keepcnt(stream).map_err(Into::into)),
TcpState::Listening(listener) => {
Ok(sockopt::get_tcp_keepcnt(listener).map_err(Into::into))
}
TcpState::BindStarted | TcpState::Connecting | TcpState::Closed => {
Ok(Err(ErrorCode::InvalidState))
}
}
Ok(sock.keep_alive_count())
}

fn set_keep_alive_count(
&mut self,
socket: Resource<TcpSocket>,
value: u32,
) -> wasmtime::Result<Result<(), ErrorCode>> {
todo!()
let sock = self
.table()
.get(&socket)
.context("failed to get socket from table")?;
Ok(sock.set_keep_alive_count(value))
}

fn hop_limit(
&mut self,
socket: Resource<TcpSocket>,
) -> wasmtime::Result<Result<u8, ErrorCode>> {
todo!()
let sock = self
.table()
.get(&socket)
.context("failed to get socket from table")?;
Ok(sock.hop_limit())
}

fn set_hop_limit(
&mut self,
socket: Resource<TcpSocket>,
value: u8,
) -> wasmtime::Result<Result<(), ErrorCode>> {
todo!()
let sock = self
.table()
.get(&socket)
.context("failed to get socket from table")?;
Ok(sock.set_hop_limit(value))
}

fn receive_buffer_size(
&mut self,
socket: Resource<TcpSocket>,
) -> wasmtime::Result<Result<u64, ErrorCode>> {
todo!()
let sock = self
.table()
.get(&socket)
.context("failed to get socket from table")?;
Ok(sock.receive_buffer_size())
}

fn set_receive_buffer_size(
&mut self,
socket: Resource<TcpSocket>,
value: u64,
) -> wasmtime::Result<Result<(), ErrorCode>> {
todo!()
let sock = self
.table()
.get_mut(&socket)
.context("failed to get socket from table")?;
Ok(sock.set_receive_buffer_size(value))
}

fn send_buffer_size(
&mut self,
socket: Resource<TcpSocket>,
) -> wasmtime::Result<Result<u64, ErrorCode>> {
todo!()
let sock = self
.table()
.get(&socket)
.context("failed to get socket from table")?;
Ok(sock.send_buffer_size())
}

fn set_send_buffer_size(
&mut self,
socket: Resource<TcpSocket>,
value: u64,
) -> wasmtime::Result<Result<(), ErrorCode>> {
todo!()
let sock = self
.table()
.get_mut(&socket)
.context("failed to get socket from table")?;
Ok(sock.set_send_buffer_size(value))
}

fn drop(&mut self, rep: Resource<TcpSocket>) -> wasmtime::Result<()> {
Expand Down
Loading

0 comments on commit 251c2a1

Please sign in to comment.