Skip to content

Commit aca9982

Browse files
committed
Handle v2 errors (v1 fails)
1 parent 5243442 commit aca9982

File tree

9 files changed

+281
-118
lines changed

9 files changed

+281
-118
lines changed

payjoin-cli/seen_inputs.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
["c93eb8f0c617f1150bdf311f594774c7c50a9518e954b83b5424753426d91a5e:1"][["c93eb8f0c617f1150bdf311f594774c7c50a9518e954b83b5424753426d91a5e:1"]

payjoin-cli/src/app.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,20 @@ impl App {
9191
&self,
9292
client: &reqwest::blocking::Client,
9393
enroll_context: &mut EnrollContext,
94-
) -> Result<UncheckedProposal, reqwest::Error> {
94+
) -> Result<UncheckedProposal> {
9595
loop {
96-
let (payjoin_get_body, context) = enroll_context.payjoin_get_body();
96+
let (payjoin_get_body, context) = enroll_context
97+
.payjoin_get_body()
98+
.map_err(|e| anyhow!("Failed to create payjoin GET body: {}", e))?;
9799
let ohttp_response =
98100
client.post(&self.config.ohttp_proxy).body(payjoin_get_body).send()?;
99101
let ohttp_response = ohttp_response.bytes()?;
100-
let proposal =
101-
enroll_context.parse_relay_response(ohttp_response.as_ref(), context).unwrap();
102+
let proposal = enroll_context
103+
.parse_relay_response(ohttp_response.as_ref(), context)
104+
.map_err(|e| anyhow!("parse error {}", e))?;
105+
log::debug!("got response");
102106
match proposal {
103-
Some(proposal) => return Ok(proposal),
107+
Some(proposal) => break Ok(proposal),
104108
None => std::thread::sleep(std::time::Duration::from_secs(5)),
105109
}
106110
}
@@ -229,17 +233,19 @@ impl App {
229233
.build()
230234
.with_context(|| "Failed to build reqwest http client")?;
231235
log::debug!("Awaiting request");
232-
let _enroll = client.post(&self.config.pj_endpoint).body(context.enroll_body()).send()?;
236+
let (body, _) = context.enroll_body().unwrap();
237+
let _enroll = client.post(&self.config.pj_endpoint).body(body).send()?;
233238

234239
log::debug!("Awaiting proposal");
235240
let res = self.long_poll_get(&client, &mut context)?;
236241
log::debug!("Received request");
237-
let payjoin_proposal = self
238-
.process_proposal(proposal)
239-
.map_err(|e| anyhow!("Failed to process UncheckedProposal {}", e))?;
240-
let payjoin_endpoint = format!("{}/{}/receive", self.config.pj_endpoint, pubkey_base64);
241-
let (body, ohttp_ctx) =
242-
payjoin_proposal.extract_v2_req(&self.config.ohttp_config, &payjoin_endpoint);
242+
let payjoin_proposal =
243+
self.process_proposal(res).map_err(|e| anyhow!("Failed to process proposal {}", e))?;
244+
log::debug!("Posting payjoin back");
245+
let receive_endpoint = format!("{}/{}", self.config.pj_endpoint, context.payjoin_subdir());
246+
let (body, ohttp_ctx) = payjoin_proposal
247+
.extract_v2_req(&self.config.ohttp_config, &receive_endpoint)
248+
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
243249
let res = client
244250
.post(&self.config.ohttp_proxy)
245251
.body(body)

payjoin-relay/src/main.rs

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use std::env;
22
use std::net::SocketAddr;
3-
use std::str::FromStr;
43
use std::sync::Arc;
54

65
use anyhow::Result;
@@ -73,7 +72,7 @@ fn init_ohttp() -> Result<ohttp::Server> {
7372
let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))?;
7473
let encoded_config = server_config.encode()?;
7574
let b64_config = base64::encode_config(
76-
encoded_config,
75+
&encoded_config,
7776
base64::Config::new(base64::CharacterSet::UrlSafe, false),
7877
);
7978
info!("ohttp server config base64 UrlSafe: {:?}", b64_config);
@@ -119,29 +118,36 @@ async fn handle_ohttp(
119118
let (bhttp_req, res_ctx) =
120119
ohttp_locked.decapsulate(&ohttp_body).map_err(|e| HandlerError::BadRequest(e.into()))?;
121120
let mut cursor = std::io::Cursor::new(bhttp_req);
122-
let req = bhttp::Message::read_bhttp(&mut cursor).unwrap();
121+
let req =
122+
bhttp::Message::read_bhttp(&mut cursor).map_err(|e| HandlerError::BadRequest(e.into()))?;
123123
let uri = Uri::builder()
124-
.scheme(req.control().scheme().unwrap())
125-
.authority(req.control().authority().unwrap())
126-
.path_and_query(req.control().path().unwrap())
127-
.build()
128-
.unwrap();
124+
.scheme(req.control().scheme().unwrap_or_default())
125+
.authority(req.control().authority().unwrap_or_default())
126+
.path_and_query(req.control().path().unwrap_or_default())
127+
.build()?;
129128
let body = req.content().to_vec();
130-
let mut http_req = Request::builder().uri(uri).method(req.control().method().unwrap());
129+
let mut http_req =
130+
Request::builder().uri(uri).method(req.control().method().unwrap_or_default());
131131
for header in req.header().fields() {
132132
http_req = http_req.header(header.name(), header.value())
133133
}
134-
let request = http_req.body(Body::from(body)).unwrap();
134+
let request = http_req.body(Body::from(body))?;
135135

136136
let response = handle_v2(pool, request).await?;
137137

138138
let (parts, body) = response.into_parts();
139139
let mut bhttp_res = bhttp::Message::response(parts.status.as_u16());
140-
let full_body = hyper::body::to_bytes(body).await.unwrap();
140+
let full_body = hyper::body::to_bytes(body)
141+
.await
142+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
141143
bhttp_res.write_content(&full_body);
142144
let mut bhttp_bytes = Vec::new();
143-
bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).unwrap();
144-
let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).unwrap();
145+
bhttp_res
146+
.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes)
147+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
148+
let ohttp_res = res_ctx
149+
.encapsulate(&bhttp_bytes)
150+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
145151
Ok(Response::new(Body::from(ohttp_res)))
146152
}
147153

