11use std:: collections:: HashMap ;
2+ use std:: sync:: Arc ;
23
34use crate :: error:: MutinyError ;
45use crate :: storage:: MutinyStorage ;
56use core:: time:: Duration ;
7+ use gloo_net:: websocket:: futures:: WebSocket ;
68use hex_conservative:: DisplayHex ;
79use once_cell:: sync:: Lazy ;
810use payjoin:: receive:: v2:: Enrolled ;
@@ -69,16 +71,73 @@ impl<S: MutinyStorage> PayjoinStorage for S {
6971 }
7072}
7173
72- pub async fn fetch_ohttp_keys ( _ohttp_relay : Url , directory : Url ) -> Result < OhttpKeys , Error > {
73- let http_client = reqwest :: Client :: builder ( ) . build ( ) ? ;
74+ pub async fn fetch_ohttp_keys ( ohttp_relay : Url , directory : Url ) -> Result < OhttpKeys , Error > {
75+ use futures_util :: { AsyncReadExt , AsyncWriteExt } ;
7476
75- let ohttp_keys_res = http_client
76- . get ( format ! ( "{}/ohttp-keys" , directory. as_ref( ) ) )
77- . send ( )
78- . await ?
79- . bytes ( )
80- . await ?;
81- Ok ( OhttpKeys :: decode ( ohttp_keys_res. as_ref ( ) ) . map_err ( |_| Error :: OhttpDecodeFailed ) ?)
77+ let tls_connector = {
78+ let root_store = futures_rustls:: rustls:: RootCertStore {
79+ roots : webpki_roots:: TLS_SERVER_ROOTS . to_vec ( ) ,
80+ } ;
81+ let config = futures_rustls:: rustls:: ClientConfig :: builder ( )
82+ . with_root_certificates ( root_store)
83+ . with_no_client_auth ( ) ;
84+ futures_rustls:: TlsConnector :: from ( Arc :: new ( config) )
85+ } ;
86+ let directory_host = directory. host_str ( ) . ok_or ( Error :: BadDirectoryHost ) ?;
87+ let domain = futures_rustls:: rustls:: pki_types:: ServerName :: try_from ( directory_host)
88+ . map_err ( |_| Error :: BadDirectoryHost ) ?
89+ . to_owned ( ) ;
90+
91+ let ws = WebSocket :: open ( & format ! (
92+ "wss://{}:443" ,
93+ ohttp_relay. host_str( ) . ok_or( Error :: BadOhttpWsHost ) ?
94+ ) )
95+ . map_err ( |_| Error :: BadOhttpWsHost ) ?;
96+
97+ let mut tls_stream = tls_connector
98+ . connect ( domain, ws)
99+ . await
100+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
101+ let ohttp_keys_req = format ! (
102+ "GET /ohttp-keys HTTP/1.1\r \n Host: {}\r \n Connection: close\r \n \r \n " ,
103+ directory_host
104+ ) ;
105+ tls_stream
106+ . write_all ( ohttp_keys_req. as_bytes ( ) )
107+ . await
108+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
109+ tls_stream
110+ . flush ( )
111+ . await
112+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
113+ let mut response_bytes = Vec :: new ( ) ;
114+ tls_stream
115+ . read_to_end ( & mut response_bytes)
116+ . await
117+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
118+ let ( _headers, res_body) = separate_headers_and_body ( & response_bytes) ?;
119+ payjoin:: OhttpKeys :: decode ( res_body) . map_err ( |_| Error :: OhttpDecodeFailed )
120+ }
121+
122+ fn separate_headers_and_body ( response_bytes : & [ u8 ] ) -> Result < ( & [ u8 ] , & [ u8 ] ) , Error > {
123+ let separator = b"\r \n \r \n " ;
124+
125+ // Search for the separator
126+ if let Some ( position) = response_bytes
127+ . windows ( separator. len ( ) )
128+ . position ( |window| window == separator)
129+ {
130+ // The body starts immediately after the separator
131+ let body_start_index = position + separator. len ( ) ;
132+ let headers = & response_bytes[ ..position] ;
133+ let body = & response_bytes[ body_start_index..] ;
134+
135+ Ok ( ( headers, body) )
136+ } else {
137+ Err ( Error :: RequestFailed (
138+ "No header-body separator found in the response" . to_string ( ) ,
139+ ) )
140+ }
82141}
83142
84143#[ derive( Debug ) ]
@@ -89,6 +148,9 @@ pub enum Error {
89148 OhttpDecodeFailed ,
90149 Shutdown ,
91150 SessionExpired ,
151+ BadDirectoryHost ,
152+ BadOhttpWsHost ,
153+ RequestFailed ( String ) ,
92154}
93155
94156impl std:: error:: Error for Error { }
@@ -102,6 +164,9 @@ impl std::fmt::Display for Error {
102164 Error :: OhttpDecodeFailed => write ! ( f, "Failed to decode ohttp keys" ) ,
103165 Error :: Shutdown => write ! ( f, "Payjoin stopped by application shutdown" ) ,
104166 Error :: SessionExpired => write ! ( f, "Payjoin session expired. Create a new payment request and have the sender try again." ) ,
167+ Error :: BadDirectoryHost => write ! ( f, "Bad directory host" ) ,
168+ Error :: BadOhttpWsHost => write ! ( f, "Bad ohttp ws host" ) ,
169+ Error :: RequestFailed ( e) => write ! ( f, "Request failed: {}" , e) ,
105170 }
106171 }
107172}
0 commit comments