Skip to content

Commit fa48d70

Browse files
committed
feat: seek quorum for readonly session
Closes #30.
1 parent 376b21d commit fa48d70

File tree

9 files changed

+600
-313
lines changed

9 files changed

+600
-313
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,4 @@ assertor = "0.0.2"
4646
assert_matches = "1.5.0"
4747
tempfile = "3.6.0"
4848
rcgen = { version = "0.12.1", features = ["default", "x509-parser"] }
49+
serial_test = "3.0.0"

src/client/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1577,7 +1577,7 @@ impl Connector {
15771577
self
15781578
}
15791579

1580-
/// Specifies whether readonly server is allowed.
1580+
/// Specifies whether readonly session is allowed.
15811581
pub fn readonly(&mut self, readonly: bool) -> &mut Self {
15821582
self.readonly = readonly;
15831583
self
@@ -1717,7 +1717,7 @@ impl ClientBuilder {
17171717
self
17181718
}
17191719

1720-
/// Specifies whether readonly server is allowed.
1720+
/// Specifies whether readonly session is allowed.
17211721
pub fn with_readonly(&mut self, readonly: bool) -> &mut ClientBuilder {
17221722
self.connector.readonly = readonly;
17231723
self

src/deadline.rs

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
use std::future::Future;
2+
use std::pin::Pin;
3+
use std::task::{Context, Poll};
4+
5+
use tokio::time::{self, Instant, Sleep};
6+
7+
pub struct Deadline {
8+
sleep: Option<Sleep>,
9+
}
10+
11+
impl Deadline {
12+
pub fn never() -> Self {
13+
Self { sleep: None }
14+
}
15+
16+
pub fn until(deadline: Instant) -> Self {
17+
Self { sleep: Some(time::sleep_until(deadline)) }
18+
}
19+
20+
pub fn elapsed(&self) -> bool {
21+
self.sleep.as_ref().map(|f| f.is_elapsed()).unwrap_or(false)
22+
}
23+
}
24+
25+
impl Future for Deadline {
26+
type Output = ();
27+
28+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
29+
if self.sleep.is_none() {
30+
return Poll::Pending;
31+
}
32+
let sleep = unsafe { self.map_unchecked_mut(|deadline| deadline.sleep.as_mut().unwrap_unchecked()) };
33+
sleep.poll(cx)
34+
}
35+
}

src/endpoint.rs

+12
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ impl IterableEndpoints {
165165
Self { cycle: false, next: 0, endpoints: endpoints.into() }
166166
}
167167

168+
pub fn len(&self) -> usize {
169+
self.endpoints.len()
170+
}
171+
168172
pub fn endpoints(&self) -> &[Endpoint] {
169173
&self.endpoints
170174
}
@@ -192,6 +196,14 @@ impl IterableEndpoints {
192196
self.next = 0;
193197
}
194198
}
199+
200+
pub fn peek(&self) -> Option<EndpointRef<'_>> {
201+
let next = self.next;
202+
if next >= self.endpoints.len() {
203+
return None;
204+
}
205+
Some(self.endpoints[next].to_ref())
206+
}
195207
}
196208

197209
impl From<&[EndpointRef<'_>]> for IterableEndpoints {

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod acl;
22
mod chroot;
33
mod client;
4+
mod deadline;
45
mod endpoint;
56
mod error;
67
mod proto;

src/proto/connect.rs

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use bytes::BufMut;
2+
use derive_where::derive_where;
23

34
use crate::record::{
45
DeserializableRecord,
@@ -10,11 +11,13 @@ use crate::record::{
1011
UnsafeBuf,
1112
};
1213

14+
#[derive_where(Debug)]
1315
pub struct ConnectRequest<'a> {
1416
pub protocol_version: i32,
1517
pub last_zxid_seen: i64,
1618
pub timeout: i32,
1719
pub session_id: i64,
20+
#[derive_where(skip(Debug))]
1821
pub password: &'a [u8],
1922
pub readonly: bool,
2023
}
@@ -36,11 +39,13 @@ impl DynamicRecord for ConnectRequest<'_> {
3639
}
3740
}
3841

42+
#[derive_where(Debug)]
3943
pub struct ConnectResponse<'a> {
4044
#[allow(dead_code)]
4145
pub protocol_version: i32,
4246
pub session_timeout: i32,
4347
pub session_id: i64,
48+
#[derive_where(skip(Debug))]
4449
pub password: &'a [u8],
4550
pub readonly: bool,
4651
}
@@ -49,7 +54,7 @@ impl<'a> DeserializableRecord<'a> for ConnectResponse<'a> {
4954
type Error = DeserializeError;
5055