@@ -162,16 +168,22 @@ async fn handle_v2(pool: DbPool, req: Request<Body>) -> Result<Response<Body>, H
162168

163169
enum HandlerError {
164170
PayloadTooLarge,
165-
InternalServerError,
166-
BadRequest,
171+
InternalServerError(Box<dyn std::error::Error>),
172+
BadRequest(Box<dyn std::error::Error>),
167173
}
168174

169175
impl HandlerError {
170176
fn to_response(&self) -> Response<Body> {
171177
let status = match self {
172178
HandlerError::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
173-
HandlerError::BadRequest => StatusCode::BAD_REQUEST,
174-
_ => StatusCode::INTERNAL_SERVER_ERROR,
179+
Self::InternalServerError(e) => {
180+
error!("Internal server error: {}", e);
181+
StatusCode::INTERNAL_SERVER_ERROR
182+
}
183+
Self::BadRequest(e) => {
184+
error!("Bad request: {}", e);
185+
StatusCode::BAD_REQUEST
186+
}
175187
};
176188

177189
let mut res = Response::new(Body::empty());
@@ -181,17 +193,19 @@ impl HandlerError {
181193
}
182194

183195
impl From<hyper::http::Error> for HandlerError {
184-
fn from(_: hyper::http::Error) -> Self { HandlerError::InternalServerError }
196+
fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) }
185197
}
186198

