Skip to content

Commit 6592f1d

Browse files
committed
fix(provisioning): address service principal review findings
1 parent 7066adb commit 6592f1d

13 files changed

Lines changed: 470 additions & 69 deletions

api/acp_conversations.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (s *server) listACPConversations(c echo.Context) error {
4848
if spritzName != "" {
4949
labels[acpConversationSpritzLabelKey] = spritzName
5050
}
51-
if s.auth.enabled() {
51+
if s.auth.enabled() && !principal.isAdminPrincipal() {
5252
labels[acpConversationOwnerLabelKey] = ownerLabelValue(principal.ID)
5353
}
5454
opts = append(opts, client.MatchingLabels(labels))

api/acp_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/gorilla/websocket"
1414
"github.com/labstack/echo/v4"
15+
corev1 "k8s.io/api/core/v1"
1516
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1617
"k8s.io/apimachinery/pkg/runtime"
1718
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -26,6 +27,9 @@ func newACPTestScheme(t *testing.T) *runtime.Scheme {
2627
if err := spritzv1.AddToScheme(scheme); err != nil {
2728
t.Fatalf("failed to register spritz scheme: %v", err)
2829
}
30+
if err := corev1.AddToScheme(scheme); err != nil {
31+
t.Fatalf("failed to register core scheme: %v", err)
32+
}
2933
return scheme
3034
}
3135

@@ -428,6 +432,41 @@ func TestListAndPatchACPConversationsByID(t *testing.T) {
428432
}
429433
}
430434

435+
func TestListACPConversationsAllowsAdminToSeeAllOwners(t *testing.T) {
436+
now := metav1.Now()
437+
spritz := readyACPSpritz("tidy-otter", "user-1")
438+
ownerOne := conversationFor("tidy-otter-user-1", "tidy-otter", "user-1", "Owner one", now)
439+
ownerTwo := conversationFor("tidy-otter-user-2", "tidy-otter", "user-2", "Owner two", now)
440+
441+
s := newACPTestServer(t, spritz, ownerOne, ownerTwo)
442+
s.auth.adminIDs = map[string]struct{}{"admin-1": {}}
443+
e := echo.New()
444+
secured := e.Group("", s.authMiddleware())
445+
secured.GET("/api/acp/conversations", s.listACPConversations)
446+
447+
req := httptest.NewRequest(http.MethodGet, "/api/acp/conversations?spritz=tidy-otter", nil)
448+
req.Header.Set("X-Spritz-User-Id", "admin-1")
449+
rec := httptest.NewRecorder()
450+
e.ServeHTTP(rec, req)
451+
452+
if rec.Code != http.StatusOK {
453+
t.Fatalf("expected status 200, got %d: %s", rec.Code, rec.Body.String())
454+
}
455+
456+
var payload struct {
457+
Status string `json:"status"`
458+
Data struct {
459+
Items []spritzv1.SpritzConversation `json:"items"`
460+
} `json:"data"`
461+
}
462+
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
463+
t.Fatalf("failed to decode list response: %v", err)
464+
}
465+
if len(payload.Data.Items) != 2 {
466+
t.Fatalf("expected 2 visible conversations for admin, got %d", len(payload.Data.Items))
467+
}
468+
}
469+
431470
func TestPatchACPConversationRejectsSessionIDMutation(t *testing.T) {
432471
spritz := readyACPSpritz("tidy-otter", "user-1")
433472
conversation := conversationFor("tidy-otter-new", "tidy-otter", "user-1", "Latest", metav1.Now())

api/auth.go

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ func (a *authConfig) principal(r *http.Request) (principal, error) {
231231
id,
232232
"",
233233
normalizePrincipalType(r.Header.Get(a.headerType), a.headerDefaultType),
234-
splitList(r.Header.Get(a.headerScopes)),
234+
splitScopes(r.Header.Get(a.headerScopes)),
235235
a.isAdmin(id, teams),
236236
), nil
237237
case authModeAuto:
@@ -246,7 +246,7 @@ func (a *authConfig) principal(r *http.Request) (principal, error) {
246246
id,
247247
"",
248248
normalizePrincipalType(r.Header.Get(a.headerType), a.headerDefaultType),
249-
splitList(r.Header.Get(a.headerScopes)),
249+
splitScopes(r.Header.Get(a.headerScopes)),
250250
a.isAdmin(id, teams),
251251
), nil
252252
}
@@ -348,7 +348,7 @@ func (a *authConfig) introspectToken(ctx context.Context, token string) (princip
348348
firstStringPath(payload, []string{"sub"}),
349349
firstStringPath(payload, []string{"iss", "issuer"}),
350350
normalizePrincipalType(firstStringPath(payload, a.bearerTypePaths), a.bearerDefaultType),
351-
firstStringListPath(payload, a.bearerScopesPaths),
351+
firstScopeListPath(payload, a.bearerScopesPaths),
352352
a.isAdmin(id, teams),
353353
), nil
354354
}
@@ -434,7 +434,7 @@ func (a *authConfig) principalFromJWT(ctx context.Context, token string) (princi
434434
firstStringPath(claims, []string{"sub"}),
435435
firstStringPath(claims, []string{"iss", "issuer"}),
436436
normalizePrincipalType(firstStringPath(claims, a.bearerTypePaths), a.bearerDefaultType),
437-
firstStringListPath(claims, a.bearerScopesPaths),
437+
firstScopeListPath(claims, a.bearerScopesPaths),
438438
a.isAdmin(id, teams),
439439
), nil
440440
}
@@ -602,6 +602,23 @@ func splitList(value string) []string {
602602
return out
603603
}
604604

