Skip to content

Show copy progress for CopyFileToMachine on Shell Service #4918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const (
generalFlagTags = "tags"
generalFlagStart = "start"
generalFlagEnd = "end"
generalFlagNoProgress = "no-progress"

moduleFlagLanguage = "language"
moduleFlagPublicNamespace = "public-namespace"
Expand Down Expand Up @@ -2343,6 +2344,11 @@ Copy multiple files from the machine to a local destination with recursion and k
// Note(erd): maybe support access time in the future if needed
Usage: "preserve modification times and file mode bits from the source files",
},
&cli.BoolFlag{
Name: generalFlagNoProgress,
Aliases: []string{"n"},
Usage: "hide progress of the file transfer",
},
}...),
Action: createCommandWithT[machinesPartCopyFilesArgs](MachinesPartCopyFilesAction),
},
Expand Down Expand Up @@ -2815,6 +2821,10 @@ This won't work unless you have an existing installation of our GitHub app on yo
Name: moduleFlagLocal,
Usage: "if the target machine is localhost, run the entrypoint directly rather than transferring a bundle",
},
&cli.BoolFlag{
Name: generalFlagNoProgress,
Usage: "hide progress of the file transfer",
},
&cli.StringFlag{
Name: moduleFlagHomeDir,
Usage: "remote user's home directory. only necessary if you're targeting a remote machine where $HOME is not /root",
Expand Down
168 changes: 160 additions & 8 deletions cli/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"encoding/json"
"fmt"
"io"
"io/fs"
"math"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -1238,6 +1240,7 @@ type machinesPartCopyFilesArgs struct {
Part string
Recursive bool
Preserve bool
NoProgress bool
}

// MachinesPartCopyFilesAction is the corresponding Action for 'machines part cp'.
Expand Down Expand Up @@ -1344,6 +1347,7 @@ func (c *viamClient) machinesPartCopyFilesAction(
paths,
destination,
logger,
flagArgs.NoProgress,
)
}
if err := doCopy(); err != nil {
Expand Down Expand Up @@ -2420,12 +2424,13 @@ func (c *viamClient) copyFilesToMachine(
paths []string,
destination string,
logger logging.Logger,
noProgress bool,
) error {
shellSvc, closeClient, err := c.connectToShellService(orgStr, locStr, robotStr, partStr, debug, logger)
if err != nil {
return err
}
return c.copyFilesToMachineInner(shellSvc, closeClient, allowRecursion, preserve, paths, destination)
return c.copyFilesToMachineInner(shellSvc, closeClient, allowRecursion, preserve, paths, destination, noProgress)
}

// copyFilesToFqdn is a copyFilesToMachine variant that makes use of pre-fetched part FQDN.
Expand All @@ -2437,12 +2442,13 @@ func (c *viamClient) copyFilesToFqdn(
paths []string,
destination string,
logger logging.Logger,
noProgress bool,
) error {
shellSvc, closeClient, err := c.connectToShellServiceFqdn(fqdn, debug, logger)
if err != nil {
return err
}
return c.copyFilesToMachineInner(shellSvc, closeClient, allowRecursion, preserve, paths, destination)
return c.copyFilesToMachineInner(shellSvc, closeClient, allowRecursion, preserve, paths, destination, noProgress)
}

// copyFilesToMachineInner is the common logic for both copyFiles variants.
Expand All @@ -2453,16 +2459,81 @@ func (c *viamClient) copyFilesToMachineInner(
preserve bool,
paths []string,
destination string,
noProgress bool,
) error {
defer func() {
utils.UncheckedError(closeClient(c.c.Context))
}()

// prepare a factory that understands the file copying service (RPC or not).
copyFactory := shell.NewCopyFileToMachineFactory(destination, preserve, shellSvc)
// make a reader copier that just does the traversal and copy work for us. Think of
// this as a tee reader.
readCopier, err := shell.NewLocalFileReadCopier(paths, allowRecursion, false, copyFactory)
if noProgress {
// prepare a factory that understands the file copying service (RPC or not).
copyFactory := shell.NewCopyFileToMachineFactory(destination, preserve, shellSvc)
// make a reader copier that just does the traversal and copy work for us. Think of
// this as a tee reader.
readCopier, err := shell.NewLocalFileReadCopier(paths, allowRecursion, false, copyFactory)
if err != nil {
return err
}
defer func() {
if err := readCopier.Close(c.c.Context); err != nil {
utils.UncheckedError(err)
}
}()

// ReadAll the files into the copier.
return readCopier.ReadAll(c.c.Context)
}

// Calculate total size of all files to be copied
var totalSize int64
for _, path := range paths {
info, err := os.Stat(path)
if err != nil {
return err
}
if info.IsDir() && allowRecursion {
err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
totalSize += info.Size()
}
return nil
})
if err != nil {
return err
}
} else if !info.IsDir() {
totalSize += info.Size()
}
}

