Skip to content

Commit dc41888

Browse files
committed
fix(provisioning): address local review findings
1 parent 6592f1d commit dc41888

7 files changed

Lines changed: 252 additions & 25 deletions

File tree

api/auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func newAuthConfig() authConfig {
119119
bearerTeamsPaths: splitListOrDefault(os.Getenv("SPRITZ_AUTH_BEARER_TEAMS_PATHS"), nil),
120120
bearerTypePaths: splitListOrDefault(os.Getenv("SPRITZ_AUTH_BEARER_TYPE_PATHS"), nil),
121121
bearerScopesPaths: splitListOrDefault(os.Getenv("SPRITZ_AUTH_BEARER_SCOPES_PATHS"), []string{"scope", "scopes", "scp"}),
122-
bearerDefaultType: normalizePrincipalType(envOrDefault("SPRITZ_AUTH_BEARER_DEFAULT_TYPE", string(principalTypeHuman)), principalTypeHuman),
122+
bearerDefaultType: normalizePrincipalType(envOrDefault("SPRITZ_AUTH_BEARER_DEFAULT_TYPE", string(principalTypeService)), principalTypeService),
123123
bearerAuthorizationHeader: envOrDefault("SPRITZ_AUTH_BEARER_HEADER", "Authorization"),
124124
bearerJWKSURL: strings.TrimSpace(os.Getenv("SPRITZ_AUTH_BEARER_JWKS_URL")),
125125
bearerJWKSIssuer: strings.TrimSpace(os.Getenv("SPRITZ_AUTH_BEARER_ISSUER")),

api/auth_middleware_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,49 @@ func TestBearerAuthParsesSpaceDelimitedScopes(t *testing.T) {
178178
t.Fatalf("expected two scopes, got %#v", payload["scopes"])
179179
}
180180
}
181+
182+
func TestBearerAuthDefaultsToServiceTypeWithoutTypeClaim(t *testing.T) {
183+
introspection := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184+
_ = json.NewEncoder(w).Encode(map[string]any{
185+
"sub": "zenobot",
186+
"scope": "spritz.instances.create spritz.instances.assign_owner",
187+
})
188+
}))
189+
defer introspection.Close()
190+
191+
t.Setenv("SPRITZ_AUTH_MODE", "auto")
192+
t.Setenv("SPRITZ_AUTH_BEARER_INTROSPECTION_URL", introspection.URL)
193+
t.Setenv("SPRITZ_AUTH_BEARER_ID_PATHS", "sub")
194+
t.Setenv("SPRITZ_AUTH_BEARER_SCOPES_PATHS", "scope")
195+
196+
s := &server{auth: newAuthConfig()}
197+
e := echo.New()
198+
secured := e.Group("", s.authMiddleware())
199+
secured.GET("/api/spritzes", func(c echo.Context) error {
200+
p, ok := principalFromContext(c)
201+
if !ok {
202+
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "missing principal"})
203+
}
204+
return c.JSON(http.StatusOK, map[string]any{
205+
"type": p.Type,
206+
"scopes": p.Scopes,
207+
})
208+
})
209+
210+
req := httptest.NewRequest(http.MethodGet, "/api/spritzes", nil)
211+
req.Header.Set("Authorization", "Bearer test-token")
212+
rec := httptest.NewRecorder()
213+
e.ServeHTTP(rec, req)
214+
215+
if rec.Code != http.StatusOK {
216+
t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String())
217+
}
218+
219+
payload := map[string]any{}
220+
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
221+
t.Fatalf("failed to decode response: %v", err)
222+
}
223+
if payload["type"] != string(principalTypeService) {
224+
t.Fatalf("expected default bearer principal type to be service, got %#v", payload["type"])
225+
}
226+
}

api/main.go

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,27 @@ import (
2929
)
3030

