Skip to content

Commit a64f51d

Browse files
authored
Merge pull request #129 from AvivNaaman/feature/zero-copy
2 parents 648d2f6 + 0e3ffc3 commit a64f51d

31 files changed

+709
-165
lines changed

Cargo.lock

Lines changed: 69 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/smb-cli/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ tokio = { workspace = true, optional = true }
1212
futures-util = { workspace = true, optional = true }
1313

1414
# CLI & logging
15+
indicatif = "0.18.0"
1516
clap = { version = "4.5.27", features = ["derive"] }
1617
env_logger = "0.11.6"
1718
log = { workspace = true }

crates/smb-cli/src/cli.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ pub struct Cli {
2222
#[arg(long)]
2323
pub no_dfs: bool,
2424

25+
/// Opts-in to use SMB compression if the server supports it.
26+
#[arg(long)]
27+
pub compress: bool,
28+
2529
/// Disables NTLM authentication.
2630
#[arg(long)]
2731
pub no_ntlm: bool,
@@ -83,6 +87,7 @@ impl Cli {
8387
kerberos: !self.no_kerberos,
8488
},
8589
allow_unsigned_guest_access: self.disable_message_signing,
90+
compression_enabled: self.compress,
8691
..Default::default()
8792
},
8893
}

crates/smb-cli/src/copy.rs

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
use crate::{Cli, path::*};
22
use clap::Parser;
3+
use indicatif::{ProgressBar, ProgressStyle};
34
use maybe_async::*;
45
use smb::sync_helpers::*;
56
use smb::{Client, CreateOptions, FileAccessMask, FileAttributes, resource::*};
67
use std::error::Error;
78
#[cfg(not(feature = "async"))]
89
use std::fs;
10+
#[cfg(not(feature = "single_threaded"))]
11+
use std::sync::Arc;
12+
#[cfg(feature = "multi_threaded")]
13+
use std::thread::sleep;
914

1015
#[cfg(feature = "async")]
11-
use tokio::fs;
16+
use tokio::{fs, time::sleep};
1217

1318
#[derive(Parser, Debug)]
1419
pub struct CopyCmd {
@@ -91,10 +96,10 @@ impl CopyFile {
9196
match self.value {
9297
Local(from_local) => match to.value {
9398
Local(_) => unreachable!(),
94-
Remote(to_remote) => block_copy(from_local, to_remote, 16).await?,
99+
Remote(to_remote) => Self::do_copy(from_local, to_remote, 16).await?,
95100
},
96101
Remote(from_remote) => match to.value {
97-
Local(to_local) => block_copy(from_remote, to_local, 16).await?,
102+
Local(to_local) => Self::do_copy(from_remote, to_local, 16).await?,
98103
Remote(to_remote) => {
99104
if to.path.as_remote().unwrap().server()
100105
== self.path.as_remote().unwrap().server()
@@ -104,19 +109,96 @@ impl CopyFile {
104109
// Use server-side copy if both files are on the same server
105110
to_remote.srv_copy(&from_remote).await?
106111
} else {
107-
block_copy(from_remote, to_remote, 8).await?
112+
Self::do_copy(from_remote, to_remote, 8).await?
108113
}
109114
}
110115
},
111116
}
112117
Ok(())
113118
}
119+
120+
#[maybe_async]
121+
#[cfg(not(feature = "single_threaded"))]
122+
pub async fn do_copy<
123+
F: ReadAt + GetLen + Send + Sync + 'static,
124+
T: WriteAt + SetLen + Send + Sync + 'static,
125+
>(
126+
from: F,
127+
to: T,
128+
jobs: usize,
129+
) -> smb::Result<()> {
130+
let state = prepare_parallel_copy(&from, &to, jobs).await?;
131+
let state = Arc::new(state);
132+
let progress_handle = Self::progress(state.clone());
133+
start_parallel_copy(from, to, state).await?;
134+
135+
#[cfg(feature = "async")]
136+
progress_handle.await.unwrap();
137+
#[cfg(not(feature = "async"))]
138+
progress_handle.join().unwrap();
139+
Ok(())
140+
}
141+
142+
/// Single-threaded copy implementation.
143+
#[cfg(feature = "single_threaded")]
144+
pub fn do_copy<F: ReadAt + GetLen, T: WriteAt + SetLen>(
145+
from: F,
146+
to: T,
147+
_jobs: usize,
148+
) -> smb::Result<()> {
149+
let progress = Self::make_progress_bar(from.get_len()?);
150+
block_copy_progress(
151+
from,
152+
to,
153+
Some(&move |current| {
154+
progress.set_position(current);
155+
}),
156+
)
157+
}
158+
159+
/// Async progress bar task starter.
160+
#[cfg(feature = "async")]
161+
fn progress(state: Arc<CopyState>) -> tokio::task::JoinHandle<()> {
162+
tokio::task::spawn(async move { Self::progress_loop(state).await })
163+
}
164+
165+
/// Thread progress bar task starter.
166+
#[cfg(feature = "multi_threaded")]
167+
fn progress(state: Arc<CopyState>) -> std::thread::JoinHandle<()> {
168+
std::thread::spawn(move || {
169+
Self::progress_loop(state);
170+
})
171+
}
172+
173+
/// Thread/task entrypoint for measuring and displaying copy progress.
174+
#[cfg(not(feature = "single_threaded"))]
175+
#[maybe_async]
176+
async fn progress_loop(state: Arc<CopyState>) {
177+
let progress_bar = Self::make_progress_bar(state.total_size());
178+
loop {
179+
let bytes_copied = state.bytes_copied();
180+
progress_bar.set_position(bytes_copied);
181+
if bytes_copied >= state.total_size() {
182+
break;
183+
}
184+
sleep(std::time::Duration::from_millis(100)).await;
185+
}
186+
progress_bar.finish_with_message("Copy complete");
187+
}
188+
189+
/// Returns a new progress bar instance for copying files.
190+
fn make_progress_bar(len: u64) -> ProgressBar {
191+
let progress = ProgressBar::new(len);
192+
progress.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
193+
.unwrap().progress_chars("#>-"));
194+
progress
195+
}
114196
}
115197

116198
#[maybe_async]
117199
pub async fn copy(cmd: &CopyCmd, cli: &Cli) -> Result<(), Box<dyn Error>> {
118200
if matches!(cmd.from, Path::Local(_)) && matches!(cmd.to, Path::Local(_)) {
119-
return Err("Copying between two local files is not supported".into());
201+
return Err("Copying between two local files is not supported. Use `cp` or `copy` shell commands instead :)".into());
120202
}
121203

122204
let client = Client::new(cli.make_smb_client_config());

crates/smb-cli/src/path.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ impl FromStr for Path {
3232
type Err = smb::Error;
3333

3434
fn from_str(input: &str) -> Result<Self, Self::Err> {
35+
let input = input.replace('/', r"\");
3536
if input.starts_with(r"\\") {
3637
Ok(Path::Remote(input.parse()?))
3738
} else {

0 commit comments

Comments
 (0)