Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compio-ws/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ readme = { workspace = true }
license = { workspace = true }
repository = { workspace = true }

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]

[dependencies]
rustls = { workspace = true, optional = true, default-features = false }
Expand Down
20 changes: 20 additions & 0 deletions compio-ws/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
//!
//! Each WebSocket stream implements message reading and writing.

#![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(missing_docs)]

pub mod stream;

#[cfg(feature = "rustls")]
pub mod rustls;

use std::io::ErrorKind;

use compio_buf::IntoInner;
use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
use tungstenite::{
Error as WsError, HandshakeError, Message, WebSocket,
Expand All @@ -34,6 +38,7 @@ pub use crate::rustls::{
connect_async_with_tls_connector_and_config,
};

/// A WebSocket stream that works with compio.
pub struct WebSocketStream<S> {
inner: WebSocket<SyncStream<S>>,
}
Expand All @@ -42,6 +47,7 @@ impl<S> WebSocketStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug,
{
/// Send a message on the WebSocket stream.
pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
// Send the message - this buffers it
// Since CompioStream::flush() now returns Ok, this should succeed on first try
Expand All @@ -57,6 +63,7 @@ where
Ok(())
}

/// Read a message from the WebSocket stream.
pub async fn read(&mut self) -> Result<Message, WsError> {
loop {
match self.inner.read() {
Expand Down Expand Up @@ -84,6 +91,7 @@ where
}
}

/// Close the WebSocket connection.
pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
loop {
match self.inner.close(close_frame.clone()) {
Expand All @@ -103,19 +111,31 @@ where
}
}

/// Get a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
self.inner.get_ref().get_ref()
}

/// Get a mutable reference to the underlying stream.
pub fn get_mut(&mut self) -> &mut S {
self.inner.get_mut().get_mut()
}

/// Get the inner WebSocket.
#[deprecated = "Use IntoInner trait instead. This method will be removed in a future release."]
pub fn get_inner(self) -> WebSocket<SyncStream<S>> {
self.inner
}
}

impl<S> IntoInner for WebSocketStream<S> {
type Inner = WebSocket<SyncStream<S>>;

fn into_inner(self) -> Self::Inner {
self.inner
}
}

/// Accepts a new WebSocket connection with the provided stream.
///
/// This function will internally call `server::accept` to create a
Expand Down
5 changes: 5 additions & 0 deletions compio-ws/src/rustls.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Rustls support.

#[cfg(any(feature = "rustls-platform-verifier", feature = "webpki-roots"))]
use std::sync::Arc;

Expand All @@ -17,8 +19,10 @@ use crate::{
WebSocketConfig, WebSocketStream, client_async_with_config, domain, stream::MaybeTlsStream,
};

/// Type alias for a stream that can be either plain TCP or TLS-encrypted.
pub type AutoStream<S> = MaybeTlsStream<S>;

/// Type alias for a TLS connector.
pub type Connector = TlsConnector;

async fn wrap_stream<S>(
Expand Down Expand Up @@ -202,6 +206,7 @@ where
client_async_with_config(request, stream, config).await
}

/// Type alias for a connect stream.
Comment thread
Berrysoft marked this conversation as resolved.
Outdated
pub type ConnectStream = AutoStream<TcpStream>;

/// Connect to a given URL.
Expand Down
5 changes: 5 additions & 0 deletions compio-ws/src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Provides [`MaybeTlsStream`].

#[cfg(feature = "rustls")]
use std::io::Result as IoResult;

Expand All @@ -22,15 +24,18 @@ pub enum MaybeTlsStream<S> {

#[cfg(feature = "rustls")]
impl<S> MaybeTlsStream<S> {
/// Create an unencrypted stream.
pub fn plain(stream: S) -> Self {
MaybeTlsStream::Plain(stream)
}

/// Create a TLS-encrypted stream.
#[cfg(feature = "rustls")]
pub fn tls(stream: TlsStream<S>) -> Self {
MaybeTlsStream::Tls(stream)
}

/// Whether the stream is TLS-encrypted.
pub fn is_tls(&self) -> bool {
#[cfg(feature = "rustls")]
{
Expand Down
Loading