1
1
use async_std:: io:: { BufReader , Read } ;
2
2
use async_std:: prelude:: * ;
3
- use http_types:: { ensure , ensure_eq , format_err } ;
3
+ use http_types:: content :: ContentLength ;
4
4
use http_types:: {
5
- headers:: { CONTENT_LENGTH , DATE , TRANSFER_ENCODING } ,
5
+ headers:: { DATE , TRANSFER_ENCODING } ,
6
6
Body , Response , StatusCode ,
7
7
} ;
8
8
9
9
use std:: convert:: TryFrom ;
10
10
11
- use crate :: chunked:: ChunkedDecoder ;
12
11
use crate :: date:: fmt_http_date;
12
+ use crate :: { chunked:: ChunkedDecoder , Error } ;
13
13
use crate :: { MAX_HEADERS , MAX_HEAD_LENGTH } ;
14
14
15
15
const CR : u8 = b'\r' ;
16
16
const LF : u8 = b'\n' ;
17
17
18
18
/// Decode an HTTP response on the client.
19
- pub async fn decode < R > ( reader : R ) -> http_types :: Result < Response >
19
+ pub async fn decode < R > ( reader : R ) -> crate :: Result < Option < Response > >
20
20
where
21
21
R : Read + Unpin + Send + Sync + ' static ,
22
22
{
@@ -29,13 +29,14 @@ where
29
29
loop {
30
30
let bytes_read = reader. read_until ( LF , & mut buf) . await ?;
31
31
// No more bytes are yielded from the stream.
32
- assert ! ( bytes_read != 0 , "Empty response" ) ; // TODO: ensure?
32
+ if bytes_read == 0 {
33
+ return Ok ( None ) ;
34
+ }
33
35
34
36
// Prevent CWE-400 DDOS with large HTTP Headers.
35
- ensure ! (
36
- buf. len( ) < MAX_HEAD_LENGTH ,
37
- "Head byte length should be less than 8kb"
38
- ) ;
37
+ if buf. len ( ) >= MAX_HEAD_LENGTH {
38
+ return Err ( Error :: HeadersTooLong ) ;
39
+ }
39
40
40
41
// We've hit the end delimiter of the stream.
41
42
let idx = buf. len ( ) - 1 ;
@@ -49,17 +50,23 @@ where
49
50
50
51
// Convert our header buf into an httparse instance, and validate.
51
52
let status = httparse_res. parse ( & buf) ?;
52
- ensure ! ( !status. is_partial( ) , "Malformed HTTP head" ) ;
53
+ if status. is_partial ( ) {
54
+ return Err ( Error :: PartialHead ) ;
55
+ }
53
56
54
- let code = httparse_res. code ;
55
- let code = code. ok_or_else ( || format_err ! ( "No status code found" ) ) ?;
57
+ let code = httparse_res. code . ok_or ( Error :: MissingStatusCode ) ?;
56
58
57
59
// Convert httparse headers + body into a `http_types::Response` type.
58
- let version = httparse_res. version ;
59
- let version = version. ok_or_else ( || format_err ! ( "No version found" ) ) ?;
60
- ensure_eq ! ( version, 1 , "Unsupported HTTP version" ) ;
60
+ let version = httparse_res. version . ok_or ( Error :: MissingVersion ) ?;
61
+
62
+ if version != 1 {
63
+ return Err ( Error :: UnsupportedVersion ( version) ) ;
64
+ }
65
+
66
+ let status_code =
67
+ StatusCode :: try_from ( code) . map_err ( |_| Error :: UnrecognizedStatusCode ( code) ) ?;
68
+ let mut res = Response :: new ( status_code) ;
61
69
62
- let mut res = Response :: new ( StatusCode :: try_from ( code) ?) ;
63
70
for header in httparse_res. headers . iter ( ) {
64
71
res. append_header ( header. name , std:: str:: from_utf8 ( header. value ) ?) ;
65
72
}
@@ -69,13 +76,13 @@ where
69
76
res. insert_header ( DATE , & format ! ( "date: {}\r \n " , date) [ ..] ) ;
70
77
}
71
78
72
- let content_length = res. header ( CONTENT_LENGTH ) ;
79
+ let content_length =
80
+ ContentLength :: from_headers ( & res) . map_err ( |_| Error :: MalformedHeader ( "content-length" ) ) ?;
73
81
let transfer_encoding = res. header ( TRANSFER_ENCODING ) ;
74
82
75
- ensure ! (
76
- content_length. is_none( ) || transfer_encoding. is_none( ) ,
77
- "Unexpected Content-Length header"
78
- ) ;
83
+ if content_length. is_some ( ) && transfer_encoding. is_some ( ) {
84
+ return Err ( Error :: UnexpectedHeader ( "content-length" ) ) ;
85
+ }
79
86
80
87
if let Some ( encoding) = transfer_encoding {
81
88
if encoding. last ( ) . as_str ( ) == "chunked" {
@@ -84,16 +91,16 @@ where
84
91
res. set_body ( Body :: from_reader ( reader, None ) ) ;
85
92
86
93
// Return the response.
87
- return Ok ( res) ;
94
+ return Ok ( Some ( res) ) ;
88
95
}
89
96
}
90
97
91
98
// Check for Content-Length.
92
- if let Some ( len ) = content_length {
93
- let len = len . last ( ) . as_str ( ) . parse :: < usize > ( ) ? ;
94
- res. set_body ( Body :: from_reader ( reader. take ( len as u64 ) , Some ( len) ) ) ;
99
+ if let Some ( content_length ) = content_length {
100
+ let len = content_length . len ( ) ;
101
+ res. set_body ( Body :: from_reader ( reader. take ( len) , Some ( len as usize ) ) ) ;
95
102
}
96
103
97
104
// Return the response.
98
- Ok ( res)
105
+ Ok ( Some ( res) )
99
106
}
0 commit comments