Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 56 additions & 15 deletions commands/dial_stdio.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package commands

import (
"context"
"io"
"net"
"os"
"sync"

"github.com/containerd/platforms"
"github.com/docker/buildx/build"
Expand Down Expand Up @@ -79,27 +81,66 @@ func runDialStdio(dockerCli command.Cli, opts stdioOptions) error {
return err
}

defer conn.Close()
return proxyConn(ctx, conn, os.Stdin, os.Stdout)
})
}

var errStdinProxyCanceled = errors.New("stdin proxy canceled")

go func() {
<-ctx.Done()
closeWrite(conn)
}()
func proxyConn(ctx context.Context, conn net.Conn, stdin io.Reader, stdout io.Writer) error {
defer conn.Close()

var eg errgroup.Group
cancelableStdin := newCancelableReader(stdin)
defer cancelableStdin.cancel(errStdinProxyCanceled)

stopAfterFunc := context.AfterFunc(ctx, func() {
cancelableStdin.cancel(context.Cause(ctx))
closeWrite(conn)
})
defer stopAfterFunc()

eg.Go(func() error {
_, err := io.Copy(conn, os.Stdin)
closeWrite(conn)
var eg errgroup.Group
eg.Go(func() error {
_, err := io.Copy(conn, cancelableStdin)
closeWrite(conn)
if err != nil && !errors.Is(err, errStdinProxyCanceled) {
return err
})
eg.Go(func() error {
_, err := io.Copy(os.Stdout, conn)
closeRead(conn)
}
return nil
})
eg.Go(func() error {
_, err := io.Copy(stdout, conn)
cancelableStdin.cancel(errStdinProxyCanceled)
closeRead(conn)
if err != nil && !errors.Is(err, io.EOF) {
return err
})
return eg.Wait()
}
return nil
})
return eg.Wait()
}

type cancelableReader struct {
io.Reader
cancel func(error)
}

func newCancelableReader(r io.Reader) *cancelableReader {
pr, pw := io.Pipe()
var once sync.Once
closePipe := func(err error) {
once.Do(func() {
_ = pw.CloseWithError(err)
})
}
go func() {
_, err := io.Copy(pw, r)
closePipe(err)
}()
return &cancelableReader{
Reader: pr,
cancel: closePipe,
}
}

func closeRead(conn net.Conn) error {
Expand Down
56 changes: 56 additions & 0 deletions commands/dial_stdio_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package commands

import (
"bytes"
"context"
"io"
"net"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestProxyConnRemoteClose(t *testing.T) {
clientConn, serverConn := net.Pipe()
defer serverConn.Close()

stdin := &blockingReader{waitCh: make(chan struct{})}
defer stdin.Close()

var stdout bytes.Buffer
errCh := make(chan error, 1)
go func() {
errCh <- proxyConn(context.Background(), clientConn, stdin, &stdout)
}()

go func() {
_, _ = serverConn.Write([]byte("hello"))
_ = serverConn.Close()
}()

select {
case err := <-errCh:
require.NoError(t, err)
require.Equal(t, "hello", stdout.String())
case <-time.After(2 * time.Second):
t.Fatal("proxyConn did not return after the remote side closed")
}
}

type blockingReader struct {
waitCh chan struct{}
closeOnce sync.Once
}

func (r *blockingReader) Read([]byte) (int, error) {
<-r.waitCh
return 0, io.EOF
}

func (r *blockingReader) Close() {
r.closeOnce.Do(func() {
close(r.waitCh)
})
}
Loading