forked from OpenHands/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_user_based_rate_limiter.py
More file actions
347 lines (281 loc) · 12.7 KB
/
Copy pathtest_user_based_rate_limiter.py
File metadata and controls
347 lines (281 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import asyncio
from datetime import datetime, timedelta
from unittest.mock import AsyncMock
import pytest
from openhands.server.middleware import UserBasedRateLimiter
from openhands.server.utils.ratelimit_storage import InMemoryRateLimiterStorage
class MockRateLimiterStorage:
"""Mock storage for testing rate limiter logic."""
def __init__(self):
self.requests = []
self.add_request_calls = []
self.clean_old_requests_calls = []
async def get_requests(self, key: str) -> list[datetime]:
return self.requests.copy()
async def add_request(self, key: str, timestamp: datetime) -> None:
self.add_request_calls.append((key, timestamp))
self.requests.append(timestamp)
async def clean_old_requests(self, key: str, cutoff: datetime) -> None:
self.clean_old_requests_calls.append((key, cutoff))
self.requests = [ts for ts in self.requests if ts > cutoff]
async def get_request_count(self, key: str) -> int:
return len(self.requests)
class TestUserBasedRateLimiter:
"""Test suite for UserBasedRateLimiter."""
@pytest.fixture
def mock_storage(self):
"""Create a mock storage for testing."""
return MockRateLimiterStorage()
@pytest.fixture
def rate_limiter_with_mock_storage(self, mock_storage):
"""Create a rate limiter with mock storage."""
return UserBasedRateLimiter(
requests=5, seconds=60, sleep_seconds=0, storage=mock_storage
)
@pytest.fixture
def rate_limiter_with_memory_storage(self):
"""Create a rate limiter with in-memory storage."""
return UserBasedRateLimiter(
requests=5,
seconds=60,
sleep_seconds=0,
storage=InMemoryRateLimiterStorage(),
)
@pytest.mark.asyncio
async def test_init_default_storage(self):
"""Test rate limiter initialization with default storage."""
rate_limiter = UserBasedRateLimiter()
assert isinstance(rate_limiter.storage, InMemoryRateLimiterStorage)
assert rate_limiter.requests == 60
assert rate_limiter.seconds == 60
assert rate_limiter.sleep_seconds == 1
@pytest.mark.asyncio
async def test_init_custom_storage(self, mock_storage):
"""Test rate limiter initialization with custom storage."""
rate_limiter = UserBasedRateLimiter(
requests=10, seconds=30, sleep_seconds=2, storage=mock_storage
)
assert rate_limiter.storage is mock_storage
assert rate_limiter.requests == 10
assert rate_limiter.seconds == 30
assert rate_limiter.sleep_seconds == 2
@pytest.mark.asyncio
async def test_is_allowed_empty_user_id(self, rate_limiter_with_mock_storage):
"""Test rate limiting with empty user ID."""
is_allowed = await rate_limiter_with_mock_storage.is_allowed('')
assert is_allowed is False
is_allowed = await rate_limiter_with_mock_storage.is_allowed(None)
assert is_allowed is False
@pytest.mark.asyncio
async def test_is_allowed_within_limit(
self, rate_limiter_with_mock_storage, mock_storage
):
"""Test rate limiting within the allowed limit."""
user_id = 'test_user'
# First request should be allowed
is_allowed = await rate_limiter_with_mock_storage.is_allowed(user_id)
assert is_allowed is True
# Verify storage interactions
assert len(mock_storage.add_request_calls) == 1
assert len(mock_storage.clean_old_requests_calls) == 1
assert mock_storage.add_request_calls[0][0] == user_id
@pytest.mark.asyncio
async def test_is_allowed_at_limit(self, rate_limiter_with_memory_storage):
"""Test rate limiting at the exact limit."""
user_id = 'test_user'
now = datetime.now()
# Add exactly the limit number of requests (5)
for i in range(5):
await rate_limiter_with_memory_storage.storage.add_request(
user_id, now + timedelta(seconds=i)
)
# Next request will be added (making it 6), then checked
# Since 6 > 5 and sleep_seconds=0, it should be rejected
is_allowed = await rate_limiter_with_memory_storage.is_allowed(user_id)
assert is_allowed is False
# Count should now be 6 (over the limit of 5)
count = await rate_limiter_with_memory_storage.storage.get_request_count(
user_id
)
assert count == 6
@pytest.mark.asyncio
async def test_is_allowed_over_limit_no_sleep(
self, rate_limiter_with_memory_storage
):
"""Test rate limiting over the limit with no sleep."""
user_id = 'test_user'
now = datetime.now()
# Add more than the limit (5) but less than 2x limit (10)
for i in range(7):
await rate_limiter_with_memory_storage.storage.add_request(
user_id, now + timedelta(seconds=i)
)
# Next request should trigger sleep logic but since sleep_seconds=0, should be rejected
is_allowed = await rate_limiter_with_memory_storage.is_allowed(user_id)
assert is_allowed is False
@pytest.mark.asyncio
async def test_is_allowed_over_limit_with_sleep(self):
"""Test rate limiting over the limit with sleep."""
storage = InMemoryRateLimiterStorage()
rate_limiter = UserBasedRateLimiter(
requests=3,
seconds=60,
sleep_seconds=0.1, # Short sleep for testing
storage=storage,
)
user_id = 'test_user'
now = datetime.now()
# Add more than the limit (3) but less than 2x limit (6)
for i in range(4):
await storage.add_request(user_id, now + timedelta(seconds=i))
# Next request should trigger sleep and then be allowed
start_time = datetime.now()
is_allowed = await rate_limiter.is_allowed(user_id)
end_time = datetime.now()
assert is_allowed is True
# Verify that sleep actually happened
assert abs((end_time - start_time).total_seconds() - 0.1) < 0.01
@pytest.mark.asyncio
async def test_is_allowed_way_over_limit(self, rate_limiter_with_memory_storage):
"""Test rate limiting way over the limit (2x limit)."""
user_id = 'test_user'
now = datetime.now()
# Add more than 2x the limit (10 > 2*5)
for i in range(12):
await rate_limiter_with_memory_storage.storage.add_request(
user_id, now + timedelta(seconds=i)
)
# Should be immediately rejected
is_allowed = await rate_limiter_with_memory_storage.is_allowed(user_id)
assert is_allowed is False
@pytest.mark.asyncio
async def test_old_requests_cleaned(self, rate_limiter_with_memory_storage):
"""Test that old requests are properly cleaned."""
user_id = 'test_user'
now = datetime.now()
# Add old requests (outside the 60-second window)
old_time = now - timedelta(seconds=120)
for i in range(5):
await rate_limiter_with_memory_storage.storage.add_request(
user_id, old_time + timedelta(seconds=i)
)
# Add recent request
await rate_limiter_with_memory_storage.storage.add_request(user_id, now)
# Make a new request - this should clean old requests
is_allowed = await rate_limiter_with_memory_storage.is_allowed(user_id)
assert is_allowed is True
# Verify old requests were cleaned (should only have 2 requests: recent + new)
count = await rate_limiter_with_memory_storage.storage.get_request_count(
user_id
)
assert count == 2
@pytest.mark.asyncio
async def test_multiple_users_isolation(self, rate_limiter_with_memory_storage):
"""Test that different users have isolated rate limits."""
user1 = 'user1'
user2 = 'user2'
now = datetime.now()
# Fill up user1's limit (5 requests)
for i in range(5):
await rate_limiter_with_memory_storage.storage.add_request(
user1, now + timedelta(seconds=i)
)
# User1 should be over limit (5 + 1 = 6 > 5, sleep_seconds=0)
is_allowed = await rate_limiter_with_memory_storage.is_allowed(user1)
assert is_allowed is False
# User2 should still be allowed (independent limit, first request)
is_allowed = await rate_limiter_with_memory_storage.is_allowed(user2)
assert is_allowed is True
# Verify counts
count1 = await rate_limiter_with_memory_storage.storage.get_request_count(user1)
count2 = await rate_limiter_with_memory_storage.storage.get_request_count(user2)
assert count1 == 6 # 5 pre-existing + 1 from is_allowed call
assert count2 == 1 # 1 from is_allowed call
@pytest.mark.asyncio
async def test_concurrent_requests(self, rate_limiter_with_memory_storage):
"""Test concurrent requests for the same user."""
user_id = 'test_user'
# Make multiple concurrent requests
tasks = []
for _ in range(10):
task = asyncio.create_task(
rate_limiter_with_memory_storage.is_allowed(user_id)
)
tasks.append(task)
results = await asyncio.gather(*tasks)
# Some should be allowed, some rejected based on the limit
allowed_count = sum(1 for result in results if result)
rejected_count = sum(1 for result in results if not result)
assert allowed_count > 0
assert rejected_count > 0
assert allowed_count + rejected_count == 10
@pytest.mark.asyncio
async def test_time_window_reset(self, rate_limiter_with_memory_storage):
"""Test that rate limit resets after time window."""
user_id = 'test_user'
# Create a rate limiter with short time window for testing
storage = InMemoryRateLimiterStorage()
rate_limiter = UserBasedRateLimiter(
requests=2,
seconds=1, # 1 second window
sleep_seconds=0,
storage=storage,
)
# Make requests to hit the limit
await rate_limiter.is_allowed(user_id) # 1st request
await rate_limiter.is_allowed(user_id) # 2nd request
await rate_limiter.is_allowed(user_id) # 3rd request (over limit)
# Should be over limit now
is_allowed = await rate_limiter.is_allowed(user_id)
assert is_allowed is False
# Wait for time window to pass
await asyncio.sleep(1.1)
# Should be allowed again (old requests cleaned)
is_allowed = await rate_limiter.is_allowed(user_id)
assert is_allowed is True
@pytest.mark.asyncio
async def test_storage_error_handling(self):
"""Test rate limiter behavior when storage operations fail."""
# Create a mock storage that raises exceptions
mock_storage = AsyncMock()
mock_storage.clean_old_requests = AsyncMock(
side_effect=Exception('Storage error')
)
mock_storage.add_request = AsyncMock(side_effect=Exception('Storage error'))
mock_storage.get_request_count = AsyncMock(
side_effect=Exception('Storage error')
)
rate_limiter = UserBasedRateLimiter(
requests=5, seconds=60, sleep_seconds=0, storage=mock_storage
)
# Should handle errors gracefully and not crash
try:
result = await rate_limiter.is_allowed('test_user')
# The behavior when storage fails is implementation-dependent
# but it should not raise an exception
assert isinstance(result, bool)
except Exception as e:
pytest.fail(f'Rate limiter should handle storage errors gracefully: {e}')
@pytest.mark.asyncio
async def test_edge_case_zero_requests(self):
"""Test rate limiter with zero allowed requests."""
storage = InMemoryRateLimiterStorage()
rate_limiter = UserBasedRateLimiter(
requests=0, seconds=60, sleep_seconds=0, storage=storage
)
# Should immediately be over limit
is_allowed = await rate_limiter.is_allowed('test_user')
assert is_allowed is False
@pytest.mark.asyncio
async def test_edge_case_zero_time_window(self):
"""Test rate limiter with zero time window."""
storage = InMemoryRateLimiterStorage()
rate_limiter = UserBasedRateLimiter(
requests=5, seconds=0, sleep_seconds=0, storage=storage
)
# With zero time window, all requests should be considered old
is_allowed = await rate_limiter.is_allowed('test_user')
# Should still be allowed as the request is added after cleaning
assert is_allowed is True
if __name__ == '__main__':
pytest.main([__file__])