|
1 |
| -use std::io::{ErrorKind, IoSlice, Result}; |
| 1 | +use std::io::{Error, ErrorKind, IoSlice, Result}; |
2 | 2 | use std::pin::Pin;
|
3 | 3 | use std::ptr;
|
| 4 | +use std::sync::Arc; |
4 | 5 | use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
|
| 6 | +use std::time::Duration; |
5 | 7 |
|
6 | 8 | use bytes::buf::BufMut;
|
7 |
| -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| 9 | +use ignore_result::Ignore; |
| 10 | +use rustls::pki_types::ServerName; |
| 11 | +use rustls::ClientConfig; |
| 12 | +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf}; |
8 | 13 | use tokio::net::TcpStream;
|
| 14 | +use tokio::{select, time}; |
9 | 15 | use tokio_rustls::client::TlsStream;
|
| 16 | +use tokio_rustls::TlsConnector; |
| 17 | + |
| 18 | +use crate::deadline::Deadline; |
| 19 | +use crate::endpoint::{EndpointRef, IterableEndpoints}; |
10 | 20 |
|
11 | 21 | const NOOP_VTABLE: RawWakerVTable =
|
12 | 22 | RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
|
@@ -126,4 +136,90 @@ impl Connection {
|
126 | 136 | Poll::Ready(result) => result,
|
127 | 137 | }
|
128 | 138 | }
|
| 139 | + |
| 140 | + pub async fn command(self, cmd: &str) -> Result<String> { |
| 141 | + let mut stream = BufStream::new(self); |
| 142 | + stream.write_all(cmd.as_bytes()).await?; |
| 143 | + stream.flush().await?; |
| 144 | + let mut line = String::new(); |
| 145 | + stream.read_line(&mut line).await?; |
| 146 | + stream.shutdown().await.ignore(); |
| 147 | + Ok(line) |
| 148 | + } |
| 149 | + |
| 150 | + pub async fn command_isro(self) -> Result<bool> { |
| 151 | + let r = self.command("isro").await?; |
| 152 | + if r == "rw" { |
| 153 | + Ok(true) |
| 154 | + } else { |
| 155 | + Ok(false) |
| 156 | + } |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +#[derive(Clone)] |
| 161 | +pub struct Connector { |
| 162 | + tls: TlsConnector, |
| 163 | + timeout: Duration, |
| 164 | +} |
| 165 | + |
| 166 | +impl Connector { |
| 167 | + pub fn new(config: impl Into<Arc<ClientConfig>>) -> Self { |
| 168 | + Self { tls: TlsConnector::from(config.into()), timeout: Duration::from_secs(10) } |
| 169 | + } |
| 170 | + |
| 171 | + pub fn timeout(&self) -> Duration { |
| 172 | + self.timeout |
| 173 | + } |
| 174 | + |
| 175 | + pub fn set_timeout(&mut self, timeout: Duration) { |
| 176 | + self.timeout = timeout; |
| 177 | + } |
| 178 | + |
| 179 | + pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> { |
| 180 | + select! { |
| 181 | + _ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")), |
| 182 | + _ = time::sleep(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))), |
| 183 | + r = TcpStream::connect((endpoint.host, endpoint.port)) => { |
| 184 | + match r { |
| 185 | + Err(err) => Err(err), |
| 186 | + Ok(sock) => { |
| 187 | + let connection = if endpoint.tls { |
| 188 | + let domain = ServerName::try_from(endpoint.host).unwrap().to_owned(); |
| 189 | + let stream = self.tls.connect(domain, sock).await?; |
| 190 | + Connection::new_tls(stream) |
| 191 | + } else { |
| 192 | + Connection::new_raw(sock) |
| 193 | + }; |
| 194 | + Ok(connection) |
| 195 | + }, |
| 196 | + } |
| 197 | + }, |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + pub async fn seek_for_writable(self, endpoints: &mut IterableEndpoints) -> Option<EndpointRef<'_>> { |
| 202 | + let n = endpoints.len(); |
| 203 | + let max_timeout = Duration::from_secs(60); |
| 204 | + let mut i = 0; |
| 205 | + let mut timeout = Duration::from_millis(100); |
| 206 | + let mut deadline = Deadline::never(); |
| 207 | + while let Some(endpoint) = endpoints.peek() { |
| 208 | + i += 1; |
| 209 | + if let Ok(conn) = self.connect(endpoint, &mut deadline).await { |
| 210 | + if let Ok(true) = conn.command_isro().await { |
| 211 | + return Some(unsafe { std::mem::transmute(endpoint) }); |
| 212 | + } |
| 213 | + } |
| 214 | + endpoints.step(); |
| 215 | + if i % n == 0 { |
| 216 | + log::debug!("ZooKeeper fails to contact writable server from endpoints {:?}", endpoints.endpoints()); |
| 217 | + time::sleep(timeout).await; |
| 218 | + timeout = max_timeout.min(timeout * 2); |
| 219 | + } else { |
| 220 | + time::sleep(Duration::from_millis(5)).await; |
| 221 | + } |
| 222 | + } |
| 223 | + None |
| 224 | + } |
129 | 225 | }
|
0 commit comments