Skip to content

Commit e4bd1df

Browse files
Copilot0xrinegade
andcommitted
Fix mutex poisoning and optimize metrics with dashmap, add derived traits to McpError
Co-authored-by: 0xrinegade <[email protected]>
1 parent 5bd56cf commit e4bd1df

File tree

3 files changed

+61
-26
lines changed

3 files changed

+61
-26
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ tracing = "0.1"
1717
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
1818
uuid = { version = "1.0", features = ["v4"] }
1919
once_cell = "1.19"
20+
dashmap = "6.1"
2021
solana-client = "1.17"
2122
solana-sdk = "1.17"
2223
solana-account-decoder = "1.17"

src/error.rs

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use uuid::Uuid;
77
/// This module defines a hierarchy of error types that provide
88
/// rich context for debugging and monitoring while maintaining
99
/// security by avoiding sensitive data exposure.
10-
#[derive(Error, Debug)]
10+
#[derive(Error, Debug, Clone, PartialEq)]
1111
pub enum McpError {
1212
/// Client-side errors (invalid input, malformed requests)
1313
#[error("Client error: {message}")]
@@ -23,7 +23,7 @@ pub enum McpError {
2323
message: String,
2424
request_id: Option<Uuid>,
2525
method: Option<String>,
26-
source: Option<Box<dyn std::error::Error + Send + Sync>>,
26+
source_message: Option<String>, // Store source error as string for Clone/PartialEq
2727
},
2828

2929
/// RPC-specific errors (Solana client failures)
@@ -33,7 +33,7 @@ pub enum McpError {
3333
request_id: Option<Uuid>,
3434
method: Option<String>,
3535
rpc_url: Option<String>,
36-
source: Option<Box<dyn std::error::Error + Send + Sync>>,
36+
source_message: Option<String>, // Store source error as string for Clone/PartialEq
3737
},
3838

3939
/// Validation errors (invalid parameters, security checks)
@@ -79,7 +79,7 @@ impl McpError {
7979
message: message.into(),
8080
request_id: None,
8181
method: None,
82-
source: None,
82+
source_message: None,
8383
}
8484
}
8585

@@ -90,7 +90,7 @@ impl McpError {
9090
request_id: None,
9191
method: None,
9292
rpc_url: None,
93-
source: None,
93+
source_message: None,
9494
}
9595
}
9696

@@ -176,9 +176,10 @@ impl McpError {
176176

177177
/// Adds source error context
178178
pub fn with_source(mut self, source: Box<dyn std::error::Error + Send + Sync>) -> Self {
179+
let source_message = source.to_string();
179180
match &mut self {
180-
McpError::Server { source: ref mut s, .. } => *s = Some(source),
181-
McpError::Rpc { source: ref mut s, .. } => *s = Some(source),
181+
McpError::Server { source_message: ref mut s, .. } => *s = Some(source_message),
182+
McpError::Rpc { source_message: ref mut s, .. } => *s = Some(source_message),
182183
_ => {}, // Other error types don't have source fields
183184
}
184185
self
@@ -253,19 +254,27 @@ impl McpError {
253254
log_data.insert("parameter".to_string(), Value::String(param.clone()));
254255
}
255256
},
256-
McpError::Rpc { rpc_url, .. } => {
257+
McpError::Rpc { rpc_url, source_message, .. } => {
257258
if let Some(url) = rpc_url {
258259
// Sanitize URL for logging
259260
let sanitized = crate::validation::sanitize_for_logging(url);
260261
log_data.insert("rpc_url".to_string(), Value::String(sanitized));
261262
}
263+
if let Some(source_msg) = source_message {
264+
log_data.insert("source_error".to_string(), Value::String(source_msg.clone()));
265+
}
262266
},
263267
McpError::Network { endpoint, .. } => {
264268
if let Some(ep) = endpoint {
265269
let sanitized = crate::validation::sanitize_for_logging(ep);
266270
log_data.insert("endpoint".to_string(), Value::String(sanitized));
267271
}
268272
},
273+
McpError::Server { source_message, .. } => {
274+
if let Some(source_msg) = source_message {
275+
log_data.insert("source_error".to_string(), Value::String(source_msg.clone()));
276+
}
277+
},
269278
_ => {}
270279
}
271280

@@ -352,4 +361,26 @@ mod tests {
352361
assert!(log_value.get("method").is_some());
353362
assert!(log_value.get("rpc_url").is_some());
354363
}
364+
365+
#[test]
366+
fn test_derived_traits() {
367+
let request_id = Uuid::new_v4();
368+
let error1 = McpError::validation("Invalid pubkey format")
369+
.with_request_id(request_id)
370+
.with_method("getBalance")
371+
.with_parameter("pubkey");
372+
373+
// Test Clone
374+
let error2 = error1.clone();
375+
assert_eq!(error1.request_id(), error2.request_id());
376+
assert_eq!(error1.method(), error2.method());
377+
assert_eq!(error1.error_type(), error2.error_type());
378+
379+
// Test PartialEq
380+
assert_eq!(error1, error2);
381+
382+
// Test that different errors are not equal
383+
let error3 = McpError::client("Different error");
384+
assert_ne!(error1, error3);
385+
}
355386
}

