Skip to content

Commit 9eb5cfc

Browse files
committed
Handle v2 errors (v1 fails)
1 parent 7ddabb1 commit 9eb5cfc

File tree

10 files changed

+286
-122
lines changed

10 files changed

+286
-122
lines changed

payjoin-cli/seen_inputs.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
["c93eb8f0c617f1150bdf311f594774c7c50a9518e954b83b5424753426d91a5e:1"][["c93eb8f0c617f1150bdf311f594774c7c50a9518e954b83b5424753426d91a5e:1"]

payjoin-cli/src/app.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,20 @@ impl App {
9191
&self,
9292
client: &reqwest::blocking::Client,
9393
enroll_context: &mut EnrollContext,
94-
) -> Result<UncheckedProposal, reqwest::Error> {
94+
) -> Result<UncheckedProposal> {
9595
loop {
96-
let (payjoin_get_body, context) = enroll_context.payjoin_get_body();
96+
let (payjoin_get_body, context) = enroll_context
97+
.payjoin_get_body()
98+
.map_err(|e| anyhow!("Failed to create payjoin GET body: {}", e))?;
9799
let ohttp_response =
98100
client.post(&self.config.ohttp_proxy).body(payjoin_get_body).send()?;
99101
let ohttp_response = ohttp_response.bytes()?;
100-
let proposal =
101-
enroll_context.parse_relay_response(ohttp_response.as_ref(), context).unwrap();
102+
let proposal = enroll_context
103+
.parse_relay_response(ohttp_response.as_ref(), context)
104+
.map_err(|e| anyhow!("parse error {}", e))?;
105+
log::debug!("got response");
102106
match proposal {
103-
Some(proposal) => return Ok(proposal),
107+
Some(proposal) => break Ok(proposal),
104108
None => std::thread::sleep(std::time::Duration::from_secs(5)),
105109
}
106110
}
@@ -229,17 +233,19 @@ impl App {
229233
.build()
230234
.with_context(|| "Failed to build reqwest http client")?;
231235
log::debug!("Awaiting request");
232-
let _enroll = client.post(&self.config.pj_endpoint).body(context.enroll_body()).send()?;
236+
let (body, _) = context.enroll_body().unwrap();
237+
let _enroll = client.post(&self.config.pj_endpoint).body(body).send()?;
233238

234239
log::debug!("Awaiting proposal");
235240
let res = self.long_poll_get(&client, &mut context)?;
236241
log::debug!("Received request");
237-
let payjoin_proposal = self
238-
.process_proposal(proposal)
239-
.map_err(|e| anyhow!("Failed to process UncheckedProposal {}", e))?;
240-
let payjoin_endpoint = format!("{}/{}/receive", self.config.pj_endpoint, pubkey_base64);
241-
let (body, ohttp_ctx) =
242-
payjoin_proposal.extract_v2_req(&self.config.ohttp_config, &payjoin_endpoint);
242+
let payjoin_proposal =
243+
self.process_proposal(res).map_err(|e| anyhow!("Failed to process proposal {}", e))?;
244+
log::debug!("Posting payjoin back");
245+
let receive_endpoint = format!("{}/{}", self.config.pj_endpoint, context.payjoin_subdir());
246+
let (body, ohttp_ctx) = payjoin_proposal
247+
.extract_v2_req(&self.config.ohttp_config, &receive_endpoint)
248+
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
243249
let res = client
244250
.post(&self.config.ohttp_proxy)
245251
.body(body)

payjoin-relay/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ edition = "2021"
88
[dependencies]
99
hyper = { version = "0.14", features = ["full"] }
1010
anyhow = "1.0.71"
11-
payjoin = { path = "../payjoin", features = ["base64"] }
11+
payjoin = { path = "../payjoin", features = ["base64", "v2"] }
1212
# ohttp = "0.4.0"
1313
ohttp = { path = "../../ohttp/ohttp" }
1414
bhttp = { version = "0.4.0", features = ["http"] }

payjoin-relay/src/main.rs

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use std::env;
22
use std::net::SocketAddr;
3-
use std::str::FromStr;
43
use std::sync::Arc;
54

65
use anyhow::Result;
76
use hyper::service::{make_service_fn, service_fn};
87
use hyper::{Body, Method, Request, Response, StatusCode, Uri};
98
use payjoin::{base64, bitcoin};
10-
use tracing::{debug, info, trace};
9+
use tracing::{debug, error, info, trace};
1110
use tracing_subscriber::filter::LevelFilter;
1211
use tracing_subscriber::EnvFilter;
1312

@@ -72,7 +71,7 @@ fn init_ohttp() -> Result<ohttp::Server> {
7271
let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))?;
7372
let encoded_config = server_config.encode()?;
7473
let b64_config = base64::encode_config(
75-
encoded_config,
74+
&encoded_config,
7675
base64::Config::new(base64::CharacterSet::UrlSafe, false),
7776
);
7877
info!("ohttp server config base64 UrlSafe: {:?}", b64_config);
@@ -112,33 +111,41 @@ async fn handle_ohttp(
112111
) -> Result<Response<Body>, HandlerError> {
113112
// decapsulate
114113
let ohttp_body =
115-
hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?;
114+
hyper::body::to_bytes(body).await.map_err(|e| HandlerError::BadRequest(e.into()))?;
116115

117-
let (bhttp_req, res_ctx) = ohttp.decapsulate(&ohttp_body).unwrap();
116+
let (bhttp_req, res_ctx) =
117+
ohttp.decapsulate(&ohttp_body).map_err(|e| HandlerError::BadRequest(e.into()))?;
118118
let mut cursor = std::io::Cursor::new(bhttp_req);
119-
let req = bhttp::Message::read_bhttp(&mut cursor).unwrap();
119+
let req =
120+
bhttp::Message::read_bhttp(&mut cursor).map_err(|e| HandlerError::BadRequest(e.into()))?;
120121
let uri = Uri::builder()
121-
.scheme(req.control().scheme().unwrap())
122-
.authority(req.control().authority().unwrap())
123-
.path_and_query(req.control().path().unwrap())
124-
.build()
125-
.unwrap();
122+
.scheme(req.control().scheme().unwrap_or_default())
123+
.authority(req.control().authority().unwrap_or_default())
124+
.path_and_query(req.control().path().unwrap_or_default())
125+
.build()?;
126126
let body = req.content().to_vec();
127-
let mut http_req = Request::builder().uri(uri).method(req.control().method().unwrap());
127+
let mut http_req =
128+
Request::builder().uri(uri).method(req.control().method().unwrap_or_default());
128129
for header in req.header().fields() {
129130
http_req = http_req.header(header.name(), header.value())
130131
}
131-
let request = http_req.body(Body::from(body)).unwrap();
132+
let request = http_req.body(Body::from(body))?;
132133

133134
let response = handle_v2(pool, request).await?;
134135

135136
let (parts, body) = response.into_parts();
136137
let mut bhttp_res = bhttp::Message::response(parts.status.as_u16());
137-
let full_body = hyper::body::to_bytes(body).await.unwrap();
138+
let full_body = hyper::body::to_bytes(body)
139+
.await
140+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
138141
bhttp_res.write_content(&full_body);
139142
let mut bhttp_bytes = Vec::new();
140-
bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).unwrap();
141-
let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).unwrap();
143+
bhttp_res
144+
.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes)
145+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
146+
let ohttp_res = res_ctx
147+
.encapsulate(&bhttp_bytes)
148+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
142149
Ok(Response::new(Body::from(ohttp_res)))
143150
}
144151

