Skip to content

Commit 75f0afd

Browse files
authored
Merge branch 'master' into TT-16767
2 parents 41b4354 + 9792d07 commit 75f0afd

40 files changed

+1551
-281
lines changed

cli/linter/schema.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,12 @@
16441644
}
16451645
}
16461646
},
1647+
"rate_limit_response_headers": {
1648+
"description": "Determines the type of data that will be returned in the rate limit headers",
1649+
"type": ["string"],
1650+
"enum": ["", "quotas", "rate_limits"],
1651+
"default": "quotas"
1652+
},
16471653
"allow_unsafe_policy_ids": {
16481654
"type": ["boolean", "null"],
16491655
"additionalProperties": false

config/rate_limit.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,20 @@ type RateLimit struct {
3737

3838
// Controls which algorthm to use as a fallback when your distributed rate limiter can't be used.
3939
DRLEnableSentinelRateLimiter bool `json:"drl_enable_sentinel_rate_limiter"`
40+
41+
// RateLimitResponseHeaders specifies the data source for rate limit headers in HTTP responses.
42+
// This controls whether rate limit headers (X-RateLimit-Limit, X-RateLimit-Remaining, etc.)
43+
// are populated from quota data or rate limit data. Valid values: "quotas", "rate_limits".
44+
RateLimitResponseHeaders RateLimitSource `json:"rate_limit_response_headers"`
4045
}
4146

47+
type RateLimitSource string
48+
49+
const (
50+
SourceQuotas RateLimitSource = "quotas"
51+
SourceRateLimits RateLimitSource = "rate_limits"
52+
)
53+
4254
// String returns a readable setting for the rate limiter in effect.
4355
func (r *RateLimit) String() string {
4456
info := "using transactions"

gateway/middleware_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ func TestSessionLimiter_RedisQuotaExceeded_ExpiredAtReset(t *testing.T) {
420420
}
421421

422422
beforeTime := time.Now()
423-
blocked := limiter.RedisQuotaExceeded(req, session, quotaKey, "", limit, g.Gw.GlobalSessionManager.Store(), false)
423+
blocked := limiter.RedisQuotaExceeded(req, session, quotaKey, "", limit, false, false)
424424
afterTime := time.Now()
425425

426426
assert.Equal(t, quotaMax-1, session.QuotaRemaining, "Quota remaining should be quotaMax - 1 after increment")
@@ -475,7 +475,7 @@ func TestSessionLimiter_RedisQuotaExceeded_ExpiredAtReset(t *testing.T) {
475475
}
476476

477477
beforeTime := time.Now()
478-
blocked := limiter.RedisQuotaExceeded(req, session, quotaKey, scope, limit, g.Gw.GlobalSessionManager.Store(), false)
478+
blocked := limiter.RedisQuotaExceeded(req, session, quotaKey, scope, limit, false, false)
479479
afterTime := time.Now()
480480

481481
accessDef := session.AccessRights["api1"]

gateway/model.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,20 @@ import (
1212

1313
type EventMetaDefault = model.EventMetaDefault
1414

15+
type CtxData = map[string]any
16+
17+
const (
18+
ctxDataKeyRateLimitLimit = "rate_limit_limit"
19+
ctxDataKeyRateLimitRemaining = "rate_limit_remaining"
20+
ctxDataKeyRateLimitReset = "rate_limit_reset"
21+
22+
ctxDataKeyQuotaLimit = "quota_limit"
23+
ctxDataKeyQuotaRemaining = "quota_remaining"
24+
ctxDataKeyQuotaReset = "quota_reset"
25+
)
26+
1527
var (
16-
ctxData = httpctx.NewValue[map[string]any](ctx.ContextData)
28+
ctxData = httpctx.NewValue[CtxData](ctx.ContextData)
1729

1830
ctxGetData = ctxData.Get
1931
ctxSetData = ctxData.Set
@@ -28,3 +40,14 @@ var (
2840

2941
EncodeRequestToEvent = event.EncodeRequestToEvent
3042
)
43+
44+
func ctxGetOrCreateData(r *http.Request) CtxData {
45+
data := ctxGetData(r)
46+
47+
if data == nil {
48+
data = CtxData{}
49+
ctxSetData(r, data)
50+
}
51+
52+
return data
53+
}

gateway/model_apispec.go

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"net/http"
66
"net/url"
7-
"strconv"
87
"strings"
98
"sync"
109
"sync/atomic"
@@ -18,15 +17,13 @@ import (
1817
"github.com/TykTechnologies/tyk/apidef/oas"
1918
"github.com/TykTechnologies/tyk/config"
2019
"github.com/TykTechnologies/tyk/ctx"
21-
"github.com/TykTechnologies/tyk/header"
2220
"github.com/TykTechnologies/tyk/internal/agentprotocol"
2321
"github.com/TykTechnologies/tyk/internal/certcheck"
2422
"github.com/TykTechnologies/tyk/internal/errors"
2523
"github.com/TykTechnologies/tyk/internal/graphengine"
2624
"github.com/TykTechnologies/tyk/internal/httpctx"
2725
"github.com/TykTechnologies/tyk/internal/httputil"
2826
"github.com/TykTechnologies/tyk/internal/jsonrpc"
29-
"github.com/TykTechnologies/tyk/user"
3027

3128
_ "github.com/TykTechnologies/tyk/internal/mcp" // registers MCP VEM prefixes
3229
)
@@ -379,19 +376,3 @@ func (a *APISpec) APIType() string {
379376
return "classic"
380377
}
381378
}
382-
383-
func (a *APISpec) sendRateLimitHeaders(session *user.SessionState, dest *http.Response) {
384-
quotaMax, quotaRemaining, quotaRenews := int64(0), int64(0), int64(0)
385-
386-
if session != nil {
387-
quotaMax, quotaRemaining, _, quotaRenews = session.GetQuotaLimitByAPIID(a.APIID)
388-
}
389-
390-
if dest.Header == nil {
391-
dest.Header = http.Header{}
392-
}
393-
394-
dest.Header.Set(header.XRateLimitLimit, strconv.Itoa(int(quotaMax)))
395-
dest.Header.Set(header.XRateLimitRemaining, strconv.Itoa(int(quotaRemaining)))
396-
dest.Header.Set(header.XRateLimitReset, strconv.Itoa(int(quotaRenews)))
397-
}

gateway/model_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package gateway
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func Test_ctxGetOrCreateData(t *testing.T) {
12+
t.Run("returns data if already exists", func(t *testing.T) {
13+
req, err := http.NewRequestWithContext(t.Context(), "GET", "/", nil)
14+
require.NoError(t, err)
15+
16+
ctxSetData(req, CtxData{"hello": "world0"})
17+
18+
assert.Equal(t, CtxData{"hello": "world0"}, ctxGetOrCreateData(req))
19+
})
20+
21+
t.Run("create new data if not exists", func(t *testing.T) {
22+
req, err := http.NewRequestWithContext(t.Context(), "GET", "/", nil)
23+
require.NoError(t, err)
24+
25+
data1 := ctxGetOrCreateData(req)
26+
data1["hello"] = "world1"
27+
28+
data2 := ctxGetOrCreateData(req)
29+
30+
assert.Equal(t, data1, data2)
31+
})
32+
}

gateway/mw_api_rate_limit.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"time"
99

1010
"github.com/TykTechnologies/tyk/ctx"
11+
"github.com/TykTechnologies/tyk/header"
1112
tykerrors "github.com/TykTechnologies/tyk/internal/errors"
1213
"github.com/TykTechnologies/tyk/internal/event"
1314
"github.com/TykTechnologies/tyk/storage"
@@ -94,30 +95,38 @@ func (k *RateLimitForAPI) EnabledForSpec() bool {
9495
}
9596

9697
// ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
97-
func (k *RateLimitForAPI) ProcessRequest(_ http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {
98+
//
99+
//nolint:staticcheck
100+
func (k *RateLimitForAPI) ProcessRequest(rw http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {
98101
// Skip rate limiting and quotas for looping
99102
if !ctxCheckLimits(r) {
100103
return nil, http.StatusOK
101104
}
102105

103-
storeRef := k.Gw.GlobalSessionManager.Store()
106+
session := k.getSession(r)
107+
108+
limitHeaderSender := k.Gw.limitHeaderFactory(rw.Header())
109+
// Only inject API-level rate limit headers if personal rate limit headers
110+
// haven't already been injected by RateLimitAndQuotaCheck.
111+
if rw.Header().Get(header.XRateLimitLimit) != "" {
112+
limitHeaderSender = nil
113+
}
104114

105115
reason := k.Gw.SessionLimiter.ForwardMessage(
106116
r,
107-
k.getSession(r),
117+
session,
108118
k.keyName,
109119
k.quotaKey,
110-
storeRef,
111120
true,
112121
false,
113122
k.Spec,
114123
false,
124+
limitHeaderSender,
115125
)
116126

117127
k.emitRateLimitEvents(r, k.keyName)
118128

119129
if reason == sessionFailRateLimit {
120-
// Set error classification for access logs
121130
ctx.SetErrorClassification(r, tykerrors.ClassifyRateLimitError(tykerrors.ErrTypeAPIRateLimit, k.Name()))
122131
return k.handleRateLimitFailure(r, event.RateLimitExceeded, "API Rate Limit Exceeded", k.keyName)
123132
}

gateway/mw_api_rate_limit_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package gateway
22

33
import (
4+
"fmt"
45
"net/http"
56
"net/http/httptest"
67
"net/url"
78
"testing"
89
"time"
910

1011
"github.com/TykTechnologies/tyk/apidef"
12+
"github.com/TykTechnologies/tyk/config"
13+
"github.com/TykTechnologies/tyk/header"
1114

1215
"github.com/stretchr/testify/assert"
1316

@@ -75,6 +78,82 @@ func TestRateLimitForAPI_EnabledForSpec(t *testing.T) {
7578
assert.False(t, rlDisabled.EnabledForSpec())
7679
}
7780

81+
func TestAPIRateLimitResponseHeaders(t *testing.T) {
82+
limiters := []string{"Redis", "Sentinel", "DRL", "FixedWindow"}
83+
84+
for _, limiter := range limiters {
85+
t.Run("API Rate limit headers for "+limiter, func(t *testing.T) {
86+
ts := StartTest(func(globalConf *config.Config) {
87+
globalConf.RateLimitResponseHeaders = config.SourceRateLimits
88+
89+
switch limiter {
90+
case "Redis":
91+
globalConf.EnableRedisRollingLimiter = true
92+
case "Sentinel":
93+
globalConf.EnableSentinelRateLimiter = true
94+
case "DRL":
95+
globalConf.DRLEnableSentinelRateLimiter = true
96+
case "FixedWindow":
97+
globalConf.EnableFixedWindowRateLimiter = true
98+
}
99+
})
100+
defer ts.Close()
101+
102+
var (
103+
rateLimitRate float64 = 2
104+
rateLimitPer float64 = 10
105+
)
106+
107+
_ = ts.Gw.BuildAndLoadAPI(func(spec *APISpec) {
108+
spec.APIID = "api-rate-limit-headers-test-" + limiter
109+
spec.Proxy.ListenPath = "/api-rate-limit-headers-test"
110+
spec.UseKeylessAccess = true
111+
spec.GlobalRateLimit = apidef.GlobalRateLimit{
112+
Disabled: false,
113+
Rate: rateLimitRate,
114+
Per: rateLimitPer,
115+
}
116+
})[0]
117+
118+
expectedRemaining1 := fmt.Sprintf("%d", int(rateLimitRate)-1)
119+
expectedRemaining2 := fmt.Sprintf("%d", int(rateLimitRate)-2)
120+
121+
headersMatch1 := map[string]string{
122+
header.XRateLimitLimit: fmt.Sprintf("%d", int(rateLimitRate)),
123+
}
124+
headersMatch2 := map[string]string{
125+
header.XRateLimitLimit: fmt.Sprintf("%d", int(rateLimitRate)),
126+
}
127+
128+
// For limiters that don't support Remaining (Sentinel, FixedWindow), it should be assigned to 0.
129+
if limiter == "Redis" || limiter == "DRL" {
130+
headersMatch1[header.XRateLimitRemaining] = expectedRemaining1
131+
headersMatch2[header.XRateLimitRemaining] = expectedRemaining2
132+
} else {
133+
headersMatch1[header.XRateLimitRemaining] = "0"
134+
headersMatch2[header.XRateLimitRemaining] = "0"
135+
}
136+
137+
_, _ = ts.Run(t, []test.TestCase{
138+
{
139+
Path: "/api-rate-limit-headers-test",
140+
Code: http.StatusOK,
141+
HeadersMatch: headersMatch1,
142+
},
143+
{
144+
Path: "/api-rate-limit-headers-test",
145+
Code: http.StatusOK,
146+
HeadersMatch: headersMatch2,
147+
},
148+
{
149+
Path: "/api-rate-limit-headers-test",
150+
Code: http.StatusTooManyRequests,
151+
},
152+
}...)
153+
})
154+
}
155+
}
156+
78157
func TestRLOpen(t *testing.T) {
79158
ts := StartTest(nil)
80159
defer ts.Close()

gateway/mw_mock_response.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ func (m *mockResponseMiddleware) mockResponse(r *http.Request) (
170170
res.Header.Set(header.Connection, "close")
171171
}
172172

173-
m.Spec.sendRateLimitHeaders(ctxGetSession(r), res)
173+
m.Gw.limitHeaderFactory(res.Header).SendQuotas(ctxGetSession(r), m.Spec.APIID)
174174

175175
return res, internal, nil
176176
}

gateway/mw_organisation_activity.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,13 @@ func (k *OrganizationMonitor) refreshOrgSession(orgID string) {
120120
}
121121

122122
// ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
123-
func (k *OrganizationMonitor) ProcessRequestLive(r *http.Request, orgSession *user.SessionState) (error, int) {
123+
//
124+
//nolint:staticcheck
125+
func (k *OrganizationMonitor) ProcessRequestLive(
126+
r *http.Request,
127+
orgSession *user.SessionState,
128+
) (error, int) {
129+
124130
logger := k.Logger()
125131

126132
if orgSession.IsInactive {
@@ -135,11 +141,11 @@ func (k *OrganizationMonitor) ProcessRequestLive(r *http.Request, orgSession *us
135141
orgSession,
136142
k.Spec.OrgID,
137143
"",
138-
k.Spec.OrgSessionManager.Store(),
139144
orgSession.Per > 0 && orgSession.Rate > 0,
140145
true,
141146
k.Spec,
142147
false,
148+
nil,
143149
)
144150

145151
sessionLifeTime := orgSession.Lifetime(k.Spec.GetSessionLifetimeRespectsKeyExpiration(), k.Spec.SessionLifetime, k.Gw.GetConfig().ForceGlobalSessionLifetime, k.Gw.GetConfig().GlobalSessionLifetime)
@@ -211,7 +217,12 @@ func (k *OrganizationMonitor) SetOrgSentinel(orgChan chan bool, orgId string) {
211217
}
212218
}
213219

214-
func (k *OrganizationMonitor) ProcessRequestOffThread(r *http.Request, orgSession *user.SessionState) (error, int) {
220+
//nolint:staticcheck
221+
func (k *OrganizationMonitor) ProcessRequestOffThread(
222+
r *http.Request,
223+
orgSession *user.SessionState,
224+
) (error, int) {
225+
215226
orgChanMap.Lock()
216227
orgChan, ok := orgChanMap.channels[k.Spec.OrgID]
217228
if !ok {
@@ -251,7 +262,8 @@ func (k *OrganizationMonitor) AllowAccessNext(
251262
path string,
252263
IP string,
253264
r *http.Request,
254-
session *user.SessionState) {
265+
session *user.SessionState,
266+
) {
255267

256268
// Is it active?
257269
logEntry := k.Gw.getExplicitLogEntryForRequest(k.Logger(), path, IP, k.Spec.OrgID, nil)
@@ -269,11 +281,11 @@ func (k *OrganizationMonitor) AllowAccessNext(
269281
session,
270282
k.Spec.OrgID,
271283
customQuotaKey,
272-
k.Spec.OrgSessionManager.Store(),
273284
session.Per > 0 && session.Rate > 0,
274285
true,
275286
k.Spec,
276287
false,
288+
nil,
277289
)
278290

279291
sessionLifeTime := session.Lifetime(k.Spec.GetSessionLifetimeRespectsKeyExpiration(), k.Spec.SessionLifetime, k.Gw.GetConfig().ForceGlobalSessionLifetime, k.Gw.GetConfig().GlobalSessionLifetime)

0 commit comments

Comments
 (0)