Skip to content

Commit

Permalink
proto: rewrite error handling, remove a todo
Browse files Browse the repository at this point in the history
Signed-off-by: Arthur Gautier <[email protected]>
  • Loading branch information
baloo committed Apr 4, 2024
1 parent e44431d commit 7a89143
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 38 deletions.
17 changes: 3 additions & 14 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::marker::Unpin;
use std::mem::size_of;

use super::error::AgentError;
use super::proto::message::Message;
use super::proto::{message::Message, ProtoError};

#[derive(Debug)]
pub struct MessageCodec;
Expand All @@ -40,11 +40,6 @@ impl Decoder for MessageCodec {
return Ok(None);
}

//use std::io::Write;
//let mut file = std::fs::File::create(uuid::Uuid::new_v4().to_string())?;
//file.write_all(bytes)?;
//drop(file);

let message: Message = Message::decode(&mut bytes)?;
src.advance(size_of::<u32>() + length);
Ok(Some(message))
Expand All @@ -58,16 +53,10 @@ impl Encoder<Message> for MessageCodec {
let mut bytes = Vec::new();

let len = item.encoded_len().unwrap() as u32;
len.encode(&mut bytes)?;
len.encode(&mut bytes).map_err(ProtoError::SshEncoding)?;

item.encode(&mut bytes)?;
item.encode(&mut bytes).map_err(ProtoError::SshEncoding)?;
dst.put(&*bytes);
//use std::io::Write;
//let mut file = std::fs::File::create(uuid::Uuid::new_v4().to_string())?;
//let mut bytes = Vec::new();
//item.encode(&mut bytes)?;
//file.write_all(&bytes)?;
//drop(file);

Ok(())
}
Expand Down
22 changes: 12 additions & 10 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
use std::io;

use crate::proto::ProtoError;

#[derive(Debug)]
pub enum AgentError {
Ssh(ssh_key::Error),
Proto(ssh_encoding::Error),
//Ssh(ssh_key::Error),
Proto(ProtoError),
IO(io::Error),
}