@@ -159,16 +166,22 @@ async fn handle_v2(pool: DbPool, req: Request<Body>) -> Result<Response<Body>, H
159166

160167
enum HandlerError {
161168
PayloadTooLarge,
162-
InternalServerError,
163-
BadRequest,
169+
InternalServerError(Box<dyn std::error::Error>),
170+
BadRequest(Box<dyn std::error::Error>),
164171
}
165172

166173
impl HandlerError {
167174
fn to_response(&self) -> Response<Body> {
168175
let status = match self {
169176
HandlerError::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
170-
HandlerError::BadRequest => StatusCode::BAD_REQUEST,
171-
_ => StatusCode::INTERNAL_SERVER_ERROR,
177+
Self::InternalServerError(e) => {
178+
error!("Internal server error: {}", e);
179+
StatusCode::INTERNAL_SERVER_ERROR
180+
}
181+
Self::BadRequest(e) => {
182+
error!("Bad request: {}", e);
183+
StatusCode::BAD_REQUEST
184+
}
172185
};
173186

174187
let mut res = Response::new(Body::empty());
@@ -178,17 +191,19 @@ impl HandlerError {
178191
}
179192

180193
impl From<hyper::http::Error> for HandlerError {
181-
fn from(_: hyper::http::Error) -> Self { HandlerError::InternalServerError }
194+
fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) }
182195
}
183196

