Skip to content

Commit da5dfcd

Browse files
authored
Merge pull request #1 from pnadolny13/multiple_queries
Multiple concurrent queries
2 parents fa62751 + 2856c96 commit da5dfcd

File tree

7 files changed

+125
-55
lines changed

7 files changed

+125
-55
lines changed

mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ ignore_missing_imports = True
77

88
[mypy-boto3.*]
99
ignore_missing_imports = True
10+
11+
[mypy-pytz.*]
12+
ignore_missing_imports = True

tap_cloudwatch/client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ def get_records(self, context: Optional[dict]) -> Iterable[dict]:
2929
self.config.get("batch_increment_s"),
3030
)
3131
for batch in cloudwatch_iter:
32-
for record in batch.get("results"):
32+
for record in batch:
3333
yield {i["field"][1:]: i["value"] for i in record}

tap_cloudwatch/cloudwatch_api.py

+85-27
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
import os
44
import time
5+
from collections import deque
56
from datetime import datetime, timezone
7+
from math import ceil
8+
9+
import boto3
610
import pytz
711

812
from tap_cloudwatch.exception import InvalidQueryException
913

10-
import boto3
11-
from math import ceil
1214

1315
class CloudwatchAPI:
1416
"""Cloudwatch class for interacting with the API."""
@@ -17,6 +19,8 @@ def __init__(self, logger):
1719
"""Initialize CloudwatchAPI."""
1820
self._client = None
1921
self.logger = logger
22+
self.limit = 10000
23+
self.max_concurrent_queries = 20
2024

2125
@property
2226
def client(self):
@@ -64,7 +68,7 @@ def _create_client(self, config):
6468
def _request_more_records():
6569
return True
6670

67-
def split_batch_into_windows(self, start_time, end_time, batch_increment_s):
71+
def _split_batch_into_windows(self, start_time, end_time, batch_increment_s):
6872
diff_s = end_time - start_time
6973
total_batches = ceil(diff_s / batch_increment_s)
7074
batch_windows = []
@@ -79,71 +83,125 @@ def split_batch_into_windows(self, start_time, end_time, batch_increment_s):
7983
batch_windows.append((query_start, query_end))
8084
return batch_windows
8185

82-
def validate_query(self, query):
86+
def _validate_query(self, query):
8387
if "|sort" in query.replace(" ", ""):
8488
raise InvalidQueryException("sort not allowed")
8589
if "|limit" in query.replace(" ", ""):
8690
raise InvalidQueryException("limit not allowed")
8791
if "stats" in query:
8892
raise InvalidQueryException("stats not allowed")
8993
if "@timestamp" not in query.split("|")[0]:
90-
raise InvalidQueryException("@timestamp field is used as the replication key so it must be selected")
94+
raise InvalidQueryException(
95+
"@timestamp field is used as the replication key so it must be selected"
96+
)
9197

9298
def get_records_iterator(self, bookmark, log_group, query, batch_increment_s):
9399
"""Retrieve records from Cloudwatch."""
94100
end_time = datetime.now(timezone.utc).timestamp()
95101
start_time = bookmark.timestamp()
96-
self.validate_query(query)
97-
batch_windows = self.split_batch_into_windows(start_time, end_time, batch_increment_s)
102+
self._validate_query(query)
103+
batch_windows = self._split_batch_into_windows(
104+
start_time, end_time, batch_increment_s
105+
)
98106

107+
queue = deque()
99108
for window in batch_windows:
100-
yield self.handle_batch_window(window[0], window[1], log_group, query)
101-
102-
def handle_limit_exceeded(self, response, log_group, query_start, query_end, query):
109+
if len(queue) < (self.max_concurrent_queries - 1):
110+
queue.append(
111+
(
112+
self._start_query(window[0], window[1], log_group, query),
113+
window[0],
114+
window[1],
115+
)
116+
)
117+
else:
118+
query_id, start, end = queue.popleft()
119+
queue.append(
120+
(
121+
self._start_query(window[0], window[1], log_group, query),
122+
window[0],
123+
window[1],
124+
)
125+
)
126+
results = self._get_results(log_group, start, end, query, query_id)
127+
yield results
128+
129+
while len(queue) > 0:
130+
query_id, start, end = queue.popleft()
131+
results = self._get_results(log_group, start, end, query, query_id)
132+
yield results
133+
134+
def _handle_limit_exceeded(
135+
self, response, log_group, query_start, query_end, query
136+
):
103137
results = response.get("results")
104138
last_record = results[-1]
105139

