Skip to content

Commit 9c03e1b

Browse files
committed
Address PR review feedback
oauthcallback: - Validate auth code is non-empty before accepting callback - Accept pre-bound net.Listener for ephemeral port tests - Non-blocking channel sends to handle duplicate requests - Tests use 127.0.0.1:0 instead of fixed ports output: - AsError handles nil error without panicking - writeCount returns 0 for nil data, propagates write errors profile: Sort names before calling Picker for deterministic order surface: Sort Diff results for stable CI output credstore: - Random probe key to avoid collisions with real credentials - FallbackWarning exported instead of writing to stderr actions: - rubric-check: Verify exit code is exactly 1, not just non-zero - surface-compat: Sort baseline before comm for locale safety - sync-skills: Preserve subdirectory structure during copy
1 parent c805e0a commit 9c03e1b

File tree

11 files changed

+159
-56
lines changed

11 files changed

+159
-56
lines changed

actions/rubric-check/action.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ runs:
6464
6565
# 1B.1: Exit codes - bad usage should exit 1
6666
"$BINARY" --nonexistent-flag >/dev/null 2>&1
67-
if [ $? -ne 0 ]; then
68-
check "1B.1" "Non-zero exit on bad usage" "pass"
67+
EXIT_CODE=$?
68+
if [ "$EXIT_CODE" -eq 1 ]; then
69+
check "1B.1" "Exit code 1 on bad usage" "pass"
6970
else
70-
check "1B.1" "Non-zero exit on bad usage" "fail"
71+
check "1B.1" "Exit code 1 on bad usage (got $EXIT_CODE)" "fail"
7172
fi
7273
7374
if [ "$PROFILE" = "api-cli" ]; then

actions/surface-compat/action.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ runs:
7373
BASELINE="${{ inputs.baseline }}"
7474
CURRENT="/tmp/surface-current.txt"
7575
76+
# comm requires sorted input — sort both with consistent locale
77+
LC_ALL=C sort "$BASELINE" -o "$BASELINE"
78+
LC_ALL=C sort "$CURRENT" -o "$CURRENT"
79+
7680
REMOVED=$(comm -23 "$BASELINE" "$CURRENT" | wc -l | tr -d ' ')
7781
ADDED=$(comm -13 "$BASELINE" "$CURRENT" | wc -l | tr -d ' ')
7882

actions/sync-skills/action.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ runs:
6767
dest="${TARGET_DIR}/${CLI_NAME}/${skill_name}"
6868
echo " Syncing ${skill_name}..."
6969
mkdir -p "$dest"
70-
find "$skill_dir" -type f ! -name '*.go' ! -name '.*' -exec cp {} "$dest/" \;
70+
# Preserve subdirectory structure; exclude Go sources and dotfiles
71+
(cd "$skill_dir" && find . -type f ! -name '*.go' ! -name '.*' | while read -r f; do
72+
mkdir -p "$dest/$(dirname "$f")"
73+
cp "$f" "$dest/$f"
74+
done)
7175
MANAGED_SKILLS+=("${CLI_NAME}/${skill_name}")
7276
done
7377

credstore/store.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package credstore
22

