Skip to content

Commit 45dec8d

Browse files
committed
refactor: replace typeless HostPort with Endpoint struct
1 parent 8657d58 commit 45dec8d

File tree

5 files changed

+334
-205
lines changed

5 files changed

+334
-205
lines changed

src/client/mod.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
1616
use super::session::{Depot, MarshalledRequest, Session, SessionOperation, WatchReceiver, PASSWORD_LEN};
1717
use crate::acl::{Acl, Acls, AuthUser};
1818
use crate::chroot::{Chroot, ChrootPath, OwnedChroot};
19+
use crate::endpoint::{self, IterableEndpoints};
1920
use crate::error::Error;
2021
use crate::proto::{
2122
self,
@@ -45,7 +46,7 @@ use crate::record::{self, Record, StaticRecord};
4546
use crate::session::StateReceiver;
4647
pub use crate::session::{EventType, SessionId, SessionState, WatchedEvent};
4748
use crate::tls::TlsOptions;
48-
use crate::util::{self, Ref as _};
49+
use crate::util;
4950

5051
pub(crate) type Result<T, E = Error> = std::result::Result<T, E>;
5152

@@ -1621,7 +1622,7 @@ impl Connector {
16211622
}
16221623

16231624
async fn connect_internally(&mut self, secure: bool, cluster: &str) -> Result<Client> {
1624-
let (hosts, chroot) = util::parse_connect_string(cluster, secure)?;
1625+
let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
16251626
if let Some((id, password)) = &self.session {
16261627
if id.0 == 0 {
16271628
return Err(Error::BadArguments(&"session id must not be 0"));
@@ -1647,16 +1648,15 @@ impl Connector {
16471648
self.session_timeout,
16481649
self.connection_timeout,
16491650
);
1650-
let mut hosts_iter = hosts.iter().copied();
1651+
let mut endpoints = IterableEndpoints::from(endpoints.as_slice());
16511652
let mut buf = Vec::with_capacity(4096);
16521653
let mut connecting_depot = Depot::for_connecting();
1653-
let conn = session.start(&mut hosts_iter, &mut buf, &mut connecting_depot).await?;
1654+
let conn = session.start(&mut endpoints, &mut buf, &mut connecting_depot).await?;
16541655
let (sender, receiver) = mpsc::unbounded_channel();
1655-
let servers = hosts.into_iter().map(|addr| addr.to_value()).collect();
16561656
let session_info = (session.session_id, session.session_password.clone());
16571657
let session_timeout = session.session_timeout;
16581658
tokio::spawn(async move {
1659-
session.serve(servers, conn, buf, connecting_depot, receiver).await;
1659+
session.serve(endpoints, conn, buf, connecting_depot, receiver).await;
16601660
});
16611661
let client =
16621662
Client::new(chroot.to_owned(), self.server_version, session_info, session_timeout, sender, state_receiver);

src/endpoint.rs

+290
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
use std::fmt::{self, Display, Formatter};
2+
3+
use crate::chroot::Chroot;
4+
use crate::error::Error;
5+
use crate::util::{Ref, ToRef};
6+
7+
#[derive(Debug, Clone, PartialEq, Eq)]
8+
pub struct Endpoint {
9+
pub host: String,
10+
pub port: u16,
11+
pub tls: bool,
12+
}
13+
14+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
15+
pub struct EndpointRef<'a> {
16+
pub host: &'a str,
17+
pub port: u16,
18+
pub tls: bool,
19+
}
20+
21+
impl Endpoint {
22+
pub fn new(host: impl Into<String>, port: u16, tls: bool) -> Self {
23+
Self { host: host.into(), port, tls }
24+
}
25+
}
26+
27+
impl<'a> EndpointRef<'a> {
28+
pub fn new(host: &'a str, port: u16, tls: bool) -> Self {
29+
Self { host, port, tls }
30+
}
31+
}
32+
33+
impl Display for Endpoint {
34+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35+
self.to_ref().fmt(f)
36+
}
37+
}
38+
39+
impl<'a> From<(&'a str, u16, bool)> for EndpointRef<'a> {
40+
fn from(v: (&'a str, u16, bool)) -> Self {
41+
Self::new(v.0, v.1, v.2)
42+
}
43+
}
44+
45+
impl Display for EndpointRef<'_> {
46+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
47+
let proto = if self.tls { "tcp" } else { "tcp+tls" };
48+
write!(f, "{}://{}:{}", proto, self.host, self.port)
49+
}
50+
}
51+
52+
impl PartialEq<(&str, u16, bool)> for EndpointRef<'_> {
53+
fn eq(&self, other: &(&str, u16, bool)) -> bool {
54+
self.host == other.0 && self.port == other.1 && self.tls == other.2
55+
}
56+
}
57+
58+
impl<'a> ToRef<'a, EndpointRef<'a>> for Endpoint {
59+
fn to_ref(&'a self) -> EndpointRef<'a> {
60+
return EndpointRef::new(self.host.as_str(), self.port, self.tls);
61+
}
62+
}
63+
64+
impl<'a> Ref<'a> for EndpointRef<'a> {
65+
type Value = Endpoint;
66+
67+
fn to_value(&self) -> Endpoint {
68+
Endpoint::new(self.host.to_owned(), self.port, self.tls)
69+
}
70+
}
71+
72+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73+
struct InvalidAddress(&'static &'static str);
74+
75+
impl From<InvalidAddress> for Error {
76+
fn from(_: InvalidAddress) -> Error {
77+
Error::BadArguments(&"invalid address")
78+
}
79+
}
80+
81+
fn parse_host_port(host: &str, port: &str) -> Result<u16, InvalidAddress> {
82+
if host.is_empty() {
83+
return Err(InvalidAddress(&"empty host"));
84+
}
85+
if port.is_empty() {
86+
return Ok(2181);
87+
}
88+
let port = match port.parse::<u16>() {
89+
Err(_) => return Err(InvalidAddress(&"invalid port")),
90+
Ok(port) => port,
91+
};
92+
if port == 0 {
93+
return Err(InvalidAddress(&"invalid port number"));
94+
}
95+
Ok(port)
96+
}
97+
98+
fn parse_address(s: &str) -> Result<(&str, u16), InvalidAddress> {
99+
let (host, port_str) = if s.starts_with('[') {
100+
let i = match s.rfind(']') {
101+
None => return Err(InvalidAddress(&"invalid address")),
102+
Some(i) => i,
103+
};
104+
let host = &s[1..i];
105+
let mut remains = &s[i + 1..];
106+
if !remains.is_empty() {
107+
if remains.as_bytes()[0] != b':' {
108+
return Err(InvalidAddress(&"invalid address"));
109+
}
110+
remains = &remains[1..];
111+
}
112+
(host, remains)
113+
} else {
114+
match s.rfind(':') {
115+
None => (s, Default::default()),
116+
Some(i) => (&s[..i], &s[i + 1..]),
117+
}
118+
};
119+
let port = parse_host_port(host, port_str)?;
120+
Ok((host, port))
121+
}
122+
123+
/// Parses connection string to host port pairs and chroot.
124+
pub fn parse_connect_string(s: &str, tls: bool) -> Result<(Vec<EndpointRef<'_>>, Chroot<'_>), Error> {
125+
let mut chroot = None;
126+
let mut endpoints = Vec::with_capacity(10);
127+
for s in s.rsplit(',') {
128+
let (mut hostport, tls) = if let Some(s) = s.strip_prefix("tcp://") {
129+
(s, false)
130+
} else if let Some(s) = s.strip_prefix("tcp+tls://") {
131+
(s, true)
132+
} else if s.is_empty() {
133+
let err = if chroot.is_none() {
134+
Error::BadArguments(&"empty connect string")
135+
} else {
136+
Error::BadArguments(&"invalid address")
137+
};
138+
return Err(err);
139+
} else {
140+
(s, tls)
141+
};
142+
if chroot.is_none() {
143+
chroot = Some(Chroot::default());
144+
if let Some(i) = hostport.find('/') {
145+
chroot = Some(Chroot::new(&hostport[i..])?);
146+
hostport = &hostport[..i];
147+
}
148+
}
149+
let (host, port) = parse_address(hostport)?;
150+
endpoints.push(EndpointRef::new(host, port, tls));
151+
}
152+
endpoints.reverse();
153+
Ok((endpoints, chroot.unwrap()))
154+
}
155+
156+
#[derive(Clone, Debug)]
157+
pub struct IterableEndpoints {
158+
cycle: bool,
159+
next: usize,
160+
endpoints: Vec<Endpoint>,
161+
}
162+
163+
impl IterableEndpoints {
164+
pub fn new(endpoints: impl Into<Vec<Endpoint>>) -> Self {
165+
Self { cycle: false, next: 0, endpoints: endpoints.into() }
166+
}
167+
168+
pub fn endpoints(&self) -> &[Endpoint] {
169+
&self.endpoints
170+
}
171+
172+
pub fn cycle(&mut self) {
173+
if self.next >= self.endpoints.len() {
174+
self.next = 0;
175+
}
176+
self.cycle = true;
177+
}
178+
179+
pub fn next(&mut self) -> Option<EndpointRef<'_>> {
180+
let next = self.next;
181+
if next >= self.endpoints.len() {
182+
return None;
183+
}
184+
self.step();
185+
let host = &self.endpoints[next];
186+
Some(host.to_ref())
187+
}
188+
189+
pub fn step(&mut self) {
190+
self.next += 1;
191+
if self.cycle && self.next >= self.endpoints.len() {
192+
self.next = 0;
193+
}
194+
}
195+
}
196+
197+
impl From<&[EndpointRef<'_>]> for IterableEndpoints {
198+
fn from(endpoints: &[EndpointRef<'_>]) -> Self {
199+
let endpoints: Vec<_> = endpoints.iter().map(|endpoint| endpoint.to_value()).collect();
200+
Self::new(endpoints)
201+
}
202+
}
203+
204+
#[cfg(test)]
205+
mod tests {
206+
use pretty_assertions::assert_eq;
207+
208+
use crate::chroot::Chroot;
209+
use crate::error::Error;
210+
211+
#[test]
212+
fn test_parse_address_v4() {
213+
use super::{parse_address, InvalidAddress};
214+
assert_eq!(parse_address("fasl:0").unwrap_err(), InvalidAddress(&"invalid port number"));
215+
assert_eq!(parse_address(":1234").unwrap_err(), InvalidAddress(&"empty host"));
216+
assert_eq!(parse_address("fasl:a234").unwrap_err(), InvalidAddress(&"invalid port"));
217+
assert_eq!(parse_address("fasl:1234").unwrap(), ("fasl", 1234));
218+
assert_eq!(parse_address("fasl:2181").unwrap(), ("fasl", 2181));
219+
assert_eq!(parse_address("fasl").unwrap(), ("fasl", 2181));
220+
}
221+
222+
#[test]
223+
fn test_parse_address_v6() {
224+
use super::{parse_address, InvalidAddress};
225+
assert_eq!(parse_address("[fasl").unwrap_err(), InvalidAddress(&"invalid address"));
226+
assert_eq!(parse_address("[fasl]:0").unwrap_err(), InvalidAddress(&"invalid port number"));
227+
assert_eq!(parse_address("[]:1234").unwrap_err(), InvalidAddress(&"empty host"));
228+
assert_eq!(parse_address("[fasl]:a234").unwrap_err(), InvalidAddress(&"invalid port"));
229+
assert_eq!(parse_address("[fasl]:1234").unwrap(), ("fasl", 1234));
230+
assert_eq!(parse_address("[fasl]").unwrap(), ("fasl", 2181));
231+
assert_eq!(parse_address("[::1]:2181").unwrap(), ("::1", 2181));
232+
}
233+
234+
#[test]
235+
fn test_parse_connect_string() {
236+
use super::parse_connect_string;
237+
238+
assert_eq!(parse_connect_string("", false).unwrap_err(), Error::BadArguments(&"empty connect string"));
239+
assert_eq!(parse_connect_string("host1:abc", false).unwrap_err(), Error::BadArguments(&"invalid address"));
240+
assert_eq!(
241+
parse_connect_string("host1/abc/", true).unwrap_err(),
242+
Error::BadArguments(&"path must not end with '/'")
243+
);
244+
245+
assert_eq!(
246+
parse_connect_string("host1", false).unwrap(),
247+
(vec![("host1", 2181, false).into()], Chroot::default())
248+
);
249+
assert_eq!(
250+
parse_connect_string("host1", true).unwrap(),
251+
(vec![("host1", 2181, true).into()], Chroot::default())
252+
);
253+
assert_eq!(
254+
parse_connect_string("tcp+tls://host1", false).unwrap(),
255+
(vec![("host1", 2181, true).into()], Chroot::default())
256+
);
257+
assert_eq!(
258+
parse_connect_string("host1,host2:2222/", false).unwrap(),
259+
(vec![("host1", 2181, false).into(), ("host2", 2222, false).into()], Chroot::default())
260+
);
261+
assert_eq!(
262+
parse_connect_string("host1,host2:2222/abc", false).unwrap(),
263+
(vec![("host1", 2181, false).into(), ("host2", 2222, false).into()], Chroot::new("/abc").unwrap())
264+
);
265+
assert_eq!(
266+
parse_connect_string("host1,tcp+tls://host2:2222,tcp://host3/abc", true).unwrap(),
267+
(
268+
vec![("host1", 2181, true).into(), ("host2", 2222, true).into(), ("host3", 2181, false).into()],
269+
Chroot::new("/abc").unwrap()
270+
)
271+
);
272+
}
273+
274+
#[test]
275+
fn test_iterable_endpoints() {
276+
use super::{parse_connect_string, EndpointRef, IterableEndpoints};
277+
let (endpoints, _) = parse_connect_string("host1:2181,tcp://host2,tcp+tls://host3:2182", true).unwrap();
278+
let mut endpoints = IterableEndpoints::from(endpoints.as_slice());
279+
assert_eq!(endpoints.next(), Some(EndpointRef::new("host1", 2181, true)));
280+
assert_eq!(endpoints.next(), Some(EndpointRef::new("host2", 2181, false)));
281+
assert_eq!(endpoints.next(), Some(EndpointRef::new("host3", 2182, true)));
282+
assert_eq!(endpoints.next(), None);
283+
284+
endpoints.cycle();
285+
assert_eq!(endpoints.next(), Some(EndpointRef::new("host1", 2181, true)));
286+
assert_eq!(endpoints.next(), Some(EndpointRef::new("host2", 2181, false)));
287+
assert_eq!(endpoints.next(), Some(EndpointRef::new("host3", 2182, true)));
288+
assert_eq!(endpoints.next(), Some(EndpointRef::new("host1", 2181, true)));
289+
}
290+
}

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod acl;
22
mod chroot;
33
mod client;
4+
mod endpoint;
45
mod error;
56
mod proto;
67
mod record;

0 commit comments

Comments
 (0)