// Create a progress tracking function
var currentFile string
progressFunc := func(bytes int64, file string, fileSize int64) {
if file != currentFile {
if currentFile != "" {
//nolint:errcheck // progress display is non-critical
_, _ = os.Stdout.WriteString("\n")
}
currentFile = file
//nolint:errcheck // progress display is non-critical
_, _ = os.Stdout.WriteString(fmt.Sprintf("Copying %s...\n", file))
}
uploadPercent := int(math.Ceil(100 * float64(bytes) / float64(fileSize)))
//nolint:errcheck // progress display is non-critical
_, _ = os.Stdout.WriteString(fmt.Sprintf("\rProgress: %d%% (%d/%d bytes)", uploadPercent, bytes, fileSize))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note this will work differently in headless shells

would not worry too much about it, but maybe add a --no-progress field for reload + cp, so that people can skip this behavior if it's messing with their shell?

(or assign this to me as afterwork so this PR can get merged; I feel bad that I took so long to start this review)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Was easy.

}

// Wrap the copy factory to track progress
progressFactory := &progressTrackingFactory{
factory: shell.NewCopyFileToMachineFactory(destination, preserve, shellSvc),
onProgress: progressFunc,
}

// Create a new read copier with the progress tracking factory
readCopier, err := shell.NewLocalFileReadCopier(paths, allowRecursion, false, progressFactory)
if err != nil {
return err
}
Expand All @@ -2473,7 +2544,88 @@ func (c *viamClient) copyFilesToMachineInner(
}()

// ReadAll the files into the copier.
return readCopier.ReadAll(c.c.Context)
err = readCopier.ReadAll(c.c.Context)
return err
}

// progressTrackingFactory wraps a copy factory to track progress.
type progressTrackingFactory struct {
factory shell.FileCopyFactory
onProgress func(int64, string, int64)
}

func (ptf *progressTrackingFactory) MakeFileCopier(ctx context.Context, sourceType shell.CopyFilesSourceType) (shell.FileCopier, error) {
copier, err := ptf.factory.MakeFileCopier(ctx, sourceType)
if err != nil {
return nil, err
}
return &progressTrackingCopier{
copier: copier,
onProgress: ptf.onProgress,
}, nil
}

// progressTrackingCopier wraps a file copier to track progress.
type progressTrackingCopier struct {
copier shell.FileCopier
onProgress func(int64, string, int64)
}

func (ptc *progressTrackingCopier) Copy(ctx context.Context, file shell.File) error {
// Get file size
info, err := file.Data.Stat()
if err != nil {
return err
}
fileSize := info.Size()

// Create a progress tracking reader
progressReader := &progressReader{
reader: file.Data,
onProgress: ptc.onProgress,
fileName: file.RelativeName,
fileSize: fileSize,
}

// Create a new file with the progress tracking reader
progressFile := shell.File{
RelativeName: file.RelativeName,
Data: progressReader,
}

return ptc.copier.Copy(ctx, progressFile)
}

func (ptc *progressTrackingCopier) Close(ctx context.Context) error {
//nolint:errcheck // progress display is non-critical
_, _ = os.Stdout.WriteString("\n")
return ptc.copier.Close(ctx)
}

// progressReader wraps a reader to track progress.
type progressReader struct {
reader fs.File
onProgress func(int64, string, int64)
copied int64
fileName string
fileSize int64
}

func (pr *progressReader) Read(p []byte) (int, error) {
n, err := pr.reader.Read(p)
if n > 0 {
pr.copied += int64(n)
pr.onProgress(pr.copied, pr.fileName, pr.fileSize)
}
return n, err
}

func (pr *progressReader) Stat() (fs.FileInfo, error) {
return pr.reader.Stat()
}

func (pr *progressReader) Close() error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding final \n somewhere (potentially in Close)

as-is, the next shell prompt appears on the same line as the progress output

~/repo/rdk$ go run ./cli/viam machine part cp --part PART README.md machine:README.md
Copying README.md...
Progress: 100% (5577/5577 bytes)~/repo/rdk$

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Good catch. I had this but then I refactored to handle multiple files.

return pr.reader.Close()
}

func (c *viamClient) copyFilesFromMachine(
Expand Down
7 changes: 4 additions & 3 deletions cli/module_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ type reloadModuleArgs struct {
RestartOnly bool
NoBuild bool
Local bool
NoProgress bool
}

// ReloadModuleAction builds a module, configures it on a robot, and starts or restarts it.
Expand Down Expand Up @@ -558,14 +559,14 @@ func reloadModuleAction(c *cli.Context, vc *viamClient, args reloadModuleArgs, l
return err
}
infof(c.App.Writer, "Copying %s to part %s", manifest.Build.Path, part.Part.Id)
args, err := getGlobalArgs(c)
globalArgs, err := getGlobalArgs(c)
if err != nil {
return err
}
dest := reloadingDestination(c, manifest)
err = vc.copyFilesToFqdn(
part.Part.Fqdn, args.Debug, false, false, []string{manifest.Build.Path},
dest, logging.NewLogger("reload"))
part.Part.Fqdn, globalArgs.Debug, false, false, []string{manifest.Build.Path},
dest, logging.NewLogger("reload"), args.NoProgress)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied {
warningf(c.App.ErrWriter, "RDK couldn't write to the default file copy destination. "+
Expand Down
Loading