605+
func splitScopes(value string) []string {
606+
if value == "" {
607+
return nil
608+
}
609+
raw := strings.FieldsFunc(value, func(r rune) bool {
610+
return r == ',' || r == ';' || r == ' ' || r == '\n' || r == '\r' || r == '\t'
611+
})
612+
out := make([]string, 0, len(raw))
613+
for _, item := range raw {
614+
item = strings.TrimSpace(item)
615+
if item != "" {
616+
out = append(out, item)
617+
}
618+
}
619+
return out
620+
}
621+
605622
func splitListOrDefault(value string, fallback []string) []string {
606623
items := splitList(value)
607624
if len(items) == 0 {
@@ -724,6 +741,32 @@ func firstStringListPath(payload map[string]any, paths []string) []string {
724741
return nil
725742
}
726743

744+
func firstScopeListPath(payload map[string]any, paths []string) []string {
745+
for _, path := range paths {
746+
value, ok := lookupPath(payload, path)
747+
if !ok {
748+
continue
749+
}
750+
switch typed := value.(type) {
751+
case []string:
752+
return typed
753+
case []any:
754+
items := make([]string, 0, len(typed))
755+
for _, item := range typed {
756+
if s, ok := item.(string); ok && s != "" {
757+
items = append(items, s)
758+
}
759+
}
760+
if len(items) > 0 {
761+
return items
762+
}
763+
case string:
764+
return splitScopes(typed)
765+
}
766+
}
767+
return nil
768+
}
769+
727770
func lookupPath(payload map[string]any, path string) (any, bool) {
728771
path = strings.TrimSpace(path)
729772
if path == "" {

api/auth_middleware_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,55 @@ func TestAuthMiddlewareSetsPrincipalTypeAndScopes(t *testing.T) {
126126
t.Fatalf("expected two scopes, got %#v", payload["scopes"])
127127
}
128128
}
129+
130+
func TestBearerAuthParsesSpaceDelimitedScopes(t *testing.T) {
131+
introspection := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
132+
_ = json.NewEncoder(w).Encode(map[string]any{
133+
"sub": "zenobot",
134+
"type": "service",
135+
"scope": "spritz.instances.create spritz.instances.assign_owner",
136+
})
137+
}))
138+
defer introspection.Close()
139+
140+
t.Setenv("SPRITZ_AUTH_MODE", "bearer")
141+
t.Setenv("SPRITZ_AUTH_BEARER_INTROSPECTION_URL", introspection.URL)
142+
t.Setenv("SPRITZ_AUTH_BEARER_ID_PATHS", "sub")
143+
t.Setenv("SPRITZ_AUTH_BEARER_TYPE_PATHS", "type")
144+
t.Setenv("SPRITZ_AUTH_BEARER_SCOPES_PATHS", "scope")
145+
146+
s := &server{auth: newAuthConfig()}
147+
e := echo.New()
148+
secured := e.Group("", s.authMiddleware())
149+
secured.GET("/api/spritzes", func(c echo.Context) error {
150+
p, ok := principalFromContext(c)
151+
if !ok {
152+
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "missing principal"})
153+
}
154+
return c.JSON(http.StatusOK, map[string]any{
155+
"type": p.Type,
156+
"scopes": p.Scopes,
157+
})
158+
})
159+
160+
req := httptest.NewRequest(http.MethodGet, "/api/spritzes", nil)
161+
req.Header.Set("Authorization", "Bearer test-token")
162+
rec := httptest.NewRecorder()
163+
e.ServeHTTP(rec, req)
164+
165+
if rec.Code != http.StatusOK {
166+
t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String())
167+
}
168+
169+
payload := map[string]any{}
170+
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
171+
t.Fatalf("failed to decode response: %v", err)
172+
}
173+
if payload["type"] != string(principalTypeService) {
174+
t.Fatalf("expected service principal type, got %#v", payload["type"])
175+
}
176+
scopes, _ := payload["scopes"].([]any)
177+
if len(scopes) != 2 {
178+
t.Fatalf("expected two scopes, got %#v", payload["scopes"])
179+
}
180+
}

