Skip to content

Commit e66f80d

Browse files
committed
Impl traits for session persistence
1 parent f23175c commit e66f80d

File tree

3 files changed

+140
-2
lines changed

3 files changed

+140
-2
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

payjoin/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ exclude = ["tests"]
1818
send = []
1919
receive = ["rand"]
2020
base64 = ["bitcoin/base64"]
21-
v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp"]
21+
v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp", "serde"]
2222

2323
[dependencies]
2424
bitcoin = { version = "0.30.0", features = ["base64"] }
@@ -28,6 +28,7 @@ log = { version = "0.4.14"}
2828
ohttp = { version = "0.4.0", optional = true }
2929
bhttp = { version = "0.4.0", optional = true }
3030
rand = { version = "0.8.4", optional = true }
31+
serde = { version = "1.0.186", default-features = false, optional = true }
3132
url = "2.2.2"
3233

3334
[dev-dependencies]

payjoin/src/receive/v2.rs

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use std::collections::HashMap;
22

33
use bitcoin::psbt::Psbt;
44
use bitcoin::{base64, Amount, FeeRate, OutPoint, Script, TxOut};
5+
use serde::ser::SerializeStruct;
6+
use serde::{Deserialize, Serialize, Serializer};
57

68
use super::{Error, InternalRequestError, RequestError, SelectionError};
79
use crate::psbt::PsbtExt;
@@ -107,14 +109,146 @@ fn subdirectory(pubkey: &bitcoin::secp256k1::PublicKey) -> String {
107109
pubkey_base64
108110
}
109111

110-
#[derive(Debug)]
112+
#[derive(Debug, Clone, PartialEq)]
111113
pub struct Enrolled {
112114
relay_url: url::Url,
113115
ohttp_config: Vec<u8>,
114116
ohttp_proxy: url::Url,
115117
s: bitcoin::secp256k1::KeyPair,
116118
}
117119

120+
impl Serialize for Enrolled {
121+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122+
where
123+
S: Serializer,
124+
{
125+
let mut state = serializer.serialize_struct("Enrolled", 4)?;
126+
state.serialize_field("relay_url", &self.relay_url.to_string())?;
127+
state.serialize_field("ohttp_config", &self.ohttp_config)?;
128+
state.serialize_field("ohttp_proxy", &self.ohttp_proxy.to_string())?;
129+
state.serialize_field("s", &self.s.secret_key().secret_bytes())?;
130+
131+
state.end()
132+
}
133+
}
134+
135+
use std::fmt;
136+
use std::str::FromStr;
137+
138+
use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
139+
140+
impl<'de> Deserialize<'de> for Enrolled {
141+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
142+
where
143+
D: Deserializer<'de>,
144+
{
145+
enum Field {
146+
RelayUrl,
147+
OhttpConfig,
148+
OhttpProxy,
149+
S,
150+
}
151+
152+
impl<'de> Deserialize<'de> for Field {
153+
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
154+
where
155+
D: Deserializer<'de>,
156+
{
157+
struct FieldVisitor;
158+
159+
impl<'de> Visitor<'de> for FieldVisitor {
160+
type Value = Field;
161+
162+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
163+
formatter.write_str("`relay_url`, `ohttp_config`, `ohttp_proxy`, or `s`")
164+
}
165+
166+
fn visit_str<E>(self, value: &str) -> Result<Field, E>
167+
where
168+
E: de::Error,
169+
{
170+
match value {
171+
"relay_url" => Ok(Field::RelayUrl),
172+
"ohttp_config" => Ok(Field::OhttpConfig),
173+
"ohttp_proxy" => Ok(Field::OhttpProxy),
174+
"s" => Ok(Field::S),
175+
_ => Err(de::Error::unknown_field(value, FIELDS)),
176+
}
177+
}
178+
}
179+
180+
deserializer.deserialize_identifier(FieldVisitor)
181+
}
182+
}
183+
184+
struct EnrolledVisitor;
185+
186+
impl<'de> Visitor<'de> for EnrolledVisitor {
187+
type Value = Enrolled;
188+
189+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
190+
formatter.write_str("struct Enrolled")
191+
}
192+
193+
fn visit_map<V>(self, mut map: V) -> Result<Enrolled, V::Error>
194+
where
195+
V: MapAccess<'de>,
196+
{
197+
let mut relay_url = None;
198+
let mut ohttp_config = None;
199+
let mut ohttp_proxy = None;
200+
let mut s = None;
201+
while let Some(key) = map.next_key()? {
202+
match key {
203+
Field::RelayUrl => {
204+
if relay_url.is_some() {
205+
return Err(de::Error::duplicate_field("relay_url"));
206+
}
207+
let url_str: String = map.next_value()?;
208+
relay_url = Some(url::Url::parse(&url_str).map_err(de::Error::custom)?);
209+
}
210+
Field::OhttpConfig => {
211+
if ohttp_config.is_some() {
212+
return Err(de::Error::duplicate_field("ohttp_config"));
213+
}
214+
ohttp_config = Some(map.next_value()?);
215+
}
216+
Field::OhttpProxy => {
217+
if ohttp_proxy.is_some() {
218+
return Err(de::Error::duplicate_field("ohttp_proxy"));
219+
}
220+
let proxy_str: String = map.next_value()?;
221+
ohttp_proxy =
222+
Some(url::Url::parse(&proxy_str).map_err(de::Error::custom)?);
223+
}
224+
Field::S => {
225+
if s.is_some() {
226+
return Err(de::Error::duplicate_field("s"));
227+
}
228+
let s_bytes: Vec<u8> = map.next_value()?;
229+
let secp = bitcoin::secp256k1::Secp256k1::new();
230+
s = Some(
231+
bitcoin::secp256k1::KeyPair::from_seckey_slice(&secp, &s_bytes)
232+
.map_err(de::Error::custom)?,
233+
);
234+
}
235+
}
236+
}
237+
let relay_url = relay_url.ok_or_else(|| de::Error::missing_field("relay_url"))?;
238+
let ohttp_config =
239+
ohttp_config.ok_or_else(|| de::Error::missing_field("ohttp_config"))?;
240+
let ohttp_proxy =
241+
ohttp_proxy.ok_or_else(|| de::Error::missing_field("ohttp_proxy"))?;
242+
let s = s.ok_or_else(|| de::Error::missing_field("s"))?;
243+
Ok(Enrolled { relay_url, ohttp_config, ohttp_proxy, s })
244+
}
245+
}
246+
247+
const FIELDS: &'static [&'static str] = &["relay_url", "ohttp_config", "ohttp_proxy", "s"];
248+
deserializer.deserialize_struct("Enrolled", FIELDS, EnrolledVisitor)
249+
}
250+
}
251+
118252
impl Enrolled {
119253
pub fn extract_req(&self) -> Result<(Request, ohttp::ClientResponse), Error> {
120254
let (body, ohttp_ctx) = self.fallback_req_body()?;
@@ -174,6 +308,8 @@ impl Enrolled {
174308
crate::v2::ohttp_encapsulate(&self.ohttp_config, "GET", &self.fallback_target(), None)
175309
}
176310

311+
pub fn pubkey(&self) -> [u8; 33] { self.s.public_key().serialize() }
312+
177313
pub fn fallback_target(&self) -> String {
178314
let pubkey = &self.s.public_key().serialize();
179315
let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);

0 commit comments

Comments
 (0)