Skip to content

Commit 8d2aaa4

Browse files
authored
Add test for stdout scanner race with runner.Wait() (#300)
1 parent 557fdfb commit 8d2aaa4

File tree

1 file changed

+84
-4
lines changed

1 file changed

+84
-4
lines changed

client_test.go

+84-4
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ import (
2828

2929
func TestClient(t *testing.T) {
3030
process := helperProcess("mock")
31+
logger := &trackingLogger{Logger: hclog.Default()}
3132
c := NewClient(&ClientConfig{
3233
Cmd: process,
3334
HandshakeConfig: testHandshake,
3435
Plugins: testPluginMap,
36+
Logger: logger,
3537
})
3638
defer c.Kill()
3739

@@ -61,6 +63,9 @@ func TestClient(t *testing.T) {
6163
if !c.killed() {
6264
t.Fatal("Client should have failed")
6365
}
66+
67+
// One error for connection refused, one for plugin exited.
68+
assertLines(t, logger.errorLogs, 2)
6469
}
6570

6671
// This tests a bug where Kill would start
@@ -112,19 +117,19 @@ func TestClient_killStart(t *testing.T) {
112117
}
113118

114119
func TestClient_testCleanup(t *testing.T) {
115-
// Create a temporary dir to store the result file
116-
td := t.TempDir()
117-
defer os.RemoveAll(td)
120+
t.Parallel()
118121

119122
// Create a path that the helper process will write on cleanup
120-
path := filepath.Join(td, "output")
123+
path := filepath.Join(t.TempDir(), "output")
121124

122125
// Test the cleanup
123126
process := helperProcess("cleanup", path)
127+
logger := &trackingLogger{Logger: hclog.Default()}
124128
c := NewClient(&ClientConfig{
125129
Cmd: process,
126130
HandshakeConfig: testHandshake,
127131
Plugins: testPluginMap,
132+
Logger: logger,
128133
})
129134

130135
// Grab the client so the process starts
@@ -140,6 +145,61 @@ func TestClient_testCleanup(t *testing.T) {
140145
if _, err := os.Stat(path); err != nil {
141146
t.Fatalf("err: %s", err)
142147
}
148+
149+
assertLines(t, logger.errorLogs, 0)
150+
}
151+
152+
func TestClient_noStdoutScannerRace(t *testing.T) {
153+
t.Parallel()
154+
155+
process := helperProcess("test-grpc")
156+
logger := &trackingLogger{Logger: hclog.Default()}
157+
c := NewClient(&ClientConfig{
158+
RunnerFunc: func(l hclog.Logger, cmd *exec.Cmd, tmpDir string) (runner.Runner, error) {
159+
process.Env = append(process.Env, cmd.Env...)
160+
concreteRunner, err := cmdrunner.NewCmdRunner(l, process)
161+
if err != nil {
162+
return nil, err
163+
}
164+
// Inject a delay before calling .Read() method on the command's
165+
// stdout reader. This ensures that if there is a race between the
166+
// stdout scanner loop reading stdout and runner.Wait() closing
167+
// stdout, .Wait() will win and trigger a scanner error in the logs.
168+
return &delayedStdoutCmdRunner{concreteRunner}, nil
169+
},
170+
HandshakeConfig: testHandshake,
171+
Plugins: testGRPCPluginMap,
172+
AllowedProtocols: []Protocol{ProtocolGRPC},
173+
Logger: logger,
174+
})
175+
176+
// Grab the client so the process starts
177+
if _, err := c.Client(); err != nil {
178+
c.Kill()
179+
t.Fatalf("err: %s", err)
180+
}
181+
182+
// Kill it gracefully
183+
c.Kill()
184+
185+
assertLines(t, logger.errorLogs, 0)
186+
}
187+
188+
type delayedStdoutCmdRunner struct {
189+
*cmdrunner.CmdRunner
190+
}
191+
192+
func (m *delayedStdoutCmdRunner) Stdout() io.ReadCloser {
193+
return &delayedReader{m.CmdRunner.Stdout()}
194+
}
195+
196+
type delayedReader struct {
197+
io.ReadCloser
198+
}
199+
200+
func (d *delayedReader) Read(p []byte) (n int, err error) {
201+
time.Sleep(100 * time.Millisecond)
202+
return d.ReadCloser.Read(p)
143203
}
144204

145205
func TestClient_testInterface(t *testing.T) {
@@ -1563,3 +1623,23 @@ func TestClient_logStderrParseJSON(t *testing.T) {
15631623
}
15641624
}
15651625
}
1626+
1627+
type trackingLogger struct {
1628+
hclog.Logger
1629+
errorLogs []string
1630+
}
1631+
1632+
func (l *trackingLogger) Error(msg string, args ...interface{}) {
1633+
l.errorLogs = append(l.errorLogs, fmt.Sprintf("%s: %v", msg, args))
1634+
l.Logger.Error(msg, args...)
1635+
}
1636+
1637+
func assertLines(t *testing.T, lines []string, expected int) {
1638+
t.Helper()
1639+
if len(lines) != expected {
1640+
t.Errorf("expected %d, got %d", expected, len(lines))
1641+
for _, log := range lines {
1642+
t.Error(log)
1643+
}
1644+
}
1645+
}

0 commit comments

Comments
 (0)