diff --git a/src/proto/extension.rs b/src/proto/extension.rs new file mode 100644 index 0000000..76cda23 --- /dev/null +++ b/src/proto/extension.rs @@ -0,0 +1,28 @@ +use serde::{Deserialize, Serialize}; + +use super::recursive; +use super::signature::Signature; + +/// SSH key +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SshKey { + pub alg: String, + pub blob: Vec, +} + +/// session-bind@openssh.com extension +/// +/// This extension allows a ssh client to bind an agent connection to a +/// particular SSH session. +/// +/// Spec: +/// +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionBind { + #[serde(with = "recursive")] + pub host_key: SshKey, + pub session_id: Vec, + #[serde(with = "recursive")] + pub signature: Signature, + pub is_forwarding: bool, +} diff --git a/src/proto/message.rs b/src/proto/message.rs index 044169f..9f24086 100644 --- a/src/proto/message.rs +++ b/src/proto/message.rs @@ -1,4 +1,5 @@ use serde::de::{Deserializer, Visitor}; +use serde::ser::SerializeTuple; use serde::{Deserialize, Serialize}; use super::private_key::PrivateKey; @@ -65,7 +66,11 @@ impl Serialize for ExtensionContents { where S: serde::Serializer, { - serializer.serialize_bytes(&self.0) + let mut seq = serializer.serialize_tuple(self.0.len())?; + for i in &self.0 { + seq.serialize_element(i)?; + } + seq.end() } } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index cade6d9..59a4663 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -4,6 +4,7 @@ pub mod ser; #[macro_use] pub mod key_type; pub mod error; +pub mod extension; pub mod message; pub mod private_key; pub mod public_key; @@ -40,3 +41,49 @@ impl<'a, T: Serialize + Deserialize<'a>> Blob for T { from_bytes(blob) } } + +pub mod recursive { + use super::{from_bytes, to_bytes}; + use serde::{ + de::{self, Deserializer, Visitor}, + ser::{Error, Serializer}, + Deserialize, Serialize, + }; + use std::{fmt, marker::PhantomData}; + + pub fn serialize(obj: &T, serializer: S) -> Result + where + T: Serialize, + S: Serializer, + { + serializer.serialize_bytes(&to_bytes(obj).map_err(S::Error::custom)?) + } + + pub fn deserialize<'de, T, D>(deserialize: D) -> Result + where + T: Deserialize<'de>, + D: Deserializer<'de>, + { + struct RecursiveVisitor(PhantomData); + + impl<'de, T> Visitor<'de> for RecursiveVisitor + where + T: Deserialize<'de>, + { + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an integer between -2^31 and 2^31") + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + from_bytes(value).map_err(E::custom) + } + } + + deserialize.deserialize_bytes(RecursiveVisitor(PhantomData::)) + } +} diff --git a/src/proto/ser.rs b/src/proto/ser.rs index dbb7894..f828501 100644 --- a/src/proto/ser.rs +++ b/src/proto/ser.rs @@ -27,8 +27,10 @@ impl<'a, W: io::Write> ser::Serializer for &'a mut Serializer { type SerializeStruct = Self; type SerializeStructVariant = Self; - fn serialize_bool(self, _v: bool) -> ProtoResult<()> { - unimplemented!() + fn serialize_bool(self, v: bool) -> ProtoResult<()> { + self.writer + .write_u8(if v { 1 } else { 0 }) + .map_err(Into::into) } fn serialize_i8(self, v: i8) -> ProtoResult<()> { @@ -76,11 +78,11 @@ impl<'a, W: io::Write> ser::Serializer for &'a mut Serializer { } fn serialize_str(self, v: &str) -> ProtoResult<()> { - (v.len() as u32).serialize(&mut *self)?; self.serialize_bytes(v.as_bytes()) } fn serialize_bytes(self, v: &[u8]) -> ProtoResult<()> { + (v.len() as u32).serialize(&mut *self)?; self.writer.write_all(v).map_err(Into::into) } diff --git a/src/proto/tests/mod.rs b/src/proto/tests/mod.rs index b83a1a5..3cdc5fa 100644 --- a/src/proto/tests/mod.rs +++ b/src/proto/tests/mod.rs @@ -1,3 +1,4 @@ +use super::extension::SessionBind; use super::message::{Extension, Identity, Message, SignRequest}; use super::private_key::*; use super::public_key::*; @@ -178,4 +179,9 @@ fn test_extension() { assert_eq!(extension.extension_type, "session-bind@openssh.com"); let out = to_bytes(&extension).unwrap(); assert_eq!(extension_bytes, out); + + let session_bind: SessionBind = from_bytes(&extension.extension_contents.0).unwrap(); + + let out = to_bytes(&session_bind).unwrap(); + assert_eq!(extension.extension_contents.0, out); }