3131
type server struct {
32-
client client.Client
33-
clientset *kubernetes.Clientset
34-
restConfig *rest.Config
35-
scheme *runtime.Scheme
36-
namespace string
37-
auth authConfig
38-
internalAuth internalAuthConfig
39-
ingressDefaults ingressDefaults
40-
terminal terminalConfig
41-
sshGateway sshGatewayConfig
42-
sshDefaults sshDefaults
43-
sshMintLimiter *sshMintLimiter
44-
acp acpConfig
45-
presets presetCatalog
46-
provisioners provisionerPolicy
47-
defaultMetadata map[string]string
48-
sharedMounts sharedMountsConfig
49-
sharedMountsStore *sharedMountsStore
50-
sharedMountsLive *sharedMountsLatestNotifier
51-
userConfigPolicy userConfigPolicy
32+
client client.Client
33+
clientset *kubernetes.Clientset
34+
restConfig *rest.Config
35+
scheme *runtime.Scheme
36+
namespace string
37+
auth authConfig
38+
internalAuth internalAuthConfig
39+
ingressDefaults ingressDefaults
40+
terminal terminalConfig
41+
sshGateway sshGatewayConfig
42+
sshDefaults sshDefaults
43+
sshMintLimiter *sshMintLimiter
44+
acp acpConfig
45+
presets presetCatalog
46+
provisioners provisionerPolicy
47+
defaultMetadata map[string]string
48+
sharedMounts sharedMountsConfig
49+
sharedMountsStore *sharedMountsStore
50+
sharedMountsLive *sharedMountsLatestNotifier
51+
userConfigPolicy userConfigPolicy
52+
nameGeneratorFactory func(context.Context, string, string) (func() string, error)
5253
}
5354