184197
async fn post_enroll(body: Body) -> Result<Response<Body>, HandlerError> {
185198
let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
186-
let bytes = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::BadRequest)?;
187-
let base64_id = String::from_utf8(bytes.to_vec()).map_err(|_| HandlerError::BadRequest)?;
188-
let pubkey_bytes: Vec<u8> =
189-
base64::decode_config(base64_id, b64_config).map_err(|_| HandlerError::BadRequest)?;
199+
let bytes =
200+
hyper::body::to_bytes(body).await.map_err(|e| HandlerError::BadRequest(e.into()))?;
201+
let base64_id =
202+
String::from_utf8(bytes.to_vec()).map_err(|e| HandlerError::BadRequest(e.into()))?;
203+
let pubkey_bytes: Vec<u8> = base64::decode_config(base64_id, b64_config)
204+
.map_err(|e| HandlerError::BadRequest(e.into()))?;
190205
let pubkey = bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes)
191-
.map_err(|_| HandlerError::BadRequest)?;
206+
.map_err(|e| HandlerError::BadRequest(e.into()))?;
192207
tracing::info!("Enrolled valid pubkey: {:?}", pubkey);
193208
Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?)
194209
}
@@ -223,20 +238,23 @@ async fn post_fallback(
223238
) -> Result<Response<Body>, HandlerError> {
224239
tracing::trace!("Post fallback");
225240
let id = shorten_string(id);
226-
let req = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?;
241+
let req = hyper::body::to_bytes(body)
242+
.await
243+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
244+
227245
if req.len() > MAX_BUFFER_SIZE {
228246
return Err(HandlerError::PayloadTooLarge);
229247
}
230248

231249
match pool.push_req(&id, req.into()).await {
232250
Ok(_) => (),
233-
Err(_) => return Err(HandlerError::BadRequest),
251+
Err(e) => return Err(HandlerError::BadRequest(e.into())),
234252
};
235253

236254
match pool.peek_res(&id).await {
237255
Some(result) => match result {
238256
Ok(buffered_res) => Ok(Response::new(Body::from(buffered_res))),
239-
Err(_) => Err(HandlerError::BadRequest),
257+
Err(e) => Err(HandlerError::BadRequest(e.into())),
240258
},
241259
None => Ok(none_response),
242260
}
@@ -247,19 +265,21 @@ async fn get_fallback(id: &str, pool: DbPool) -> Result<Response<Body>, HandlerE
247265
match pool.peek_req(&id).await {
248266
Some(result) => match result {
249267
Ok(buffered_req) => Ok(Response::new(Body::from(buffered_req))),
250-
Err(_) => Err(HandlerError::BadRequest),
268+
Err(e) => Err(HandlerError::BadRequest(e.into())),
251269
},
252270
None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?),
253271
}
254272
}
255273

256274
async fn post_payjoin(id: &str, body: Body, pool: DbPool) -> Result<Response<Body>, HandlerError> {
257275
let id = shorten_string(id);
258-
let res = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?;
276+
let res = hyper::body::to_bytes(body)
277+
.await
278+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
259279

260280
match pool.push_res(&id, res.into()).await {
261281
Ok(_) => Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?),
262-
Err(_) => Err(HandlerError::BadRequest),
282+
Err(e) => Err(HandlerError::BadRequest(e.into())),
263283
}
264284
}
265285