106-
latest_ts_str = [i["value"] for i in last_record if i["field"] == "@timestamp"][0]
140+
latest_ts_str = [i["value"] for i in last_record if i["field"] == "@timestamp"][
141+
0
142+
]
107143
# Include latest ts in query, this could cause duplicates but
108144
# without it we might miss ties
109-
query_start = int(datetime.fromisoformat(latest_ts_str).replace(tzinfo=pytz.UTC).timestamp())
110-
self.handle_batch_window(query_start, query_end, log_group, query, prev_start=query_start)
145+
new_query_start = int(
146+
datetime.fromisoformat(latest_ts_str).replace(tzinfo=pytz.UTC).timestamp()
147+
)
148+
new_query_id = self._start_query(new_query_start, query_end, log_group, query)
149+
return self._get_results(
150+
log_group, new_query_start, query_end, query, new_query_id
151+
)
111152

112-
def alter_query(self, query):
153+
def _alter_query(self, query):
113154
query += " | sort @timestamp asc"
114155
return query
115156

116-
def handle_batch_window(self, query_start, query_end, log_group, query, prev_start=None):
157+
def _start_query(self, query_start, query_end, log_group, query, prev_start=None):
117158
self.logger.info(
118159
(
119-
"Retrieving batch from:"
160+
"Submitting query for batch from:"
120161
f" `{datetime.utcfromtimestamp(query_start).isoformat()} UTC` -"
121162
f" `{datetime.utcfromtimestamp(query_end).isoformat()} UTC`"
122163
)
123164
)
124-
limit = 10000
125-
query = self.alter_query(query)
165+
query = self._alter_query(query)
126166
start_query_response = self.client.start_query(
127167
logGroupName=log_group,
128168
startTime=query_start,
129169
endTime=query_end,
130170
queryString=query,
131-
limit=limit,
171+
limit=self.limit,
132172
)
173+
return start_query_response["queryId"]
133174

134-
query_id = start_query_response["queryId"]
135-
response = None
175+
def _get_results(
176+
self, log_group, query_start, query_end, query, query_id, prev_start=None
177+
):
178+
self.logger.info(
179+
(
180+
"Retrieving results for batch from:"
181+
f" `{datetime.utcfromtimestamp(query_start).isoformat()} UTC` -"
182+
f" `{datetime.utcfromtimestamp(query_end).isoformat()} UTC`"
183+
)
184+
)
185+
response = self.client.get_query_results(queryId=query_id)
136186
while response is None or response["status"] == "Running":
137-
time.sleep(1)
187+
time.sleep(0.5)
138188
response = self.client.get_query_results(queryId=query_id)
139189
if response.get("ResponseMetadata", {}).get("HTTPStatusCode") != 200:
140190
raise Exception(f"Failed: {response}")
141191
result_size = response.get("statistics", {}).get("recordsMatched")
142-
if result_size > limit:
192+
results = response["results"]
193+
self.logger.info(f"Result set size '{int(result_size)}' received.")
194+
if result_size > self.limit:
143195
if prev_start == query_start:
144-
raise Exception("Stuck in a loop, smaller batch still exceeds limit. Reduce batch window.")
196+
raise Exception(
197+
"Stuck in a loop, smaller batch still exceeds limit."
198+
"Reduce batch window."
199+
)
145200
self.logger.info(
146-
f"Result set size '{int(result_size)}' exceeded limit '{limit}'. Re-running sub-batch..."
201+
f"Result set size '{int(result_size)}' exceeded limit "
202+
f"'{self.limit}'. Re-running sub-batch..."
203+
)
204+
results += self._handle_limit_exceeded(
205+
response, log_group, query_start, query_end, query
147206
)
148-
self.handle_limit_exceeded(response, log_group, query_start, query_end, query)
149-
return response
207+
return results

