diff --git a/commands/dial_stdio.go b/commands/dial_stdio.go index e848efba5100..73c3b9a24304 100644 --- a/commands/dial_stdio.go +++ b/commands/dial_stdio.go @@ -1,9 +1,11 @@ package commands import ( + "context" "io" "net" "os" + "sync" "github.com/containerd/platforms" "github.com/docker/buildx/build" @@ -79,27 +81,66 @@ func runDialStdio(dockerCli command.Cli, opts stdioOptions) error { return err } - defer conn.Close() + return proxyConn(ctx, conn, os.Stdin, os.Stdout) + }) +} + +var errStdinProxyCanceled = errors.New("stdin proxy canceled") - go func() { - <-ctx.Done() - closeWrite(conn) - }() +func proxyConn(ctx context.Context, conn net.Conn, stdin io.Reader, stdout io.Writer) error { + defer conn.Close() - var eg errgroup.Group + cancelableStdin := newCancelableReader(stdin) + defer cancelableStdin.cancel(errStdinProxyCanceled) + + stopAfterFunc := context.AfterFunc(ctx, func() { + cancelableStdin.cancel(context.Cause(ctx)) + closeWrite(conn) + }) + defer stopAfterFunc() - eg.Go(func() error { - _, err := io.Copy(conn, os.Stdin) - closeWrite(conn) + var eg errgroup.Group + eg.Go(func() error { + _, err := io.Copy(conn, cancelableStdin) + closeWrite(conn) + if err != nil && !errors.Is(err, errStdinProxyCanceled) { return err - }) - eg.Go(func() error { - _, err := io.Copy(os.Stdout, conn) - closeRead(conn) + } + return nil + }) + eg.Go(func() error { + _, err := io.Copy(stdout, conn) + cancelableStdin.cancel(errStdinProxyCanceled) + closeRead(conn) + if err != nil && !errors.Is(err, io.EOF) { return err - }) - return eg.Wait() + } + return nil }) + return eg.Wait() +} + +type cancelableReader struct { + io.Reader + cancel func(error) +} + +func newCancelableReader(r io.Reader) *cancelableReader { + pr, pw := io.Pipe() + var once sync.Once + closePipe := func(err error) { + once.Do(func() { + _ = pw.CloseWithError(err) + }) + } + go func() { + _, err := io.Copy(pw, r) + closePipe(err) + }() + return &cancelableReader{ + Reader: pr, + cancel: closePipe, + } } func closeRead(conn net.Conn) error { diff --git a/commands/dial_stdio_test.go b/commands/dial_stdio_test.go new file mode 100644 index 000000000000..a304da393d99 --- /dev/null +++ b/commands/dial_stdio_test.go @@ -0,0 +1,56 @@ +package commands + +import ( + "bytes" + "context" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestProxyConnRemoteClose(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + stdin := &blockingReader{waitCh: make(chan struct{})} + defer stdin.Close() + + var stdout bytes.Buffer + errCh := make(chan error, 1) + go func() { + errCh <- proxyConn(context.Background(), clientConn, stdin, &stdout) + }() + + go func() { + _, _ = serverConn.Write([]byte("hello")) + _ = serverConn.Close() + }() + + select { + case err := <-errCh: + require.NoError(t, err) + require.Equal(t, "hello", stdout.String()) + case <-time.After(2 * time.Second): + t.Fatal("proxyConn did not return after the remote side closed") + } +} + +type blockingReader struct { + waitCh chan struct{} + closeOnce sync.Once +} + +func (r *blockingReader) Read([]byte) (int, error) { + <-r.waitCh + return 0, io.EOF +} + +func (r *blockingReader) Close() { + r.closeOnce.Do(func() { + close(r.waitCh) + }) +}