5156
fn deserialize(buf: &mut ReadingBuf<'a>) -> Result<Self, Self::Error> {
52-
let min_record_len = 4 + 4 + 8 + 4 + 1;
57+
let min_record_len = 4 + 4 + 8 + 4;
5358
if buf.len() < min_record_len {
5459
return Err(DeserializeError::InsufficientBuf);
5560
}
@@ -64,12 +69,12 @@ impl<'a> DeserializableRecord<'a> for ConnectResponse<'a> {
6469
)));
6570
}
6671
let len = unsafe { buf.get_unchecked_i32() };
67-
if len <= 0 || len != (buf.len() - 1) as i32 {
72+
if len <= 0 || len >= buf.len() as i32 {
6873
return Err(DeserializeError::UnmarshalError(format!("invalid session password length {len}")));
6974
}
7075
let len = len as usize;
7176
let password = unsafe { buf.get_unchecked(..len) };
72-
let readonly = unsafe { *buf.get_unchecked(len) };
77+
let readonly = if buf.is_empty() { 0 } else { unsafe { *buf.get_unchecked(len) } };
7378
if readonly != 0 && readonly != 1 {
7479
return Err(DeserializeError::UnmarshalError(format!("invalid session readonly value {readonly}")));
7580
}

src/session/connection.rs

+98-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1-
use std::io::{ErrorKind, IoSlice, Result};
1+
use std::io::{Error, ErrorKind, IoSlice, Result};
22
use std::pin::Pin;
33
use std::ptr;
4+
use std::sync::Arc;
45
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
6+
use std::time::Duration;
57

68
use bytes::buf::BufMut;
7-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9+
use ignore_result::Ignore;
10+
use rustls::pki_types::ServerName;
11+
use rustls::ClientConfig;
12+
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
813
use tokio::net::TcpStream;
14+
use tokio::{select, time};
915
use tokio_rustls::client::TlsStream;
16+
use tokio_rustls::TlsConnector;
17+
18+
use crate::deadline::Deadline;
19+
use crate::endpoint::{EndpointRef, IterableEndpoints};
1020

1121
const NOOP_VTABLE: RawWakerVTable =
1222
RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
@@ -126,4 +136,90 @@ impl Connection {
126136
Poll::Ready(result) => result,
127137
}
128138
}
139+
140+
pub async fn command(self, cmd: &str) -> Result<String> {
141+
let mut stream = BufStream::new(self);
142+
stream.write_all(cmd.as_bytes()).await?;
143+
stream.flush().await?;
144+
let mut line = String::new();
145+
stream.read_line(&mut line).await?;
146+
stream.shutdown().await.ignore();
147+
Ok(line)
148+
}
149+
150+
pub async fn command_isro(self) -> Result<bool> {
151+
let r = self.command("isro").await?;
152+
if r == "rw" {
153+
Ok(true)
154+
} else {
155+
Ok(false)
156+
}
157+
}
158+
}
159+
160+
#[derive(Clone)]
161+
pub struct Connector {
162+
tls: TlsConnector,
163+
timeout: Duration,
164+
}
165+
166+
impl Connector {
167+
pub fn new(config: impl Into<Arc<ClientConfig>>) -> Self {
168+
Self { tls: TlsConnector::from(config.into()), timeout: Duration::from_secs(10) }
169+
}
170+
171+
pub fn timeout(&self) -> Duration {
172+
self.timeout
173+
}
174+
175+
pub fn set_timeout(&mut self, timeout: Duration) {
176+
self.timeout = timeout;
177+
}
178+
179+
pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
180+
select! {
181+
_ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")),
182+
_ = time::sleep(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
183+
r = TcpStream::connect((endpoint.host, endpoint.port)) => {
184+
match r {
185+
Err(err) => Err(err),
186+
Ok(sock) => {
187+
let connection = if endpoint.tls {
188+
let domain = ServerName::try_from(endpoint.host).unwrap().to_owned();
189+
let stream = self.tls.connect(domain, sock).await?;
190+
Connection::new_tls(stream)
191+
} else {
192+
Connection::new_raw(sock)
193+
};
194+
Ok(connection)
195+
},
196+
}
197+
},
198+
}
199+
}
200+
201+
pub async fn seek_for_writable(self, endpoints: &mut IterableEndpoints) -> Option<EndpointRef<'_>> {
202+
let n = endpoints.len();
203+
let max_timeout = Duration::from_secs(60);
204+
let mut i = 0;
205+
let mut timeout = Duration::from_millis(100);
206+
let mut deadline = Deadline::never();
207+
while let Some(endpoint) = endpoints.peek() {
208+
i += 1;
209+
if let Ok(conn) = self.connect(endpoint, &mut deadline).await {
210+
if let Ok(true) = conn.command_isro().await {
211+
return Some(unsafe { std::mem::transmute(endpoint) });
212+
}
213+
}
214+
endpoints.step();
215+
if i % n == 0 {
216+
log::debug!("ZooKeeper fails to contact writable server from endpoints {:?}", endpoints.endpoints());
217+
time::sleep(timeout).await;
218+
timeout = max_timeout.min(timeout * 2);
219+
} else {
220+
time::sleep(Duration::from_millis(5)).await;
221+
}
222+
}
223+
None
224+
}
129225
}

0 commit comments

Comments
 (0)