Skip to content

Commit 40a0164

Browse files
authored
Merge pull request #3 from chmod77/feat/rewrite-connect-websocket
feat - rewrite WS connection
2 parents 9b41da7 + 9d19219 commit 40a0164

File tree

5 files changed

+312
-255
lines changed

5 files changed

+312
-255
lines changed

README.md

+83
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,53 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
7171
}
7272
```
7373

74+
75+
### Connecting
76+
77+
Connect to Pusher:
78+
79+
```rust
80+
client.connect().await?;
81+
```
82+
83+
### Subscribing to Channels
84+
85+
Subscribe to a public channel:
86+
87+
```rust
88+
client.subscribe("my-channel").await?;
89+
```
90+
91+
Subscribe to a private channel:
92+
93+
```rust
94+
client.subscribe("private-my-channel").await?;
95+
```
96+
97+
Subscribe to a presence channel:
98+
99+
```rust
100+
client.subscribe("presence-my-channel").await?;
101+
```
102+
103+
### Unsubscribing from Channels
104+
105+
```rust
106+
client.unsubscribe("my-channel").await?;
107+
```
108+
109+
### Binding to Events
110+
111+
Bind to a specific event on a channel:
112+
113+
```rust
114+
use pusher_rs::Event;
115+
116+
client.bind("my-event", |event: Event| {
117+
println!("Received event: {:?}", event);
118+
}).await?;
119+
```
120+
74121
### Subscribing to a channel
75122

76123
```rust
@@ -133,6 +180,15 @@ The library supports four types of channels:
133180

134181
Each channel type has specific features and authentication requirements.
135182

183+
### Handling Connection State
184+
185+
Get the current connection state:
186+
187+
```rust
188+
let state = client.get_connection_state().await;
189+
println!("Current connection state: {:?}", state);
190+
```
191+
136192
## Error Handling
137193

138194
The library uses a custom `PusherError` type for error handling. You can match on different error variants to handle specific error cases:
@@ -148,6 +204,14 @@ match client.connect().await {
148204
}
149205
```
150206

207+
### Disconnecting
208+
209+
When you're done, disconnect from Pusher:
210+
211+
```rust
212+
client.disconnect().await?;
213+
```
214+
151215
## Advanced Usage
152216

153217
### Custom Configuration
@@ -186,6 +250,25 @@ if let Some(channel) = channel_list.get("my-channel") {
186250
}
187251
```
188252

253+
### Presence Channels
254+
255+
When subscribing to a presence channel, you can provide user information:
256+
257+
```rust
258+
use serde_json::json;
259+
260+
let channel = "presence-my-channel";
261+
let socket_id = client.get_socket_id().await?;
262+
let user_id = "user_123";
263+
let user_info = json!({
264+
"name": "John Doe",
265+
"email": "[email protected]"
266+
});
267+
268+
let auth = client.authenticate_presence_channel(&socket_id, channel, user_id, Some(&user_info))?;
269+
client.subscribe_with_auth(channel, &auth).await?;
270+
```
271+
189272
### Tests
190273

191274
Integration tests live under `tests/integration_tests`

src/events.rs

+20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use serde_json::Value;
55
pub struct Event {
66
pub event: String,
77
pub channel: Option<String>,
8+
#[serde(with = "json_string")]
89
pub data: Value,
910
}
1011

@@ -113,6 +114,25 @@ impl SystemEvent {
113114
}
114115
}
115116