5455
func main() {
@@ -534,6 +535,9 @@ func (s *server) createSpritz(c echo.Context) error {
534535
if existing != nil && strings.TrimSpace(existing.Annotations[idempotencyHashAnnotationKey]) == strings.TrimSpace(annotations[idempotencyHashAnnotationKey]) {
535536
return writeJSON(c, http.StatusOK, summarizeCreateResponse(existing, principal, body.PresetID, provisionerSource(&body), body.IdempotencyKey, true))
536537
}
538+
if !nameProvided {
539+
continue
540+
}
537541
return writeError(c, http.StatusConflict, "idempotencyKey already used with a different request")
538542
}
539543
if !nameProvided && apierrors.IsAlreadyExists(err) {

api/main_create_owner_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"net/http"
78
"net/http/httptest"
@@ -11,7 +12,10 @@ import (
1112

1213
"github.com/labstack/echo/v4"
1314
corev1 "k8s.io/api/core/v1"
15+
apierrors "k8s.io/apimachinery/pkg/api/errors"
1416
"k8s.io/apimachinery/pkg/runtime"
17+
"k8s.io/apimachinery/pkg/runtime/schema"
18+
"sigs.k8s.io/controller-runtime/pkg/client"
1519
"sigs.k8s.io/controller-runtime/pkg/client/fake"
1620

1721
spritzv1 "spritz.sh/operator/api/v1"
@@ -49,6 +53,20 @@ func newCreateSpritzTestServer(t *testing.T) *server {
4953
}
5054
}
5155

56+
type createInterceptClient struct {
57+
client.Client
58+
onCreate func(context.Context, client.Object) error
59+
}
60+
61+
func (c *createInterceptClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error {
62+
if c.onCreate != nil {
63+
if err := c.onCreate(ctx, obj); err != nil {
64+
return err
65+
}
66+
}
67+
return c.Client.Create(ctx, obj, opts...)
68+
}
69+
5270
func configureProvisionerTestServer(s *server) {
5371
s.presets = presetCatalog{
5472
byID: []runtimePreset{{
@@ -392,3 +410,63 @@ func TestCreateSpritzReplaysIdempotentProvisionerRequestBeforeQuotaCheck(t *test
392410
t.Fatalf("expected replay status 200, got %d: %s", rec2.Code, rec2.Body.String())
393411
}
394412
}
413+
414+
func TestCreateSpritzRetriesGeneratedServiceNameAfterAlreadyExists(t *testing.T) {
415+
s := newCreateSpritzTestServer(t)
416+
configureProvisionerTestServer(s)
417+
baseClient := s.client
418+
s.client = &createInterceptClient{
419+
Client: baseClient,
420+
onCreate: func(_ context.Context, obj client.Object) error {
421+
spritz, ok := obj.(*spritzv1.Spritz)
422+
if !ok {
423+
return nil
424+
}
425+
if spritz.Name == "openclaw-first" {
426+
return apierrors.NewAlreadyExists(schema.GroupResource{
427+
Group: spritzv1.GroupVersion.Group,
428+
Resource: "spritzes",
429+
}, spritz.Name)
430+
}
431+
return nil
432+
},
433+
}
434+
s.nameGeneratorFactory = func(context.Context, string, string) (func() string, error) {
435+
names := []string{"openclaw-first", "openclaw-second"}
436+
index := 0
437+
return func() string {
438+
name := names[index]
439+
if index < len(names)-1 {
440+
index++
441+
}
442+
return name
443+
}, nil
444+
}
445+
446+
e := echo.New()
447+
secured := e.Group("", s.authMiddleware())
448+
secured.POST("/api/spritzes", s.createSpritz)
449+
450+
body := []byte(`{"presetId":"openclaw","ownerId":"user-123","idempotencyKey":"discord-race"}`)
451+
req := httptest.NewRequest(http.MethodPost, "/api/spritzes", bytes.NewReader(body))
452+
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
453+
req.Header.Set("X-Spritz-User-Id", "zenobot")
454+
req.Header.Set("X-Spritz-Principal-Type", "service")
455+
req.Header.Set("X-Spritz-Principal-Scopes", "spritz.instances.create,spritz.instances.assign_owner")
456+
rec := httptest.NewRecorder()
457+
458+
e.ServeHTTP(rec, req)
459+
460+
if rec.Code != http.StatusCreated {
461+
t.Fatalf("expected status 201 after autogenerated name retry, got %d: %s", rec.Code, rec.Body.String())
462+
}
463+
464+
var payload map[string]any
465+
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
466+
t.Fatalf("failed to decode response: %v", err)
467+
}
468+
name := payload["data"].(map[string]any)["spritz"].(map[string]any)["metadata"].(map[string]any)["name"]
469+
if name != "openclaw-second" {
470+
t.Fatalf("expected second generated name after race, got %#v", name)
471+
}
472+
}

api/random_name.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,9 @@ func randomSuffix(length int) string {
296296
}
297297

298298
func (s *server) newSpritzNameGenerator(ctx context.Context, namespace string, prefix string) (func() string, error) {
299+
if s.nameGeneratorFactory != nil {
300+
return s.nameGeneratorFactory(ctx, namespace, prefix)
301+
}
299302
list := &spritzv1.SpritzList{}
300303
opts := []client.ListOption{client.InNamespace(namespace)}
301304
if err := s.client.List(ctx, list, opts...); err != nil {

api/terminal.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func (s *server) openTerminal(c echo.Context) error {
165165
if usingZmx {
166166
log.Printf("spritz terminal: zmx attach name=%s namespace=%s session=%s user_id=%s", name, namespace, resolvedSession, principal.ID)
167167
}
168-
if err := s.streamTerminal(c.Request().Context(), pod, conn, command); err != nil {
168+
if err := s.streamTerminal(c.Request().Context(), namespace, name, pod, conn, command); err != nil {
169169
if errors.Is(err, context.Canceled) {
170170
return nil
171171
}
@@ -202,7 +202,7 @@ func (s *server) findRunningPod(ctx context.Context, namespace, name, container
202202
return nil, fmt.Errorf("spritz not ready")
203203
}
204204

205-
func (s *server) streamTerminal(ctx context.Context, pod *corev1.Pod, conn *websocket.Conn, command []string) error {
205+
func (s *server) streamTerminal(ctx context.Context, namespace, name string, pod *corev1.Pod, conn *websocket.Conn, command []string) error {
206206
if len(command) == 0 {
207207
return errors.New("terminal command missing")
208208
}
@@ -236,7 +236,13 @@ func (s *server) streamTerminal(ctx context.Context, pod *corev1.Pod, conn *webs
236236

237237
readErr := make(chan error, 1)
238238
go func() {
239-
readErr <- readTerminalInput(ctx, conn, stdinWriter, sizeQueue)
239+
readErr <- readTerminalInput(ctx, conn, stdinWriter, sizeQueue, func() {
240+
refreshCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
241+
defer cancel()
242+
if err := s.markSpritzActivity(refreshCtx, namespace, name, time.Now()); err != nil {
243+
log.Printf("spritz terminal: failed to refresh activity name=%s namespace=%s pod=%s err=%v", name, namespace, pod.Name, err)
244+
}
245+
})
240246
}()
241247

242248
streamErr := executor.StreamWithContext(ctx, remotecommand.StreamOptions{
@@ -268,7 +274,7 @@ type resizeMessage struct {
268274
Rows int `json:"rows"`
269275
}
270276

271-
func readTerminalInput(ctx context.Context, conn *websocket.Conn, stdin *io.PipeWriter, sizeQueue *terminalSizeQueue) error {
277+
func readTerminalInput(ctx context.Context, conn *websocket.Conn, stdin *io.PipeWriter, sizeQueue *terminalSizeQueue, onInput func()) error {
272278
for {
273279
select {
274280
case <-ctx.Done():
@@ -290,6 +296,9 @@ func readTerminalInput(ctx context.Context, conn *websocket.Conn, stdin *io.Pipe
290296
if _, err := stdin.Write(payload); err != nil {
291297
return err
292298
}
299+
if onInput != nil {
300+
onInput()
301+
}
293302
}
294303
}
295304
}

api/terminal_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"net/http/httptest"
8+
"net/url"
9+
"sync/atomic"
10+
"testing"
11+
"time"
12+
13+
"github.com/gorilla/websocket"
14+
)
15+
16+
func TestReadTerminalInputInvokesActivityCallbackOnInput(t *testing.T) {
17+
upgrader := websocket.Upgrader{}
18+
serverConn := make(chan *websocket.Conn, 1)
19+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20+
conn, err := upgrader.Upgrade(w, r, nil)
21+
if err != nil {
22+
t.Fatalf("upgrade failed: %v", err)
23+
}
24+
serverConn <- conn
25+
}))
26+
defer srv.Close()
27+
28+
wsURL, err := url.Parse(srv.URL)
29+
if err != nil {
30+
t.Fatalf("failed to parse server url: %v", err)
31+
}
32+
wsURL.Scheme = "ws"
33+
clientConn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
34+
if err != nil {
35+
t.Fatalf("failed to dial websocket: %v", err)
36+
}
37+
defer clientConn.Close()
38+
39+
conn := <-serverConn
40+
defer conn.Close()
41+
42+
reader, writer := io.Pipe()
43+
defer reader.Close()
44+
45+
ctx, cancel := context.WithCancel(context.Background())
46+
defer cancel()
47+
48+
var callbacks atomic.Int32
49+
done := make(chan error, 1)
50+
go func() {
51+
done <- readTerminalInput(ctx, conn, writer, newTerminalSizeQueue(), func() {
52+
callbacks.Add(1)
53+
})
54+
}()
55+
56+
if err := clientConn.WriteMessage(websocket.TextMessage, []byte(`{"type":"resize","cols":80,"rows":24}`)); err != nil {
57+
t.Fatalf("failed to send resize message: %v", err)
58+
}
59+
if callbacks.Load() != 0 {
60+
t.Fatalf("expected resize message to skip activity callback, got %d", callbacks.Load())
61+
}
62+
63+
if err := clientConn.WriteMessage(websocket.TextMessage, []byte("ls\n")); err != nil {
64+
t.Fatalf("failed to send terminal input: %v", err)
65+
}
66+
67+
buf := make([]byte, 3)
68+
if _, err := io.ReadFull(reader, buf); err != nil {
69+
t.Fatalf("failed to read stdin payload: %v", err)
70+
}
71+
deadline := time.Now().Add(2 * time.Second)
72+
for callbacks.Load() != 1 && time.Now().Before(deadline) {
73+
time.Sleep(10 * time.Millisecond)
74+
}
75+
if callbacks.Load() != 1 {
76+
t.Fatalf("expected one activity callback for terminal input, got %d", callbacks.Load())
77+
}
78+
79+
cancel()
80+
_ = clientConn.Close()
81+
82+
select {
83+
case <-done:
84+
case <-time.After(2 * time.Second):
85+
t.Fatal("timed out waiting for terminal reader to exit")
86+
}
87+
}

0 commit comments

Comments
 (0)