diff --git a/README.md b/README.md index 78f0dcd..10d1453 100644 --- a/README.md +++ b/README.md @@ -14,15 +14,16 @@ processes requests. use async_trait::async_trait; use tokio::net::UnixListener; -use ssh_agent_lib::agent::Agent; +use ssh_agent_lib::agent::{Session, Agent}; use ssh_agent_lib::error::AgentError; use ssh_agent_lib::proto::message::{Message, SignRequest}; +#[derive(Default)] struct MyAgent; #[async_trait] -impl Agent for MyAgent { - async fn handle(&self, message: Message) -> Result { +impl Session for MyAgent { + async fn handle(&mut self, message: Message) -> Result { match message { Message::SignRequest(request) => { // get the signature by signing `request.data` diff --git a/examples/key_storage.rs b/examples/key_storage.rs index 4f9b974..88f455f 100644 --- a/examples/key_storage.rs +++ b/examples/key_storage.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use log::info; use tokio::net::UnixListener; -use ssh_agent_lib::agent::Agent; +use ssh_agent_lib::agent::{Agent, Session}; use ssh_agent_lib::error::AgentError; use ssh_agent_lib::proto::message::{self, Message, SignRequest}; use ssh_agent_lib::proto::private_key::{PrivateKey, RsaPrivateKey}; @@ -147,8 +147,8 @@ impl KeyStorage { } #[async_trait] -impl Agent for KeyStorage { - async fn handle(&self, message: Message) -> Result { +impl Session for KeyStorage { + async fn handle(&mut self, message: Message) -> Result { self.handle_message(message).or_else(|error| { println!("Error handling message - {:?}", error); Ok(Message::Failure) @@ -156,6 +156,12 @@ impl Agent for KeyStorage { } } +impl Agent for KeyStorage { + fn new_session(&mut self) -> impl Session { + KeyStorage::new() + } +} + fn rsa_openssl_from_ssh(ssh_rsa: &RsaPrivateKey) -> Result, Box> { let n = BigNum::from_slice(&ssh_rsa.n)?; let e = BigNum::from_slice(&ssh_rsa.e)?; diff --git a/src/agent.rs b/src/agent.rs index 51e5567..7cd2a33 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -11,13 +11,13 @@ use std::fmt; use std::io; use std::marker::Unpin; use std::mem::size_of; -use std::sync::Arc; use super::error::AgentError; use super::proto::message::Message; use super::proto::{from_bytes, to_bytes}; -struct MessageCodec; +#[derive(Debug)] +pub struct MessageCodec; impl Decoder for MessageCodec { type Item = Message; @@ -52,39 +52,6 @@ impl Encoder for MessageCodec { } } -struct Session { - agent: Arc, - adapter: Framed, -} - -impl Session -where - A: Agent, - S: AsyncRead + AsyncWrite + Unpin, -{ - fn new(agent: Arc, socket: S) -> Self { - let adapter = Framed::new(socket, MessageCodec); - Self { agent, adapter } - } - - async fn handle_socket(&mut self) -> Result<(), AgentError> { - loop { - if let Some(incoming_message) = self.adapter.try_next().await? { - let response = self.agent.handle(incoming_message).await.map_err(|e| { - error!("Error handling message; error = {:?}", e); - AgentError::User - })?; - - self.adapter.send(response).await?; - } else { - // Reached EOF of the stream (client disconnected), - // we can close the socket and exit the handler. - return Ok(()); - } - } - } -} - #[async_trait] pub trait ListeningSocket { type Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static; @@ -109,24 +76,48 @@ impl ListeningSocket for TcpListener { } #[async_trait] -pub trait Agent: 'static + Sync + Send + Sized { - async fn handle(&self, message: Message) -> Result; +pub trait Session: 'static + Sync + Send + Sized { + async fn handle(&mut self, message: Message) -> Result; - async fn listen(self, socket: S) -> Result<(), AgentError> + async fn handle_socket( + &mut self, + mut adapter: Framed, + ) -> Result<(), AgentError> where S: ListeningSocket + fmt::Debug + Send, { - info!("Listening; socket = {:?}", socket); - let arc_self = Arc::new(self); + loop { + if let Some(incoming_message) = adapter.try_next().await? { + let response = self.handle(incoming_message).await.map_err(|e| { + error!("Error handling message; error = {:?}", e); + AgentError::User + })?; + + adapter.send(response).await?; + } else { + // Reached EOF of the stream (client disconnected), + // we can close the socket and exit the handler. + return Ok(()); + } + } + } +} +#[async_trait] +pub trait Agent: 'static + Sync + Send + Sized { + fn new_session(&mut self) -> impl Session; + async fn listen(mut self, socket: S) -> Result<(), AgentError> + where + S: ListeningSocket + fmt::Debug + Send, + { + info!("Listening; socket = {:?}", socket); loop { match socket.accept().await { Ok(socket) => { - let agent = arc_self.clone(); - let mut session = Session::new(agent, socket); - + let mut session = self.new_session(); tokio::spawn(async move { - if let Err(e) = session.handle_socket().await { + let adapter = Framed::new(socket, MessageCodec); + if let Err(e) = session.handle_socket::(adapter).await { error!("Agent protocol error; error = {:?}", e); } }); @@ -139,3 +130,12 @@ pub trait Agent: 'static + Sync + Send + Sized { } } } + +impl Agent for T +where + T: Default + Session, +{ + fn new_session(&mut self) -> impl Session { + Self::default() + } +}