117+
mod json_string {
118+
use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
119+
use serde_json::Value;
120+
121+
pub fn serialize<S>(value: &Value, serializer: S) -> Result<S::Ok, S::Error>
122+
where
123+
S: Serializer,
124+
{
125+
value.to_string().serialize(serializer)
126+
}
127+
128+
pub fn deserialize<'de, D>(deserializer: D) -> Result<Value, D::Error>
129+
where
130+
D: Deserializer<'de>,
131+
{
132+
let s = String::deserialize(deserializer)?;
133+
serde_json::from_str(&s).map_err(D::Error::custom)
134+
}
135+
}
116136
#[cfg(test)]
117137
mod tests {
118138
use super::*;

src/lib.rs

+65-44
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use cbc::{Decryptor, Encryptor};
1515
use hmac::{Hmac, Mac};
1616
use log::info;
1717
use rand::Rng;
18-
use serde_json::json;
18+
use serde_json::{json, Value};
1919
use sha2::Sha256;
2020
use std::collections::HashMap;
2121
use std::sync::Arc;
@@ -28,14 +28,15 @@ pub use config::PusherConfig;
2828
pub use error::{PusherError, PusherResult};
2929
pub use events::{Event, SystemEvent};
3030

31-
use websocket::WebSocketClient;
31+
use websocket::{WebSocketClient, WebSocketCommand};
3232

3333
/// This struct provides methods for connecting to Pusher, subscribing to channels,
3434
/// triggering events, and handling incoming events.
3535
pub struct PusherClient {
3636
config: PusherConfig,
3737
auth: PusherAuth,
38-
websocket: Option<WebSocketClient>,
38+
// websocket: Option<WebSocketClient>,
39+
websocket_command_tx: Option<mpsc::Sender<WebSocketCommand>>,
3940
channels: Arc<RwLock<HashMap<String, Channel>>>,
4041
event_handlers: Arc<RwLock<HashMap<String, Vec<Box<dyn Fn(Event) + Send + Sync + 'static>>>>>,
4142
state: Arc<RwLock<ConnectionState>>,
@@ -73,29 +74,44 @@ impl PusherClient {
7374
let auth = PusherAuth::new(&config.app_key, &config.app_secret);
7475
let (event_tx, event_rx) = mpsc::channel(100);
7576
let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
76-
let event_handlers = Arc::new(RwLock::new(HashMap::new()));
77-
let encrypted_channels = Arc::new(RwLock::new(HashMap::new()));
77+
let event_handlers = Arc::new(RwLock::new(std::collections::HashMap::new()));
78+
let encrypted_channels = Arc::new(RwLock::new(std::collections::HashMap::new()));
7879

7980
let client = Self {
8081
config,
8182
auth,
82-
websocket: None,
83-
channels: Arc::new(RwLock::new(HashMap::new())),
83+
websocket_command_tx: None,
84+
channels: Arc::new(RwLock::new(std::collections::HashMap::new())),
8485
event_handlers: event_handlers.clone(),
8586
state: state.clone(),
8687
event_tx,
8788
encrypted_channels,
8889
};
8990

90-
// Spawn the event handling task
9191
tokio::spawn(Self::handle_events(event_rx, event_handlers));
9292

9393
Ok(client)
9494
}
95+
96+
async fn send(&self, message: String) -> PusherResult<()> {
97+
if let Some(tx) = &self.websocket_command_tx {
98+
tx.send(WebSocketCommand::Send(message))
99+
.await
100+
.map_err(|e| {
101+
PusherError::WebSocketError(format!("Failed to send command: {}", e))
102+
})?;
103+
Ok(())
104+
} else {
105+
Err(PusherError::ConnectionError("Not connected".into()))
106+
}
107+
}
108+
95109
async fn handle_events(
96110
mut event_rx: mpsc::Receiver<Event>,
97111
event_handlers: Arc<
98-
RwLock<HashMap<String, Vec<Box<dyn Fn(Event) + Send + Sync + 'static>>>>,
112+
RwLock<
113+
std::collections::HashMap<String, Vec<Box<dyn Fn(Event) + Send + Sync + 'static>>>,
114+
>,
99115
>,
100116
) {
101117
while let Some(event) = event_rx.recv().await {
@@ -115,18 +131,24 @@ impl PusherClient {
115131
/// A `PusherResult` indicating success or failure.
116132
pub async fn connect(&mut self) -> PusherResult<()> {
117133
let url = self.get_websocket_url()?;
118-
let mut websocket =
119-
WebSocketClient::new(url.clone(), Arc::clone(&self.state), self.event_tx.clone());
134+
let (command_tx, command_rx) = mpsc::channel(100);
135+
136+
let mut websocket = WebSocketClient::new(
137+
url.clone(),
138+
Arc::clone(&self.state),
139+
self.event_tx.clone(),
140+
command_rx,
141+
);
142+
120143
log::info!("Connecting to Pusher using URL: {}", url);
121144
websocket.connect().await?;
122-
self.websocket = Some(websocket);
123145

124-
// Start the WebSocket event loop
125-
let mut ws = self.websocket.take().unwrap();
126146
tokio::spawn(async move {
127-
ws.run().await;
147+
websocket.run().await;
128148
});
129149

150+
self.websocket_command_tx = Some(command_tx);
151+
130152
Ok(())
131153
}
132154

@@ -136,11 +158,12 @@ impl PusherClient {
136158
///
137159
/// A `PusherResult` indicating success or failure.
138160
pub async fn disconnect(&mut self) -> PusherResult<()> {
139-
if let Some(websocket) = &self.websocket {
140-
websocket.close().await?;
161+
if let Some(tx) = self.websocket_command_tx.take() {
162+
tx.send(WebSocketCommand::Close).await.map_err(|e| {
163+
PusherError::WebSocketError(format!("Failed to send close command: {}", e))
164+
})?;
141165
}
142166
*self.state.write().await = ConnectionState::Disconnected;
143-
self.websocket = None;
144167
Ok(())
145168
}
146169

@@ -158,21 +181,17 @@ impl PusherClient {
158181
let mut channels = self.channels.write().await;
159182
channels.insert(channel_name.to_string(), channel);
160183

161-
if let Some(websocket) = &self.websocket {
162-
let data = json!({
163-
"event": "pusher:subscribe",
164-
"data": {
165-
"channel": channel_name
166-
}
167-
});
168-
websocket.send(serde_json::to_string(&data)?).await?;
169-
} else {
170-
return Err(PusherError::ConnectionError("Not connected".into()));
171-
}
184+
let data = json!({
185+
"event": "pusher:subscribe",
186+
"data": {
187+
"channel": channel_name
188+
}
189+
});
172190

173-
Ok(())
191+
self.send(serde_json::to_string(&data)?).await
174192
}
175193

194+
176195
/// Subscribes to an encrypted channel.
177196
///
178197
/// # Arguments
@@ -208,6 +227,7 @@ impl PusherClient {
208227
/// # Returns
209228
///
210229
/// A `PusherResult` indicating success or failure.
230+
///
211231
pub async fn unsubscribe(&mut self, channel_name: &str) -> PusherResult<()> {
212232
{
213233
let mut channels = self.channels.write().await;
@@ -219,19 +239,14 @@ impl PusherClient {
219239
encrypted_channels.remove(channel_name);
220240
}
221241

222-
if let Some(websocket) = &self.websocket {
223-
let data = json!({
224-
"event": "pusher:unsubscribe",
225-
"data": {
226-
"channel": channel_name
227-
}
228-
});
229-
websocket.send(serde_json::to_string(&data)?).await?;
230-
} else {
231-
return Err(PusherError::ConnectionError("Not connected".into()));
232-
}
242+
let data = json!({
243+
"event": "pusher:unsubscribe",
244+
"data": {
245+
"channel": channel_name
246+
}
247+
});
233248

234-
Ok(())
249+
self.send(serde_json::to_string(&data)?).await
235250
}
236251

237252
/// Triggers an event on a channel.
@@ -251,10 +266,14 @@ impl PusherClient {
251266
self.config.cluster, self.config.app_id
252267
);
253268

269+
// Validate that the data is valid JSON, but keep it as a string
270+
serde_json::from_str::<serde_json::Value>(data)
271+
.map_err(|e| PusherError::JsonError(e))?;
272+
254273
let body = json!({
255274
"name": event,
256275
"channel": channel,
257-
"data": data
276+
"data": data, // Keep data as a string
258277
});
259278
let path = format!("/apps/{}/events", self.config.app_id);
260279
let auth_params = self.auth.authenticate_request("POST", &path, &body)?;
@@ -371,6 +390,7 @@ impl PusherClient {
371390
/// # Returns
372391
///
373392
/// A `PusherResult` indicating success or failure.
393+
///
374394
pub async fn bind<F>(&self, event_name: &str, callback: F) -> PusherResult<()>
375395
where
376396
F: Fn(Event) + Send + Sync + 'static,
@@ -535,7 +555,8 @@ mod tests {
535555

536556
#[tokio::test]
537557
async fn test_trigger_batch() {
538-
let config = PusherConfig::from_env().expect("Failed to load Pusher configuration from environment");
558+
let config =
559+
PusherConfig::from_env().expect("Failed to load Pusher configuration from environment");
539560
let client = PusherClient::new(config).unwrap();
540561

541562
let batch_events = vec![

0 commit comments

Comments
 (0)