Skip to content

Commit dabe27d

Browse files
committed
use a concrete error type in async-h1
1 parent e513687 commit dabe27d

11 files changed

+189
-95
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ log = "0.4.11"
2828
pin-project = "1.0.2"
2929
async-channel = "1.5.1"
3030
async-dup = "1.2.2"
31+
thiserror = "1.0.22"
3132

3233
[dev-dependencies]
3334
pretty_assertions = "0.6.1"

src/client/decode.rs

+33-26
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
use async_std::io::{BufReader, Read};
22
use async_std::prelude::*;
3-
use http_types::{ensure, ensure_eq, format_err};
3+
use http_types::content::ContentLength;
44
use http_types::{
5-
headers::{CONTENT_LENGTH, DATE, TRANSFER_ENCODING},
5+
headers::{DATE, TRANSFER_ENCODING},
66
Body, Response, StatusCode,
77
};
88

99
use std::convert::TryFrom;
1010

11-
use crate::chunked::ChunkedDecoder;
1211
use crate::date::fmt_http_date;
12+
use crate::{chunked::ChunkedDecoder, Error};
1313
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
1414

1515
const CR: u8 = b'\r';
1616
const LF: u8 = b'\n';
1717

1818
/// Decode an HTTP response on the client.
19-
pub async fn decode<R>(reader: R) -> http_types::Result<Response>
19+
pub async fn decode<R>(reader: R) -> crate::Result<Option<Response>>
2020
where
2121
R: Read + Unpin + Send + Sync + 'static,
2222
{
@@ -29,13 +29,14 @@ where
2929
loop {
3030
let bytes_read = reader.read_until(LF, &mut buf).await?;
3131
// No more bytes are yielded from the stream.
32-
assert!(bytes_read != 0, "Empty response"); // TODO: ensure?
32+
if bytes_read == 0 {
33+
return Ok(None);
34+
}
3335

3436
// Prevent CWE-400 DDOS with large HTTP Headers.
35-
ensure!(
36-
buf.len() < MAX_HEAD_LENGTH,
37-
"Head byte length should be less than 8kb"
38-
);
37+
if buf.len() >= MAX_HEAD_LENGTH {
38+
return Err(Error::HeadersTooLong);
39+
}
3940

4041
// We've hit the end delimiter of the stream.
4142
let idx = buf.len() - 1;
@@ -49,17 +50,23 @@ where
4950

5051
// Convert our header buf into an httparse instance, and validate.
5152
let status = httparse_res.parse(&buf)?;
52-
ensure!(!status.is_partial(), "Malformed HTTP head");
53+
if status.is_partial() {
54+
return Err(Error::PartialHead);
55+
}
5356

54-
let code = httparse_res.code;
55-
let code = code.ok_or_else(|| format_err!("No status code found"))?;
57+
let code = httparse_res.code.ok_or(Error::MissingStatusCode)?;
5658

5759
// Convert httparse headers + body into a `http_types::Response` type.
58-
let version = httparse_res.version;
59-
let version = version.ok_or_else(|| format_err!("No version found"))?;
60-
ensure_eq!(version, 1, "Unsupported HTTP version");
60+
let version = httparse_res.version.ok_or(Error::MissingVersion)?;
61+
62+
if version != 1 {
63+
return Err(Error::UnsupportedVersion(version));
64+
}
65+
66+
let status_code =
67+
StatusCode::try_from(code).map_err(|_| Error::UnrecognizedStatusCode(code))?;
68+
let mut res = Response::new(status_code);
6169

62-
let mut res = Response::new(StatusCode::try_from(code)?);
6370
for header in httparse_res.headers.iter() {
6471
res.append_header(header.name, std::str::from_utf8(header.value)?);
6572
}
@@ -69,13 +76,13 @@ where
6976
res.insert_header(DATE, &format!("date: {}\r\n", date)[..]);
7077
}
7178

72-
let content_length = res.header(CONTENT_LENGTH);
79+
let content_length =
80+
ContentLength::from_headers(&res).map_err(|_| Error::MalformedHeader("content-length"))?;
7381
let transfer_encoding = res.header(TRANSFER_ENCODING);
7482

75-
ensure!(
76-
content_length.is_none() || transfer_encoding.is_none(),
77-
"Unexpected Content-Length header"
78-
);
83+
if content_length.is_some() && transfer_encoding.is_some() {
84+
return Err(Error::UnexpectedHeader("content-length"));
85+
}
7986

8087
if let Some(encoding) = transfer_encoding {
8188
if encoding.last().as_str() == "chunked" {
@@ -84,16 +91,16 @@ where
8491
res.set_body(Body::from_reader(reader, None));
8592

8693
// Return the response.
87-
return Ok(res);
94+
return Ok(Some(res));
8895
}
8996
}
9097

9198
// Check for Content-Length.
92-
if let Some(len) = content_length {
93-
let len = len.last().as_str().parse::<usize>()?;
94-
res.set_body(Body::from_reader(reader.take(len as u64), Some(len)));
99+
if let Some(content_length) = content_length {
100+
let len = content_length.len();
101+
res.set_body(Body::from_reader(reader.take(len), Some(len as usize)));
95102
}
96103

97104
// Return the response.
98-
Ok(res)
105+
Ok(Some(res))
99106
}

src/client/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub use decode::decode;
1010
pub use encode::Encoder;
1111

1212
/// Opens an HTTP/1.1 connection to a remote host.
13-
pub async fn connect<RW>(mut stream: RW, req: Request) -> http_types::Result<Response>
13+
pub async fn connect<RW>(mut stream: RW, req: Request) -> crate::Result<Option<Response>>
1414
where
1515
RW: Read + Write + Send + Sync + Unpin + 'static,
1616
{

src/error.rs

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
use std::str::Utf8Error;
2+
3+
use http_types::url;
4+
use thiserror::Error;
5+
6+
/// Concrete errors that occur within async-h1
7+
#[derive(Error, Debug)]
8+
#[non_exhaustive]
9+
pub enum Error {
10+
/// [`std::io::Error`]
11+
#[error(transparent)]
12+
IO(#[from] std::io::Error),
13+
14+
/// [`url::ParseError`]
15+
#[error(transparent)]
16+
Url(#[from] url::ParseError),
17+
18+
/// this error describes a malformed request with a path that does
19+
/// not start with / or http:// or https://
20+
#[error("unexpected uri format")]
21+
UnexpectedURIFormat,
22+
23+
/// this error describes a http 1.1 request that is missing a Host
24+
/// header
25+
#[error("mandatory host header missing")]
26+
HostHeaderMissing,
27+
28+
/// this error describes a request that does not specify a path
29+
#[error("request path missing")]
30+
RequestPathMissing,
31+
32+
/// [`httparse::Error`]
33+
#[error(transparent)]
34+
Httparse(#[from] httparse::Error),
35+
36+
/// an incomplete http head
37+
#[error("partial http head")]
38+
PartialHead,
39+
40+
/// we were unable to parse a header
41+
#[error("malformed http header {0}")]
42+
MalformedHeader(&'static str),
43+
44+
/// async-h1 doesn't speak this http version
45+
/// this error is deprecated
46+
#[error("unsupported http version 1.{0}")]
47+
UnsupportedVersion(u8),
48+
49+
/// we were unable to parse this http method
50+
#[error("unsupported http method {0}")]
51+
UnrecognizedMethod(String),
52+
53+
/// this request did not have a method
54+
#[error("missing method")]
55+
MissingMethod,
56+
57+
/// this request did not have a status code
58+
#[error("missing status code")]
59+
MissingStatusCode,
60+
61+
/// we were unable to parse this http method
62+
#[error("unrecognized http status code {0}")]
63+
UnrecognizedStatusCode(u16),
64+
65+
/// this request did not have a version, but we expect one
66+
/// this error is deprecated
67+
#[error("missing version")]
68+
MissingVersion,
69+
70+
/// we expected utf8, but there was an encoding error
71+
#[error(transparent)]
72+
EncodingError(#[from] Utf8Error),
73+
74+
/// we received a header that does not make sense in context
75+
#[error("unexpected header: {0}")]
76+
UnexpectedHeader(&'static str),
77+
78+
/// for security reasons, we do not allow request headers beyond 8kb.
79+
#[error("Head byte length should be less than 8kb")]
80+
HeadersTooLong,
81+
}
82+
83+
/// this crate's result type
84+
pub type Result<T> = std::result::Result<T, Error>;

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ use async_std::io::Cursor;
116116
use body_encoder::BodyEncoder;
117117
pub use client::connect;
118118
pub use server::{accept, accept_with_opts, ServerOptions};
119+
mod error;
120+
pub use error::{Error, Result};
119121

120122
#[derive(Debug)]
121123
pub(crate) enum EncoderState {

0 commit comments

Comments
 (0)