diff --git a/src/hostnet/hostnet_http.ml b/src/hostnet/hostnet_http.ml index 3f1909bf1..f28e8e8d7 100644 --- a/src/hostnet/hostnet_http.ml +++ b/src/hostnet/hostnet_http.ml @@ -294,23 +294,129 @@ module Make b_t remote ~incoming ~outgoing ] - let rec proxy_body_request_exn ~reader ~writer = - let open Cohttp.Transfer in - Incoming.Request.read_body_chunk reader >>= function - | Done -> Lwt.return_unit - | Final_chunk x -> Outgoing.Request.write_body writer x - | Chunk x -> - Outgoing.Request.write_body writer x >>= fun () -> - proxy_body_request_exn ~reader ~writer - - let rec proxy_body_response_exn ~reader ~writer = - let open Cohttp.Transfer in - Outgoing.Response.read_body_chunk reader >>= function - | Done -> Lwt.return_unit - | Final_chunk x -> Incoming.Response.write_body writer x - | Chunk x -> - Incoming.Response.write_body writer x >>= fun () -> - proxy_body_response_exn ~reader ~writer + (* from cohttp/transfer_io.ml *) + let parse_chunksize chunk_size_hex = + let hex = + (* chunk size is optionally delimited by ; *) + match String.cut ~sep:";" chunk_size_hex with + | None -> chunk_size_hex + | Some (chunk_size_hex, _extensions) -> chunk_size_hex in + try Some (Int64.of_string ("0x" ^ hex)) + with _ -> None + + module Copy_body(From: Mirage_channel_lwt.S)(To: Mirage_channel_lwt.S) = struct + let copy_exn ~description ~direction ~encoding ~incoming ~outgoing = + Log.debug (fun f -> f "%s: copy_exn" (description direction)); + let open Cohttp.Transfer in + let max_buffer_size = 0x8000L in (* same as cohttp *) + let rec copy_bytes remaining = + (* If we know the total amount of data then [remaining = Some x]. If [remaining = None] + then we copy up to EOF *) + Log.debug (fun f -> f "%s: copy_bytes remaining = %s" (description direction) (match remaining with None -> "None" | Some x -> Int64.to_string x)); + if remaining = Some 0L then begin + Log.debug (fun f -> f "%s: copy_bytes complete" (description direction)); + Lwt.return_unit + end else begin + let len = match remaining with None -> max_buffer_size | Some x -> min max_buffer_size x in + From.read_some incoming ~len:(Int64.to_int len) + >>= function + | Ok `Eof -> + if remaining = None then begin + (* EOF is the end of the data *) + Log.debug (fun f -> f "%s: copy_bytes/read_some hit EOF" (description direction)); + Lwt.return_unit + end else begin + Log.warn (fun f -> f "%s: copy_bytes/read_some encountered premature EOF" (description direction)); + Lwt.fail End_of_file + end + | Error e -> + Log.warn (fun f -> f "%s: copy_bytes/read_some failed with %a" (description direction) From.pp_error e); + Lwt.fail End_of_file + | Ok (`Data buf) -> + Log.debug (fun f -> f "%s: copy_bytes/read_some read %d bytes" (description direction) (Cstruct.len buf)); + To.write_buffer outgoing buf; + To.flush outgoing >>= function + | Ok () -> + Log.debug (fun f -> f "%s: copy_bytes/flush wrote %d bytes" (description direction) (Cstruct.len buf)); + copy_bytes (match remaining with None -> None | Some x -> Some(Int64.sub x (Int64.of_int @@ Cstruct.len buf))) + | Error `Closed -> + Log.warn (fun f -> f "%s: copy_bytes/flush encountered premature EOF" (description direction)); + Lwt.fail End_of_file + | Error e -> + Log.warn (fun f -> f "%s: copy_bytes/flush failed with %a" (description direction) To.pp_write_error e); + Lwt.fail End_of_file + end in + match encoding with + | Fixed len -> + copy_bytes (Some len) + | Unknown -> + copy_bytes None + | Chunked -> + let rec copy_chunks () = + From.read_line incoming + >>= function + | Ok `Eof -> + Log.warn (fun f -> f "%s: copy_chunks/read_line encountered premature EOF" (description direction)); + Lwt.fail End_of_file + | Error e -> + Log.warn (fun f -> f "%s: copy_chunks/read_line failed with %a" (description direction) From.pp_error e); + Lwt.fail End_of_file + | Ok (`Data bufs) -> + let size_and_parameters = Cstruct.copyv bufs in + Log.debug (fun f -> f "%s: copy_chunks size = '%s'" (description direction) size_and_parameters); + begin match parse_chunksize size_and_parameters with + | None -> + Log.warn (fun f -> f "%s: copy_chunks failed to parse chunk size: %s" (description direction) size_and_parameters); + Lwt.fail End_of_file + | Some 0L -> + Log.debug (fun f -> f "%s: copy_chunks length = 0" (description direction)); + List.iter (To.write_buffer outgoing) bufs; + let rec trailer_headers () = + From.read_line incoming + >>= function + | Ok `Eof -> + Log.warn (fun f -> f "%s: trailer_headers/read_line encountered premature EOF" (description direction)); + Lwt.fail End_of_file + | Error e -> + Log.warn (fun f -> f "%s: trailer_headers/read_line failed with %a" (description direction) From.pp_error e); + Lwt.fail End_of_file + | Ok (`Data bufs) -> + Log.debug (fun f -> f "%s: trailer_headers/read_line succeed with %s" (description direction) (Cstruct.copyv bufs)); + List.iter (To.write_buffer outgoing) bufs; + To.write_line outgoing "\r"; + To.flush outgoing + >>= function + | Error `Closed -> + Log.warn (fun f -> f "%s: trailer_headers/flush encountered premature EOF" (description direction)); + Lwt.fail End_of_file + | Error e -> + Log.warn (fun f -> f "%s: trailer_headers/flush failed with %a" (description direction) To.pp_write_error e); + Lwt.fail End_of_file + | Ok () -> + Log.debug (fun f -> f "%s: trailer_headers/flush wrote %d bytes" (description direction) (List.fold_left (+) 2 (List.map Cstruct.len bufs))); + if Cstruct.copyv bufs = "" then begin + Log.debug (fun f -> f "%s: trailer_headers complete" (description direction)); + Lwt.return_unit (* end of headers *) + end else trailer_headers () in + trailer_headers () + | Some count -> + (* chunk size (and original parameters) *) + List.iter (To.write_buffer outgoing) bufs; + To.write_line outgoing "\r"; + Log.debug (fun f -> f "%s: copy_chunks queued %d bytes" (description direction) (List.fold_left (+) 2 (List.map Cstruct.len bufs))); + (* chunk data *) + copy_bytes (Some count) + >>= fun () -> + (* CRLF at the end of the chunk *) + copy_bytes (Some 2L) + >>= fun () -> + copy_chunks () + end in + copy_chunks () + end + + module Copy_request_body = Copy_body(Incoming.C)(Outgoing.C) + module Copy_response_body = Copy_body(Outgoing.C)(Incoming.C) (* Take a request and a pair (incoming, outgoing) of channels, send the request to the outgoing channel and then proxy back any response. @@ -319,11 +425,9 @@ module Make (* Cohttp can fail promises so we catch them here *) Lwt.catch (fun () -> - let reader = Incoming.Request.make_body_reader req incoming in - Log.info (fun f -> f "Outgoing.Request.write"); - Outgoing.Request.write ~flush:true (fun writer -> + Outgoing.Request.write ~flush:true (fun _writer -> match Incoming.Request.has_body req with - | `Yes -> proxy_body_request_exn ~reader ~writer + | `Yes -> Copy_request_body.copy_exn ~description ~direction:true ~encoding:req.encoding ~incoming ~outgoing | `No -> Lwt.return_unit | `Unknown -> Log.warn (fun f -> @@ -332,8 +436,6 @@ module Make Lwt.return_unit ) req outgoing >>= fun () -> - Log.info (fun f -> f "Outgoing.Response.read"); - Outgoing.Response.read outgoing >>= function | `Eof -> Log.warn (fun f -> f "%s: EOF" (description false)); @@ -371,8 +473,7 @@ module Make Lwt.return false | _, _ -> (* Otherwise stay in HTTP mode *) - let reader = Outgoing.Response.make_body_reader res outgoing in - Incoming.Response.write ~flush:true (fun writer -> + Incoming.Response.write ~flush:true (fun _writer -> match Cohttp.Request.meth req, Incoming.Response.has_body res with | `HEAD, `Yes -> (* Bug in cohttp.1.0.2: according to Section 9.4 of RFC2616 @@ -383,17 +484,17 @@ module Make Log.debug (fun f -> f "%s: HEAD requests MUST NOT have response bodies" (description false)); Lwt.return_unit | _, `Yes -> - Log.info (fun f -> f "%s: proxying body" (description false)); - proxy_body_response_exn ~reader ~writer + Log.debug (fun f -> f "%s: proxying body" (description false)); + Copy_response_body.copy_exn ~description ~direction:false ~encoding:res.encoding ~incoming:outgoing ~outgoing:incoming >>= fun () -> Lwt.return_unit | _, `No -> - Log.info (fun f -> f "%s: no body to proxy" (description false)); + Log.debug (fun f -> f "%s: no body to proxy" (description false)); Lwt.return_unit | _, `Unknown when connection_close -> (* There may be a body between here and the EOF *) - Log.info (fun f -> f "%s: proxying until EOF" (description false)); - proxy_body_response_exn ~reader ~writer + Log.debug (fun f -> f "%s: proxying until EOF" (description false)); + Copy_response_body.copy_exn ~description ~direction:false ~encoding:res.encoding ~incoming:outgoing ~outgoing:incoming | _, `Unknown -> Log.warn (fun f -> f "Response.has_body returned `Unknown: not sure \ @@ -667,7 +768,7 @@ module Make | `Proxy, `CONNECT -> host_and_port | `Proxy, _ -> Uri.with_scheme (Uri.with_host (Uri.with_port uri (Some port)) (Some host)) (Some "http") |> Uri.to_string in let req = { req with Cohttp.Request.headers; resource } in - Log.debug (fun f -> f "%s: sending %s" + Log.info (fun f -> f "%s: sending %s" (description false) (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t req)) diff --git a/src/hostnet_test/test_http.ml b/src/hostnet_test/test_http.ml index 47cfeaad5..2665e81c5 100644 --- a/src/hostnet_test/test_http.ml +++ b/src/hostnet_test/test_http.ml @@ -889,6 +889,142 @@ let test_http_connect_tunnel proxy () = ) end + let test_http_proxy_chunked () = + Host.Main.run begin + let results, results_u = Lwt.task () in + let payloads = ref [] in + Slirp_stack.with_stack ~pcap:"test_http_proxy_chunked.pcap" (fun _ stack -> + with_server (fun flow -> + (* Expect 3 requests: one chunked, one fixed and one using EOF *) + (* Note the proxy sets connection: close on external requests *) + let ic = Incoming.C.create flow in + let read_one () = + Incoming.Request.read ic >>= function + | `Eof -> + Log.err (fun f -> f "EOF reading request"); + failwith "EOF reading request" + | `Invalid x -> + Log.err (fun f -> f "Failed to parse request: %s" x); + failwith ("Failed to parse request: " ^ x) + | `Ok req -> + Log.info (fun f -> f "HTTP server received %s" (Sexplib.Sexp.to_string_hum (Cohttp.Request.sexp_of_t req))); + let reader = Incoming.Request.make_body_reader req ic in + let rec read_body acc = + Incoming.Request.read_body_chunk reader >>= function + | Done -> + Log.info (fun f -> f "Chunk done"); + Lwt.return acc + | Final_chunk x -> + Log.info (fun f -> f "Final_chunk '%s'" x); + Lwt.return (acc ^ x) + | Chunk x -> + Log.info (fun f -> f "Chunk [%s]" (String.escaped x)); + read_body (acc ^ x) in + Lwt.catch + (fun () -> + Log.info (fun f -> f "HTTP server reading request body"); + read_body "" + ) + (fun e -> Log.err (fun f -> f "read_body caught %s" (Printexc.to_string e)); Lwt.fail e) + in + let write_ok () = + Log.info (fun f -> f "HTTP server responding with 200 OK"); + let headers = + let h = Cohttp.Header.init () in + Cohttp.Header.add_list h [ + "content-length", "0"; + "connection", "keep-alive"; + ] in + let response = Cohttp.Response.make ~flush:true ~headers ~status:`OK () in + Incoming.Response.write ~flush:true (fun _writer -> Lwt.return_unit) response ic in + read_one () + >>= fun one -> + payloads := one :: !payloads; + if List.length !payloads = 2 then Lwt.wakeup_later results_u (List.rev !payloads); + Log.info (fun f -> f "got: %s" (String.escaped one)); + write_ok () + ) (fun server -> + Lwt.catch (fun () -> + let host = "localhost" in + let port = server.Server.port in + let open Slirp_stack in + let _with_connection f = + Client.TCPV4.create_connection (Client.tcpv4 stack.t) (primary_dns_ip, 3128) + >>= function + | Error _ -> + Log.err (fun f -> f "Failed to connect to %s:3128" (Ipaddr.V4.to_string primary_dns_ip)); + failwith "test_proxy_get: connect failed" + | Ok flow -> + Log.info (fun f -> f "Connected to %s:3128" (Ipaddr.V4.to_string primary_dns_ip)); + let oc = Outgoing.C.create flow in + Lwt.finalize + (fun () -> f oc) + (fun () -> Client.TCPV4.close flow) in + Client.TCPV4.create_connection (Client.tcpv4 stack.t) (primary_dns_ip, 3128) + >>= function + | Error _ -> + Log.err (fun f -> f "Failed to connect to %s:3128" (Ipaddr.V4.to_string primary_dns_ip)); + failwith "test_proxy_get: connect failed" + | Ok flow -> + Log.info (fun f -> f "Connected to %s:3128" (Ipaddr.V4.to_string primary_dns_ip)); + let oc = Outgoing.C.create flow in + let read_ok () = + Outgoing.Response.read oc + >>= function + | `Ok res -> + Log.info (fun f -> f "client received %s" (Sexplib.Sexp.to_string_hum @@ Cohttp.Response.sexp_of_t res)); + Lwt.return_unit + | _ -> failwith "Failed to read response to chunked request" in + let request = Cohttp.Request.make ~meth:`POST (Uri.make ~host ~port ()) in + let headers = + Cohttp.Header.(add_list @@ init ()) [ + "host", "localhost:" ^ (string_of_int port); + ] in + Log.info (fun f -> f "sending one"); + Outgoing.Request.write ~flush:true + (fun _writer -> + (* Example from https://en.wikipedia.org/wiki/Chunked_transfer_encoding *) + let example = "4\r\nWiki\r\n5\r\npedia\r\nE\r\n in\r\n\r\nchunks.\r\n0\r\n\r\n" in + let _example = "0\r\n\r\n" in + Outgoing.C.write_string oc example 0 (String.length example); + Outgoing.C.flush oc + >>= fun _ -> + Lwt.return_unit + ) { request with Cohttp.Request.headers = Cohttp.Header.add_list headers [ + "transfer-encoding", "chunked"; + "connection", "keep-alive" ] } oc + >>= fun () -> + read_ok () + >>= fun () -> + Log.info (fun f -> f "sending two"); + + Outgoing.Request.write ~flush:true + (fun _writer -> + Outgoing.C.write_string oc "hello" 0 5; + Outgoing.C.flush oc + >>= fun _ -> + Lwt.return_unit + ) { request with Cohttp.Request.headers = Cohttp.Header.add_list headers [ + "content-length", "5"; + "connection", "keep-alive" ] } oc + >>= fun () -> + read_ok () + >>= fun () -> + Client.TCPV4.close flow + >>= fun () -> + results + >>= fun result -> + let expected = [ + "Wikipedia in\r\n\r\nchunks."; + "hello"; + ] in + Alcotest.check Alcotest.(list string) "body" expected result; + Lwt.return_unit + ) (fun e -> Log.err (fun f -> f "HTTP client raised %s" (Printexc.to_string e)); Lwt.fail e) + ) + ) + end + let test_http_proxy_head () = Host.Main.run begin @@ -1043,6 +1179,9 @@ let tests = [ "HTTP proxy: respect HTTP/1.0 implicit connection: close", [ "check that the transparent proxy will respect HTTP/1.0 implicit connection: close headers from origin servers", `Quick, test_connection_close true ]; + "HTTP proxy: check transfer-encodings", + [ "check that the proxy understands transfer encodings", `Quick, test_http_proxy_chunked ]; + ] @ (List.map (fun name -> "HTTP proxy: GET to localhost", [ "check that HTTP GET to localhost via hostname", `Quick, test_http_proxy_localhost (Dns.Name.to_string name) ]