Skip to content

Commit 58f8fb4

Browse files
committed
refactor!: add SessionInfo to box session id, password and readonly
We also use it to reject readonly session reestablish request as readonly sessions are local to connected node.
1 parent d7ce22c commit 58f8fb4

File tree

5 files changed

+95
-61
lines changed

5 files changed

+95
-61
lines changed

src/client/mod.rs

+21-29
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use thiserror::Error;
1313
use tokio::sync::{mpsc, watch};
1414

1515
pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
16-
use super::session::{Depot, MarshalledRequest, Session, SessionOperation, WatchReceiver, PASSWORD_LEN};
16+
use super::session::{Depot, MarshalledRequest, Session, SessionOperation, WatchReceiver};
1717
use crate::acl::{Acl, Acls, AuthUser};
1818
use crate::chroot::{Chroot, ChrootPath, OwnedChroot};
1919
use crate::endpoint::{self, IterableEndpoints};
@@ -44,7 +44,7 @@ use crate::proto::{
4444
pub use crate::proto::{EnsembleUpdate, Stat};
4545
use crate::record::{self, Record, StaticRecord};
4646
use crate::session::StateReceiver;
47-
pub use crate::session::{EventType, SessionId, SessionState, WatchedEvent};
47+
pub use crate::session::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
4848
use crate::tls::TlsOptions;
4949
use crate::util;
5050

@@ -215,7 +215,7 @@ impl CreateSequence {
215215
pub struct Client {
216216
chroot: OwnedChroot,
217217
version: Version,
218-
session: (SessionId, Vec<u8>),
218+
session: SessionInfo,
219219
session_timeout: Duration,
220220
requester: mpsc::UnboundedSender<SessionOperation>,
221221
state_watcher: StateWatcher,
@@ -243,7 +243,7 @@ impl Client {
243243
pub(crate) fn new(
244244
chroot: OwnedChroot,
245245
version: Version,
246-
session: (SessionId, Vec<u8>),
246+
session: SessionInfo,
247247
timeout: Duration,
248248
requester: mpsc::UnboundedSender<SessionOperation>,
249249
state_receiver: watch::Receiver<SessionState>,
@@ -265,18 +265,18 @@ impl Client {
265265
self.chroot.path()
266266
}
267267

268-
/// ZooKeeper session id.
269-
pub fn session_id(&self) -> SessionId {
270-
self.session.0
268+
/// ZooKeeper session info.
269+
pub fn session(&self) -> &SessionInfo {
270+
&self.session
271271
}
272272

273-
/// Session password.
274-
pub fn session_password(&self) -> &[u8] {
275-
self.session.1.as_slice()
273+
/// ZooKeeper session id.
274+
pub fn session_id(&self) -> SessionId {
275+
self.session().id()
276276
}
277277

278278
/// Consumes this instance into session info.
279-
pub fn into_session(self) -> (SessionId, Vec<u8>) {
279+
pub fn into_session(self) -> SessionInfo {
280280
self.session
281281
}
282282

@@ -1538,7 +1538,7 @@ pub(crate) struct Version(u32, u32, u32);
15381538
pub struct Connector {
15391539
tls: Option<TlsOptions>,
15401540
authes: Vec<AuthPacket>,
1541-
session: Option<(SessionId, Vec<u8>)>,
1541+
session: Option<SessionInfo>,
15421542
readonly: bool,
15431543
detached: bool,
15441544
server_version: Version,
@@ -1590,8 +1590,8 @@ impl Connector {
15901590
}
15911591

15921592
/// Specifies session to reestablish.
1593-
pub fn session(&mut self, id: SessionId, password: Vec<u8>) -> &mut Self {
1594-
self.session = Some((id, password));
1593+
pub fn session(&mut self, session: SessionInfo) -> &mut Self {
1594+
self.session = Some(session);
15951595
self
15961596
}
15971597

@@ -1623,14 +1623,12 @@ impl Connector {
16231623

16241624
async fn connect_internally(&mut self, secure: bool, cluster: &str) -> Result<Client> {
16251625
let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
1626-
if let Some((id, password)) = &self.session {
1627-
if id.0 == 0 {
1628-
return Err(Error::BadArguments(&"session id must not be 0"));
1629-
} else if password.is_empty() {
1630-
return Err(Error::BadArguments(&formatcp!(
1631-
"session password is empty, it should have length of {}",
1632-
PASSWORD_LEN
1633-
)));
1626+
if let Some(session) = self.session.as_ref() {
1627+
if session.is_readonly() {
1628+
return Err(Error::new_other(
1629+
format!("can't reestablish readonly and hence local session {}", session.id()),
1630+
None,
1631+
));
16341632
}
16351633
}
16361634
if self.session_timeout < Duration::ZERO {
@@ -1653,7 +1651,7 @@ impl Connector {
16531651
let mut connecting_depot = Depot::for_connecting();
16541652
let conn = session.start(&mut endpoints, &mut buf, &mut connecting_depot).await?;
16551653
let (sender, receiver) = mpsc::unbounded_channel();
1656-
let session_info = (session.session_id, session.session_password.clone());
1654+
let session_info = session.session.clone();
16571655
let session_timeout = session.session_timeout;
16581656
tokio::spawn(async move {
16591657
session.serve(endpoints, conn, buf, connecting_depot, receiver).await;
@@ -1729,12 +1727,6 @@ impl ClientBuilder {
17291727
self
17301728
}
17311729

1732-
/// Specifies session to reestablish.
1733-
pub fn with_session(&mut self, id: SessionId, password: Vec<u8>) -> &mut Self {
1734-
self.connector.session(id, password);
1735-
self
1736-
}
1737-
17381730
/// Specifies client assumed server version of ZooKeeper cluster.
17391731
///
17401732
/// Client will issue server compatible protocol to avoid [Error::Unimplemented] for some

src/error.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,19 @@ impl Error {
127127
}
128128
}
129129

130+
pub(crate) fn new_other(
131+
message: impl Into<Arc<String>>,
132+
source: Option<Arc<dyn std::error::Error + Send + Sync + 'static>>,
133+
) -> Self {
134+
Self::Other(OtherError { message: message.into(), source })
135+
}
136+
130137
pub(crate) fn other(message: impl Into<String>, source: impl std::error::Error + Send + Sync + 'static) -> Self {
131-
Self::Other(OtherError { message: Arc::new(message.into()), source: Some(Arc::new(source)) })
138+
Self::new_other(message.into(), Some(Arc::new(source)))
132139
}
133140

134141
pub(crate) fn other_from(source: impl std::error::Error + Send + Sync + 'static) -> Self {
135-
Self::Other(OtherError { message: Arc::new(source.to_string()), source: Some(Arc::new(source)) })
142+
Self::new_other(source.to_string(), Some(Arc::new(source)))
136143
}
137144
}
138145

src/session/mod.rs

+18-23
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub use self::request::{
2727
StateReceiver,
2828
StateResponser,
2929
};
30-
pub use self::types::{EventType, SessionId, SessionState, WatchedEvent};
30+
pub use self::types::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
3131
pub use self::watch::{OneshotReceiver, PersistentReceiver, WatchReceiver};
3232
use self::watch::{WatchManager, WatcherId};
3333
use crate::deadline::Deadline;
@@ -71,11 +71,9 @@ pub struct Session {
7171
ping_timeout: Duration,
7272
session_expired_timeout: Duration,
7373

74-
pub session_id: SessionId,
74+
pub session: SessionInfo,
7575
session_state: SessionState,
7676
pub session_timeout: Duration,
77-
pub session_password: Vec<u8>,
78-
session_readonly: bool,
7977

8078
pub authes: Vec<MarshalledRequest>,
8179
state_sender: tokio::sync::watch::Sender<SessionState>,
@@ -86,16 +84,15 @@ pub struct Session {
8684

8785
impl Session {
8886
pub fn new(
89-
session: Option<(SessionId, Vec<u8>)>,
87+
session: Option<SessionInfo>,
9088
authes: &[AuthPacket],
9189
readonly: bool,
9290
detached: bool,
9391
tls_config: ClientConfig,
9492
session_timeout: Duration,
9593
connection_timeout: Duration,
9694
) -> (Session, tokio::sync::watch::Receiver<SessionState>) {
97-
let (session_id, session_password) =
98-
session.unwrap_or_else(|| (SessionId(0), Vec::with_capacity(PASSWORD_LEN)));
95+
let session = session.unwrap_or_else(|| SessionInfo::new(SessionId(0), Vec::with_capacity(PASSWORD_LEN)));
9996
let (state_sender, state_receiver) = tokio::sync::watch::channel(SessionState::Disconnected);
10097
let now = Instant::now();
10198
let (watch_manager, unwatch_receiver) = WatchManager::new();
@@ -114,11 +111,9 @@ impl Session {
114111
session_expired_timeout: Duration::ZERO,
115112
connector: Connector::new(tls_config),
116113

117-
session_id,
114+
session,
118115
session_timeout,
119116
session_state: SessionState::Disconnected,
120-
session_password,
121-
session_readonly: session_id.0 == 0,
122117

123118
authes: authes.iter().map(|auth| MarshalledRequest::new(OpCode::Auth, auth)).collect(),
124119
state_sender,
@@ -132,7 +127,7 @@ impl Session {
132127

133128
fn is_readonly_allowed(&self) -> bool {
134129
// Session downgrade is not allowed as partitioned session will expired finally by quorum.
135-
self.readonly && self.session_readonly
130+
self.readonly && self.session.readonly
136131
}
137132

138133
async fn close_requester<T: RequestOperation>(mut requester: mpsc::UnboundedReceiver<T>, err: &Error) {
@@ -215,7 +210,7 @@ impl Session {
215210
) {
216211
if let Err(err) = self.serve_session(endpoints, &mut conn, buf, depot, requester, unwatch_requester).await {
217212
self.resolve_serve_error(&err);
218-
log::info!("ZooKeeper session {} state {} error {}", self.session_id, self.session_state, err);
213+
log::info!("ZooKeeper session {} state {} error {}", self.session.id, self.session_state, err);
219214
depot.error(&err);
220215
} else {
221216
self.change_state(SessionState::Disconnected);
@@ -301,7 +296,7 @@ impl Session {
301296
depot.pop_ping()?;
302297
if let Some(last_ping) = self.last_ping.take() {
303298
let elapsed = Instant::now() - last_ping;
304-
log::debug!("ZooKeeper session {} got ping response after {}ms", self.session_id, elapsed.as_millis());
299+
log::debug!("ZooKeeper session {} got ping response after {}ms", self.session.id, elapsed.as_millis());
305300
}
306301
return Ok(());
307302
}
@@ -349,7 +344,7 @@ impl Session {
349344
}
350345

351346
fn complete_connect(&mut self) {
352-
let state = if self.session_readonly { SessionState::ConnectedReadOnly } else { SessionState::SyncConnected };
347+
let state = if self.session.readonly { SessionState::ConnectedReadOnly } else { SessionState::SyncConnected };
353348
self.change_state(state);
354349
}
355350

@@ -361,11 +356,11 @@ impl Session {
361356
} else if !self.is_readonly_allowed() && response.readonly {
362357
return Err(Error::ConnectionLoss);
363358
}
364-
self.session_id = SessionId(response.session_id);
359+
self.session.id = SessionId(response.session_id);
360+
self.session.password.clear();
361+
self.session.password.extend_from_slice(response.password);
362+
self.session.readonly = response.readonly;
365363
self.reset_timeout(Duration::from_millis(response.session_timeout as u64));
366-
self.session_password.clear();
367-
self.session_password.extend_from_slice(response.password);
368-
self.session_readonly = response.readonly;
369364
self.complete_connect();
370365
Ok(())
371366
}
@@ -450,7 +445,7 @@ impl Session {
450445
unwatch_requester: &mut mpsc::UnboundedReceiver<(WatcherId, StateResponser)>,
451446
) -> Result<(), Error> {
452447
let mut seek_for_writable =
453-
if self.session_readonly { Some(self.connector.clone().seek_for_writable(endpoints)) } else { None };
448+
if self.session.readonly { Some(self.connector.clone().seek_for_writable(endpoints)) } else { None };
454449
let mut tick = time::interval(self.tick_timeout);
455450
tick.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
456451
let mut channel_closed = false;
@@ -519,8 +514,8 @@ impl Session {
519514
protocol_version: 0,
520515
last_zxid_seen: 0,
521516
timeout: self.session_timeout.as_millis() as i32,
522-
session_id: if self.session_readonly { 0 } else { self.session_id.0 },
523-
password: self.session_password.as_slice(),
517+
session_id: if self.session.readonly { 0 } else { self.session.id.0 },
518+
password: self.session.password.as_slice(),
524519
readonly: self.is_readonly_allowed(),
525520
};
526521
log::trace!("Sending connect request: {request:?}");
@@ -570,7 +565,7 @@ impl Session {
570565
Err(err)
571566
},
572567
_ => {
573-
log::info!("ZooKeeper succeeds to establish session({}) to {}", self.session_id, endpoint);
568+
log::info!("ZooKeeper succeeds to establish session({}) to {}", self.session.id, endpoint);
574569
Ok(conn)
575570
},
576571
}
@@ -582,7 +577,7 @@ impl Session {
582577
buf: &mut Vec<u8>,
583578
depot: &mut Depot,
584579
) -> Result<Connection, Error> {
585-
let session_timeout = if self.session_id.0 == 0 { self.session_timeout } else { self.session_expired_timeout };
580+
let session_timeout = if self.session.id.0 == 0 { self.session_timeout } else { self.session_expired_timeout };
586581
let mut deadline = Deadline::until(self.last_recv + session_timeout);
587582
let mut last_error = match self.start_once(endpoints, &mut deadline, buf, depot).await {
588583
Err(err) => err,

src/session/types.rs

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use derive_where::derive_where;
12
use num_enum::{IntoPrimitive, TryFromPrimitive};
23
use strum::EnumIter;
34

@@ -6,7 +7,7 @@ use crate::proto::AddWatchMode;
67
use crate::util;
78

89
/// Thin wrapper for zookeeper session id. It prints in hex format headed with 0x.
9-
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10+
#[derive(Copy, Clone, PartialEq, Eq)]
1011
pub struct SessionId(pub i64);
1112

1213
impl std::fmt::Display for SessionId {
@@ -15,6 +16,40 @@ impl std::fmt::Display for SessionId {
1516
}
1617
}
1718

19+
impl std::fmt::Debug for SessionId {
20+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21+
std::fmt::Display::fmt(self, f)
22+
}
23+
}
24+
25+
/// ZooKeeper session info.
26+
#[derive(Clone)]
27+
#[derive_where(Debug)]
28+
pub struct SessionInfo {
29+
pub(crate) id: SessionId,
30+
#[derive_where(skip(Debug))]
31+
pub(crate) password: Vec<u8>,
32+
pub(crate) readonly: bool,
33+
}
34+
35+
impl SessionInfo {
36+
pub(crate) fn new(id: SessionId, password: Vec<u8>) -> Self {
37+
Self { id, password, readonly: id.0 == 0 }
38+
}
39+
40+
/// Session id.
41+
pub fn id(&self) -> SessionId {
42+
self.id
43+
}
44+
45+
/// Is this an readonly session ?
46+
///
47+
/// Readonly sessions are local to connected server thus not eligible for session reestablishment.
48+
pub fn is_readonly(&self) -> bool {
49+
self.readonly
50+
}
51+
}
52+
1853
/// ZooKeeper session states.
1954
#[derive(Copy, Clone, Debug, PartialEq, Eq, strum::Display)]
2055
pub enum SessionState {

tests/zookeeper.rs

+11-6
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,11 @@ async fn test_connect_session_expired() {
164164
let cluster = Cluster::new().await;
165165
let client = cluster.custom_client(None, |connector| connector.detached()).await.unwrap();
166166
let timeout = client.session_timeout();
167-
let (id, password) = client.into_session();
167+
let session = client.into_session();
168168

169169
tokio::time::sleep(timeout * 2).await;
170170

171-
assert_that!(cluster.custom_client(None, |connector| connector.session(id, password)).await.unwrap_err())
171+
assert_that!(cluster.custom_client(None, |connector| connector.session(session)).await.unwrap_err())
172172
.is_equal_to(zk::Error::SessionExpired);
173173
}
174174

@@ -1701,10 +1701,10 @@ async fn test_client_drop() {
17011701

17021702
let mut state_watcher = client.state_watcher();
17031703
tokio::time::sleep(Duration::from_secs(20)).await;
1704-
let (id, password) = client.into_session();
1704+
let session = client.into_session();
17051705
assert_eq!(zk::SessionState::Closed, state_watcher.changed().await);
17061706

1707-
cluster.custom_client(None, |connector| connector.session(id, password)).await.unwrap_err();
1707+
cluster.custom_client(None, |connector| connector.session(session)).await.unwrap_err();
17081708
}
17091709

17101710
#[test_log::test(tokio::test)]
@@ -1713,10 +1713,10 @@ async fn test_client_detach() {
17131713
let client = cluster.custom_client(None, |connector| connector.detached()).await.unwrap();
17141714

17151715
let mut state_watcher = client.state_watcher();
1716-
let (id, password) = client.into_session();
1716+
let session = client.into_session();
17171717
assert_eq!(zk::SessionState::Closed, state_watcher.changed().await);
17181718

1719-
cluster.custom_client(None, |connector| connector.session(id, password)).await.unwrap();
1719+
cluster.custom_client(None, |connector| connector.session(session)).await.unwrap();
17201720
}
17211721

17221722
fn generate_ca_cert() -> (Certificate, String) {
@@ -1821,6 +1821,11 @@ async fn test_readonly(tls: bool) {
18211821
};
18221822
assert_that!(client.create("/y", b"", PERSISTENT_OPEN).await.unwrap_err()).is_equal_to(zk::Error::NotReadOnly);
18231823

1824+
let session = client.session().clone();
1825+
assert_eq!(session.is_readonly(), true);
1826+
assert_that!(zk::Client::connector().session(session).connect("localhost:4001").await.unwrap_err().to_string())
1827+
.contains("readonly");
1828+
18241829
let mut state_watcher = client.state_watcher();
18251830
assert_eq!(state_watcher.state(), zk::SessionState::ConnectedReadOnly);
18261831

0 commit comments

Comments
 (0)