33
import (
4+
"crypto/rand"
5+
"encoding/hex"
46
"fmt"
57
"os"
68
"path/filepath"
@@ -28,26 +30,38 @@ type Store struct {
2830
fallbackDir string
2931
}
3032

33+
// FallbackWarning is set when NewStore falls back to file storage.
34+
// The caller can check this and display the warning as appropriate.
35+
var FallbackWarning string
36+
3137
// NewStore creates a credential store. It probes the system keyring
3238
// and falls back to file storage if unavailable.
3339
func NewStore(opts StoreOptions) *Store {
40+
FallbackWarning = ""
41+
3442
if opts.DisableEnvVar != "" && os.Getenv(opts.DisableEnvVar) != "" {
3543
return &Store{serviceName: opts.ServiceName, useKeyring: false, fallbackDir: opts.FallbackDir}
3644
}
3745

38-
// Test if keyring is available
39-
testKey := opts.ServiceName + "::test"
40-
err := keyring.Set(opts.ServiceName, testKey, "test")
46+
// Probe keyring with a random key to avoid collisions.
47+
probeKey := probeKeyName()
48+
err := keyring.Set(opts.ServiceName, probeKey, "probe")
4149
if err == nil {
42-
_ = keyring.Delete(opts.ServiceName, testKey)
50+
_ = keyring.Delete(opts.ServiceName, probeKey)
4351
return &Store{serviceName: opts.ServiceName, useKeyring: true, fallbackDir: opts.FallbackDir}
4452
}
4553

46-
fmt.Fprintf(os.Stderr, "warning: system keyring unavailable, credentials stored in plaintext at %s\n",
54+
FallbackWarning = fmt.Sprintf("system keyring unavailable, credentials stored in plaintext at %s",
4755
filepath.Join(opts.FallbackDir, "credentials.json"))
4856
return &Store{serviceName: opts.ServiceName, useKeyring: false, fallbackDir: opts.FallbackDir}
4957
}
5058

59+
func probeKeyName() string {
60+
b := make([]byte, 8)
61+
_, _ = rand.Read(b)
62+
return "__probe_" + hex.EncodeToString(b)
63+
}
64+
5165
func (s *Store) key(name string) string {
5266
return fmt.Sprintf("%s::%s", s.serviceName, name)
5367
}

oauthcallback/server.go

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,39 @@ import (
99
"time"
1010
)
1111

12-
// WaitForCallback starts a local HTTP server and waits for an OAuth callback.
13-
// It returns the authorization code from the callback.
14-
func WaitForCallback(ctx context.Context, expectedState, authURL, listenAddr string) (string, error) {
15-
lc := net.ListenConfig{}
16-
listener, err := lc.Listen(ctx, "tcp", listenAddr)
17-
if err != nil {
18-
return "", fmt.Errorf("failed to start callback server: %w", err)
12+
// WaitForCallback starts a local HTTP server on listener and waits for an
13+
// OAuth callback. It returns the authorization code from the callback.
14+
//
15+
// If listener is nil, one is created on listenAddr. Passing a pre-bound
16+
// listener (e.g., from net.Listen("tcp", "127.0.0.1:0")) is preferred
17+
// for tests to avoid port conflicts.
18+
func WaitForCallback(ctx context.Context, expectedState string, listener net.Listener, listenAddr string) (string, error) {
19+
if listener == nil {
20+
lc := net.ListenConfig{}
21+
var err error
22+
listener, err = lc.Listen(ctx, "tcp", listenAddr)
23+
if err != nil {
24+
return "", fmt.Errorf("failed to start callback server: %w", err)
25+
}
1926
}
2027
defer listener.Close()
2128

2229
codeCh := make(chan string, 1)
2330
errCh := make(chan error, 1)
24-
var shutdownOnce sync.Once
31+
var once sync.Once
32+
33+
send := func(ch chan<- string, val string) {
34+
select {
35+
case ch <- val:
36+
default:
37+
}
38+
}
39+
sendErr := func(ch chan<- error, val error) {
40+
select {
41+
case ch <- val:
42+
default:
43+
}
44+
}
2545

2646
server := &http.Server{
2747
ReadHeaderTimeout: 10 * time.Second,
@@ -30,37 +50,42 @@ func WaitForCallback(ctx context.Context, expectedState, authURL, listenAddr str
3050
IdleTimeout: 30 * time.Second,
3151
}
3252

53+
shutdown := func() {
54+
once.Do(func() {
55+
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
56+
go func() { defer cancel(); server.Shutdown(shutdownCtx) }()
57+
})
58+
}
59+
3360
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3461
state := r.URL.Query().Get("state")
3562
code := r.URL.Query().Get("code")
3663
errParam := r.URL.Query().Get("error")
3764

3865
if errParam != "" {
39-
errCh <- fmt.Errorf("OAuth error: %s", errParam)
66+
sendErr(errCh, fmt.Errorf("OAuth error: %s", errParam))
4067
fmt.Fprint(w, "<html><body><h1>Authentication failed</h1><p>You can close this window.</p></body></html>")
41-
shutdownOnce.Do(func() {
42-
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
43-
go func() { defer cancel(); server.Shutdown(shutdownCtx) }()
44-
})
68+
shutdown()
4569
return
4670
}
4771

4872
if state != expectedState {
49-
errCh <- fmt.Errorf("state mismatch: CSRF protection failed")
73+
sendErr(errCh, fmt.Errorf("state mismatch: CSRF protection failed"))
5074
fmt.Fprint(w, "<html><body><h1>Authentication failed</h1><p>State mismatch.</p></body></html>")
51-
shutdownOnce.Do(func() {
52-
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
53-
go func() { defer cancel(); server.Shutdown(shutdownCtx) }()
54-
})
75+
shutdown()
76+
return
77+
}
78+
79+
if code == "" {
80+
sendErr(errCh, fmt.Errorf("OAuth callback missing authorization code"))
81+
fmt.Fprint(w, "<html><body><h1>Authentication failed</h1><p>Missing authorization code.</p></body></html>")
82+
shutdown()
5583
return
5684
}
5785

58-
codeCh <- code
86+
send(codeCh, code)
5987
fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
60-
shutdownOnce.Do(func() {
61-
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
62-
go func() { defer cancel(); server.Shutdown(shutdownCtx) }()
63-
})
88+
shutdown()
6489
})
6590

6691
go server.Serve(listener)
@@ -73,6 +98,6 @@ func WaitForCallback(ctx context.Context, expectedState, authURL, listenAddr str
7398
case <-ctx.Done():
7499
return "", ctx.Err()
75100
case <-time.After(5 * time.Minute):
76-
return "", fmt.Errorf("authentication timeout waiting for callback on %s", listenAddr)
101+
return "", fmt.Errorf("authentication timeout waiting for callback on %s", listener.Addr())
77102
}
78103
}

oauthcallback/server_test.go

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package oauthcallback
22

33
import (
44
"context"
5+
"fmt"
6+
"net"
57
"net/http"
68
"testing"
79
"time"
@@ -10,30 +12,37 @@ import (
1012
"github.com/stretchr/testify/require"
1113
)
1214

15+
func listen(t *testing.T) net.Listener {
16+
t.Helper()
17+
ln, err := net.Listen("tcp", "127.0.0.1:0")
18+
require.NoError(t, err)
19+
t.Cleanup(func() { ln.Close() })
20+
return ln
21+
}
22+
1323
func TestWaitForCallback_Success(t *testing.T) {
1424
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1525
defer cancel()
1626

27+
ln := listen(t)
28+
addr := ln.Addr().String()
1729
state := "test-state-123"
18-
listenAddr := "127.0.0.1:18976"
1930

2031
codeCh := make(chan string, 1)
2132
errCh := make(chan error, 1)
2233

2334
go func() {
24-
code, err := WaitForCallback(ctx, state, "", listenAddr)
35+
code, err := WaitForCallback(ctx, state, ln, "")
2536
if err != nil {
2637
errCh <- err
2738
} else {
2839
codeCh <- code
2940
}
3041
}()
3142

32-
// Give server time to start
3343
time.Sleep(100 * time.Millisecond)
3444

35-
// Simulate callback
36-
resp, err := http.Get("http://127.0.0.1:18976/callback?state=test-state-123&code=auth-code-456")
45+
resp, err := http.Get(fmt.Sprintf("http://%s/callback?state=test-state-123&code=auth-code-456", addr))
3746
require.NoError(t, err)
3847
resp.Body.Close()
3948

@@ -47,21 +56,49 @@ func TestWaitForCallback_Success(t *testing.T) {
4756
}
4857
}
4958

59+
func TestWaitForCallback_MissingCode(t *testing.T) {
60+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
61+
defer cancel()
62+
63+
ln := listen(t)
64+
addr := ln.Addr().String()
65+
errCh := make(chan error, 1)
66+
67+
go func() {
68+
_, err := WaitForCallback(ctx, "state", ln, "")
69+
errCh <- err
70+
}()
71+
72+
time.Sleep(100 * time.Millisecond)
73+
74+
resp, err := http.Get(fmt.Sprintf("http://%s/callback?state=state", addr))
75+
require.NoError(t, err)
76+
resp.Body.Close()
77+
78+
select {
79+
case err := <-errCh:
80+
assert.Contains(t, err.Error(), "missing authorization code")
81+
case <-time.After(3 * time.Second):
82+
t.Fatal("timeout")
83+
}
84+
}
85+
5086
func TestWaitForCallback_StateMismatch(t *testing.T) {
5187
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
5288
defer cancel()
5389

54-
listenAddr := "127.0.0.1:18977"
90+
ln := listen(t)
91+
addr := ln.Addr().String()
5592
errCh := make(chan error, 1)
5693

5794
go func() {
58-
_, err := WaitForCallback(ctx, "expected-state", "", listenAddr)
95+
_, err := WaitForCallback(ctx, "expected-state", ln, "")
5996
errCh <- err
6097
}()
6198

6299
time.Sleep(100 * time.Millisecond)
63100

64-
resp, err := http.Get("http://127.0.0.1:18977/callback?state=wrong-state&code=abc")
101+
resp, err := http.Get(fmt.Sprintf("http://%s/callback?state=wrong-state&code=abc", addr))
65102
require.NoError(t, err)
66103
resp.Body.Close()
67104

@@ -77,17 +114,18 @@ func TestWaitForCallback_OAuthError(t *testing.T) {
77114
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
78115
defer cancel()
79116

80-
listenAddr := "127.0.0.1:18978"
117+
ln := listen(t)
118+
addr := ln.Addr().String()
81119
errCh := make(chan error, 1)
82120

83121
go func() {
84-
_, err := WaitForCallback(ctx, "state", "", listenAddr)
122+
_, err := WaitForCallback(ctx, "state", ln, "")
85123
errCh <- err
86124
}()
87125

88126
time.Sleep(100 * time.Millisecond)
89127

90-
resp, err := http.Get("http://127.0.0.1:18978/callback?error=access_denied")
128+
resp, err := http.Get(fmt.Sprintf("http://%s/callback?error=access_denied", addr))
91129
require.NoError(t, err)
92130
resp.Body.Close()
93131

@@ -102,11 +140,11 @@ func TestWaitForCallback_OAuthError(t *testing.T) {
102140
func TestWaitForCallback_ContextCancellation(t *testing.T) {
103141
ctx, cancel := context.WithCancel(context.Background())
104142

105-
listenAddr := "127.0.0.1:18979"
143+
ln := listen(t)
106144
errCh := make(chan error, 1)
107145

108146
go func() {
109-
_, err := WaitForCallback(ctx, "state", "", listenAddr)
147+
_, err := WaitForCallback(ctx, "state", ln, "")
110148
errCh <- err
111149
}()
112150

output/envelope.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,17 @@ func (w *Writer) writeCount(v any) error {
280280
data := NormalizeData(resp.Data)
281281

282282
switch d := data.(type) {
283+
case nil:
284+
_, err := fmt.Fprintln(w.opts.Writer, 0)
285+
return err
283286
case []any:
284-
fmt.Fprintln(w.opts.Writer, len(d))
287+
_, err := fmt.Fprintln(w.opts.Writer, len(d))
288+
return err
285289
case []map[string]any:
286-
fmt.Fprintln(w.opts.Writer, len(d))
290+
_, err := fmt.Fprintln(w.opts.Writer, len(d))
291+
return err
287292
default:
288-
fmt.Fprintln(w.opts.Writer, 1)
293+
_, err := fmt.Fprintln(w.opts.Writer, 1)
294+
return err
289295
}
290-
return nil
291296
}

output/errors.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ func ErrAmbiguous(resource string, matches []string) *Error {
127127

128128
// AsError attempts to convert an error to an *Error.
129129
func AsError(err error) *Error {
130+
if err == nil {
131+
return &Error{Code: CodeAPI, Message: "unknown error"}
132+
}
130133
var e *Error
131134
if errors.As(err, &e) {
132135
return e

profile/resolve.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package profile
22

3-
import "fmt"
3+
import (
4+
"fmt"
5+
"sort"
6+
)
47

58
// ResolveOptions controls how profile resolution behaves.
69
type ResolveOptions struct {
@@ -78,6 +81,7 @@ func Resolve(opts ResolveOptions) (string, error) {
7881
for name := range opts.Profiles {
7982
names = append(names, name)
8083
}
84+
sort.Strings(names)
8185
return opts.Picker(names)
8286
}
8387

0 commit comments

Comments
 (0)