Skip to content

Commit e85a928

Browse files
jrhynessclaude
andauthored
refactor: simplify TokenRateLimitPolicy by trusting AuthPolicy validation (#543)
## Simplify TokenRateLimitPolicy by trusting AuthPolicy validation **Jira**: https://redhat.atlassian.net/browse/RHOAIENG-53680 **jira**: https://redhat.atlassian.net/browse/RHOAIENG-53951 ## Description This PR removes ~80 lines of duplicate subscription validation logic from the TokenRateLimitPolicy (TRLP) controller. The TRLP now trusts the validated `auth.identity.selected_subscription` from AuthPolicy instead of re-implementing membership checks, header validation, and deny rules. We also now auto-select the subscription by model if the model is available in only one subscription. ### Problem The TRLP controller was duplicating validation that AuthPolicy already performs: - Checking if users belong to subscriptions (groups/users) - Validating the `x-maas-subscription` header - Creating 3+ deny rules per model ### Solution AuthPolicy already validates subscriptions via `/v1/subscriptions/select` and injects `auth.identity.selected_subscription`. TRLP now simply uses this value. **Before:** ```yaml limits: sub-a-model-tokens: when: - predicate: auth.identity.groups_str.split(",").exists(g, g == "team-a") - predicate: request.headers["x-maas-subscription"] == "sub-a" || (!request.headers.exists(...) && !(exclusions...)) deny-not-member-sub-a-model: [...] deny-unsubscribed-model: [...] deny-invalid-header-model: [...] ``` **After:** ```yaml limits: sub-a-model-tokens: when: - predicate: auth.identity.selected_subscription == "sub-a" ``` ### Changes - Removed `buildMembershipCheck` function and validation logic (~80 lines) - Simplified `subInfo` struct from 6 fields to 3 - Removed all deny rules - Added 2 unit tests verifying simplified structure - Updated E2E tests to expect 403 (not 429) for invalid subscriptions - Enhanced logging with subscription count **Result:** 75% fewer TRLP limit entries, 32% smaller function, better error codes (403 vs 429) ## How Has This Been Tested? ### Unit Tests ```bash cd maas-controller go test ./pkg/controller/maas -run TestMaaSSubscription -v ``` All 7 tests pass (5 existing + 2 new): - New: `TestMaaSSubscriptionReconciler_SimplifiedTRLP` - verifies single predicate, no deny rules - New: `TestMaaSSubscriptionReconciler_MultipleSubscriptionsSimplified` - verifies no exclusion logic ### E2E Tests ```bash cd test/e2e pytest tests/test_subscription.py -v ``` Updated tests to expect 403 Forbidden (from AuthPolicy) instead of 429 (from TRLP) for subscription validation failures. ### Linter ```bash golangci-lint run --new-from-rev=HEAD~1 ``` No issues. ## Merge criteria: - [x] The commits are squashed in a cohesive manner and have meaningful messages. - [x] Testing instructions have been added in the PR body (for PRs involving changes that are not immediately obvious). - [x] The developer has manually tested the changes and verified that the changes work <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Invalid, missing, or ambiguous subscription validation now returns 403 Forbidden (applied before rate limiting). * **New Features** * Explicit model selection in subscription flows with clear "model_not_in_subscription" responses and per-model isolation for policies. * Namespace-scoped subscription isolation to prevent cross-tenant collisions. * Simplified per-subscription rate-limit generation relying on auth-selected subscription keys. * **Tests** * Expanded unit and e2e coverage for simplified rate limits, model-scoped behavior, cross-namespace isolation, 403 expectations, and transient command retries. * **Diagnostics** * Enhanced debug/reporting (test user, subscription→model mappings, model listings, configuration summary). <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 4807f87 commit e85a928

14 files changed

Lines changed: 1515 additions & 235 deletions

File tree

maas-api/internal/handlers/models.go

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ func (h *ModelsHandler) selectSubscriptionsForListing(
7878

7979
// Single subscription selection (existing behavior)
8080
if h.subscriptionSelector != nil {
81-
result, err := h.subscriptionSelector.Select(userContext.Groups, userContext.Username, requestedSubscription)
81+
//nolint:unqueryvet,nolintlint // Select is a method, not a SQL query
82+
result, err := h.subscriptionSelector.Select(userContext.Groups, userContext.Username, requestedSubscription, "")
8283
if err != nil {
8384
h.handleSubscriptionSelectionError(c, err)
8485
return nil, true
@@ -99,6 +100,7 @@ func (h *ModelsHandler) selectSubscriptionsForListing(
99100
// handleSubscriptionSelectionError handles errors from subscription selection and sends appropriate HTTP responses.
100101
func (h *ModelsHandler) handleSubscriptionSelectionError(c *gin.Context, err error) {
101102
var multipleSubsErr *subscription.MultipleSubscriptionsError
103+
var ambiguousErr *subscription.SubscriptionAmbiguousError
102104
var accessDeniedErr *subscription.AccessDeniedError
103105
var notFoundErr *subscription.SubscriptionNotFoundError
104106
var noSubErr *subscription.NoSubscriptionError
@@ -117,6 +119,16 @@ func (h *ModelsHandler) handleSubscriptionSelectionError(c *gin.Context, err err
117119
return
118120
}
119121

122+
if errors.As(err, &ambiguousErr) {
123+
h.logger.Debug("Subscription name is ambiguous")
124+
c.JSON(http.StatusForbidden, gin.H{
125+
"error": gin.H{
126+
"message": err.Error(),
127+
"type": "permission_error",
128+
}})
129+
return
130+
}
131+
120132
if errors.As(err, &accessDeniedErr) {
121133
h.logger.Debug("Access denied to subscription")
122134
c.JSON(http.StatusForbidden, gin.H{
@@ -255,7 +267,7 @@ func (h *ModelsHandler) ListLLMs(c *gin.Context) {
255267
} else {
256268
// User has zero accessible subscriptions - return empty list
257269
h.logger.Debug("User has zero accessible subscriptions, returning empty model list")
258-
// modelList is already initialized to empty slice at line 235
270+
// modelList is already initialized to empty slice above
259271
}
260272
} else {
261273
// Filter models by subscription(s) and aggregate subscriptions
@@ -268,8 +280,22 @@ func (h *ModelsHandler) ListLLMs(c *gin.Context) {
268280
modelsByKey := make(map[modelKey]*models.Model)
269281

270282
for _, sub := range subscriptionsToUse {
271-
h.logger.Debug("Filtering models by subscription", "subscription", sub.Name, "modelCount", len(list))
272-
filteredModels := h.modelMgr.FilterModelsByAccess(c.Request.Context(), list, authHeader, sub.Name)
283+
// Pre-filter by modelRefs if available (optimization to reduce HTTP calls)
284+
modelsToCheck := list
285+
if len(sub.ModelRefs) > 0 {
286+
h.logger.Debug("Pre-filtering models by subscription modelRefs",
287+
"subscription", sub.Name,
288+
"totalModels", len(list),
289+
"modelRefsCount", len(sub.ModelRefs),
290+
)
291+
modelsToCheck = filterModelsBySubscription(list, sub.ModelRefs)
292+
h.logger.Debug("After modelRef filtering", "modelsToCheck", len(modelsToCheck))
293+
}
294+
295+
// Use qualified "namespace/name" format for accurate authorization checks
296+
qualifiedSubName := sub.Namespace + "/" + sub.Name
297+
h.logger.Debug("Filtering models by subscription", "subscription", qualifiedSubName, "modelCount", len(modelsToCheck))
298+
filteredModels := h.modelMgr.FilterModelsByAccess(c.Request.Context(), modelsToCheck, authHeader, qualifiedSubName)
273299

274300
for _, model := range filteredModels {
275301
subInfo := models.SubscriptionInfo{
@@ -323,3 +349,29 @@ func (h *ModelsHandler) ListLLMs(c *gin.Context) {
323349
Data: modelList,
324350
})
325351
}
352+
353+
// filterModelsBySubscription filters models to only those matching the subscription's modelRefs.
354+
func filterModelsBySubscription(modelList []models.Model, modelRefs []subscription.ModelRefInfo) []models.Model {
355+
if len(modelRefs) == 0 {
356+
return modelList
357+
}
358+
359+
// Build map of allowed models for fast lookup
360+
allowed := make(map[string]bool)
361+
for _, ref := range modelRefs {
362+
key := ref.Namespace + "/" + ref.Name
363+
allowed[key] = true
364+
}
365+
366+
// Filter models
367+
filtered := make([]models.Model, 0, len(modelList))
368+
for _, model := range modelList {
369+
// Models from MaaSModelRefLister have OwnedBy set to namespace
370+
modelKey := model.OwnedBy + "/" + model.ID
371+
if allowed[modelKey] {
372+
filtered = append(filtered, model)
373+
}
374+
}
375+
376+
return filtered
377+
}

maas-api/internal/handlers/models_test.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,9 @@ func TestListingModelsWithSubscriptionHeader(t *testing.T) {
402402
testLogger := logger.Development()
403403

404404
// Create mock servers that require specific subscription headers
405-
premiumModelServer := createMockModelServerWithSubscriptionCheck(t, "premium-model", "premium")
406-
freeModelServer := createMockModelServerWithSubscriptionCheck(t, "free-model", "free")
405+
// Use qualified names (namespace/name) to match the format sent by the handler
406+
premiumModelServer := createMockModelServerWithSubscriptionCheck(t, "premium-model", "test-namespace/premium")
407+
freeModelServer := createMockModelServerWithSubscriptionCheck(t, "free-model", "test-namespace/free")
407408

408409
// Build MaaSModelRef unstructured list
409410
maasModelRefItems := []*unstructured.Unstructured{
@@ -573,9 +574,10 @@ func TestListModels_ReturnAllModels(t *testing.T) {
573574
testLogger := logger.Development()
574575

575576
// Create mock servers for models
576-
model1Server := createMockModelServerWithSubscriptionCheck(t, "model-1", "sub-a")
577-
model2Server := createMockModelServerWithSubscriptionCheck(t, "model-2", "sub-b")
578-
model3Server := createMockModelServerWithSubscriptionCheck(t, "model-3", "sub-a")
577+
// Use qualified names (namespace/name) to match the format sent by the handler
578+
model1Server := createMockModelServerWithSubscriptionCheck(t, "model-1", "test-namespace/sub-a")
579+
model2Server := createMockModelServerWithSubscriptionCheck(t, "model-2", "test-namespace/sub-b")
580+
model3Server := createMockModelServerWithSubscriptionCheck(t, "model-3", "test-namespace/sub-a")
579581

580582
// Setup MaaSModelRef lister with three models
581583
lister := fakeMaaSModelRefLister{
@@ -1097,8 +1099,9 @@ func TestListModels_DifferentModelRefsWithSameModelIDAndDifferentSubscriptions(t
10971099

10981100
// Create two mock servers that both return the same model ID "gpt-4"
10991101
// One accessible via sub-a, one via sub-b
1100-
modelServerA := createMockModelServerWithSubscriptionCheck(t, "gpt-4", "sub-a")
1101-
modelServerB := createMockModelServerWithSubscriptionCheck(t, "gpt-4", "sub-b")
1102+
// Use qualified names (namespace/name) to match the format sent by the handler
1103+
modelServerA := createMockModelServerWithSubscriptionCheck(t, "gpt-4", "test-namespace/sub-a")
1104+
modelServerB := createMockModelServerWithSubscriptionCheck(t, "gpt-4", "test-namespace/sub-b")
11021105

11031106
// Setup MaaSModelRef lister with two different MaaSModelRefs in different namespaces
11041107
lister := fakeMaaSModelRefLister{

maas-api/internal/models/discovery.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,8 @@ func discoveredToModels(discovered []openai.Model, original Model) []Model {
141141
if d.ID == "" {
142142
continue
143143
}
144-
ownedBy := d.OwnedBy
145-
if ownedBy == "" {
146-
ownedBy = original.OwnedBy
147-
}
144+
// Always use the trusted namespace from MaaSModelRef (original.OwnedBy)
145+
// Never trust backend-returned OwnedBy to prevent namespace spoofing
148146
created := d.Created
149147
if created == 0 {
150148
created = original.Created
@@ -154,7 +152,7 @@ func discoveredToModels(discovered []openai.Model, original Model) []Model {
154152
ID: d.ID,
155153
Object: "model",
156154
Created: created,
157-
OwnedBy: ownedBy,
155+
OwnedBy: original.OwnedBy,
158156
},
159157
Kind: original.Kind,
160158
URL: original.URL,

maas-api/internal/subscription/handler.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,17 @@ func (h *Handler) SelectSubscription(c *gin.Context) {
6464
"username", req.Username,
6565
"groups", req.Groups,
6666
"requestedSubscription", req.RequestedSubscription,
67+
"requestedModel", req.RequestedModel,
6768
)
6869

69-
response, err := h.selector.Select(req.Groups, req.Username, req.RequestedSubscription)
70+
response, err := h.selector.Select(req.Groups, req.Username, req.RequestedSubscription, req.RequestedModel)
7071
if err != nil {
7172
var noSubErr *NoSubscriptionError
7273
var notFoundErr *SubscriptionNotFoundError
7374
var accessDeniedErr *AccessDeniedError
7475
var multipleSubsErr *MultipleSubscriptionsError
76+
var ambiguousErr *SubscriptionAmbiguousError
77+
var modelNotInSubErr *ModelNotInSubscriptionError
7578

7679
if errors.As(err, &noSubErr) {
7780
h.logger.Debug("No subscription found for user",
@@ -120,6 +123,29 @@ func (h *Handler) SelectSubscription(c *gin.Context) {
120123
return
121124
}
122125

126+
if errors.As(err, &ambiguousErr) {
127+
h.logger.Debug("Subscription name is ambiguous",
128+
"username", req.Username,
129+
)
130+
c.JSON(http.StatusOK, SelectResponse{
131+
Error: "ambiguous_subscription",
132+
Message: err.Error(),
133+
})
134+
return
135+
}
136+
137+
if errors.As(err, &modelNotInSubErr) {
138+
h.logger.Debug("Model not included in subscription",
139+
"subscription", modelNotInSubErr.Subscription,
140+
"model", modelNotInSubErr.Model,
141+
)
142+
c.JSON(http.StatusOK, SelectResponse{
143+
Error: "model_not_in_subscription",
144+
Message: err.Error(),
145+
})
146+
return
147+
}
148+
123149
// All other errors are internal server errors
124150
h.logger.Error("Subscription selection failed",
125151
"error", err.Error(),

0 commit comments

Comments
 (0)