tap_cloudwatch/exception.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""Custom exceptions."""
2+
3+
14
class InvalidQueryException(Exception):
2-
"Raised when the input value is less than 18"
3-
pass
5+
"""Raised when the input query is invalid."""
6+
7+
pass

tap_cloudwatch/streams.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,12 @@ def schema(self):
2323
# | parse @message "[*] *" as loggingType, loggingMessage
2424
properties.append(
2525
th.Property(
26-
"ptr",
27-
th.StringType(),
28-
description="The identifier for the log record."
26+
"ptr", th.StringType(), description="The identifier for the log record."
2927
)
3028
)
3129
properties.append(
3230
th.Property(
33-
"timestamp",
34-
th.DateTimeType(),
35-
description="The timestamp of the log."
31+
"timestamp", th.DateTimeType(), description="The timestamp of the log."
3632
)
3733
)
3834
for prop in self.config.get("query").split("|")[0].split(","):

tap_cloudwatch/tests/test_cloudwatch_api.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,43 @@
11
"""Tests cloudwatch api module."""
22

3-
from tap_cloudwatch.cloudwatch_api import CloudwatchAPI
4-
from tap_cloudwatch.exception import InvalidQueryException
5-
import pytest
6-
from unittest.mock import patch
7-
from datetime import datetime, timezone
3+
import logging
4+
from contextlib import nullcontext as does_not_raise
85

96
import boto3
7+
import pytest
108
from botocore.stub import Stubber
119
from freezegun import freeze_time
12-
import logging
13-
from contextlib import nullcontext as does_not_raise
10+
11+
from tap_cloudwatch.cloudwatch_api import CloudwatchAPI
12+
from tap_cloudwatch.exception import InvalidQueryException
1413

1514

1615
@pytest.mark.parametrize(
17-
'start,end,batch,expected',
16+
"start,end,batch,expected",
1817
[
1918
[1672272000, 1672275600, 3600, [(1672272000, 1672275600)]],
20-
[1672272000, 1672275601, 3600, [(1672272000, 1672275600), (1672275601, 1672279200)]],
19+
[
20+
1672272000,
21+
1672275601,
22+
3600,
23+
[(1672272000, 1672275600), (1672275601, 1672279200)],
24+
],
2125
],
2226
)
2327
def test_split_batch_into_windows(start, end, batch, expected):
2428
"""Run standard tap tests from the SDK."""
2529
api = CloudwatchAPI(None)
26-
batches = api.split_batch_into_windows(start, end, batch)
30+
batches = api._split_batch_into_windows(start, end, batch)
2731
assert batches == expected
2832

2933

30-
3134
@pytest.mark.parametrize(
32-
'query,expectation',
35+
"query,expectation",
3336
[
34-
["fields @timestamp, @message | sort @timestamp desc", pytest.raises(InvalidQueryException)],
37+
[
38+
"fields @timestamp, @message | sort @timestamp desc",
39+
pytest.raises(InvalidQueryException),
40+
],
3541
["fields @timestamp, @message | limit 5", pytest.raises(InvalidQueryException)],
3642
["stats count(*) by duration as time", pytest.raises(InvalidQueryException)],
3743
["fields @message", pytest.raises(InvalidQueryException)],
@@ -42,7 +48,8 @@ def test_validate_query(query, expectation):
4248
"""Run standard tap tests from the SDK."""
4349
api = CloudwatchAPI(None)
4450
with expectation:
45-
api.validate_query(query)
51+
api._validate_query(query)
52+
4653

4754
@freeze_time("2022-12-30")
4855
def test_handle_batch_window():
@@ -66,7 +73,7 @@ def test_handle_batch_window():
6673
]
6774
],
6875
"ResponseMetadata": {"HTTPStatusCode": 200},
69-
"statistics": {"recordsMatched": 10000}
76+
"statistics": {"recordsMatched": 10000},
7077
}
7178
stubber.add_response(
7279
"start_query",
@@ -86,5 +93,7 @@ def test_handle_batch_window():
8693
)
8794
stubber.activate()
8895

89-
output = api.handle_batch_window(query_start, query_end, log_group, in_query)
90-
assert response == output
96+
query_id = api._start_query(query_start, query_end, log_group, in_query)
97+
output = api._get_results(log_group, query_start, query_end, in_query, query_id)
98+
99+
assert response["results"] == output

tap_cloudwatch/tests/test_core.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"query": "fields @timestamp, @message",
1616
"aws_region_name": "us-east-1",
1717
"start_date": "2022-12-29",
18-
"batch_increment_s": 86400
18+
"batch_increment_s": 86400,
1919
}
2020

2121
client = boto3.client("logs", region_name="us-east-1")
@@ -49,7 +49,7 @@ def test_standard_tap_tests(patch_client):
4949
]
5050
],
5151
"ResponseMetadata": {"HTTPStatusCode": 200},
52-
"statistics": {"recordsMatched": 0}
52+
"statistics": {"recordsMatched": 0},
5353
},
5454
{"queryId": "123"},
5555
)

0 commit comments

Comments
 (0)