@@ -4,11 +4,14 @@ use std::borrow::Cow;
4
4
use std:: fmt:: Write as _;
5
5
use std:: future:: Future ;
6
6
use std:: mem:: ManuallyDrop ;
7
+ use std:: sync:: Arc ;
7
8
use std:: time:: Duration ;
8
9
9
10
use const_format:: formatcp;
10
11
use either:: { Either , Left , Right } ;
11
12
use ignore_result:: Ignore ;
13
+ use rustls:: pki_types:: { CertificateDer , PrivateKeyDer } ;
14
+ use rustls:: { ClientConfig , RootCertStore } ;
12
15
use thiserror:: Error ;
13
16
use tokio:: sync:: { mpsc, watch} ;
14
17
@@ -1528,6 +1531,9 @@ pub(crate) struct Version(u32, u32, u32);
1528
1531
/// Builder for [Client] with more options than [Client::connect].
1529
1532
#[ derive( Clone , Debug ) ]
1530
1533
pub struct ClientBuilder {
1534
+ tls : bool ,
1535
+ trusted_certs : RootCertStore ,
1536
+ client_certs : Option < ( Vec < CertificateDer < ' static > > , Arc < PrivateKeyDer < ' static > > ) > ,
1531
1537
authes : Vec < AuthPacket > ,
1532
1538
version : Version ,
1533
1539
session : Option < ( SessionId , Vec < u8 > ) > ,
@@ -1540,6 +1546,9 @@ pub struct ClientBuilder {
1540
1546
impl ClientBuilder {
1541
1547
fn new ( ) -> Self {
1542
1548
Self {
1549
+ tls : false ,
1550
+ trusted_certs : RootCertStore :: empty ( ) ,
1551
+ client_certs : None ,
1543
1552
authes : Default :: default ( ) ,
1544
1553
version : Version ( u32:: MAX , u32:: MAX , u32:: MAX ) ,
1545
1554
session : None ,
@@ -1584,6 +1593,43 @@ impl ClientBuilder {
1584
1593
self
1585
1594
}
1586
1595
1596
+ /// Assumes tls for server in connection string if no protocol specified individually.
1597
+ /// See [Self::connect] for syntax to specify protocol individually.
1598
+ pub fn assume_tls ( & mut self ) -> & mut Self {
1599
+ self . tls = true ;
1600
+ self
1601
+ }
1602
+
1603
+ /// Trusts certificates signed by given ca certificates.
1604
+ pub fn trust_ca_pem_certs ( & mut self , certs : & str ) -> Result < & mut Self > {
1605
+ for r in rustls_pemfile:: certs ( & mut certs. as_bytes ( ) ) {
1606
+ let cert = match r {
1607
+ Ok ( cert) => cert,
1608
+ Err ( err) => return Err ( Error :: other ( format ! ( "fail to read cert {}" , err) , err) ) ,
1609
+ } ;
1610
+ if let Err ( err) = self . trusted_certs . add ( cert) {
1611
+ return Err ( Error :: other ( format ! ( "fail to add cert {}" , err) , err) ) ;
1612
+ }
1613
+ }
1614
+ Ok ( self )
1615
+ }
1616
+
1617
+ /// Identifies client itself to server with given cert chain and private key.
1618
+ pub fn use_client_pem_cert ( & mut self , cert : & str , key : & str ) -> Result < & mut Self > {
1619
+ let r: std:: result:: Result < Vec < _ > , _ > = rustls_pemfile:: certs ( & mut cert. as_bytes ( ) ) . collect ( ) ;
1620
+ let certs = match r {
1621
+ Err ( err) => return Err ( Error :: other ( format ! ( "fail to read cert {}" , err) , err) ) ,
1622
+ Ok ( certs) => certs,
1623
+ } ;
1624
+ let key = match rustls_pemfile:: private_key ( & mut key. as_bytes ( ) ) {
1625
+ Err ( err) => return Err ( Error :: other ( format ! ( "fail to read client private key {err}" ) , err) ) ,
1626
+ Ok ( None ) => return Err ( Error :: BadArguments ( & "no client private key" ) ) ,
1627
+ Ok ( Some ( key) ) => key,
1628
+ } ;
1629
+ self . client_certs = Some ( ( certs, Arc :: new ( key) ) ) ;
1630
+ Ok ( self )
1631
+ }
1632
+
1587
1633
/// Specifies client assumed server version of ZooKeeper cluster.
1588
1634
///
1589
1635
/// Client will issue server compatible protocol to avoid [Error::Unimplemented] for some
@@ -1606,13 +1652,17 @@ impl ClientBuilder {
1606
1652
1607
1653
/// Connects to ZooKeeper cluster.
1608
1654
///
1655
+ /// Parameter `cluster` specifies connection string to ZooKeeper cluster. It has same syntax as
1656
+ /// Java client except that you can specifies protocol for server individually. For example,
1657
+ /// `tcp://server1,tcp+tls://server2:port,server3`. This claims that `server1` uses plaintext
1658
+ /// protocol, `server2` uses tls encrypted protocol while `server3` uses tls if
1659
+ /// [Self::assume_tls] is specified or plaintext otherwise.
1660
+ ///
1609
1661
/// # Notable errors
1610
1662
/// * [Error::NoHosts] if no host is available
1611
1663
/// * [Error::SessionExpired] if specified session expired
1612
1664
pub async fn connect ( & mut self , cluster : & str ) -> Result < Client > {
1613
- let ( hosts, chroot) = util:: parse_connect_string ( cluster) ?;
1614
- let mut buf = Vec :: with_capacity ( 4096 ) ;
1615
- let mut connecting_depot = Depot :: for_connecting ( ) ;
1665
+ let ( hosts, chroot) = util:: parse_connect_string ( cluster, self . tls ) ?;
1616
1666
if let Some ( ( id, password) ) = & self . session {
1617
1667
if id. 0 == 0 {
1618
1668
return Err ( Error :: BadArguments ( & "session id must not be 0" ) ) ;
@@ -1628,22 +1678,39 @@ impl ClientBuilder {
1628
1678
} else if self . connection_timeout < Duration :: ZERO {
1629
1679
return Err ( Error :: BadArguments ( & "connection timeout must not be negative" ) ) ;
1630
1680
}
1681
+ self . trusted_certs . extend ( webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) ) ;
1682
+ let tls_config = if let Some ( ( certs, private_key) ) = self . client_certs . take ( ) {
1683
+ match ClientConfig :: builder ( )
1684
+ . with_root_certificates ( std:: mem:: replace ( & mut self . trusted_certs , RootCertStore :: empty ( ) ) )
1685
+ . with_client_auth_cert ( certs, Arc :: try_unwrap ( private_key) . unwrap_or_else ( |k| k. clone_key ( ) ) )
1686
+ {
1687
+ Ok ( config) => config,
1688
+ Err ( err) => return Err ( Error :: other ( format ! ( "invalid client private key {err}" ) , err) ) ,
1689
+ }
1690
+ } else {
1691
+ ClientConfig :: builder ( )
1692
+ . with_root_certificates ( std:: mem:: replace ( & mut self . trusted_certs , RootCertStore :: empty ( ) ) )
1693
+ . with_no_client_auth ( )
1694
+ } ;
1631
1695
let ( mut session, state_receiver) = Session :: new (
1632
1696
self . session . take ( ) ,
1633
1697
& self . authes ,
1634
1698
self . readonly ,
1635
1699
self . detached ,
1700
+ tls_config,
1636
1701
self . session_timeout ,
1637
1702
self . connection_timeout ,
1638
1703
) ;
1639
1704
let mut hosts_iter = hosts. iter ( ) . copied ( ) ;
1640
- let sock = session. start ( & mut hosts_iter, & mut buf, & mut connecting_depot) . await ?;
1705
+ let mut buf = Vec :: with_capacity ( 4096 ) ;
1706
+ let mut connecting_depot = Depot :: for_connecting ( ) ;
1707
+ let conn = session. start ( & mut hosts_iter, & mut buf, & mut connecting_depot) . await ?;
1641
1708
let ( sender, receiver) = mpsc:: unbounded_channel ( ) ;
1642
1709
let servers = hosts. into_iter ( ) . map ( |addr| addr. to_value ( ) ) . collect ( ) ;
1643
1710
let session_info = ( session. session_id , session. session_password . clone ( ) ) ;
1644
1711
let session_timeout = session. session_timeout ;
1645
1712
tokio:: spawn ( async move {
1646
- session. serve ( servers, sock , buf, connecting_depot, receiver) . await ;
1713
+ session. serve ( servers, conn , buf, connecting_depot, receiver) . await ;
1647
1714
} ) ;
1648
1715
let client =
1649
1716
Client :: new ( chroot. to_owned ( ) , self . version , session_info, session_timeout, sender, state_receiver) ;
0 commit comments