Skip to content

Commit 7d637c3

Browse files
authored
Merge pull request #38 from AvivNaaman/33-use-a-common-connection-info-struct
#33 common Arc<ConnectionInfo> instead of handler derefs, and adding …
2 parents 8858a01 + 1dea7fa commit 7d637c3

File tree

14 files changed

+240
-174
lines changed

14 files changed

+240
-174
lines changed

smb/src/connection.rs

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pub mod config;
2-
pub mod negotiation_state;
2+
pub mod connection_info;
33
pub mod netbios_client;
44
pub mod preauth_hash;
55
pub mod transformer;
@@ -22,8 +22,8 @@ use crate::{
2222
};
2323
use binrw::prelude::*;
2424
pub use config::*;
25+
use connection_info::{ConnectionInfo, NegotiatedProperties};
2526
use maybe_async::*;
26-
use negotiation_state::{ConnectionInfo, NegotiatedProperties};
2727
use netbios_client::NetBiosClient;
2828
use std::cmp::max;
2929
use std::sync::atomic::{AtomicU16, AtomicU64};
@@ -38,14 +38,18 @@ pub struct Connection {
3838
}
3939

4040
impl Connection {
41+
/// Creates a new SMB connection, specifying a server configuration, without connecting to a server.
42+
/// Use the [`connect`](Connection::connect) method to establish a connection.
4143
pub fn build(config: ConnectionConfig) -> crate::Result<Connection> {
4244
config.validate()?;
45+
let client_guid = config.client_guid.unwrap_or_else(Guid::gen);
4346
Ok(Connection {
44-
handler: HandlerReference::new(ConnectionMessageHandler::new()),
47+
handler: HandlerReference::new(ConnectionMessageHandler::new(client_guid)),
4548
config,
4649
})
4750
}
4851

52+
/// Sets operations timeout for the connection.
4953
#[maybe_async]
5054
pub async fn set_timeout(&mut self, timeout: Option<Duration>) -> crate::Result<()> {
5155
self.config.timeout = timeout;
@@ -55,8 +59,13 @@ impl Connection {
5559
Ok(())
5660
}
5761

62+
/// Connects to the specified server, if it is not already connected, and negotiates the connection.
5863
#[maybe_async]
5964
pub async fn connect(&mut self, address: &str) -> crate::Result<()> {
65+
if self.handler.worker().is_some() {
66+
return Err(Error::InvalidState("Already connected".into()));
67+
}
68+
6069
let mut netbios_client = NetBiosClient::new(self.config.timeout);
6170

6271
log::debug!("Connecting to {}...", address);
@@ -76,6 +85,7 @@ impl Connection {
7685
}
7786
}
7887

88+
/// This method switches the netbios client to SMB2 and starts the worker.
7989
#[maybe_async]
8090
async fn negotiate_switch_to_smb2(
8191
&mut self,
@@ -125,10 +135,11 @@ impl Connection {
125135
Ok(WorkerImpl::start(netbios_client, self.config.timeout).await?)
126136
}
127137

138+
/// This method perofrms the SMB2 negotiation.
128139
#[maybe_async]
129-
async fn negotiate_smb2(&mut self) -> crate::Result<()> {
140+
async fn negotiate_smb2(&mut self) -> crate::Result<ConnectionInfo> {
130141
// Confirm that we're not already negotiated.
131-
if self.handler.negotiate_info().is_some() {
142+
if self.handler.conn_info.get().is_some() {
132143
return Err(Error::InvalidState("Already negotiated".into()));
133144
}
134145

@@ -157,12 +168,17 @@ impl Connection {
157168

158169
// Send SMB2 negotiate request
159170
let client_guid = self.handler.client_guid;
171+
let hostname = self
172+
.config
173+
.client_name
174+
.clone()
175+
.unwrap_or_else(|| "smb-client".to_string());
160176
let response = self
161177
.handler
162178
.send_recv(Content::NegotiateRequest(NegotiateRequest::new(
163-
"AVIV-MBP".to_string(),
179+
hostname,
164180
client_guid,
165-
dialects.clone(),
181+
dialects,
166182
encryption_algos,
167183
crypto::ENCRYPTING_ALGOS.to_vec(),
168184
compression::SUPPORTED_ALGORITHMS.to_vec(),
@@ -178,21 +194,21 @@ impl Connection {
178194
))?;
179195

180196
// well, only 3.1 is supported for starters.
181-
if !dialects.contains(&smb2_negotiate_response.dialect_revision.try_into()?) {
197+
let dialect_rev = smb2_negotiate_response.dialect_revision.try_into()?;
198+
if dialect_rev > max_dialect || dialect_rev < min_dialect {
182199
return Err(Error::NegotiationError(
183200
"Server selected an unsupported dialect.".into(),
184201
));
185202
}
186203

187-
let dialect_rev = smb2_negotiate_response.dialect_revision.try_into()?;
188204
let dialect_impl = DialectImpl::new(dialect_rev);
189-
let mut state = NegotiatedProperties {
205+
let mut negotiation = NegotiatedProperties {
190206
server_guid: smb2_negotiate_response.server_guid,
191207
caps: smb2_negotiate_response.capabilities.clone(),
192208
max_transact_size: smb2_negotiate_response.max_transact_size,
193209
max_read_size: smb2_negotiate_response.max_read_size,
194210
max_write_size: smb2_negotiate_response.max_write_size,
195-
gss_token: smb2_negotiate_response.buffer.clone(),
211+
auth_buffer: smb2_negotiate_response.buffer.clone(),
196212
signing_algo: None,
197213
encryption_cipher: None,
198214
compression: None,
@@ -201,11 +217,11 @@ impl Connection {
201217

202218
dialect_impl.process_negotiate_request(
203219
&smb2_negotiate_response,
204-
&mut state,
220+
&mut negotiation,
205221
&self.config,
206222
)?;
207223
if ((!u32::from_le_bytes(dialect_impl.get_negotiate_caps_mask().into_bytes()))
208-
& u32::from_le_bytes(state.caps.into_bytes()))
224+
& u32::from_le_bytes(negotiation.caps.into_bytes()))
209225
!= 0
210226
{
211227
return Err(Error::NegotiationError(
@@ -216,19 +232,14 @@ impl Connection {
216232
log::trace!(
217233
"Negotiated SMB results: dialect={:?}, state={:?}",
218234
dialect_rev,
219-
&state
235+
&negotiation
220236
);
221237

222-
self.handler
223-
.negotiate_info
224-
.set(ConnectionInfo {
225-
state,
226-
dialect: dialect_impl,
227-
config: self.config.clone(),
228-
})
229-
.unwrap();
230-
231-
Ok(())
238+
Ok(ConnectionInfo {
239+
negotiation,
240+
dialect: dialect_impl,
241+
config: self.config.clone(),
242+
})
232243
}
233244

234245
/// Send negotiate messages, potentially
@@ -238,7 +249,7 @@ impl Connection {
238249
netbios_client: NetBiosClient,
239250
multi_protocol: bool,
240251
) -> crate::Result<()> {
241-
if self.handler.negotiate_info().is_some() {
252+
if self.handler.conn_info.get().is_some() {
242253
return Err(Error::InvalidState("Already negotiated".into()));
243254
}
244255

@@ -250,14 +261,18 @@ impl Connection {
250261
self.handler.worker.set(worker).unwrap();
251262

252263
// Negotiate SMB2
253-
self.negotiate_smb2().await?;
264+
let info = self.negotiate_smb2().await?;
265+
254266
self.handler
255267
.worker
256268
.get()
257269
.ok_or("Worker is uninitialized")
258270
.unwrap()
259-
.negotaite_complete(&self.handler.negotiate_info().unwrap())
271+
.negotaite_complete(&info)
260272
.await;
273+
274+
self.handler.conn_info.set(Arc::new(info)).unwrap();
275+
261276
log::info!("Negotiation successful");
262277
Ok(())
263278
}
@@ -268,11 +283,13 @@ impl Connection {
268283
user_name: &str,
269284
password: String,
270285
) -> crate::Result<Session> {
271-
let mut session = Session::new(self.handler.clone());
272-
273-
session.setup(user_name, password).await?;
274-
275-
Ok(session)
286+
Session::setup(
287+
user_name,
288+
password,
289+
self.handler.clone(),
290+
self.handler.conn_info.get().unwrap(),
291+
)
292+
.await
276293
}
277294
}
278295

@@ -286,7 +303,7 @@ pub struct ConnectionMessageHandler {
286303
worker: OnceCell<Arc<WorkerImpl>>,
287304

288305
// Negotiation-related state.
289-
negotiate_info: OnceCell<ConnectionInfo>,
306+
conn_info: OnceCell<Arc<ConnectionInfo>>,
290307

291308
/// Number of credits available to the client at the moment, for the next requests.
292309
curr_credits: Semaphore,
@@ -297,22 +314,18 @@ pub struct ConnectionMessageHandler {
297314
}
298315

299316
impl ConnectionMessageHandler {
300-
fn new() -> ConnectionMessageHandler {
317+
fn new(client_guid: Guid) -> ConnectionMessageHandler {
301318
ConnectionMessageHandler {
302-
client_guid: Guid::gen(),
319+
client_guid,
303320
worker: OnceCell::new(),
304-
negotiate_info: OnceCell::new(),
321+
conn_info: OnceCell::new(),
305322
extra_credits_to_request: 4,
306323
curr_credits: Semaphore::new(1),
307324
curr_msg_id: AtomicU64::new(1),
308325
credit_pool: AtomicU16::new(1),
309326
}
310327
}
311328

312-
pub fn negotiate_info(&self) -> Option<&ConnectionInfo> {
313-
self.negotiate_info.get()
314-
}
315-
316329
pub fn worker(&self) -> Option<&Arc<WorkerImpl>> {
317330
self.worker.get()
318331
}
@@ -328,8 +341,8 @@ impl ConnectionMessageHandler {
328341

329342
#[maybe_async]
330343
async fn process_sequence_outgoing(&self, msg: &mut OutgoingMessage) -> crate::Result<()> {
331-
if let Some(neg) = self.negotiate_info() {
332-
if neg.state.dialect_rev > Dialect::Smb0202 && neg.state.caps.large_mtu() {
344+
if let Some(neg) = self.conn_info.get() {
345+
if neg.negotiation.dialect_rev > Dialect::Smb0202 && neg.negotiation.caps.large_mtu() {
333346
// Calculate the cost of the message (charge).
334347
let cost = if Self::SET_CREDIT_CHARGE_CMDS
335348
.iter()
@@ -379,8 +392,8 @@ impl ConnectionMessageHandler {
379392

380393
#[maybe_async]
381394
async fn process_sequence_incoming(&self, msg: &IncomingMessage) -> crate::Result<()> {
382-
if let Some(neg) = self.negotiate_info() {
383-
if neg.state.dialect_rev > Dialect::Smb0202 && neg.state.caps.large_mtu() {
395+
if let Some(neg) = self.conn_info.get() {
396+
if neg.negotiation.dialect_rev > Dialect::Smb0202 && neg.negotiation.caps.large_mtu() {
384397
let granted_credits = msg.message.header.credit_request;
385398
let charged_credits = msg.message.header.credit_charge;
386399
// Update the pool size - return how many EXTRA credits were granted.
@@ -409,8 +422,8 @@ impl MessageHandler for ConnectionMessageHandler {
409422
#[maybe_async]
410423
async fn sendo(&self, mut msg: OutgoingMessage) -> crate::Result<SendMessageResult> {
411424
// TODO: Add assertion in the struct regarding the selected dialect!
412-
let priority_value = match self.negotiate_info.get() {
413-
Some(neg_info) => match neg_info.state.dialect_rev {
425+
let priority_value = match self.conn_info.get() {
426+
Some(neg_info) => match neg_info.negotiation.dialect_rev {
414427
Dialect::Smb0311 => 1,
415428
_ => 0,
416429
},

smb/src/connection/config.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
//! Connection configuration settings.
2+
13
use std::time::Duration;
24

3-
use crate::packets::smb2::Dialect;
5+
use crate::packets::{guid::Guid, smb2::Dialect};
46

7+
/// Specifies the encryption mode for the connection.
8+
/// Use this as part of the [ConnectionConfig] to specify the encryption mode for the connection.
59
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
610
pub enum EncryptionMode {
711
/// Encryption is allowed but not required, it's up to the server to decide.
@@ -28,15 +32,39 @@ impl EncryptionMode {
2832
/// Specifies the configuration for a connection.
2933
#[derive(Debug, Default, Clone)]
3034
pub struct ConnectionConfig {
35+
/// Specifies the timeout for the connection.
3136
pub timeout: Option<Duration>,
37+
38+
/// Specifies the minimum and maximum dialects to be used in the connection.
39+
///
40+
/// Note, that if set, the minimum dialect must be less than or equal to the maximum dialect.
3241
pub min_dialect: Option<Dialect>,
42+
43+
/// Specifies the minimum and maximum dialects to be used in the connection.
44+
///
45+
/// Note, that if set, the minimum dialect must be less than or equal to the maximum dialect.
3346
pub max_dialect: Option<Dialect>,
47+
48+
/// Sets the encryption mode for the connection.
49+
/// See [EncryptionMode] for more information.
3450
pub encryption_mode: EncryptionMode,
51+
52+
/// Whether to enable compression, if supported by the server and specified connection dialects.
53+
///
54+
/// Note: you must also have compression features enabled when building the crate, otherwise compression
55+
/// would not be available. *The compression feature is enabled by default.*
3556
pub compression_enabled: bool,
57+
58+
/// Specifies the client name to be used in the SMB2 negotiate request.
59+
pub client_name: Option<String>,
60+
61+
/// Specifies the GUID of the client to be used in the SMB2 negotiate request.
62+
/// If not set, a random GUID will be generated.
63+
pub client_guid: Option<Guid>,
3664
}
3765

3866
impl ConnectionConfig {
39-
/// Validates the configuration.
67+
/// Validates common configuration settings.
4068
pub fn validate(&self) -> crate::Result<()> {
4169
// Make sure dialects min <= max.
4270
if let (Some(min), Some(max)) = (self.min_dialect, self.max_dialect) {
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use std::sync::Arc;
2+
3+
use crate::dialects::DialectImpl;
4+
use crate::packets::guid::Guid;
5+
use crate::packets::smb2::*;
6+
use binrw::prelude::*;
7+
8+
use super::ConnectionConfig;
9+
10+
/// Contains important information from the negotiation process,
11+
/// to be used during connection operations.
12+
#[derive(Debug)]
13+
pub struct NegotiatedProperties {
14+
/// From the server's negotiation response.
15+
pub server_guid: Guid,
16+
17+
/// From the server's negotiation response.
18+
pub caps: GlobalCapabilities,
19+
20+
/// From the server's negotiation response.
21+
pub max_transact_size: u32,
22+
/// From the server's negotiation response.
23+
pub max_read_size: u32,
24+
/// From the server's negotiation response.
25+
pub max_write_size: u32,
26+
27+
/// From the server's negotiation response.
28+
pub auth_buffer: Vec<u8>,
29+
30+
/// Signing algorithm used for the connection, and specified by the server
31+
/// using negotiation context. This is irrelevant for dialects below 3.1.1,
32+
/// and if not specified, this property is not set, but the connection may still be
33+
/// signed using the default algorithm, as specified in the spec.
34+
pub signing_algo: Option<SigningAlgorithmId>,
35+
/// Encryption cipher used for the connection, and specified by the server
36+
/// using negotiation context. This is irrelevant for dialects below 3.1.1,
37+
/// and if not specified, this property is not set, but the connection may still be
38+
/// encrypted using the default cipher, as specified in the spec.
39+
pub encryption_cipher: Option<EncryptionCipher>,
40+
/// Compression capabilities used for the connection, and specified by the server
41+
/// using negotiation context.
42+
pub compression: Option<CompressionCaps>,
43+
44+
/// The selected dialect revision for the connection.
45+
/// Use [ConnectionInfo::dialect] to get the implementation of the selected dialect.
46+
pub dialect_rev: Dialect,
47+
}
48+
49+
/// This struct is initalized once a connection is established and negotiated.
50+
/// It contains all the information about the connection.
51+
#[derive(Debug)]
52+
pub struct ConnectionInfo {
53+
/// Contains negotiated properties of the connection.
54+
pub negotiation: NegotiatedProperties,
55+
/// Contains the implementation of the selected dialect.
56+
pub dialect: Arc<DialectImpl>,
57+
/// Contains the configuration of the connection, as specified by the user when the connection was established.
58+
pub config: ConnectionConfig,
59+
}

0 commit comments

Comments
 (0)