From efc1d5a22e12a40d2584626f49e42930e9725902 Mon Sep 17 00:00:00 2001 From: Zach Heylmun Date: Thu, 7 May 2026 17:38:32 +0200 Subject: [PATCH 1/7] feat!: add Codec abstraction with JSON, MessagePack, and Raw codecs Introduces a `Codec` trait that owns a connection's `Tx` and `Rx` types and controls how each maps to a WebSocket frame. This replaces the hard-wired `serde_json`/Text-frame path with a pluggable interface, lets a single connection use asymmetric framing on send vs receive (e.g. IBKR's prefixed text out, JSON in), and supports MessagePack out of the box behind the new `msgpack` cargo feature. Breaking changes: - `Socketeer` becomes `Socketeer`; message types live on the codec - `ConnectionHandler` is now generic over the codec - `HandshakeContext` is codec-aware: codec-driven `send`/`recv` plus raw `send_text`/`send_binary`/`recv_text`/`recv_raw` escape hatches; the old `send_json`/`recv_json` helpers are gone - `ConnectOptions::custom_keepalive_message` widened from `Option` to `Option` to support binary keepalives - `Error::SerializationError(serde_json::Error)` replaced with `Error::Codec(Box)` - New `connect_with_codec` constructor; `connect`/`connect_with` retained as shortcuts when the codec is `Default` --- Cargo.lock | 35 ++++++ Cargo.toml | 2 + README.md | 42 +++++-- examples/echo_chat.rs | 5 +- src/codec.rs | 203 ++++++++++++++++++++++++++++++++++ src/config.rs | 9 +- src/error.rs | 6 +- src/handler.rs | 123 ++++++++++++++------- src/lib.rs | 249 +++++++++++++++++++++--------------------- src/mock_server.rs | 50 +++++++++ 10 files changed, 542 insertions(+), 182 deletions(-) create mode 100644 src/codec.rs diff --git a/Cargo.lock b/Cargo.lock index 0052614..9139edf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,12 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + [[package]] name = "bitflags" version = "2.11.0" @@ -530,6 +536,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -685,6 +700,25 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rmp" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" +dependencies = [ + "rmp", + "serde", +] + [[package]] name = "rustix" version = "1.1.4" @@ -834,6 +868,7 @@ dependencies = [ "bytes", "futures", "futures-util", + "rmp-serde", "serde", "serde_json", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 75be439..aec5284 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,12 +10,14 @@ repository = "https://github.com/zheylmun/socketeer" [features] mocking = [] +msgpack = ["dep:rmp-serde"] tracing = ["dep:tracing"] [dependencies] bytes = "1" futures = "0.3" futures-util = "0.3" +rmp-serde = { version = "1", optional = true } serde = { version = "1", features = ["derive"] } serde_json = "1" thiserror = "2" diff --git a/README.md b/README.md index 89e6c4d..28411ec 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,7 @@ `socketeer` is a simplified async WebSocket client built on tokio-tungstenite. It manages the underlying connection and exposes a clean API for sending and receiving messages, with support for: - Automatic connection management with configurable keepalive -- Type-safe JSON message serialization/deserialization via serde -- Raw message support for non-JSON protocols +- Pluggable codec for typed messages (`JsonCodec`, `MsgPackCodec`, `RawCodec`, or your own) - Custom HTTP headers on the WebSocket upgrade request - Connection lifecycle hooks for auth handshakes and subscriptions - Transparent handling of WebSocket protocol messages (ping/pong/close) @@ -20,7 +19,7 @@ ### Simple JSON messages ```rust no_run -use socketeer::Socketeer; +use socketeer::{JsonCodec, Socketeer}; #[derive(Debug, serde::Serialize, serde::Deserialize)] struct SocketMessage { @@ -29,7 +28,7 @@ struct SocketMessage { #[tokio::main] async fn main() { - let mut socketeer: Socketeer = + let mut socketeer: Socketeer> = Socketeer::connect("ws://127.0.0.1:80") .await .unwrap(); @@ -45,10 +44,30 @@ async fn main() { } ``` +### `MessagePack` + +Enable the `msgpack` feature and use [`MsgPackCodec`] in place of [`JsonCodec`]: + +```rust no_run +# #[cfg(feature = "msgpack")] +# { +use socketeer::{MsgPackCodec, Socketeer}; + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +struct SocketMessage { message: String } + +# #[tokio::main] +# async fn main() { +let mut socketeer: Socketeer> = + Socketeer::connect("ws://127.0.0.1:80").await.unwrap(); +# } +# } +``` + ### Custom headers and connection options ```rust no_run -use socketeer::{Socketeer, ConnectOptions}; +use socketeer::{ConnectOptions, JsonCodec, Socketeer}; use std::time::Duration; # #[derive(Debug, serde::Serialize, serde::Deserialize)] @@ -59,7 +78,7 @@ let mut options = ConnectOptions::default(); options.extra_headers.insert("Authorization", "Bearer my-token".parse().unwrap()); options.keepalive_interval = Some(Duration::from_secs(10)); -let socketeer: Socketeer = +let socketeer: Socketeer> = Socketeer::connect_with("wss://api.example.com/ws", options) .await .unwrap(); @@ -69,12 +88,12 @@ let socketeer: Socketeer = ### Connection lifecycle hooks ```rust no_run -use socketeer::{Socketeer, ConnectOptions, ConnectionHandler, HandshakeContext, Error}; +use socketeer::{Codec, ConnectOptions, ConnectionHandler, Error, HandshakeContext, JsonCodec, Socketeer}; struct MyAuthHandler { api_key: String } -impl ConnectionHandler for MyAuthHandler { - async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_>) -> Result<(), Error> { +impl ConnectionHandler for MyAuthHandler { + async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_, C>) -> Result<(), Error> { ctx.send_text(&format!(r#"{{"action":"auth","key":"{}"}}"#, self.api_key)).await?; let _response = ctx.recv_text().await?; Ok(()) @@ -86,10 +105,11 @@ impl ConnectionHandler for MyAuthHandler { # #[tokio::main] # async fn main() { let handler = MyAuthHandler { api_key: "secret".into() }; -let socketeer: Socketeer = - Socketeer::connect_with_handler( +let socketeer: Socketeer, MyAuthHandler> = + Socketeer::connect_with_codec( "wss://stream.example.com", ConnectOptions::default(), + JsonCodec::new(), handler, ) .await diff --git a/examples/echo_chat.rs b/examples/echo_chat.rs index cacfbe8..4f201d1 100644 --- a/examples/echo_chat.rs +++ b/examples/echo_chat.rs @@ -1,4 +1,4 @@ -use socketeer::{EchoControlMessage, Socketeer, echo_server, get_mock_address}; +use socketeer::{EchoControlMessage, JsonCodec, Socketeer, echo_server, get_mock_address}; use tracing_subscriber::fmt::Subscriber; #[tokio::main] @@ -14,7 +14,8 @@ async fn main() { let server_address = get_mock_address(echo_server).await; // Next, we create a Socketeer instance that connects to the mock server. - let mut socketeer: Socketeer = + // The codec parameter declares both the wire format and the message types. + let mut socketeer: Socketeer> = Socketeer::connect(&format!("ws://{server_address}",)) .await .unwrap(); diff --git a/src/codec.rs b/src/codec.rs new file mode 100644 index 0000000..288d813 --- /dev/null +++ b/src/codec.rs @@ -0,0 +1,203 @@ +//! Codec abstraction for serializing outgoing messages and deserializing incoming +//! messages on a [`crate::Socketeer`] connection. +//! +//! A [`Codec`] owns the `Tx` and `Rx` types for a connection and decides how each +//! is mapped to a WebSocket [`Message`]. This lets a single connection use different +//! framing on the send and receive sides — useful for protocols like Interactive +//! Brokers' Client Portal stream, which sends prefixed text strings (`smd+...`) but +//! receives JSON. +//! +//! Three stock codecs are provided: +//! +//! - [`JsonCodec`] — `serde_json`, sends as `Message::Text`. Decodes Text or Binary. +//! - [`MsgPackCodec`] — `rmp-serde`, sends as `Message::Binary`. Behind the +//! `msgpack` cargo feature. +//! - [`RawCodec`] — `Tx = Rx = Message`, no transformation. Useful when you want +//! the typed [`crate::Socketeer::send`] / [`crate::Socketeer::next_message`] +//! path but don't want any (de)serialization. +//! +//! Custom codecs implement the [`Codec`] trait directly. + +use std::marker::PhantomData; + +use serde::{Serialize, de::DeserializeOwned}; +use tokio_tungstenite::tungstenite::Message; + +use crate::Error; + +/// Encodes outgoing values into WebSocket messages and decodes incoming messages +/// into typed values. +/// +/// The `Tx` and `Rx` associated types are the values surfaced to users via +/// [`crate::Socketeer::send`] and [`crate::Socketeer::next_message`]. A codec is +/// free to use the same type on both sides (most common) or to use different +/// types when a protocol's send and receive shapes differ. +pub trait Codec: Send + Sync + 'static { + /// The type accepted by [`crate::Socketeer::send`]. + type Tx; + /// The type returned by [`crate::Socketeer::next_message`]. + type Rx; + + /// Encode a value of [`Self::Tx`] into a WebSocket [`Message`]. + /// # Errors + /// - If the value cannot be encoded. + fn encode(&self, value: &Self::Tx) -> Result; + + /// Decode a WebSocket [`Message`] into a value of [`Self::Rx`]. + /// # Errors + /// - If the message cannot be decoded. + fn decode(&self, frame: &Message) -> Result; +} + +/// JSON codec backed by `serde_json`. +/// +/// Encodes outgoing values as [`Message::Text`]. Decodes incoming `Text` +/// frames and, for compatibility with servers that send JSON in binary frames, +/// also accepts [`Message::Binary`]. +pub struct JsonCodec(PhantomData (Rx, Tx)>); + +impl JsonCodec { + /// Construct a new [`JsonCodec`]. + #[must_use] + pub const fn new() -> Self { + Self(PhantomData) + } +} + +impl Default for JsonCodec { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for JsonCodec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("JsonCodec") + } +} + +impl Clone for JsonCodec { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for JsonCodec {} + +impl Codec for JsonCodec +where + Rx: DeserializeOwned + Send + 'static, + Tx: Serialize + Send + 'static, +{ + type Tx = Tx; + type Rx = Rx; + + fn encode(&self, value: &Self::Tx) -> Result { + let text = serde_json::to_string(value).map_err(|e| Error::Codec(Box::new(e)))?; + Ok(Message::Text(text.into())) + } + + fn decode(&self, frame: &Message) -> Result { + match frame { + Message::Text(text) => { + serde_json::from_str(text).map_err(|e| Error::Codec(Box::new(e))) + } + Message::Binary(bytes) => { + serde_json::from_slice(bytes).map_err(|e| Error::Codec(Box::new(e))) + } + other => Err(Error::UnexpectedMessageType(Box::new(other.clone()))), + } + } +} + +/// `MessagePack` codec backed by `rmp-serde`. +/// +/// Encodes outgoing values as [`Message::Binary`]. Decodes only `Binary` frames. +#[cfg(feature = "msgpack")] +pub struct MsgPackCodec(PhantomData (Rx, Tx)>); + +#[cfg(feature = "msgpack")] +impl MsgPackCodec { + /// Construct a new [`MsgPackCodec`]. + #[must_use] + pub const fn new() -> Self { + Self(PhantomData) + } +} + +#[cfg(feature = "msgpack")] +impl Default for MsgPackCodec { + fn default() -> Self { + Self::new() + } +} + +#[cfg(feature = "msgpack")] +impl std::fmt::Debug for MsgPackCodec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("MsgPackCodec") + } +} + +#[cfg(feature = "msgpack")] +impl Clone for MsgPackCodec { + fn clone(&self) -> Self { + *self + } +} + +#[cfg(feature = "msgpack")] +impl Copy for MsgPackCodec {} + +#[cfg(feature = "msgpack")] +impl Codec for MsgPackCodec +where + Rx: DeserializeOwned + Send + 'static, + Tx: Serialize + Send + 'static, +{ + type Tx = Tx; + type Rx = Rx; + + fn encode(&self, value: &Self::Tx) -> Result { + let bytes = rmp_serde::to_vec_named(value).map_err(|e| Error::Codec(Box::new(e)))?; + Ok(Message::Binary(bytes.into())) + } + + fn decode(&self, frame: &Message) -> Result { + match frame { + Message::Binary(bytes) => { + rmp_serde::from_slice(bytes).map_err(|e| Error::Codec(Box::new(e))) + } + other => Err(Error::UnexpectedMessageType(Box::new(other.clone()))), + } + } +} + +/// Identity codec — `Tx` and `Rx` are both [`Message`], no (de)serialization. +/// +/// Useful when you want to drive the typed [`crate::Socketeer::send`] / +/// [`crate::Socketeer::next_message`] API without any encoding step, e.g. for +/// protocols where you assemble frame bodies by hand. +#[derive(Debug, Default, Clone, Copy)] +pub struct RawCodec; + +impl RawCodec { + /// Construct a new [`RawCodec`]. + #[must_use] + pub const fn new() -> Self { + Self + } +} + +impl Codec for RawCodec { + type Tx = Message; + type Rx = Message; + + fn encode(&self, value: &Self::Tx) -> Result { + Ok(value.clone()) + } + + fn decode(&self, frame: &Message) -> Result { + Ok(frame.clone()) + } +} diff --git a/src/config.rs b/src/config.rs index 713d1cd..1e10aca 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,5 @@ use std::time::Duration; -use tokio_tungstenite::tungstenite::{client::IntoClientRequest, http}; +use tokio_tungstenite::tungstenite::{Message, client::IntoClientRequest, http}; use url::Url; use crate::Error; @@ -15,9 +15,10 @@ pub struct ConnectOptions { /// Idle timeout before sending a keepalive. `None` disables keepalives entirely. /// Defaults to 2 seconds. pub keepalive_interval: Option, - /// If set, send this text string as keepalive instead of a WebSocket ping frame. - /// Useful for APIs like Interactive Brokers that expect a custom keepalive message. - pub custom_keepalive_message: Option, + /// If set, send this message as the keepalive instead of a WebSocket Ping frame. + /// Useful for APIs that expect a custom keepalive payload — e.g. Interactive + /// Brokers' literal text `tic`, or a binary heartbeat in a msgpack protocol. + pub custom_keepalive_message: Option, } impl Default for ConnectOptions { diff --git a/src/error.rs b/src/error.rs index d964fc2..4a782b1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -24,9 +24,9 @@ pub enum Error { /// Error thrown if a message type not handled by `socketeer` is received. #[error("Unexpected Message type: {0}")] UnexpectedMessageType(Box), - /// Error thrown if the message received fails to serialize or deserialize. - #[error("Serialization Error: {0}")] - SerializationError(#[from] serde_json::Error), + /// Error thrown if a [`crate::Codec`] fails to encode or decode a message. + #[error("Codec error: {0}")] + Codec(Box), /// Error thrown if socketeer is dropped without closing the connection. /// This error will be removed once async destructors are stabilized. /// See [issue](https://github.com/rust-lang/rust/issues/126482) diff --git a/src/handler.rs b/src/handler.rs index bccacbe..aa45468 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,28 +1,62 @@ +use bytes::Bytes; use futures::{SinkExt, StreamExt, stream::SplitSink, stream::SplitStream}; -use serde::{Deserialize, Serialize}; use tokio_tungstenite::tungstenite::Message; -use crate::{Error, WebSocketStreamType}; +use crate::{Codec, Error, WebSocketStreamType}; /// Context available during the WebSocket handshake phase. /// /// Provides methods to send and receive messages before the main socket loop starts. /// Use this in [`ConnectionHandler::on_connected`] for authentication handshakes /// and initial subscriptions. -pub struct HandshakeContext<'a> { +/// +/// The context is generic over the connection's [`Codec`], so the typed +/// [`Self::send`] and [`Self::recv`] methods speak the same encoding the rest of +/// the connection will use. Raw text/binary helpers are provided for protocols +/// (e.g. IBKR's `api={session}` step) that don't fit the codec's `Tx` shape. +pub struct HandshakeContext<'a, C: Codec> { sink: &'a mut SplitSink, stream: &'a mut SplitStream, + codec: &'a C, } -impl<'a> HandshakeContext<'a> { +impl<'a, C: Codec> HandshakeContext<'a, C> { pub(crate) fn new( sink: &'a mut SplitSink, stream: &'a mut SplitStream, + codec: &'a C, ) -> Self { - Self { sink, stream } + Self { + sink, + stream, + codec, + } } - /// Send a text message during the handshake. + /// Encode and send a value using the connection's [`Codec`]. + /// # Errors + /// - If the codec fails to encode the value + /// - If the WebSocket connection fails + pub async fn send(&mut self, value: &C::Tx) -> Result<(), Error> { + let message = self.codec.encode(value)?; + self.sink.send(message).await.map_err(Error::from) + } + + /// Receive and decode the next message using the connection's [`Codec`]. + /// + /// Skips ping/pong protocol frames. Returns an error on close or codec failure. + /// # Errors + /// - If the WebSocket connection is closed + /// - If the codec fails to decode the frame + pub async fn recv(&mut self) -> Result { + let frame = self.recv_raw().await?; + self.codec.decode(&frame) + } + + /// Send a raw text frame during the handshake. + /// + /// Useful for protocols whose handshake messages don't fit the codec's `Tx` + /// type (e.g. IBKR's literal `api={session_id}` auth string). /// # Errors /// - If the WebSocket connection fails pub async fn send_text(&mut self, text: &str) -> Result<(), Error> { @@ -32,43 +66,47 @@ impl<'a> HandshakeContext<'a> { .map_err(Error::from) } + /// Send a raw binary frame during the handshake. + /// # Errors + /// - If the WebSocket connection fails + pub async fn send_binary(&mut self, bytes: impl Into) -> Result<(), Error> { + self.sink + .send(Message::Binary(bytes.into())) + .await + .map_err(Error::from) + } + /// Receive the next text message during the handshake. /// - /// Skips non-text protocol messages (ping/pong). Returns an error if a - /// binary or close frame is received, or if the connection closes. + /// Skips ping/pong protocol messages. Returns an error if a binary or close + /// frame is received, or if the connection closes. /// # Errors /// - If the WebSocket connection is closed /// - If a non-text message is received pub async fn recv_text(&mut self) -> Result { + match self.recv_raw().await? { + Message::Text(text) => Ok(text.to_string()), + other => Err(Error::UnexpectedMessageType(Box::new(other))), + } + } + + /// Receive the next data frame during the handshake without decoding. + /// + /// Skips ping/pong protocol messages, returning the next `Text`, `Binary`, + /// or `Close` frame as-is. + /// # Errors + /// - If the WebSocket connection is closed + pub async fn recv_raw(&mut self) -> Result { loop { let Some(msg) = self.stream.next().await else { return Err(Error::WebsocketClosed); }; match msg.map_err(Error::WebsocketError)? { - Message::Text(text) => return Ok(text.to_string()), Message::Ping(_) | Message::Pong(_) => {} - other => return Err(Error::UnexpectedMessageType(Box::new(other))), + other => return Ok(other), } } } - - /// Serialize and send a JSON message during the handshake. - /// # Errors - /// - If the message cannot be serialized - /// - If the WebSocket connection fails - pub async fn send_json(&mut self, msg: &T) -> Result<(), Error> { - let text = serde_json::to_string(msg)?; - self.send_text(&text).await - } - - /// Receive and deserialize a JSON message during the handshake. - /// # Errors - /// - If the WebSocket connection is closed - /// - If the message cannot be deserialized - pub async fn recv_json Deserialize<'de>>(&mut self) -> Result { - let text = self.recv_text().await?; - serde_json::from_str(&text).map_err(Error::from) - } } /// Trait for handling WebSocket connection lifecycle events. @@ -77,32 +115,36 @@ impl<'a> HandshakeContext<'a> { /// handshakes, subscriptions) and teardown. The handler is called both on initial /// connection and on reconnect. /// +/// The trait is generic over the connection's [`Codec`] so that +/// [`Self::on_connected`] sees a [`HandshakeContext`] specialized to that codec. +/// A handler that doesn't care about the codec can implement +/// `ConnectionHandler` for all `C: Codec` (see [`NoopHandler`]). +/// /// # Example /// /// ```rust,no_run -/// use socketeer::{ConnectionHandler, HandshakeContext, Error}; +/// use socketeer::{Codec, ConnectionHandler, HandshakeContext, Error}; /// /// struct AlpacaAuth { /// api_key: String, /// api_secret: String, /// } /// -/// impl ConnectionHandler for AlpacaAuth { -/// async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_>) -> Result<(), Error> { -/// // Wait for "connected" message -/// let _ = ctx.recv_text().await?; -/// // Send auth +/// impl ConnectionHandler for AlpacaAuth { +/// async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_, C>) -> Result<(), Error> { +/// // Wait for the welcome frame +/// let _ = ctx.recv_raw().await?; +/// // Send auth as plain text — works regardless of the negotiated codec /// ctx.send_text(&format!( /// r#"{{"action":"auth","key":"{}","secret":"{}"}}"#, /// self.api_key, self.api_secret /// )).await?; -/// // Wait for "authenticated" response -/// let _ = ctx.recv_text().await?; +/// let _ = ctx.recv_raw().await?; /// Ok(()) /// } /// } /// ``` -pub trait ConnectionHandler: Send + 'static { +pub trait ConnectionHandler: Send + 'static { /// Called after the WebSocket upgrade completes, before the socket loop starts. /// /// Use this for authentication handshakes, initial subscriptions, or any @@ -110,7 +152,7 @@ pub trait ConnectionHandler: Send + 'static { /// This is also called after a reconnect. fn on_connected( &mut self, - ctx: &mut HandshakeContext<'_>, + ctx: &mut HandshakeContext<'_, C>, ) -> impl std::future::Future> + Send { let _ = ctx; async { Ok(()) } @@ -124,8 +166,9 @@ pub trait ConnectionHandler: Send + 'static { /// Default no-op connection handler. /// -/// Used when no lifecycle hooks are needed (the simple case). +/// Used when no lifecycle hooks are needed (the simple case). Implements +/// [`ConnectionHandler`] for every codec. #[derive(Debug, Clone, Copy)] pub struct NoopHandler; -impl ConnectionHandler for NoopHandler {} +impl ConnectionHandler for NoopHandler {} diff --git a/src/lib.rs b/src/lib.rs index 9edf8c0..a966ba3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,21 +1,26 @@ #![doc = include_str!("../README.md")] #![deny(missing_docs)] +mod codec; mod config; mod error; mod handler; #[cfg(feature = "mocking")] mod mock_server; +#[cfg(feature = "msgpack")] +pub use codec::MsgPackCodec; +pub use codec::{Codec, JsonCodec, RawCodec}; pub use config::ConnectOptions; pub use error::Error; pub use handler::{ConnectionHandler, HandshakeContext, NoopHandler}; +#[cfg(all(feature = "mocking", feature = "msgpack"))] +pub use mock_server::msgpack_echo_server; #[cfg(feature = "mocking")] pub use mock_server::{EchoControlMessage, auth_echo_server, echo_server, get_mock_address}; use bytes::Bytes; use futures::{SinkExt, StreamExt, stream::SplitSink, stream::SplitStream}; -use serde::{Deserialize, Serialize}; -use std::{fmt::Debug, time::Duration}; +use std::time::Duration; use tokio::{ net::TcpStream, select, @@ -42,25 +47,31 @@ struct TxChannelPayload { /// /// # Type Parameters /// -/// - `RxMessage`: The type of message that the client will receive from the server. -/// - `TxMessage`: The type of message that the client will send to the server. +/// - `C`: A [`Codec`] that defines the connection's `Tx` and `Rx` types and how +/// they map to WebSocket frames. Use [`JsonCodec`] for the common case, +/// [`MsgPackCodec`] (behind the `msgpack` feature) for `MessagePack`, or +/// [`RawCodec`] for direct [`Message`] access. /// - `Handler`: A [`ConnectionHandler`] for lifecycle hooks (auth, subscriptions). /// Defaults to [`NoopHandler`] for the simple case. /// - `CHANNEL_SIZE`: The size of the internal channels used to communicate between /// the task managing the WebSocket connection and the client. -pub struct Socketeer { +pub struct Socketeer +where + Handler: ConnectionHandler, +{ url: Url, options: ConnectOptions, + codec: C, handler: Handler, receiver: mpsc::Receiver, sender: mpsc::Sender, socket_handle: tokio::task::JoinHandle>, - _rx_message: std::marker::PhantomData, - _tx_message: std::marker::PhantomData, } -impl std::fmt::Debug - for Socketeer +impl std::fmt::Debug + for Socketeer +where + Handler: ConnectionHandler, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Socketeer") @@ -69,11 +80,9 @@ impl std::fmt::Debug } } -impl< - RxMessage: for<'a> Deserialize<'a> + Debug, - TxMessage: Serialize + Debug, - const CHANNEL_SIZE: usize, -> Socketeer +impl Socketeer +where + C: Codec + Default, { /// Create a `Socketeer` connected to the provided URL with default options. /// Once connected, Socketeer manages the underlying WebSocket connection, transparently handling protocol messages. @@ -91,18 +100,16 @@ impl< /// - If the WebSocket connection to the requested URL fails #[cfg_attr(feature = "tracing", instrument(skip(options)))] pub async fn connect_with(url: &str, options: ConnectOptions) -> Result { - Socketeer::connect_with_handler(url, options, NoopHandler).await + Socketeer::connect_with_codec(url, options, C::default(), NoopHandler).await } } -impl< - RxMessage: for<'a> Deserialize<'a> + Debug, - TxMessage: Serialize + Debug, - Handler: ConnectionHandler, - const CHANNEL_SIZE: usize, -> Socketeer +impl Socketeer +where + C: Codec, + Handler: ConnectionHandler, { - /// Create a `Socketeer` with a custom [`ConnectionHandler`] for lifecycle hooks. + /// Create a `Socketeer` with an explicit codec and [`ConnectionHandler`]. /// /// The handler's [`ConnectionHandler::on_connected`] method is called after the /// WebSocket upgrade completes, before the socket loop starts. This is where @@ -111,10 +118,11 @@ impl< /// - If the URL cannot be parsed /// - If the WebSocket connection to the requested URL fails /// - If the handler's `on_connected` returns an error - #[cfg_attr(feature = "tracing", instrument(skip(options, handler)))] - pub async fn connect_with_handler( + #[cfg_attr(feature = "tracing", instrument(skip(options, codec, handler)))] + pub async fn connect_with_codec( url: &str, options: ConnectOptions, + codec: C, mut handler: Handler, ) -> Result { let url = Url::parse(url).map_err(|source| Error::UrlParse { @@ -130,7 +138,7 @@ impl< let (mut sink, mut stream) = socket.split(); { - let mut ctx = HandshakeContext::new(&mut sink, &mut stream); + let mut ctx = HandshakeContext::new(&mut sink, &mut stream, &codec); handler.on_connected(&mut ctx).await?; } @@ -154,75 +162,49 @@ impl< Ok(Socketeer { url, options, + codec, handler, receiver: rx_rx, sender: tx_tx, socket_handle, - _rx_message: std::marker::PhantomData, - _tx_message: std::marker::PhantomData, }) } - /// Wait for the next parsed message from the WebSocket connection. + /// Wait for the next message from the WebSocket connection, decoded by the + /// connection's [`Codec`]. /// /// # Errors /// /// - If the WebSocket connection is closed or otherwise errored - /// - If the message cannot be deserialized + /// - If the codec fails to decode the frame #[cfg_attr(feature = "tracing", instrument(skip(self)))] - pub async fn next_message(&mut self) -> Result { + pub async fn next_message(&mut self) -> Result { let Some(message) = self.receiver.recv().await else { return Err(Error::WebsocketClosed); }; - match message { - Message::Text(text) => { - #[cfg(feature = "tracing")] - trace!("Received text message: {:?}", text); - let message = serde_json::from_str(&text)?; - Ok(message) - } - Message::Binary(message) => { - #[cfg(feature = "tracing")] - trace!("Received binary message: {:?}", message); - let message = serde_json::from_slice(&message)?; - Ok(message) - } - _ => Err(Error::UnexpectedMessageType(Box::new(message))), - } + #[cfg(feature = "tracing")] + trace!("Received message: {:?}", message); + self.codec.decode(&message) } - /// Send a message to the WebSocket connection. + /// Encode and send a message via the connection's [`Codec`]. /// This function will wait for the message to be sent before returning. /// /// # Errors /// - /// - If the message cannot be serialized + /// - If the codec fails to encode the value /// - If the WebSocket connection is closed, or otherwise errored - #[cfg_attr(feature = "tracing", instrument(skip(self)))] - pub async fn send(&self, message: TxMessage) -> Result<(), Error> { - #[cfg(feature = "tracing")] - trace!("Sending message: {:?}", message); - - let (tx, rx) = oneshot::channel::>(); - let message = serde_json::to_string(&message)?; - - self.sender - .send(TxChannelPayload { - message: Message::Text(message.into()), - response_tx: tx, - }) - .await - .map_err(|_| Error::WebsocketClosed)?; - // We'll ensure that we always respond before dropping the tx channel - match rx.await { - Ok(result) => result, - Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"), - } + #[cfg_attr(feature = "tracing", instrument(skip(self, message)))] + pub async fn send(&self, message: C::Tx) -> Result<(), Error> { + let encoded = self.codec.encode(&message)?; + self.send_raw(encoded).await } - /// Receive the next raw [`Message`] from the WebSocket connection without deserialization. + /// Receive the next raw [`Message`] from the WebSocket connection without + /// running the codec. /// - /// This is useful for protocols that don't use JSON or need to inspect the raw message. + /// Useful when you need to inspect the underlying frame type or handle a + /// message that the codec would reject. /// /// # Errors /// @@ -231,10 +213,9 @@ impl< self.receiver.recv().await.ok_or(Error::WebsocketClosed) } - /// Send a raw [`Message`] to the WebSocket connection without serialization. + /// Send a raw [`Message`] to the WebSocket connection without running the codec. /// - /// This is useful for sending non-JSON messages (e.g., plain text keepalives) - /// or binary data that is already encoded. + /// Useful for sending control frames or pre-encoded payloads. /// /// # Errors /// @@ -267,6 +248,7 @@ impl< pub async fn reconnect(self) -> Result { let url = self.url.as_str().to_owned(); let options = self.options.clone(); + let codec = self.codec; let mut handler = self.handler; #[cfg(feature = "tracing")] info!("Reconnecting"); @@ -280,7 +262,7 @@ impl< error!("Socket Loop already stopped: {}", e); } } - Self::connect_with_handler(&url, options, handler).await + Self::connect_with_codec(&url, options, codec, handler).await } /// Close the WebSocket connection gracefully. @@ -339,7 +321,7 @@ async fn socket_loop_split( mut sink: SocketSink, mut stream: SocketStream, keepalive_interval: Option, - keepalive_message: Option, + keepalive_message: Option, ) -> Result<(), Error> { let mut state = LoopState::Running; while matches!(state, LoopState::Running) { @@ -347,7 +329,7 @@ async fn socket_loop_split( select! { outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await, incoming_message = stream.next() => socket_message_received(incoming_message, &mut sender, &mut sink).await, - () = sleep(interval) => send_keepalive(&mut sink, keepalive_message.as_deref()).await, + () = sleep(interval) => send_keepalive(&mut sink, keepalive_message.as_ref()).await, } } else { select! { @@ -444,11 +426,11 @@ async fn socket_message_received( } #[cfg_attr(feature = "tracing", instrument)] -async fn send_keepalive(sink: &mut SocketSink, custom_message: Option<&str>) -> LoopState { - let message = if let Some(text) = custom_message { +async fn send_keepalive(sink: &mut SocketSink, custom_message: Option<&Message>) -> LoopState { + let message = if let Some(custom) = custom_message { #[cfg(feature = "tracing")] info!("Timeout waiting for message, sending custom keepalive"); - Message::Text(text.into()) + custom.clone() } else { #[cfg(feature = "tracing")] info!("Timeout waiting for message, sending Ping"); @@ -470,6 +452,8 @@ mod tests { use super::*; use tokio::time::sleep; + type EchoJson = JsonCodec; + #[tokio::test] async fn test_server_startup() { let _server_address = get_mock_address(echo_server).await; @@ -478,24 +462,22 @@ mod tests { #[tokio::test] async fn test_connection() { let server_address = get_mock_address(echo_server).await; - let _socketeer: Socketeer = - Socketeer::connect(&format!("ws://{server_address}",)) - .await - .unwrap(); + let _socketeer: Socketeer = Socketeer::connect(&format!("ws://{server_address}")) + .await + .unwrap(); } #[tokio::test] async fn test_bad_url() { - let error: Result, Error> = - Socketeer::connect("Not a URL").await; + let error: Result, Error> = Socketeer::connect("Not a URL").await; assert!(matches!(error.unwrap_err(), Error::UrlParse { .. })); } #[tokio::test] async fn test_send_receive() { let server_address = get_mock_address(echo_server).await; - let mut socketeer: Socketeer = - Socketeer::connect(&format!("ws://{server_address}",)) + let mut socketeer: Socketeer = + Socketeer::connect(&format!("ws://{server_address}")) .await .unwrap(); let message = EchoControlMessage::Message("Hello".to_string()); @@ -507,8 +489,8 @@ mod tests { #[tokio::test] async fn test_ping_request() { let server_address = get_mock_address(echo_server).await; - let mut socketeer: Socketeer = - Socketeer::connect(&format!("ws://{server_address}",)) + let mut socketeer: Socketeer = + Socketeer::connect(&format!("ws://{server_address}")) .await .unwrap(); let ping_request = EchoControlMessage::SendPing; @@ -527,8 +509,8 @@ mod tests { #[tokio::test] async fn test_reconnection() { let server_address = get_mock_address(echo_server).await; - let mut socketeer: Socketeer = - Socketeer::connect(&format!("ws://{server_address}",)) + let mut socketeer: Socketeer = + Socketeer::connect(&format!("ws://{server_address}")) .await .unwrap(); let message = EchoControlMessage::Message("Hello".to_string()); @@ -546,8 +528,8 @@ mod tests { #[tokio::test] async fn test_closed_socket() { let server_address = get_mock_address(echo_server).await; - let mut socketeer: Socketeer = - Socketeer::connect(&format!("ws://{server_address}",)) + let mut socketeer: Socketeer = + Socketeer::connect(&format!("ws://{server_address}")) .await .unwrap(); let close_request = EchoControlMessage::Close; @@ -564,17 +546,16 @@ mod tests { #[tokio::test] async fn test_close_request() { let server_address = get_mock_address(echo_server).await; - let socketeer: Socketeer = - Socketeer::connect(&format!("ws://{server_address}",)) - .await - .unwrap(); + let socketeer: Socketeer = Socketeer::connect(&format!("ws://{server_address}")) + .await + .unwrap(); socketeer.close_connection().await.unwrap(); } #[tokio::test] async fn test_connect_with_default_options() { let server_address = get_mock_address(echo_server).await; - let mut socketeer: Socketeer = + let mut socketeer: Socketeer = Socketeer::connect_with(&format!("ws://{server_address}"), ConnectOptions::default()) .await .unwrap(); @@ -587,16 +568,16 @@ mod tests { #[tokio::test] async fn test_send_raw_receive_raw() { let server_address = get_mock_address(echo_server).await; - let mut socketeer: Socketeer = + let mut socketeer: Socketeer = Socketeer::connect(&format!("ws://{server_address}")) .await .unwrap(); let raw_text = r#"{"Message":"raw hello"}"#; socketeer - .send_raw(Message::Text(raw_text.into())) + .send(Message::Text(raw_text.into())) .await .unwrap(); - let received = socketeer.next_raw_message().await.unwrap(); + let received = socketeer.next_message().await.unwrap(); assert_eq!(received, Message::Text(raw_text.into())); } @@ -607,7 +588,7 @@ mod tests { keepalive_interval: None, ..ConnectOptions::default() }; - let mut socketeer: Socketeer = + let mut socketeer: Socketeer = Socketeer::connect_with(&format!("ws://{server_address}"), options) .await .unwrap(); @@ -632,11 +613,15 @@ mod tests { connected_count: Arc>, } - impl ConnectionHandler for TestAuthHandler { - async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_>) -> Result<(), Error> { + impl ConnectionHandler for TestAuthHandler { + async fn on_connected( + &mut self, + ctx: &mut HandshakeContext<'_, C>, + ) -> Result<(), Error> { ctx.send_text(r#"{"action":"auth","token":"test-token"}"#) .await?; - let response: AuthResponse = ctx.recv_json().await?; + let text = ctx.recv_text().await?; + let response: AuthResponse = serde_json::from_str(&text).unwrap(); assert_eq!(response.status, "authenticated"); let mut count = self.connected_count.lock().await; *count += 1; @@ -650,14 +635,14 @@ mod tests { }; let server_address = get_mock_address(auth_echo_server).await; - let mut socketeer: Socketeer = - Socketeer::connect_with_handler( - &format!("ws://{server_address}"), - ConnectOptions::default(), - handler, - ) - .await - .unwrap(); + let mut socketeer: Socketeer = Socketeer::connect_with_codec( + &format!("ws://{server_address}"), + ConnectOptions::default(), + JsonCodec::new(), + handler, + ) + .await + .unwrap(); assert_eq!(*connected_count.lock().await, 1); @@ -677,8 +662,11 @@ mod tests { disconnected_count: Arc>, } - impl ConnectionHandler for ReconnectHandler { - async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_>) -> Result<(), Error> { + impl ConnectionHandler for ReconnectHandler { + async fn on_connected( + &mut self, + ctx: &mut HandshakeContext<'_, C>, + ) -> Result<(), Error> { ctx.send_text(r#"{"action":"auth","token":"test-token"}"#) .await?; let _response = ctx.recv_text().await?; @@ -701,14 +689,14 @@ mod tests { }; let server_address = get_mock_address(auth_echo_server).await; - let mut socketeer = - Socketeer::::connect_with_handler( - &format!("ws://{server_address}"), - ConnectOptions::default(), - handler, - ) - .await - .unwrap(); + let mut socketeer = Socketeer::::connect_with_codec( + &format!("ws://{server_address}"), + ConnectOptions::default(), + JsonCodec::new(), + handler, + ) + .await + .unwrap(); assert_eq!(*connected_count.lock().await, 1); assert_eq!(*disconnected_count.lock().await, 0); @@ -733,4 +721,21 @@ mod tests { socketeer.close_connection().await.unwrap(); } + + #[cfg(feature = "msgpack")] + #[tokio::test] + async fn test_msgpack_send_receive() { + type EchoMsgPack = MsgPackCodec; + + let server_address = get_mock_address(msgpack_echo_server).await; + let mut socketeer: Socketeer = + Socketeer::connect(&format!("ws://{server_address}")) + .await + .unwrap(); + let message = EchoControlMessage::Message("msgpack hello".to_string()); + socketeer.send(message.clone()).await.unwrap(); + let received = socketeer.next_message().await.unwrap(); + assert_eq!(message, received); + socketeer.close_connection().await.unwrap(); + } } diff --git a/src/mock_server.rs b/src/mock_server.rs index 4bb5faa..31c5bed 100644 --- a/src/mock_server.rs +++ b/src/mock_server.rs @@ -79,6 +79,56 @@ pub async fn echo_server(ws: WebSocketStreamType) -> Result Result { + let (mut sink, mut stream) = ws.split(); + let mut shutting_down = false; + while let Some(message) = stream.next().await { + match message { + Ok(Message::Binary(bytes)) => { + let control_message: EchoControlMessage = rmp_serde::from_slice(&bytes).unwrap(); + match control_message { + EchoControlMessage::Message(_) => { + sink.send(Message::Binary(bytes)).await?; + } + EchoControlMessage::SendPing => { + sink.send(Message::Ping(Bytes::new())).await?; + } + EchoControlMessage::Close => { + sink.send(Message::Close(Some(CloseFrame { + code: CloseCode::Normal, + reason: "".into(), + }))) + .await?; + shutting_down = true; + } + } + } + Ok(Message::Ping(_)) => { + sink.send(Message::Pong(Bytes::new())).await.unwrap(); + } + Ok(Message::Close(_)) => { + if !shutting_down { + sink.close().await.unwrap(); + drop(stream); + } + break; + } + _ => {} + } + } + Ok(shutting_down) +} + /// Echo server that requires an auth handshake before echoing. /// /// Expects the client to send `{"action":"auth","token":"test-token"}` as the first message. From aeb3c8964ae9a8d21fc67a4b31fb4a9700c7c525 Mon Sep 17 00:00:00 2001 From: Zach Heylmun Date: Thu, 7 May 2026 18:03:52 +0200 Subject: [PATCH 2/7] test: add coverage for Codec trait and new HandshakeContext paths - Inline unit tests for JsonCodec, MsgPackCodec, and RawCodec covering frame-type assertions, round-trips, the back-compat Binary decode path on JsonCodec, frame-type rejections, and Error::Codec mapping for malformed payloads - Integration test for codec-driven HandshakeContext::send / recv (the typed handshake path; existing handler tests only exercised send_text / recv_text) - Integration test for the widened Option custom keepalive, using Message::Binary to confirm the connection survives a binary keepalive cycle --- src/codec.rs | 123 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 63 ++++++++++++++++++++++++++ 2 files changed, 186 insertions(+) diff --git a/src/codec.rs b/src/codec.rs index 288d813..1c56891 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -201,3 +201,126 @@ impl Codec for RawCodec { Ok(frame.clone()) } } + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use serde::{Deserialize, Serialize}; + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct TestMsg { + id: u32, + name: String, + } + + fn sample() -> TestMsg { + TestMsg { + id: 42, + name: "hello".into(), + } + } + + #[test] + fn json_codec_encodes_to_text() { + let codec: JsonCodec = JsonCodec::new(); + let frame = codec.encode(&sample()).unwrap(); + let Message::Text(text) = frame else { + panic!("expected Text frame, got {frame:?}"); + }; + assert!(text.contains("\"id\":42")); + assert!(text.contains("\"name\":\"hello\"")); + } + + #[test] + fn json_codec_round_trips_text() { + let codec: JsonCodec = JsonCodec::new(); + let frame = codec.encode(&sample()).unwrap(); + assert_eq!(codec.decode(&frame).unwrap(), sample()); + } + + #[test] + fn json_codec_decodes_binary_for_back_compat() { + // Some servers (e.g. legacy Socketeer behavior) ship JSON inside Binary frames. + let codec: JsonCodec = JsonCodec::new(); + let bytes = serde_json::to_vec(&sample()).unwrap(); + let decoded = codec.decode(&Message::Binary(bytes.into())).unwrap(); + assert_eq!(decoded, sample()); + } + + #[test] + fn json_codec_rejects_ping_frame() { + let codec: JsonCodec = JsonCodec::new(); + let result = codec.decode(&Message::Ping(Bytes::new())); + assert!(matches!( + result.unwrap_err(), + Error::UnexpectedMessageType(_) + )); + } + + #[test] + fn json_codec_surfaces_decode_failure_as_codec_error() { + let codec: JsonCodec = JsonCodec::new(); + let result = codec.decode(&Message::Text("not json".into())); + assert!(matches!(result.unwrap_err(), Error::Codec(_))); + } + + #[cfg(feature = "msgpack")] + #[test] + fn msgpack_codec_encodes_to_binary() { + let codec: MsgPackCodec = MsgPackCodec::new(); + let frame = codec.encode(&sample()).unwrap(); + assert!(matches!(frame, Message::Binary(_))); + } + + #[cfg(feature = "msgpack")] + #[test] + fn msgpack_codec_round_trips_binary() { + let codec: MsgPackCodec = MsgPackCodec::new(); + let frame = codec.encode(&sample()).unwrap(); + assert_eq!(codec.decode(&frame).unwrap(), sample()); + } + + #[cfg(feature = "msgpack")] + #[test] + fn msgpack_codec_rejects_text_frame() { + let codec: MsgPackCodec = MsgPackCodec::new(); + let result = codec.decode(&Message::Text("not msgpack".into())); + assert!(matches!( + result.unwrap_err(), + Error::UnexpectedMessageType(_) + )); + } + + #[cfg(feature = "msgpack")] + #[test] + fn msgpack_codec_surfaces_decode_failure_as_codec_error() { + let codec: MsgPackCodec = MsgPackCodec::new(); + let result = codec.decode(&Message::Binary(Bytes::from_static(b"not msgpack"))); + assert!(matches!(result.unwrap_err(), Error::Codec(_))); + } + + #[test] + fn raw_codec_round_trips_text() { + let codec = RawCodec::new(); + let frame = Message::Text("raw text".into()); + assert_eq!(codec.encode(&frame).unwrap(), frame); + assert_eq!(codec.decode(&frame).unwrap(), frame); + } + + #[test] + fn raw_codec_round_trips_binary() { + let codec = RawCodec::new(); + let frame = Message::Binary(Bytes::from_static(b"raw bytes")); + assert_eq!(codec.encode(&frame).unwrap(), frame); + assert_eq!(codec.decode(&frame).unwrap(), frame); + } + + #[test] + fn raw_codec_passes_protocol_frames_through() { + // RawCodec doesn't filter — Ping/Pong/Close round-trip unchanged. + let codec = RawCodec::new(); + let frame = Message::Ping(Bytes::from_static(b"ping")); + assert_eq!(codec.decode(&frame).unwrap(), frame); + } +} diff --git a/src/lib.rs b/src/lib.rs index a966ba3..280dbcf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -738,4 +738,67 @@ mod tests { assert_eq!(message, received); socketeer.close_connection().await.unwrap(); } + + #[tokio::test] + async fn test_handler_uses_codec_driven_send_recv() { + // Exercises HandshakeContext::send / recv (the codec-driven path). + // Other handler tests only cover the raw send_text / recv_text helpers. + struct TypedHandshakeHandler; + + impl ConnectionHandler for TypedHandshakeHandler { + async fn on_connected( + &mut self, + ctx: &mut HandshakeContext<'_, EchoJson>, + ) -> Result<(), Error> { + ctx.send(&EchoControlMessage::Message("handshake".into())) + .await?; + let echoed = ctx.recv().await?; + assert_eq!(echoed, EchoControlMessage::Message("handshake".into())); + Ok(()) + } + } + + let server_address = get_mock_address(echo_server).await; + let mut socketeer: Socketeer = + Socketeer::connect_with_codec( + &format!("ws://{server_address}"), + ConnectOptions::default(), + JsonCodec::new(), + TypedHandshakeHandler, + ) + .await + .unwrap(); + + // Confirm normal traffic still flows after the typed handshake. + let message = EchoControlMessage::Message("after handshake".into()); + socketeer.send(message.clone()).await.unwrap(); + assert_eq!(socketeer.next_message().await.unwrap(), message); + socketeer.close_connection().await.unwrap(); + } + + #[tokio::test] + async fn test_binary_custom_keepalive() { + // The widening of custom_keepalive_message from Option to + // Option is otherwise unexercised. echo_server silently + // ignores Binary frames, so the receive queue stays clean and we can + // verify the connection survives a binary keepalive cycle. + let server_address = get_mock_address(echo_server).await; + let options = ConnectOptions { + keepalive_interval: Some(Duration::from_millis(100)), + custom_keepalive_message: Some(Message::Binary(Bytes::from_static(b"keepalive"))), + ..ConnectOptions::default() + }; + let mut socketeer: Socketeer = + Socketeer::connect_with(&format!("ws://{server_address}"), options) + .await + .unwrap(); + + // Wait long enough for at least a couple of keepalive ticks to fire. + sleep(Duration::from_millis(350)).await; + + let message = EchoControlMessage::Message("post-keepalive".into()); + socketeer.send(message.clone()).await.unwrap(); + assert_eq!(socketeer.next_message().await.unwrap(), message); + socketeer.close_connection().await.unwrap(); + } } From 09d7364f9e0ee018c5d111335c912544a96eabed Mon Sep 17 00:00:00 2001 From: Zach Heylmun Date: Thu, 7 May 2026 15:48:53 -0400 Subject: [PATCH 3/7] test: push coverage to 98% lines / 96% regions Adds focused tests for the gaps that remained after the initial Codec test pass: - Debug/Default/Clone impls on JsonCodec, MsgPackCodec, RawCodec - Socketeer::Debug formatter and Socketeer::next_raw_message - HandshakeContext::send_binary (round-tripped via msgpack_echo_server) - HandshakeContext::recv_text rejecting a Binary frame as Error::UnexpectedMessageType - ConnectOptions::build_request's extra_headers loop body - auth_echo_server's bad-token branch - msgpack_echo_server's SendPing and Close arms Remaining uncovered regions are unreachable!() arms (by design) and a couple of mock-server protocol-frame edges that need server-side oddities to trigger. --- src/codec.rs | 28 ++++++++ src/lib.rs | 186 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+) diff --git a/src/codec.rs b/src/codec.rs index 1c56891..bd1ae53 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -323,4 +323,32 @@ mod tests { let frame = Message::Ping(Bytes::from_static(b"ping")); assert_eq!(codec.decode(&frame).unwrap(), frame); } + + // The codecs are `Copy`, so `let _ = c.clone();` would normally be + // `clippy::clone_on_copy`. We explicitly want to exercise the manual Clone + // impls so the codec module reaches full coverage. + #[allow(clippy::clone_on_copy)] + #[test] + fn json_codec_debug_default_clone() { + let codec: JsonCodec = JsonCodec::default(); + let _cloned = codec.clone(); + assert_eq!(format!("{codec:?}"), "JsonCodec"); + } + + #[cfg(feature = "msgpack")] + #[allow(clippy::clone_on_copy)] + #[test] + fn msgpack_codec_debug_default_clone() { + let codec: MsgPackCodec = MsgPackCodec::default(); + let _cloned = codec.clone(); + assert_eq!(format!("{codec:?}"), "MsgPackCodec"); + } + + #[allow(clippy::clone_on_copy)] + #[test] + fn raw_codec_debug_clone() { + let codec = RawCodec::new(); + let _cloned = codec.clone(); + assert_eq!(format!("{codec:?}"), "RawCodec"); + } } diff --git a/src/lib.rs b/src/lib.rs index 280dbcf..4f11a42 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -776,6 +776,192 @@ mod tests { socketeer.close_connection().await.unwrap(); } + #[tokio::test] + async fn test_extra_headers_used() { + // Cover ConnectOptions::build_request's loop body that copies + // `extra_headers` onto the upgrade request. + let server_address = get_mock_address(echo_server).await; + let mut headers = tokio_tungstenite::tungstenite::http::HeaderMap::new(); + headers.insert("X-Test-Header", "socketeer".parse().unwrap()); + let options = ConnectOptions { + extra_headers: headers, + ..ConnectOptions::default() + }; + let mut socketeer: Socketeer = + Socketeer::connect_with(&format!("ws://{server_address}"), options) + .await + .unwrap(); + let message = EchoControlMessage::Message("hi".into()); + socketeer.send(message.clone()).await.unwrap(); + assert_eq!(socketeer.next_message().await.unwrap(), message); + socketeer.close_connection().await.unwrap(); + } + + #[tokio::test] + async fn test_auth_handler_bad_token() { + // Covers auth_echo_server's bad-token branch (sends {"status":"error"} + // and shuts down). The handler observes the error response, returns + // Ok, then a subsequent send fails because the server has closed. + struct BadTokenHandler; + + impl ConnectionHandler for BadTokenHandler { + async fn on_connected( + &mut self, + ctx: &mut HandshakeContext<'_, C>, + ) -> Result<(), Error> { + ctx.send_text(r#"{"action":"auth","token":"WRONG"}"#) + .await?; + let resp = ctx.recv_text().await?; + assert!(resp.contains("error")); + Ok(()) + } + } + + let server_address = get_mock_address(auth_echo_server).await; + let _socketeer: Socketeer = Socketeer::connect_with_codec( + &format!("ws://{server_address}"), + ConnectOptions::default(), + JsonCodec::new(), + BadTokenHandler, + ) + .await + .unwrap(); + } + + #[cfg(feature = "msgpack")] + #[tokio::test] + async fn test_msgpack_send_ping() { + // Covers the SendPing arm of msgpack_echo_server. + type EchoMsgPack = MsgPackCodec; + + let server_address = get_mock_address(msgpack_echo_server).await; + let mut socketeer: Socketeer = + Socketeer::connect(&format!("ws://{server_address}")) + .await + .unwrap(); + socketeer.send(EchoControlMessage::SendPing).await.unwrap(); + // Server replies with a Ping; Socketeer auto-Pongs. Round-trip a real + // message to confirm the connection is still alive. + let message = EchoControlMessage::Message("after ping".into()); + socketeer.send(message.clone()).await.unwrap(); + assert_eq!(socketeer.next_message().await.unwrap(), message); + socketeer.close_connection().await.unwrap(); + } + + #[cfg(feature = "msgpack")] + #[tokio::test] + async fn test_msgpack_close_request() { + // Covers the Close arm of msgpack_echo_server. + type EchoMsgPack = MsgPackCodec; + + let server_address = get_mock_address(msgpack_echo_server).await; + let mut socketeer: Socketeer = + Socketeer::connect(&format!("ws://{server_address}")) + .await + .unwrap(); + socketeer.send(EchoControlMessage::Close).await.unwrap(); + let result = socketeer.next_message().await; + assert!(matches!(result.unwrap_err(), Error::WebsocketClosed)); + } + + #[tokio::test] + async fn test_socketeer_debug_format() { + let server_address = get_mock_address(echo_server).await; + let socketeer: Socketeer = Socketeer::connect(&format!("ws://{server_address}")) + .await + .unwrap(); + let formatted = format!("{socketeer:?}"); + assert!(formatted.starts_with("Socketeer")); + assert!(formatted.contains("url")); + } + + #[tokio::test] + async fn test_next_raw_message() { + // Cover Socketeer::next_raw_message (the raw-receive escape hatch). + let server_address = get_mock_address(echo_server).await; + let mut socketeer: Socketeer = + Socketeer::connect(&format!("ws://{server_address}")) + .await + .unwrap(); + let message = EchoControlMessage::Message("raw recv".into()); + socketeer.send(message).await.unwrap(); + let frame = socketeer.next_raw_message().await.unwrap(); + assert!(matches!(frame, Message::Text(_))); + socketeer.close_connection().await.unwrap(); + } + + #[cfg(feature = "msgpack")] + #[tokio::test] + async fn test_handshake_send_binary_recv_raw() { + // Cover HandshakeContext::send_binary by sending a pre-encoded + // msgpack frame from on_connected and reading the binary echo back + // via recv_raw. + struct BinaryHandshake; + + type EchoMsgPack = MsgPackCodec; + + impl ConnectionHandler for BinaryHandshake { + async fn on_connected( + &mut self, + ctx: &mut HandshakeContext<'_, EchoMsgPack>, + ) -> Result<(), Error> { + let payload = + rmp_serde::to_vec_named(&EchoControlMessage::Message("binary".into())).unwrap(); + ctx.send_binary(payload).await?; + let echo = ctx.recv_raw().await?; + assert!(matches!(echo, Message::Binary(_))); + Ok(()) + } + } + + let server_address = get_mock_address(msgpack_echo_server).await; + let socketeer: Socketeer = Socketeer::connect_with_codec( + &format!("ws://{server_address}"), + ConnectOptions::default(), + MsgPackCodec::new(), + BinaryHandshake, + ) + .await + .unwrap(); + socketeer.close_connection().await.unwrap(); + } + + #[cfg(feature = "msgpack")] + #[tokio::test] + async fn test_handshake_recv_text_rejects_binary() { + // Cover the non-Text branch of HandshakeContext::recv_text by pointing + // it at a server that only speaks binary frames. + struct ExpectsTextOnBinary; + + type EchoMsgPack = MsgPackCodec; + + impl ConnectionHandler for ExpectsTextOnBinary { + async fn on_connected( + &mut self, + ctx: &mut HandshakeContext<'_, EchoMsgPack>, + ) -> Result<(), Error> { + let payload = + rmp_serde::to_vec_named(&EchoControlMessage::Message("hi".into())).unwrap(); + ctx.send_binary(payload).await?; + // recv_text must reject the echoed Binary frame. + let err = ctx.recv_text().await.unwrap_err(); + assert!(matches!(err, Error::UnexpectedMessageType(_))); + Ok(()) + } + } + + let server_address = get_mock_address(msgpack_echo_server).await; + let socketeer: Socketeer = Socketeer::connect_with_codec( + &format!("ws://{server_address}"), + ConnectOptions::default(), + MsgPackCodec::new(), + ExpectsTextOnBinary, + ) + .await + .unwrap(); + socketeer.close_connection().await.unwrap(); + } + #[tokio::test] async fn test_binary_custom_keepalive() { // The widening of custom_keepalive_message from Option to From 1516b0ba8ddb571a751efe3fc85c7425ee3d7e41 Mon Sep 17 00:00:00 2001 From: Zach Date: Fri, 8 May 2026 09:08:36 -0400 Subject: [PATCH 4/7] Clearly document test focused intent, justifying panics --- src/mock_server.rs | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/mock_server.rs b/src/mock_server.rs index 31c5bed..648614b 100644 --- a/src/mock_server.rs +++ b/src/mock_server.rs @@ -28,11 +28,19 @@ pub enum EchoControlMessage { /// Basic echo server that sends back messages it receives. /// It will also respond to pings and close the connection upon request. +/// +/// This is a test harness exposed under the `mocking` feature flag and is **not** +/// intended for production use. It deliberately panics on any unexpected +/// condition (bad payloads, send/close failures on auxiliary frames) so that +/// protocol violations surface loudly as test failures rather than being +/// silently swallowed. /// # Errors /// - If the socket is closed unexpectedly -/// - If the server cannot send a message +/// - If echoing or sending a control frame to the client fails /// # Panics -/// - If a received message fails to deserialize +/// - If a received message fails to deserialize as an [`EchoControlMessage`] +/// - If sending a Pong reply fails +/// - If closing the sink in response to a peer-initiated close fails pub async fn echo_server(ws: WebSocketStreamType) -> Result { let (mut sink, mut stream) = ws.split(); let mut shutting_down = false; @@ -83,11 +91,19 @@ pub async fn echo_server(ws: WebSocketStreamType) -> Result Result { let (mut sink, mut stream) = ws.split(); @@ -134,11 +150,16 @@ pub async fn msgpack_echo_server(ws: WebSocketStreamType) -> Result Result { let (mut sink, mut stream) = ws.split(); From ffdca32e38aed73c8b656590264da4337f83d848 Mon Sep 17 00:00:00 2001 From: Zach Heylmun Date: Fri, 8 May 2026 09:47:23 -0400 Subject: [PATCH 5/7] fix: mark Error::Codec inner as #[source] So downstream consumers can walk the error chain via std::error::Error::source() and reach the underlying codec error. --- src/error.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/error.rs b/src/error.rs index 4a782b1..442d8db 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,7 +26,7 @@ pub enum Error { UnexpectedMessageType(Box), /// Error thrown if a [`crate::Codec`] fails to encode or decode a message. #[error("Codec error: {0}")] - Codec(Box), + Codec(#[source] Box), /// Error thrown if socketeer is dropped without closing the connection. /// This error will be removed once async destructors are stabilized. /// See [issue](https://github.com/rust-lang/rust/issues/126482) From cd6728804ae8229a0f7a62172c460022c59acb35 Mon Sep 17 00:00:00 2001 From: Zach Heylmun Date: Fri, 8 May 2026 10:00:56 -0400 Subject: [PATCH 6/7] fix: surface peer close as WebsocketClosed in HandshakeContext::recv MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With RawCodec, recv_raw returns Ok(Message::Close(_)) and RawCodec::decode is the identity, so a peer-initiated close used to leak through recv() as Ok(Close) instead of Err(WebsocketClosed) — contradicting the documented behavior. Intercept Close before delegating to the codec, and point users at recv_raw if they need to observe close frames directly. --- src/handler.rs | 12 ++++++++---- src/lib.rs | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/handler.rs b/src/handler.rs index aa45468..012fbee 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -44,13 +44,17 @@ impl<'a, C: Codec> HandshakeContext<'a, C> { /// Receive and decode the next message using the connection's [`Codec`]. /// - /// Skips ping/pong protocol frames. Returns an error on close or codec failure. + /// Skips ping/pong protocol frames. A peer-initiated close is surfaced as + /// [`Error::WebsocketClosed`] regardless of the negotiated codec — use + /// [`Self::recv_raw`] if you need to observe the close frame directly. /// # Errors - /// - If the WebSocket connection is closed + /// - If the WebSocket connection is closed (including a received close frame) /// - If the codec fails to decode the frame pub async fn recv(&mut self) -> Result { - let frame = self.recv_raw().await?; - self.codec.decode(&frame) + match self.recv_raw().await? { + Message::Close(_) => Err(Error::WebsocketClosed), + frame => self.codec.decode(&frame), + } } /// Send a raw text frame during the handshake. diff --git a/src/lib.rs b/src/lib.rs index 4f11a42..180399a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -776,6 +776,38 @@ mod tests { socketeer.close_connection().await.unwrap(); } + #[tokio::test] + async fn test_handshake_recv_close_with_raw_codec() { + // Regression: with RawCodec, recv_raw returns Ok(Message::Close(_)) and + // RawCodec::decode is the identity, so a peer-initiated close used to + // surface as Ok(Close) instead of Err(WebsocketClosed). recv must + // intercept Close before delegating to the codec. + struct CloseExpecting; + + impl ConnectionHandler for CloseExpecting { + async fn on_connected( + &mut self, + ctx: &mut HandshakeContext<'_, RawCodec>, + ) -> Result<(), Error> { + // Ask the echo server to close (JSON unit-variant for EchoControlMessage::Close). + ctx.send(&Message::Text(r#""Close""#.into())).await?; + let err = ctx.recv().await.unwrap_err(); + assert!(matches!(err, Error::WebsocketClosed)); + Ok(()) + } + } + + let server_address = get_mock_address(echo_server).await; + let _socketeer: Socketeer = Socketeer::connect_with_codec( + &format!("ws://{server_address}"), + ConnectOptions::default(), + RawCodec::new(), + CloseExpecting, + ) + .await + .unwrap(); + } + #[tokio::test] async fn test_extra_headers_used() { // Cover ConnectOptions::build_request's loop body that copies From b10d98707a3a140daba6dc0fd8d0d41a3cd64ef7 Mon Sep 17 00:00:00 2001 From: Zach Heylmun Date: Fri, 8 May 2026 10:04:03 -0400 Subject: [PATCH 7/7] test: clarify raw-API test names and cover send_raw directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename test_send_raw_receive_raw to test_raw_codec_message_roundtrip — it exercises the typed send/next_message path with RawCodec, not the raw API. Update test_next_raw_message to call send_raw alongside next_raw_message (rename to test_send_raw_next_raw_message), giving both escape-hatch methods explicit coverage. --- src/lib.rs | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 180399a..40baac1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -566,7 +566,9 @@ mod tests { } #[tokio::test] - async fn test_send_raw_receive_raw() { + async fn test_raw_codec_message_roundtrip() { + // Typed send/next_message round-trip when the codec is RawCodec — the + // codec is identity, so frames pass through unchanged. let server_address = get_mock_address(echo_server).await; let mut socketeer: Socketeer = Socketeer::connect(&format!("ws://{server_address}")) @@ -908,17 +910,23 @@ mod tests { } #[tokio::test] - async fn test_next_raw_message() { - // Cover Socketeer::next_raw_message (the raw-receive escape hatch). + async fn test_send_raw_next_raw_message() { + // Cover the raw send/receive escape hatches on a typed (non-RawCodec) + // connection: send_raw bypasses encoding, next_raw_message bypasses + // decoding, so we can speak frames the codec wouldn't otherwise + // produce or accept. let server_address = get_mock_address(echo_server).await; let mut socketeer: Socketeer = Socketeer::connect(&format!("ws://{server_address}")) .await .unwrap(); - let message = EchoControlMessage::Message("raw recv".into()); - socketeer.send(message).await.unwrap(); + let raw_text = r#"{"Message":"raw recv"}"#; + socketeer + .send_raw(Message::Text(raw_text.into())) + .await + .unwrap(); let frame = socketeer.next_raw_message().await.unwrap(); - assert!(matches!(frame, Message::Text(_))); + assert_eq!(frame, Message::Text(raw_text.into())); socketeer.close_connection().await.unwrap(); }