Skip to content

Commit dfb3e00

Browse files
Show copy progress for CopyFileToMachine on Shell Service (#4918)
1 parent f12082c commit dfb3e00

File tree

3 files changed

+174
-11
lines changed

3 files changed

+174
-11
lines changed

cli/app.go

+10
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ const (
6969
generalFlagTags = "tags"
7070
generalFlagStart = "start"
7171
generalFlagEnd = "end"
72+
generalFlagNoProgress = "no-progress"
7273

7374
moduleFlagLanguage = "language"
7475
moduleFlagPublicNamespace = "public-namespace"
@@ -2349,6 +2350,11 @@ Copy multiple files from the machine to a local destination with recursion and k
23492350
// Note(erd): maybe support access time in the future if needed
23502351
Usage: "preserve modification times and file mode bits from the source files",
23512352
},
2353+
&cli.BoolFlag{
2354+
Name: generalFlagNoProgress,
2355+
Aliases: []string{"n"},
2356+
Usage: "hide progress of the file transfer",
2357+
},
23522358
}...),
23532359
Action: createCommandWithT[machinesPartCopyFilesArgs](MachinesPartCopyFilesAction),
23542360
},
@@ -2821,6 +2827,10 @@ This won't work unless you have an existing installation of our GitHub app on yo
28212827
Name: moduleFlagLocal,
28222828
Usage: "if the target machine is localhost, run the entrypoint directly rather than transferring a bundle",
28232829
},
2830+
&cli.BoolFlag{
2831+
Name: generalFlagNoProgress,
2832+
Usage: "hide progress of the file transfer",
2833+
},
28242834
&cli.StringFlag{
28252835
Name: moduleFlagHomeDir,
28262836
Usage: "remote user's home directory. only necessary if you're targeting a remote machine where $HOME is not /root",

cli/client.go

+160-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"encoding/json"
88
"fmt"
99
"io"
10+
"io/fs"
11+
"math"
1012
"net"
1113
"net/http"
1214
"net/url"
@@ -1289,6 +1291,7 @@ type machinesPartCopyFilesArgs struct {
12891291
Part string
12901292
Recursive bool
12911293
Preserve bool
1294+
NoProgress bool
12921295
}
12931296

12941297
// MachinesPartCopyFilesAction is the corresponding Action for 'machines part cp'.
@@ -1395,6 +1398,7 @@ func (c *viamClient) machinesPartCopyFilesAction(
13951398
paths,
13961399
destination,
13971400
logger,
1401+
flagArgs.NoProgress,
13981402
)
13991403
}
14001404
if err := doCopy(); err != nil {
@@ -2471,12 +2475,13 @@ func (c *viamClient) copyFilesToMachine(
24712475
paths []string,
24722476
destination string,
24732477
logger logging.Logger,
2478+
noProgress bool,
24742479
) error {
24752480
shellSvc, closeClient, err := c.connectToShellService(orgStr, locStr, robotStr, partStr, debug, logger)
24762481
if err != nil {
24772482
return err
24782483
}
2479-
return c.copyFilesToMachineInner(shellSvc, closeClient, allowRecursion, preserve, paths, destination)
2484+
return c.copyFilesToMachineInner(shellSvc, closeClient, allowRecursion, preserve, paths, destination, noProgress)
24802485
}
24812486

24822487
// copyFilesToFqdn is a copyFilesToMachine variant that makes use of pre-fetched part FQDN.
@@ -2488,12 +2493,13 @@ func (c *viamClient) copyFilesToFqdn(
24882493
paths []string,
24892494
destination string,
24902495
logger logging.Logger,
2496+
noProgress bool,
24912497
) error {
24922498
shellSvc, closeClient, err := c.connectToShellServiceFqdn(fqdn, debug, logger)
24932499
if err != nil {
24942500
return err
24952501
}
2496-
return c.copyFilesToMachineInner(shellSvc, closeClient, allowRecursion, preserve, paths, destination)
2502+
return c.copyFilesToMachineInner(shellSvc, closeClient, allowRecursion, preserve, paths, destination, noProgress)
24972503
}
24982504

24992505
// copyFilesToMachineInner is the common logic for both copyFiles variants.
@@ -2504,16 +2510,81 @@ func (c *viamClient) copyFilesToMachineInner(
25042510
preserve bool,
25052511
paths []string,
25062512
destination string,
2513+
noProgress bool,
25072514
) error {
25082515
defer func() {
25092516
utils.UncheckedError(closeClient(c.c.Context))
25102517
}()
25112518

2512-
// prepare a factory that understands the file copying service (RPC or not).
2513-
copyFactory := shell.NewCopyFileToMachineFactory(destination, preserve, shellSvc)
2514-
// make a reader copier that just does the traversal and copy work for us. Think of
2515-
// this as a tee reader.
2516-
readCopier, err := shell.NewLocalFileReadCopier(paths, allowRecursion, false, copyFactory)
2519+
if noProgress {
2520+
// prepare a factory that understands the file copying service (RPC or not).
2521+
copyFactory := shell.NewCopyFileToMachineFactory(destination, preserve, shellSvc)
2522+
// make a reader copier that just does the traversal and copy work for us. Think of
2523+
// this as a tee reader.
2524+
readCopier, err := shell.NewLocalFileReadCopier(paths, allowRecursion, false, copyFactory)
2525+
if err != nil {
2526+
return err
2527+
}
2528+
defer func() {
2529+
if err := readCopier.Close(c.c.Context); err != nil {
2530+
utils.UncheckedError(err)
2531+
}
2532+
}()
2533+
2534+
// ReadAll the files into the copier.
2535+
return readCopier.ReadAll(c.c.Context)
2536+
}
2537+
2538+
// Calculate total size of all files to be copied
2539+
var totalSize int64
2540+
for _, path := range paths {
2541+
info, err := os.Stat(path)
2542+
if err != nil {
2543+
return err
2544+
}
2545+
if info.IsDir() && allowRecursion {
2546+
err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
2547+
if err != nil {
2548+
return err
2549+
}
2550+
if !info.IsDir() {
2551+
totalSize += info.Size()
2552+
}
2553+
return nil
2554+
})
2555+
if err != nil {
2556+
return err
2557+
}
2558+
} else if !info.IsDir() {
2559+
totalSize += info.Size()
2560+
}
2561+
}
2562+
2563+
// Create a progress tracking function
2564+
var currentFile string
2565+
progressFunc := func(bytes int64, file string, fileSize int64) {
2566+
if file != currentFile {
2567+
if currentFile != "" {
2568+
//nolint:errcheck // progress display is non-critical
2569+
_, _ = os.Stdout.WriteString("\n")
2570+
}
2571+
currentFile = file
2572+
//nolint:errcheck // progress display is non-critical
2573+
_, _ = os.Stdout.WriteString(fmt.Sprintf("Copying %s...\n", file))
2574+
}
2575+
uploadPercent := int(math.Ceil(100 * float64(bytes) / float64(fileSize)))
2576+
//nolint:errcheck // progress display is non-critical
2577+
_, _ = os.Stdout.WriteString(fmt.Sprintf("\rProgress: %d%% (%d/%d bytes)", uploadPercent, bytes, fileSize))
2578+
}
2579+
2580+
// Wrap the copy factory to track progress
2581+
progressFactory := &progressTrackingFactory{
2582+
factory: shell.NewCopyFileToMachineFactory(destination, preserve, shellSvc),
2583+
onProgress: progressFunc,
2584+
}
2585+
2586+
// Create a new read copier with the progress tracking factory
2587+
readCopier, err := shell.NewLocalFileReadCopier(paths, allowRecursion, false, progressFactory)
25172588
if err != nil {
25182589
return err
25192590
}
@@ -2524,7 +2595,88 @@ func (c *viamClient) copyFilesToMachineInner(
25242595
}()
25252596

