Skip to content

Commit 656d94b

Browse files
authored
nsurlsession fix deadlock caused by race conditions in InputStream (#45)
1 parent d6933c4 commit 656d94b

3 files changed

Lines changed: 176 additions & 144 deletions

File tree

backends/nsurlsession/src/stream.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1-
#[cfg(any(feature = "async-stream", feature = "blocking-stream"))]
2-
mod ns_stream;
3-
#[cfg(any(feature = "async-stream", feature = "blocking-stream"))]
4-
mod writer;
5-
#[cfg(not(any(feature = "async-stream", feature = "blocking-stream")))]
6-
#[path = "stream/dummy_writer.rs"]
7-
mod writer;
8-
9-
#[cfg(any(feature = "async-stream", feature = "blocking-stream"))]
10-
pub(crate) use ns_stream::InputStream;
111
pub(crate) use writer::StreamWriter;
122

3+
cfg_if::cfg_if! {
4+
if #[cfg(any(feature = "async-stream", feature = "blocking-stream"))] {
5+
mod ns_stream;
6+
mod writer;
7+
8+
pub(crate) use ns_stream::InputStream;
9+
10+
#[cfg(target_os = "macos")]
11+
const STREAM_BUFFER_SIZE: usize = 1024 * 32;
12+
#[cfg(not(target_os = "macos"))]
13+
const STREAM_BUFFER_SIZE: usize = 1024 * 8;
14+
} else {
15+
#[path = "stream/dummy_writer.rs"]
16+
mod writer;
17+
}
18+
}
19+
1320
pub enum DataOrStream<S> {
1421
Data(Vec<u8>),
1522
Stream(S),

backends/nsurlsession/src/stream/ns_stream.rs

Lines changed: 67 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#![allow(non_snake_case)]
22

3-
use std::io::{self, Cursor};
4-
use std::ops::ControlFlow;
3+
use std::collections::VecDeque;
54
use std::ptr::{null_mut, NonNull};
65
use std::sync::atomic::{AtomicBool, Ordering};
76
use std::sync::Mutex;
@@ -13,23 +12,19 @@ use objc2::runtime::{AnyObject, ProtocolObject};
1312
use objc2::{define_class, msg_send, AllocAnyThread, ClassType, DefinedClass, Message as _};
1413
use objc2_foundation::{
1514
NSArray, NSError, NSInputStream, NSInteger, NSObjectProtocol, NSRunLoop, NSRunLoopMode,
16-
NSStreamDelegate, NSStreamEvent, NSStreamPropertyKey, NSStreamStatus, NSString, NSUInteger,
15+
NSStreamDelegate, NSStreamEvent, NSStreamPropertyKey, NSStreamStatus, NSUInteger,
1716
};
1817

1918
use crate::datatask::GenericWaker;
2019
use crate::retained_ext::SwappableRetained;
21-
22-
#[cfg(target_os = "macos")]
23-
const STREAM_BUFFER_SIZE: usize = 1024 * 64;
24-
#[cfg(not(target_os = "macos"))]
25-
const STREAM_BUFFER_SIZE: usize = 1024 * 16;
20+
use crate::stream::STREAM_BUFFER_SIZE;
2621

2722
pub(crate) struct InputStreamIvars {
2823
waker: GenericWaker,
2924
delegate: ArcSwapAny<Option<SwappableRetained<ProtocolObject<dyn NSStreamDelegate>>>>,
3025
run_loop: ArcSwapAny<Option<SwappableRetained<NSRunLoop>>>,
3126
run_loop_mode: ArcSwapAny<Option<SwappableRetained<NSRunLoopMode>>>,
32-
stream_buffer: Mutex<Result<Cursor<Vec<u8>>, Retained<NSError>>>,
27+
stream_buffer: Mutex<Result<VecDeque<u8>, Retained<NSError>>>,
3328
is_open: AtomicBool,
3429
eof: AtomicBool,
3530
}
@@ -128,65 +123,41 @@ impl InputStream {
128123
delegate: ArcSwapAny::new(None),
129124
run_loop: ArcSwapAny::new(None),
130125
run_loop_mode: ArcSwapAny::new(None),
131-
stream_buffer: Mutex::new(Ok(Cursor::new(vec![0; STREAM_BUFFER_SIZE]))),
126+
stream_buffer: Mutex::new(Ok(VecDeque::with_capacity(STREAM_BUFFER_SIZE))),
132127
is_open: AtomicBool::new(false),
133128
eof: AtomicBool::new(false),
134129
});
135130

