Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions smb/src/client/smb_client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::HashMap, str::FromStr};

use std::sync::{Arc};
use maybe_async::maybe_async;

use tokio::sync::Mutex;
use crate::{
Connection, Error, FileCreateArgs, Resource, Session, Tree,
packets::{
Expand All @@ -18,7 +18,7 @@ use super::{config::ClientConfig, unc_path::UncPath};
pub struct Client {
config: ClientConfig,

connections: HashMap<UncPath, OpenedConnectionInfo>,
connections: Mutex<HashMap<UncPath, Arc<OpenedConnectionInfo>>>,
}

struct OpenedConnectionInfo {
Expand All @@ -33,13 +33,13 @@ impl Client {
pub fn new(config: ClientConfig) -> Self {
Client {
config,
connections: HashMap::new(),
connections: Mutex::new(HashMap::new()),
}
}

#[maybe_async]
pub async fn close(&mut self) -> crate::Result<()> {
self.connections.clear();
pub async fn close(&self) -> crate::Result<()> {
self.connections.lock().await.clear();
Ok(())
}

Expand Down Expand Up @@ -67,7 +67,7 @@ impl Client {

#[maybe_async]
pub async fn share_connect(
&mut self,
&self,
unc: &UncPath,
user_name: &str,
password: String,
Expand All @@ -80,7 +80,8 @@ impl Client {

let share_unc = unc.clone().with_no_path();

if self.connections.contains_key(&share_unc) {
let mut connections = self.connections.lock().await;
if connections.contains_key(&share_unc) {
log::warn!("Connection already exists for this UNC path. Reusing it.");
return Ok(());
}
Expand All @@ -102,14 +103,15 @@ impl Client {
}

log::debug!("Connected to share {share_unc} with user {user_name}");
self.connections.insert(share_unc, opened_conn_info);
connections.insert(share_unc, Arc::new(opened_conn_info));

Ok(())
}

fn get_opened_conn_for_path(&self, unc: &UncPath) -> crate::Result<&OpenedConnectionInfo> {
if let Some(cst) = self.connections.get(&unc.clone().with_no_path()) {
Ok(cst)
async fn get_opened_conn_for_path(&self, unc: &UncPath) -> crate::Result<Arc<OpenedConnectionInfo>> {
let connections = self.connections.lock().await;
if let Some(cst) = connections.get(&unc.clone().with_no_path()) {
Ok(cst.clone())
} else {
Err(crate::Error::InvalidArgument(format!(
"No connection found for {unc}. Use `share_connect` to create one.",
Expand All @@ -123,7 +125,7 @@ impl Client {
path: &UncPath,
args: &FileCreateArgs,
) -> crate::Result<Resource> {
let conn_info = self.get_opened_conn_for_path(path)?;
let conn_info = self.get_opened_conn_for_path(path).await?;
conn_info
.tree
.create(path.path.as_deref().unwrap_or(""), args)
Expand All @@ -132,7 +134,7 @@ impl Client {

#[maybe_async]
pub async fn create_file(
&mut self,
&self,
path: &UncPath,
args: &FileCreateArgs,
) -> crate::Result<Resource> {
Expand Down Expand Up @@ -173,10 +175,10 @@ impl Client {
}
}

struct DfsResolver<'a>(&'a mut Client);
struct DfsResolver<'a>(&'a Client);

impl<'a> DfsResolver<'a> {
fn new(client: &'a mut Client) -> Self {
fn new(client: &'a Client) -> Self {
DfsResolver(client)
}

Expand All @@ -192,7 +194,7 @@ impl<'a> DfsResolver<'a> {
// Re-use the same credentials for the DFS referral.
let dfs_creds = self
.0
.get_opened_conn_for_path(dfs_path)?
.get_opened_conn_for_path(dfs_path).await?
.creds
.clone()
.ok_or_else(|| {
Expand Down Expand Up @@ -232,14 +234,17 @@ impl<'a> DfsResolver<'a> {
log::debug!("Resolving DFS referral for {unc}");
let dfs_path_string = unc.to_string();

let dfs_root = self.0.get_opened_conn_for_path(unc)?.tree.as_dfs_tree()?;

let dfs_refs = dfs_root.dfs_get_referrals(&dfs_path_string).await?;
let dfs_refs = {
let conn = &self.0.get_opened_conn_for_path(unc).await?;
let dfs_root = conn.tree.as_dfs_tree()?;
dfs_root.dfs_get_referrals(&dfs_path_string).await?
};
if !dfs_refs.referral_header_flags.storage_servers() {
return Err(Error::InvalidMessage(
"DFS referral does not contain storage servers".to_string(),
));
}

let mut paths = vec![];
// Resolve the DFS referral entries.
for (indx, curr_referral) in dfs_refs.referral_entries.iter().enumerate() {
Expand Down
Loading