Skip to content

Commit d791a1a

Browse files
authored
Merge pull request #11 from baloo/baloo/session-bind
Fixup extension serialization
2 parents 57d5fe5 + 58024c8 commit d791a1a

File tree

5 files changed

+92
-4
lines changed

5 files changed

+92
-4
lines changed

src/proto/extension.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use serde::{Deserialize, Serialize};
2+
3+
use super::recursive;
4+
use super::signature::Signature;
5+
6+
/// SSH key
7+
#[derive(Debug, Clone, Serialize, Deserialize)]
8+
pub struct SshKey {
9+
pub alg: String,
10+
pub blob: Vec<u8>,
11+
}
12+
13+
/// [email protected] extension
14+
///
15+
/// This extension allows a ssh client to bind an agent connection to a
16+
/// particular SSH session.
17+
///
18+
/// Spec:
19+
/// <https://github.com/openssh/openssh-portable/blob/cbbdf868bce431a59e2fa36ca244d5739429408d/PROTOCOL.agent#L6>
20+
#[derive(Debug, Clone, Serialize, Deserialize)]
21+
pub struct SessionBind {
22+
#[serde(with = "recursive")]
23+
pub host_key: SshKey,
24+
pub session_id: Vec<u8>,
25+
#[serde(with = "recursive")]
26+
pub signature: Signature,
27+
pub is_forwarding: bool,
28+
}

src/proto/message.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use serde::de::{Deserializer, Visitor};
2+
use serde::ser::SerializeTuple;
23
use serde::{Deserialize, Serialize};
34

45
use super::private_key::PrivateKey;
@@ -65,7 +66,11 @@ impl Serialize for ExtensionContents {
6566
where
6667
S: serde::Serializer,
6768
{
68-
serializer.serialize_bytes(&self.0)
69+
let mut seq = serializer.serialize_tuple(self.0.len())?;
70+
for i in &self.0 {
71+
seq.serialize_element(i)?;
72+
}
73+
seq.end()
6974
}
7075
}
7176

src/proto/mod.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub mod ser;
44
#[macro_use]
55
pub mod key_type;
66
pub mod error;
7+
pub mod extension;
78
pub mod message;
89
pub mod private_key;
910
pub mod public_key;
@@ -40,3 +41,49 @@ impl<'a, T: Serialize + Deserialize<'a>> Blob for T {
4041
from_bytes(blob)
4142
}
4243
}
44+
45+
pub mod recursive {
46+
use super::{from_bytes, to_bytes};
47+
use serde::{
48+
de::{self, Deserializer, Visitor},
49+
ser::{Error, Serializer},
50+
Deserialize, Serialize,
51+
};
52+
use std::{fmt, marker::PhantomData};
53+
54+
pub fn serialize<T, S>(obj: &T, serializer: S) -> Result<S::Ok, S::Error>
55+
where
56+
T: Serialize,
57+
S: Serializer,
58+
{
59+
serializer.serialize_bytes(&to_bytes(obj).map_err(S::Error::custom)?)
60+
}
61+
62+
pub fn deserialize<'de, T, D>(deserialize: D) -> Result<T, D::Error>
63+
where
64+
T: Deserialize<'de>,
65+
D: Deserializer<'de>,
66+
{
67+
struct RecursiveVisitor<T>(PhantomData<T>);
68+
69+
impl<'de, T> Visitor<'de> for RecursiveVisitor<T>
70+
where
71+
T: Deserialize<'de>,
72+
{
73+
type Value = T;
74+
75+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
76+
formatter.write_str("an integer between -2^31 and 2^31")
77+
}
78+
79+
fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
80+
where
81+
E: de::Error,
82+
{
83+
from_bytes(value).map_err(E::custom)
84+
}
85+
}
86+
87+
deserialize.deserialize_bytes(RecursiveVisitor(PhantomData::<T>))
88+
}
89+
}

src/proto/ser.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ impl<'a, W: io::Write> ser::Serializer for &'a mut Serializer<W> {
2727
type SerializeStruct = Self;
2828
type SerializeStructVariant = Self;
2929

30-
fn serialize_bool(self, _v: bool) -> ProtoResult<()> {
31-
unimplemented!()
30+
fn serialize_bool(self, v: bool) -> ProtoResult<()> {
31+
self.writer
32+
.write_u8(if v { 1 } else { 0 })
33+
.map_err(Into::into)
3234
}
3335

3436
fn serialize_i8(self, v: i8) -> ProtoResult<()> {
@@ -76,11 +78,11 @@ impl<'a, W: io::Write> ser::Serializer for &'a mut Serializer<W> {
7678
}
7779

7880
fn serialize_str(self, v: &str) -> ProtoResult<()> {
79-
(v.len() as u32).serialize(&mut *self)?;
8081
self.serialize_bytes(v.as_bytes())
8182
}
8283

8384
fn serialize_bytes(self, v: &[u8]) -> ProtoResult<()> {
85+
(v.len() as u32).serialize(&mut *self)?;
8486
self.writer.write_all(v).map_err(Into::into)
8587
}
8688

src/proto/tests/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::extension::SessionBind;
12
use super::message::{Extension, Identity, Message, SignRequest};
23
use super::private_key::*;
34
use super::public_key::*;
@@ -178,4 +179,9 @@ fn test_extension() {
178179
assert_eq!(extension.extension_type, "[email protected]");
179180
let out = to_bytes(&extension).unwrap();
180181
assert_eq!(extension_bytes, out);
182+
183+
let session_bind: SessionBind = from_bytes(&extension.extension_contents.0).unwrap();
184+
185+
let out = to_bytes(&session_bind).unwrap();
186+
assert_eq!(extension.extension_contents.0, out);
181187
}

0 commit comments

Comments
 (0)