Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compio-ws/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ readme = { workspace = true }
license = { workspace = true }
repository = { workspace = true }

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]

[dependencies]
rustls = { workspace = true, optional = true, default-features = false }
Expand Down
20 changes: 20 additions & 0 deletions compio-ws/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
//!
//! Each WebSocket stream implements message reading and writing.

#![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(missing_docs)]

pub mod stream;

#[cfg(feature = "rustls")]
pub mod rustls;

use std::io::ErrorKind;

use compio_buf::IntoInner;
use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
use tungstenite::{
Error as WsError, HandshakeError, Message, WebSocket,
Expand All @@ -34,6 +38,7 @@ pub use crate::rustls::{
connect_async_with_tls_connector_and_config,
};

/// A WebSocket stream that works with compio.
pub struct WebSocketStream<S> {
inner: WebSocket<SyncStream<S>>,
}
Expand All @@ -42,6 +47,7 @@ impl<S> WebSocketStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
{
/// Send a message on the WebSocket stream.
pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
// Send the message - this buffers it
// Since CompioStream::flush() now returns Ok, this should succeed on first try
Expand All @@ -57,6 +63,7 @@ where
Ok(())
}

/// Read a message from the WebSocket stream.
pub async fn read(&mut self) -> Result<Message, WsError> {
loop {
match self.inner.read() {
Expand Down Expand Up @@ -84,6 +91,7 @@ where
}
}

/// Close the WebSocket connection.
pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
loop {
match self.inner.close(close_frame.clone()) {
Expand All @@ -103,19 +111,31 @@ where
}
}

/// Get a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
self.inner.get_ref().get_ref()
}

/// Get a mutable reference to the underlying stream.
pub fn get_mut(&mut self) -> &mut S {
self.inner.get_mut().get_mut()
}

/// Get the inner WebSocket.
#[deprecated = "Use IntoInner trait instead. This method will be removed in a future release."]
pub fn get_inner(self) -> WebSocket<SyncStream<S>> {
self.inner
}
}

impl<S> IntoInner for WebSocketStream<S> {
type Inner = WebSocket<SyncStream<S>>;

fn into_inner(self) -> Self::Inner {
self.inner
}
}

/// Accepts a new WebSocket connection with the provided stream.
///
/// This function will internally call `server::accept` to create a
Expand Down
5 changes: 5 additions & 0 deletions compio-ws/src/rustls.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Rustls support.

#[cfg(any(feature = "rustls-platform-verifier", feature = "webpki-roots"))]
use std::sync::Arc;

Expand All @@ -17,8 +19,10 @@ use crate::{
WebSocketConfig, WebSocketStream, client_async_with_config, domain, stream::MaybeTlsStream,
};

/// Type alias for a stream that can be either plain TCP or TLS-encrypted.
pub type AutoStream<S> = MaybeTlsStream<S>;

/// Type alias for a TLS connector.
pub type Connector = TlsConnector;

async fn wrap_stream<S>(
Expand Down Expand Up @@ -202,6 +206,7 @@ where
client_async_with_config(request, stream, config).await
}

/// Type alias for a connected stream.
pub type ConnectStream = AutoStream<TcpStream>;

/// Connect to a given URL.
Expand Down
5 changes: 5 additions & 0 deletions compio-ws/src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Provides [`MaybeTlsStream`].

#[cfg(feature = "rustls")]
use std::io::Result as IoResult;

Expand All @@ -22,15 +24,18 @@ pub enum MaybeTlsStream<S> {

#[cfg(feature = "rustls")]
impl<S> MaybeTlsStream<S> {
/// Create an unencrypted stream.
pub fn plain(stream: S) -> Self {
MaybeTlsStream::Plain(stream)
}

/// Create a TLS-encrypted stream.
#[cfg(feature = "rustls")]
pub fn tls(stream: TlsStream<S>) -> Self {
MaybeTlsStream::Tls(stream)
}

/// Whether the stream is TLS-encrypted.
pub fn is_tls(&self) -> bool {
#[cfg(feature = "rustls")]
{
Expand Down
Loading