187199
async fn post_enroll(body: Body) -> Result<Response<Body>, HandlerError> {
188200
let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
189-
let bytes = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::BadRequest)?;
190-
let base64_id = String::from_utf8(bytes.to_vec()).map_err(|_| HandlerError::BadRequest)?;
191-
let pubkey_bytes: Vec<u8> =
192-
base64::decode_config(base64_id, b64_config).map_err(|_| HandlerError::BadRequest)?;
201+
let bytes =
202+
hyper::body::to_bytes(body).await.map_err(|e| HandlerError::BadRequest(e.into()))?;
203+
let base64_id =
204+
String::from_utf8(bytes.to_vec()).map_err(|e| HandlerError::BadRequest(e.into()))?;
205+
let pubkey_bytes: Vec<u8> = base64::decode_config(base64_id, b64_config)
206+
.map_err(|e| HandlerError::BadRequest(e.into()))?;
193207
let pubkey = bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes)
194-
.map_err(|_| HandlerError::BadRequest)?;
208+
.map_err(|e| HandlerError::BadRequest(e.into()))?;
195209
tracing::info!("Enrolled valid pubkey: {:?}", pubkey);
196210
Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?)
197211
}
@@ -226,20 +240,23 @@ async fn post_fallback(
226240
) -> Result<Response<Body>, HandlerError> {
227241
tracing::trace!("Post fallback");
228242
let id = shorten_string(id);
229-
let req = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?;
243+
let req = hyper::body::to_bytes(body)
244+
.await
245+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
246+
230247
if req.len() > MAX_BUFFER_SIZE {
231248
return Err(HandlerError::PayloadTooLarge);
232249
}
233250

234251
match pool.push_req(&id, req.into()).await {
235252
Ok(_) => (),
236-
Err(_) => return Err(HandlerError::BadRequest),
253+
Err(e) => return Err(HandlerError::BadRequest(e.into())),
237254
};
238255

239256
match pool.peek_res(&id).await {
240257
Some(result) => match result {
241258
Ok(buffered_res) => Ok(Response::new(Body::from(buffered_res))),
242-
Err(_) => Err(HandlerError::BadRequest),
259+
Err(e) => Err(HandlerError::BadRequest(e.into())),
243260
},
244261
None => Ok(none_response),
245262
}
@@ -250,19 +267,21 @@ async fn get_fallback(id: &str, pool: DbPool) -> Result<Response<Body>, HandlerE
250267
match pool.peek_req(&id).await {
251268
Some(result) => match result {
252269
Ok(buffered_req) => Ok(Response::new(Body::from(buffered_req))),
253-
Err(_) => Err(HandlerError::BadRequest),
270+
Err(e) => Err(HandlerError::BadRequest(e.into())),
254271
},
255272
None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?),
256273
}
257274
}
258275

259276
async fn post_payjoin(id: &str, body: Body, pool: DbPool) -> Result<Response<Body>, HandlerError> {
260277
let id = shorten_string(id);
261-
let res = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?;
278+
let res = hyper::body::to_bytes(body)
279+
.await
280+
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
262281

263282
match pool.push_res(&id, res.into()).await {
264283
Ok(_) => Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?),
265-
Err(_) => Err(HandlerError::BadRequest),
284+
Err(e) => Err(HandlerError::BadRequest(e.into())),
266285
}
267286
}
268287