payjoin-relay/tests/integration.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ mod integration {
8282
// Enroll with relay
8383
let mut enroll_ctx =
8484
EnrollContext::from_relay_config(&RELAY_URL, &ohttp_config_base64, &RELAY_URL);
85-
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body();
85+
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body().expect("Failed to enroll");
8686
let _ohttp_response =
8787
http.post(RELAY_URL).body(enroll_body).send().await.expect("Failed to send request");
8888
log::debug!("Enrolled receiver");
@@ -150,7 +150,8 @@ mod integration {
150150
// **********************
151151
// Inside the Receiver:
152152
// GET fallback_psbt
153-
let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body();
153+
let (payjoin_get_body, ohttp_req_ctx) =
154+
enroll_ctx.payjoin_get_body().expect("Failed to get fallback");
154155
let ohttp_response = http
155156
.post(RELAY_URL)
156157
.body(payjoin_get_body)
@@ -162,18 +163,20 @@ mod integration {
162163
);
163164
let proposal = enroll_ctx.parse_relay_response(reader, ohttp_req_ctx).unwrap().unwrap();
164165
let payjoin_proposal = handle_proposal(proposal, receiver);
165-
166-
let (body, _ohttp_ctx) = payjoin_proposal.extract_v2_req(
167-
&ohttp_config_base64,
168-
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
169-
);
166+
let (body, _ohttp_ctx) = payjoin_proposal
167+
.extract_v2_req(
168+
&ohttp_config_base64,
169+
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
170+
)
171+
.expect("Failed to extract v2 req");
170172
let _ohttp_response =
171173
http.post(RELAY_URL).body(body).send().await.expect("Failed to post payjoin_psbt");
172174

173175
// **********************
174176
// Inside the Sender:
175177
// Sender checks, signs, finalizes, extracts, and broadcasts
176178
log::info!("replay POST fallback psbt for payjoin_psbt response");
179+
log::info!("Req body {:#?}", &req.body);
177180
let response = http
178181
.post(req.url.as_str())
179182
.body(req.body.clone())
@@ -256,7 +259,7 @@ mod integration {
256259
// Enroll with relay
257260
let mut enroll_ctx =
258261
EnrollContext::from_relay_config(&RELAY_URL, &ohttp_config_base64, &RELAY_URL);
259-
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body();
262+
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body().unwrap();
260263
let enroll =
261264
http.post(RELAY_URL).body(enroll_body).send().await.expect("Failed to send request");
262265

@@ -331,7 +334,7 @@ mod integration {
331334
.expect("Failed to build reqwest http client");
332335

333336
let proposal = loop {
334-
let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body();
337+
let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body().unwrap();
335338
let enc_response = http
336339
.post(RELAY_URL)
337340
.body(payjoin_get_body)
@@ -355,10 +358,12 @@ mod integration {
355358
debug!("handle relay response");
356359
let response = handle_proposal(proposal, receiver);
357360
debug!("Post payjoin_psbt to relay");
358-
let (body, _ohttp_ctx) = response.extract_v2_req(
359-
&ohttp_config_base64,
360-
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
361-
);
361+
let (body, _ohttp_ctx) = response
362+
.extract_v2_req(
363+
&ohttp_config_base64,
364+
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
365+
)
366+
.unwrap();
362367
// Respond with payjoin psbt within the time window the sender is willing to wait
363368
let response = http.post(RELAY_URL).body(body).send().await;
364369
debug!("POSTed with payjoin_psbt");

payjoin/src/receive/error.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@ pub enum Error {
77
BadRequest(RequestError),
88
// To be returned as HTTP 500
99
Server(Box<dyn error::Error>),
10+
// V2 d/encapsulation failed
11+
#[cfg(feature = "v2")]
12+
V2(crate::v2::Error),
1013
}
1114

1215
impl fmt::Display for Error {
1316
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1417
match &self {
1518
Self::BadRequest(e) => e.fmt(f),
1619
Self::Server(e) => write!(f, "Internal Server Error: {}", e),
20+
#[cfg(feature = "v2")]
21+
Self::V2(e) => e.fmt(f),
1722
}
1823
}
1924
}
@@ -23,6 +28,8 @@ impl error::Error for Error {
2328
match &self {
2429
Self::BadRequest(_) => None,
2530
Self::Server(e) => Some(e.as_ref()),
31+
#[cfg(feature = "v2")]
32+
Self::V2(e) => Some(e),
2633
}
2734
}
2835
}
@@ -31,6 +38,15 @@ impl From<RequestError> for Error {
3138
fn from(e: RequestError) -> Self { Error::BadRequest(e) }
3239
}
3340

41+
impl From<InternalRequestError> for Error {
42+
fn from(e: InternalRequestError) -> Self { Error::BadRequest(e.into()) }
43+
}
44+
45+
impl From<crate::v2::Error> for Error {
46+
#[cfg(feature = "v2")]
47+
fn from(e: crate::v2::Error) -> Self { Error::V2(e) }
48+
}
49+
3450
/// Error that may occur when the request from sender is malformed.
3551
///
3652
/// This is currently opaque type because we aren't sure which variants will stay.

0 commit comments

Comments
 (0)