Skip to content

Commit 5d80128

Browse files
committed
fix(cursors/bytes): protect against empty chunks
1 parent 8c21f6a commit 5d80128

File tree

2 files changed

+34
-51
lines changed

2 files changed

+34
-51
lines changed

src/cursors/bytes.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use crate::{bytes_ext::BytesExt, cursors::RawCursor, error::Result, response::Response};
2-
use bytes::{Bytes, BytesMut};
1+
use crate::{cursors::RawCursor, error::Result, response::Response};
2+
use bytes::{Buf, Bytes, BytesMut};
33
use std::{
44
io::Result as IoResult,
55
pin::Pin,
@@ -33,7 +33,7 @@ use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
3333
/// [`Query::fetch_bytes`]: crate::query::Query::fetch_bytes
3434
pub struct BytesCursor {
3535
raw: RawCursor,
36-
bytes: BytesExt,
36+
bytes: Bytes,
3737
}
3838

3939
// TODO: what if any next/poll_* called AFTER error returned?
@@ -42,7 +42,7 @@ impl BytesCursor {
4242
pub(crate) fn new(response: Response) -> Self {
4343
Self {
4444
raw: RawCursor::new(response),
45-
bytes: BytesExt::default(),
45+
bytes: Bytes::default(),
4646
}
4747
}
4848

@@ -82,17 +82,19 @@ impl BytesCursor {
8282

8383
#[cold]
8484
fn poll_refill(&mut self, cx: &mut Context<'_>) -> Poll<IoResult<bool>> {
85-
debug_assert_eq!(self.bytes.remaining(), 0);
86-
87-
// TODO: should we repeat if `poll_next` returns an empty buffer?
88-
89-
match ready!(self.raw.poll_next(cx)?) {
90-
Some(chunk) => {
91-
self.bytes.extend(chunk);
92-
Poll::Ready(Ok(true))
85+
debug_assert_eq!(self.bytes.len(), 0);
86+
87+
// Theoretically, `self.raw.poll_next(cx)` can return empty chunks.
88+
// In this case, we should continue polling until we get a non-empty chunk or
89+
// end of stream in order to avoid false positive `Ok(0)` in I/O traits.
90+
while self.bytes.is_empty() {
91+
match ready!(self.raw.poll_next(cx)?) {
92+
Some(chunk) => self.bytes = chunk,
93+
None => return Poll::Ready(Ok(false)),
9394
}
94-
None => Poll::Ready(Ok(false)),
9595
}
96+
97+
Poll::Ready(Ok(true))
9698
}
9799

98100
/// Returns the total size in bytes received from the CH server since
@@ -125,9 +127,9 @@ impl AsyncRead for BytesCursor {
125127
break;
126128
}
127129

128-
let bytes = self.bytes.slice();
129-
let len = bytes.len().min(buf.remaining());
130-
buf.put_slice(&bytes[..len]);
130+
let len = self.bytes.len().min(buf.remaining());
131+
let bytes = self.bytes.slice(..len);
132+
buf.put_slice(&bytes[0..len]);
131133
self.bytes.advance(len);
132134
}
133135

@@ -138,17 +140,17 @@ impl AsyncRead for BytesCursor {
138140
impl AsyncBufRead for BytesCursor {
139141
#[inline]
140142
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<&[u8]>> {
141-
if self.bytes.is_empty() && !ready!(self.poll_refill(cx)?) {
142-
return Poll::Ready(Ok(&[]));
143+
if self.bytes.is_empty() {
144+
ready!(self.poll_refill(cx)?);
143145
}
144146

145-
Poll::Ready(Ok(self.get_mut().bytes.slice()))
147+
Poll::Ready(Ok(&self.get_mut().bytes))
146148
}
147149

148150
#[inline]
149151
fn consume(mut self: Pin<&mut Self>, amt: usize) {
150152
assert!(
151-
amt <= self.bytes.remaining(),
153+
amt <= self.bytes.len(),
152154
"invalid `AsyncBufRead::consume` usage"
153155
);
154156
self.bytes.advance(amt);

src/response.rs

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ impl Chunks {
147147
fn new(stream: Incoming, compression: Compression) -> Self {
148148
let stream = IncomingStream(stream);
149149
let stream = Decompress::new(stream, compression);
150-
let stream = DetectDbException::new(stream);
150+
let stream = DetectDbException(stream);
151151
Self(Some(Box::new(stream)))
152152
}
153153
}
@@ -156,14 +156,12 @@ impl Stream for Chunks {
156156
type Item = Result<Chunk>;
157157

158158
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
159-
// `take()` prevents from use after caught panic.
159+
// We use `take()` to make the stream fused, including the case of panics.
160160
if let Some(mut stream) = self.0.take() {
161161
let res = Pin::new(&mut stream).poll_next(cx);
162162

163163
if matches!(res, Poll::Pending | Poll::Ready(Some(Ok(_)))) {
164164
self.0 = Some(stream);
165-
} else {
166-
assert!(self.0.is_none());
167165
}
168166

169167
res
@@ -244,16 +242,7 @@ where
244242

245243
// === DetectDbException ===
246244

247-
enum DetectDbException<S> {
248-
Stream(S),
249-
Exception(Option<Error>),
250-
}
251-
252-
impl<S> DetectDbException<S> {
253-
fn new(stream: S) -> Self {
254-
Self::Stream(stream)
255-
}
256-
}
245+
struct DetectDbException<S>(S);
257246

258247
impl<S> Stream for DetectDbException<S>
259248
where
@@ -262,30 +251,23 @@ where
262251
type Item = Result<Chunk>;
263252

264253
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
265-
match &mut *self {
266-
Self::Stream(stream) => {
267-
let mut res = Pin::new(stream).poll_next(cx);
268-
269-
if let Poll::Ready(Some(Ok(chunk))) = &mut res {
270-
if let Some(err) = extract_exception(&mut chunk.data) {
271-
*self = Self::Exception(Some(err));
254+
let res = Pin::new(&mut self.0).poll_next(cx);
272255

273-
// NOTE: `chunk` can be empty, but it's ok for callers.
274-
}
275-
}
276-
277-
res
256+
if let Poll::Ready(Some(Ok(chunk))) = &res {
257+
if let Some(err) = extract_exception(&chunk.data) {
258+
return Poll::Ready(Some(Err(err)));
278259
}
279-
Self::Exception(err) => Poll::Ready(err.take().map(Err)),
280260
}
261+
262+
res
281263
}
282264
}
283265

284266
// Format:
285267
// ```
286268
// <data>Code: <code>. DB::Exception: <desc> (version <version> (official build))\n
287269
// ```
288-
fn extract_exception(chunk: &mut Bytes) -> Option<Error> {
270+
fn extract_exception(chunk: &[u8]) -> Option<Error> {
289271
// `))\n` is very rare in real data, so it's fast dirty check.
290272
// In random data, it occurs with a probability of ~6*10^-8 only.
291273
if chunk.ends_with(b"))\n") {
@@ -297,14 +279,13 @@ fn extract_exception(chunk: &mut Bytes) -> Option<Error> {
297279

298280
#[cold]
299281
#[inline(never)]
300-
fn extract_exception_slow(chunk: &mut Bytes) -> Option<Error> {
282+
fn extract_exception_slow(chunk: &[u8]) -> Option<Error> {
301283
let index = chunk.rfind(b"Code:")?;
302284

303285
if !chunk[index..].contains_str(b"DB::Exception:") {
304286
return None;
305287
}
306288

307-
let exception = chunk.split_off(index);
308-
let exception = String::from_utf8_lossy(&exception[..exception.len() - 1]);
289+
let exception = String::from_utf8_lossy(&chunk[index..chunk.len() - 1]);
309290
Some(Error::BadResponse(exception.into()))
310291
}

0 commit comments

Comments
 (0)