payjoin-relay/tests/integration.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ mod integration {
8282
// Enroll with relay
8383
let mut enroll_ctx =
8484
EnrollContext::from_relay_config(&RELAY_URL, &ohttp_config_base64, &RELAY_URL);
85-
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body();
85+
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body().expect("Failed to enroll");
8686
let _ohttp_response =
8787
http.post(RELAY_URL).body(enroll_body).send().await.expect("Failed to send request");
8888
log::debug!("Enrolled receiver");
@@ -150,7 +150,8 @@ mod integration {
150150
// **********************
151151
// Inside the Receiver:
152152
// GET fallback_psbt
153-
let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body();
153+
let (payjoin_get_body, ohttp_req_ctx) =
154+
enroll_ctx.payjoin_get_body().expect("Failed to get fallback");
154155
let ohttp_response = http
155156
.post(RELAY_URL)
156157
.body(payjoin_get_body)
@@ -162,18 +163,20 @@ mod integration {
162163
);
163164
let proposal = enroll_ctx.parse_relay_response(reader, ohttp_req_ctx).unwrap().unwrap();
164165
let payjoin_proposal = handle_proposal(proposal, receiver);
165-
166-
let (body, _ohttp_ctx) = payjoin_proposal.extract_v2_req(
167-
&ohttp_config_base64,
168-
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
169-
);
166+
let (body, _ohttp_ctx) = payjoin_proposal
167+
.extract_v2_req(
168+
&ohttp_config_base64,
169+
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
170+
)
171+
.expect("Failed to extract v2 req");
170172
let _ohttp_response =
171173
http.post(RELAY_URL).body(body).send().await.expect("Failed to post payjoin_psbt");
172174

173175
// **********************
174176
// Inside the Sender:
175177
// Sender checks, signs, finalizes, extracts, and broadcasts
176178
log::info!("replay POST fallback psbt for payjoin_psbt response");
179+
log::info!("Req body {:#?}", &req.body);
177180
let response = http
178181
.post(req.url.as_str())
179182
.body(req.body.clone())
@@ -256,7 +259,7 @@ mod integration {
256259
// Enroll with relay
257260
let mut enroll_ctx =
258261
EnrollContext::from_relay_config(&RELAY_URL, &ohttp_config_base64, &RELAY_URL);
259-
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body();
262+
let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body().unwrap();
260263
let enroll =
261264
http.post(RELAY_URL).body(enroll_body).send().await.expect("Failed to send request");
262265

@@ -331,7 +334,7 @@ mod integration {
331334
.expect("Failed to build reqwest http client");
332335

333336
let proposal = loop {
334-
let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body();
337+
let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body().unwrap();
335338
let enc_response = http
336339
.post(RELAY_URL)
337340
.body(payjoin_get_body)
@@ -355,10 +358,12 @@ mod integration {
355358
debug!("handle relay response");
356359
let response = handle_proposal(proposal, receiver);
357360
debug!("Post payjoin_psbt to relay");
358-
let (body, _ohttp_ctx) = response.extract_v2_req(
359-
&ohttp_config_base64,
360-
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
361-
);
361+
let (body, _ohttp_ctx) = response
362+
.extract_v2_req(
363+
&ohttp_config_base64,
364+
&format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()),
365+
)
366+
.unwrap();
362367
// Respond with payjoin psbt within the time window the sender is willing to wait
363368
let response = http.post(RELAY_URL).body(body).send().await;
364369
debug!("POSTed with payjoin_psbt");

payjoin/src/receive/error.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@ pub enum Error {
77
BadRequest(RequestError),
88
// To be returned as HTTP 500
99
Server(Box<dyn error::Error>),
10+
// V2 d/encapsulation failed
11+
#[cfg(feature = "v2")]
12+
V2(crate::v2::Error),
1013
}
1114

1215
impl fmt::Display for Error {
1316
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1417
match &self {
1518
Self::BadRequest(e) => e.fmt(f),
1619
Self::Server(e) => write!(f, "Internal Server Error: {}", e),
20+
#[cfg(feature = "v2")]
21+
Self::V2(e) => e.fmt(f),
1722
}
1823
}
1924
}
@@ -23,6 +28,8 @@ impl error::Error for Error {
2328
match &self {
2429
Self::BadRequest(_) => None,
2530
Self::Server(e) => Some(e.as_ref()),
31+
#[cfg(feature = "v2")]
32+
Self::V2(e) => Some(e),
2633
}
2734
}
2835
}
@@ -31,6 +38,15 @@ impl From<RequestError> for Error {
3138
fn from(e: RequestError) -> Self { Error::BadRequest(e) }
3239
}
3340

41+
impl From<InternalRequestError> for Error {
42+
fn from(e: InternalRequestError) -> Self { Error::BadRequest(e.into()) }
43+
}
44+
45+
impl From<crate::v2::Error> for Error {
46+
#[cfg(feature = "v2")]
47+
fn from(e: crate::v2::Error) -> Self { Error::V2(e) }
48+
}
49+
3450
/// Error that may occur when the request from sender is malformed.
3551
///
3652
/// This is currently opaque type because we aren't sure which variants will stay.

0 commit comments

Comments
 (0)