136131
unsafe { msg_send![super(this), init] }
137132
}
138-
#[cfg(any(feature = "async-stream", feature = "blocking-stream"))]
139-
pub(crate) fn update_buffer(
140-
&self,
141-
cb: impl FnOnce(&mut [u8]) -> ControlFlow<(), io::Result<usize>>,
142-
) -> io::Result<()> {
133+
pub(crate) fn write(&self, read_result: Result<&[u8], Retained<NSError>>) -> usize {
143134
let ivars = self.ivars();
144-
if ivars.eof.load(Ordering::SeqCst) {
145-
return Ok(());
146-
}
135+
let data = match read_result {
136+
Ok(data) => data,
137+
Err(e) => {
138+
let mut stream_buffer = ivars.stream_buffer.lock().unwrap();
139+
*stream_buffer = Err(e);
140+
drop(stream_buffer);
141+
self.notify_stream_state();
142+
return 0;
143+
}
144+
};
147145
let mut stream_buffer = ivars.stream_buffer.lock().unwrap();
148-
let Ok(cursor) = &mut *stream_buffer else {
149-
return Ok(());
146+
let Ok(stream_buffer) = stream_buffer.as_mut() else {
147+
return 0;
150148
};
151-
152-
let pos = cursor.position() as usize;
153-
let buffer = &mut cursor.get_mut()[pos..];
154-
if !buffer.is_empty() {
155-
let read_res = cb(buffer);
156-
match read_res {
157-
ControlFlow::Break(()) => {
158-
return Ok(());
159-
}
160-
ControlFlow::Continue(Ok(0)) => {
161-
ivars.eof.store(true, Ordering::SeqCst);
162-
if pos == 0 {
163-
self.notify_stream_state(NSStreamEvent::EndEncountered);
164-
return Ok(());
165-
}
166-
}
167-
ControlFlow::Continue(Ok(read_len)) => {
168-
cursor.set_position((pos + read_len) as u64);
169-
}
170-
ControlFlow::Continue(Err(e)) => {
171-
let ns_err = NSError::new(
172-
e.raw_os_error().unwrap_or_default() as _,
173-
&NSString::from_str(&e.to_string()),
174-
);
175-
*stream_buffer = Err(ns_err);
176-
// Release the lock before notifying delegate
177-
drop(stream_buffer);
178-
self.notify_stream_state(NSStreamEvent::ErrorOccurred);
179-
return Err(e);
180-
}
149+
if data.is_empty() {
150+
if !ivars.eof.swap(true, Ordering::SeqCst) && stream_buffer.len() == 0 {
151+
self.notify_stream_state();
181152
}
153+
return 0;
182154
}
183-
184-
if cursor.position() > 0 {
185-
// Release the lock before notifying delegate
186-
drop(stream_buffer);
187-
self.notify_stream_state(NSStreamEvent::HasBytesAvailable);
188-
}
189-
Ok(())
155+
let to_write = data
156+
.len()
157+
.min(stream_buffer.capacity() - stream_buffer.len());
158+
stream_buffer.extend(&data[..to_write]);
159+
self.notify_stream_state();
160+
to_write
190161
}
191162

192163
#[cfg(any(feature = "async-stream", feature = "blocking-stream"))]
@@ -195,7 +166,7 @@ impl InputStream {
195166
}
196167

197168
#[cfg(any(feature = "async-stream", feature = "blocking-stream"))]
198-
fn notify_stream_state(&self, event: NSStreamEvent) {
169+
fn notify_stream_state(&self) {
199170
let ivars = self.ivars();
200171
let Some(delegate) = ivars.delegate.load_full() else {
201172
return;
@@ -207,11 +178,24 @@ impl InputStream {
207178
return;
208179
};
209180
unsafe {
210-
let stream = self.as_super().retain();
181+
let this = self.retain();
211182
run_loop.performInModes_block(
212183
&NSArray::from_retained_slice(std::slice::from_ref(&*run_loop_mode)),
213184
&RcBlock::new(move || {
185+
let event = {
186+
let ivars = this.ivars();
187+
let stream_buffer = ivars.stream_buffer.lock().unwrap();
188+
match &*stream_buffer {
189+
Ok(buffer) if buffer.len() > 0 => NSStreamEvent::HasBytesAvailable,
190+
Ok(_) if ivars.eof.load(Ordering::SeqCst) => {
191+
NSStreamEvent::EndEncountered
192+
}
193+
Ok(_) => return,
194+
Err(_) => NSStreamEvent::ErrorOccurred,
195+
}
196+
};
214197
let delegate: &ProtocolObject<dyn NSStreamDelegate> = &*delegate;
198+
let stream = this.as_super();
215199
delegate.stream_handleEvent(&stream, event);
216200
}),
217201
);
@@ -229,8 +213,7 @@ impl InputStream {
229213
let Ok(stream_buffer) = stream_buffer.as_mut() else {
230214
return;
231215
};
232-
stream_buffer.set_position(0);
233-
*stream_buffer.get_mut() = vec![];
216+
*stream_buffer = Default::default();
234217
}
235218
fn callback_setDelegate(&self, delegate: Option<&ProtocolObject<dyn NSStreamDelegate>>) {
236219
let delegate = delegate.map(|d| d.retain().into());
@@ -251,25 +234,17 @@ impl InputStream {
251234
if !ivars.is_open.load(Ordering::Relaxed) {
252235
return NSStreamStatus::NotOpen;
253236
}
254-
if let Ok(stream_buffer) = ivars.stream_buffer.try_lock() {
255-
let eof = ivars.eof.load(Ordering::SeqCst);
256-
match &*stream_buffer {
257-
Ok(cursor) if cursor.position() == 0 && eof => {
258-
return NSStreamStatus::AtEnd;
259-
}
260-
Err(_) => {
261-
return NSStreamStatus::Error;
262-
}
263-
_ => {}
264-
}
237+
let stream_buffer = ivars.stream_buffer.lock().unwrap();
238+
let eof = ivars.eof.load(Ordering::SeqCst);
239+
match &*stream_buffer {
240+
Ok(buf) if buf.is_empty() && eof => NSStreamStatus::AtEnd,
241+
Ok(_) => NSStreamStatus::Open,
242+
Err(_) => NSStreamStatus::Error,
265243
}
266-
NSStreamStatus::Open
267244
}
268245
fn callback_streamError(&self) -> *mut NSError {
269246
let ivars = self.ivars();
270-
let Ok(stream_buffer) = ivars.stream_buffer.try_lock() else {
271-
return null_mut();
272-
};
247+
let stream_buffer = ivars.stream_buffer.lock().unwrap();
273248
match &*stream_buffer {
274249
Ok(_) => std::ptr::null_mut(),
275250
Err(error) => Retained::into_raw(error.clone()),
@@ -282,34 +257,31 @@ impl InputStream {
282257
};
283258

284259
match &mut *stream_buffer {
285-
Ok(cursor) => {
260+
Ok(buf) => {
286261
ivars.waker.wake();
287-
let read_len = (cursor.position() as usize).min(len);
288-
if read_len == 0 {
289-
return 0;
290-
}
291-
unsafe {
292-
std::ptr::copy_nonoverlapping(
293-
cursor.get_ref().as_ptr(),
294-
buffer.as_ptr(),
295-
read_len,
296-
);
297-
cursor.get_mut().drain(..read_len);
298-
// Safety: the underlying buffer will always have STREAM_BUFFER_SIZE initialized bytes
299-
cursor.get_mut().set_len(STREAM_BUFFER_SIZE);
262+
let (slice1, slice2) = buf.as_slices();
263+
let buffer = unsafe { std::slice::from_raw_parts_mut(buffer.as_ptr(), len) };
264+
let mut read_len = {
265+
let read_len = slice1.len().min(buffer.len());
266+
buffer[..read_len].copy_from_slice(&slice1[..read_len]);
267+
read_len
268+
};
269+
if read_len < buffer.len() {
270+
let read_len2 = slice2.len().min(buffer.len() - read_len);
271+
buffer[read_len..][..read_len2].copy_from_slice(&slice2[..read_len2]);
272+
read_len += read_len2;
300273
}
301-
cursor.set_position(cursor.position() - read_len as u64);
274+
275+
buf.drain(..read_len);
302276
read_len as NSInteger
303277
}
304278
Err(_) => -1,
305279
}
306280
}
307281
fn callback_hasBytesAvailable(&self) -> bool {
308282
let ivars = self.ivars();
309-
let Ok(stream_buffer) = ivars.stream_buffer.try_lock() else {
310-
return false;
311-
};
312-
matches!(&*stream_buffer, Ok(cursor) if cursor.position() > 0)
283+
let stream_buffer = ivars.stream_buffer.lock().unwrap();
284+
matches!(&*stream_buffer, Ok(buf) if !buf.is_empty())
313285
}
314286
}
315287

0 commit comments

Comments
 (0)