api/main.go

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -379,15 +379,45 @@ func (s *server) createSpritz(c echo.Context) error {
379379
}
380380
body.Spec.Owner = owner
381381

382+
nameProvided := body.Name != ""
383+
var nameGenerator func() string
384+
namePrefix := resolveSpritzNamePrefix(body.NamePrefix, body.Spec.Image)
385+
if namePrefix == "" && preset != nil {
386+
namePrefix = resolveSpritzNamePrefix(preset.NamePrefix, preset.Image)
387+
}
388+
if !nameProvided {
389+
generator, err := s.newSpritzNameGenerator(c.Request().Context(), namespace, namePrefix)
390+
if err != nil {
391+
return writeError(c, http.StatusInternalServerError, "failed to generate spritz name")
392+
}
393+
nameGenerator = generator
394+
body.Name = nameGenerator()
395+
}
396+
if body.Name == "" {
397+
return writeError(c, http.StatusInternalServerError, "failed to generate spritz name")
398+
}
399+
382400
if principal.isService() {
383-
fingerprint, err := s.validateProvisionerCreate(c.Request().Context(), principal, namespace, &body, normalizedUserConfig, requestedImage, requestedRepo, requestedNamespace)
401+
fingerprintName := body.Name
402+
if !nameProvided {
403+
fingerprintName = ""
404+
}
405+
fingerprint, err := s.validateProvisionerCreate(c.Request().Context(), principal, namespace, &body, normalizedUserConfig, requestedImage, requestedRepo, requestedNamespace, fingerprintName)
384406
if err != nil {
385407
if errors.Is(err, errForbidden) {
386408
return writeError(c, http.StatusForbidden, "forbidden")
387409
}
388410
return writeError(c, http.StatusBadRequest, err.Error())
389411
}
390-
existing, err := findIdempotentSpritz(c.Request().Context(), s.client, namespace, principal.ID, body.IdempotencyKey)
412+
reservedName, completed, err := s.reserveIdempotentCreateName(c.Request().Context(), namespace, principal, body.IdempotencyKey, fingerprint, body.Name)
413+
if err != nil {
414+
if strings.Contains(err.Error(), "idempotencyKey already used") {
415+
return writeError(c, http.StatusConflict, err.Error())
416+
}
417+
return writeError(c, http.StatusInternalServerError, err.Error())
418+
}
419+
body.Name = reservedName
420+
existing, err := s.findReservedSpritz(c.Request().Context(), namespace, reservedName)
391421
if err != nil {
392422
return writeError(c, http.StatusInternalServerError, err.Error())
393423
}
@@ -397,12 +427,18 @@ func (s *server) createSpritz(c echo.Context) error {
397427
}
398428
return writeJSON(c, http.StatusOK, summarizeCreateResponse(existing, principal, body.PresetID, provisionerSource(&body), body.IdempotencyKey, true))
399429
}
430+
if completed {
431+
return writeError(c, http.StatusConflict, "idempotencyKey already used")
432+
}
433+
if err := s.enforceProvisionerQuotas(c.Request().Context(), namespace, principal, body.Spec.Owner.ID); err != nil {
434+
return writeError(c, http.StatusBadRequest, err.Error())
435+
}
400436
body.Annotations = mergeStringMap(body.Annotations, map[string]string{
401-
actorIDAnnotationKey: principal.ID,
402-
actorTypeAnnotationKey: string(principal.Type),
403-
sourceAnnotationKey: provisionerSource(&body),
404-
requestIDAnnotationKey: body.RequestID,
405-
idempotencyKeyAnnotationKey: body.IdempotencyKey,
437+
actorIDAnnotationKey: principal.ID,
438+
actorTypeAnnotationKey: string(principal.Type),
439+
sourceAnnotationKey: provisionerSource(&body),
440+
requestIDAnnotationKey: body.RequestID,
441+
idempotencyKeyAnnotationKey: body.IdempotencyKey,
406442
idempotencyHashAnnotationKey: fingerprint,
407443
})
408444
} else if s.auth.enabled() && !principal.isAdminPrincipal() && owner.ID != principal.ID {
@@ -413,24 +449,6 @@ func (s *server) createSpritz(c echo.Context) error {
413449
return writeError(c, http.StatusBadRequest, err.Error())
414450
}
415451

416-
nameProvided := body.Name != ""
417-
var nameGenerator func() string
418-
if !nameProvided {
419-
namePrefix := resolveSpritzNamePrefix(body.NamePrefix, body.Spec.Image)
420-
if namePrefix == "" && preset != nil {
421-
namePrefix = resolveSpritzNamePrefix(preset.NamePrefix, preset.Image)
422-
}
423-
generator, err := s.newSpritzNameGenerator(c.Request().Context(), namespace, namePrefix)
424-
if err != nil {
425-
return writeError(c, http.StatusInternalServerError, "failed to generate spritz name")
426-
}
427-
nameGenerator = generator
428-
body.Name = nameGenerator()
429-
}
430-
if body.Name == "" {
431-
return writeError(c, http.StatusInternalServerError, "failed to generate spritz name")
432-
}
433-
434452
labels := map[string]string{
435453
ownerLabelKey: ownerLabelValue(owner.ID),
436454
}
@@ -508,11 +526,26 @@ func (s *server) createSpritz(c echo.Context) error {
508526
return writeError(c, http.StatusBadRequest, err.Error())
509527
}
510528
if err := s.client.Create(c.Request().Context(), spritz); err != nil {
529+
if principal.isService() && apierrors.IsAlreadyExists(err) {
530+
existing, getErr := s.findReservedSpritz(c.Request().Context(), namespace, name)
531+
if getErr != nil {
532+
return writeError(c, http.StatusInternalServerError, getErr.Error())
533+
}
534+
if existing != nil && strings.TrimSpace(existing.Annotations[idempotencyHashAnnotationKey]) == strings.TrimSpace(annotations[idempotencyHashAnnotationKey]) {
535+
return writeJSON(c, http.StatusOK, summarizeCreateResponse(existing, principal, body.PresetID, provisionerSource(&body), body.IdempotencyKey, true))
536+
}
537+
return writeError(c, http.StatusConflict, "idempotencyKey already used with a different request")
538+
}
511539
if !nameProvided && apierrors.IsAlreadyExists(err) {
512540
continue
513541
}
514542
return writeError(c, http.StatusInternalServerError, err.Error())
515543
}
544+
if principal.isService() {
545+
if err := s.completeIdempotencyReservation(c.Request().Context(), namespace, principal.ID, body.IdempotencyKey, spritz); err != nil {
546+
return writeError(c, http.StatusInternalServerError, err.Error())
547+
}
548+
}
516549
return writeJSON(c, http.StatusCreated, summarizeCreateResponse(spritz, principal, body.PresetID, provisionerSource(&body), body.IdempotencyKey, false))
517550
}
518551

0 commit comments

Comments
 (0)