Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
502 changes: 494 additions & 8 deletions example/workbook-host/Cargo.lock

Large diffs are not rendered by default.

14 changes: 13 additions & 1 deletion source/postcard-rpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ package = "embassy-usb"
version = "0.4"
optional = true

[dependencies.embassy-net]
package = "embassy-net"
version = "0.7"
optional = true
features = [ "medium-ip", "proto-ipv4", "proto-ipv6", "tcp" ]

[dependencies.embassy-usb-driver]
version = "0.1"
optional = true
Expand Down Expand Up @@ -164,7 +170,7 @@ features = ["std"]


[features]
default = []
default = ["embassy-usb-0_4-server", "embassy-net-tcp-server", "tcp"]
test-utils = ["use-std", "postcard-schema/use-std"]
use-std = [
"dep:maitake-sync",
Expand Down Expand Up @@ -210,6 +216,7 @@ webusb = [
"dep:js-sys",
"use-std",
]
tcp = ["use-std", "tokio/net", "dep:cobs", "cobs/use_std"]
embassy-usb-0_3-server = [
"dep:embassy-usb-0_3",
"dep:embassy-sync",
Expand All @@ -230,6 +237,11 @@ embassy-usb-0_4-server = [
"dep:embassy-futures",
]

embassy-net-tcp-server = [
"dep:embassy-net",
"dep:cobs",
]

# NOTE: This exists because `embassy-usb` indirectly relies on ssmarshal
# which doesn't work on `std` builds without the `std` feature. This causes
# `cargo doc --all-features` (and docs.rs builds) to fail. Sneakily re-activate
Expand Down
3 changes: 3 additions & 0 deletions source/postcard-rpc/src/host_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ mod serial;
#[cfg(all(feature = "webusb", target_family = "wasm"))]
pub mod webusb;

#[cfg(all(feature = "tcp", not(target_family = "wasm")))]
pub mod tcp;

pub(crate) mod util;

#[cfg(feature = "test-utils")]
Expand Down
172 changes: 172 additions & 0 deletions source/postcard-rpc/src/host_client/tcp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
//! TCP client
use std::error::Error;
use std::fmt::{Debug, Display};
use std::future::Future;
use std::net::SocketAddr;

use postcard_schema::Schema;
use serde::de::DeserializeOwned;
use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::TcpStream;

use crate::header;
use crate::standard_icd::ERROR_PATH;

use super::{HostClient, WireRx, WireSpawn, WireTx};

/// Error during TCP RX
pub enum TcpCommsRxError {
/// Rx buffer overflow
RxOverflow,
/// General connection error
ConnError,
}

impl Debug for TcpCommsRxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("oops")
}
}

impl Display for TcpCommsRxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("oops")
}
}

impl Error for TcpCommsRxError {}

struct TcpCommsRx<T: AsyncRead + Send + 'static> {
addr: SocketAddr,
buf: Vec<u8>,
rx: ReadHalf<T>,
}

impl<T: AsyncRead + Send + 'static> TcpCommsRx<T> {
async fn receive_inner(&mut self) -> Result<Vec<u8>, TcpCommsRxError> {
let mut rx_buf = [0u8; 1024];
'frame: loop {
if self.buf.len() > (1024 * 1024) {
tracing::warn!(?self.addr, "Refusing to collect >1MiB, terminating");
self.buf.clear();
return Err(TcpCommsRxError::RxOverflow);
}

// Do we have a message already?
if let Some(pos) = self.buf.iter().position(|b| *b == 0) {
// we found the end of a message, attempt to decode it
let mut split = self.buf.split_off(pos + 1);
core::mem::swap(&mut self.buf, &mut split);

// Can we decode the cobs?
let res = cobs::decode_vec(&split);
let Ok(msg) = res else {
tracing::warn!(?self.addr, discarded = split.len(), "Discarding bad message (cobs)");
continue 'frame;
};

return Ok(msg);
}

// No message yet, let's try and receive some data
let Ok(used) = self.rx.read(&mut rx_buf).await else {
tracing::warn!(?self.addr, "Closing");
return Err(TcpCommsRxError::ConnError);
};
if used == 0 {
tracing::warn!(?self.addr, "Closing");
return Err(TcpCommsRxError::ConnError);
}
self.buf.extend_from_slice(&rx_buf[..used]);
}
}
}

impl<T: AsyncRead + Send + 'static> WireRx for TcpCommsRx<T> {
type Error = TcpCommsRxError;

fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send {
self.receive_inner()
}
}

/// An error during TX
pub enum TcpCommsTxError {
/// A general tx comms error
CommsError,
}

impl Debug for TcpCommsTxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("oops")
}
}

impl Display for TcpCommsTxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("oops")
}
}

impl Error for TcpCommsTxError {}

struct TcpCommsTx<T: AsyncWrite + Send + 'static> {
tx: WriteHalf<T>,
}

impl<T: AsyncWrite + Send + 'static> TcpCommsTx<T> {
async fn send_inner(&mut self, data: Vec<u8>) -> Result<(), TcpCommsTxError> {
//let mut data = cobs::encode_vec(&data);
//data.push(0);
self.tx
.write_all(&data)
.await
.map_err(|_| TcpCommsTxError::CommsError)
}
}

impl<T: AsyncWrite + Send + 'static> WireTx for TcpCommsTx<T> {
type Error = TcpCommsTxError;

fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send {
self.send_inner(data)
}
}

// ---

struct TcpSpawn;

impl WireSpawn for TcpSpawn {
fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(fut);
}
}

impl<WireErr> HostClient<WireErr>
where
WireErr: DeserializeOwned + Schema,
{
/// Connect to a server via TCP
pub async fn connect_tcp<T>(addr: T) -> Self
where
T: tokio::net::ToSocketAddrs + Debug,
{
println!("connecting to {:?}", addr);
let stream = TcpStream::connect(addr).await.unwrap();
let addr = stream.peer_addr().unwrap();
let (rx, tx) = split(stream);
HostClient::new_with_wire(
TcpCommsTx { tx },
TcpCommsRx {
rx,
addr,
buf: vec![],
},
TcpSpawn,
header::VarSeqKind::Seq4,
ERROR_PATH,
64,
)
}
}
Loading