25262597
// ReadAll the files into the copier.
2527-
return readCopier.ReadAll(c.c.Context)
2598+
err = readCopier.ReadAll(c.c.Context)
2599+
return err
2600+
}
2601+
2602+
// progressTrackingFactory wraps a copy factory to track progress.
2603+
type progressTrackingFactory struct {
2604+
factory shell.FileCopyFactory
2605+
onProgress func(int64, string, int64)
2606+
}
2607+
2608+
func (ptf *progressTrackingFactory) MakeFileCopier(ctx context.Context, sourceType shell.CopyFilesSourceType) (shell.FileCopier, error) {
2609+
copier, err := ptf.factory.MakeFileCopier(ctx, sourceType)
2610+
if err != nil {
2611+
return nil, err
2612+
}
2613+
return &progressTrackingCopier{
2614+
copier: copier,
2615+
onProgress: ptf.onProgress,
2616+
}, nil
2617+
}
2618+
2619+
// progressTrackingCopier wraps a file copier to track progress.
2620+
type progressTrackingCopier struct {
2621+
copier shell.FileCopier
2622+
onProgress func(int64, string, int64)
2623+
}
2624+
2625+
func (ptc *progressTrackingCopier) Copy(ctx context.Context, file shell.File) error {
2626+
// Get file size
2627+
info, err := file.Data.Stat()
2628+
if err != nil {
2629+
return err
2630+
}
2631+
fileSize := info.Size()
2632+
2633+
// Create a progress tracking reader
2634+
progressReader := &progressReader{
2635+
reader: file.Data,
2636+
onProgress: ptc.onProgress,
2637+
fileName: file.RelativeName,
2638+
fileSize: fileSize,
2639+
}
2640+
2641+
// Create a new file with the progress tracking reader
2642+
progressFile := shell.File{
2643+
RelativeName: file.RelativeName,
2644+
Data: progressReader,
2645+
}
2646+
2647+
return ptc.copier.Copy(ctx, progressFile)
2648+
}
2649+
2650+
func (ptc *progressTrackingCopier) Close(ctx context.Context) error {
2651+
//nolint:errcheck // progress display is non-critical
2652+
_, _ = os.Stdout.WriteString("\n")
2653+
return ptc.copier.Close(ctx)
2654+
}
2655+
2656+
// progressReader wraps a reader to track progress.
2657+
type progressReader struct {
2658+
reader fs.File
2659+
onProgress func(int64, string, int64)
2660+
copied int64
2661+
fileName string
2662+
fileSize int64
2663+
}
2664+
2665+
func (pr *progressReader) Read(p []byte) (int, error) {
2666+
n, err := pr.reader.Read(p)
2667+
if n > 0 {
2668+
pr.copied += int64(n)
2669+
pr.onProgress(pr.copied, pr.fileName, pr.fileSize)
2670+
}
2671+
return n, err
2672+
}
2673+
2674+
func (pr *progressReader) Stat() (fs.FileInfo, error) {
2675+
return pr.reader.Stat()
2676+
}
2677+
2678+
func (pr *progressReader) Close() error {
2679+
return pr.reader.Close()
25282680
}
25292681