impl From<ssh_encoding::Error> for AgentError {
fn from(e: ssh_encoding::Error) -> AgentError {
impl From<ProtoError> for AgentError {
fn from(e: ProtoError) -> AgentError {
AgentError::Proto(e)
}
}

impl From<ssh_key::Error> for AgentError {
fn from(e: ssh_key::Error) -> AgentError {
AgentError::Ssh(e)
}
}
//impl From<ssh_key::Error> for AgentError {
// fn from(e: ssh_key::Error) -> AgentError {
// AgentError::Ssh(e)
// }
//}

impl From<io::Error> for AgentError {
fn from(e: io::Error) -> AgentError {
Expand All @@ -28,7 +30,7 @@ impl From<io::Error> for AgentError {
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AgentError::Ssh(e) => write!(f, "Agent: Ssh key error: {e}"),
//AgentError::Ssh(e) => write!(f, "Agent: Ssh key error: {e}"),
AgentError::Proto(proto) => write!(f, "Agent: Protocol error: {}", proto),
AgentError::IO(error) => write!(f, "Agent: I/O error: {}", error),
}
Expand Down
5 changes: 5 additions & 0 deletions src/proto/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub enum ProtoError {
IO(io::Error),
SshEncoding(ssh_encoding::Error),
SshKey(ssh_key::Error),
UnsupportedCommand { command: u8 },
}

impl From<ProtoError> for () {
Expand Down Expand Up @@ -48,6 +49,7 @@ impl std::error::Error for ProtoError {
ProtoError::IO(e) => Some(e),
ProtoError::SshEncoding(e) => Some(e),
ProtoError::SshKey(e) => Some(e),
ProtoError::UnsupportedCommand { .. } => None,
}
}
}
Expand All @@ -61,6 +63,9 @@ impl std::fmt::Display for ProtoError {
ProtoError::IO(_) => f.write_str("I/O Error"),
ProtoError::SshEncoding(_) => f.write_str("SSH encoding Error"),
ProtoError::SshKey(e) => write!(f, "SSH key Error: {e}"),
ProtoError::UnsupportedCommand { command } => {
write!(f, "Command not supported ({command})")
}
}
}
}
Expand Down
35 changes: 21 additions & 14 deletions src/proto/message.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use ssh_encoding::{CheckedSum, Decode, Encode, Error as EncodingError, Reader, Writer};
use ssh_key::{private::KeypairData, public::KeyData, Error, Result, Signature};
use ssh_key::{private::KeypairData, public::KeyData, Error, Signature};

use super::ProtoError;

type Result<T> = core::result::Result<T, ProtoError>;

#[derive(Clone, PartialEq, Debug)]
pub struct Identity {
Expand All @@ -21,7 +25,7 @@ impl Identity {
}

impl Decode for Identity {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let pubkey = reader.read_prefixed(KeyData::decode)?;
Expand Down Expand Up @@ -56,7 +60,7 @@ pub struct SignRequest {
}

impl Decode for SignRequest {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let pubkey = reader.read_prefixed(KeyData::decode)?;
Expand Down Expand Up @@ -97,7 +101,7 @@ pub struct AddIdentity {
}

impl Decode for AddIdentity {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let privkey = KeypairData::decode(reader)?;
Expand Down Expand Up @@ -126,7 +130,7 @@ pub struct AddIdentityConstrained {
}

impl Decode for AddIdentityConstrained {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let identity = AddIdentity::decode(reader)?;
Expand Down Expand Up @@ -168,7 +172,7 @@ pub struct RemoveIdentity {
}

impl Decode for RemoveIdentity {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let pubkey = reader.read_prefixed(KeyData::decode)?;
Expand All @@ -194,7 +198,7 @@ pub struct SmartcardKey {
}

impl Decode for SmartcardKey {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let id = String::decode(reader)?;
Expand Down Expand Up @@ -225,7 +229,7 @@ pub enum KeyConstraint {
}

impl Decode for KeyConstraint {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let constraint_type = u8::decode(reader)?;
Expand All @@ -239,7 +243,7 @@ impl Decode for KeyConstraint {
reader.read(&mut details)?;
KeyConstraint::Extension(name, details.into())
}
_ => return Err(Error::AlgorithmUnknown), // FIXME: it should be our own type
_ => return Err(Error::AlgorithmUnknown)?, // FIXME: it should be our own type
})
}
}
Expand Down Expand Up @@ -282,7 +286,7 @@ pub struct AddSmartcardKeyConstrained {
}

impl Decode for AddSmartcardKeyConstrained {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let key = SmartcardKey::decode(reader)?;
Expand Down Expand Up @@ -321,7 +325,7 @@ pub struct Extension {
}

impl Decode for Extension {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let name = String::decode(reader)?;
Expand Down Expand Up @@ -425,7 +429,7 @@ impl Message {
}

impl Decode for Message {
type Error = Error;
type Error = ProtoError;

fn decode(reader: &mut impl Reader) -> Result<Self> {
let message_type = u8::decode(reader)?;
Expand All @@ -436,7 +440,10 @@ impl Decode for Message {
11 => Ok(Self::RequestIdentities),
12 => Identity::decode_vec(reader).map(Self::IdentitiesAnswer),
13 => SignRequest::decode(reader).map(Self::SignRequest),
14 => reader.read_prefixed(|reader| Signature::decode(reader).map(Self::SignResponse)),
14 => {
Ok(reader
.read_prefixed(|reader| Signature::decode(reader).map(Self::SignResponse))?)
}
17 => AddIdentity::decode(reader).map(Self::AddIdentity),
18 => RemoveIdentity::decode(reader).map(Self::RemoveIdentity),
19 => Ok(Self::RemoveAllIdentities),
Expand All @@ -448,7 +455,7 @@ impl Decode for Message {
26 => AddSmartcardKeyConstrained::decode(reader).map(Self::AddSmartcardKeyConstrained),
27 => Extension::decode(reader).map(Self::Extension),
28 => Ok(Self::ExtensionFailure),
_ => todo!(),
command => Err(ProtoError::UnsupportedCommand { command }),
}
}
}
Expand Down

0 comments on commit 7a89143

Please sign in to comment.