Skip to content

Commit 1fe8a99

Browse files
committed
Implement Sink for WsClient
1 parent 77fa91a commit 1fe8a99

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

src/test.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ use std::future::Future;
8888
use std::net::SocketAddr;
8989
#[cfg(feature = "websocket")]
9090
use std::pin::Pin;
91+
use std::task::Context;
9192
#[cfg(feature = "websocket")]
9293
use std::task::{self, Poll};
9394

@@ -106,10 +107,11 @@ use serde_json;
106107
use tokio::sync::oneshot;
107108

108109
use crate::filter::Filter;
110+
use crate::filters::ws::Message;
109111
use crate::reject::IsReject;
110112
use crate::reply::Reply;
111113
use crate::route::{self, Route};
112-
use crate::Request;
114+
use crate::{Request, Sink};
113115

114116
use self::inner::OneOrTuple;
115117

@@ -600,6 +602,11 @@ impl WsClient {
600602
Ok(())
601603
})
602604
}
605+
606+
fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender<crate::ws::Message>> {
607+
let this = Pin::into_inner(self);
608+
Pin::new(&mut this.tx)
609+
}
603610
}
604611

605612
#[cfg(feature = "websocket")]
@@ -609,6 +616,36 @@ impl fmt::Debug for WsClient {
609616
}
610617
}
611618

619+
#[cfg(feature = "websocket")]
620+
impl Sink<crate::ws::Message> for WsClient {
621+
type Error = ();
622+
623+
fn poll_ready(
624+
self: Pin<&mut Self>,
625+
context: &mut Context<'_>,
626+
) -> Poll<Result<(), Self::Error>> {
627+
self.pinned_tx().poll_ready(context).map_err(|_| ())
628+
}
629+
630+
fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
631+
self.pinned_tx().start_send(message).map_err(|_| ())
632+
}
633+
634+
fn poll_flush(
635+
self: Pin<&mut Self>,
636+
context: &mut Context<'_>,
637+
) -> Poll<Result<(), Self::Error>> {
638+
self.pinned_tx().poll_flush(context).map_err(|_| ())
639+
}
640+
641+
fn poll_close(
642+
self: Pin<&mut Self>,
643+
context: &mut Context<'_>,
644+
) -> Poll<Result<(), Self::Error>> {
645+
self.pinned_tx().poll_close(context).map_err(|_| ())
646+
}
647+
}
648+
612649
// ===== impl WsError =====
613650

614651
#[cfg(feature = "websocket")]

0 commit comments

Comments
 (0)