|
| 1 | +use std::fmt::{self, Display, Formatter}; |
| 2 | + |
| 3 | +use crate::chroot::Chroot; |
| 4 | +use crate::error::Error; |
| 5 | +use crate::util::{Ref, ToRef}; |
| 6 | + |
| 7 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 8 | +pub struct Endpoint { |
| 9 | + pub host: String, |
| 10 | + pub port: u16, |
| 11 | + pub tls: bool, |
| 12 | +} |
| 13 | + |
| 14 | +#[derive(Debug, Copy, Clone, PartialEq, Eq)] |
| 15 | +pub struct EndpointRef<'a> { |
| 16 | + pub host: &'a str, |
| 17 | + pub port: u16, |
| 18 | + pub tls: bool, |
| 19 | +} |
| 20 | + |
| 21 | +impl Endpoint { |
| 22 | + pub fn new(host: impl Into<String>, port: u16, tls: bool) -> Self { |
| 23 | + Self { host: host.into(), port, tls } |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +impl<'a> EndpointRef<'a> { |
| 28 | + pub fn new(host: &'a str, port: u16, tls: bool) -> Self { |
| 29 | + Self { host, port, tls } |
| 30 | + } |
| 31 | +} |
| 32 | + |
| 33 | +impl Display for Endpoint { |
| 34 | + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { |
| 35 | + self.to_ref().fmt(f) |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +impl<'a> From<(&'a str, u16, bool)> for EndpointRef<'a> { |
| 40 | + fn from(v: (&'a str, u16, bool)) -> Self { |
| 41 | + Self::new(v.0, v.1, v.2) |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +impl Display for EndpointRef<'_> { |
| 46 | + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { |
| 47 | + let proto = if self.tls { "tcp" } else { "tcp+tls" }; |
| 48 | + write!(f, "{}://{}:{}", proto, self.host, self.port) |
| 49 | + } |
| 50 | +} |
| 51 | + |
| 52 | +impl PartialEq<(&str, u16, bool)> for EndpointRef<'_> { |
| 53 | + fn eq(&self, other: &(&str, u16, bool)) -> bool { |
| 54 | + self.host == other.0 && self.port == other.1 && self.tls == other.2 |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +impl<'a> ToRef<'a, EndpointRef<'a>> for Endpoint { |
| 59 | + fn to_ref(&'a self) -> EndpointRef<'a> { |
| 60 | + return EndpointRef::new(self.host.as_str(), self.port, self.tls); |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +impl<'a> Ref<'a> for EndpointRef<'a> { |
| 65 | + type Value = Endpoint; |
| 66 | + |
| 67 | + fn to_value(&self) -> Endpoint { |
| 68 | + Endpoint::new(self.host.to_owned(), self.port, self.tls) |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +#[derive(Clone, Copy, Debug, PartialEq, Eq)] |
| 73 | +struct InvalidAddress(&'static &'static str); |
| 74 | + |
| 75 | +impl From<InvalidAddress> for Error { |
| 76 | + fn from(_: InvalidAddress) -> Error { |
| 77 | + Error::BadArguments(&"invalid address") |
| 78 | + } |
| 79 | +} |
| 80 | + |
| 81 | +fn parse_host_port(host: &str, port: &str) -> Result<u16, InvalidAddress> { |
| 82 | + if host.is_empty() { |
| 83 | + return Err(InvalidAddress(&"empty host")); |
| 84 | + } |
| 85 | + if port.is_empty() { |
| 86 | + return Ok(2181); |
| 87 | + } |
| 88 | + let port = match port.parse::<u16>() { |
| 89 | + Err(_) => return Err(InvalidAddress(&"invalid port")), |
| 90 | + Ok(port) => port, |
| 91 | + }; |
| 92 | + if port == 0 { |
| 93 | + return Err(InvalidAddress(&"invalid port number")); |
| 94 | + } |
| 95 | + Ok(port) |
| 96 | +} |
| 97 | + |
| 98 | +fn parse_address(s: &str) -> Result<(&str, u16), InvalidAddress> { |
| 99 | + let (host, port_str) = if s.starts_with('[') { |
| 100 | + let i = match s.rfind(']') { |
| 101 | + None => return Err(InvalidAddress(&"invalid address")), |
| 102 | + Some(i) => i, |
| 103 | + }; |
| 104 | + let host = &s[1..i]; |
| 105 | + let mut remains = &s[i + 1..]; |
| 106 | + if !remains.is_empty() { |
| 107 | + if remains.as_bytes()[0] != b':' { |
| 108 | + return Err(InvalidAddress(&"invalid address")); |
| 109 | + } |
| 110 | + remains = &remains[1..]; |
| 111 | + } |
| 112 | + (host, remains) |
| 113 | + } else { |
| 114 | + match s.rfind(':') { |
| 115 | + None => (s, Default::default()), |
| 116 | + Some(i) => (&s[..i], &s[i + 1..]), |
| 117 | + } |
| 118 | + }; |
| 119 | + let port = parse_host_port(host, port_str)?; |
| 120 | + Ok((host, port)) |
| 121 | +} |
| 122 | + |
| 123 | +/// Parses connection string to host port pairs and chroot. |
| 124 | +pub fn parse_connect_string(s: &str, tls: bool) -> Result<(Vec<EndpointRef<'_>>, Chroot<'_>), Error> { |
| 125 | + let mut chroot = None; |
| 126 | + let mut endpoints = Vec::with_capacity(10); |
| 127 | + for s in s.rsplit(',') { |
| 128 | + let (mut hostport, tls) = if let Some(s) = s.strip_prefix("tcp://") { |
| 129 | + (s, false) |
| 130 | + } else if let Some(s) = s.strip_prefix("tcp+tls://") { |
| 131 | + (s, true) |
| 132 | + } else if s.is_empty() { |
| 133 | + let err = if chroot.is_none() { |
| 134 | + Error::BadArguments(&"empty connect string") |
| 135 | + } else { |
| 136 | + Error::BadArguments(&"invalid address") |
| 137 | + }; |
| 138 | + return Err(err); |
| 139 | + } else { |
| 140 | + (s, tls) |
| 141 | + }; |
| 142 | + if chroot.is_none() { |
| 143 | + chroot = Some(Chroot::default()); |
| 144 | + if let Some(i) = hostport.find('/') { |
| 145 | + chroot = Some(Chroot::new(&hostport[i..])?); |
| 146 | + hostport = &hostport[..i]; |
| 147 | + } |
| 148 | + } |
| 149 | + let (host, port) = parse_address(hostport)?; |
| 150 | + endpoints.push(EndpointRef::new(host, port, tls)); |
| 151 | + } |
| 152 | + endpoints.reverse(); |
| 153 | + Ok((endpoints, chroot.unwrap())) |
| 154 | +} |
| 155 | + |
| 156 | +#[derive(Clone, Debug)] |
| 157 | +pub struct IterableEndpoints { |
| 158 | + cycle: bool, |
| 159 | + next: usize, |
| 160 | + endpoints: Vec<Endpoint>, |
| 161 | +} |
| 162 | + |
| 163 | +impl IterableEndpoints { |
| 164 | + pub fn new(endpoints: impl Into<Vec<Endpoint>>) -> Self { |
| 165 | + Self { cycle: false, next: 0, endpoints: endpoints.into() } |
| 166 | + } |
| 167 | + |
| 168 | + pub fn endpoints(&self) -> &[Endpoint] { |
| 169 | + &self.endpoints |
| 170 | + } |
| 171 | + |
| 172 | + pub fn cycle(&mut self) { |
| 173 | + if self.next >= self.endpoints.len() { |
| 174 | + self.next = 0; |
| 175 | + } |
| 176 | + self.cycle = true; |
| 177 | + } |
| 178 | + |
| 179 | + pub fn next(&mut self) -> Option<EndpointRef<'_>> { |
| 180 | + let next = self.next; |
| 181 | + if next >= self.endpoints.len() { |
| 182 | + return None; |
| 183 | + } |
| 184 | + self.step(); |
| 185 | + let host = &self.endpoints[next]; |
| 186 | + Some(host.to_ref()) |
| 187 | + } |
| 188 | + |
| 189 | + pub fn step(&mut self) { |
| 190 | + self.next += 1; |
| 191 | + if self.cycle && self.next >= self.endpoints.len() { |
| 192 | + self.next = 0; |
| 193 | + } |
| 194 | + } |
| 195 | +} |
| 196 | + |
| 197 | +impl From<&[EndpointRef<'_>]> for IterableEndpoints { |
| 198 | + fn from(endpoints: &[EndpointRef<'_>]) -> Self { |
| 199 | + let endpoints: Vec<_> = endpoints.iter().map(|endpoint| endpoint.to_value()).collect(); |
| 200 | + Self::new(endpoints) |
| 201 | + } |
| 202 | +} |
| 203 | + |
| 204 | +#[cfg(test)] |
| 205 | +mod tests { |
| 206 | + use pretty_assertions::assert_eq; |
| 207 | + |
| 208 | + use crate::chroot::Chroot; |
| 209 | + use crate::error::Error; |
| 210 | + |
| 211 | + #[test] |
| 212 | + fn test_parse_address_v4() { |
| 213 | + use super::{parse_address, InvalidAddress}; |
| 214 | + assert_eq!(parse_address("fasl:0").unwrap_err(), InvalidAddress(&"invalid port number")); |
| 215 | + assert_eq!(parse_address(":1234").unwrap_err(), InvalidAddress(&"empty host")); |
| 216 | + assert_eq!(parse_address("fasl:a234").unwrap_err(), InvalidAddress(&"invalid port")); |
| 217 | + assert_eq!(parse_address("fasl:1234").unwrap(), ("fasl", 1234)); |
| 218 | + assert_eq!(parse_address("fasl:2181").unwrap(), ("fasl", 2181)); |
| 219 | + assert_eq!(parse_address("fasl").unwrap(), ("fasl", 2181)); |
| 220 | + } |
| 221 | + |
| 222 | + #[test] |
| 223 | + fn test_parse_address_v6() { |
| 224 | + use super::{parse_address, InvalidAddress}; |
| 225 | + assert_eq!(parse_address("[fasl").unwrap_err(), InvalidAddress(&"invalid address")); |
| 226 | + assert_eq!(parse_address("[fasl]:0").unwrap_err(), InvalidAddress(&"invalid port number")); |
| 227 | + assert_eq!(parse_address("[]:1234").unwrap_err(), InvalidAddress(&"empty host")); |
| 228 | + assert_eq!(parse_address("[fasl]:a234").unwrap_err(), InvalidAddress(&"invalid port")); |
| 229 | + assert_eq!(parse_address("[fasl]:1234").unwrap(), ("fasl", 1234)); |
| 230 | + assert_eq!(parse_address("[fasl]").unwrap(), ("fasl", 2181)); |
| 231 | + assert_eq!(parse_address("[::1]:2181").unwrap(), ("::1", 2181)); |
| 232 | + } |
| 233 | + |
| 234 | + #[test] |
| 235 | + fn test_parse_connect_string() { |
| 236 | + use super::parse_connect_string; |
| 237 | + |
| 238 | + assert_eq!(parse_connect_string("", false).unwrap_err(), Error::BadArguments(&"empty connect string")); |
| 239 | + assert_eq!(parse_connect_string("host1:abc", false).unwrap_err(), Error::BadArguments(&"invalid address")); |
| 240 | + assert_eq!( |
| 241 | + parse_connect_string("host1/abc/", true).unwrap_err(), |
| 242 | + Error::BadArguments(&"path must not end with '/'") |
| 243 | + ); |
| 244 | + |
| 245 | + assert_eq!( |
| 246 | + parse_connect_string("host1", false).unwrap(), |
| 247 | + (vec![("host1", 2181, false).into()], Chroot::default()) |
| 248 | + ); |
| 249 | + assert_eq!( |
| 250 | + parse_connect_string("host1", true).unwrap(), |
| 251 | + (vec![("host1", 2181, true).into()], Chroot::default()) |
| 252 | + ); |
| 253 | + assert_eq!( |
| 254 | + parse_connect_string("tcp+tls://host1", false).unwrap(), |
| 255 | + (vec![("host1", 2181, true).into()], Chroot::default()) |
| 256 | + ); |
| 257 | + assert_eq!( |
| 258 | + parse_connect_string("host1,host2:2222/", false).unwrap(), |
| 259 | + (vec![("host1", 2181, false).into(), ("host2", 2222, false).into()], Chroot::default()) |
| 260 | + ); |
| 261 | + assert_eq!( |
| 262 | + parse_connect_string("host1,host2:2222/abc", false).unwrap(), |
| 263 | + (vec![("host1", 2181, false).into(), ("host2", 2222, false).into()], Chroot::new("/abc").unwrap()) |
| 264 | + ); |
| 265 | + assert_eq!( |
| 266 | + parse_connect_string("host1,tcp+tls://host2:2222,tcp://host3/abc", true).unwrap(), |
| 267 | + ( |
| 268 | + vec![("host1", 2181, true).into(), ("host2", 2222, true).into(), ("host3", 2181, false).into()], |
| 269 | + Chroot::new("/abc").unwrap() |
| 270 | + ) |
| 271 | + ); |
| 272 | + } |
| 273 | + |
| 274 | + #[test] |
| 275 | + fn test_iterable_endpoints() { |
| 276 | + use super::{parse_connect_string, EndpointRef, IterableEndpoints}; |
| 277 | + let (endpoints, _) = parse_connect_string("host1:2181,tcp://host2,tcp+tls://host3:2182", true).unwrap(); |
| 278 | + let mut endpoints = IterableEndpoints::from(endpoints.as_slice()); |
| 279 | + assert_eq!(endpoints.next(), Some(EndpointRef::new("host1", 2181, true))); |
| 280 | + assert_eq!(endpoints.next(), Some(EndpointRef::new("host2", 2181, false))); |
| 281 | + assert_eq!(endpoints.next(), Some(EndpointRef::new("host3", 2182, true))); |
| 282 | + assert_eq!(endpoints.next(), None); |
| 283 | + |
| 284 | + endpoints.cycle(); |
| 285 | + assert_eq!(endpoints.next(), Some(EndpointRef::new("host1", 2181, true))); |
| 286 | + assert_eq!(endpoints.next(), Some(EndpointRef::new("host2", 2181, false))); |
| 287 | + assert_eq!(endpoints.next(), Some(EndpointRef::new("host3", 2182, true))); |
| 288 | + assert_eq!(endpoints.next(), Some(EndpointRef::new("host1", 2181, true))); |
| 289 | + } |
| 290 | +} |
0 commit comments