Skip to content

Commit 575f9d0

Browse files
authored
Improvements to usb-gadget server impl (#151)
* refactor(usb-gdaget): use tokio::watch for simpler wait_connection for tx/rx * refactor(usb-gadget): remove explicit endpoint size + simplify tx/rx logic send_async / recv_async already do framing and packetization, apparently...
1 parent 4ba78ee commit 575f9d0

2 files changed

Lines changed: 56 additions & 145 deletions

File tree

example/server-usb-gadget/src/bin/comms-01.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@ use postcard_rpc::{
22
define_dispatch,
33
header::VarHeader,
44
server::{
5-
impls::usb_gadget::{
6-
dispatch_impl::{WireRxBuf, WireRxImpl, WireSpawnImpl, WireStorage, WireTxImpl},
7-
USB_FS_MAX_PACKET_SIZE,
5+
impls::usb_gadget::dispatch_impl::{
6+
WireRxBuf, WireRxImpl, WireSpawnImpl, WireStorage, WireTxImpl,
87
},
98
Dispatch, Server,
109
},
@@ -63,7 +62,7 @@ async fn main() {
6362
let tx_buf = TX_BUF.init([0u8; 1024]);
6463

6564
let (_reg, tx_impl, rx_impl) = STORAGE
66-
.init(gadget, tx_buf.as_mut_slice(), USB_FS_MAX_PACKET_SIZE)
65+
.init(gadget, tx_buf.as_mut_slice())
6766
.expect("Failed to init");
6867
let dispatcher = Dispatcher::new(context, tokio::runtime::Handle::current().into());
6968

source/postcard-rpc/src/server/impls/usb_gadget.rs

Lines changed: 53 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
#![allow(missing_docs)]
22

3-
use core::{
4-
sync::atomic::{AtomicBool, Ordering},
5-
time::Duration,
6-
};
3+
use core::time::Duration;
74
use std::sync::Arc;
85

96
use crate::{
@@ -14,7 +11,7 @@ use crate::{
1411
};
1512

1613
use bytes::Bytes;
17-
use tokio::sync::Mutex;
14+
use tokio::sync::{watch, Mutex};
1815
use usb_gadget::function::custom::{EndpointReceiver, EndpointSender};
1916

2017
/// Default time in milliseconds to wait for the completion of sending
@@ -26,12 +23,7 @@ pub const USB_HS_MAX_PACKET_SIZE: usize = 512;
2623

2724
/// A collection of types and aliases useful for importing the correct types
2825
pub mod dispatch_impl {
29-
use core::sync::atomic::{AtomicBool, Ordering};
30-
31-
use std::{
32-
io::{self, Error, ErrorKind},
33-
sync::Arc,
34-
};
26+
use std::io::{self, Error, ErrorKind};
3527

3628
use usb_gadget::{
3729
function::{
@@ -69,12 +61,10 @@ pub mod dispatch_impl {
6961
&'static self,
7062
gadget: Gadget,
7163
tx_buf: &'static mut [u8],
72-
max_usb_frame_size: usize,
7364
) -> Result<(RegGadget, WireTxImpl, WireRxImpl), io::Error> {
7465
let udc = usb_gadget::default_udc()?;
7566

76-
let ((gadget, handle), wtx, wrx) =
77-
self.init_without_build(gadget, tx_buf, max_usb_frame_size);
67+
let ((gadget, handle), wtx, wrx) = self.init_without_build(gadget, tx_buf);
7868
let reg = gadget
7969
.with_config(Config::new("config").with_function(handle))
8070
.bind(&udc)?;
@@ -86,59 +76,46 @@ pub mod dispatch_impl {
8676
&'static self,
8777
gadget: Gadget,
8878
tx_buf: &'static mut [u8],
89-
max_usb_frame_size: usize,
9079
) -> ((Gadget, Handle), WireTxImpl, WireRxImpl) {
91-
assert!(max_usb_frame_size.is_power_of_two());
92-
9380
let (ep_tx, ep_tx_dir) = EndpointDirection::device_to_host();
9481
let (ep_rx, ep_rx_dir) = EndpointDirection::host_to_device();
9582

9683
let (mut custom, handle) = Custom::builder()
9784
.with_interface(
9885
Interface::new(Class::vendor_specific(0, 0), "postcard-rpc")
99-
.with_endpoint({
100-
let mut ep = Endpoint::bulk(ep_tx_dir);
101-
ep.max_packet_size_hs = max_usb_frame_size as u16;
102-
ep
103-
})
104-
.with_endpoint({
105-
let mut ep = Endpoint::bulk(ep_rx_dir);
106-
ep.max_packet_size_hs = max_usb_frame_size as u16;
107-
ep
108-
}),
86+
.with_endpoint(Endpoint::bulk(ep_tx_dir))
87+
.with_endpoint(Endpoint::bulk(ep_rx_dir)),
10988
)
11089
.build();
11190

11291
let gadget = gadget
11392
.with_os_descriptor(OsDescriptor::microsoft())
11493
.with_web_usb(WebUsb::new(0xf1, "http://webusb.org"));
11594

116-
let rx_enabled = Arc::new(AtomicBool::new(false));
117-
let tx_enabled = Arc::new(AtomicBool::new(false));
118-
119-
{
120-
let rx_enabled = rx_enabled.clone();
121-
let tx_enabled = tx_enabled.clone();
122-
123-
// Listen to events on the custom function
124-
// The device will be unbound/removed when the `custom` interface is dropped
125-
tokio::spawn(async move {
126-
while let Ok(_) = custom.wait_event().await {
127-
match custom.event()? {
128-
Event::Enable => {
129-
tx_enabled.store(true, Ordering::Release);
130-
rx_enabled.store(true, Ordering::Release);
131-
}
132-
_ => {}
95+
let (enabled_tx, enabled_rx) = tokio::sync::watch::channel(false);
96+
97+
// Listen to events on the custom function
98+
// The device will be unbound/removed when the `custom` interface is dropped
99+
tokio::spawn(async move {
100+
while let Ok(_) = custom.wait_event().await {
101+
let event = custom.event()?;
102+
103+
match event {
104+
Event::Enable => {
105+
let _ = enabled_tx.send(true);
106+
}
107+
Event::Disable | Event::Suspend => {
108+
let _ = enabled_tx.send(false);
133109
}
110+
_ => {}
134111
}
112+
}
135113

136-
Err::<(), io::Error>(Error::from(ErrorKind::BrokenPipe))
137-
});
138-
}
114+
Err::<(), io::Error>(Error::from(ErrorKind::BrokenPipe))
115+
});
139116

140-
let wtx = UsbGadgetWireTx::new(ep_tx, tx_enabled, tx_buf);
141-
let wrx = UsbGadgetWireRx::new(ep_rx, rx_enabled);
117+
let wtx = UsbGadgetWireTx::new(ep_tx, enabled_rx.clone(), tx_buf);
118+
let wrx = UsbGadgetWireRx::new(ep_rx, enabled_rx.clone());
142119

143120
((gadget, handle), wtx, wrx)
144121
}
@@ -153,46 +130,41 @@ pub mod dispatch_impl {
153130
#[derive(Debug, Clone)]
154131
pub struct UsbGadgetWireTx {
155132
inner: Arc<Mutex<UsbGadgetWireTxInner>>,
133+
ep_enabled: watch::Receiver<bool>,
156134
}
157135

158136
impl UsbGadgetWireTx {
159137
pub fn new(
160138
ep_tx: EndpointSender,
161-
ep_enabled: Arc<AtomicBool>,
139+
ep_enabled: watch::Receiver<bool>,
162140
tx_buf: &'static mut [u8],
163141
) -> Self {
164142
let inner = UsbGadgetWireTxInner {
165143
ep_tx,
166-
ep_enabled,
167144
log_seq: 0,
168145
tx_buf,
169-
pending_frame: false,
170146
};
171147

172148
Self {
173149
inner: Arc::new(Mutex::new(inner)),
150+
ep_enabled,
174151
}
175152
}
176153
}
177154

178155
#[derive(Debug)]
179156
struct UsbGadgetWireTxInner {
180157
ep_tx: EndpointSender,
181-
ep_enabled: Arc<AtomicBool>,
182158
log_seq: u16,
183159
tx_buf: &'static mut [u8],
184-
pending_frame: bool,
185160
}
186161

187162
impl WireTx for UsbGadgetWireTx {
188163
type Error = WireTxErrorKind;
189164

190165
async fn wait_connection(&self) {
191-
let inner = self.inner.lock().await;
192-
193-
while !inner.ep_enabled.load(Ordering::Acquire) {
194-
tokio::time::sleep(Duration::from_millis(2)).await;
195-
}
166+
let mut ep_enabled = self.ep_enabled.clone();
167+
let _ = ep_enabled.wait_for(|&enabled| enabled).await;
196168
}
197169

198170
async fn send<T: serde::Serialize + ?Sized>(
@@ -218,56 +190,19 @@ impl WireTx for UsbGadgetWireTx {
218190

219191
async fn send_raw(&self, buf: &[u8]) -> Result<(), Self::Error> {
220192
let mut inner = self.inner.lock().await;
221-
let UsbGadgetWireTxInner {
222-
ep_tx,
223-
pending_frame,
224-
..
225-
} = &mut *inner;
193+
let UsbGadgetWireTxInner { ep_tx, .. } = &mut *inner;
226194

227-
let chunk_size = ep_tx
195+
let packet_size = ep_tx
228196
.max_packet_size()
229197
.or(Err(WireTxErrorKind::ConnectionClosed))?;
230-
231-
let timeout_ms_per_frame = DEFAULT_TIMEOUT_MS_PER_FRAME;
232-
233-
// Calculate an estimated timeout based on the number of frames we need to send
234-
// For now, we use 2ms/frame by default, rounded UP
235-
let frames = (buf.len() + (chunk_size - 1)) / chunk_size;
236-
let timeout = Duration::from_millis((frames * timeout_ms_per_frame) as u64);
198+
let num_packets = buf.len().div_ceil(packet_size);
199+
let timeout = Duration::from_millis((num_packets * DEFAULT_TIMEOUT_MS_PER_FRAME) as u64);
237200

238201
let send = async {
239-
// If we left off a pending frame, send one now so we don't leave an unterminated message
240-
if *pending_frame {
241-
ep_tx
242-
.send_async(Bytes::new())
243-
.await
244-
.or(Err(WireTxErrorKind::ConnectionClosed))?
245-
}
246-
247-
*pending_frame = true;
248-
249-
let mut bytes = Bytes::copy_from_slice(buf);
250-
251-
while !bytes.is_empty() {
252-
let ch = bytes.split_to(chunk_size.min(bytes.len()));
253-
254-
ep_tx
255-
.send_async(ch)
256-
.await
257-
.or(Err(WireTxErrorKind::ConnectionClosed))?;
258-
}
259-
260-
// If the total we sent was a multiple of packet size, send an
261-
// empty message to "flush" the transaction. We already checked
262-
// above that the len != 0.
263-
if (buf.len() & (chunk_size - 1)) == 0 {
264-
ep_tx
265-
.send_async(Bytes::new())
266-
.await
267-
.or(Err(WireTxErrorKind::ConnectionClosed))?
268-
}
269-
270-
*pending_frame = false;
202+
ep_tx
203+
.send_async(Bytes::copy_from_slice(buf))
204+
.await
205+
.or(Err(WireTxErrorKind::ConnectionClosed))?;
271206

272207
Ok::<(), WireTxErrorKind>(())
273208
};
@@ -486,11 +421,11 @@ fn actual_varint_max_len(largest: usize) -> usize {
486421
#[derive(Debug, Clone)]
487422
pub struct UsbGadgetWireRx {
488423
ep_rx: Arc<Mutex<EndpointReceiver>>,
489-
ep_enabled: Arc<AtomicBool>,
424+
ep_enabled: watch::Receiver<bool>,
490425
}
491426

492427
impl UsbGadgetWireRx {
493-
pub fn new(ep_rx: EndpointReceiver, ep_enabled: Arc<AtomicBool>) -> Self {
428+
pub fn new(ep_rx: EndpointReceiver, ep_enabled: watch::Receiver<bool>) -> Self {
494429
Self {
495430
ep_rx: Arc::new(Mutex::new(ep_rx)),
496431
ep_enabled,
@@ -502,50 +437,27 @@ impl WireRx for UsbGadgetWireRx {
502437
type Error = WireRxErrorKind;
503438

504439
async fn wait_connection(&mut self) {
505-
let Self { ep_enabled, .. } = self;
506-
507-
while !ep_enabled.load(Ordering::Acquire) {
508-
tokio::time::sleep(Duration::from_millis(2)).await;
509-
}
440+
let _ = self.ep_enabled.wait_for(|&enabled| enabled).await;
510441
}
511442

512443
async fn receive<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Self::Error> {
513444
let mut ep_rx = self.ep_rx.lock().await;
514445

515-
let packet_size = ep_rx
516-
.max_packet_size()
446+
let data = ep_rx
447+
.recv_async(bytes::BytesMut::with_capacity(buf.len()))
448+
.await
517449
.or(Err(WireRxErrorKind::ConnectionClosed))?;
518450

519-
let buflen = buf.len();
520-
let mut window = &mut buf[..];
521-
522-
while !window.is_empty() {
523-
let data = ep_rx
524-
.recv_async(bytes::BytesMut::with_capacity(packet_size))
525-
.await
526-
.or(Err(WireRxErrorKind::ConnectionClosed))?;
527-
528-
match data {
529-
Some(data) => {
530-
let n = data.len();
531-
window[0..n].copy_from_slice(&data);
532-
533-
let (_now, later) = window.split_at_mut(n);
534-
window = later;
535-
if n != packet_size {
536-
// We now have a full frame! Great!
537-
let wlen = window.len();
538-
let len = buflen - wlen;
539-
let frame = &mut buf[..len];
540-
541-
return Ok(frame);
542-
}
451+
match data {
452+
Some(data) => {
453+
if data.len() > buf.len() {
454+
return Err(WireRxErrorKind::ReceivedMessageTooLarge);
543455
}
544-
None => return Ok(&mut buf[0..0]),
456+
457+
buf[..data.len()].copy_from_slice(&data);
458+
Ok(&mut buf[..data.len()])
545459
}
460+
None => Ok(&mut buf[0..0]),
546461
}
547-
548-
// Ran out of space...?
549-
Err(WireRxErrorKind::Other)
550462
}
551463
}

0 commit comments

Comments
 (0)