diff --git a/grpc-sys/bindings/x86_64-unknown-linux-gnu-bindings.rs b/grpc-sys/bindings/x86_64-unknown-linux-gnu-bindings.rs index 4987ef506..ec848804c 100644 --- a/grpc-sys/bindings/x86_64-unknown-linux-gnu-bindings.rs +++ b/grpc-sys/bindings/x86_64-unknown-linux-gnu-bindings.rs @@ -5948,6 +5948,11 @@ extern "C" { extern "C" { pub fn grpcwrap_slice_length(slice: *const grpc_slice) -> usize; } +extern "C" { + pub fn grpcwrap_batch_context_take_send_message( + ctx: *mut grpcwrap_batch_context, + ) -> *mut grpc_byte_buffer; +} extern "C" { pub fn grpcwrap_batch_context_take_recv_message( ctx: *mut grpcwrap_batch_context, diff --git a/grpc-sys/grpc_wrap.cc b/grpc-sys/grpc_wrap.cc index 2ba578e6b..b703ca995 100644 --- a/grpc-sys/grpc_wrap.cc +++ b/grpc-sys/grpc_wrap.cc @@ -307,6 +307,16 @@ GPR_EXPORT size_t GPR_CALLTYPE grpcwrap_slice_length(const grpc_slice* slice) { return GRPC_SLICE_LENGTH(*slice); } +GPR_EXPORT grpc_byte_buffer* GPR_CALLTYPE +grpcwrap_batch_context_take_send_message(grpcwrap_batch_context* ctx) { + grpc_byte_buffer* buf = nullptr; + if (ctx->send_message) { + buf = ctx->send_message; + ctx->send_message = nullptr; + } + return buf; +} + GPR_EXPORT grpc_byte_buffer* GPR_CALLTYPE grpcwrap_batch_context_take_recv_message(grpcwrap_batch_context* ctx) { grpc_byte_buffer* buf = nullptr; diff --git a/src/call/client.rs b/src/call/client.rs index 0dad49095..2a9f44b75 100644 --- a/src/call/client.rs +++ b/src/call/client.rs @@ -13,7 +13,7 @@ use crate::channel::Channel; use crate::codec::{DeserializeFn, SerializeFn}; use crate::error::{Error, Result}; use crate::metadata::Metadata; -use crate::task::{BatchFuture, BatchType, SpinLock}; +use crate::task::{unref_raw_tag, BatchFuture, BatchType, CallTag, SpinLock}; /// Update the flag bit in res. #[inline] @@ -104,7 +104,7 @@ impl Call { let call = channel.create_call(method, &opt)?; let mut payload = vec![]; (method.req_ser())(req, &mut payload); - let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe { + let (cq_f, _) = check_run(BatchType::CheckRead, |ctx, tag| unsafe { grpc_sys::grpcwrap_call_start_unary( call.call, ctx, @@ -127,7 +127,7 @@ impl Call { mut opt: CallOption, ) -> Result<(ClientCStreamSender, ClientCStreamReceiver)> { let call = channel.create_call(method, &opt)?; - let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe { + let (cq_f, _) = check_run(BatchType::CheckRead, |ctx, tag| unsafe { grpc_sys::grpcwrap_call_start_client_streaming( call.call, ctx, @@ -158,7 +158,7 @@ impl Call { let call = channel.create_call(method, &opt)?; let mut payload = vec![]; (method.req_ser())(req, &mut payload); - let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe { + let (cq_f, _) = check_run(BatchType::Finish, |ctx, tag| unsafe { grpc_sys::grpcwrap_call_start_server_streaming( call.call, ctx, @@ -187,7 +187,7 @@ impl Call { mut opt: CallOption, ) -> Result<(ClientDuplexSender, ClientDuplexReceiver)> { let call = channel.create_call(method, &opt)?; - let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe { + let (cq_f, _) = check_run(BatchType::Finish, |ctx, tag| unsafe { grpc_sys::grpcwrap_call_start_duplex_streaming( call.call, ctx, @@ -410,8 +410,11 @@ struct ResponseStreamImpl { read_done: bool, finished: bool, resp_de: DeserializeFn, + tag: *mut CallTag, } +unsafe impl Send for ResponseStreamImpl {} + impl ResponseStreamImpl { fn new(call: H, resp_de: DeserializeFn) -> ResponseStreamImpl { ResponseStreamImpl { @@ -420,6 +423,7 @@ impl ResponseStreamImpl { read_done: false, finished: false, resp_de, + tag: ptr::null_mut(), } } @@ -457,7 +461,8 @@ impl ResponseStreamImpl { // so msg_f must be either stale or not initialised yet. self.msg_f.take(); - let msg_f = self.call.call(|c| c.call.start_recv_message())?; + let tag = &mut self.tag; + let msg_f = self.call.call(|c| c.call.start_recv_message(tag))?; self.msg_f = Some(msg_f); if let Some(data) = bytes { let msg = (self.resp_de)(data)?; @@ -475,6 +480,12 @@ impl ResponseStreamImpl { } } +impl Drop for ResponseStreamImpl { + fn drop(&mut self) { + unsafe { unref_raw_tag(self.tag) } + } +} + /// A receiver for server streaming call. #[must_use = "if unused the ClientSStreamReceiver may immediately cancel the RPC"] pub struct ClientSStreamReceiver { diff --git a/src/call/mod.rs b/src/call/mod.rs index 5a147fa00..2a71e690c 100644 --- a/src/call/mod.rs +++ b/src/call/mod.rs @@ -15,7 +15,7 @@ use crate::buf::{GrpcByteBuffer, GrpcByteBufferReader}; use crate::codec::{DeserializeFn, Marshaller, SerializeFn}; use crate::error::{Error, Result}; use crate::grpc_sys::grpc_status_code::*; -use crate::task::{self, BatchFuture, BatchType, CallTag, SpinLock}; +use crate::task::{self, unref_raw_tag, BatchFuture, BatchType, CallTag, SpinLock}; // By default buffers in `SinkBase` will be shrink to 4K size. const BUF_SHRINK_SIZE: usize = 4 * 1024; @@ -182,6 +182,15 @@ impl BatchContext { } } + pub fn take_send_message(&self) -> Option { + let ptr = unsafe { grpc_sys::grpcwrap_batch_context_take_send_message(self.ctx) }; + if ptr.is_null() { + None + } else { + Some(unsafe { GrpcByteBuffer::from_raw(ptr) }) + } + } + /// Get the status of the rpc call. pub fn rpc_status(&self) -> RpcStatus { let status = RpcStatusCode(unsafe { @@ -228,7 +237,7 @@ fn box_batch_tag(tag: CallTag) -> (*mut grpcwrap_batch_context, *mut c_void) { } /// A helper function that runs the batch call and checks the result. -fn check_run(bt: BatchType, f: F) -> BatchFuture +fn check_run(bt: BatchType, f: F) -> (BatchFuture, *mut CallTag) where F: FnOnce(*mut grpcwrap_batch_context, *mut c_void) -> grpc_call_error, { @@ -236,12 +245,28 @@ where let (batch_ptr, tag_ptr) = box_batch_tag(tag); let code = f(batch_ptr, tag_ptr); if code != grpc_call_error::GRPC_CALL_OK { - unsafe { - Box::from_raw(tag_ptr); - } + drop(unsafe { Box::from_raw(tag_ptr) }); panic!("create call fail: {:?}", code); } - cq_f + (cq_f, tag_ptr as *mut CallTag) +} + +fn check_run_with_tag(tag: *mut CallTag, f: F) -> (BatchFuture, *mut CallTag) +where + F: FnOnce(*mut grpcwrap_batch_context, *mut c_void) -> grpc_call_error, +{ + unsafe { + let cq_f = match &*tag { + CallTag::Batch(promise) => promise.cq_future(), + _ => unreachable!(), + }; + let ctx = (*tag).batch_ctx().unwrap().as_ptr(); + let code = f(ctx, tag as *mut c_void); + if code != grpc_call_error::GRPC_CALL_OK { + panic!("create call fail: {:?}", code); + } + (cq_f, tag) + } } /// A Call represents an RPC. @@ -268,38 +293,55 @@ impl Call { msg: &[u8], write_flags: u32, initial_meta: bool, + batch: &mut *mut CallTag, ) -> Result { let _cq_ref = self.cq.borrow()?; + let ptr = msg.as_ptr() as _; + let len = msg.len(); let i = if initial_meta { 1 } else { 0 }; - let f = check_run(BatchType::Finish, |ctx, tag| unsafe { - grpc_sys::grpcwrap_call_send_message( - self.call, - ctx, - msg.as_ptr() as _, - msg.len(), - write_flags, - i, - tag, - ) - }); + let send_message = |ctx, tag| unsafe { + match *(tag as *mut CallTag) { + CallTag::Batch(ref prom) => prom.ref_batch(), + _ => unreachable!(), + } + grpc_sys::grpcwrap_call_send_message(self.call, ctx, ptr, len, write_flags, i, tag) + }; + + let (f, tag) = if !batch.is_null() { + check_run_with_tag(*batch, send_message) + } else { + check_run(BatchType::Finish, send_message) + }; + *batch = tag; Ok(f) } /// Finish the rpc call from client. pub fn start_send_close_client(&mut self) -> Result { let _cq_ref = self.cq.borrow()?; - let f = check_run(BatchType::Finish, |_, tag| unsafe { + let (f, _) = check_run(BatchType::Finish, |_, tag| unsafe { grpc_sys::grpcwrap_call_send_close_from_client(self.call, tag) }); Ok(f) } /// Receive a message asynchronously. - pub fn start_recv_message(&mut self) -> Result { + pub fn start_recv_message(&mut self, batch: &mut *mut CallTag) -> Result { let _cq_ref = self.cq.borrow()?; - let f = check_run(BatchType::Read, |ctx, tag| unsafe { + let recv_message = |ctx, tag| unsafe { + match *(tag as *mut CallTag) { + CallTag::Batch(ref prom) => prom.ref_batch(), + _ => unreachable!(), + } grpc_sys::grpcwrap_call_recv_message(self.call, ctx, tag) - }); + }; + + let (f, tag) = if !batch.is_null() { + check_run_with_tag(*batch, recv_message) + } else { + check_run(BatchType::Read, recv_message) + }; + *batch = tag; Ok(f) } @@ -308,7 +350,7 @@ impl Call { /// Future will finish once close is received by the server. pub fn start_server_side(&mut self) -> Result { let _cq_ref = self.cq.borrow()?; - let f = check_run(BatchType::Finish, |ctx, tag| unsafe { + let (f, _) = check_run(BatchType::Finish, |ctx, tag| unsafe { grpc_sys::grpcwrap_call_start_serverside(self.call, ctx, tag) }); Ok(f) @@ -327,7 +369,7 @@ impl Call { let (payload_ptr, payload_len) = payload .as_ref() .map_or((ptr::null(), 0), |b| (b.as_ptr(), b.len())); - let f = check_run(BatchType::Finish, |ctx, tag| unsafe { + let (f, _) = check_run(BatchType::Finish, |ctx, tag| unsafe { let details_ptr = status .details .as_ref() @@ -487,14 +529,21 @@ struct StreamingBase { close_f: Option, msg_f: Option, read_done: bool, + + // `tag` can be reused during the stream's lifetime. + tag: *mut CallTag, } +// Because it carrys a `CallTag`. +unsafe impl Send for StreamingBase {} + impl StreamingBase { fn new(close_f: Option) -> StreamingBase { StreamingBase { close_f, msg_f: None, read_done: false, + tag: ptr::null_mut(), } } @@ -539,7 +588,7 @@ impl StreamingBase { // so msg_f must be either stale or not initialised yet. self.msg_f.take(); - let msg_f = call.call(|c| c.call.start_recv_message())?; + let msg_f = call.call(|c| c.call.start_recv_message(&mut self.tag))?; self.msg_f = Some(msg_f); if bytes.is_none() { self.poll(call, true) @@ -557,6 +606,12 @@ impl StreamingBase { } } +impl Drop for StreamingBase { + fn drop(&mut self) { + unsafe { unref_raw_tag(self.tag) } + } +} + /// Flags for write operations. #[derive(Default, Clone, Copy)] pub struct WriteFlags { @@ -603,14 +658,20 @@ struct SinkBase { batch_f: Option, buf: Vec, send_metadata: bool, + + tag: *mut CallTag, } +// Because it carrys a `CallTag`. +unsafe impl Send for SinkBase {} + impl SinkBase { fn new(send_metadata: bool) -> SinkBase { SinkBase { batch_f: None, buf: Vec::new(), send_metadata, + tag: ptr::null_mut(), } } @@ -637,7 +698,7 @@ impl SinkBase { } let write_f = call.call(|c| { c.call - .start_send_message(&self.buf, flags.flags, self.send_metadata) + .start_send_message(&self.buf, flags.flags, self.send_metadata, &mut self.tag) })?; // NOTE: Content of `self.buf` is copied into grpc internal. if self.buf.capacity() > BUF_SHRINK_SIZE { @@ -658,3 +719,9 @@ impl SinkBase { Ok(Async::Ready(())) } } + +impl Drop for SinkBase { + fn drop(&mut self) { + unsafe { unref_raw_tag(self.tag) } + } +} diff --git a/src/env.rs b/src/env.rs index 3f5b6e14f..68f99d46b 100644 --- a/src/env.rs +++ b/src/env.rs @@ -8,7 +8,7 @@ use std::thread::{Builder as ThreadBuilder, JoinHandle}; use crate::grpc_sys; use crate::cq::{CompletionQueue, CompletionQueueHandle, EventType, WorkQueue}; -use crate::task::CallTag; +use crate::task::{self, CallTag}; // event loop fn poll_queue(tx: mpsc::Sender) { @@ -26,8 +26,7 @@ fn poll_queue(tx: mpsc::Sender) { } let tag: Box = unsafe { Box::from_raw(e.tag as _) }; - - tag.resolve(&cq, e.success != 0); + task::resolve(tag, &cq, e.success != 0); while let Some(work) = unsafe { cq.worker.pop_work() } { work.finish(&cq); } diff --git a/src/task/mod.rs b/src/task/mod.rs index cdce8f231..6d2701514 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -13,7 +13,7 @@ use futures::{Async, Future, Poll}; use self::callback::{Abort, Request as RequestCallback, UnaryRequest as UnaryRequestCallback}; use self::executor::SpawnTask; -use self::promise::{Batch as BatchPromise, Shutdown as ShutdownPromise}; +use self::promise::Shutdown as ShutdownPromise; use crate::call::server::RequestContext; use crate::call::{BatchContext, Call, MessageReader}; use crate::cq::CompletionQueue; @@ -22,7 +22,7 @@ use crate::server::RequestCallContext; pub(crate) use self::executor::{Executor, Kicker, UnfinishedWork}; pub use self::lock::SpinLock; -pub use self::promise::BatchType; +pub use self::promise::{Batch as BatchPromise, BatchType}; /// A handle that is used to notify future that the task finishes. pub struct NotifyHandle { @@ -46,6 +46,12 @@ impl NotifyHandle { self.task.take() } + + fn reset(&mut self) { + debug_assert!(self.task.is_none()); + self.result = None; + self.stale = false; + } } type Inner = SpinLock>; @@ -168,18 +174,39 @@ impl CallTag { _ => None, } } +} - /// Resolve the CallTag with given status. - pub fn resolve(self, cq: &CompletionQueue, success: bool) { - match self { - CallTag::Batch(prom) => prom.resolve(success), - CallTag::Request(cb) => cb.resolve(cq, success), - CallTag::UnaryRequest(cb) => cb.resolve(cq, success), - CallTag::Abort(_) => {} - CallTag::Shutdown(prom) => prom.resolve(success), - CallTag::Spawn(notify) => self::executor::resolve(cq, notify, success), +/// Resolve the CallTag with given status. +pub fn resolve(tag: Box, cq: &CompletionQueue, success: bool) { + let raw = Box::into_raw(tag); + if let CallTag::Batch(ref mut prom) = unsafe { &mut *raw } { + if prom.resolve(success) { + // Return directly, skip to drop the `CallTag`. + return; } } + match unsafe { *Box::from_raw(raw) } { + CallTag::Batch(_) => {} // Already handled on above. + CallTag::Request(cb) => cb.resolve(cq, success), + CallTag::UnaryRequest(cb) => cb.resolve(cq, success), + CallTag::Abort(_) => {} + CallTag::Shutdown(prom) => prom.resolve(success), + CallTag::Spawn(notify) => self::executor::resolve(cq, notify, success), + } +} + +/// Unref a `CallTag` which must be on heap. +pub(crate) unsafe fn unref_raw_tag(tag: *mut CallTag) { + if tag.is_null() { + return; + } + let in_resolving = match *tag { + CallTag::Batch(ref prom) => prom.unref_batch(), + _ => false, + }; + if !in_resolving { + drop(Box::from_raw(tag)); + } } impl Debug for CallTag { @@ -218,11 +245,11 @@ mod tests { }); assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); - tag1.resolve(&env.pick_cq(), true); + resolve(Box::new(tag1), &env.pick_cq(), true); assert!(rx.recv().unwrap().is_ok()); assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); - tag2.resolve(&env.pick_cq(), false); + resolve(Box::new(tag2), &env.pick_cq(), false); match rx.recv() { Ok(Err(Error::ShutdownFailed)) => {} res => panic!("expect shutdown failed, but got {:?}", res), diff --git a/src/task/promise.rs b/src/task/promise.rs index 9399f3e56..6923b2320 100644 --- a/src/task/promise.rs +++ b/src/task/promise.rs @@ -1,11 +1,13 @@ // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. use std::fmt::{self, Debug, Formatter}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use super::Inner; use crate::call::{BatchContext, MessageReader, RpcStatusCode}; use crate::error::Error; +use crate::task::CqFuture; /// Batch job type. #[derive(PartialEq, Debug)] @@ -23,6 +25,7 @@ pub struct Batch { ty: BatchType, ctx: BatchContext, inner: Arc>>, + ref_count: AtomicUsize, } impl Batch { @@ -31,6 +34,7 @@ impl Batch { ty, ctx: BatchContext::new(), inner, + ref_count: AtomicUsize::new(1), } } @@ -38,6 +42,13 @@ impl Batch { &self.ctx } + /// Create a future which will be notified after the batch is resolved. + pub fn cq_future(&self) -> CqFuture> { + let mut guard = self.inner.lock(); + guard.reset(); + CqFuture::new(self.inner.clone()) + } + fn read_one_msg(&mut self, success: bool) { let task = { let mut guard = self.inner.lock(); @@ -82,7 +93,8 @@ impl Batch { task.map(|t| t.notify()); } - pub fn resolve(mut self, success: bool) { + /// Return `true` means the tag can be reused. + pub fn resolve(&mut self, success: bool) -> bool { match self.ty { BatchType::CheckRead => { assert!(success); @@ -90,11 +102,25 @@ impl Batch { } BatchType::Finish => { self.finish_response(success); + drop(self.ctx.take_send_message()); + return self.unref_batch(); } BatchType::Read => { self.read_one_msg(success); + return self.unref_batch(); } } + false + } + + /// Ref the `Batch` before call `grpc_call_start_batch`. + pub fn ref_batch(&self) { + self.ref_count.fetch_add(1, Ordering::Release); + } + + /// Return `true` means the tag can be reused. + pub fn unref_batch(&self) -> bool { + self.ref_count.fetch_sub(1, Ordering::AcqRel) > 1 } }