Skip to content

Commit

Permalink
Merge pull request #11 from baloo/baloo/session-bind
Browse files Browse the repository at this point in the history
Fixup extension serialization
  • Loading branch information
wiktor-k authored Feb 15, 2024
2 parents 57d5fe5 + 58024c8 commit d791a1a
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 4 deletions.
28 changes: 28 additions & 0 deletions src/proto/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use serde::{Deserialize, Serialize};

use super::recursive;
use super::signature::Signature;

/// SSH key
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SshKey {
pub alg: String,
pub blob: Vec<u8>,
}

/// [email protected] extension
///
/// This extension allows a ssh client to bind an agent connection to a
/// particular SSH session.
///
/// Spec:
/// <https://github.com/openssh/openssh-portable/blob/cbbdf868bce431a59e2fa36ca244d5739429408d/PROTOCOL.agent#L6>
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionBind {
#[serde(with = "recursive")]
pub host_key: SshKey,
pub session_id: Vec<u8>,
#[serde(with = "recursive")]
pub signature: Signature,
pub is_forwarding: bool,
}
7 changes: 6 additions & 1 deletion src/proto/message.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use serde::de::{Deserializer, Visitor};
use serde::ser::SerializeTuple;
use serde::{Deserialize, Serialize};

use super::private_key::PrivateKey;
Expand Down Expand Up @@ -65,7 +66,11 @@ impl Serialize for ExtensionContents {
where
S: serde::Serializer,
{
serializer.serialize_bytes(&self.0)
let mut seq = serializer.serialize_tuple(self.0.len())?;
for i in &self.0 {
seq.serialize_element(i)?;
}
seq.end()
}
}

Expand Down
47 changes: 47 additions & 0 deletions src/proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod ser;
#[macro_use]
pub mod key_type;
pub mod error;
pub mod extension;
pub mod message;
pub mod private_key;
pub mod public_key;
Expand Down Expand Up @@ -40,3 +41,49 @@ impl<'a, T: Serialize + Deserialize<'a>> Blob for T {
from_bytes(blob)
}
}

pub mod recursive {
use super::{from_bytes, to_bytes};
use serde::{
de::{self, Deserializer, Visitor},
ser::{Error, Serializer},
Deserialize, Serialize,
};
use std::{fmt, marker::PhantomData};

pub fn serialize<T, S>(obj: &T, serializer: S) -> Result<S::Ok, S::Error>
where
T: Serialize,
S: Serializer,
{
serializer.serialize_bytes(&to_bytes(obj).map_err(S::Error::custom)?)
}

pub fn deserialize<'de, T, D>(deserialize: D) -> Result<T, D::Error>
where
T: Deserialize<'de>,
D: Deserializer<'de>,
{
struct RecursiveVisitor<T>(PhantomData<T>);

impl<'de, T> Visitor<'de> for RecursiveVisitor<T>
where
T: Deserialize<'de>,
{
type Value = T;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("an integer between -2^31 and 2^31")
}

fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
from_bytes(value).map_err(E::custom)
}
}

deserialize.deserialize_bytes(RecursiveVisitor(PhantomData::<T>))
}
}
8 changes: 5 additions & 3 deletions src/proto/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ impl<'a, W: io::Write> ser::Serializer for &'a mut Serializer<W> {
type SerializeStruct = Self;
type SerializeStructVariant = Self;

fn serialize_bool(self, _v: bool) -> ProtoResult<()> {
unimplemented!()
fn serialize_bool(self, v: bool) -> ProtoResult<()> {
self.writer
.write_u8(if v { 1 } else { 0 })
.map_err(Into::into)
}

fn serialize_i8(self, v: i8) -> ProtoResult<()> {
Expand Down Expand Up @@ -76,11 +78,11 @@ impl<'a, W: io::Write> ser::Serializer for &'a mut Serializer<W> {
}

fn serialize_str(self, v: &str) -> ProtoResult<()> {
(v.len() as u32).serialize(&mut *self)?;
self.serialize_bytes(v.as_bytes())
}

fn serialize_bytes(self, v: &[u8]) -> ProtoResult<()> {
(v.len() as u32).serialize(&mut *self)?;
self.writer.write_all(v).map_err(Into::into)
}

Expand Down
6 changes: 6 additions & 0 deletions src/proto/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::extension::SessionBind;
use super::message::{Extension, Identity, Message, SignRequest};
use super::private_key::*;
use super::public_key::*;
Expand Down Expand Up @@ -178,4 +179,9 @@ fn test_extension() {
assert_eq!(extension.extension_type, "[email protected]");
let out = to_bytes(&extension).unwrap();
assert_eq!(extension_bytes, out);

let session_bind: SessionBind = from_bytes(&extension.extension_contents.0).unwrap();

let out = to_bytes(&session_bind).unwrap();
assert_eq!(extension.extension_contents.0, out);
}

0 comments on commit d791a1a

Please sign in to comment.