25302682
func (c *viamClient) copyFilesFromMachine(

cli/module_build.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ type reloadModuleArgs struct {
488488
RestartOnly bool
489489
NoBuild bool
490490
Local bool
491+
NoProgress bool
491492
}
492493

493494
// ReloadModuleAction builds a module, configures it on a robot, and starts or restarts it.
@@ -558,14 +559,14 @@ func reloadModuleAction(c *cli.Context, vc *viamClient, args reloadModuleArgs, l
558559
return err
559560
}
560561
infof(c.App.Writer, "Copying %s to part %s", manifest.Build.Path, part.Part.Id)
561-
args, err := getGlobalArgs(c)
562+
globalArgs, err := getGlobalArgs(c)
562563
if err != nil {
563564
return err
564565
}
565566
dest := reloadingDestination(c, manifest)
566567
err = vc.copyFilesToFqdn(
567-
part.Part.Fqdn, args.Debug, false, false, []string{manifest.Build.Path},
568-
dest, logging.NewLogger("reload"))
568+
part.Part.Fqdn, globalArgs.Debug, false, false, []string{manifest.Build.Path},
569+
dest, logging.NewLogger("reload"), args.NoProgress)
569570
if err != nil {
570571
if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied {
571572
warningf(c.App.ErrWriter, "RDK couldn't write to the default file copy destination. "+

0 commit comments

Comments
 (0)