Skip to content

Commit 28cef51

Browse files
committed
Impl traits for session persistence
1 parent 8e5a98c commit 28cef51

File tree

3 files changed

+131
-2
lines changed

3 files changed

+131
-2
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

payjoin/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ exclude = ["tests"]
1818
send = []
1919
receive = ["rand"]
2020
base64 = ["bitcoin/base64"]
21-
v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp"]
21+
v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp", "serde"]
2222

2323
[dependencies]
2424
bitcoin = { version = "0.30.0", features = ["base64"] }
@@ -28,6 +28,7 @@ log = { version = "0.4.14"}
2828
ohttp = { version = "0.4.0", optional = true }
2929
bhttp = { version = "0.4.0", optional = true }
3030
rand = { version = "0.8.4", optional = true }
31+
serde = { version = "1.0.186", default-features = false, optional = true }
3132
url = "2.2.2"
3233

3334
[dev-dependencies]

payjoin/src/receive/v2.rs

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use std::collections::HashMap;
22

33
use bitcoin::psbt::Psbt;
44
use bitcoin::{base64, Amount, FeeRate, OutPoint, Script, TxOut};
5+
use serde::ser::SerializeStruct;
6+
use serde::{Serialize, Deserialize, Serializer};
57

68
use super::{Error, InternalRequestError, RequestError, SelectionError};
79
use crate::psbt::PsbtExt;
@@ -107,14 +109,135 @@ fn subdirectory(pubkey: &bitcoin::secp256k1::PublicKey) -> String {
107109
pubkey_base64
108110
}
109111

110-
#[derive(Debug)]
112+
#[derive(Debug, Clone, PartialEq)]
111113
pub struct Enrolled {
112114
relay_url: url::Url,
113115
ohttp_config: Vec<u8>,
114116
ohttp_proxy: url::Url,
115117
s: bitcoin::secp256k1::KeyPair,
116118
}
117119

120+
impl Serialize for Enrolled {
121+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122+
where
123+
S: Serializer,
124+
{
125+
let mut state = serializer.serialize_struct("Enrolled", 4)?;
126+
state.serialize_field("relay_url", &self.relay_url.to_string())?;
127+
state.serialize_field("ohttp_config", &self.ohttp_config)?;
128+
state.serialize_field("ohttp_proxy", &self.ohttp_proxy.to_string())?;
129+
state.serialize_field("s", &self.s.secret_key().secret_bytes())?;
130+
131+
state.end()
132+
}
133+
}
134+
135+
use serde::de::{self, Deserializer, Visitor, SeqAccess, MapAccess};
136+
use std::fmt;
137+
use std::str::FromStr;
138+
139+
impl<'de> Deserialize<'de> for Enrolled {
140+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
141+
where
142+
D: Deserializer<'de>,
143+
{
144+
enum Field { RelayUrl, OhttpConfig, OhttpProxy, S }
145+
146+
impl<'de> Deserialize<'de> for Field {
147+
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
148+
where
149+
D: Deserializer<'de>,
150+
{
151+
struct FieldVisitor;
152+
153+
impl<'de> Visitor<'de> for FieldVisitor {
154+
type Value = Field;
155+
156+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
157+
formatter.write_str("`relay_url`, `ohttp_config`, `ohttp_proxy`, or `s`")
158+
}
159+
160+
fn visit_str<E>(self, value: &str) -> Result<Field, E>
161+
where
162+
E: de::Error,
163+
{
164+
match value {
165+
"relay_url" => Ok(Field::RelayUrl),
166+
"ohttp_config" => Ok(Field::OhttpConfig),
167+
"ohttp_proxy" => Ok(Field::OhttpProxy),
168+
"s" => Ok(Field::S),
169+
_ => Err(de::Error::unknown_field(value, FIELDS)),
170+
}
171+
}
172+
}
173+
174+
deserializer.deserialize_identifier(FieldVisitor)
175+
}
176+
}
177+
178+
struct EnrolledVisitor;
179+
180+
impl<'de> Visitor<'de> for EnrolledVisitor {
181+
type Value = Enrolled;
182+
183+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
184+
formatter.write_str("struct Enrolled")
185+
}
186+
187+
fn visit_map<V>(self, mut map: V) -> Result<Enrolled, V::Error>
188+
where
189+
V: MapAccess<'de>,
190+
{
191+
let mut relay_url = None;
192+
let mut ohttp_config = None;
193+
let mut ohttp_proxy = None;
194+
let mut s = None;
195+
while let Some(key) = map.next_key()? {
196+
match key {
197+
Field::RelayUrl => {
198+
if relay_url.is_some() {
199+
return Err(de::Error::duplicate_field("relay_url"));
200+
}
201+
let url_str: String = map.next_value()?;
202+
relay_url = Some(url::Url::parse(&url_str).map_err(de::Error::custom)?);
203+
},
204+
Field::OhttpConfig => {
205+
if ohttp_config.is_some() {
206+
return Err(de::Error::duplicate_field("ohttp_config"));
207+
}
208+
ohttp_config = Some(map.next_value()?);
209+
},
210+
Field::OhttpProxy => {
211+
if ohttp_proxy.is_some() {
212+
return Err(de::Error::duplicate_field("ohttp_proxy"));
213+
}
214+
let proxy_str: String = map.next_value()?;
215+
ohttp_proxy = Some(url::Url::parse(&proxy_str).map_err(de::Error::custom)?);
216+
},
217+
Field::S => {
218+
if s.is_some() {
219+
return Err(de::Error::duplicate_field("s"));
220+
}
221+
let s_bytes: Vec<u8> = map.next_value()?;
222+
let secp = bitcoin::secp256k1::Secp256k1::new();
223+
s = Some(bitcoin::secp256k1::KeyPair::from_seckey_slice(&secp, &s_bytes)
224+
.map_err(de::Error::custom)?);
225+
}
226+
}
227+
}
228+
let relay_url = relay_url.ok_or_else(|| de::Error::missing_field("relay_url"))?;
229+
let ohttp_config = ohttp_config.ok_or_else(|| de::Error::missing_field("ohttp_config"))?;
230+
let ohttp_proxy = ohttp_proxy.ok_or_else(|| de::Error::missing_field("ohttp_proxy"))?;
231+
let s = s.ok_or_else(|| de::Error::missing_field("s"))?;
232+
Ok(Enrolled { relay_url, ohttp_config, ohttp_proxy, s })
233+
}
234+
}
235+
236+
const FIELDS: &'static [&'static str] = &["relay_url", "ohttp_config", "ohttp_proxy", "s"];
237+
deserializer.deserialize_struct("Enrolled", FIELDS, EnrolledVisitor)
238+
}
239+
}
240+
118241
impl Enrolled {
119242
pub fn extract_req(&self) -> Result<(Request, ohttp::ClientResponse), Error> {
120243
let (body, ohttp_ctx) = self.fallback_req_body()?;
@@ -174,6 +297,10 @@ impl Enrolled {
174297
crate::v2::ohttp_encapsulate(&self.ohttp_config, "GET", &self.fallback_target(), None)
175298
}
176299

300+
pub fn pubkey(&self) -> [u8; 33] {
301+
self.s.public_key().serialize()
302+
}
303+
177304
pub fn fallback_target(&self) -> String {
178305
let pubkey = &self.s.public_key().serialize();
179306
let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);

0 commit comments

Comments
 (0)