src/logging.rs

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ use tracing_subscriber::{
88
use uuid::Uuid;
99
use serde_json::Value;
1010
use std::sync::atomic::{AtomicU64, Ordering};
11-
use std::collections::HashMap;
12-
use std::sync::Mutex;
11+
use dashmap::DashMap;
1312

1413
/// Metrics collection for monitoring RPC call outcomes
1514
#[derive(Debug, Default)]
@@ -19,9 +18,9 @@ pub struct Metrics {
1918
/// Number of successful RPC calls
2019
pub successful_calls: AtomicU64,
2120
/// Number of failed RPC calls by error type
22-
pub failed_calls_by_type: Mutex<HashMap<String, u64>>,
21+
pub failed_calls_by_type: DashMap<String, u64>,
2322
/// Number of failed RPC calls by method
24-
pub failed_calls_by_method: Mutex<HashMap<String, u64>>,
23+
pub failed_calls_by_method: DashMap<String, u64>,
2524
}
2625

2726
impl Metrics {
@@ -37,21 +36,27 @@ impl Metrics {
3736

3837
/// Increment failed calls counter by error type
3938
pub fn increment_failed_calls(&self, error_type: &str, method: Option<&str>) {
40-
// Increment by error type
41-
let mut failed_by_type = self.failed_calls_by_type.lock().unwrap();
42-
*failed_by_type.entry(error_type.to_string()).or_insert(0) += 1;
39+
// Increment by error type using dashmap for concurrent access
40+
*self.failed_calls_by_type.entry(error_type.to_string()).or_insert(0) += 1;
4341

4442
// Increment by method if available
4543
if let Some(method) = method {
46-
let mut failed_by_method = self.failed_calls_by_method.lock().unwrap();
47-
*failed_by_method.entry(method.to_string()).or_insert(0) += 1;
44+
*self.failed_calls_by_method.entry(method.to_string()).or_insert(0) += 1;
4845
}
4946
}
5047

5148
/// Get current metrics as JSON value
5249
pub fn to_json(&self) -> Value {
53-
let failed_by_type = self.failed_calls_by_type.lock().unwrap().clone();
54-
let failed_by_method = self.failed_calls_by_method.lock().unwrap().clone();
50+
// Convert DashMap to HashMap for JSON serialization
51+
let failed_by_type: std::collections::HashMap<String, u64> = self.failed_calls_by_type
52+
.iter()
53+
.map(|entry| (entry.key().clone(), *entry.value()))
54+
.collect();
55+
56+
let failed_by_method: std::collections::HashMap<String, u64> = self.failed_calls_by_method
57+
.iter()
58+
.map(|entry| (entry.key().clone(), *entry.value()))
59+
.collect();
5560

5661
serde_json::json!({
5762
"total_calls": self.total_calls.load(Ordering::Relaxed),
@@ -65,8 +70,8 @@ impl Metrics {
6570
pub fn reset(&self) {
6671
self.total_calls.store(0, Ordering::Relaxed);
6772
self.successful_calls.store(0, Ordering::Relaxed);
68-
self.failed_calls_by_type.lock().unwrap().clear();
69-
self.failed_calls_by_method.lock().unwrap().clear();
73+
self.failed_calls_by_type.clear();
74+
self.failed_calls_by_method.clear();
7075
}
7176
}
7277

@@ -378,11 +383,9 @@ mod tests {
378383
assert_eq!(metrics.total_calls.load(Ordering::Relaxed), 1);
379384
assert_eq!(metrics.successful_calls.load(Ordering::Relaxed), 1);
380385

381-
let failed_by_type = metrics.failed_calls_by_type.lock().unwrap();
382-
assert_eq!(failed_by_type.get("validation"), Some(&1));
383-
384-
let failed_by_method = metrics.failed_calls_by_method.lock().unwrap();
385-
assert_eq!(failed_by_method.get("getBalance"), Some(&1));
386+
// Test dashmap access
387+
assert_eq!(metrics.failed_calls_by_type.get("validation").map(|v| *v), Some(1));
388+
assert_eq!(metrics.failed_calls_by_method.get("getBalance").map(|v| *v), Some(1));
386389
}
387390

388391
#[test]

0 commit comments

Comments
 (0)