Skip to content

Commit 90df9f5

Browse files
committed
refactor: add Builder for Session to avoid too many arguments
1 parent fea7703 commit 90df9f5

File tree

2 files changed

+116
-72
lines changed

2 files changed

+116
-72
lines changed

src/client/mod.rs

+12-28
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ use either::{Either, Left, Right};
1111
use ignore_result::Ignore;
1212
use thiserror::Error;
1313
use tokio::sync::mpsc;
14-
use tracing::field::display;
15-
use tracing::{instrument, Span};
14+
use tracing::instrument;
1615

1716
pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
1817
use super::session::{Depot, MarshalledRequest, Session, SessionOperation, WatchReceiver};
@@ -1649,32 +1648,17 @@ impl Connector {
16491648
#[instrument(name = "connect", skip_all, fields(session))]
16501649
async fn connect_internally(&mut self, secure: bool, cluster: &str) -> Result<Client> {
16511650
let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
1652-
if let Some(session) = self.session.as_ref() {
1653-
if session.is_readonly() {
1654-
return Err(Error::new_other(
1655-
format!("can't reestablish readonly and hence local session {}", session.id()),
1656-
None,
1657-
));
1658-
}
1659-
Span::current().record("session", display(session.id()));
1660-
}
1661-
if self.session_timeout < Duration::ZERO {
1662-
return Err(Error::BadArguments(&"session timeout must not be negative"));
1663-
} else if self.connection_timeout < Duration::ZERO {
1664-
return Err(Error::BadArguments(&"connection timeout must not be negative"));
1665-
}
1666-
let tls_config = self.tls.take().unwrap_or_default().into_config()?;
1667-
let (mut session, state_receiver) = Session::new(
1668-
self.session.take(),
1669-
&self.authes,
1670-
self.readonly,
1671-
self.detached,
1672-
tls_config,
1673-
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1674-
self.sasl.take(),
1675-
self.session_timeout,
1676-
self.connection_timeout,
1677-
);
1651+
let builder = Session::builder()
1652+
.with_tls(self.tls.take())
1653+
.with_session(self.session.take())
1654+
.with_authes(&self.authes)
1655+
.with_readonly(self.readonly)
1656+
.with_detached(self.detached)
1657+
.with_session_timeout(self.session_timeout)
1658+
.with_connection_timeout(self.connection_timeout);
1659+
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1660+
let builder = builder.with_sasl(self.sasl.take());
1661+
let (mut session, state_receiver) = builder.build()?;
16781662
let mut endpoints = IterableEndpoints::from(endpoints.as_slice());
16791663
endpoints.reset();
16801664
if !self.fail_eagerly {

src/session/mod.rs

+104-44
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use std::io;
1010
use std::time::Duration;
1111

1212
use ignore_result::Ignore;
13-
use rustls::ClientConfig;
1413
use tokio::select;
1514
use tokio::sync::mpsc;
1615
use tokio::time::{self, Instant};
@@ -39,6 +38,7 @@ use crate::proto::{AuthPacket, ConnectRequest, ConnectResponse, ErrorCode, OpCod
3938
use crate::record;
4039
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
4140
use crate::sasl::{SaslInitiator, SaslOptions, SaslSession};
41+
use crate::tls::TlsOptions;
4242

4343
pub const PASSWORD_LEN: usize = 16;
4444
pub const DEFAULT_SESSION_TIMEOUT: Duration = Duration::from_secs(6);
@@ -59,59 +59,81 @@ impl RequestOperation for (WatcherId, StateResponser) {
5959
}
6060
}
6161

62-
pub struct Session {
62+
#[derive(Default)]
63+
pub struct Builder {
64+
tls: Option<TlsOptions>,
65+
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
66+
sasl: Option<SaslOptions>,
67+
authes: Vec<MarshalledRequest>,
6368
readonly: bool,
6469
detached: bool,
70+
session: Option<SessionInfo>,
71+
session_timeout: Duration,
72+
connection_timeout: Duration,
73+
}
6574

66-
connector: Connector,
75+
impl Builder {
76+
pub fn with_tls(self, tls: Option<TlsOptions>) -> Self {
77+
Self { tls, ..self }
78+
}
6779

68-
configured_connection_timeout: Duration,
80+
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
81+
pub fn with_sasl(self, sasl: Option<SaslOptions>) -> Self {
82+
Self { sasl, ..self }
83+
}
6984

70-
last_zxid: i64,
71-
last_recv: Instant,
72-
last_send: Instant,
73-
last_ping: Option<Instant>,
74-
tick_timeout: Duration,
75-
ping_timeout: Duration,
76-
session_expired_timeout: Duration,
85+
pub fn with_authes(self, authes: &[AuthPacket]) -> Self {
86+
Self { authes: authes.iter().map(|auth| MarshalledRequest::new(OpCode::Auth, auth)).collect(), ..self }
87+
}
7788

78-
pub session: SessionInfo,
79-
session_state: SessionState,
80-
pub session_timeout: Duration,
89+
pub fn with_readonly(self, readonly: bool) -> Self {
90+
Self { readonly, ..self }
91+
}
8192

82-
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
83-
sasl_options: Option<SaslOptions>,
84-
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
85-
sasl_session: Option<SaslSession>,
93+
pub fn with_detached(self, detached: bool) -> Self {
94+
Self { detached, ..self }
95+
}
8696

87-
pub authes: Vec<MarshalledRequest>,
88-
state_sender: tokio::sync::watch::Sender<SessionState>,
97+
pub fn with_session(self, session: Option<SessionInfo>) -> Self {
98+
Self { session, ..self }
99+
}
89100

90-
watch_manager: WatchManager,
91-
unwatch_receiver: Option<mpsc::UnboundedReceiver<(WatcherId, StateResponser)>>,
92-
}
101+
pub fn with_session_timeout(self, session_timeout: Duration) -> Self {
102+
Self { session_timeout, ..self }
103+
}
93104

94-
impl Session {
95-
#[allow(clippy::too_many_arguments)]
96-
pub fn new(
97-
session: Option<SessionInfo>,
98-
authes: &[AuthPacket],
99-
readonly: bool,
100-
detached: bool,
101-
tls_config: ClientConfig,
102-
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))] sasl_options: Option<SaslOptions>,
103-
session_timeout: Duration,
104-
connection_timeout: Duration,
105-
) -> (Session, tokio::sync::watch::Receiver<SessionState>) {
106-
let session = session.unwrap_or_else(|| SessionInfo::new(SessionId(0), Vec::with_capacity(PASSWORD_LEN)));
105+
pub fn with_connection_timeout(self, connection_timeout: Duration) -> Self {
106+
Self { connection_timeout, ..self }
107+
}
108+
109+
pub fn build(self) -> Result<(Session, tokio::sync::watch::Receiver<SessionState>), Error> {
110+
let session = match self.session {
111+
Some(session) => {
112+
if session.is_readonly() {
113+
return Err(Error::new_other(
114+
format!("can't reestablish readonly and hence local session {}", session.id()),
115+
None,
116+
));
117+
}
118+
Span::current().record("session", display(session.id()));
119+
session
120+
},
121+
None => SessionInfo::new(SessionId(0), Vec::with_capacity(PASSWORD_LEN)),
122+
};
123+
if self.session_timeout < Duration::ZERO {
124+
return Err(Error::BadArguments(&"session timeout must not be negative"));
125+
} else if self.connection_timeout < Duration::ZERO {
126+
return Err(Error::BadArguments(&"connection timeout must not be negative"));
127+
}
128+
let tls_config = self.tls.unwrap_or_default().into_config()?;
107129
let (state_sender, state_receiver) = tokio::sync::watch::channel(SessionState::Disconnected);
108130
let now = Instant::now();
109131
let (watch_manager, unwatch_receiver) = WatchManager::new();
110132
let mut session = Session {
111-
readonly,
112-
detached,
133+
readonly: self.readonly,
134+
detached: self.detached,
113135

114-
configured_connection_timeout: connection_timeout,
136+
configured_connection_timeout: self.connection_timeout,
115137

116138
last_zxid: session.last_zxid,
117139
last_recv: now,
@@ -122,22 +144,60 @@ impl Session {
122144
session_expired_timeout: Duration::ZERO,
123145
connector: Connector::new(tls_config),
124146
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
125-
sasl_options,
147+
sasl_options: self.sasl,
126148
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
127149
sasl_session: None,
128150

129151
session,
130-
session_timeout,
152+
session_timeout: self.session_timeout,
131153
session_state: SessionState::Disconnected,
132154

133-
authes: authes.iter().map(|auth| MarshalledRequest::new(OpCode::Auth, auth)).collect(),
155+
authes: self.authes,
134156
state_sender,
135157
watch_manager,
136158
unwatch_receiver: Some(unwatch_receiver),
137159
};
138-
let timeout = if session_timeout.is_zero() { DEFAULT_SESSION_TIMEOUT } else { session_timeout };
160+
let timeout = if self.session_timeout.is_zero() { DEFAULT_SESSION_TIMEOUT } else { self.session_timeout };
139161
session.reset_timeout(timeout);
140-
(session, state_receiver)
162+
Ok((session, state_receiver))
163+
}
164+
}
165+
166+
pub struct Session {
167+
readonly: bool,
168+
detached: bool,
169+
170+
connector: Connector,
171+
172+
configured_connection_timeout: Duration,
173+
174+
last_zxid: i64,
175+
last_recv: Instant,
176+
last_send: Instant,
177+
last_ping: Option<Instant>,
178+
tick_timeout: Duration,
179+
ping_timeout: Duration,
180+
session_expired_timeout: Duration,
181+
182+
pub session: SessionInfo,
183+
session_state: SessionState,
184+
pub session_timeout: Duration,
185+
186+
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
187+
sasl_options: Option<SaslOptions>,
188+
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
189+
sasl_session: Option<SaslSession>,
190+
191+
pub authes: Vec<MarshalledRequest>,
192+
state_sender: tokio::sync::watch::Sender<SessionState>,
193+
194+
watch_manager: WatchManager,
195+
unwatch_receiver: Option<mpsc::UnboundedReceiver<(WatcherId, StateResponser)>>,
196+
}
197+
198+
impl Session {
199+
pub fn builder() -> Builder {
200+
Builder::default()
141201
}
142202

143203
fn is_readonly_allowed(&self) -> bool {

0 commit comments

Comments
 (0)