Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions internal/dutagent/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func (b *Broker) Start(ctx context.Context, s Stream) (module.Session, <-chan er

workerCtx, workerCancel := context.WithCancel(ctx)

b.session.done = workerCtx.Done()

b.wg.Add(numWorkers)
b.toClient(workerCtx, workerCancel)
b.fromClient(workerCtx, workerCancel)
Expand Down
87 changes: 87 additions & 0 deletions internal/dutagent/broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"errors"
"io"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -276,3 +277,89 @@ func TestBroker_DualErrors(t *testing.T) {
// WANT both errors; order unspecified.
}
}

// TestSession_PrintNotBlockingOnShutdown verifies that Print, Printf, and Println
// return promptly when the broker workers have shut down (done is closed).
func TestSession_PrintNotBlockingOnShutdown(t *testing.T) {
b := &Broker{}
// recvErrs with a real error triggers fromClientWorker to cancel the context,
// which closes done and causes toClientWorker to exit.
stream := &testStream{recvErrs: []error{errors.New("recv died")}}
sess, errCh := b.Start(context.Background(), stream)

// Wait for workers to shut down.
collectErrors(t, errCh, 300*time.Millisecond)

// All three Print variants must return without blocking.
done := make(chan struct{})
go func() {
defer close(done)
sess.Print("a")
sess.Printf("b %s", "c")
sess.Println("d")
}()

select {
case <-done:
// ok
case <-time.After(200 * time.Millisecond):
t.Fatal("Print/Printf/Println blocked after broker shutdown")
}
}

// TestSession_RequestFileErrorOnShutdown verifies that RequestFile returns an error
// when done is closed while the module is waiting for a file to be handed over.
func TestSession_RequestFileErrorOnShutdown(t *testing.T) {
b := &Broker{}
// Block Receive so fromClientWorker doesn't exit on its own; we cancel manually.
stream := &testStream{recvBlock: true}
ctx, cancel := context.WithCancel(context.Background())
sess, errCh := b.Start(ctx, stream)

done := make(chan error, 1)
go func() {
_, err := sess.(*session).RequestFile("firmware.bin")
done <- err
}()

// Give the goroutine time to block on fileReqCh/fileCh.
time.Sleep(20 * time.Millisecond)

// Cancel to shut down workers (closes done channel on session).
cancel()
if stream.unblockCh != nil {
close(stream.unblockCh)
}

collectErrors(t, errCh, 300*time.Millisecond)

select {
case err := <-done:
if err == nil {
t.Fatal("expected RequestFile to return an error on shutdown, got nil")
}
case <-time.After(200 * time.Millisecond):
t.Fatal("RequestFile blocked after broker shutdown")
}
}

// TestSession_SendFileErrorOnShutdown verifies that SendFile returns an error
// when done is closed before the broker's toClientWorker picks up the file channel.
func TestSession_SendFileErrorOnShutdown(t *testing.T) {
b := &Broker{}
// Use a send error so toClientWorker exits as soon as it tries to send,
// which cancels the context and closes done before SendFile can hand off the file.
stream := &testStream{
sendErr: errors.New("send died"),
recvErrs: []error{errors.New("recv died")},
}
sess, errCh := b.Start(context.Background(), stream)

// Wait for workers to exit.
collectErrors(t, errCh, 300*time.Millisecond)

err := sess.(*session).SendFile("result.bin", strings.NewReader("data"))
if err == nil {
t.Fatal("expected SendFile to return an error on shutdown, got nil")
}
}
40 changes: 34 additions & 6 deletions internal/dutagent/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

// session implements the module.Session interface.
type session struct {
done <-chan struct{} // closed when broker workers shut down; unblocks pending session calls
printCh chan string
stdinCh chan []byte
stdoutCh chan []byte
Expand All @@ -28,15 +29,24 @@ type session struct {
}

func (s *session) Print(a ...any) {
s.printCh <- fmt.Sprint(a...)
select {
case s.printCh <- fmt.Sprint(a...):
case <-s.done:
}
Comment thread
RiSKeD marked this conversation as resolved.
}

func (s *session) Printf(format string, a ...any) {
s.printCh <- fmt.Sprintf(format, a...)
select {
case s.printCh <- fmt.Sprintf(format, a...):
case <-s.done:
}
}

func (s *session) Println(a ...any) {
s.printCh <- fmt.Sprintln(a...)
select {
case s.printCh <- fmt.Sprintln(a...):
case <-s.done:
}
}

//nolint:nonamedreturns
Expand Down Expand Up @@ -72,9 +82,19 @@ func (s *session) RequestFile(name string) (io.Reader, error) {

log.Printf("Module issued file request for: %q", name)

s.fileReqCh <- name // Send the file request to the client.
select {
case s.fileReqCh <- name:
case <-s.done:
return nil, fmt.Errorf("session closed before file request %q could be sent", name)
}

var file chan []byte

file := <-s.fileCh // This will block until the client sends the file.
select {
case file = <-s.fileCh:
case <-s.done:
return nil, fmt.Errorf("session closed while waiting for file %q", name)
}
Comment thread
RiSKeD marked this conversation as resolved.
Comment thread
RiSKeD marked this conversation as resolved.

r, err := chanio.NewChanReader(file)
if err != nil {
Expand All @@ -99,7 +119,15 @@ func (s *session) SendFile(name string, r io.Reader) error {
s.currentFile = name

file := make(chan []byte, 1)
s.fileCh <- file

select {
case s.fileCh <- file:
case <-s.done:
s.currentFile = ""

Comment on lines +123 to +127
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SendFile now returns an error when done is closed, but the method still log.Fatals if s.currentFile != "" before it ever checks done. With the new shutdown-unblocking behavior, currentFile can remain set during cancellation (e.g., an in-flight RequestFile), which can turn a shutdown into a process exit. Consider checking done (and returning an error) before the currentFile != "" fatal, or ensuring currentFile is always cleared on shutdown.

Copilot uses AI. Check for mistakes.
return fmt.Errorf("session closed before file %q could be sent", name)
}

file <- content

close(file) // indicate EOF.
Expand Down
8 changes: 7 additions & 1 deletion internal/dutagent/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,13 @@ func fromClientWorker(ctx context.Context, stream Stream, s *session) error {
log.Printf("Server received file %q from client", path)

file := make(chan []byte, 1)
s.fileCh <- file

select {
case <-ctx.Done():
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ctx.Done() fires here, fromClientWorker returns without clearing s.currentFile. In cancellation scenarios where toClientWorker has already set currentFile for an outstanding request, this can leave the session in an inconsistent state and may cause later calls (notably SendFile, which checks currentFile) to log.Fatal. Consider resetting s.currentFile before returning on the cancellation path.

Suggested change
case <-ctx.Done():
case <-ctx.Done():
s.currentFile = ""

Copilot uses AI. Check for mistakes.
return nil
case s.fileCh <- file:
}

file <- content

close(file)
Expand Down